Skip to content

Commit 4cf43bc

Browse files
committed
improve readability
1 parent f5392fe commit 4cf43bc

2 files changed

Lines changed: 100 additions & 46 deletions

File tree

lmdeploy/pytorch/engine/inputs_maker.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,53 @@ def __create_short_or_normal_prefill_turn():
905905
self._short_prefill_turns_since_long_chunk += 1
906906
return result
907907

908+
def __is_empty_forward(forward_inputs: 'ModelInputs|None', forward_delta: 'ModelInputsDelta|None'):
909+
return forward_inputs is None and forward_delta is None
910+
911+
def __try_active_long_context_chunk():
912+
nonlocal attempted_long_work
913+
nonlocal active_long_chunk_blocked_by_kv
914+
attempted_long_work = True
915+
result = __create_inputs_long_context_chunk()
916+
_, chunk_inputs, chunk_delta, _ = result
917+
active_long_chunk_blocked_by_kv = __is_empty_forward(chunk_inputs, chunk_delta)
918+
return result
919+
920+
def __should_try_short_prefill_before_active_chunk():
921+
"""Allow short/normal prefill quota before an active non-final
922+
chunk."""
923+
if self.long_context_chunker.is_last_chunk():
924+
return False
925+
if not scheduler.has_waiting():
926+
return False
927+
return not self._is_long_context_chunk_turn_due()
928+
929+
def __has_no_forward():
930+
return __is_empty_forward(inputs, delta)
931+
932+
def __can_fallback_to_short_after_long_work():
933+
if not __has_no_forward():
934+
return False
935+
if not attempted_long_work:
936+
return False
937+
if active_long_chunk_blocked_by_kv:
938+
return False
939+
if attempted_short_or_normal_prefill:
940+
return False
941+
return scheduler.has_waiting()
942+
943+
def __can_try_short_prefill_after_defer():
944+
if not __has_no_forward():
945+
return False
946+
if not deferred_long_context_chunk:
947+
return False
948+
if self._is_long_context_chunk_turn_due():
949+
return False
950+
return scheduler.has_waiting()
951+
952+
def __can_retry_deferred_active_chunk():
953+
return __has_no_forward() and deferred_long_context_chunk and self.long_context_chunker.enabled()
954+
908955
scheduler = self.scheduler
909956
logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}')
910957

