Skip to content

Commit 37dd886

Browse files
committed
fix unfinished
1 parent 91ad7e2 commit 37dd886

3 files changed

Lines changed: 55 additions & 9 deletions

File tree

lmdeploy/pytorch/engine/engine_loop.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import asyncio
33
import logging
44
import time
5+
from collections.abc import Callable
56
from dataclasses import dataclass
67
from typing import TYPE_CHECKING, Any, Optional
78

@@ -55,25 +56,32 @@ def clear(self):
5556
class RunableEventAsync:
5657
"""Awaitable async runable event."""
5758

58-
def __init__(self, scheduler: 'Scheduler'):
59+
def __init__(self, scheduler: 'Scheduler', extra_runable_checker: Callable[[], bool] | None = None):
5960
self.scheduler = scheduler
61+
self.extra_runable_checker = extra_runable_checker
6062
self.event = asyncio.Event()
6163

64+
def has_unfinished(self):
65+
"""Check whether scheduler or engine-local state has runnable work."""
66+
if self.scheduler.has_unfinished():
67+
return True
68+
return self.extra_runable_checker is not None and self.extra_runable_checker()
69+
6270
async def wait(self):
6371
"""Wait event."""
6472
await self.event.wait()
6573

6674
def set(self):
6775
"""Set event."""
68-
if self.scheduler.has_unfinished():
76+
if self.has_unfinished():
6977
self.event.set()
7078
else:
7179
self.event.clear()
7280

7381

74-
def build_runable_event(scheduler: 'Scheduler'):
82+
def build_runable_event(scheduler: 'Scheduler', extra_runable_checker: Callable[[], bool] | None = None):
7583
"""Build runable event."""
76-
return RunableEventAsync(scheduler)
84+
return RunableEventAsync(scheduler, extra_runable_checker)
7785

7886

7987
@dataclass
@@ -128,7 +136,9 @@ def __init__(self,
128136
self.resp_queue = asyncio.Queue()
129137
self.forward_event = CounterEvent()
130138
self.migration_event = asyncio.Event()
131-
self.has_runable_event = RunableEventAsync(self.scheduler)
139+
# Active long-context chunks are owned by InputsMaker, not the
140+
# scheduler WAITING/READY queues, so include them in the runnable gate.
141+
self.has_runable_event = RunableEventAsync(self.scheduler, self.inputs_maker.has_pending_long_context_chunk)
132142
# Sleep uses a small handshake with the scheduling loops:
133143
# 1. sleep() sets _sleep_requested and waits for main/migration drain events.
134144
# 2. main_loop and migration_loop reach safe boundaries, acknowledge
@@ -383,13 +393,12 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'):
383393

384394
async def _main_loop_try_send_next_inputs(self):
385395
"""Try send next inputs."""
386-
scheduler = self.scheduler
387-
if not scheduler.has_unfinished():
396+
if not self.has_runable_event.has_unfinished():
388397
await self.has_runable_event.wait()
389398
if self._sleep_requested:
390399
return None, None
391400

392-
scheduler.collect_migration_done()
401+
self.scheduler.collect_migration_done()
393402
return await self.inputs_maker.send_next_inputs()
394403

395404
@staticmethod

lmdeploy/pytorch/engine/inputs_maker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,11 @@ def _has_pending_last_long_context_chunk(self):
346346
left."""
347347
return self.long_context_chunker.enabled() and self.long_context_chunker.is_last_chunk()
348348

349+
def has_pending_long_context_chunk(self):
350+
"""Check whether engine-local long-context chunk work can run."""
351+
self.long_context_chunker.check_enable()
352+
return self.long_context_chunker.enabled()
353+
349354
def _should_defer_long_context_chunk(self, prefill: bool):
350355
"""Check whether the active long-context chunk should yield this
351356
loop."""

tests/pytorch/engine/test_inputs_maker.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import lmdeploy.pytorch.engine.inputs_maker as inputs_maker_module
99
from lmdeploy.pytorch.disagg.config import EngineRole
10-
from lmdeploy.pytorch.engine.engine_loop import EngineLoop
10+
from lmdeploy.pytorch.engine.engine_loop import EngineLoop, RunableEventAsync
1111
from lmdeploy.pytorch.engine.inputs_maker import (
1212
InputsMakerAsync,
1313
InputsMakerConfig,
@@ -206,6 +206,38 @@ async def get_output_async(self):
206206
assert not block_trie.pinned
207207

208208

209+
def test_engine_loop_treats_pending_long_context_chunk_as_runnable():
210+
events = []
211+
212+
class _Scheduler:
213+
214+
def has_unfinished(self):
215+
return False
216+
217+
def collect_migration_done(self):
218+
events.append('collect_migration_done')
219+
220+
class _InputsMaker:
221+
222+
def has_pending_long_context_chunk(self):
223+
return True
224+
225+
async def send_next_inputs(self):
226+
events.append('send_next_inputs')
227+
return 'forward_inputs', ['long-seq']
228+
229+
loop = EngineLoop.__new__(EngineLoop)
230+
loop.scheduler = _Scheduler()
231+
loop.inputs_maker = _InputsMaker()
232+
loop.has_runable_event = RunableEventAsync(loop.scheduler, loop.inputs_maker.has_pending_long_context_chunk)
233+
loop._sleep_requested = False
234+
235+
result = asyncio.run(asyncio.wait_for(loop._main_loop_try_send_next_inputs(), timeout=1.0))
236+
237+
assert result == ('forward_inputs', ['long-seq'])
238+
assert events == ['collect_migration_done', 'send_next_inputs']
239+
240+
209241
def _make_policy_maker(long_seq, decode_seq=None):
210242
maker = InputsMakerAsync.__new__(InputsMakerAsync)
211243
maker.config = SimpleNamespace(role=EngineRole.Decode)

0 commit comments

Comments
 (0)