haiku.rag-slim 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag-slim might be problematic. Click here for more details.
- haiku/rag/__init__.py +0 -0
- haiku/rag/app.py +542 -0
- haiku/rag/chunker.py +65 -0
- haiku/rag/cli.py +466 -0
- haiku/rag/client.py +731 -0
- haiku/rag/config/__init__.py +74 -0
- haiku/rag/config/loader.py +94 -0
- haiku/rag/config/models.py +99 -0
- haiku/rag/embeddings/__init__.py +49 -0
- haiku/rag/embeddings/base.py +25 -0
- haiku/rag/embeddings/ollama.py +28 -0
- haiku/rag/embeddings/openai.py +26 -0
- haiku/rag/embeddings/vllm.py +29 -0
- haiku/rag/embeddings/voyageai.py +27 -0
- haiku/rag/graph/__init__.py +26 -0
- haiku/rag/graph/agui/__init__.py +53 -0
- haiku/rag/graph/agui/cli_renderer.py +135 -0
- haiku/rag/graph/agui/emitter.py +197 -0
- haiku/rag/graph/agui/events.py +254 -0
- haiku/rag/graph/agui/server.py +310 -0
- haiku/rag/graph/agui/state.py +34 -0
- haiku/rag/graph/agui/stream.py +86 -0
- haiku/rag/graph/common/__init__.py +5 -0
- haiku/rag/graph/common/models.py +42 -0
- haiku/rag/graph/common/nodes.py +265 -0
- haiku/rag/graph/common/prompts.py +46 -0
- haiku/rag/graph/common/utils.py +44 -0
- haiku/rag/graph/deep_qa/__init__.py +1 -0
- haiku/rag/graph/deep_qa/dependencies.py +27 -0
- haiku/rag/graph/deep_qa/graph.py +243 -0
- haiku/rag/graph/deep_qa/models.py +20 -0
- haiku/rag/graph/deep_qa/prompts.py +59 -0
- haiku/rag/graph/deep_qa/state.py +56 -0
- haiku/rag/graph/research/__init__.py +3 -0
- haiku/rag/graph/research/common.py +87 -0
- haiku/rag/graph/research/dependencies.py +151 -0
- haiku/rag/graph/research/graph.py +295 -0
- haiku/rag/graph/research/models.py +166 -0
- haiku/rag/graph/research/prompts.py +107 -0
- haiku/rag/graph/research/state.py +85 -0
- haiku/rag/logging.py +56 -0
- haiku/rag/mcp.py +245 -0
- haiku/rag/monitor.py +194 -0
- haiku/rag/qa/__init__.py +33 -0
- haiku/rag/qa/agent.py +93 -0
- haiku/rag/qa/prompts.py +60 -0
- haiku/rag/reader.py +135 -0
- haiku/rag/reranking/__init__.py +63 -0
- haiku/rag/reranking/base.py +13 -0
- haiku/rag/reranking/cohere.py +34 -0
- haiku/rag/reranking/mxbai.py +28 -0
- haiku/rag/reranking/vllm.py +44 -0
- haiku/rag/reranking/zeroentropy.py +59 -0
- haiku/rag/store/__init__.py +4 -0
- haiku/rag/store/engine.py +309 -0
- haiku/rag/store/models/__init__.py +4 -0
- haiku/rag/store/models/chunk.py +17 -0
- haiku/rag/store/models/document.py +17 -0
- haiku/rag/store/repositories/__init__.py +9 -0
- haiku/rag/store/repositories/chunk.py +442 -0
- haiku/rag/store/repositories/document.py +261 -0
- haiku/rag/store/repositories/settings.py +165 -0
- haiku/rag/store/upgrades/__init__.py +62 -0
- haiku/rag/store/upgrades/v0_10_1.py +64 -0
- haiku/rag/store/upgrades/v0_9_3.py +112 -0
- haiku/rag/utils.py +211 -0
- haiku_rag_slim-0.16.0.dist-info/METADATA +128 -0
- haiku_rag_slim-0.16.0.dist-info/RECORD +71 -0
- haiku_rag_slim-0.16.0.dist-info/WHEEL +4 -0
- haiku_rag_slim-0.16.0.dist-info/entry_points.txt +2 -0
- haiku_rag_slim-0.16.0.dist-info/licenses/LICENSE +7 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
"""AG-UI HTTP server implementation for graph execution."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import AsyncIterator, Callable
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Protocol
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from haiku.rag.config.models import AppConfig
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
from pydantic_graph.beta import Graph
|
|
13
|
+
from starlette.applications import Starlette
|
|
14
|
+
from starlette.middleware import Middleware
|
|
15
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
16
|
+
from starlette.requests import Request
|
|
17
|
+
from starlette.responses import JSONResponse, StreamingResponse
|
|
18
|
+
from starlette.routing import Route
|
|
19
|
+
|
|
20
|
+
from haiku.rag.config.models import AGUIConfig
|
|
21
|
+
from haiku.rag.graph.agui.emitter import AGUIEmitter
|
|
22
|
+
from haiku.rag.graph.agui.events import AGUIEvent
|
|
23
|
+
from haiku.rag.graph.agui.stream import stream_graph
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GraphDeps(Protocol):
|
|
27
|
+
"""Protocol for graph dependencies that support AG-UI emission."""
|
|
28
|
+
|
|
29
|
+
agui_emitter: AGUIEmitter[Any, Any] | None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RunAgentInput(BaseModel):
|
|
33
|
+
"""AG-UI protocol run agent input.
|
|
34
|
+
|
|
35
|
+
See: https://docs.ag-ui.com/concepts/agents#runagentinput
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
thread_id: str | None = Field(None, alias="threadId")
|
|
39
|
+
run_id: str | None = Field(None, alias="runId")
|
|
40
|
+
state: dict[str, Any] = Field(default_factory=dict)
|
|
41
|
+
messages: list[dict[str, Any]] = Field(default_factory=list)
|
|
42
|
+
config: dict[str, Any] = Field(default_factory=dict)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def create_agui_app(
|
|
46
|
+
graph_factory: Callable[[], Graph],
|
|
47
|
+
state_factory: Callable[[dict[str, Any]], BaseModel],
|
|
48
|
+
deps_factory: Callable[[dict[str, Any]], GraphDeps],
|
|
49
|
+
config: AGUIConfig,
|
|
50
|
+
) -> Starlette:
|
|
51
|
+
"""Create Starlette app with AG-UI endpoint.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
graph_factory: Factory function to create graph instance
|
|
55
|
+
state_factory: Factory to create initial state from input
|
|
56
|
+
deps_factory: Factory to create graph dependencies
|
|
57
|
+
config: AG-UI server configuration
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Starlette application with AG-UI endpoints
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
async def event_stream(
|
|
64
|
+
input_data: RunAgentInput,
|
|
65
|
+
) -> AsyncIterator[str]:
|
|
66
|
+
"""Generate SSE event stream from graph execution.
|
|
67
|
+
|
|
68
|
+
Yields:
|
|
69
|
+
Server-Sent Events formatted strings
|
|
70
|
+
"""
|
|
71
|
+
# Create graph, state, and dependencies
|
|
72
|
+
graph = graph_factory()
|
|
73
|
+
|
|
74
|
+
# Create initial state from input
|
|
75
|
+
initial_state = state_factory(input_data.state)
|
|
76
|
+
|
|
77
|
+
# Create dependencies (may use config from input)
|
|
78
|
+
deps = deps_factory(input_data.config)
|
|
79
|
+
|
|
80
|
+
# Execute graph and stream events
|
|
81
|
+
async for event in stream_graph(graph, initial_state, deps):
|
|
82
|
+
# Format as SSE event
|
|
83
|
+
event_data = format_sse_event(event)
|
|
84
|
+
yield event_data
|
|
85
|
+
|
|
86
|
+
async def stream_agent(request: Request) -> StreamingResponse:
|
|
87
|
+
"""AG-UI agent stream endpoint.
|
|
88
|
+
|
|
89
|
+
Accepts AG-UI RunAgentInput and streams events via SSE.
|
|
90
|
+
"""
|
|
91
|
+
# Parse request body
|
|
92
|
+
body = await request.json()
|
|
93
|
+
input_data = RunAgentInput(**body)
|
|
94
|
+
|
|
95
|
+
# Return SSE stream
|
|
96
|
+
return StreamingResponse(
|
|
97
|
+
event_stream(input_data),
|
|
98
|
+
media_type="text/event-stream",
|
|
99
|
+
headers={
|
|
100
|
+
"Cache-Control": "no-cache",
|
|
101
|
+
"Connection": "keep-alive",
|
|
102
|
+
"X-Accel-Buffering": "no", # Disable buffering in nginx
|
|
103
|
+
},
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
async def health_check(_: Request) -> JSONResponse:
|
|
107
|
+
"""Health check endpoint."""
|
|
108
|
+
return JSONResponse({"status": "healthy"})
|
|
109
|
+
|
|
110
|
+
# Define routes
|
|
111
|
+
routes = [
|
|
112
|
+
Route("/v1/agent/stream", stream_agent, methods=["POST"]),
|
|
113
|
+
Route("/health", health_check, methods=["GET"]),
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
# Configure CORS middleware
|
|
117
|
+
middleware = [
|
|
118
|
+
Middleware(
|
|
119
|
+
CORSMiddleware,
|
|
120
|
+
allow_origins=config.cors_origins,
|
|
121
|
+
allow_credentials=config.cors_credentials,
|
|
122
|
+
allow_methods=config.cors_methods,
|
|
123
|
+
allow_headers=config.cors_headers,
|
|
124
|
+
)
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
# Create Starlette app
|
|
128
|
+
app = Starlette(
|
|
129
|
+
routes=routes,
|
|
130
|
+
middleware=middleware,
|
|
131
|
+
debug=False,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return app
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def format_sse_event(event: AGUIEvent) -> str:
|
|
138
|
+
"""Format AG-UI event as Server-Sent Event.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
event: AG-UI event dictionary
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
SSE formatted string with event data
|
|
145
|
+
"""
|
|
146
|
+
# Convert event to JSON
|
|
147
|
+
event_json = json.dumps(event, ensure_ascii=False)
|
|
148
|
+
|
|
149
|
+
# Format as SSE
|
|
150
|
+
# Each event is: data: <json>\n\n
|
|
151
|
+
return f"data: {event_json}\n\n"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def create_agui_server(config: "AppConfig", db_path: Path | None = None) -> Starlette:
|
|
155
|
+
"""Create AG-UI server with both research and deep ask endpoints.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
config: Application config with research and qa settings
|
|
159
|
+
db_path: Optional database path override
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Starlette app with research and deep ask endpoints
|
|
163
|
+
"""
|
|
164
|
+
from haiku.rag.client import HaikuRAG
|
|
165
|
+
from haiku.rag.graph.deep_qa.dependencies import DeepQAContext
|
|
166
|
+
from haiku.rag.graph.deep_qa.graph import build_deep_qa_graph
|
|
167
|
+
from haiku.rag.graph.deep_qa.state import DeepQADeps, DeepQAState
|
|
168
|
+
from haiku.rag.graph.research.dependencies import ResearchContext
|
|
169
|
+
from haiku.rag.graph.research.graph import build_research_graph
|
|
170
|
+
from haiku.rag.graph.research.state import ResearchDeps, ResearchState
|
|
171
|
+
|
|
172
|
+
# Store client reference for proper lifecycle management
|
|
173
|
+
_client_cache: dict[str, HaikuRAG] = {}
|
|
174
|
+
|
|
175
|
+
def get_client(effective_db_path: Path) -> HaikuRAG:
|
|
176
|
+
"""Get or create cached client."""
|
|
177
|
+
path_key = str(effective_db_path)
|
|
178
|
+
if path_key not in _client_cache:
|
|
179
|
+
_client_cache[path_key] = HaikuRAG(db_path=effective_db_path, config=config)
|
|
180
|
+
return _client_cache[path_key]
|
|
181
|
+
|
|
182
|
+
# Research graph factories
|
|
183
|
+
def research_graph_factory() -> Graph:
|
|
184
|
+
return build_research_graph(config)
|
|
185
|
+
|
|
186
|
+
def research_state_factory(input_state: dict[str, Any]) -> ResearchState:
|
|
187
|
+
question = input_state.get("question", "")
|
|
188
|
+
if not question:
|
|
189
|
+
messages = input_state.get("messages", [])
|
|
190
|
+
if messages:
|
|
191
|
+
question = messages[0].get("content", "")
|
|
192
|
+
context = ResearchContext(original_question=question)
|
|
193
|
+
return ResearchState.from_config(context=context, config=config)
|
|
194
|
+
|
|
195
|
+
def research_deps_factory(input_config: dict[str, Any]) -> ResearchDeps:
|
|
196
|
+
effective_db_path = (
|
|
197
|
+
db_path
|
|
198
|
+
or input_config.get("db_path")
|
|
199
|
+
or config.storage.data_dir / "haiku.rag.lancedb"
|
|
200
|
+
)
|
|
201
|
+
return ResearchDeps(client=get_client(effective_db_path))
|
|
202
|
+
|
|
203
|
+
# Deep ask graph factories
|
|
204
|
+
def deep_ask_graph_factory() -> Graph:
|
|
205
|
+
return build_deep_qa_graph(config)
|
|
206
|
+
|
|
207
|
+
def deep_ask_state_factory(input_state: dict[str, Any]) -> DeepQAState:
|
|
208
|
+
question = input_state.get("question", "")
|
|
209
|
+
if not question:
|
|
210
|
+
messages = input_state.get("messages", [])
|
|
211
|
+
if messages:
|
|
212
|
+
question = messages[0].get("content", "")
|
|
213
|
+
use_citations = input_state.get("use_citations", False)
|
|
214
|
+
context = DeepQAContext(original_question=question, use_citations=use_citations)
|
|
215
|
+
return DeepQAState.from_config(context=context, config=config)
|
|
216
|
+
|
|
217
|
+
def deep_ask_deps_factory(input_config: dict[str, Any]) -> DeepQADeps:
|
|
218
|
+
effective_db_path = (
|
|
219
|
+
db_path
|
|
220
|
+
or input_config.get("db_path")
|
|
221
|
+
or config.storage.data_dir / "haiku.rag.lancedb"
|
|
222
|
+
)
|
|
223
|
+
return DeepQADeps(client=get_client(effective_db_path))
|
|
224
|
+
|
|
225
|
+
# Create event stream functions for each graph type
|
|
226
|
+
async def research_event_stream(
|
|
227
|
+
input_data: RunAgentInput,
|
|
228
|
+
) -> AsyncIterator[str]:
|
|
229
|
+
"""Generate SSE event stream from research graph execution."""
|
|
230
|
+
graph = research_graph_factory()
|
|
231
|
+
initial_state = research_state_factory(input_data.state)
|
|
232
|
+
deps = research_deps_factory(input_data.config)
|
|
233
|
+
|
|
234
|
+
async for event in stream_graph(graph, initial_state, deps):
|
|
235
|
+
event_data = format_sse_event(event)
|
|
236
|
+
yield event_data
|
|
237
|
+
|
|
238
|
+
async def deep_ask_event_stream(
|
|
239
|
+
input_data: RunAgentInput,
|
|
240
|
+
) -> AsyncIterator[str]:
|
|
241
|
+
"""Generate SSE event stream from deep ask graph execution."""
|
|
242
|
+
graph = deep_ask_graph_factory()
|
|
243
|
+
initial_state = deep_ask_state_factory(input_data.state)
|
|
244
|
+
deps = deep_ask_deps_factory(input_data.config)
|
|
245
|
+
|
|
246
|
+
async for event in stream_graph(graph, initial_state, deps):
|
|
247
|
+
event_data = format_sse_event(event)
|
|
248
|
+
yield event_data
|
|
249
|
+
|
|
250
|
+
# Endpoint handlers
|
|
251
|
+
async def stream_research(request: Request) -> StreamingResponse:
|
|
252
|
+
"""Research graph streaming endpoint."""
|
|
253
|
+
body = await request.json()
|
|
254
|
+
input_data = RunAgentInput(**body)
|
|
255
|
+
|
|
256
|
+
return StreamingResponse(
|
|
257
|
+
research_event_stream(input_data),
|
|
258
|
+
media_type="text/event-stream",
|
|
259
|
+
headers={
|
|
260
|
+
"Cache-Control": "no-cache",
|
|
261
|
+
"Connection": "keep-alive",
|
|
262
|
+
"X-Accel-Buffering": "no",
|
|
263
|
+
},
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
async def stream_deep_ask(request: Request) -> StreamingResponse:
|
|
267
|
+
"""Deep ask graph streaming endpoint."""
|
|
268
|
+
body = await request.json()
|
|
269
|
+
input_data = RunAgentInput(**body)
|
|
270
|
+
|
|
271
|
+
return StreamingResponse(
|
|
272
|
+
deep_ask_event_stream(input_data),
|
|
273
|
+
media_type="text/event-stream",
|
|
274
|
+
headers={
|
|
275
|
+
"Cache-Control": "no-cache",
|
|
276
|
+
"Connection": "keep-alive",
|
|
277
|
+
"X-Accel-Buffering": "no",
|
|
278
|
+
},
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
async def health_check(_: Request) -> JSONResponse:
|
|
282
|
+
"""Health check endpoint."""
|
|
283
|
+
return JSONResponse({"status": "healthy"})
|
|
284
|
+
|
|
285
|
+
# Define routes
|
|
286
|
+
routes = [
|
|
287
|
+
Route("/v1/research/stream", stream_research, methods=["POST"]),
|
|
288
|
+
Route("/v1/deep-ask/stream", stream_deep_ask, methods=["POST"]),
|
|
289
|
+
Route("/health", health_check, methods=["GET"]),
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
# Configure CORS middleware
|
|
293
|
+
middleware = [
|
|
294
|
+
Middleware(
|
|
295
|
+
CORSMiddleware,
|
|
296
|
+
allow_origins=config.agui.cors_origins,
|
|
297
|
+
allow_credentials=config.agui.cors_credentials,
|
|
298
|
+
allow_methods=config.agui.cors_methods,
|
|
299
|
+
allow_headers=config.agui.cors_headers,
|
|
300
|
+
)
|
|
301
|
+
]
|
|
302
|
+
|
|
303
|
+
# Create Starlette app
|
|
304
|
+
app = Starlette(
|
|
305
|
+
routes=routes,
|
|
306
|
+
middleware=middleware,
|
|
307
|
+
debug=False,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
return app
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Generic AG-UI state utilities for any Pydantic BaseModel."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def compute_state_delta(
|
|
9
|
+
old_state: BaseModel, new_state: BaseModel
|
|
10
|
+
) -> list[dict[str, Any]]:
|
|
11
|
+
"""Compute JSON Patch (RFC 6902) operations from old state to new state.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
old_state: Previous state (any Pydantic BaseModel)
|
|
15
|
+
new_state: Current state (same type as old_state)
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
List of JSON Patch operations
|
|
19
|
+
"""
|
|
20
|
+
operations: list[dict[str, Any]] = []
|
|
21
|
+
|
|
22
|
+
# Convert states to dicts for comparison
|
|
23
|
+
old_dict = old_state.model_dump()
|
|
24
|
+
new_dict = new_state.model_dump()
|
|
25
|
+
|
|
26
|
+
# Compare each field and generate patches
|
|
27
|
+
for key, new_value in new_dict.items():
|
|
28
|
+
old_value = old_dict.get(key)
|
|
29
|
+
|
|
30
|
+
if old_value != new_value:
|
|
31
|
+
# Simple replace operation
|
|
32
|
+
operations.append({"op": "replace", "path": f"/{key}", "value": new_value})
|
|
33
|
+
|
|
34
|
+
return operations
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Generic graph streaming with AG-UI events."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from contextlib import suppress
|
|
6
|
+
from typing import Protocol, TypeVar
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
from pydantic_graph.beta import Graph
|
|
10
|
+
|
|
11
|
+
from haiku.rag.graph.agui.emitter import AGUIEmitter
|
|
12
|
+
from haiku.rag.graph.agui.events import AGUIEvent
|
|
13
|
+
|
|
14
|
+
StateT = TypeVar("StateT", bound=BaseModel)
|
|
15
|
+
ResultT = TypeVar("ResultT")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GraphDeps[StateT: BaseModel, ResultT](Protocol):
|
|
19
|
+
"""Protocol for graph dependencies that support AG-UI emission."""
|
|
20
|
+
|
|
21
|
+
agui_emitter: AGUIEmitter[StateT, ResultT] | None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def stream_graph[StateT: BaseModel, DepsT: GraphDeps, ResultT](
|
|
25
|
+
graph: Graph[StateT, DepsT, None, ResultT],
|
|
26
|
+
state: StateT,
|
|
27
|
+
deps: DepsT,
|
|
28
|
+
use_deltas: bool = True,
|
|
29
|
+
) -> AsyncIterator[AGUIEvent]:
|
|
30
|
+
"""Run a graph and yield AG-UI events as they occur.
|
|
31
|
+
|
|
32
|
+
This is a generic streaming function that works with any pydantic-graph
|
|
33
|
+
that follows the AG-UI pattern:
|
|
34
|
+
- State must be a Pydantic BaseModel
|
|
35
|
+
- Deps must have an optional agui_emitter attribute
|
|
36
|
+
- Graph must be a pydantic-graph Graph instance
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
graph: The pydantic-graph Graph to execute
|
|
40
|
+
state: Initial state (Pydantic BaseModel)
|
|
41
|
+
deps: Graph dependencies with agui_emitter support
|
|
42
|
+
use_deltas: Whether to emit state deltas instead of full snapshots (default: True)
|
|
43
|
+
|
|
44
|
+
Yields:
|
|
45
|
+
AG-UI event dictionaries
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
TypeError: If deps doesn't support agui_emitter
|
|
49
|
+
RuntimeError: If graph doesn't produce a result
|
|
50
|
+
"""
|
|
51
|
+
if not hasattr(deps, "agui_emitter"):
|
|
52
|
+
raise TypeError("deps must have an 'agui_emitter' attribute")
|
|
53
|
+
|
|
54
|
+
# Create AG-UI emitter
|
|
55
|
+
emitter: AGUIEmitter[StateT, ResultT] = AGUIEmitter(use_deltas=use_deltas)
|
|
56
|
+
deps.agui_emitter = emitter # type: ignore[assignment]
|
|
57
|
+
|
|
58
|
+
async def _execute() -> None:
|
|
59
|
+
try:
|
|
60
|
+
# Start the run with initial state
|
|
61
|
+
emitter.start_run(initial_state=state)
|
|
62
|
+
|
|
63
|
+
# Execute the graph
|
|
64
|
+
result = await graph.run(state=state, deps=deps)
|
|
65
|
+
|
|
66
|
+
if result is None:
|
|
67
|
+
raise RuntimeError("Graph did not produce a result")
|
|
68
|
+
|
|
69
|
+
# Finish the run with the result
|
|
70
|
+
emitter.finish_run(result)
|
|
71
|
+
except Exception as exc:
|
|
72
|
+
# Emit error event
|
|
73
|
+
emitter.error(exc)
|
|
74
|
+
finally:
|
|
75
|
+
await emitter.close()
|
|
76
|
+
|
|
77
|
+
runner = asyncio.create_task(_execute())
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
async for event in emitter:
|
|
81
|
+
yield event
|
|
82
|
+
finally:
|
|
83
|
+
if not runner.done():
|
|
84
|
+
runner.cancel()
|
|
85
|
+
with suppress(asyncio.CancelledError):
|
|
86
|
+
await runner
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Common models used across different graph implementations."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, field_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ResearchPlan(BaseModel):
|
|
7
|
+
"""A structured research plan with sub-questions to explore."""
|
|
8
|
+
|
|
9
|
+
sub_questions: list[str] = Field(
|
|
10
|
+
...,
|
|
11
|
+
description="Specific questions to research, phrased as complete questions",
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
@field_validator("sub_questions")
|
|
15
|
+
@classmethod
|
|
16
|
+
def validate_sub_questions(cls, v: list[str]) -> list[str]:
|
|
17
|
+
if len(v) < 1:
|
|
18
|
+
raise ValueError("Must have at least 1 sub-question")
|
|
19
|
+
if len(v) > 12:
|
|
20
|
+
raise ValueError("Cannot have more than 12 sub-questions")
|
|
21
|
+
return v
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SearchAnswer(BaseModel):
|
|
25
|
+
"""Answer from a search operation with sources."""
|
|
26
|
+
|
|
27
|
+
query: str = Field(..., description="The question that was answered")
|
|
28
|
+
answer: str = Field(..., description="The comprehensive answer to the question")
|
|
29
|
+
context: list[str] = Field(
|
|
30
|
+
default_factory=list,
|
|
31
|
+
description="Relevant snippets that directly support the answer",
|
|
32
|
+
)
|
|
33
|
+
sources: list[str] = Field(
|
|
34
|
+
default_factory=list,
|
|
35
|
+
description="Source URIs or titles that contributed to this answer",
|
|
36
|
+
)
|
|
37
|
+
confidence: float = Field(
|
|
38
|
+
default=1.0,
|
|
39
|
+
description="Confidence score for this answer (0-1)",
|
|
40
|
+
ge=0.0,
|
|
41
|
+
le=1.0,
|
|
42
|
+
)
|