haiku.rag-slim 0.16.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag-slim might be problematic. Click here for more details.

Files changed (94) hide show
  1. haiku/rag/app.py +430 -72
  2. haiku/rag/chunkers/__init__.py +31 -0
  3. haiku/rag/chunkers/base.py +31 -0
  4. haiku/rag/chunkers/docling_local.py +164 -0
  5. haiku/rag/chunkers/docling_serve.py +179 -0
  6. haiku/rag/cli.py +207 -24
  7. haiku/rag/cli_chat.py +489 -0
  8. haiku/rag/client.py +1251 -266
  9. haiku/rag/config/__init__.py +16 -10
  10. haiku/rag/config/loader.py +5 -44
  11. haiku/rag/config/models.py +126 -17
  12. haiku/rag/converters/__init__.py +31 -0
  13. haiku/rag/converters/base.py +63 -0
  14. haiku/rag/converters/docling_local.py +193 -0
  15. haiku/rag/converters/docling_serve.py +229 -0
  16. haiku/rag/converters/text_utils.py +237 -0
  17. haiku/rag/embeddings/__init__.py +123 -24
  18. haiku/rag/embeddings/voyageai.py +175 -20
  19. haiku/rag/graph/__init__.py +0 -11
  20. haiku/rag/graph/agui/__init__.py +8 -2
  21. haiku/rag/graph/agui/cli_renderer.py +1 -1
  22. haiku/rag/graph/agui/emitter.py +219 -31
  23. haiku/rag/graph/agui/server.py +20 -62
  24. haiku/rag/graph/agui/stream.py +1 -2
  25. haiku/rag/graph/research/__init__.py +5 -2
  26. haiku/rag/graph/research/dependencies.py +12 -126
  27. haiku/rag/graph/research/graph.py +390 -135
  28. haiku/rag/graph/research/models.py +91 -112
  29. haiku/rag/graph/research/prompts.py +99 -91
  30. haiku/rag/graph/research/state.py +35 -27
  31. haiku/rag/inspector/__init__.py +8 -0
  32. haiku/rag/inspector/app.py +259 -0
  33. haiku/rag/inspector/widgets/__init__.py +6 -0
  34. haiku/rag/inspector/widgets/chunk_list.py +100 -0
  35. haiku/rag/inspector/widgets/context_modal.py +89 -0
  36. haiku/rag/inspector/widgets/detail_view.py +130 -0
  37. haiku/rag/inspector/widgets/document_list.py +75 -0
  38. haiku/rag/inspector/widgets/info_modal.py +209 -0
  39. haiku/rag/inspector/widgets/search_modal.py +183 -0
  40. haiku/rag/inspector/widgets/visual_modal.py +126 -0
  41. haiku/rag/mcp.py +106 -102
  42. haiku/rag/monitor.py +33 -9
  43. haiku/rag/providers/__init__.py +5 -0
  44. haiku/rag/providers/docling_serve.py +108 -0
  45. haiku/rag/qa/__init__.py +12 -10
  46. haiku/rag/qa/agent.py +43 -61
  47. haiku/rag/qa/prompts.py +35 -57
  48. haiku/rag/reranking/__init__.py +9 -6
  49. haiku/rag/reranking/base.py +1 -1
  50. haiku/rag/reranking/cohere.py +5 -4
  51. haiku/rag/reranking/mxbai.py +5 -2
  52. haiku/rag/reranking/vllm.py +3 -4
  53. haiku/rag/reranking/zeroentropy.py +6 -5
  54. haiku/rag/store/__init__.py +2 -1
  55. haiku/rag/store/engine.py +242 -42
  56. haiku/rag/store/exceptions.py +4 -0
  57. haiku/rag/store/models/__init__.py +8 -2
  58. haiku/rag/store/models/chunk.py +190 -0
  59. haiku/rag/store/models/document.py +46 -0
  60. haiku/rag/store/repositories/chunk.py +141 -121
  61. haiku/rag/store/repositories/document.py +25 -84
  62. haiku/rag/store/repositories/settings.py +11 -14
  63. haiku/rag/store/upgrades/__init__.py +19 -3
  64. haiku/rag/store/upgrades/v0_10_1.py +1 -1
  65. haiku/rag/store/upgrades/v0_19_6.py +65 -0
  66. haiku/rag/store/upgrades/v0_20_0.py +68 -0
  67. haiku/rag/store/upgrades/v0_23_1.py +100 -0
  68. haiku/rag/store/upgrades/v0_9_3.py +3 -3
  69. haiku/rag/utils.py +371 -146
  70. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/METADATA +15 -12
  71. haiku_rag_slim-0.24.0.dist-info/RECORD +78 -0
  72. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/WHEEL +1 -1
  73. haiku/rag/chunker.py +0 -65
  74. haiku/rag/embeddings/base.py +0 -25
  75. haiku/rag/embeddings/ollama.py +0 -28
  76. haiku/rag/embeddings/openai.py +0 -26
  77. haiku/rag/embeddings/vllm.py +0 -29
  78. haiku/rag/graph/agui/events.py +0 -254
  79. haiku/rag/graph/common/__init__.py +0 -5
  80. haiku/rag/graph/common/models.py +0 -42
  81. haiku/rag/graph/common/nodes.py +0 -265
  82. haiku/rag/graph/common/prompts.py +0 -46
  83. haiku/rag/graph/common/utils.py +0 -44
  84. haiku/rag/graph/deep_qa/__init__.py +0 -1
  85. haiku/rag/graph/deep_qa/dependencies.py +0 -27
  86. haiku/rag/graph/deep_qa/graph.py +0 -243
  87. haiku/rag/graph/deep_qa/models.py +0 -20
  88. haiku/rag/graph/deep_qa/prompts.py +0 -59
  89. haiku/rag/graph/deep_qa/state.py +0 -56
  90. haiku/rag/graph/research/common.py +0 -87
  91. haiku/rag/reader.py +0 -135
  92. haiku_rag_slim-0.16.0.dist-info/RECORD +0 -71
  93. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/entry_points.txt +0 -0
  94. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,265 +0,0 @@
