haiku.rag-slim 0.16.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag-slim might be problematic. Click here for more details.

Files changed (94) hide show
  1. haiku/rag/app.py +430 -72
  2. haiku/rag/chunkers/__init__.py +31 -0
  3. haiku/rag/chunkers/base.py +31 -0
  4. haiku/rag/chunkers/docling_local.py +164 -0
  5. haiku/rag/chunkers/docling_serve.py +179 -0
  6. haiku/rag/cli.py +207 -24
  7. haiku/rag/cli_chat.py +489 -0
  8. haiku/rag/client.py +1251 -266
  9. haiku/rag/config/__init__.py +16 -10
  10. haiku/rag/config/loader.py +5 -44
  11. haiku/rag/config/models.py +126 -17
  12. haiku/rag/converters/__init__.py +31 -0
  13. haiku/rag/converters/base.py +63 -0
  14. haiku/rag/converters/docling_local.py +193 -0
  15. haiku/rag/converters/docling_serve.py +229 -0
  16. haiku/rag/converters/text_utils.py +237 -0
  17. haiku/rag/embeddings/__init__.py +123 -24
  18. haiku/rag/embeddings/voyageai.py +175 -20
  19. haiku/rag/graph/__init__.py +0 -11
  20. haiku/rag/graph/agui/__init__.py +8 -2
  21. haiku/rag/graph/agui/cli_renderer.py +1 -1
  22. haiku/rag/graph/agui/emitter.py +219 -31
  23. haiku/rag/graph/agui/server.py +20 -62
  24. haiku/rag/graph/agui/stream.py +1 -2
  25. haiku/rag/graph/research/__init__.py +5 -2
  26. haiku/rag/graph/research/dependencies.py +12 -126
  27. haiku/rag/graph/research/graph.py +390 -135
  28. haiku/rag/graph/research/models.py +91 -112
  29. haiku/rag/graph/research/prompts.py +99 -91
  30. haiku/rag/graph/research/state.py +35 -27
  31. haiku/rag/inspector/__init__.py +8 -0
  32. haiku/rag/inspector/app.py +259 -0
  33. haiku/rag/inspector/widgets/__init__.py +6 -0
  34. haiku/rag/inspector/widgets/chunk_list.py +100 -0
  35. haiku/rag/inspector/widgets/context_modal.py +89 -0
  36. haiku/rag/inspector/widgets/detail_view.py +130 -0
  37. haiku/rag/inspector/widgets/document_list.py +75 -0
  38. haiku/rag/inspector/widgets/info_modal.py +209 -0
  39. haiku/rag/inspector/widgets/search_modal.py +183 -0
  40. haiku/rag/inspector/widgets/visual_modal.py +126 -0
  41. haiku/rag/mcp.py +106 -102
  42. haiku/rag/monitor.py +33 -9
  43. haiku/rag/providers/__init__.py +5 -0
  44. haiku/rag/providers/docling_serve.py +108 -0
  45. haiku/rag/qa/__init__.py +12 -10
  46. haiku/rag/qa/agent.py +43 -61
  47. haiku/rag/qa/prompts.py +35 -57
  48. haiku/rag/reranking/__init__.py +9 -6
  49. haiku/rag/reranking/base.py +1 -1
  50. haiku/rag/reranking/cohere.py +5 -4
  51. haiku/rag/reranking/mxbai.py +5 -2
  52. haiku/rag/reranking/vllm.py +3 -4
  53. haiku/rag/reranking/zeroentropy.py +6 -5
  54. haiku/rag/store/__init__.py +2 -1
  55. haiku/rag/store/engine.py +242 -42
  56. haiku/rag/store/exceptions.py +4 -0
  57. haiku/rag/store/models/__init__.py +8 -2
  58. haiku/rag/store/models/chunk.py +190 -0
  59. haiku/rag/store/models/document.py +46 -0
  60. haiku/rag/store/repositories/chunk.py +141 -121
  61. haiku/rag/store/repositories/document.py +25 -84
  62. haiku/rag/store/repositories/settings.py +11 -14
  63. haiku/rag/store/upgrades/__init__.py +19 -3
  64. haiku/rag/store/upgrades/v0_10_1.py +1 -1
  65. haiku/rag/store/upgrades/v0_19_6.py +65 -0
  66. haiku/rag/store/upgrades/v0_20_0.py +68 -0
  67. haiku/rag/store/upgrades/v0_23_1.py +100 -0
  68. haiku/rag/store/upgrades/v0_9_3.py +3 -3
  69. haiku/rag/utils.py +371 -146
  70. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/METADATA +15 -12
  71. haiku_rag_slim-0.24.0.dist-info/RECORD +78 -0
  72. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/WHEEL +1 -1
  73. haiku/rag/chunker.py +0 -65
  74. haiku/rag/embeddings/base.py +0 -25
  75. haiku/rag/embeddings/ollama.py +0 -28
  76. haiku/rag/embeddings/openai.py +0 -26
  77. haiku/rag/embeddings/vllm.py +0 -29
  78. haiku/rag/graph/agui/events.py +0 -254
  79. haiku/rag/graph/common/__init__.py +0 -5
  80. haiku/rag/graph/common/models.py +0 -42
  81. haiku/rag/graph/common/nodes.py +0 -265
  82. haiku/rag/graph/common/prompts.py +0 -46
  83. haiku/rag/graph/common/utils.py +0 -44
  84. haiku/rag/graph/deep_qa/__init__.py +0 -1
  85. haiku/rag/graph/deep_qa/dependencies.py +0 -27
  86. haiku/rag/graph/deep_qa/graph.py +0 -243
  87. haiku/rag/graph/deep_qa/models.py +0 -20
  88. haiku/rag/graph/deep_qa/prompts.py +0 -59
  89. haiku/rag/graph/deep_qa/state.py +0 -56
  90. haiku/rag/graph/research/common.py +0 -87
  91. haiku/rag/reader.py +0 -135
  92. haiku_rag_slim-0.16.0.dist-info/RECORD +0 -71
  93. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/entry_points.txt +0 -0
  94. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,172 +1,305 @@