@@ -926,11 +973,9 @@ def __create_short_or_normal_prefill_turn():
926973
# long prefill through the scheduler.
927974
self.long_context_chunker.check_enable()
928975
if self.long_context_chunker.enabled():
929-
# long context chunking
930976
if self._should_defer_long_context_chunk(prefill):
931977
deferred_long_context_chunk = True
932-
elif (not self.long_context_chunker.is_last_chunk() and scheduler.has_waiting()
933-
and not self._is_long_context_chunk_turn_due()):
978+
elif __should_try_short_prefill_before_active_chunk():
934979
# After a decode turn, keep the short/normal prefill quota in
935980
# front of active long chunks; otherwise decode -> long can
936981
# repeat and small waiting requests remain gated by the active
@@ -943,14 +988,10 @@ def __create_short_or_normal_prefill_turn():
943988
swap_in_map,
944989
swap_out_map,
945990
) = __create_short_or_normal_prefill_turn()
946-
if inputs is None and delta is None:
947-
attempted_long_work = True
948-
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
949-
active_long_chunk_blocked_by_kv = inputs is None and delta is None
991+
if __is_empty_forward(inputs, delta):
992+
running, inputs, delta, extra_inputs = __try_active_long_context_chunk()
950993
else:
951-
attempted_long_work = True
952-
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
953-
active_long_chunk_blocked_by_kv = inputs is None and delta is None
994+
running, inputs, delta, extra_inputs = __try_active_long_context_chunk()
954995
elif prefill:
955996
# prefill
956997
has_waiting_long_prefill = scheduler.has_waiting_long_prefill()
@@ -963,7 +1004,7 @@ def __create_short_or_normal_prefill_turn():
9631004
swap_in_map,
9641005
swap_out_map,
9651006
) = __create_short_or_normal_prefill_turn()
966-
if inputs is None and delta is None:
1007+
if __has_no_forward():
9671008
(
9681009
running,
9691010
inputs,
@@ -986,8 +1027,7 @@ def __create_short_or_normal_prefill_turn():
9861027
# Waiting-long admission failure can still fall back to short prefills.
9871028
# Active-long reservation failure means KV is pinned by running work;
9881029
# admit decode only so existing requests can drain blocks.
989-
if (inputs is None and delta is None and attempted_long_work and not active_long_chunk_blocked_by_kv
990-
and scheduler.has_waiting() and not attempted_short_or_normal_prefill):
1030+
if __can_fallback_to_short_after_long_work():
9911031
(
9921032
running,
9931033
inputs,
@@ -1004,8 +1044,7 @@ def __create_short_or_normal_prefill_turn():
10041044
self.to_evict_seqs = invalid_seqs
10051045
extra_inputs = None
10061046

1007-
if (inputs is None and delta is None and deferred_long_context_chunk and scheduler.has_waiting()
1008-
and not self._is_long_context_chunk_turn_due()):
1047+
if __can_try_short_prefill_after_defer():
10091048
(
10101049
running,
10111050
inputs,
@@ -1015,8 +1054,8 @@ def __create_short_or_normal_prefill_turn():
10151054
swap_out_map,
10161055
) = __create_short_or_normal_prefill_turn()
10171056

1018-
if inputs is None and delta is None and deferred_long_context_chunk and self.long_context_chunker.enabled():
1019-
running, inputs, delta, extra_inputs = __create_inputs_long_context_chunk()
1057+
if __can_retry_deferred_active_chunk():
1058+
running, inputs, delta, extra_inputs = __try_active_long_context_chunk()
10201059

10211060
# reset decode count when non-decoding inputs are produced
10221061
if inputs is not None and not inputs.is_decoding:

lmdeploy/pytorch/paging/scheduler.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def _long_prefill_priority_key(self, seq: SchedulerSequence, now: float):
250250
estimated_chunks = self._long_prefill_estimated_chunks(seq)
251251
wait_age = max(0.0, now - seq.arrive_time)
252252
age_credit = int(wait_age // self._long_prefill_aging_seconds_per_chunk)
253-
virtual_chunks = estimated_chunks - age_credit
254-
return virtual_chunks, estimated_chunks, seq.arrive_time
253+
age_adjusted_chunks = estimated_chunks - age_credit
254+
return age_adjusted_chunks, estimated_chunks, seq.arrive_time
255255

256256
def _prepare_prefill_allocation(self, seq: SchedulerSequence, prealloc_size: int):
257257
"""Apply chunk KV limit and return the effective prealloc size."""
@@ -423,44 +423,59 @@ def __prepare_and_evict(seq: SchedulerSequence, waiting):
423423
seq.kv_token_limit = None
424424
return False, alloc_prealloc_size
425425

426+
def _split_waiting_by_prefill_kind(waiting: SeqList):
427+
"""Split waiting requests into normal/final and non-final long
428+
prefill."""
429+
normal_waiting: SeqList = []
430+
long_waiting: SeqList = []
431+
for seq in waiting:
432+
if self._prefill_kv_token_limit(seq) is None:
433+
normal_waiting.append(seq)
434+
else:
435+
long_waiting.append(seq)
436+
return normal_waiting, long_waiting
437+
438+
def _sort_normal_prefills(waiting: SeqList):
439+
return sorted(waiting, key=lambda seq: (self._prefill_admission_token_count(seq), seq.arrive_time))
440+
441+
def _sort_long_prefills_for_long_turn(waiting: SeqList):
442+
if self._long_prefill_policy != 'size':
443+
return waiting
444+
now = time.perf_counter()
445+
return sorted(waiting, key=lambda seq: self._long_prefill_priority_key(seq, now))
446+
447+
def _reorder_waiting_for_long_turn(waiting: SeqList):
448+
"""Choose one long waiter, then fill the turn with normal
449+
prefills."""
450+
normal_waiting, long_waiting = _split_waiting_by_prefill_kind(waiting)
451+
if len(long_waiting) == 0:
452+
return None
453+
454+
long_waiting = _sort_long_prefills_for_long_turn(long_waiting)
455+
normal_waiting = _sort_normal_prefills(normal_waiting)
456+
return [long_waiting[0]] + normal_waiting + long_waiting[1:]
457+
458+
def _reorder_waiting_for_short_turn(waiting: SeqList):
459+
"""Prioritize normal/final prefills while preserving long
460+
waiters."""
461+
normal_waiting, long_waiting = _split_waiting_by_prefill_kind(waiting)
462+
return _sort_normal_prefills(normal_waiting) + long_waiting
463+
426464
def _reorder_waiting():
427465
"""Reorder waiting."""
428466
waiting = sorted(self.waiting, key=lambda seq: seq.arrive_time)
429467
if prefer_long_prefill:
430468
# Long-work turns choose one long waiter first. The size policy
431469
# only reorders this long lane; it is not global
432470
# shortest-prefill-first admission.
433-
long_waiting: SeqList = []
434-
normal_waiting: SeqList = []
435-
for seq in waiting:
436-
if self._prefill_kv_token_limit(seq) is None:
437-
normal_waiting.append(seq)
438-
else:
439-
long_waiting.append(seq)
440-
if len(long_waiting) > 0:
441-
if self._long_prefill_policy == 'size':
442-
now = time.perf_counter()
443-
long_waiting = sorted(long_waiting,
444-
key=lambda seq: self._long_prefill_priority_key(seq, now))
445-
normal_waiting = sorted(normal_waiting,
446-
key=lambda seq: (self._prefill_admission_token_count(seq),
447-
seq.arrive_time))
448-
return [long_waiting[0]] + normal_waiting + long_waiting[1:]
471+
long_turn_waiting = _reorder_waiting_for_long_turn(waiting)
472+
if long_turn_waiting is not None:
473+
return long_turn_waiting
449474

450475
if allow_long_prefill:
451476
return waiting
452477

453-
normal_waiting: SeqList = []
454-
long_waiting: SeqList = []
455-
for seq in waiting:
456-
if self._prefill_kv_token_limit(seq) is None:
457-
normal_waiting.append(seq)
458-
else:
459-
long_waiting.append(seq)
460-
461-
normal_waiting = sorted(normal_waiting, key=lambda seq: (self._prefill_admission_token_count(seq),
462-
seq.arrive_time))
463-
return normal_waiting + long_waiting
478+
return _reorder_waiting_for_short_turn(waiting)
464479

465480
num_waiting = self.seq_manager.num_sequences(MessageStatus.WAITING)
466481
if (len(running) >= max_batches or num_waiting == 0):

0 commit comments

Comments
 (0)