@@ -905,6 +905,53 @@ def __create_short_or_normal_prefill_turn():
905905 self ._short_prefill_turns_since_long_chunk += 1
906906 return result
907907
908+ def __is_empty_forward (forward_inputs : 'ModelInputs|None' , forward_delta : 'ModelInputsDelta|None' ):
909+ return forward_inputs is None and forward_delta is None
910+
911+ def __try_active_long_context_chunk ():
912+ nonlocal attempted_long_work
913+ nonlocal active_long_chunk_blocked_by_kv
914+ attempted_long_work = True
915+ result = __create_inputs_long_context_chunk ()
916+ _ , chunk_inputs , chunk_delta , _ = result
917+ active_long_chunk_blocked_by_kv = __is_empty_forward (chunk_inputs , chunk_delta )
918+ return result
919+
920+ def __should_try_short_prefill_before_active_chunk ():
921+ """Allow short/normal prefill quota before an active non-final
922+ chunk."""
923+ if self .long_context_chunker .is_last_chunk ():
924+ return False
925+ if not scheduler .has_waiting ():
926+ return False
927+ return not self ._is_long_context_chunk_turn_due ()
928+
929+ def __has_no_forward ():
930+ return __is_empty_forward (inputs , delta )
931+
932+ def __can_fallback_to_short_after_long_work ():
933+ if not __has_no_forward ():
934+ return False
935+ if not attempted_long_work :
936+ return False
937+ if active_long_chunk_blocked_by_kv :
938+ return False
939+ if attempted_short_or_normal_prefill :
940+ return False
941+ return scheduler .has_waiting ()
942+
943+ def __can_try_short_prefill_after_defer ():
944+ if not __has_no_forward ():
945+ return False
946+ if not deferred_long_context_chunk :
947+ return False
948+ if self ._is_long_context_chunk_turn_due ():
949+ return False
950+ return scheduler .has_waiting ()
951+
952+ def __can_retry_deferred_active_chunk ():
953+ return __has_no_forward () and deferred_long_context_chunk and self .long_context_chunker .enabled ()
954+
908955 scheduler = self .scheduler
909956 logger .debug (f'Make forward inputs with prefill={ prefill } , enable_empty={ enable_empty } ' )
910957
@@ -926,11 +973,9 @@ def __create_short_or_normal_prefill_turn():
926973 # long prefill through the scheduler.
927974 self .long_context_chunker .check_enable ()
928975 if self .long_context_chunker .enabled ():
929- # long context chunking
930976 if self ._should_defer_long_context_chunk (prefill ):
931977 deferred_long_context_chunk = True
932- elif (not self .long_context_chunker .is_last_chunk () and scheduler .has_waiting ()
933- and not self ._is_long_context_chunk_turn_due ()):
978+ elif __should_try_short_prefill_before_active_chunk ():
934979 # After a decode turn, keep the short/normal prefill quota in
935980 # front of active long chunks; otherwise decode -> long can
936981 # repeat and small waiting requests remain gated by the active
@@ -943,14 +988,10 @@ def __create_short_or_normal_prefill_turn():
943988 swap_in_map ,
944989 swap_out_map ,
945990 ) = __create_short_or_normal_prefill_turn ()
946- if inputs is None and delta is None :
947- attempted_long_work = True
948- running , inputs , delta , extra_inputs = __create_inputs_long_context_chunk ()
949- active_long_chunk_blocked_by_kv = inputs is None and delta is None
991+ if __is_empty_forward (inputs , delta ):
992+ running , inputs , delta , extra_inputs = __try_active_long_context_chunk ()
950993 else :
951- attempted_long_work = True
952- running , inputs , delta , extra_inputs = __create_inputs_long_context_chunk ()
953- active_long_chunk_blocked_by_kv = inputs is None and delta is None
994+ running , inputs , delta , extra_inputs = __try_active_long_context_chunk ()
954995 elif prefill :
955996 # prefill
956997 has_waiting_long_prefill = scheduler .has_waiting_long_prefill ()
@@ -963,7 +1004,7 @@ def __create_short_or_normal_prefill_turn():
9631004 swap_in_map ,
9641005 swap_out_map ,
9651006 ) = __create_short_or_normal_prefill_turn ()
966- if inputs is None and delta is None :
1007+ if __has_no_forward () :
9671008 (
9681009 running ,
9691010 inputs ,
@@ -986,8 +1027,7 @@ def __create_short_or_normal_prefill_turn():
9861027 # Waiting-long admission failure can still fall back to short prefills.
9871028 # Active-long reservation failure means KV is pinned by running work;
9881029 # admit decode only so existing requests can drain blocks.
989- if (inputs is None and delta is None and attempted_long_work and not active_long_chunk_blocked_by_kv
990- and scheduler .has_waiting () and not attempted_short_or_normal_prefill ):
1030+ if __can_fallback_to_short_after_long_work ():
9911031 (
9921032 running ,
9931033 inputs ,
@@ -1004,8 +1044,7 @@ def __create_short_or_normal_prefill_turn():
10041044 self .to_evict_seqs = invalid_seqs
10051045 extra_inputs = None
10061046
1007- if (inputs is None and delta is None and deferred_long_context_chunk and scheduler .has_waiting ()
1008- and not self ._is_long_context_chunk_turn_due ()):
1047+ if __can_try_short_prefill_after_defer ():
10091048 (
10101049 running ,
10111050 inputs ,
@@ -1015,8 +1054,8 @@ def __create_short_or_normal_prefill_turn():
10151054 swap_out_map ,
10161055 ) = __create_short_or_normal_prefill_turn ()
10171056
1018- if inputs is None and delta is None and deferred_long_context_chunk and self . long_context_chunker . enabled ():
1019- running , inputs , delta , extra_inputs = __create_inputs_long_context_chunk ()
1057+ if __can_retry_deferred_active_chunk ():
1058+ running , inputs , delta , extra_inputs = __try_active_long_context_chunk ()
10201059
10211060 # reset decode count when non-decoding inputs are produced
10221061 if inputs is not None and not inputs .is_decoding :
0 commit comments