@@ -518,6 +518,37 @@ def test_scheduler_publishes_cached_tokens_for_accepted_prefix_hit():
518518 assert seq .prefix_cache .match_start_step == - 1
519519
520520
521+ def test_scheduler_recomputes_prefill_budget_after_prefix_hit ():
522+ from lmdeploy .pytorch .strategies .ar .sequence import ARSequenceStrategy
523+ block_size = 16
524+ seq_meta = SequenceMeta (block_size , strategy = ARSequenceStrategy ())
525+ cache_config = CacheConfig (max_batches = 2 ,
526+ block_size = block_size ,
527+ num_cpu_blocks = 0 ,
528+ num_gpu_blocks = 8 ,
529+ max_prefill_token_num = block_size ,
530+ enable_prefix_caching = True )
531+ scheduler_config = SchedulerConfig (max_batches = 2 ,
532+ max_session_len = 128 ,
533+ max_request_output_len = 64 ,
534+ eviction_type = 'recompute' )
535+ scheduler = Scheduler (scheduler_config = scheduler_config , cache_config = cache_config , seq_meta = seq_meta )
536+
537+ cached = scheduler .add_session (0 ).add_sequence ([1 ] * block_size + [2 ])
538+ scheduler .schedule (is_prefill = True )
539+ cached .state .stop ()
540+
541+ cache_hit_tail = scheduler .add_session (1 ).add_sequence ([1 ] * block_size + [3 ])
542+ short = scheduler .add_session (2 ).add_sequence ([4 ])
543+
544+ output = scheduler .schedule (is_prefill = True )
545+
546+ assert output .running == [cache_hit_tail , short ]
547+ assert cache_hit_tail .num_history_ids == block_size
548+ assert cache_hit_tail .num_token_ids == 1
549+ assert short .status == MessageStatus .READY
550+
551+
521552def test_scheduler_reports_zero_cached_tokens_for_prefix_miss ():
522553 from lmdeploy .pytorch .strategies .ar .sequence import ARSequenceStrategy
523554 block_size = 16
0 commit comments