|
23 | 23 | import logging |
24 | 24 | from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TypedDict |
25 | 25 |
|
| 26 | +from .events import EventHub |
26 | 27 | from .llm.base import BaseLLMBackend, RerankCandidate |
27 | 28 | from .settings import GraphSearchConfig |
28 | 29 |
|
@@ -232,34 +233,65 @@ def postprocess_results_node(state: SearchState) -> SearchState: |
232 | 233 | # --------------------------------------------------------------------------- |
233 | 234 |
|
234 | 235 |
|
235 | | -def build_search_graph(config: GraphSearchConfig, *, embedding_backend, vector_store, llm: BaseLLMBackend): |
| 236 | +def build_search_graph( |
| 237 | + config: GraphSearchConfig, |
| 238 | + *, |
| 239 | + embedding_backend, |
| 240 | + vector_store, |
| 241 | + llm: BaseLLMBackend, |
| 242 | + event_hub: Optional[EventHub] = None, |
| 243 | +): |
236 | 244 | """Build and compile the search graph. |
237 | 245 |
|
238 | 246 | When the ``langgraph`` package is available we return a compiled |
239 | 247 | LangGraph ``StateGraph``. Otherwise we return :class:`_FallbackGraph` so |
240 | 248 | the rest of the code stays identical. |
| 249 | +
|
| 250 | + Pass ``event_hub`` to receive lifecycle events (``query_received``, |
| 251 | + ``query_expanded``, ``vector_search_completed``, ``rerank_completed``, |
| 252 | + ``completed``) — the same hub powers the streaming HTTP endpoint. |
241 | 253 | """ |
242 | 254 | try: |
243 | 255 | from langgraph.graph import END, StateGraph # type: ignore |
244 | 256 | except Exception: # pragma: no cover - exercised when langgraph absent. |
245 | | - return _FallbackGraph(config=config, embedding_backend=embedding_backend, |
246 | | - vector_store=vector_store, llm=llm) |
| 257 | + return _FallbackGraph( |
| 258 | + config=config, |
| 259 | + embedding_backend=embedding_backend, |
| 260 | + vector_store=vector_store, |
| 261 | + llm=llm, |
| 262 | + event_hub=event_hub, |
| 263 | + ) |
| 264 | + |
| 265 | + def _wrap(name: str, fn): |
| 266 | + if event_hub is None: |
| 267 | + return fn |
| 268 | + |
| 269 | + def _wrapped(s): |
| 270 | + event_hub.publish({"type": f"{name}_started", "query": s.get("normalized_query") or s.get("query")}) |
| 271 | + out = fn(s) |
| 272 | + event_hub.publish({ |
| 273 | + "type": f"{name}_completed", |
| 274 | + "candidate_count": len(out.get("merged_results") or out.get("raw_results") or []), |
| 275 | + }) |
| 276 | + return out |
| 277 | + |
| 278 | + return _wrapped |
247 | 279 |
|
248 | 280 | graph: Any = StateGraph(dict) |
249 | | - graph.add_node("analyze_query", lambda s: analyze_query_node(s, config=config)) |
| 281 | + graph.add_node("analyze_query", _wrap("analyze_query", lambda s: analyze_query_node(s, config=config))) |
250 | 282 | graph.add_node( |
251 | 283 | "expand_query", |
252 | | - lambda s: expand_query_node(s, config=config, llm=llm), |
| 284 | + _wrap("expand_query", lambda s: expand_query_node(s, config=config, llm=llm)), |
253 | 285 | ) |
254 | 286 | graph.add_node( |
255 | 287 | "vector_search", |
256 | | - lambda s: vector_search_node(s, embedding_backend=embedding_backend, vector_store=vector_store), |
| 288 | + _wrap("vector_search", lambda s: vector_search_node(s, embedding_backend=embedding_backend, vector_store=vector_store)), |
257 | 289 | ) |
258 | 290 | graph.add_node( |
259 | 291 | "rerank_results", |
260 | | - lambda s: rerank_results_node(s, config=config, llm=llm), |
| 292 | + _wrap("rerank_results", lambda s: rerank_results_node(s, config=config, llm=llm)), |
261 | 293 | ) |
262 | | - graph.add_node("postprocess_results", lambda s: postprocess_results_node(s)) |
| 294 | + graph.add_node("postprocess_results", _wrap("postprocess_results", lambda s: postprocess_results_node(s))) |
263 | 295 |
|
264 | 296 | graph.set_entry_point("analyze_query") |
265 | 297 | graph.add_conditional_edges( |
@@ -295,24 +327,47 @@ def __init__( |
295 | 327 | embedding_backend, |
296 | 328 | vector_store, |
297 | 329 | llm: BaseLLMBackend, |
| 330 | + event_hub: Optional[EventHub] = None, |
298 | 331 | ) -> None: |
299 | 332 | self.config = config |
300 | 333 | self.embedding_backend = embedding_backend |
301 | 334 | self.vector_store = vector_store |
302 | 335 | self.llm = llm |
| 336 | + self.event_hub = event_hub |
| 337 | + |
| 338 | + def _emit(self, event: Dict[str, Any]) -> None: |
| 339 | + if self.event_hub is not None: |
| 340 | + self.event_hub.publish(event) |
303 | 341 |
|
304 | 342 | def invoke(self, state: SearchState) -> SearchState: |
| 343 | + self._emit({"type": "query_received", "query": state.get("query") or ""}) |
305 | 344 | state = analyze_query_node(state, config=self.config) |
306 | 345 | if self.config.langgraph.query_expansion: |
307 | 346 | state = expand_query_node(state, config=self.config, llm=self.llm) |
| 347 | + self._emit({ |
| 348 | + "type": "query_expanded", |
| 349 | + "queries": list(state.get("expanded_queries") or []), |
| 350 | + }) |
308 | 351 | state = vector_search_node( |
309 | 352 | state, |
310 | 353 | embedding_backend=self.embedding_backend, |
311 | 354 | vector_store=self.vector_store, |
312 | 355 | ) |
| 356 | + self._emit({ |
| 357 | + "type": "vector_search_completed", |
| 358 | + "candidate_count": len(state.get("merged_results") or []), |
| 359 | + }) |
313 | 360 | if self.config.langgraph.reranking: |
314 | 361 | state = rerank_results_node(state, config=self.config, llm=self.llm) |
| 362 | + self._emit({ |
| 363 | + "type": "rerank_completed", |
| 364 | + "candidate_count": len(state.get("reranked_results") or []), |
| 365 | + }) |
315 | 366 | state = postprocess_results_node(state) |
| 367 | + self._emit({ |
| 368 | + "type": "completed", |
| 369 | + "total": len(state.get("final_results") or []), |
| 370 | + }) |
316 | 371 | return state |
317 | 372 |
|
318 | 373 |
|
|
0 commit comments