Skip to content

Commit 4e18b3e

Browse files
committed
better readbility
1 parent 37dd886 commit 4e18b3e

2 files changed

Lines changed: 338 additions & 14 deletions

File tree

lmdeploy/pytorch/paging/scheduler.py

Lines changed: 150 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import logging
4141
import time
4242
from collections import OrderedDict
43+
from collections.abc import Callable
4344
from contextlib import contextmanager
4445
from dataclasses import dataclass
4546

@@ -72,6 +73,28 @@ class SchedulerOutput:
7273
copy_map: MapType
7374

7475

76+
_PREFILL_GATE_SKIP = 'skip'
77+
_PREFILL_GATE_BREAK = 'break'
78+
79+
80+
@dataclass
81+
class _PrefixMatchForPrefillGate:
82+
"""Tentative prefix match kept only because it passes a prefill gate."""
83+
84+
stats_snapshot: object
85+
prefill_token_count: int
86+
is_nonfinal_long_prefill: bool
87+
88+
89+
@dataclass
90+
class _PrefillGateCheck:
91+
"""Result of prefill-gate checks before final resource admission."""
92+
93+
prefix_match: _PrefixMatchForPrefillGate | None = None
94+
rollback_action: str | None = None
95+
reject_action: str | None = None
96+
97+
7598
class Scheduler:
7699
"""Tools to schedule next step.
77100
@@ -154,6 +177,96 @@ def _rollback_unscheduled_prefix_match(self, seq: SchedulerSequence, stats_snaps
154177
prefix_cache.match_start_step = -1
155178
seq.cached_tokens = 0
156179

180+
def _rollback_prefix_match_for_prefill_gate(self, seq: SchedulerSequence, stats_snapshot, reason: str):
181+
"""Rollback a prefix match tried only to re-check prefill gates."""
182+
if logger.isEnabledFor(logging.DEBUG):
183+
logger.debug(f'Rollback tentative prefix-cache gate match: session_id={seq.session_id} '
184+
f'seq_id={seq.seq_id} reason={reason} num_history_ids={seq.num_history_ids} '
185+
f'restore_state={seq.prefix_cache.restore_state}')
186+
self._rollback_unscheduled_prefix_match(seq, stats_snapshot)
187+
188+
def _try_prefix_match_for_prefill_gate(
189+
self,
190+
seq: SchedulerSequence,
191+
accept_match: Callable[[_PrefixMatchForPrefillGate], bool],
192+
rollback_reason: str,
193+
):
194+
"""Tentatively match prefix cache before rejecting a prefill candidate.
195+
196+
This helper is intentionally limited to pre-admission gates. It does not evict, allocate, acquire SSM restore
197+
state, or publish cache state. The caller either continues into the normal admission path with the returned
198+
match, or the helper rolls every match side effect back.
199+
"""
200+
if not self.block_trie.enable:
201+
return None
202+
203+
stats_snapshot = self.block_trie.snapshot_stats()
204+
self.block_trie.match(seq)
205+
if self._prefix_hit_starts_middle_long_context_chunk(seq):
206+
self._rollback_prefix_match_for_prefill_gate(seq, stats_snapshot,
207+
'long-context chunk starts after prefix hit')
208+
return None
209+
210+
prefix_match = _PrefixMatchForPrefillGate(
211+
stats_snapshot=stats_snapshot,
212+
prefill_token_count=self._prefill_admission_token_count(seq),
213+
is_nonfinal_long_prefill=self._prefill_kv_token_limit(seq) is not None,
214+
)
215+
if accept_match(prefix_match):
216+
return prefix_match
217+
218+
self._rollback_prefix_match_for_prefill_gate(seq, stats_snapshot, rollback_reason)
219+
return None
220+
221+
def _check_prefill_admission_gates(self, seq: SchedulerSequence, token_count: int, has_admitted: bool,
222+
allow_long_prefill: bool):
223+
"""Check prefill policy gates before resource admission.
224+
225+
A prefix-cache hit can shrink a request enough to pass a short-turn or
226+
token-budget gate. When that happens, the returned prefix match is
227+
still tentative; if later resource admission rolls it back, the caller
228+
must reject this candidate with ``rollback_action`` for the current
229+
scheduler turn.
230+
"""
231+
prefill_token_count = self._prefill_admission_token_count(seq)
232+
is_nonfinal_long_prefill = self._prefill_kv_token_limit(seq) is not None
233+
prefix_match = None
234+
rollback_action = None
235+
236+
if is_nonfinal_long_prefill and not allow_long_prefill:
237+
prefix_match = self._try_prefix_match_for_prefill_gate(
238+
seq,
239+
accept_match=lambda match: not match.is_nonfinal_long_prefill,
240+
rollback_reason='still non-final long prefill on short turn')
241+
if prefix_match is None:
242+
return _PrefillGateCheck(reject_action=_PREFILL_GATE_SKIP)
243+
prefill_token_count = prefix_match.prefill_token_count
244+
rollback_action = _PREFILL_GATE_SKIP
245+
246+
exceeds_token_budget = (has_admitted
247+
and token_count + prefill_token_count > self.cache_config.max_prefill_token_num)
248+
if exceeds_token_budget:
249+
if prefix_match is None:
250+
prefix_match = self._try_prefix_match_for_prefill_gate(
251+
seq,
252+
accept_match=lambda match: token_count +
253+
match.prefill_token_count <= self.cache_config.max_prefill_token_num,
254+
rollback_reason='still exceeds prefill token budget')
255+
if prefix_match is not None:
256+
prefill_token_count = prefix_match.prefill_token_count
257+
rollback_action = _PREFILL_GATE_SKIP if not allow_long_prefill else _PREFILL_GATE_BREAK
258+
259+
still_exceeds_token_budget = token_count + prefill_token_count > self.cache_config.max_prefill_token_num
260+
if prefix_match is None or still_exceeds_token_budget:
261+
if prefix_match is not None:
262+
self._rollback_prefix_match_for_prefill_gate(seq, prefix_match.stats_snapshot,
263+
'still exceeds prefill token budget')
264+
reject_action = _PREFILL_GATE_SKIP if not allow_long_prefill else _PREFILL_GATE_BREAK
265+
return _PrefillGateCheck(reject_action=reject_action)
266+
267+
return _PrefillGateCheck(prefix_match=prefix_match,
268+
rollback_action=rollback_action)
269+
157270
@staticmethod
158271
def _finalize_prefix_cache_match(seq: SchedulerSequence):
159272
"""Publish accepted cached-token count within the current prompt."""
@@ -485,23 +598,33 @@ def _reorder_waiting():
485598
skipped_waiting: SeqList = []
486599
while len(waiting) > 0 and len(running) < max_batches:
487600
seq = waiting.pop(0)
488-
prefill_token_count = self._prefill_admission_token_count(seq)
489-
is_nonfinal_long_prefill = self._prefill_kv_token_limit(seq) is not None
490-
491-
if is_nonfinal_long_prefill and not allow_long_prefill:
492-
skipped_waiting.append(seq)
493-
continue
601+
gate_check = self._check_prefill_admission_gates(seq,
602+
token_count=token_count,
603+
has_admitted=len(running) > 0,
604+
allow_long_prefill=allow_long_prefill)
605+
606+
def __reject_after_prefill_gate_match_rollback():
607+
"""Reject if resource admission rolled back a gate-only hit."""
608+
if gate_check.prefix_match is None:
609+
return False
610+
if gate_check.rollback_action == _PREFILL_GATE_SKIP:
611+
skipped_waiting.append(seq)
612+
return True
613+
return False
494614

495-
if (len(running) > 0 and token_count + prefill_token_count > self.cache_config.max_prefill_token_num):
496-
if not allow_long_prefill:
615+
if gate_check.reject_action is not None:
616+
if gate_check.reject_action == _PREFILL_GATE_SKIP:
497617
skipped_waiting.append(seq)
498618
continue
499619
break
500620

501621
evictable_waiting = skipped_waiting + waiting
502622

503623
if self.block_trie.enable:
504-
stats_snapshot = self.block_trie.snapshot_stats()
624+
if gate_check.prefix_match is None:
625+
stats_snapshot = self.block_trie.snapshot_stats()
626+
else:
627+
stats_snapshot = gate_check.prefix_match.stats_snapshot
505628

506629
def __rollback_prefix_match(reason: str):
507630
if logger.isEnabledFor(logging.DEBUG):
@@ -510,30 +633,45 @@ def __rollback_prefix_match(reason: str):
510633
f'restore_state={seq.prefix_cache.restore_state}')
511634
self._rollback_unscheduled_prefix_match(seq, stats_snapshot)
512635

513-
self.block_trie.match(seq)
514-
if self._prefix_hit_starts_middle_long_context_chunk(seq):
515-
__rollback_prefix_match('long-context chunk starts after prefix hit')
636+
if gate_check.prefix_match is None:
637+
self.block_trie.match(seq)
638+
if self._prefix_hit_starts_middle_long_context_chunk(seq):
639+
__rollback_prefix_match('long-context chunk starts after prefix hit')
516640

517641
had_ssm_restore = self.is_ssm and seq.prefix_cache.restore_state >= 0
518642
if not self._acquire_ssm_restore_if_needed(seq):
519643
__rollback_prefix_match('failed to acquire SSM restore checkpoint')
644+
if gate_check.prefix_match is not None:
645+
if __reject_after_prefill_gate_match_rollback():
646+
continue
647+
break
520648

521649
evicted, alloc_prealloc_size = __prepare_and_evict(seq, evictable_waiting)
522650
if not evicted:
523651
if not had_ssm_restore:
524652
__rollback_prefix_match('eviction failed')
653+
if __reject_after_prefill_gate_match_rollback():
654+
continue
525655
break
526656
# A matched SSM restore may be pinning the only checkpoint
527657
# state that eviction would otherwise free. Roll it back once
528658
# and retry eviction before declaring the sequence unschedulable.
529659
__rollback_prefix_match('eviction failed with pinned SSM restore')
660+
if __reject_after_prefill_gate_match_rollback():
661+
continue
662+
if gate_check.prefix_match is not None:
663+
break
530664
evicted, alloc_prealloc_size = __prepare_and_evict(seq, evictable_waiting)
531665
if not evicted:
532666
break
533667

534668
# allocate session memory
535669
if self.is_ssm and not self._ensure_runtime_state_available():
536670
__rollback_prefix_match('no runtime SSM state available')
671+
if __reject_after_prefill_gate_match_rollback():
672+
continue
673+
if gate_check.prefix_match is not None:
674+
break
537675
evicted, alloc_prealloc_size = __prepare_and_evict(seq, evictable_waiting)
538676
if not evicted:
539677
break

0 commit comments

Comments
 (0)