haiku.rag-slim 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag-slim might be problematic. Click here for more details.
- haiku/rag/__init__.py +0 -0
- haiku/rag/app.py +542 -0
- haiku/rag/chunker.py +65 -0
- haiku/rag/cli.py +466 -0
- haiku/rag/client.py +731 -0
- haiku/rag/config/__init__.py +74 -0
- haiku/rag/config/loader.py +94 -0
- haiku/rag/config/models.py +99 -0
- haiku/rag/embeddings/__init__.py +49 -0
- haiku/rag/embeddings/base.py +25 -0
- haiku/rag/embeddings/ollama.py +28 -0
- haiku/rag/embeddings/openai.py +26 -0
- haiku/rag/embeddings/vllm.py +29 -0
- haiku/rag/embeddings/voyageai.py +27 -0
- haiku/rag/graph/__init__.py +26 -0
- haiku/rag/graph/agui/__init__.py +53 -0
- haiku/rag/graph/agui/cli_renderer.py +135 -0
- haiku/rag/graph/agui/emitter.py +197 -0
- haiku/rag/graph/agui/events.py +254 -0
- haiku/rag/graph/agui/server.py +310 -0
- haiku/rag/graph/agui/state.py +34 -0
- haiku/rag/graph/agui/stream.py +86 -0
- haiku/rag/graph/common/__init__.py +5 -0
- haiku/rag/graph/common/models.py +42 -0
- haiku/rag/graph/common/nodes.py +265 -0
- haiku/rag/graph/common/prompts.py +46 -0
- haiku/rag/graph/common/utils.py +44 -0
- haiku/rag/graph/deep_qa/__init__.py +1 -0
- haiku/rag/graph/deep_qa/dependencies.py +27 -0
- haiku/rag/graph/deep_qa/graph.py +243 -0
- haiku/rag/graph/deep_qa/models.py +20 -0
- haiku/rag/graph/deep_qa/prompts.py +59 -0
- haiku/rag/graph/deep_qa/state.py +56 -0
- haiku/rag/graph/research/__init__.py +3 -0
- haiku/rag/graph/research/common.py +87 -0
- haiku/rag/graph/research/dependencies.py +151 -0
- haiku/rag/graph/research/graph.py +295 -0
- haiku/rag/graph/research/models.py +166 -0
- haiku/rag/graph/research/prompts.py +107 -0
- haiku/rag/graph/research/state.py +85 -0
- haiku/rag/logging.py +56 -0
- haiku/rag/mcp.py +245 -0
- haiku/rag/monitor.py +194 -0
- haiku/rag/qa/__init__.py +33 -0
- haiku/rag/qa/agent.py +93 -0
- haiku/rag/qa/prompts.py +60 -0
- haiku/rag/reader.py +135 -0
- haiku/rag/reranking/__init__.py +63 -0
- haiku/rag/reranking/base.py +13 -0
- haiku/rag/reranking/cohere.py +34 -0
- haiku/rag/reranking/mxbai.py +28 -0
- haiku/rag/reranking/vllm.py +44 -0
- haiku/rag/reranking/zeroentropy.py +59 -0
- haiku/rag/store/__init__.py +4 -0
- haiku/rag/store/engine.py +309 -0
- haiku/rag/store/models/__init__.py +4 -0
- haiku/rag/store/models/chunk.py +17 -0
- haiku/rag/store/models/document.py +17 -0
- haiku/rag/store/repositories/__init__.py +9 -0
- haiku/rag/store/repositories/chunk.py +442 -0
- haiku/rag/store/repositories/document.py +261 -0
- haiku/rag/store/repositories/settings.py +165 -0
- haiku/rag/store/upgrades/__init__.py +62 -0
- haiku/rag/store/upgrades/v0_10_1.py +64 -0
- haiku/rag/store/upgrades/v0_9_3.py +112 -0
- haiku/rag/utils.py +211 -0
- haiku_rag_slim-0.16.0.dist-info/METADATA +128 -0
- haiku_rag_slim-0.16.0.dist-info/RECORD +71 -0
- haiku_rag_slim-0.16.0.dist-info/WHEEL +4 -0
- haiku_rag_slim-0.16.0.dist-info/entry_points.txt +2 -0
- haiku_rag_slim-0.16.0.dist-info/licenses/LICENSE +7 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
"""Common node implementations for graph workflows."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
from typing import Any, Protocol
|
|
6
|
+
|
|
7
|
+
from pydantic_ai import Agent, RunContext
|
|
8
|
+
from pydantic_ai.format_prompt import format_as_xml
|
|
9
|
+
from pydantic_ai.output import ToolOutput
|
|
10
|
+
from pydantic_graph.beta import StepContext
|
|
11
|
+
|
|
12
|
+
from haiku.rag.client import HaikuRAG
|
|
13
|
+
from haiku.rag.graph.agui.emitter import AGUIEmitter
|
|
14
|
+
from haiku.rag.graph.common import get_model
|
|
15
|
+
from haiku.rag.graph.common.models import ResearchPlan, SearchAnswer
|
|
16
|
+
from haiku.rag.graph.common.prompts import PLAN_PROMPT, SEARCH_AGENT_PROMPT
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class GraphContext(Protocol):
|
|
20
|
+
"""Protocol for graph context objects."""
|
|
21
|
+
|
|
22
|
+
original_question: str
|
|
23
|
+
sub_questions: list[str]
|
|
24
|
+
|
|
25
|
+
def add_qa_response(self, qa: SearchAnswer) -> None:
|
|
26
|
+
"""Add a QA response to context."""
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GraphState(Protocol):
|
|
31
|
+
"""Protocol for graph state objects."""
|
|
32
|
+
|
|
33
|
+
context: GraphContext
|
|
34
|
+
max_concurrency: int
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class GraphDeps(Protocol):
|
|
38
|
+
"""Protocol for graph dependencies."""
|
|
39
|
+
|
|
40
|
+
client: HaikuRAG
|
|
41
|
+
agui_emitter: AGUIEmitter[Any, Any] | None
|
|
42
|
+
semaphore: asyncio.Semaphore | None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GraphAgentDeps(Protocol):
|
|
46
|
+
"""Protocol for agent dependencies."""
|
|
47
|
+
|
|
48
|
+
client: HaikuRAG
|
|
49
|
+
context: GraphContext
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def create_plan_node[AgentDepsT: GraphAgentDeps](
|
|
53
|
+
provider: str,
|
|
54
|
+
model: str,
|
|
55
|
+
deps_type: type[AgentDepsT],
|
|
56
|
+
activity_message: str = "Creating plan",
|
|
57
|
+
output_retries: int | None = None,
|
|
58
|
+
) -> Callable[[StepContext[Any, Any, None]], Awaitable[None]]:
|
|
59
|
+
"""Create a plan node for any graph.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
provider: Model provider (e.g., 'openai', 'anthropic')
|
|
63
|
+
model: Model name
|
|
64
|
+
deps_type: Type of dependencies for the agent (e.g., ResearchDependencies, DeepQADependencies)
|
|
65
|
+
activity_message: Message to show during planning activity
|
|
66
|
+
output_retries: Number of output retries for the agent (optional)
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Async function that can be used as a graph step
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
async def plan(ctx: StepContext[Any, Any, None], /) -> None:
|
|
73
|
+
state: GraphState = ctx.state # type: ignore[assignment]
|
|
74
|
+
deps: GraphDeps = ctx.deps # type: ignore[assignment]
|
|
75
|
+
|
|
76
|
+
if deps.agui_emitter:
|
|
77
|
+
deps.agui_emitter.start_step("plan")
|
|
78
|
+
deps.agui_emitter.update_activity("planning", activity_message)
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
# Build agent configuration
|
|
82
|
+
agent_config = {
|
|
83
|
+
"model": get_model(provider, model),
|
|
84
|
+
"output_type": ResearchPlan,
|
|
85
|
+
"instructions": (
|
|
86
|
+
PLAN_PROMPT
|
|
87
|
+
+ "\n\nUse the gather_context tool once on the main question before planning."
|
|
88
|
+
),
|
|
89
|
+
"retries": 3,
|
|
90
|
+
"deps_type": deps_type,
|
|
91
|
+
}
|
|
92
|
+
if output_retries is not None:
|
|
93
|
+
agent_config["output_retries"] = output_retries
|
|
94
|
+
|
|
95
|
+
plan_agent = Agent(**agent_config)
|
|
96
|
+
|
|
97
|
+
@plan_agent.tool
|
|
98
|
+
async def gather_context(
|
|
99
|
+
ctx2: RunContext[AgentDepsT], query: str, limit: int = 6
|
|
100
|
+
) -> str:
|
|
101
|
+
results = await ctx2.deps.client.search(query, limit=limit)
|
|
102
|
+
expanded = await ctx2.deps.client.expand_context(results)
|
|
103
|
+
return "\n\n".join(chunk.content for chunk, _ in expanded)
|
|
104
|
+
|
|
105
|
+
# Tool is registered via decorator above
|
|
106
|
+
_ = gather_context
|
|
107
|
+
|
|
108
|
+
prompt = (
|
|
109
|
+
"Plan a focused approach for the main question.\n\n"
|
|
110
|
+
f"Main question: {state.context.original_question}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Create agent dependencies
|
|
114
|
+
agent_deps = deps_type(client=deps.client, context=state.context) # type: ignore[call-arg]
|
|
115
|
+
plan_result = await plan_agent.run(prompt, deps=agent_deps)
|
|
116
|
+
state.context.sub_questions = list(plan_result.output.sub_questions)
|
|
117
|
+
|
|
118
|
+
# State now contains the plan - emit state update and narrate
|
|
119
|
+
if deps.agui_emitter:
|
|
120
|
+
deps.agui_emitter.update_state(state)
|
|
121
|
+
count = len(state.context.sub_questions)
|
|
122
|
+
deps.agui_emitter.update_activity(
|
|
123
|
+
"planning", f"Created plan with {count} sub-questions"
|
|
124
|
+
)
|
|
125
|
+
finally:
|
|
126
|
+
if deps.agui_emitter:
|
|
127
|
+
deps.agui_emitter.finish_step()
|
|
128
|
+
|
|
129
|
+
return plan
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def create_search_node[AgentDepsT: GraphAgentDeps](
|
|
133
|
+
provider: str,
|
|
134
|
+
model: str,
|
|
135
|
+
deps_type: type[AgentDepsT],
|
|
136
|
+
with_step_wrapper: bool = True,
|
|
137
|
+
success_message_format: str = "Answered: {sub_q}",
|
|
138
|
+
handle_exceptions: bool = False,
|
|
139
|
+
) -> Callable[[StepContext[Any, Any, str]], Awaitable[SearchAnswer]]:
|
|
140
|
+
"""Create a search_one node for any graph.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
provider: Model provider
|
|
144
|
+
model: Model name
|
|
145
|
+
deps_type: Type of dependencies for the agent
|
|
146
|
+
with_step_wrapper: Whether to wrap with agui_emitter start/finish step
|
|
147
|
+
success_message_format: Format string for success activity message
|
|
148
|
+
handle_exceptions: Whether to handle exceptions with fallback answer
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Async function that can be used as a graph step
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
async def search_one(ctx: StepContext[Any, Any, str], /) -> SearchAnswer:
|
|
155
|
+
state: GraphState = ctx.state # type: ignore[assignment]
|
|
156
|
+
deps: GraphDeps = ctx.deps # type: ignore[assignment]
|
|
157
|
+
sub_q = ctx.inputs
|
|
158
|
+
|
|
159
|
+
# Create unique step name from question text
|
|
160
|
+
step_name = f"search: {sub_q}"
|
|
161
|
+
|
|
162
|
+
if deps.agui_emitter and with_step_wrapper:
|
|
163
|
+
deps.agui_emitter.start_step(step_name)
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
# Create semaphore if not already provided
|
|
167
|
+
if deps.semaphore is None:
|
|
168
|
+
deps.semaphore = asyncio.Semaphore(state.max_concurrency)
|
|
169
|
+
|
|
170
|
+
# Use semaphore to control concurrency
|
|
171
|
+
async with deps.semaphore:
|
|
172
|
+
return await _do_search(
|
|
173
|
+
state,
|
|
174
|
+
deps,
|
|
175
|
+
sub_q,
|
|
176
|
+
provider,
|
|
177
|
+
model,
|
|
178
|
+
deps_type,
|
|
179
|
+
success_message_format,
|
|
180
|
+
handle_exceptions,
|
|
181
|
+
)
|
|
182
|
+
finally:
|
|
183
|
+
if deps.agui_emitter and with_step_wrapper:
|
|
184
|
+
deps.agui_emitter.finish_step()
|
|
185
|
+
|
|
186
|
+
return search_one
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
async def _do_search[AgentDepsT: GraphAgentDeps](
|
|
190
|
+
state: GraphState,
|
|
191
|
+
deps: GraphDeps,
|
|
192
|
+
sub_q: str,
|
|
193
|
+
provider: str,
|
|
194
|
+
model: str,
|
|
195
|
+
deps_type: type[AgentDepsT],
|
|
196
|
+
success_message_format: str,
|
|
197
|
+
handle_exceptions: bool,
|
|
198
|
+
) -> SearchAnswer:
|
|
199
|
+
"""Internal search implementation."""
|
|
200
|
+
if deps.agui_emitter:
|
|
201
|
+
deps.agui_emitter.update_activity("searching", f"Searching: {sub_q}")
|
|
202
|
+
|
|
203
|
+
agent = Agent(
|
|
204
|
+
model=get_model(provider, model),
|
|
205
|
+
output_type=ToolOutput(SearchAnswer, max_retries=3),
|
|
206
|
+
instructions=SEARCH_AGENT_PROMPT,
|
|
207
|
+
retries=3,
|
|
208
|
+
deps_type=deps_type,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
@agent.tool
|
|
212
|
+
async def search_and_answer(
|
|
213
|
+
ctx2: RunContext[AgentDepsT], query: str, limit: int = 5
|
|
214
|
+
) -> str:
|
|
215
|
+
search_results = await ctx2.deps.client.search(query, limit=limit)
|
|
216
|
+
expanded = await ctx2.deps.client.expand_context(search_results)
|
|
217
|
+
|
|
218
|
+
entries: list[dict[str, Any]] = [
|
|
219
|
+
{
|
|
220
|
+
"text": chunk.content,
|
|
221
|
+
"score": score,
|
|
222
|
+
"document_uri": (chunk.document_title or chunk.document_uri or ""),
|
|
223
|
+
}
|
|
224
|
+
for chunk, score in expanded
|
|
225
|
+
]
|
|
226
|
+
if not entries:
|
|
227
|
+
return f"No relevant information found in the knowledge base for: {query}"
|
|
228
|
+
|
|
229
|
+
return format_as_xml(entries, root_tag="snippets")
|
|
230
|
+
|
|
231
|
+
# Tool is registered via decorator above
|
|
232
|
+
_ = search_and_answer
|
|
233
|
+
|
|
234
|
+
agent_deps = deps_type(client=deps.client, context=state.context) # type: ignore[call-arg]
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
result = await agent.run(sub_q, deps=agent_deps)
|
|
238
|
+
answer = result.output
|
|
239
|
+
if answer:
|
|
240
|
+
state.context.add_qa_response(answer)
|
|
241
|
+
# State updated with new answer - emit state update and narrate
|
|
242
|
+
if deps.agui_emitter:
|
|
243
|
+
deps.agui_emitter.update_state(state)
|
|
244
|
+
# Format the success message
|
|
245
|
+
if "{confidence" in success_message_format:
|
|
246
|
+
message = success_message_format.format(
|
|
247
|
+
sub_q=sub_q, confidence=answer.confidence
|
|
248
|
+
)
|
|
249
|
+
else:
|
|
250
|
+
message = success_message_format.format(sub_q=sub_q)
|
|
251
|
+
deps.agui_emitter.update_activity("searching", message)
|
|
252
|
+
return answer
|
|
253
|
+
except Exception as e:
|
|
254
|
+
if handle_exceptions:
|
|
255
|
+
# Narrate the error
|
|
256
|
+
if deps.agui_emitter:
|
|
257
|
+
deps.agui_emitter.update_activity("searching", f"Search failed: {e}")
|
|
258
|
+
failure_answer = SearchAnswer(
|
|
259
|
+
query=sub_q,
|
|
260
|
+
answer=f"Search failed after retries: {str(e)}",
|
|
261
|
+
confidence=0.0,
|
|
262
|
+
)
|
|
263
|
+
return failure_answer
|
|
264
|
+
else:
|
|
265
|
+
raise
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Common prompts used across different graph implementations."""
|
|
2
|
+
|
|
3
|
+
PLAN_PROMPT = """You are the research orchestrator for a focused, iterative workflow.
|
|
4
|
+
|
|
5
|
+
Responsibilities:
|
|
6
|
+
1. Understand and decompose the main question
|
|
7
|
+
2. Propose a minimal, high-leverage plan
|
|
8
|
+
3. Coordinate specialized agents to gather evidence
|
|
9
|
+
4. Iterate based on gaps and new findings
|
|
10
|
+
|
|
11
|
+
Plan requirements:
|
|
12
|
+
- Produce at most 3 sub_questions that together cover the main question.
|
|
13
|
+
- Each sub_question must be a standalone, self-contained query that can run
|
|
14
|
+
without extra context. Include concrete entities, scope, timeframe, and any
|
|
15
|
+
qualifiers. Avoid ambiguous pronouns (it/they/this/that).
|
|
16
|
+
- Prioritize the highest-value aspects first; avoid redundancy and overlap.
|
|
17
|
+
- Prefer questions that are likely answerable from the current knowledge base;
|
|
18
|
+
if coverage is uncertain, make scopes narrower and specific.
|
|
19
|
+
- Order sub_questions by execution priority (most valuable first)."""
|
|
20
|
+
|
|
21
|
+
SEARCH_AGENT_PROMPT = """You are a search and question-answering specialist.
|
|
22
|
+
|
|
23
|
+
Tasks:
|
|
24
|
+
1. Search the knowledge base for relevant evidence.
|
|
25
|
+
2. Analyze retrieved snippets.
|
|
26
|
+
3. Provide an answer strictly grounded in that evidence.
|
|
27
|
+
|
|
28
|
+
Tool usage:
|
|
29
|
+
- Always call search_and_answer before drafting any answer.
|
|
30
|
+
- The tool returns snippets with verbatim `text`, a relevance `score`, and the
|
|
31
|
+
originating document identifier (document title if available, otherwise URI).
|
|
32
|
+
- You may call the tool multiple times to refine or broaden context, but do not
|
|
33
|
+
exceed 3 total calls. Favor precision over volume.
|
|
34
|
+
- Use scores to prioritize evidence, but include only the minimal subset of
|
|
35
|
+
snippet texts (verbatim) in SearchAnswer.context (typically 1-4).
|
|
36
|
+
- Set SearchAnswer.sources to the corresponding document identifiers for the
|
|
37
|
+
snippets you used (title if available, otherwise URI; one per snippet; same
|
|
38
|
+
order as context). Context must be text-only.
|
|
39
|
+
- If no relevant information is found, clearly say so and return an empty
|
|
40
|
+
context list and sources list.
|
|
41
|
+
|
|
42
|
+
Answering rules:
|
|
43
|
+
- Be direct and specific; avoid meta commentary about the process.
|
|
44
|
+
- Do not include any claims not supported by the provided snippets.
|
|
45
|
+
- Prefer concise phrasing; avoid copying long passages.
|
|
46
|
+
- When evidence is partial, state the limits explicitly in the answer."""
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Common utilities for all graph implementations."""
|
|
2
|
+
|
|
3
|
+
from pydantic_ai.models.openai import OpenAIChatModel
|
|
4
|
+
from pydantic_ai.providers.ollama import OllamaProvider
|
|
5
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
6
|
+
|
|
7
|
+
from haiku.rag.config import Config
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_model(provider: str, model: str) -> OpenAIChatModel | str:
|
|
11
|
+
"""
|
|
12
|
+
Get a model instance for the specified provider and model name.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
provider: The model provider ("ollama", "vllm", or other)
|
|
16
|
+
model: The model name
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
A configured model instance
|
|
20
|
+
|
|
21
|
+
Raises:
|
|
22
|
+
ValueError: If the provider is unknown
|
|
23
|
+
"""
|
|
24
|
+
if provider == "ollama":
|
|
25
|
+
return OpenAIChatModel(
|
|
26
|
+
model_name=model,
|
|
27
|
+
provider=OllamaProvider(base_url=f"{Config.providers.ollama.base_url}/v1"),
|
|
28
|
+
)
|
|
29
|
+
elif provider == "vllm":
|
|
30
|
+
return OpenAIChatModel(
|
|
31
|
+
model_name=model,
|
|
32
|
+
provider=OpenAIProvider(
|
|
33
|
+
base_url=f"{Config.providers.vllm.research_base_url or Config.providers.vllm.qa_base_url}/v1",
|
|
34
|
+
api_key="none",
|
|
35
|
+
),
|
|
36
|
+
)
|
|
37
|
+
elif provider in ("openai", "anthropic", "gemini", "groq", "bedrock"):
|
|
38
|
+
# These providers use string format
|
|
39
|
+
return f"{provider}:{model}"
|
|
40
|
+
else:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Unknown model provider: {provider}. "
|
|
43
|
+
f"Supported providers: ollama, vllm, openai, anthropic, gemini, groq, bedrock"
|
|
44
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from haiku.rag.graph.deep_qa.models import DeepQAAnswer
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
from haiku.rag.client import HaikuRAG
|
|
4
|
+
from haiku.rag.graph.common.models import SearchAnswer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DeepQAContext(BaseModel):
|
|
8
|
+
original_question: str = Field(description="The original question")
|
|
9
|
+
sub_questions: list[str] = Field(
|
|
10
|
+
default_factory=list, description="Decomposed sub-questions"
|
|
11
|
+
)
|
|
12
|
+
qa_responses: list[SearchAnswer] = Field(
|
|
13
|
+
default_factory=list, description="QA pairs collected during answering"
|
|
14
|
+
)
|
|
15
|
+
use_citations: bool = Field(
|
|
16
|
+
default=False, description="Whether to include citations in the answer"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
def add_qa_response(self, qa: SearchAnswer) -> None:
|
|
20
|
+
self.qa_responses.append(qa)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DeepQADependencies(BaseModel):
|
|
24
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
25
|
+
|
|
26
|
+
client: HaikuRAG = Field(description="RAG client for document operations")
|
|
27
|
+
context: DeepQAContext = Field(description="Shared QA context")
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from pydantic_ai import Agent
|
|
2
|
+
from pydantic_ai.format_prompt import format_as_xml
|
|
3
|
+
from pydantic_graph.beta import Graph, GraphBuilder, StepContext
|
|
4
|
+
from pydantic_graph.beta.join import reduce_list_append
|
|
5
|
+
|
|
6
|
+
from haiku.rag.config import Config
|
|
7
|
+
from haiku.rag.config.models import AppConfig
|
|
8
|
+
from haiku.rag.graph.common import get_model
|
|
9
|
+
from haiku.rag.graph.common.models import SearchAnswer
|
|
10
|
+
from haiku.rag.graph.common.nodes import create_plan_node, create_search_node
|
|
11
|
+
from haiku.rag.graph.deep_qa.dependencies import DeepQADependencies
|
|
12
|
+
from haiku.rag.graph.deep_qa.models import DeepQAAnswer, DeepQAEvaluation
|
|
13
|
+
from haiku.rag.graph.deep_qa.prompts import (
|
|
14
|
+
DECISION_PROMPT,
|
|
15
|
+
SYNTHESIS_PROMPT,
|
|
16
|
+
SYNTHESIS_PROMPT_WITH_CITATIONS,
|
|
17
|
+
)
|
|
18
|
+
from haiku.rag.graph.deep_qa.state import DeepQADeps, DeepQAState
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def build_deep_qa_graph(
|
|
22
|
+
config: AppConfig = Config,
|
|
23
|
+
) -> Graph[DeepQAState, DeepQADeps, None, DeepQAAnswer]:
|
|
24
|
+
"""Build the Deep QA graph.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
config: AppConfig object (uses config.qa for provider, model, and graph parameters)
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Configured Deep QA graph
|
|
31
|
+
"""
|
|
32
|
+
provider = config.qa.provider
|
|
33
|
+
model = config.qa.model
|
|
34
|
+
g = GraphBuilder(
|
|
35
|
+
state_type=DeepQAState,
|
|
36
|
+
deps_type=DeepQADeps,
|
|
37
|
+
output_type=DeepQAAnswer,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Create and register the plan node using the factory
|
|
41
|
+
plan = g.step(
|
|
42
|
+
create_plan_node(
|
|
43
|
+
provider=provider,
|
|
44
|
+
model=model,
|
|
45
|
+
deps_type=DeepQADependencies, # type: ignore[arg-type]
|
|
46
|
+
activity_message="Planning approach",
|
|
47
|
+
output_retries=None, # Deep QA doesn't use output_retries
|
|
48
|
+
)
|
|
49
|
+
) # type: ignore[arg-type]
|
|
50
|
+
|
|
51
|
+
# Create and register the search_one node using the factory
|
|
52
|
+
search_one = g.step(
|
|
53
|
+
create_search_node(
|
|
54
|
+
provider=provider,
|
|
55
|
+
model=model,
|
|
56
|
+
deps_type=DeepQADependencies, # type: ignore[arg-type]
|
|
57
|
+
with_step_wrapper=False, # Deep QA doesn't wrap with agui_emitter step
|
|
58
|
+
success_message_format="Answered: {sub_q}",
|
|
59
|
+
handle_exceptions=True,
|
|
60
|
+
)
|
|
61
|
+
) # type: ignore[arg-type]
|
|
62
|
+
|
|
63
|
+
@g.step
|
|
64
|
+
async def get_batch(
|
|
65
|
+
ctx: StepContext[DeepQAState, DeepQADeps, None | bool],
|
|
66
|
+
) -> list[str] | None:
|
|
67
|
+
"""Get all remaining questions for this iteration."""
|
|
68
|
+
state = ctx.state
|
|
69
|
+
|
|
70
|
+
if not state.context.sub_questions:
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
# Take ALL remaining questions - max_concurrency controls parallel execution within .map()
|
|
74
|
+
batch = list(state.context.sub_questions)
|
|
75
|
+
state.context.sub_questions.clear()
|
|
76
|
+
return batch
|
|
77
|
+
|
|
78
|
+
@g.step
|
|
79
|
+
async def decide(
|
|
80
|
+
ctx: StepContext[DeepQAState, DeepQADeps, list[SearchAnswer]],
|
|
81
|
+
) -> bool:
|
|
82
|
+
state = ctx.state
|
|
83
|
+
deps = ctx.deps
|
|
84
|
+
|
|
85
|
+
if deps.agui_emitter:
|
|
86
|
+
deps.agui_emitter.start_step("decide")
|
|
87
|
+
deps.agui_emitter.update_activity(
|
|
88
|
+
"evaluating", "Evaluating information sufficiency"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
agent = Agent(
|
|
93
|
+
model=get_model(provider, model),
|
|
94
|
+
output_type=DeepQAEvaluation,
|
|
95
|
+
instructions=DECISION_PROMPT,
|
|
96
|
+
retries=3,
|
|
97
|
+
deps_type=DeepQADependencies,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
context_data = {
|
|
101
|
+
"original_question": state.context.original_question,
|
|
102
|
+
"gathered_answers": [
|
|
103
|
+
{
|
|
104
|
+
"question": qa.query,
|
|
105
|
+
"answer": qa.answer,
|
|
106
|
+
"sources": qa.sources,
|
|
107
|
+
}
|
|
108
|
+
for qa in state.context.qa_responses
|
|
109
|
+
],
|
|
110
|
+
}
|
|
111
|
+
context_xml = format_as_xml(context_data, root_tag="gathered_information")
|
|
112
|
+
|
|
113
|
+
prompt = (
|
|
114
|
+
"Evaluate whether we have sufficient information to answer the question.\n\n"
|
|
115
|
+
f"{context_xml}"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
agent_deps = DeepQADependencies(
|
|
119
|
+
client=deps.client,
|
|
120
|
+
context=state.context,
|
|
121
|
+
)
|
|
122
|
+
result = await agent.run(prompt, deps=agent_deps)
|
|
123
|
+
evaluation = result.output
|
|
124
|
+
|
|
125
|
+
state.iterations += 1
|
|
126
|
+
|
|
127
|
+
for new_q in evaluation.new_questions:
|
|
128
|
+
if new_q not in state.context.sub_questions:
|
|
129
|
+
state.context.sub_questions.append(new_q)
|
|
130
|
+
|
|
131
|
+
if deps.agui_emitter:
|
|
132
|
+
deps.agui_emitter.update_state(state)
|
|
133
|
+
status = "sufficient" if evaluation.is_sufficient else "insufficient"
|
|
134
|
+
deps.agui_emitter.update_activity(
|
|
135
|
+
"evaluating",
|
|
136
|
+
f"Information {status} after {state.iterations} iteration(s)",
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
should_continue = (
|
|
140
|
+
not evaluation.is_sufficient and state.iterations < state.max_iterations
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return should_continue
|
|
144
|
+
finally:
|
|
145
|
+
if deps.agui_emitter:
|
|
146
|
+
deps.agui_emitter.finish_step()
|
|
147
|
+
|
|
148
|
+
@g.step
|
|
149
|
+
async def synthesize(
|
|
150
|
+
ctx: StepContext[DeepQAState, DeepQADeps, None | bool],
|
|
151
|
+
) -> DeepQAAnswer:
|
|
152
|
+
state = ctx.state
|
|
153
|
+
deps = ctx.deps
|
|
154
|
+
|
|
155
|
+
if deps.agui_emitter:
|
|
156
|
+
deps.agui_emitter.start_step("synthesize")
|
|
157
|
+
deps.agui_emitter.update_activity(
|
|
158
|
+
"synthesizing", "Synthesizing final answer"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
prompt_template = (
|
|
163
|
+
SYNTHESIS_PROMPT_WITH_CITATIONS
|
|
164
|
+
if state.context.use_citations
|
|
165
|
+
else SYNTHESIS_PROMPT
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
agent = Agent(
|
|
169
|
+
model=get_model(provider, model),
|
|
170
|
+
output_type=DeepQAAnswer,
|
|
171
|
+
instructions=prompt_template,
|
|
172
|
+
retries=3,
|
|
173
|
+
deps_type=DeepQADependencies,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
context_data = {
|
|
177
|
+
"original_question": state.context.original_question,
|
|
178
|
+
"sub_answers": [
|
|
179
|
+
{
|
|
180
|
+
"question": qa.query,
|
|
181
|
+
"answer": qa.answer,
|
|
182
|
+
"sources": qa.sources,
|
|
183
|
+
}
|
|
184
|
+
for qa in state.context.qa_responses
|
|
185
|
+
],
|
|
186
|
+
}
|
|
187
|
+
context_xml = format_as_xml(context_data, root_tag="gathered_information")
|
|
188
|
+
|
|
189
|
+
prompt = f"Synthesize a comprehensive answer to the original question.\n\n{context_xml}"
|
|
190
|
+
|
|
191
|
+
agent_deps = DeepQADependencies(
|
|
192
|
+
client=deps.client,
|
|
193
|
+
context=state.context,
|
|
194
|
+
)
|
|
195
|
+
result = await agent.run(prompt, deps=agent_deps)
|
|
196
|
+
|
|
197
|
+
if deps.agui_emitter:
|
|
198
|
+
deps.agui_emitter.update_activity("synthesizing", "Answer complete")
|
|
199
|
+
|
|
200
|
+
return result.output
|
|
201
|
+
finally:
|
|
202
|
+
if deps.agui_emitter:
|
|
203
|
+
deps.agui_emitter.finish_step()
|
|
204
|
+
|
|
205
|
+
# Build the graph structure
|
|
206
|
+
collect_answers = g.join(
|
|
207
|
+
reduce_list_append,
|
|
208
|
+
initial_factory=list[SearchAnswer],
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
g.add(
|
|
212
|
+
g.edge_from(g.start_node).to(plan),
|
|
213
|
+
g.edge_from(plan).to(get_batch),
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Branch based on whether we have questions
|
|
217
|
+
g.add(
|
|
218
|
+
g.edge_from(get_batch).to(
|
|
219
|
+
g.decision()
|
|
220
|
+
.branch(g.match(list).label("Has questions").map().to(search_one))
|
|
221
|
+
.branch(g.match(type(None)).label("No questions").to(synthesize))
|
|
222
|
+
),
|
|
223
|
+
g.edge_from(search_one).to(collect_answers),
|
|
224
|
+
g.edge_from(collect_answers).to(decide),
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Branch based on decision
|
|
228
|
+
g.add(
|
|
229
|
+
g.edge_from(decide).to(
|
|
230
|
+
g.decision()
|
|
231
|
+
.branch(
|
|
232
|
+
g.match(bool, matches=lambda x: x).label("Continue QA").to(get_batch)
|
|
233
|
+
)
|
|
234
|
+
.branch(
|
|
235
|
+
g.match(bool, matches=lambda x: not x)
|
|
236
|
+
.label("Done with QA")
|
|
237
|
+
.to(synthesize)
|
|
238
|
+
)
|
|
239
|
+
),
|
|
240
|
+
g.edge_from(synthesize).to(g.end_node),
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return g.build()
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DeepQAEvaluation(BaseModel):
|
|
5
|
+
is_sufficient: bool = Field(
|
|
6
|
+
description="Whether we have sufficient information to answer the question"
|
|
7
|
+
)
|
|
8
|
+
reasoning: str = Field(description="Explanation of the sufficiency assessment")
|
|
9
|
+
new_questions: list[str] = Field(
|
|
10
|
+
description="Additional sub-questions needed if insufficient",
|
|
11
|
+
default_factory=list,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DeepQAAnswer(BaseModel):
|
|
16
|
+
answer: str = Field(description="The comprehensive answer to the question")
|
|
17
|
+
sources: list[str] = Field(
|
|
18
|
+
description="Document titles or URIs used to generate the answer",
|
|
19
|
+
default_factory=list,
|
|
20
|
+
)
|