haiku.rag-slim 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag-slim might be problematic. Click here for more details.

Files changed (71) hide show
  1. haiku/rag/__init__.py +0 -0
  2. haiku/rag/app.py +542 -0
  3. haiku/rag/chunker.py +65 -0
  4. haiku/rag/cli.py +466 -0
  5. haiku/rag/client.py +731 -0
  6. haiku/rag/config/__init__.py +74 -0
  7. haiku/rag/config/loader.py +94 -0
  8. haiku/rag/config/models.py +99 -0
  9. haiku/rag/embeddings/__init__.py +49 -0
  10. haiku/rag/embeddings/base.py +25 -0
  11. haiku/rag/embeddings/ollama.py +28 -0
  12. haiku/rag/embeddings/openai.py +26 -0
  13. haiku/rag/embeddings/vllm.py +29 -0
  14. haiku/rag/embeddings/voyageai.py +27 -0
  15. haiku/rag/graph/__init__.py +26 -0
  16. haiku/rag/graph/agui/__init__.py +53 -0
  17. haiku/rag/graph/agui/cli_renderer.py +135 -0
  18. haiku/rag/graph/agui/emitter.py +197 -0
  19. haiku/rag/graph/agui/events.py +254 -0
  20. haiku/rag/graph/agui/server.py +310 -0
  21. haiku/rag/graph/agui/state.py +34 -0
  22. haiku/rag/graph/agui/stream.py +86 -0
  23. haiku/rag/graph/common/__init__.py +5 -0
  24. haiku/rag/graph/common/models.py +42 -0
  25. haiku/rag/graph/common/nodes.py +265 -0
  26. haiku/rag/graph/common/prompts.py +46 -0
  27. haiku/rag/graph/common/utils.py +44 -0
  28. haiku/rag/graph/deep_qa/__init__.py +1 -0
  29. haiku/rag/graph/deep_qa/dependencies.py +27 -0
  30. haiku/rag/graph/deep_qa/graph.py +243 -0
  31. haiku/rag/graph/deep_qa/models.py +20 -0
  32. haiku/rag/graph/deep_qa/prompts.py +59 -0
  33. haiku/rag/graph/deep_qa/state.py +56 -0
  34. haiku/rag/graph/research/__init__.py +3 -0
  35. haiku/rag/graph/research/common.py +87 -0
  36. haiku/rag/graph/research/dependencies.py +151 -0
  37. haiku/rag/graph/research/graph.py +295 -0
  38. haiku/rag/graph/research/models.py +166 -0
  39. haiku/rag/graph/research/prompts.py +107 -0
  40. haiku/rag/graph/research/state.py +85 -0
  41. haiku/rag/logging.py +56 -0
  42. haiku/rag/mcp.py +245 -0
  43. haiku/rag/monitor.py +194 -0
  44. haiku/rag/qa/__init__.py +33 -0
  45. haiku/rag/qa/agent.py +93 -0
  46. haiku/rag/qa/prompts.py +60 -0
  47. haiku/rag/reader.py +135 -0
  48. haiku/rag/reranking/__init__.py +63 -0
  49. haiku/rag/reranking/base.py +13 -0
  50. haiku/rag/reranking/cohere.py +34 -0
  51. haiku/rag/reranking/mxbai.py +28 -0
  52. haiku/rag/reranking/vllm.py +44 -0
  53. haiku/rag/reranking/zeroentropy.py +59 -0
  54. haiku/rag/store/__init__.py +4 -0
  55. haiku/rag/store/engine.py +309 -0
  56. haiku/rag/store/models/__init__.py +4 -0
  57. haiku/rag/store/models/chunk.py +17 -0
  58. haiku/rag/store/models/document.py +17 -0
  59. haiku/rag/store/repositories/__init__.py +9 -0
  60. haiku/rag/store/repositories/chunk.py +442 -0
  61. haiku/rag/store/repositories/document.py +261 -0
  62. haiku/rag/store/repositories/settings.py +165 -0
  63. haiku/rag/store/upgrades/__init__.py +62 -0
  64. haiku/rag/store/upgrades/v0_10_1.py +64 -0
  65. haiku/rag/store/upgrades/v0_9_3.py +112 -0
  66. haiku/rag/utils.py +211 -0
  67. haiku_rag_slim-0.16.0.dist-info/METADATA +128 -0
  68. haiku_rag_slim-0.16.0.dist-info/RECORD +71 -0
  69. haiku_rag_slim-0.16.0.dist-info/WHEEL +4 -0
  70. haiku_rag_slim-0.16.0.dist-info/entry_points.txt +2 -0
  71. haiku_rag_slim-0.16.0.dist-info/licenses/LICENSE +7 -0