1
- from pydantic_ai import Agent
1
+ import asyncio
2
+ from typing import Literal
3
+ from uuid import uuid4
4
+
5
+ from pydantic_ai import Agent, RunContext, format_as_xml
6
+ from pydantic_ai.output import ToolOutput
2
7
  from pydantic_graph.beta import Graph, GraphBuilder, StepContext
3
8
  from pydantic_graph.beta.join import reduce_list_append
4
9
 
5
10
  from haiku.rag.config import Config
6
11
  from haiku.rag.config.models import AppConfig
7
- from haiku.rag.graph.common import get_model
8
- from haiku.rag.graph.common.models import SearchAnswer
9
- from haiku.rag.graph.common.nodes import create_plan_node, create_search_node
10
- from haiku.rag.graph.research.common import (
11
- format_analysis_for_prompt,
12
- format_context_for_prompt,
12
+ from haiku.rag.graph.agui.emitter import (
13
+ emit_text_message_end,
14
+ emit_text_message_start,
15
+ emit_tool_call_args,
16
+ emit_tool_call_end,
17
+ emit_tool_call_start,
13
18
  )
14
- from haiku.rag.graph.research.dependencies import ResearchDependencies
19
+ from haiku.rag.graph.research.dependencies import ResearchContext, ResearchDependencies
15
20
  from haiku.rag.graph.research.models import (
16
21
  EvaluationResult,
17
- InsightAnalysis,
22
+ RawSearchAnswer,
23
+ ResearchPlan,
18
24
  ResearchReport,
25
+ SearchAnswer,
19
26
  )
20
27
  from haiku.rag.graph.research.prompts import (
21
- DECISION_AGENT_PROMPT,
22
- INSIGHT_AGENT_PROMPT,
23
- SYNTHESIS_AGENT_PROMPT,
28
+ DECISION_PROMPT,
29
+ PLAN_PROMPT,
30
+ SEARCH_PROMPT,
31
+ SYNTHESIS_PROMPT,
24
32
  )
25
33
  from haiku.rag.graph.research.state import ResearchDeps, ResearchState
34
+ from haiku.rag.utils import build_prompt, get_model
35
+
36
+
37
+ def format_context_for_prompt(context: ResearchContext) -> str:
38
+ """Format the research context as XML for inclusion in prompts."""
39
+ context_data = {
40
+ "original_question": context.original_question,
41
+ "unanswered_questions": context.sub_questions,
42
+ "qa_responses": [
43
+ {
44
+ "question": qa.query,
45
+ "answer": qa.answer,
46
+ "confidence": qa.confidence,
47
+ "sources": [
48
+ {
49
+ "document_uri": c.document_uri,
50
+ "document_title": c.document_title,
51
+ "page_numbers": c.page_numbers,
52
+ "headings": c.headings,
53
+ }
54
+ for c in qa.citations
55
+ ],
56
+ }
57
+ for qa in context.qa_responses
58
+ ],
59
+ }
60
+ return format_as_xml(context_data, root_tag="research_context")
26
61
 
27
62
 
28
63
  def build_research_graph(
29
64
  config: AppConfig = Config,
65
+ include_plan: bool = True,
66
+ interactive: bool = False,
30
67
  ) -> Graph[ResearchState, ResearchDeps, None, ResearchReport]:
31
68
  """Build the Research graph.
32
69
 
33
70
  Args:
34
71
  config: AppConfig object (uses config.research for provider, model, and graph parameters)
72
+ include_plan: Whether to include the planning step (False for execute-only mode)
73
+ interactive: Whether to include human decision nodes for HIL
35
74
 
36
75
  Returns:
37
76
  Configured Research graph
38
77
  """
39
- provider = config.research.provider
40
- model = config.research.model
78
+ model_config = config.research.model
79
+
80
+ # Build prompts with system_context if configured
81
+ plan_prompt = build_prompt(
82
+ PLAN_PROMPT
83
+ + "\n\nUse the gather_context tool once on the main question before planning.",
84
+ config,
85
+ )
86
+ search_prompt = build_prompt(SEARCH_PROMPT, config)
87
+ decision_prompt = build_prompt(DECISION_PROMPT, config)
88
+ synthesis_prompt = build_prompt(
89
+ config.prompts.synthesis or SYNTHESIS_PROMPT, config
90
+ )
41
91
  g = GraphBuilder(
42
92
  state_type=ResearchState,
43
93
  deps_type=ResearchDeps,
44
94
  output_type=ResearchReport,
45
95
  )
46
96
 
47
- # Create and register the plan node using the factory
48
- plan = g.step(
49
- create_plan_node(
50
- provider=provider,
51
- model=model,
52
- deps_type=ResearchDependencies, # type: ignore[arg-type]
53
- activity_message="Creating research plan",
54
- output_retries=3,
55
- )
56
- ) # type: ignore[arg-type]
57
-
58
- # Create and register the search_one node using the factory
59
- search_one = g.step(
60
- create_search_node(
61
- provider=provider,
62
- model=model,
63
- deps_type=ResearchDependencies, # type: ignore[arg-type]
64
- with_step_wrapper=True,
65
- success_message_format="Found answer with {confidence:.0%} confidence",
66
- handle_exceptions=True,
67
- )
68
- ) # type: ignore[arg-type]
69
-
70
97
  @g.step
71
- async def get_batch(
72
- ctx: StepContext[ResearchState, ResearchDeps, None | bool],
73
- ) -> list[str] | None:
74
- """Get all remaining questions for this iteration."""
75
- state = ctx.state
76
-
77
- if not state.context.sub_questions:
78
- return None
79
-
80
- # Take ALL remaining questions and process them in parallel
81
- batch = list(state.context.sub_questions)
82
- state.context.sub_questions.clear()
83
- return batch
84
-
85
- @g.step
86
- async def analyze_insights(
87
- ctx: StepContext[ResearchState, ResearchDeps, list[SearchAnswer]],
88
- ) -> None:
98
+ async def plan(ctx: StepContext[ResearchState, ResearchDeps, None]) -> None:
99
+ """Create research plan with sub-questions."""
89
100
  state = ctx.state
90
101
  deps = ctx.deps
91
102
 
92
103
  if deps.agui_emitter:
93
- deps.agui_emitter.start_step("analyze_insights")
104
+ deps.agui_emitter.start_step("plan")
94
105
  deps.agui_emitter.update_activity(
95
- "analyzing", "Synthesizing insights and gaps"
106
+ "planning", {"stepName": "plan", "message": "Creating research plan"}
96
107
  )
97
108
 
98
109
  try:
99
- agent = Agent(
100
- model=get_model(provider, model),
101
- output_type=InsightAnalysis,
102
- instructions=INSIGHT_AGENT_PROMPT,
110
+ plan_agent = Agent(
111
+ model=get_model(model_config, config),
112
+ output_type=ResearchPlan,
113
+ instructions=plan_prompt,
103
114
  retries=3,
104
115
  output_retries=3,
105
116
  deps_type=ResearchDependencies,
106
117
  )
107
118
 
108
- context_xml = format_context_for_prompt(state.context)
119
+ search_filter = state.search_filter
120
+
121
+ @plan_agent.tool
122
+ async def gather_context(
123
+ ctx2: RunContext[ResearchDependencies],
124
+ query: str,
125
+ limit: int | None = None,
126
+ ) -> str:
127
+ results = await ctx2.deps.client.search(
128
+ query, limit=limit, filter=search_filter
129
+ )
130
+ results = await ctx2.deps.client.expand_context(results)
131
+ return "\n\n".join(r.content for r in results)
132
+
133
+ _ = gather_context
134
+
109
135
  prompt = (
110
- "Review the latest research context and update the shared ledger of insights, gaps,"
111
- " and follow-up questions.\n\n"
112
- f"{context_xml}"
113
- )
114
- agent_deps = ResearchDependencies(
115
- client=deps.client,
116
- context=state.context,
136
+ "Plan a focused approach for the main question.\n\n"
137
+ f"Main question: {state.context.original_question}"
117
138
  )
118
- result = await agent.run(prompt, deps=agent_deps)
119
- analysis: InsightAnalysis = result.output
120
139
 
121
- state.context.integrate_analysis(analysis)
122
- state.last_analysis = analysis
140
+ agent_deps = ResearchDependencies(client=deps.client, context=state.context)
141
+ plan_result = await plan_agent.run(prompt, deps=agent_deps)
142
+ state.context.sub_questions = list(plan_result.output.sub_questions)
123
143
 
124
- # State updated with insights/gaps - emit state update and narrate
125
144
  if deps.agui_emitter:
126
145
  deps.agui_emitter.update_state(state)
127
- highlights = len(analysis.highlights) if analysis.highlights else 0
128
- gaps = len(analysis.gap_assessments) if analysis.gap_assessments else 0
129
- resolved = len(analysis.resolved_gaps) if analysis.resolved_gaps else 0
130
- parts = []
131
- if highlights:
132
- parts.append(f"{highlights} insights")
133
- if gaps:
134
- parts.append(f"{gaps} gaps")
135
- if resolved:
136
- parts.append(f"{resolved} resolved")
137
- summary = ", ".join(parts) if parts else "No updates"
138
- deps.agui_emitter.update_activity("analyzing", f"Analysis: {summary}")
146
+ count = len(state.context.sub_questions)
147
+ deps.agui_emitter.update_activity(
148
+ "planning",
149
+ {
150
+ "stepName": "plan",
151
+ "message": f"Created plan with {count} sub-questions",
152
+ "sub_questions": list(state.context.sub_questions),
153
+ },
154
+ )
139
155
  finally:
140
156
  if deps.agui_emitter:
141
- deps.agui_emitter.finish_step()
157
+ deps.agui_emitter.finish_step("plan")
142
158
 
143
159
  @g.step
144
- async def decide(ctx: StepContext[ResearchState, ResearchDeps, None]) -> bool:
160
+ async def search_one(
161
+ ctx: StepContext[ResearchState, ResearchDeps, str],
162
+ ) -> SearchAnswer:
163
+ """Answer a single sub-question using the knowledge base."""
164
+ state = ctx.state
165
+ deps = ctx.deps
166
+ sub_q = ctx.inputs
167
+ step_name = f"search: {sub_q}"
168
+
169
+ if deps.agui_emitter:
170
+ deps.agui_emitter.start_step(step_name)
171
+
172
+ try:
173
+ if deps.semaphore is None:
174
+ deps.semaphore = asyncio.Semaphore(state.max_concurrency)
175
+
176
+ async with deps.semaphore:
177
+ if deps.agui_emitter:
178
+ deps.agui_emitter.update_activity(
179
+ "searching",
180
+ {
181
+ "stepName": "search_one",
182
+ "message": f"Searching: {sub_q}",
183
+ "query": sub_q,
184
+ },
185
+ )
186
+
187
+ agent = Agent(
188
+ model=get_model(model_config, config),
189
+ output_type=ToolOutput(RawSearchAnswer, max_retries=3),
190
+ instructions=search_prompt,
191
+ retries=3,
192
+ deps_type=ResearchDependencies,
193
+ )
194
+
195
+ search_filter = state.search_filter
196
+
197
+ @agent.tool
198
+ async def search_and_answer(
199
+ ctx2: RunContext[ResearchDependencies],
200
+ query: str,
201
+ limit: int | None = None,
202
+ ) -> str:
203
+ """Search the knowledge base for relevant documents."""
204
+ results = await ctx2.deps.client.search(
205
+ query, limit=limit, filter=search_filter
206
+ )
207
+ results = await ctx2.deps.client.expand_context(results)
208
+ ctx2.deps.search_results = results
209
+ parts = [r.format_for_agent() for r in results]
210
+ if not parts:
211
+ return f"No relevant information found for: {query}"
212
+ return "\n\n".join(parts)
213
+
214
+ _ = search_and_answer
215
+
216
+ agent_deps = ResearchDependencies(
217
+ client=deps.client, context=state.context
218
+ )
219
+
220
+ try:
221
+ result = await agent.run(sub_q, deps=agent_deps)
222
+ raw_answer = result.output
223
+ if raw_answer:
224
+ answer = SearchAnswer.from_raw(
225
+ raw_answer, agent_deps.search_results
226
+ )
227
+ state.context.add_qa_response(answer)
228
+ if deps.agui_emitter:
229
+ deps.agui_emitter.update_state(state)
230
+ deps.agui_emitter.update_activity(
231
+ "searching",
232
+ {
233
+ "stepName": "search_one",
234
+ "message": f"Found answer with {answer.confidence:.0%} confidence",
235
+ "query": sub_q,
236
+ "confidence": answer.confidence,
237
+ },
238
+ )
239
+ return answer
240
+ return SearchAnswer(query=sub_q, answer="", confidence=0.0)
241
+ except Exception as e:
242
+ if deps.agui_emitter:
243
+ deps.agui_emitter.update_activity(
244
+ "searching",
245
+ {
246
+ "stepName": "search_one",
247
+ "message": f"Search failed: {e}",
248
+ "query": sub_q,
249
+ "error": str(e),
250
+ },
251
+ )
252
+ return SearchAnswer(
253
+ query=sub_q,
254
+ answer=f"Search failed: {str(e)}",
255
+ confidence=0.0,
256
+ )
257
+ finally:
258
+ if deps.agui_emitter:
259
+ deps.agui_emitter.finish_step(step_name)
260
+
261
+ @g.step
262
+ async def get_batch(
263
+ ctx: StepContext[ResearchState, ResearchDeps, None | bool | str],
264
+ ) -> list[str] | None:
265
+ """Get all remaining questions for this iteration."""
266
+ state = ctx.state
267
+
268
+ if not state.context.sub_questions:
269
+ return None
270
+
271
+ batch = list(state.context.sub_questions)
272
+ state.context.sub_questions.clear()
273
+ return batch
274
+
275
+ @g.step
276
+ async def decide(
277
+ ctx: StepContext[ResearchState, ResearchDeps, list[SearchAnswer]],
278
+ ) -> bool:
279
+ """Evaluate research sufficiency and decide whether to continue."""
145
280
  state = ctx.state
146
281
  deps = ctx.deps
147
282
 
148
283
  if deps.agui_emitter:
149
284
  deps.agui_emitter.start_step("decide")
150
285
  deps.agui_emitter.update_activity(
151
- "evaluating", "Evaluating research sufficiency"
286
+ "evaluating", {"message": "Evaluating research sufficiency"}
152
287
  )
153
288
 
154
289
  try:
155
290
  agent = Agent(
156
- model=get_model(provider, model),
291
+ model=get_model(model_config, config),
157
292
  output_type=EvaluationResult,
158
- instructions=DECISION_AGENT_PROMPT,
293
+ instructions=decision_prompt,
159
294
  retries=3,
160
295
  output_retries=3,
161
296
  deps_type=ResearchDependencies,
162
297
  )
163
298
 
164
299
  context_xml = format_context_for_prompt(state.context)
165
- analysis_xml = format_analysis_for_prompt(state.last_analysis)
166
300
  prompt_parts = [
167
301
  "Assess whether the research now answers the original question with adequate confidence.",
168
302
  context_xml,
169
- analysis_xml,
170
303
  ]
171
304
  if state.last_eval is not None:
172
305
  prev = state.last_eval
@@ -189,17 +322,28 @@ def build_research_graph(
189
322
  state.last_eval = output
190
323
  state.iterations += 1
191
324
 
325
+ # Get already-answered questions to avoid duplicates
326
+ answered_queries = {qa.query.lower() for qa in state.context.qa_responses}
327
+
192
328
  for new_q in output.new_questions:
193
- if new_q not in state.context.sub_questions:
194
- state.context.sub_questions.append(new_q)
329
+ # Skip if already in pending or already answered
330
+ if new_q in state.context.sub_questions:
331
+ continue
332
+ if new_q.lower() in answered_queries:
333
+ continue
334
+ state.context.sub_questions.append(new_q)
195
335
 
196
- # State updated with evaluation - emit state update and narrate
197
336
  if deps.agui_emitter:
198
337
  deps.agui_emitter.update_state(state)
199
338
  sufficient = "Yes" if output.is_sufficient else "No"
200
339
  deps.agui_emitter.update_activity(
201
340
  "evaluating",
202
- f"Confidence: {output.confidence_score:.0%}, Sufficient: {sufficient}",
341
+ {
342
+ "stepName": "decide",
343
+ "message": f"Confidence: {output.confidence_score:.0%}, Sufficient: {sufficient}",
344
+ "confidence": output.confidence_score,
345
+ "is_sufficient": output.is_sufficient,
346
+ },
203
347
  )
204
348
 
205
349
  should_continue = (
@@ -210,26 +354,100 @@ def build_research_graph(
210
354
  return should_continue
211
355
  finally:
212
356
  if deps.agui_emitter:
213
- deps.agui_emitter.finish_step()
357
+ deps.agui_emitter.finish_step("decide")
358
+
359
+ @g.step
360
+ async def human_decide(
361
+ ctx: StepContext[ResearchState, ResearchDeps, list[SearchAnswer] | None | bool],
362
+ ) -> Literal["search", "synthesize"]:
363
+ """Wait for human decision on whether to continue searching or synthesize."""
364
+ state = ctx.state
365
+ deps = ctx.deps
366
+
367
+ if deps.agui_emitter:
368
+ deps.agui_emitter.start_step("human_decide")
369
+ deps.agui_emitter.update_state(state)
370
+
371
+ try:
372
+ # Emit tool call for human input wrapped in a message context
373
+ # This makes the tool call appear as if emitted by the LLM
374
+ message_id = str(uuid4())
375
+ tool_call_id = str(uuid4())
376
+
377
+ if deps.agui_emitter:
378
+ # Start an assistant message to contain the tool call
379
+ deps.agui_emitter.emit(emit_text_message_start(message_id))
380
+ # Emit tool call with parent message reference
381
+ deps.agui_emitter.emit(
382
+ emit_tool_call_start(tool_call_id, "human_decision", message_id)
383
+ )
384
+ # Include full state for display
385
+ qa_responses = [
386
+ {
387
+ "query": qa.query,
388
+ "answer": qa.answer,
389
+ "confidence": qa.confidence,
390
+ "citations_count": len(qa.citations),
391
+ }
392
+ for qa in state.context.qa_responses
393
+ ]
394
+ deps.agui_emitter.emit(
395
+ emit_tool_call_args(
396
+ tool_call_id,
397
+ {
398
+ "original_question": state.context.original_question,
399
+ "sub_questions": list(state.context.sub_questions),
400
+ "qa_responses": qa_responses,
401
+ "iterations": state.iterations,
402
+ },
403
+ )
404
+ )
405
+ deps.agui_emitter.emit(emit_tool_call_end(tool_call_id))
406
+ # End the message after tool call
407
+ deps.agui_emitter.emit(emit_text_message_end(message_id))
408
+
409
+ # Wait for human input
410
+ if deps.human_input_queue is None:
411
+ raise RuntimeError("human_input_queue is required for interactive mode")
412
+
413
+ decision = await deps.human_input_queue.get()
414
+
415
+ # Process decision
416
+ if decision.action == "modify_questions" and decision.questions:
417
+ state.context.sub_questions = list(decision.questions)
418
+ elif decision.action == "add_questions" and decision.questions:
419
+ state.context.sub_questions.extend(decision.questions)
420
+
421
+ if deps.agui_emitter:
422
+ deps.agui_emitter.update_state(state)
423
+
424
+ if decision.action in ("search", "modify_questions", "add_questions"):
425
+ return "search"
426
+ else:
427
+ return "synthesize"
428
+ finally:
429
+ if deps.agui_emitter:
430
+ deps.agui_emitter.finish_step("human_decide")
214
431
 
215
432
  @g.step
216
433
  async def synthesize(
217
- ctx: StepContext[ResearchState, ResearchDeps, None | bool],
434
+ ctx: StepContext[ResearchState, ResearchDeps, None | bool | str],
218
435
  ) -> ResearchReport:
436
+ """Generate final research report."""
219
437
  state = ctx.state
220
438
  deps = ctx.deps
221
439
 
222
440
  if deps.agui_emitter:
223
441
  deps.agui_emitter.start_step("synthesize")
224
442
  deps.agui_emitter.update_activity(
225
- "synthesizing", "Generating final research report"
443
+ "synthesizing", {"message": "Generating final research report"}
226
444
  )
227
445
 
228
446
  try:
229
447
  agent = Agent(
230
- model=get_model(provider, model),
448
+ model=get_model(model_config, config),
231
449
  output_type=ResearchReport,
232
- instructions=SYNTHESIS_AGENT_PROMPT,
450
+ instructions=synthesis_prompt,
233
451
  retries=3,
234
452
  output_retries=3,
235
453
  deps_type=ResearchDependencies,
@@ -249,7 +467,7 @@ def build_research_graph(
249
467
  return result.output
250
468
  finally:
251
469
  if deps.agui_emitter:
252
- deps.agui_emitter.finish_step()
470
+ deps.agui_emitter.finish_step("synthesize")
253
471
 
254
472
  # Build the graph structure
255
473
  collect_answers = g.join(
@@ -257,39 +475,76 @@ def build_research_graph(
257
475
  initial_factory=list[SearchAnswer],
258
476
  )
259
477
 
260
- g.add(
261
- g.edge_from(g.start_node).to(plan),
262
- g.edge_from(plan).to(get_batch),
263
- )
264
-
265
- # Branch based on whether we have questions
266
- g.add(
267
- g.edge_from(get_batch).to(
268
- g.decision()
269
- .branch(g.match(list).label("Has questions").map().to(search_one))
270
- .branch(g.match(type(None)).label("No questions").to(synthesize))
271
- ),
272
- g.edge_from(search_one).to(collect_answers),
273
- g.edge_from(collect_answers).to(analyze_insights),
274
- g.edge_from(analyze_insights).to(decide),
275
- )
276
-
277
- # Branch based on decision
278
- g.add(
279
- g.edge_from(decide).to(
280
- g.decision()
281
- .branch(
282
- g.match(bool, matches=lambda x: x)
283
- .label("Continue research")
284
- .to(get_batch)
478
+ if interactive:
479
+ # Interactive mode: human decides after plan and after evaluation
480
+ if include_plan:
481
+ g.add(
482
+ g.edge_from(g.start_node).to(plan),
483
+ g.edge_from(plan).to(human_decide),
285
484
  )
286
- .branch(
287
- g.match(bool, matches=lambda x: not x)
288
- .label("Done researching")
289
- .to(synthesize)
485
+ else:
486
+ g.add(g.edge_from(g.start_node).to(human_decide))
487
+
488
+ g.add(
489
+ g.edge_from(human_decide).to(
490
+ g.decision()
491
+ .branch(
492
+ g.match(str, matches=lambda x: x == "search")
493
+ .label("Search")
494
+ .to(get_batch)
495
+ )
496
+ .branch(
497
+ g.match(str, matches=lambda x: x == "synthesize")
498
+ .label("Synthesize")
499
+ .to(synthesize)
500
+ )
501
+ ),
502
+ g.edge_from(get_batch).to(
503
+ g.decision()
504
+ .branch(g.match(list).label("Has questions").map().to(search_one))
505
+ .branch(g.match(type(None)).label("No questions").to(human_decide))
506
+ ),
507
+ g.edge_from(search_one).to(collect_answers),
508
+ # After search, evaluate to suggest new questions, then human decides
509
+ g.edge_from(collect_answers).to(decide),
510
+ g.edge_from(decide).to(human_decide),
511
+ g.edge_from(synthesize).to(g.end_node),
512
+ )
513
+ else:
514
+ # Non-interactive mode: automatic decision based on confidence/iterations
515
+ if include_plan:
516
+ g.add(
517
+ g.edge_from(g.start_node).to(plan),
518
+ g.edge_from(plan).to(get_batch),
290
519
  )
291
- ),
292
- g.edge_from(synthesize).to(g.end_node),
293
- )
520
+ else:
521
+ g.add(g.edge_from(g.start_node).to(get_batch))
522
+
523
+ g.add(
524
+ g.edge_from(get_batch).to(
525
+ g.decision()
526
+ .branch(g.match(list).label("Has questions").map().to(search_one))
527
+ .branch(g.match(type(None)).label("No questions").to(synthesize))
528
+ ),
529
+ g.edge_from(search_one).to(collect_answers),
530
+ g.edge_from(collect_answers).to(decide),
531
+ )
532
+
533
+ g.add(
534
+ g.edge_from(decide).to(
535
+ g.decision()
536
+ .branch(
537
+ g.match(bool, matches=lambda x: x)
538
+ .label("Continue research")
539
+ .to(get_batch)
540
+ )
541
+ .branch(
542
+ g.match(bool, matches=lambda x: not x)
543
+ .label("Done researching")
544
+ .to(synthesize)
545
+ )
546
+ ),
547
+ g.edge_from(synthesize).to(g.end_node),
548
+ )
294
549
 
295
550
  return g.build()