Skip to content

Commit 91ad7e2

Browse files
committed
fix
1 parent 4cf43bc commit 91ad7e2

2 files changed

Lines changed: 36 additions & 0 deletions

File tree

lmdeploy/pytorch/paging/scheduler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,11 @@ def __rollback_prefix_match(reason: str):
544544
evicted, alloc_prealloc_size = __prepare_and_evict(seq, evictable_waiting)
545545
if not evicted:
546546
break
547+
# Prefix-cache matching can advance the sequence step and shrink
548+
# the remaining prefill tail. Charge the admitted batch with the
549+
# post-match/post-rollback cost, not the conservative pre-match
550+
# estimate used to decide whether this sequence is worth trying.
551+
prefill_token_count = self._prefill_admission_token_count(seq)
547552
self.block_manager.allocate(seq, alloc_prealloc_size)
548553
if self.block_trie.enable:
549554
self.block_trie.allocate(seq)

tests/pytorch/paging/test_scheduler.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,37 @@ def test_scheduler_publishes_cached_tokens_for_accepted_prefix_hit():
518518
assert seq.prefix_cache.match_start_step == -1
519519

520520

521+
def test_scheduler_recomputes_prefill_budget_after_prefix_hit():
522+
from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
523+
block_size = 16
524+
seq_meta = SequenceMeta(block_size, strategy=ARSequenceStrategy())
525+
cache_config = CacheConfig(max_batches=2,
526+
block_size=block_size,
527+
num_cpu_blocks=0,
528+
num_gpu_blocks=8,
529+
max_prefill_token_num=block_size,
530+
enable_prefix_caching=True)
531+
scheduler_config = SchedulerConfig(max_batches=2,
532+
max_session_len=128,
533+
max_request_output_len=64,
534+
eviction_type='recompute')
535+
scheduler = Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta)
536+
537+
cached = scheduler.add_session(0).add_sequence([1] * block_size + [2])
538+
scheduler.schedule(is_prefill=True)
539+
cached.state.stop()
540+
541+
cache_hit_tail = scheduler.add_session(1).add_sequence([1] * block_size + [3])
542+
short = scheduler.add_session(2).add_sequence([4])
543+
544+
output = scheduler.schedule(is_prefill=True)
545+
546+
assert output.running == [cache_hit_tail, short]
547+
assert cache_hit_tail.num_history_ids == block_size
548+
assert cache_hit_tail.num_token_ids == 1
549+
assert short.status == MessageStatus.READY
550+
551+
521552
def test_scheduler_reports_zero_cached_tokens_for_prefix_miss():
522553
from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
523554
block_size = 16

0 commit comments

Comments
 (0)