|
2 | 2 | import asyncio |
3 | 3 | import logging |
4 | 4 | import time |
| 5 | +from collections.abc import Callable |
5 | 6 | from dataclasses import dataclass |
6 | 7 | from typing import TYPE_CHECKING, Any, Optional |
7 | 8 |
|
@@ -55,25 +56,32 @@ def clear(self): |
55 | 56 | class RunableEventAsync: |
56 | 57 | """Awaitable async runable event.""" |
57 | 58 |
|
58 | | - def __init__(self, scheduler: 'Scheduler'): |
| 59 | + def __init__(self, scheduler: 'Scheduler', extra_runable_checker: Callable[[], bool] | None = None): |
59 | 60 | self.scheduler = scheduler |
| 61 | + self.extra_runable_checker = extra_runable_checker |
60 | 62 | self.event = asyncio.Event() |
61 | 63 |
|
| 64 | + def has_unfinished(self): |
| 65 | + """Check whether scheduler or engine-local state has runnable work.""" |
| 66 | + if self.scheduler.has_unfinished(): |
| 67 | + return True |
| 68 | + return self.extra_runable_checker is not None and self.extra_runable_checker() |
| 69 | + |
62 | 70 | async def wait(self): |
63 | 71 | """Wait event.""" |
64 | 72 | await self.event.wait() |
65 | 73 |
|
66 | 74 | def set(self): |
67 | 75 | """Set event.""" |
68 | | - if self.scheduler.has_unfinished(): |
| 76 | + if self.has_unfinished(): |
69 | 77 | self.event.set() |
70 | 78 | else: |
71 | 79 | self.event.clear() |
72 | 80 |
|
73 | 81 |
|
74 | | -def build_runable_event(scheduler: 'Scheduler'): |
| 82 | +def build_runable_event(scheduler: 'Scheduler', extra_runable_checker: Callable[[], bool] | None = None): |
75 | 83 | """Build runable event.""" |
76 | | - return RunableEventAsync(scheduler) |
| 84 | + return RunableEventAsync(scheduler, extra_runable_checker) |
77 | 85 |
|
78 | 86 |
|
79 | 87 | @dataclass |
@@ -128,7 +136,9 @@ def __init__(self, |
128 | 136 | self.resp_queue = asyncio.Queue() |
129 | 137 | self.forward_event = CounterEvent() |
130 | 138 | self.migration_event = asyncio.Event() |
131 | | - self.has_runable_event = RunableEventAsync(self.scheduler) |
| 139 | + # Active long-context chunks are owned by InputsMaker, not the |
| 140 | + # scheduler WAITING/READY queues, so include them in the runnable gate. |
| 141 | + self.has_runable_event = RunableEventAsync(self.scheduler, self.inputs_maker.has_pending_long_context_chunk) |
132 | 142 | # Sleep uses a small handshake with the scheduling loops: |
133 | 143 | # 1. sleep() sets _sleep_requested and waits for main/migration drain events. |
134 | 144 | # 2. main_loop and migration_loop reach safe boundaries, acknowledge |
@@ -383,13 +393,12 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'): |
383 | 393 |
|
384 | 394 | async def _main_loop_try_send_next_inputs(self): |
385 | 395 | """Try send next inputs.""" |
386 | | - scheduler = self.scheduler |
387 | | - if not scheduler.has_unfinished(): |
| 396 | + if not self.has_runable_event.has_unfinished(): |
388 | 397 | await self.has_runable_event.wait() |
389 | 398 | if self._sleep_requested: |
390 | 399 | return None, None |
391 | 400 |
|
392 | | - scheduler.collect_migration_done() |
| 401 | + self.scheduler.collect_migration_done() |
393 | 402 | return await self.inputs_maker.send_next_inputs() |
394 | 403 |
|
395 | 404 | @staticmethod |
|
0 commit comments