4040import logging
4141import time
4242from collections import OrderedDict
43+ from collections .abc import Callable
4344from contextlib import contextmanager
4445from dataclasses import dataclass
4546
@@ -72,6 +73,28 @@ class SchedulerOutput:
7273 copy_map : MapType
7374
7475
76+ _PREFILL_GATE_SKIP = 'skip'
77+ _PREFILL_GATE_BREAK = 'break'
78+
79+
80+ @dataclass
81+ class _PrefixMatchForPrefillGate :
82+ """Tentative prefix match kept only because it passes a prefill gate."""
83+
84+ stats_snapshot : object
85+ prefill_token_count : int
86+ is_nonfinal_long_prefill : bool
87+
88+
89+ @dataclass
90+ class _PrefillGateCheck :
91+ """Result of prefill-gate checks before final resource admission."""
92+
93+ prefix_match : _PrefixMatchForPrefillGate | None = None
94+ rollback_action : str | None = None
95+ reject_action : str | None = None
96+
97+
7598class Scheduler :
7699 """Tools to schedule next step.
77100
@@ -154,6 +177,96 @@ def _rollback_unscheduled_prefix_match(self, seq: SchedulerSequence, stats_snaps
154177 prefix_cache .match_start_step = - 1
155178 seq .cached_tokens = 0
156179
180+ def _rollback_prefix_match_for_prefill_gate (self , seq : SchedulerSequence , stats_snapshot , reason : str ):
181+ """Rollback a prefix match tried only to re-check prefill gates."""
182+ if logger .isEnabledFor (logging .DEBUG ):
183+ logger .debug (f'Rollback tentative prefix-cache gate match: session_id={ seq .session_id } '
184+ f'seq_id={ seq .seq_id } reason={ reason } num_history_ids={ seq .num_history_ids } '
185+ f'restore_state={ seq .prefix_cache .restore_state } ' )
186+ self ._rollback_unscheduled_prefix_match (seq , stats_snapshot )
187+
188+ def _try_prefix_match_for_prefill_gate (
189+ self ,
190+ seq : SchedulerSequence ,
191+ accept_match : Callable [[_PrefixMatchForPrefillGate ], bool ],
192+ rollback_reason : str ,
193+ ):
194+ """Tentatively match prefix cache before rejecting a prefill candidate.
195+
196+ This helper is intentionally limited to pre-admission gates. It does not evict, allocate, acquire SSM restore
197+ state, or publish cache state. The caller either continues into the normal admission path with the returned
198+ match, or the helper rolls every match side effect back.
199+ """
200+ if not self .block_trie .enable :
201+ return None
202+
203+ stats_snapshot = self .block_trie .snapshot_stats ()
204+ self .block_trie .match (seq )
205+ if self ._prefix_hit_starts_middle_long_context_chunk (seq ):
206+ self ._rollback_prefix_match_for_prefill_gate (seq , stats_snapshot ,
207+ 'long-context chunk starts after prefix hit' )
208+ return None
209+
210+ prefix_match = _PrefixMatchForPrefillGate (
211+ stats_snapshot = stats_snapshot ,
212+ prefill_token_count = self ._prefill_admission_token_count (seq ),
213+ is_nonfinal_long_prefill = self ._prefill_kv_token_limit (seq ) is not None ,
214+ )
215+ if accept_match (prefix_match ):
216+ return prefix_match
217+
218+ self ._rollback_prefix_match_for_prefill_gate (seq , stats_snapshot , rollback_reason )
219+ return None
220+
221+ def _check_prefill_admission_gates (self , seq : SchedulerSequence , token_count : int , has_admitted : bool ,
222+ allow_long_prefill : bool ):
223+ """Check prefill policy gates before resource admission.
224+
225+ A prefix-cache hit can shrink a request enough to pass a short-turn or
226+ token-budget gate. When that happens, the returned prefix match is
227+ still tentative; if later resource admission rolls it back, the caller
228+ must reject this candidate with ``rollback_action`` for the current
229+ scheduler turn.
230+ """
231+ prefill_token_count = self ._prefill_admission_token_count (seq )
232+ is_nonfinal_long_prefill = self ._prefill_kv_token_limit (seq ) is not None
233+ prefix_match = None
234+ rollback_action = None
235+
236+ if is_nonfinal_long_prefill and not allow_long_prefill :
237+ prefix_match = self ._try_prefix_match_for_prefill_gate (
238+ seq ,
239+ accept_match = lambda match : not match .is_nonfinal_long_prefill ,
240+ rollback_reason = 'still non-final long prefill on short turn' )
241+ if prefix_match is None :
242+ return _PrefillGateCheck (reject_action = _PREFILL_GATE_SKIP )
243+ prefill_token_count = prefix_match .prefill_token_count
244+ rollback_action = _PREFILL_GATE_SKIP
245+
246+ exceeds_token_budget = (has_admitted
247+ and token_count + prefill_token_count > self .cache_config .max_prefill_token_num )
248+ if exceeds_token_budget :
249+ if prefix_match is None :
250+ prefix_match = self ._try_prefix_match_for_prefill_gate (
251+ seq ,
252+ accept_match = lambda match : token_count +
253+ match .prefill_token_count <= self .cache_config .max_prefill_token_num ,
254+ rollback_reason = 'still exceeds prefill token budget' )
255+ if prefix_match is not None :
256+ prefill_token_count = prefix_match .prefill_token_count
257+ rollback_action = _PREFILL_GATE_SKIP if not allow_long_prefill else _PREFILL_GATE_BREAK
258+
259+ still_exceeds_token_budget = token_count + prefill_token_count > self .cache_config .max_prefill_token_num
260+ if prefix_match is None or still_exceeds_token_budget :
261+ if prefix_match is not None :
262+ self ._rollback_prefix_match_for_prefill_gate (seq , prefix_match .stats_snapshot ,
263+ 'still exceeds prefill token budget' )
264+ reject_action = _PREFILL_GATE_SKIP if not allow_long_prefill else _PREFILL_GATE_BREAK
265+ return _PrefillGateCheck (reject_action = reject_action )
266+
267+ return _PrefillGateCheck (prefix_match = prefix_match ,
268+ rollback_action = rollback_action )
269+
157270 @staticmethod
158271 def _finalize_prefix_cache_match (seq : SchedulerSequence ):
159272 """Publish accepted cached-token count within the current prompt."""
@@ -485,23 +598,33 @@ def _reorder_waiting():
485598 skipped_waiting : SeqList = []
486599 while len (waiting ) > 0 and len (running ) < max_batches :
487600 seq = waiting .pop (0 )
488- prefill_token_count = self ._prefill_admission_token_count (seq )
489- is_nonfinal_long_prefill = self ._prefill_kv_token_limit (seq ) is not None
490-
491- if is_nonfinal_long_prefill and not allow_long_prefill :
492- skipped_waiting .append (seq )
493- continue
601+ gate_check = self ._check_prefill_admission_gates (seq ,
602+ token_count = token_count ,
603+ has_admitted = len (running ) > 0 ,
604+ allow_long_prefill = allow_long_prefill )
605+
606+ def __reject_after_prefill_gate_match_rollback ():
607+ """Reject if resource admission rolled back a gate-only hit."""
608+ if gate_check .prefix_match is None :
609+ return False
610+ if gate_check .rollback_action == _PREFILL_GATE_SKIP :
611+ skipped_waiting .append (seq )
612+ return True
613+ return False
494614
495- if ( len ( running ) > 0 and token_count + prefill_token_count > self . cache_config . max_prefill_token_num ) :
496- if not allow_long_prefill :
615+ if gate_check . reject_action is not None :
616+ if gate_check . reject_action == _PREFILL_GATE_SKIP :
497617 skipped_waiting .append (seq )
498618 continue
499619 break
500620
501621 evictable_waiting = skipped_waiting + waiting
502622
503623 if self .block_trie .enable :
504- stats_snapshot = self .block_trie .snapshot_stats ()
624+ if gate_check .prefix_match is None :
625+ stats_snapshot = self .block_trie .snapshot_stats ()
626+ else :
627+ stats_snapshot = gate_check .prefix_match .stats_snapshot
505628
506629 def __rollback_prefix_match (reason : str ):
507630 if logger .isEnabledFor (logging .DEBUG ):
@@ -510,30 +633,45 @@ def __rollback_prefix_match(reason: str):
510633 f'restore_state={ seq .prefix_cache .restore_state } ' )
511634 self ._rollback_unscheduled_prefix_match (seq , stats_snapshot )
512635
513- self .block_trie .match (seq )
514- if self ._prefix_hit_starts_middle_long_context_chunk (seq ):
515- __rollback_prefix_match ('long-context chunk starts after prefix hit' )
636+ if gate_check .prefix_match is None :
637+ self .block_trie .match (seq )
638+ if self ._prefix_hit_starts_middle_long_context_chunk (seq ):
639+ __rollback_prefix_match ('long-context chunk starts after prefix hit' )
516640
517641 had_ssm_restore = self .is_ssm and seq .prefix_cache .restore_state >= 0
518642 if not self ._acquire_ssm_restore_if_needed (seq ):
519643 __rollback_prefix_match ('failed to acquire SSM restore checkpoint' )
644+ if gate_check .prefix_match is not None :
645+ if __reject_after_prefill_gate_match_rollback ():
646+ continue
647+ break
520648
521649 evicted , alloc_prealloc_size = __prepare_and_evict (seq , evictable_waiting )
522650 if not evicted :
523651 if not had_ssm_restore :
524652 __rollback_prefix_match ('eviction failed' )
653+ if __reject_after_prefill_gate_match_rollback ():
654+ continue
525655 break
526656 # A matched SSM restore may be pinning the only checkpoint
527657 # state that eviction would otherwise free. Roll it back once
528658 # and retry eviction before declaring the sequence unschedulable.
529659 __rollback_prefix_match ('eviction failed with pinned SSM restore' )
660+ if __reject_after_prefill_gate_match_rollback ():
661+ continue
662+ if gate_check .prefix_match is not None :
663+ break
530664 evicted , alloc_prealloc_size = __prepare_and_evict (seq , evictable_waiting )
531665 if not evicted :
532666 break
533667
534668 # allocate session memory
535669 if self .is_ssm and not self ._ensure_runtime_state_available ():
536670 __rollback_prefix_match ('no runtime SSM state available' )
671+ if __reject_after_prefill_gate_match_rollback ():
672+ continue
673+ if gate_check .prefix_match is not None :
674+ break
537675 evicted , alloc_prealloc_size = __prepare_and_evict (seq , evictable_waiting )
538676 if not evicted :
539677 break
0 commit comments