1
- """Common node implementations for graph workflows."""
2
-
3
- import asyncio
4
- from collections.abc import Awaitable, Callable
5
- from typing import Any, Protocol
6
-
7
- from pydantic_ai import Agent, RunContext
8
- from pydantic_ai.format_prompt import format_as_xml
9
- from pydantic_ai.output import ToolOutput
10
- from pydantic_graph.beta import StepContext
11
-
12
- from haiku.rag.client import HaikuRAG
13
- from haiku.rag.graph.agui.emitter import AGUIEmitter
14
- from haiku.rag.graph.common import get_model
15
- from haiku.rag.graph.common.models import ResearchPlan, SearchAnswer
16
- from haiku.rag.graph.common.prompts import PLAN_PROMPT, SEARCH_AGENT_PROMPT
17
-
18
-
19
- class GraphContext(Protocol):
20
- """Protocol for graph context objects."""
21
-
22
- original_question: str
23
- sub_questions: list[str]
24
-
25
- def add_qa_response(self, qa: SearchAnswer) -> None:
26
- """Add a QA response to context."""
27
- ...
28
-
29
-
30
- class GraphState(Protocol):
31
- """Protocol for graph state objects."""
32
-
33
- context: GraphContext
34
- max_concurrency: int
35
-
36
-
37
- class GraphDeps(Protocol):
38
- """Protocol for graph dependencies."""
39
-
40
- client: HaikuRAG
41
- agui_emitter: AGUIEmitter[Any, Any] | None
42
- semaphore: asyncio.Semaphore | None
43
-
44
-
45
- class GraphAgentDeps(Protocol):
46
- """Protocol for agent dependencies."""
47
-
48
- client: HaikuRAG
49
- context: GraphContext
50
-
51
-
52
- def create_plan_node[AgentDepsT: GraphAgentDeps](
53
- provider: str,
54
- model: str,
55
- deps_type: type[AgentDepsT],
56
- activity_message: str = "Creating plan",
57
- output_retries: int | None = None,
58
- ) -> Callable[[StepContext[Any, Any, None]], Awaitable[None]]:
59
- """Create a plan node for any graph.
60
-
61
- Args:
62
- provider: Model provider (e.g., 'openai', 'anthropic')
63
- model: Model name
64
- deps_type: Type of dependencies for the agent (e.g., ResearchDependencies, DeepQADependencies)
65
- activity_message: Message to show during planning activity
66
- output_retries: Number of output retries for the agent (optional)
67
-
68
- Returns:
69
- Async function that can be used as a graph step
70
- """
71
-
72
- async def plan(ctx: StepContext[Any, Any, None], /) -> None:
73
- state: GraphState = ctx.state # type: ignore[assignment]
74
- deps: GraphDeps = ctx.deps # type: ignore[assignment]
75
-
76
- if deps.agui_emitter:
77
- deps.agui_emitter.start_step("plan")
78
- deps.agui_emitter.update_activity("planning", activity_message)
79
-
80
- try:
81
- # Build agent configuration
82
- agent_config = {
83
- "model": get_model(provider, model),
84
- "output_type": ResearchPlan,
85
- "instructions": (
86
- PLAN_PROMPT
87
- + "\n\nUse the gather_context tool once on the main question before planning."
88
- ),
89
- "retries": 3,
90
- "deps_type": deps_type,
91
- }
92
- if output_retries is not None:
93
- agent_config["output_retries"] = output_retries
94
-
95
- plan_agent = Agent(**agent_config)
96
-
97
- @plan_agent.tool
98
- async def gather_context(
99
- ctx2: RunContext[AgentDepsT], query: str, limit: int = 6
100
- ) -> str:
101
- results = await ctx2.deps.client.search(query, limit=limit)
102
- expanded = await ctx2.deps.client.expand_context(results)
103
- return "\n\n".join(chunk.content for chunk, _ in expanded)
104
-
105
- # Tool is registered via decorator above
106
- _ = gather_context
107
-
108
- prompt = (
109
- "Plan a focused approach for the main question.\n\n"
110
- f"Main question: {state.context.original_question}"
111
- )
112
-
113
- # Create agent dependencies
114
- agent_deps = deps_type(client=deps.client, context=state.context) # type: ignore[call-arg]
115
- plan_result = await plan_agent.run(prompt, deps=agent_deps)
116
- state.context.sub_questions = list(plan_result.output.sub_questions)
117
-
118
- # State now contains the plan - emit state update and narrate
119
- if deps.agui_emitter:
120
- deps.agui_emitter.update_state(state)
121
- count = len(state.context.sub_questions)
122
- deps.agui_emitter.update_activity(
123
- "planning", f"Created plan with {count} sub-questions"
124
- )
125
- finally:
126
- if deps.agui_emitter:
127
- deps.agui_emitter.finish_step()
128
-
129
- return plan
130
-
131
-
132
- def create_search_node[AgentDepsT: GraphAgentDeps](
133
- provider: str,
134
- model: str,
135
- deps_type: type[AgentDepsT],
136
- with_step_wrapper: bool = True,
137
- success_message_format: str = "Answered: {sub_q}",
138
- handle_exceptions: bool = False,
139
- ) -> Callable[[StepContext[Any, Any, str]], Awaitable[SearchAnswer]]:
140
- """Create a search_one node for any graph.
141
-
142
- Args:
143
- provider: Model provider
144
- model: Model name
145
- deps_type: Type of dependencies for the agent
146
- with_step_wrapper: Whether to wrap with agui_emitter start/finish step
147
- success_message_format: Format string for success activity message
148
- handle_exceptions: Whether to handle exceptions with fallback answer
149
-
150
- Returns:
151
- Async function that can be used as a graph step
152
- """
153
-
154
- async def search_one(ctx: StepContext[Any, Any, str], /) -> SearchAnswer:
155
- state: GraphState = ctx.state # type: ignore[assignment]
156
- deps: GraphDeps = ctx.deps # type: ignore[assignment]
157
- sub_q = ctx.inputs
158
-
159
- # Create unique step name from question text
160
- step_name = f"search: {sub_q}"
161
-
162
- if deps.agui_emitter and with_step_wrapper:
163
- deps.agui_emitter.start_step(step_name)
164
-
165
- try:
166
- # Create semaphore if not already provided
167
- if deps.semaphore is None:
168
- deps.semaphore = asyncio.Semaphore(state.max_concurrency)
169
-
170
- # Use semaphore to control concurrency
171
- async with deps.semaphore:
172
- return await _do_search(
173
- state,
174
- deps,
175
- sub_q,
176
- provider,
177
- model,
178
- deps_type,
179
- success_message_format,
180
- handle_exceptions,
181
- )
182
- finally:
183
- if deps.agui_emitter and with_step_wrapper:
184
- deps.agui_emitter.finish_step()
185
-
186
- return search_one
187
-
188
-
189
- async def _do_search[AgentDepsT: GraphAgentDeps](
190
- state: GraphState,
191
- deps: GraphDeps,
192
- sub_q: str,
193
- provider: str,
194
- model: str,
195
- deps_type: type[AgentDepsT],
196
- success_message_format: str,
197
- handle_exceptions: bool,
198
- ) -> SearchAnswer:
199
- """Internal search implementation."""
200
- if deps.agui_emitter:
201
- deps.agui_emitter.update_activity("searching", f"Searching: {sub_q}")
202
-
203
- agent = Agent(
204
- model=get_model(provider, model),
205
- output_type=ToolOutput(SearchAnswer, max_retries=3),
206
- instructions=SEARCH_AGENT_PROMPT,
207
- retries=3,
208
- deps_type=deps_type,
209
- )
210
-
211
- @agent.tool
212
- async def search_and_answer(
213
- ctx2: RunContext[AgentDepsT], query: str, limit: int = 5
214
- ) -> str:
215
- search_results = await ctx2.deps.client.search(query, limit=limit)
216
- expanded = await ctx2.deps.client.expand_context(search_results)
217
-
218
- entries: list[dict[str, Any]] = [
219
- {
220
- "text": chunk.content,
221
- "score": score,
222
- "document_uri": (chunk.document_title or chunk.document_uri or ""),
223
- }
224
- for chunk, score in expanded
225
- ]
226
- if not entries:
227
- return f"No relevant information found in the knowledge base for: {query}"
228
-
229
- return format_as_xml(entries, root_tag="snippets")
230
-
231
- # Tool is registered via decorator above
232
- _ = search_and_answer
233
-
234
- agent_deps = deps_type(client=deps.client, context=state.context) # type: ignore[call-arg]
235
-
236
- try:
237
- result = await agent.run(sub_q, deps=agent_deps)
238
- answer = result.output
239
- if answer:
240
- state.context.add_qa_response(answer)
241
- # State updated with new answer - emit state update and narrate
242
- if deps.agui_emitter:
243
- deps.agui_emitter.update_state(state)
244
- # Format the success message
245
- if "{confidence" in success_message_format:
246
- message = success_message_format.format(
247
- sub_q=sub_q, confidence=answer.confidence
248
- )
249
- else:
250
- message = success_message_format.format(sub_q=sub_q)
251
- deps.agui_emitter.update_activity("searching", message)
252
- return answer
253
- except Exception as e:
254
- if handle_exceptions:
255
- # Narrate the error
256
- if deps.agui_emitter:
257
- deps.agui_emitter.update_activity("searching", f"Search failed: {e}")
258
- failure_answer = SearchAnswer(
259
- query=sub_q,
260
- answer=f"Search failed after retries: {str(e)}",
261
- confidence=0.0,
262
- )
263
- return failure_answer
264
- else:
265
- raise
@@ -1,46 +0,0 @@
1
- """Common prompts used across different graph implementations."""
2
-
3
- PLAN_PROMPT = """You are the research orchestrator for a focused, iterative workflow.
4
-
5
- Responsibilities:
6
- 1. Understand and decompose the main question
7
- 2. Propose a minimal, high-leverage plan
8
- 3. Coordinate specialized agents to gather evidence
9
- 4. Iterate based on gaps and new findings
10
-
11
- Plan requirements:
12
- - Produce at most 3 sub_questions that together cover the main question.
13
- - Each sub_question must be a standalone, self-contained query that can run
14
- without extra context. Include concrete entities, scope, timeframe, and any
15
- qualifiers. Avoid ambiguous pronouns (it/they/this/that).
16
- - Prioritize the highest-value aspects first; avoid redundancy and overlap.
17
- - Prefer questions that are likely answerable from the current knowledge base;
18
- if coverage is uncertain, make scopes narrower and specific.
19
- - Order sub_questions by execution priority (most valuable first)."""
20
-
21
- SEARCH_AGENT_PROMPT = """You are a search and question-answering specialist.
22
-
23
- Tasks:
24
- 1. Search the knowledge base for relevant evidence.
25
- 2. Analyze retrieved snippets.
26
- 3. Provide an answer strictly grounded in that evidence.
27
-
28
- Tool usage:
29
- - Always call search_and_answer before drafting any answer.
30
- - The tool returns snippets with verbatim `text`, a relevance `score`, and the
31
- originating document identifier (document title if available, otherwise URI).
32
- - You may call the tool multiple times to refine or broaden context, but do not
33
- exceed 3 total calls. Favor precision over volume.
34
- - Use scores to prioritize evidence, but include only the minimal subset of
35
- snippet texts (verbatim) in SearchAnswer.context (typically 1-4).
36
- - Set SearchAnswer.sources to the corresponding document identifiers for the
37
- snippets you used (title if available, otherwise URI; one per snippet; same
38
- order as context). Context must be text-only.
39
- - If no relevant information is found, clearly say so and return an empty
40
- context list and sources list.
41
-
42
- Answering rules:
43
- - Be direct and specific; avoid meta commentary about the process.
44
- - Do not include any claims not supported by the provided snippets.
45
- - Prefer concise phrasing; avoid copying long passages.
46
- - When evidence is partial, state the limits explicitly in the answer."""
@@ -1,44 +0,0 @@
1
- """Common utilities for all graph implementations."""
2
-
3
- from pydantic_ai.models.openai import OpenAIChatModel
4
- from pydantic_ai.providers.ollama import OllamaProvider
5
- from pydantic_ai.providers.openai import OpenAIProvider
6
-
7
- from haiku.rag.config import Config
8
-
9
-
10
- def get_model(provider: str, model: str) -> OpenAIChatModel | str:
11
- """
12
- Get a model instance for the specified provider and model name.
13
-
14
- Args:
15
- provider: The model provider ("ollama", "vllm", or other)
16
- model: The model name
17
-
18
- Returns:
19
- A configured model instance
20
-
21
- Raises:
22
- ValueError: If the provider is unknown
23
- """
24
- if provider == "ollama":
25
- return OpenAIChatModel(
26
- model_name=model,
27
- provider=OllamaProvider(base_url=f"{Config.providers.ollama.base_url}/v1"),
28
- )
29
- elif provider == "vllm":
30
- return OpenAIChatModel(
31
- model_name=model,
32
- provider=OpenAIProvider(
33
- base_url=f"{Config.providers.vllm.research_base_url or Config.providers.vllm.qa_base_url}/v1",
34
- api_key="none",
35
- ),
36
- )
37
- elif provider in ("openai", "anthropic", "gemini", "groq", "bedrock"):
38
- # These providers use string format
39
- return f"{provider}:{model}"
40
- else:
41
- raise ValueError(
42
- f"Unknown model provider: {provider}. "
43
- f"Supported providers: ollama, vllm, openai, anthropic, gemini, groq, bedrock"
44
- )
@@ -1 +0,0 @@
1
- from haiku.rag.graph.deep_qa.models import DeepQAAnswer
@@ -1,27 +0,0 @@
1
- from pydantic import BaseModel, Field
2
-
3
- from haiku.rag.client import HaikuRAG
4
- from haiku.rag.graph.common.models import SearchAnswer
5
-
6
-
7
- class DeepQAContext(BaseModel):
8
- original_question: str = Field(description="The original question")
9
- sub_questions: list[str] = Field(
10
- default_factory=list, description="Decomposed sub-questions"
11
- )
12
- qa_responses: list[SearchAnswer] = Field(
13
- default_factory=list, description="QA pairs collected during answering"
14
- )
15
- use_citations: bool = Field(
16
- default=False, description="Whether to include citations in the answer"
17
- )
18
-
19
- def add_qa_response(self, qa: SearchAnswer) -> None:
20
- self.qa_responses.append(qa)
21
-
22
-
23
- class DeepQADependencies(BaseModel):
24
- model_config = {"arbitrary_types_allowed": True}
25
-
26
- client: HaikuRAG = Field(description="RAG client for document operations")
27
- context: DeepQAContext = Field(description="Shared QA context")
@@ -1,243 +0,0 @@
1
- from pydantic_ai import Agent
2
- from pydantic_ai.format_prompt import format_as_xml
3
- from pydantic_graph.beta import Graph, GraphBuilder, StepContext
4
- from pydantic_graph.beta.join import reduce_list_append
5
-
6
- from haiku.rag.config import Config
7
- from haiku.rag.config.models import AppConfig
8
- from haiku.rag.graph.common import get_model
9
- from haiku.rag.graph.common.models import SearchAnswer
10
- from haiku.rag.graph.common.nodes import create_plan_node, create_search_node
11
- from haiku.rag.graph.deep_qa.dependencies import DeepQADependencies
12
- from haiku.rag.graph.deep_qa.models import DeepQAAnswer, DeepQAEvaluation
13
- from haiku.rag.graph.deep_qa.prompts import (
14
- DECISION_PROMPT,
15
- SYNTHESIS_PROMPT,
16
- SYNTHESIS_PROMPT_WITH_CITATIONS,
17
- )
18
- from haiku.rag.graph.deep_qa.state import DeepQADeps, DeepQAState
19
-
20
-
21
- def build_deep_qa_graph(
22
- config: AppConfig = Config,
23
- ) -> Graph[DeepQAState, DeepQADeps, None, DeepQAAnswer]:
24
- """Build the Deep QA graph.
25
-
26
- Args:
27
- config: AppConfig object (uses config.qa for provider, model, and graph parameters)
28
-
29
- Returns:
30
- Configured Deep QA graph
31
- """
32
- provider = config.qa.provider
33
- model = config.qa.model
34
- g = GraphBuilder(
35
- state_type=DeepQAState,
36
- deps_type=DeepQADeps,
37
- output_type=DeepQAAnswer,
38
- )
39
-
40
- # Create and register the plan node using the factory
41
- plan = g.step(
42
- create_plan_node(
43
- provider=provider,
44
- model=model,
45
- deps_type=DeepQADependencies, # type: ignore[arg-type]
46
- activity_message="Planning approach",
47
- output_retries=None, # Deep QA doesn't use output_retries
48
- )
49
- ) # type: ignore[arg-type]
50
-
51
- # Create and register the search_one node using the factory
52
- search_one = g.step(
53
- create_search_node(
54
- provider=provider,
55
- model=model,
56
- deps_type=DeepQADependencies, # type: ignore[arg-type]
57
- with_step_wrapper=False, # Deep QA doesn't wrap with agui_emitter step
58
- success_message_format="Answered: {sub_q}",
59
- handle_exceptions=True,
60
- )
61
- ) # type: ignore[arg-type]
62
-
63
- @g.step
64
- async def get_batch(
65
- ctx: StepContext[DeepQAState, DeepQADeps, None | bool],
66
- ) -> list[str] | None:
67
- """Get all remaining questions for this iteration."""
68
- state = ctx.state
69
-
70
- if not state.context.sub_questions:
71
- return None
72
-
73
- # Take ALL remaining questions - max_concurrency controls parallel execution within .map()
74
- batch = list(state.context.sub_questions)
75
- state.context.sub_questions.clear()
76
- return batch
77
-
78
- @g.step
79
- async def decide(
80
- ctx: StepContext[DeepQAState, DeepQADeps, list[SearchAnswer]],
81
- ) -> bool:
82
- state = ctx.state
83
- deps = ctx.deps
84
-
85
- if deps.agui_emitter:
86
- deps.agui_emitter.start_step("decide")
87
- deps.agui_emitter.update_activity(
88
- "evaluating", "Evaluating information sufficiency"
89
- )
90
-
91
- try:
92
- agent = Agent(
93
- model=get_model(provider, model),
94
- output_type=DeepQAEvaluation,
95
- instructions=DECISION_PROMPT,
96
- retries=3,
97
- deps_type=DeepQADependencies,
98
- )
99
-
100
- context_data = {
101
- "original_question": state.context.original_question,
102
- "gathered_answers": [
103
- {
104
- "question": qa.query,
105
- "answer": qa.answer,
106
- "sources": qa.sources,
107
- }
108
- for qa in state.context.qa_responses
109
- ],
110
- }
111
- context_xml = format_as_xml(context_data, root_tag="gathered_information")
112
-
113
- prompt = (
114
- "Evaluate whether we have sufficient information to answer the question.\n\n"
115
- f"{context_xml}"
116
- )
117
-
118
- agent_deps = DeepQADependencies(
119
- client=deps.client,
120
- context=state.context,
121
- )
122
- result = await agent.run(prompt, deps=agent_deps)
123
- evaluation = result.output
124
-
125
- state.iterations += 1
126
-
127
- for new_q in evaluation.new_questions:
128
- if new_q not in state.context.sub_questions:
129
- state.context.sub_questions.append(new_q)
130
-
131
- if deps.agui_emitter:
132
- deps.agui_emitter.update_state(state)
133
- status = "sufficient" if evaluation.is_sufficient else "insufficient"
134
- deps.agui_emitter.update_activity(
135
- "evaluating",
136
- f"Information {status} after {state.iterations} iteration(s)",
137
- )
138
-
139
- should_continue = (
140
- not evaluation.is_sufficient and state.iterations < state.max_iterations
141
- )
142
-
143
- return should_continue
144
- finally:
145
- if deps.agui_emitter:
146
- deps.agui_emitter.finish_step()
147
-
148
- @g.step
149
- async def synthesize(
150
- ctx: StepContext[DeepQAState, DeepQADeps, None | bool],
151
- ) -> DeepQAAnswer:
152
- state = ctx.state
153
- deps = ctx.deps
154
-
155
- if deps.agui_emitter:
156
- deps.agui_emitter.start_step("synthesize")
157
- deps.agui_emitter.update_activity(
158
- "synthesizing", "Synthesizing final answer"
159
- )
160
-
161
- try:
162
- prompt_template = (
163
- SYNTHESIS_PROMPT_WITH_CITATIONS
164
- if state.context.use_citations
165
- else SYNTHESIS_PROMPT
166
- )
167
-
168
- agent = Agent(
169
- model=get_model(provider, model),
170
- output_type=DeepQAAnswer,
171
- instructions=prompt_template,
172
- retries=3,
173
- deps_type=DeepQADependencies,
174
- )
175
-
176
- context_data = {
177
- "original_question": state.context.original_question,
178
- "sub_answers": [
179
- {
180
- "question": qa.query,
181
- "answer": qa.answer,
182
- "sources": qa.sources,
183
- }
184
- for qa in state.context.qa_responses
185
- ],
186
- }
187
- context_xml = format_as_xml(context_data, root_tag="gathered_information")
188
-
189
- prompt = f"Synthesize a comprehensive answer to the original question.\n\n{context_xml}"
190
-
191
- agent_deps = DeepQADependencies(
192
- client=deps.client,
193
- context=state.context,
194
- )
195
- result = await agent.run(prompt, deps=agent_deps)
196
-
197
- if deps.agui_emitter:
198
- deps.agui_emitter.update_activity("synthesizing", "Answer complete")
199
-
200
- return result.output
201
- finally:
202
- if deps.agui_emitter:
203
- deps.agui_emitter.finish_step()
204
-
205
- # Build the graph structure
206
- collect_answers = g.join(
207
- reduce_list_append,
208
- initial_factory=list[SearchAnswer],
209
- )
210
-
211
- g.add(
212
- g.edge_from(g.start_node).to(plan),
213
- g.edge_from(plan).to(get_batch),
214
- )
215
-
216
- # Branch based on whether we have questions
217
- g.add(
218
- g.edge_from(get_batch).to(
219
- g.decision()
220
- .branch(g.match(list).label("Has questions").map().to(search_one))
221
- .branch(g.match(type(None)).label("No questions").to(synthesize))
222
- ),
223
- g.edge_from(search_one).to(collect_answers),
224
- g.edge_from(collect_answers).to(decide),
225
- )
226
-
227
- # Branch based on decision
228
- g.add(
229
- g.edge_from(decide).to(
230
- g.decision()
231
- .branch(
232
- g.match(bool, matches=lambda x: x).label("Continue QA").to(get_batch)
233
- )
234
- .branch(
235
- g.match(bool, matches=lambda x: not x)
236
- .label("Done with QA")
237
- .to(synthesize)
238
- )
239
- ),
240
- g.edge_from(synthesize).to(g.end_node),
241
- )
242
-
243
- return g.build()
@@ -1,20 +0,0 @@
1
- from pydantic import BaseModel, Field
2
-
3
-
4
- class DeepQAEvaluation(BaseModel):
5
- is_sufficient: bool = Field(
6
- description="Whether we have sufficient information to answer the question"
7
- )
8
- reasoning: str = Field(description="Explanation of the sufficiency assessment")
9
- new_questions: list[str] = Field(
10
- description="Additional sub-questions needed if insufficient",
11
- default_factory=list,
12
- )
13
-
14
-
15
- class DeepQAAnswer(BaseModel):
16
- answer: str = Field(description="The comprehensive answer to the question")
17
- sources: list[str] = Field(
18
- description="Document titles or URIs used to generate the answer",
19
- default_factory=list,
20
- )