@@ -0,0 +1,310 @@
1
+ """AG-UI HTTP server implementation for graph execution."""
2
+
3
+ import json
4
+ from collections.abc import AsyncIterator, Callable
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Protocol
7
+
8
+ if TYPE_CHECKING:
9
+ from haiku.rag.config.models import AppConfig
10
+
11
+ from pydantic import BaseModel, Field
12
+ from pydantic_graph.beta import Graph
13
+ from starlette.applications import Starlette
14
+ from starlette.middleware import Middleware
15
+ from starlette.middleware.cors import CORSMiddleware
16
+ from starlette.requests import Request
17
+ from starlette.responses import JSONResponse, StreamingResponse
18
+ from starlette.routing import Route
19
+
20
+ from haiku.rag.config.models import AGUIConfig
21
+ from haiku.rag.graph.agui.emitter import AGUIEmitter
22
+ from haiku.rag.graph.agui.events import AGUIEvent
23
+ from haiku.rag.graph.agui.stream import stream_graph
24
+
25
+
26
+ class GraphDeps(Protocol):
27
+ """Protocol for graph dependencies that support AG-UI emission."""
28
+
29
+ agui_emitter: AGUIEmitter[Any, Any] | None
30
+
31
+
32
+ class RunAgentInput(BaseModel):
33
+ """AG-UI protocol run agent input.
34
+
35
+ See: https://docs.ag-ui.com/concepts/agents#runagentinput
36
+ """
37
+
38
+ thread_id: str | None = Field(None, alias="threadId")
39
+ run_id: str | None = Field(None, alias="runId")
40
+ state: dict[str, Any] = Field(default_factory=dict)
41
+ messages: list[dict[str, Any]] = Field(default_factory=list)
42
+ config: dict[str, Any] = Field(default_factory=dict)
43
+
44
+
45
+ def create_agui_app(
46
+ graph_factory: Callable[[], Graph],
47
+ state_factory: Callable[[dict[str, Any]], BaseModel],
48
+ deps_factory: Callable[[dict[str, Any]], GraphDeps],
49
+ config: AGUIConfig,
50
+ ) -> Starlette:
51
+ """Create Starlette app with AG-UI endpoint.
52
+
53
+ Args:
54
+ graph_factory: Factory function to create graph instance
55
+ state_factory: Factory to create initial state from input
56
+ deps_factory: Factory to create graph dependencies
57
+ config: AG-UI server configuration
58
+
59
+ Returns:
60
+ Starlette application with AG-UI endpoints
61
+ """
62
+
63
+ async def event_stream(
64
+ input_data: RunAgentInput,
65
+ ) -> AsyncIterator[str]:
66
+ """Generate SSE event stream from graph execution.
67
+
68
+ Yields:
69
+ Server-Sent Events formatted strings
70
+ """
71
+ # Create graph, state, and dependencies
72
+ graph = graph_factory()
73
+
74
+ # Create initial state from input
75
+ initial_state = state_factory(input_data.state)
76
+
77
+ # Create dependencies (may use config from input)
78
+ deps = deps_factory(input_data.config)
79
+
80
+ # Execute graph and stream events
81
+ async for event in stream_graph(graph, initial_state, deps):
82
+ # Format as SSE event
83
+ event_data = format_sse_event(event)
84
+ yield event_data
85
+
86
+ async def stream_agent(request: Request) -> StreamingResponse:
87
+ """AG-UI agent stream endpoint.
88
+
89
+ Accepts AG-UI RunAgentInput and streams events via SSE.
90
+ """
91
+ # Parse request body
92
+ body = await request.json()
93
+ input_data = RunAgentInput(**body)
94
+
95
+ # Return SSE stream
96
+ return StreamingResponse(
97
+ event_stream(input_data),
98
+ media_type="text/event-stream",
99
+ headers={
100
+ "Cache-Control": "no-cache",
101
+ "Connection": "keep-alive",
102
+ "X-Accel-Buffering": "no", # Disable buffering in nginx
103
+ },
104
+ )
105
+
106
+ async def health_check(_: Request) -> JSONResponse:
107
+ """Health check endpoint."""
108
+ return JSONResponse({"status": "healthy"})
109
+
110
+ # Define routes
111
+ routes = [
112
+ Route("/v1/agent/stream", stream_agent, methods=["POST"]),
113
+ Route("/health", health_check, methods=["GET"]),
114
+ ]
115
+
116
+ # Configure CORS middleware
117
+ middleware = [
118
+ Middleware(
119
+ CORSMiddleware,
120
+ allow_origins=config.cors_origins,
121
+ allow_credentials=config.cors_credentials,
122
+ allow_methods=config.cors_methods,
123
+ allow_headers=config.cors_headers,
124
+ )
125
+ ]
126
+
127
+ # Create Starlette app
128
+ app = Starlette(
129
+ routes=routes,
130
+ middleware=middleware,
131
+ debug=False,
132
+ )
133
+
134
+ return app
135
+
136
+
137
+ def format_sse_event(event: AGUIEvent) -> str:
138
+ """Format AG-UI event as Server-Sent Event.
139
+
140
+ Args:
141
+ event: AG-UI event dictionary
142
+
143
+ Returns:
144
+ SSE formatted string with event data
145
+ """
146
+ # Convert event to JSON
147
+ event_json = json.dumps(event, ensure_ascii=False)
148
+
149
+ # Format as SSE
150
+ # Each event is: data: <json>\n\n
151
+ return f"data: {event_json}\n\n"
152
+
153
+
154
+ def create_agui_server(config: "AppConfig", db_path: Path | None = None) -> Starlette:
155
+ """Create AG-UI server with both research and deep ask endpoints.
156
+
157
+ Args:
158
+ config: Application config with research and qa settings
159
+ db_path: Optional database path override
160
+
161
+ Returns:
162
+ Starlette app with research and deep ask endpoints
163
+ """
164
+ from haiku.rag.client import HaikuRAG
165
+ from haiku.rag.graph.deep_qa.dependencies import DeepQAContext
166
+ from haiku.rag.graph.deep_qa.graph import build_deep_qa_graph
167
+ from haiku.rag.graph.deep_qa.state import DeepQADeps, DeepQAState
168
+ from haiku.rag.graph.research.dependencies import ResearchContext
169
+ from haiku.rag.graph.research.graph import build_research_graph
170
+ from haiku.rag.graph.research.state import ResearchDeps, ResearchState
171
+
172
+ # Store client reference for proper lifecycle management
173
+ _client_cache: dict[str, HaikuRAG] = {}
174
+
175
+ def get_client(effective_db_path: Path) -> HaikuRAG:
176
+ """Get or create cached client."""
177
+ path_key = str(effective_db_path)
178
+ if path_key not in _client_cache:
179
+ _client_cache[path_key] = HaikuRAG(db_path=effective_db_path, config=config)
180
+ return _client_cache[path_key]
181
+
182
+ # Research graph factories
183
+ def research_graph_factory() -> Graph:
184
+ return build_research_graph(config)
185
+
186
+ def research_state_factory(input_state: dict[str, Any]) -> ResearchState:
187
+ question = input_state.get("question", "")
188
+ if not question:
189
+ messages = input_state.get("messages", [])
190
+ if messages:
191
+ question = messages[0].get("content", "")
192
+ context = ResearchContext(original_question=question)
193
+ return ResearchState.from_config(context=context, config=config)
194
+
195
+ def research_deps_factory(input_config: dict[str, Any]) -> ResearchDeps:
196
+ effective_db_path = (
197
+ db_path
198
+ or input_config.get("db_path")
199
+ or config.storage.data_dir / "haiku.rag.lancedb"
200
+ )
201
+ return ResearchDeps(client=get_client(effective_db_path))
202
+
203
+ # Deep ask graph factories
204
+ def deep_ask_graph_factory() -> Graph:
205
+ return build_deep_qa_graph(config)
206
+
207
+ def deep_ask_state_factory(input_state: dict[str, Any]) -> DeepQAState:
208
+ question = input_state.get("question", "")
209
+ if not question:
210
+ messages = input_state.get("messages", [])
211
+ if messages:
212
+ question = messages[0].get("content", "")
213
+ use_citations = input_state.get("use_citations", False)
214
+ context = DeepQAContext(original_question=question, use_citations=use_citations)
215
+ return DeepQAState.from_config(context=context, config=config)
216
+
217
+ def deep_ask_deps_factory(input_config: dict[str, Any]) -> DeepQADeps:
218
+ effective_db_path = (
219
+ db_path
220
+ or input_config.get("db_path")
221
+ or config.storage.data_dir / "haiku.rag.lancedb"
222
+ )
223
+ return DeepQADeps(client=get_client(effective_db_path))
224
+
225
+ # Create event stream functions for each graph type
226
+ async def research_event_stream(
227
+ input_data: RunAgentInput,
228
+ ) -> AsyncIterator[str]:
229
+ """Generate SSE event stream from research graph execution."""
230
+ graph = research_graph_factory()
231
+ initial_state = research_state_factory(input_data.state)
232
+ deps = research_deps_factory(input_data.config)
233
+
234
+ async for event in stream_graph(graph, initial_state, deps):
235
+ event_data = format_sse_event(event)
236
+ yield event_data
237
+
238
+ async def deep_ask_event_stream(
239
+ input_data: RunAgentInput,
240
+ ) -> AsyncIterator[str]:
241
+ """Generate SSE event stream from deep ask graph execution."""
242
+ graph = deep_ask_graph_factory()
243
+ initial_state = deep_ask_state_factory(input_data.state)
244
+ deps = deep_ask_deps_factory(input_data.config)
245
+
246
+ async for event in stream_graph(graph, initial_state, deps):
247
+ event_data = format_sse_event(event)
248
+ yield event_data
249
+
250
+ # Endpoint handlers
251
+ async def stream_research(request: Request) -> StreamingResponse:
252
+ """Research graph streaming endpoint."""
253
+ body = await request.json()
254
+ input_data = RunAgentInput(**body)
255
+
256
+ return StreamingResponse(
257
+ research_event_stream(input_data),
258
+ media_type="text/event-stream",
259
+ headers={
260
+ "Cache-Control": "no-cache",
261
+ "Connection": "keep-alive",
262
+ "X-Accel-Buffering": "no",
263
+ },
264
+ )
265
+
266
+ async def stream_deep_ask(request: Request) -> StreamingResponse:
267
+ """Deep ask graph streaming endpoint."""
268
+ body = await request.json()
269
+ input_data = RunAgentInput(**body)
270
+
271
+ return StreamingResponse(
272
+ deep_ask_event_stream(input_data),
273
+ media_type="text/event-stream",
274
+ headers={
275
+ "Cache-Control": "no-cache",
276
+ "Connection": "keep-alive",
277
+ "X-Accel-Buffering": "no",
278
+ },
279
+ )
280
+
281
+ async def health_check(_: Request) -> JSONResponse:
282
+ """Health check endpoint."""
283
+ return JSONResponse({"status": "healthy"})
284
+
285
+ # Define routes
286
+ routes = [
287
+ Route("/v1/research/stream", stream_research, methods=["POST"]),
288
+ Route("/v1/deep-ask/stream", stream_deep_ask, methods=["POST"]),
289
+ Route("/health", health_check, methods=["GET"]),
290
+ ]
291
+
292
+ # Configure CORS middleware
293
+ middleware = [
294
+ Middleware(
295
+ CORSMiddleware,
296
+ allow_origins=config.agui.cors_origins,
297
+ allow_credentials=config.agui.cors_credentials,
298
+ allow_methods=config.agui.cors_methods,
299
+ allow_headers=config.agui.cors_headers,
300
+ )
301
+ ]
302
+
303
+ # Create Starlette app
304
+ app = Starlette(
305
+ routes=routes,
306
+ middleware=middleware,
307
+ debug=False,
308
+ )
309
+
310
+ return app
@@ -0,0 +1,34 @@
1
+ """Generic AG-UI state utilities for any Pydantic BaseModel."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ def compute_state_delta(
9
+ old_state: BaseModel, new_state: BaseModel
10
+ ) -> list[dict[str, Any]]:
11
+ """Compute JSON Patch (RFC 6902) operations from old state to new state.
12
+
13
+ Args:
14
+ old_state: Previous state (any Pydantic BaseModel)
15
+ new_state: Current state (same type as old_state)
16
+
17
+ Returns:
18
+ List of JSON Patch operations
19
+ """
20
+ operations: list[dict[str, Any]] = []
21
+
22
+ # Convert states to dicts for comparison
23
+ old_dict = old_state.model_dump()
24
+ new_dict = new_state.model_dump()
25
+
26
+ # Compare each field and generate patches
27
+ for key, new_value in new_dict.items():
28
+ old_value = old_dict.get(key)
29
+
30
+ if old_value != new_value:
31
+ # Simple replace operation
32
+ operations.append({"op": "replace", "path": f"/{key}", "value": new_value})
33
+
34
+ return operations
@@ -0,0 +1,86 @@
1
+ """Generic graph streaming with AG-UI events."""
2
+
3
+ import asyncio
4
+ from collections.abc import AsyncIterator
5
+ from contextlib import suppress
6
+ from typing import Protocol, TypeVar
7
+
8
+ from pydantic import BaseModel
9
+ from pydantic_graph.beta import Graph
10
+
11
+ from haiku.rag.graph.agui.emitter import AGUIEmitter
12
+ from haiku.rag.graph.agui.events import AGUIEvent
13
+
14
+ StateT = TypeVar("StateT", bound=BaseModel)
15
+ ResultT = TypeVar("ResultT")
16
+
17
+
18
+ class GraphDeps[StateT: BaseModel, ResultT](Protocol):
19
+ """Protocol for graph dependencies that support AG-UI emission."""
20
+
21
+ agui_emitter: AGUIEmitter[StateT, ResultT] | None
22
+
23
+
24
+ async def stream_graph[StateT: BaseModel, DepsT: GraphDeps, ResultT](
25
+ graph: Graph[StateT, DepsT, None, ResultT],
26
+ state: StateT,
27
+ deps: DepsT,
28
+ use_deltas: bool = True,
29
+ ) -> AsyncIterator[AGUIEvent]:
30
+ """Run a graph and yield AG-UI events as they occur.
31
+
32
+ This is a generic streaming function that works with any pydantic-graph
33
+ that follows the AG-UI pattern:
34
+ - State must be a Pydantic BaseModel
35
+ - Deps must have an optional agui_emitter attribute
36
+ - Graph must be a pydantic-graph Graph instance
37
+
38
+ Args:
39
+ graph: The pydantic-graph Graph to execute
40
+ state: Initial state (Pydantic BaseModel)
41
+ deps: Graph dependencies with agui_emitter support
42
+ use_deltas: Whether to emit state deltas instead of full snapshots (default: True)
43
+
44
+ Yields:
45
+ AG-UI event dictionaries
46
+
47
+ Raises:
48
+ TypeError: If deps doesn't support agui_emitter
49
+ RuntimeError: If graph doesn't produce a result
50
+ """
51
+ if not hasattr(deps, "agui_emitter"):
52
+ raise TypeError("deps must have an 'agui_emitter' attribute")
53
+
54
+ # Create AG-UI emitter
55
+ emitter: AGUIEmitter[StateT, ResultT] = AGUIEmitter(use_deltas=use_deltas)
56
+ deps.agui_emitter = emitter # type: ignore[assignment]
57
+
58
+ async def _execute() -> None:
59
+ try:
60
+ # Start the run with initial state
61
+ emitter.start_run(initial_state=state)
62
+
63
+ # Execute the graph
64
+ result = await graph.run(state=state, deps=deps)
65
+
66
+ if result is None:
67
+ raise RuntimeError("Graph did not produce a result")
68
+
69
+ # Finish the run with the result
70
+ emitter.finish_run(result)
71
+ except Exception as exc:
72
+ # Emit error event
73
+ emitter.error(exc)
74
+ finally:
75
+ await emitter.close()
76
+
77
+ runner = asyncio.create_task(_execute())
78
+
79
+ try:
80
+ async for event in emitter:
81
+ yield event
82
+ finally:
83
+ if not runner.done():
84
+ runner.cancel()
85
+ with suppress(asyncio.CancelledError):
86
+ await runner
@@ -0,0 +1,5 @@
1
+ """Common utilities for graph implementations."""
2
+
3
+ from haiku.rag.graph.common.utils import get_model
4
+
5
+ __all__ = ["get_model"]
@@ -0,0 +1,42 @@
1
+ """Common models used across different graph implementations."""
2
+
3
+ from pydantic import BaseModel, Field, field_validator
4
+
5
+
6
+ class ResearchPlan(BaseModel):
7
+ """A structured research plan with sub-questions to explore."""
8
+
9
+ sub_questions: list[str] = Field(
10
+ ...,
11
+ description="Specific questions to research, phrased as complete questions",
12
+ )
13
+
14
+ @field_validator("sub_questions")
15
+ @classmethod
16
+ def validate_sub_questions(cls, v: list[str]) -> list[str]:
17
+ if len(v) < 1:
18
+ raise ValueError("Must have at least 1 sub-question")
19
+ if len(v) > 12:
20
+ raise ValueError("Cannot have more than 12 sub-questions")
21
+ return v
22
+
23
+
24
+ class SearchAnswer(BaseModel):
25
+ """Answer from a search operation with sources."""
26
+
27
+ query: str = Field(..., description="The question that was answered")
28
+ answer: str = Field(..., description="The comprehensive answer to the question")
29
+ context: list[str] = Field(
30
+ default_factory=list,
31
+ description="Relevant snippets that directly support the answer",
32
+ )
33
+ sources: list[str] = Field(
34
+ default_factory=list,
35
+ description="Source URIs or titles that contributed to this answer",
36
+ )
37
+ confidence: float = Field(
38
+ default=1.0,
39
+ description="Confidence score for this answer (0-1)",
40
+ ge=0.0,
41
+ le=1.0,
42
+ )