haiku.rag 0.9.2__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/app.py +50 -14
- haiku/rag/cli.py +16 -4
- haiku/rag/client.py +3 -5
- haiku/rag/reranking/mxbai.py +1 -1
- haiku/rag/research/__init__.py +10 -27
- haiku/rag/research/common.py +53 -0
- haiku/rag/research/dependencies.py +5 -3
- haiku/rag/research/graph.py +29 -0
- haiku/rag/research/models.py +70 -0
- haiku/rag/research/nodes/evaluate.py +80 -0
- haiku/rag/research/nodes/plan.py +63 -0
- haiku/rag/research/nodes/search.py +91 -0
- haiku/rag/research/nodes/synthesize.py +51 -0
- haiku/rag/research/prompts.py +97 -113
- haiku/rag/research/state.py +25 -0
- haiku/rag/store/engine.py +42 -17
- haiku/rag/store/models/chunk.py +1 -0
- haiku/rag/store/repositories/chunk.py +60 -39
- haiku/rag/store/repositories/document.py +2 -2
- haiku/rag/store/repositories/settings.py +12 -5
- haiku/rag/store/upgrades/__init__.py +60 -1
- haiku/rag/store/upgrades/v0_9_3.py +112 -0
- {haiku_rag-0.9.2.dist-info → haiku_rag-0.10.0.dist-info}/METADATA +37 -1
- haiku_rag-0.10.0.dist-info/RECORD +53 -0
- haiku/rag/research/base.py +0 -130
- haiku/rag/research/evaluation_agent.py +0 -42
- haiku/rag/research/orchestrator.py +0 -300
- haiku/rag/research/presearch_agent.py +0 -34
- haiku/rag/research/search_agent.py +0 -65
- haiku/rag/research/synthesis_agent.py +0 -40
- haiku_rag-0.9.2.dist-info/RECORD +0 -50
- {haiku_rag-0.9.2.dist-info → haiku_rag-0.10.0.dist-info}/WHEEL +0 -0
- {haiku_rag-0.9.2.dist-info → haiku_rag-0.10.0.dist-info}/entry_points.txt +0 -0
- {haiku_rag-0.9.2.dist-info → haiku_rag-0.10.0.dist-info}/licenses/LICENSE +0 -0
haiku/rag/app.py
CHANGED
|
@@ -9,7 +9,13 @@ from haiku.rag.client import HaikuRAG
|
|
|
9
9
|
from haiku.rag.config import Config
|
|
10
10
|
from haiku.rag.mcp import create_mcp_server
|
|
11
11
|
from haiku.rag.monitor import FileWatcher
|
|
12
|
-
from haiku.rag.research.
|
|
12
|
+
from haiku.rag.research.dependencies import ResearchContext
|
|
13
|
+
from haiku.rag.research.graph import (
|
|
14
|
+
PlanNode,
|
|
15
|
+
ResearchDeps,
|
|
16
|
+
ResearchState,
|
|
17
|
+
build_research_graph,
|
|
18
|
+
)
|
|
13
19
|
from haiku.rag.store.models.chunk import Chunk
|
|
14
20
|
from haiku.rag.store.models.document import Document
|
|
15
21
|
|
|
@@ -80,30 +86,54 @@ class HaikuRAGApp:
|
|
|
80
86
|
self.console.print(f"[red]Error: {e}[/red]")
|
|
81
87
|
|
|
82
88
|
async def research(
|
|
83
|
-
self,
|
|
89
|
+
self,
|
|
90
|
+
question: str,
|
|
91
|
+
max_iterations: int = 3,
|
|
92
|
+
confidence_threshold: float = 0.8,
|
|
93
|
+
max_concurrency: int = 1,
|
|
94
|
+
verbose: bool = False,
|
|
84
95
|
):
|
|
85
|
-
"""Run
|
|
96
|
+
"""Run research via the pydantic-graph pipeline (default)."""
|
|
86
97
|
async with HaikuRAG(db_path=self.db_path) as client:
|
|
87
98
|
try:
|
|
88
|
-
# Create orchestrator with default config or fallback to QA
|
|
89
|
-
orchestrator = ResearchOrchestrator()
|
|
90
|
-
|
|
91
99
|
if verbose:
|
|
92
|
-
self.console.print(
|
|
93
|
-
f"[bold cyan]Starting research with {orchestrator.provider}:{orchestrator.model}[/bold cyan]"
|
|
94
|
-
)
|
|
100
|
+
self.console.print("[bold cyan]Starting research[/bold cyan]")
|
|
95
101
|
self.console.print(f"[bold blue]Question:[/bold blue] {question}")
|
|
96
102
|
self.console.print()
|
|
97
103
|
|
|
98
|
-
|
|
99
|
-
|
|
104
|
+
graph = build_research_graph()
|
|
105
|
+
state = ResearchState(
|
|
100
106
|
question=question,
|
|
101
|
-
|
|
107
|
+
context=ResearchContext(original_question=question),
|
|
102
108
|
max_iterations=max_iterations,
|
|
103
|
-
|
|
104
|
-
|
|
109
|
+
confidence_threshold=confidence_threshold,
|
|
110
|
+
max_concurrency=max_concurrency,
|
|
111
|
+
)
|
|
112
|
+
deps = ResearchDeps(
|
|
113
|
+
client=client, console=self.console if verbose else None
|
|
105
114
|
)
|
|
106
115
|
|
|
116
|
+
start = PlanNode(
|
|
117
|
+
provider=Config.RESEARCH_PROVIDER or Config.QA_PROVIDER,
|
|
118
|
+
model=Config.RESEARCH_MODEL or Config.QA_MODEL,
|
|
119
|
+
)
|
|
120
|
+
# Prefer graph.run; fall back to iter if unavailable
|
|
121
|
+
report = None
|
|
122
|
+
try:
|
|
123
|
+
result = await graph.run(start, state=state, deps=deps)
|
|
124
|
+
report = result.output
|
|
125
|
+
except Exception:
|
|
126
|
+
from pydantic_graph import End
|
|
127
|
+
|
|
128
|
+
async with graph.iter(start, state=state, deps=deps) as run:
|
|
129
|
+
node = run.next_node
|
|
130
|
+
while not isinstance(node, End):
|
|
131
|
+
node = await run.next(node)
|
|
132
|
+
if run.result:
|
|
133
|
+
report = run.result.output
|
|
134
|
+
if report is None:
|
|
135
|
+
raise RuntimeError("Graph did not produce a report")
|
|
136
|
+
|
|
107
137
|
# Display the report
|
|
108
138
|
self.console.print("[bold green]Research Report[/bold green]")
|
|
109
139
|
self.console.rule()
|
|
@@ -115,6 +145,12 @@ class HaikuRAGApp:
|
|
|
115
145
|
self.console.print(report.executive_summary)
|
|
116
146
|
self.console.print()
|
|
117
147
|
|
|
148
|
+
# Confidence (from last evaluation)
|
|
149
|
+
if state.last_eval:
|
|
150
|
+
conf = state.last_eval.confidence_score # type: ignore[attr-defined]
|
|
151
|
+
self.console.print(f"[bold cyan]Confidence:[/bold cyan] {conf:.1%}")
|
|
152
|
+
self.console.print()
|
|
153
|
+
|
|
118
154
|
# Main Findings
|
|
119
155
|
if report.main_findings:
|
|
120
156
|
self.console.print("[bold cyan]Main Findings:[/bold cyan]")
|
haiku/rag/cli.py
CHANGED
|
@@ -13,10 +13,10 @@ from haiku.rag.logging import configure_cli_logging
|
|
|
13
13
|
from haiku.rag.migration import migrate_sqlite_to_lancedb
|
|
14
14
|
from haiku.rag.utils import is_up_to_date
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
logfire.
|
|
18
|
-
|
|
19
|
-
|
|
16
|
+
if Config.ENV == "development":
|
|
17
|
+
logfire.configure(send_to_logfire="if-token-present")
|
|
18
|
+
logfire.instrument_pydantic_ai()
|
|
19
|
+
else:
|
|
20
20
|
warnings.filterwarnings("ignore")
|
|
21
21
|
|
|
22
22
|
cli = typer.Typer(
|
|
@@ -250,6 +250,16 @@ def research(
|
|
|
250
250
|
"-n",
|
|
251
251
|
help="Maximum search/analyze iterations",
|
|
252
252
|
),
|
|
253
|
+
confidence_threshold: float = typer.Option(
|
|
254
|
+
0.8,
|
|
255
|
+
"--confidence-threshold",
|
|
256
|
+
help="Minimum confidence (0-1) to stop",
|
|
257
|
+
),
|
|
258
|
+
max_concurrency: int = typer.Option(
|
|
259
|
+
1,
|
|
260
|
+
"--max-concurrency",
|
|
261
|
+
help="Max concurrent searches per iteration (planned)",
|
|
262
|
+
),
|
|
253
263
|
db: Path = typer.Option(
|
|
254
264
|
Config.DEFAULT_DATA_DIR / "haiku.rag.lancedb",
|
|
255
265
|
"--db",
|
|
@@ -266,6 +276,8 @@ def research(
|
|
|
266
276
|
app.research(
|
|
267
277
|
question=question,
|
|
268
278
|
max_iterations=max_iterations,
|
|
279
|
+
confidence_threshold=confidence_threshold,
|
|
280
|
+
max_concurrency=max_concurrency,
|
|
269
281
|
verbose=verbose,
|
|
270
282
|
)
|
|
271
283
|
)
|
haiku/rag/client.py
CHANGED
|
@@ -388,7 +388,7 @@ class HaikuRAG:
|
|
|
388
388
|
all_chunks = adjacent_chunks + [chunk]
|
|
389
389
|
|
|
390
390
|
# Get the range of orders for this expanded chunk
|
|
391
|
-
orders = [c.
|
|
391
|
+
orders = [c.order for c in all_chunks]
|
|
392
392
|
min_order = min(orders)
|
|
393
393
|
max_order = max(orders)
|
|
394
394
|
|
|
@@ -398,9 +398,7 @@ class HaikuRAG:
|
|
|
398
398
|
"score": score,
|
|
399
399
|
"min_order": min_order,
|
|
400
400
|
"max_order": max_order,
|
|
401
|
-
"all_chunks": sorted(
|
|
402
|
-
all_chunks, key=lambda c: c.metadata.get("order", 0)
|
|
403
|
-
),
|
|
401
|
+
"all_chunks": sorted(all_chunks, key=lambda c: c.order),
|
|
404
402
|
}
|
|
405
403
|
)
|
|
406
404
|
|
|
@@ -459,7 +457,7 @@ class HaikuRAG:
|
|
|
459
457
|
# Merge all_chunks and deduplicate by order
|
|
460
458
|
all_chunks_dict = {}
|
|
461
459
|
for chunk in current["all_chunks"] + range_info["all_chunks"]:
|
|
462
|
-
order = chunk.
|
|
460
|
+
order = chunk.order
|
|
463
461
|
all_chunks_dict[order] = chunk
|
|
464
462
|
current["all_chunks"] = [
|
|
465
463
|
all_chunks_dict[order] for order in sorted(all_chunks_dict.keys())
|
haiku/rag/reranking/mxbai.py
CHANGED
haiku/rag/research/__init__.py
CHANGED
|
@@ -1,37 +1,20 @@
|
|
|
1
|
-
"""Multi-agent research workflow for advanced RAG queries."""
|
|
2
|
-
|
|
3
|
-
from haiku.rag.research.base import (
|
|
4
|
-
BaseResearchAgent,
|
|
5
|
-
ResearchOutput,
|
|
6
|
-
SearchAnswer,
|
|
7
|
-
SearchResult,
|
|
8
|
-
)
|
|
9
1
|
from haiku.rag.research.dependencies import ResearchContext, ResearchDependencies
|
|
10
|
-
from haiku.rag.research.
|
|
11
|
-
|
|
12
|
-
|
|
2
|
+
from haiku.rag.research.graph import (
|
|
3
|
+
PlanNode,
|
|
4
|
+
ResearchDeps,
|
|
5
|
+
ResearchState,
|
|
6
|
+
build_research_graph,
|
|
13
7
|
)
|
|
14
|
-
from haiku.rag.research.
|
|
15
|
-
from haiku.rag.research.presearch_agent import PresearchSurveyAgent
|
|
16
|
-
from haiku.rag.research.search_agent import SearchSpecialistAgent
|
|
17
|
-
from haiku.rag.research.synthesis_agent import ResearchReport, SynthesisAgent
|
|
8
|
+
from haiku.rag.research.models import EvaluationResult, ResearchReport, SearchAnswer
|
|
18
9
|
|
|
19
10
|
__all__ = [
|
|
20
|
-
# Base classes
|
|
21
|
-
"BaseResearchAgent",
|
|
22
11
|
"ResearchDependencies",
|
|
23
12
|
"ResearchContext",
|
|
24
|
-
"SearchResult",
|
|
25
|
-
"ResearchOutput",
|
|
26
|
-
# Specialized agents
|
|
27
13
|
"SearchAnswer",
|
|
28
|
-
"SearchSpecialistAgent",
|
|
29
|
-
"PresearchSurveyAgent",
|
|
30
|
-
"AnalysisEvaluationAgent",
|
|
31
14
|
"EvaluationResult",
|
|
32
|
-
"SynthesisAgent",
|
|
33
15
|
"ResearchReport",
|
|
34
|
-
|
|
35
|
-
"
|
|
36
|
-
"
|
|
16
|
+
"ResearchDeps",
|
|
17
|
+
"ResearchState",
|
|
18
|
+
"PlanNode",
|
|
19
|
+
"build_research_graph",
|
|
37
20
|
]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import format_as_xml
|
|
4
|
+
from pydantic_ai.models.openai import OpenAIChatModel
|
|
5
|
+
from pydantic_ai.providers.ollama import OllamaProvider
|
|
6
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
7
|
+
|
|
8
|
+
from haiku.rag.config import Config
|
|
9
|
+
from haiku.rag.research.dependencies import ResearchContext
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_model(provider: str, model: str) -> Any:
|
|
13
|
+
if provider == "ollama":
|
|
14
|
+
return OpenAIChatModel(
|
|
15
|
+
model_name=model,
|
|
16
|
+
provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
|
|
17
|
+
)
|
|
18
|
+
elif provider == "vllm":
|
|
19
|
+
return OpenAIChatModel(
|
|
20
|
+
model_name=model,
|
|
21
|
+
provider=OpenAIProvider(
|
|
22
|
+
base_url=f"{Config.VLLM_RESEARCH_BASE_URL or Config.VLLM_QA_BASE_URL}/v1",
|
|
23
|
+
api_key="none",
|
|
24
|
+
),
|
|
25
|
+
)
|
|
26
|
+
else:
|
|
27
|
+
return f"{provider}:{model}"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def log(console, msg: str) -> None:
|
|
31
|
+
if console:
|
|
32
|
+
console.print(msg)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def format_context_for_prompt(context: ResearchContext) -> str:
|
|
36
|
+
"""Format the research context as XML for inclusion in prompts."""
|
|
37
|
+
|
|
38
|
+
context_data = {
|
|
39
|
+
"original_question": context.original_question,
|
|
40
|
+
"unanswered_questions": context.sub_questions,
|
|
41
|
+
"qa_responses": [
|
|
42
|
+
{
|
|
43
|
+
"question": qa.query,
|
|
44
|
+
"answer": qa.answer,
|
|
45
|
+
"context_snippets": qa.context,
|
|
46
|
+
"sources": qa.sources, # pyright: ignore[reportAttributeAccessIssue]
|
|
47
|
+
}
|
|
48
|
+
for qa in context.qa_responses
|
|
49
|
+
],
|
|
50
|
+
"insights": context.insights,
|
|
51
|
+
"gaps": context.gaps,
|
|
52
|
+
}
|
|
53
|
+
return format_as_xml(context_data, root_tag="research_context")
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from pydantic import BaseModel, Field
|
|
2
|
+
from rich.console import Console
|
|
2
3
|
|
|
3
4
|
from haiku.rag.client import HaikuRAG
|
|
4
|
-
from haiku.rag.research.
|
|
5
|
+
from haiku.rag.research.models import SearchAnswer
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class ResearchContext(BaseModel):
|
|
@@ -11,7 +12,7 @@ class ResearchContext(BaseModel):
|
|
|
11
12
|
sub_questions: list[str] = Field(
|
|
12
13
|
default_factory=list, description="Decomposed sub-questions"
|
|
13
14
|
)
|
|
14
|
-
qa_responses: list[
|
|
15
|
+
qa_responses: list[SearchAnswer] = Field(
|
|
15
16
|
default_factory=list, description="Structured QA pairs used during research"
|
|
16
17
|
)
|
|
17
18
|
insights: list[str] = Field(
|
|
@@ -21,7 +22,7 @@ class ResearchContext(BaseModel):
|
|
|
21
22
|
default_factory=list, description="Identified information gaps"
|
|
22
23
|
)
|
|
23
24
|
|
|
24
|
-
def add_qa_response(self, qa:
|
|
25
|
+
def add_qa_response(self, qa: SearchAnswer) -> None:
|
|
25
26
|
"""Add a structured QA response (minimal context already included)."""
|
|
26
27
|
self.qa_responses.append(qa)
|
|
27
28
|
|
|
@@ -43,3 +44,4 @@ class ResearchDependencies(BaseModel):
|
|
|
43
44
|
|
|
44
45
|
client: HaikuRAG = Field(description="RAG client for document operations")
|
|
45
46
|
context: ResearchContext = Field(description="Shared research context")
|
|
47
|
+
console: Console | None = None
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from pydantic_graph import Graph
|
|
2
|
+
|
|
3
|
+
from haiku.rag.research.models import ResearchReport
|
|
4
|
+
from haiku.rag.research.nodes.evaluate import EvaluateNode
|
|
5
|
+
from haiku.rag.research.nodes.plan import PlanNode
|
|
6
|
+
from haiku.rag.research.nodes.search import SearchDispatchNode
|
|
7
|
+
from haiku.rag.research.nodes.synthesize import SynthesizeNode
|
|
8
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"PlanNode",
|
|
12
|
+
"SearchDispatchNode",
|
|
13
|
+
"EvaluateNode",
|
|
14
|
+
"SynthesizeNode",
|
|
15
|
+
"ResearchState",
|
|
16
|
+
"ResearchDeps",
|
|
17
|
+
"build_research_graph",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def build_research_graph() -> Graph[ResearchState, ResearchDeps, ResearchReport]:
|
|
22
|
+
return Graph(
|
|
23
|
+
nodes=[
|
|
24
|
+
PlanNode,
|
|
25
|
+
SearchDispatchNode,
|
|
26
|
+
EvaluateNode,
|
|
27
|
+
SynthesizeNode,
|
|
28
|
+
]
|
|
29
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ResearchPlan(BaseModel):
|
|
5
|
+
main_question: str
|
|
6
|
+
sub_questions: list[str]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SearchAnswer(BaseModel):
|
|
10
|
+
"""Structured output for the SearchSpecialist agent."""
|
|
11
|
+
|
|
12
|
+
query: str = Field(description="The search query that was performed")
|
|
13
|
+
answer: str = Field(description="The answer generated based on the context")
|
|
14
|
+
context: list[str] = Field(
|
|
15
|
+
description=(
|
|
16
|
+
"Only the minimal set of relevant snippets (verbatim) that directly "
|
|
17
|
+
"support the answer"
|
|
18
|
+
)
|
|
19
|
+
)
|
|
20
|
+
sources: list[str] = Field(
|
|
21
|
+
description=(
|
|
22
|
+
"Document URIs corresponding to the snippets actually used in the"
|
|
23
|
+
" answer (one URI per snippet; omit if none)"
|
|
24
|
+
),
|
|
25
|
+
default_factory=list,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class EvaluationResult(BaseModel):
|
|
30
|
+
"""Result of analysis and evaluation."""
|
|
31
|
+
|
|
32
|
+
key_insights: list[str] = Field(
|
|
33
|
+
description="Main insights extracted from the research so far"
|
|
34
|
+
)
|
|
35
|
+
new_questions: list[str] = Field(
|
|
36
|
+
description="New sub-questions to add to the research (max 3)",
|
|
37
|
+
max_length=3,
|
|
38
|
+
default=[],
|
|
39
|
+
)
|
|
40
|
+
confidence_score: float = Field(
|
|
41
|
+
description="Confidence level in the completeness of research (0-1)",
|
|
42
|
+
ge=0.0,
|
|
43
|
+
le=1.0,
|
|
44
|
+
)
|
|
45
|
+
is_sufficient: bool = Field(
|
|
46
|
+
description="Whether the research is sufficient to answer the original question"
|
|
47
|
+
)
|
|
48
|
+
reasoning: str = Field(
|
|
49
|
+
description="Explanation of why the research is or isn't complete"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ResearchReport(BaseModel):
|
|
54
|
+
"""Final research report structure."""
|
|
55
|
+
|
|
56
|
+
title: str = Field(description="Concise title for the research")
|
|
57
|
+
executive_summary: str = Field(description="Brief overview of key findings")
|
|
58
|
+
main_findings: list[str] = Field(
|
|
59
|
+
description="Primary research findings with supporting evidence"
|
|
60
|
+
)
|
|
61
|
+
conclusions: list[str] = Field(description="Evidence-based conclusions")
|
|
62
|
+
limitations: list[str] = Field(
|
|
63
|
+
description="Limitations of the current research", default=[]
|
|
64
|
+
)
|
|
65
|
+
recommendations: list[str] = Field(
|
|
66
|
+
description="Actionable recommendations based on findings", default=[]
|
|
67
|
+
)
|
|
68
|
+
sources_summary: str = Field(
|
|
69
|
+
description="Summary of sources used and their reliability"
|
|
70
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import Agent
|
|
4
|
+
from pydantic_graph import BaseNode, GraphRunContext
|
|
5
|
+
|
|
6
|
+
from haiku.rag.research.common import format_context_for_prompt, get_model, log
|
|
7
|
+
from haiku.rag.research.dependencies import (
|
|
8
|
+
ResearchDependencies,
|
|
9
|
+
)
|
|
10
|
+
from haiku.rag.research.models import EvaluationResult, ResearchReport
|
|
11
|
+
from haiku.rag.research.nodes.synthesize import SynthesizeNode
|
|
12
|
+
from haiku.rag.research.prompts import EVALUATION_AGENT_PROMPT
|
|
13
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class EvaluateNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
18
|
+
provider: str
|
|
19
|
+
model: str
|
|
20
|
+
|
|
21
|
+
async def run(
|
|
22
|
+
self, ctx: GraphRunContext[ResearchState, ResearchDeps]
|
|
23
|
+
) -> BaseNode[ResearchState, ResearchDeps, ResearchReport]:
|
|
24
|
+
state = ctx.state
|
|
25
|
+
deps = ctx.deps
|
|
26
|
+
|
|
27
|
+
log(
|
|
28
|
+
deps.console,
|
|
29
|
+
"\n[bold cyan]📊 Analyzing and evaluating research progress...[/bold cyan]",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
agent = Agent(
|
|
33
|
+
model=get_model(self.provider, self.model),
|
|
34
|
+
output_type=EvaluationResult,
|
|
35
|
+
instructions=EVALUATION_AGENT_PROMPT,
|
|
36
|
+
retries=3,
|
|
37
|
+
deps_type=ResearchDependencies,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
context_xml = format_context_for_prompt(state.context)
|
|
41
|
+
prompt = (
|
|
42
|
+
"Analyze gathered information and evaluate completeness for the original question.\n\n"
|
|
43
|
+
f"{context_xml}"
|
|
44
|
+
)
|
|
45
|
+
agent_deps = ResearchDependencies(
|
|
46
|
+
client=deps.client, context=state.context, console=deps.console
|
|
47
|
+
)
|
|
48
|
+
eval_result = await agent.run(prompt, deps=agent_deps)
|
|
49
|
+
output = eval_result.output
|
|
50
|
+
|
|
51
|
+
for insight in output.key_insights:
|
|
52
|
+
state.context.add_insight(insight)
|
|
53
|
+
for new_q in output.new_questions:
|
|
54
|
+
if new_q not in state.sub_questions:
|
|
55
|
+
state.sub_questions.append(new_q)
|
|
56
|
+
|
|
57
|
+
state.last_eval = output
|
|
58
|
+
state.iterations += 1
|
|
59
|
+
|
|
60
|
+
if output.key_insights:
|
|
61
|
+
log(deps.console, " [bold]Key insights:[/bold]")
|
|
62
|
+
for ins in output.key_insights:
|
|
63
|
+
log(deps.console, f" • {ins}")
|
|
64
|
+
log(
|
|
65
|
+
deps.console,
|
|
66
|
+
f" Confidence: [yellow]{output.confidence_score:.1%}[/yellow]",
|
|
67
|
+
)
|
|
68
|
+
status = "[green]Yes[/green]" if output.is_sufficient else "[red]No[/red]"
|
|
69
|
+
log(deps.console, f" Sufficient: {status}")
|
|
70
|
+
|
|
71
|
+
from haiku.rag.research.nodes.search import SearchDispatchNode
|
|
72
|
+
|
|
73
|
+
if (
|
|
74
|
+
output.is_sufficient
|
|
75
|
+
and output.confidence_score >= state.confidence_threshold
|
|
76
|
+
) or state.iterations >= state.max_iterations:
|
|
77
|
+
log(deps.console, "\n[bold green]✅ Stopping research.[/bold green]")
|
|
78
|
+
return SynthesizeNode(self.provider, self.model)
|
|
79
|
+
|
|
80
|
+
return SearchDispatchNode(self.provider, self.model)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import Agent, RunContext
|
|
4
|
+
from pydantic_graph import BaseNode, GraphRunContext
|
|
5
|
+
|
|
6
|
+
from haiku.rag.research.common import get_model, log
|
|
7
|
+
from haiku.rag.research.dependencies import ResearchDependencies
|
|
8
|
+
from haiku.rag.research.models import ResearchPlan, ResearchReport
|
|
9
|
+
from haiku.rag.research.nodes.search import SearchDispatchNode
|
|
10
|
+
from haiku.rag.research.prompts import PLAN_PROMPT
|
|
11
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class PlanNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
16
|
+
provider: str
|
|
17
|
+
model: str
|
|
18
|
+
|
|
19
|
+
async def run(
|
|
20
|
+
self, ctx: GraphRunContext[ResearchState, ResearchDeps]
|
|
21
|
+
) -> BaseNode[ResearchState, ResearchDeps, ResearchReport]:
|
|
22
|
+
state = ctx.state
|
|
23
|
+
deps = ctx.deps
|
|
24
|
+
|
|
25
|
+
log(deps.console, "\n[bold cyan]📋 Creating research plan...[/bold cyan]")
|
|
26
|
+
|
|
27
|
+
plan_agent = Agent(
|
|
28
|
+
model=get_model(self.provider, self.model),
|
|
29
|
+
output_type=ResearchPlan,
|
|
30
|
+
instructions=(
|
|
31
|
+
PLAN_PROMPT
|
|
32
|
+
+ "\n\nUse the gather_context tool once on the main question before planning."
|
|
33
|
+
),
|
|
34
|
+
retries=3,
|
|
35
|
+
deps_type=ResearchDependencies,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
@plan_agent.tool
|
|
39
|
+
async def gather_context(
|
|
40
|
+
ctx2: RunContext[ResearchDependencies], query: str, limit: int = 6
|
|
41
|
+
) -> str:
|
|
42
|
+
results = await ctx2.deps.client.search(query, limit=limit)
|
|
43
|
+
expanded = await ctx2.deps.client.expand_context(results)
|
|
44
|
+
return "\n\n".join(chunk.content for chunk, _ in expanded)
|
|
45
|
+
|
|
46
|
+
prompt = (
|
|
47
|
+
"Plan a focused research approach for the main question.\n\n"
|
|
48
|
+
f"Main question: {state.question}"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
agent_deps = ResearchDependencies(
|
|
52
|
+
client=deps.client, context=state.context, console=deps.console
|
|
53
|
+
)
|
|
54
|
+
plan_result = await plan_agent.run(prompt, deps=agent_deps)
|
|
55
|
+
state.sub_questions = list(plan_result.output.sub_questions)
|
|
56
|
+
|
|
57
|
+
log(deps.console, "\n[bold green]✅ Research Plan Created:[/bold green]")
|
|
58
|
+
log(deps.console, f" [bold]Main Question:[/bold] {state.question}")
|
|
59
|
+
log(deps.console, " [bold]Sub-questions:[/bold]")
|
|
60
|
+
for i, sq in enumerate(state.sub_questions, 1):
|
|
61
|
+
log(deps.console, f" {i}. {sq}")
|
|
62
|
+
|
|
63
|
+
return SearchDispatchNode(self.provider, self.model)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic_ai import Agent, RunContext
|
|
6
|
+
from pydantic_ai.format_prompt import format_as_xml
|
|
7
|
+
from pydantic_ai.output import ToolOutput
|
|
8
|
+
from pydantic_graph import BaseNode, GraphRunContext
|
|
9
|
+
|
|
10
|
+
from haiku.rag.research.common import get_model, log
|
|
11
|
+
from haiku.rag.research.dependencies import ResearchDependencies
|
|
12
|
+
from haiku.rag.research.models import ResearchReport, SearchAnswer
|
|
13
|
+
from haiku.rag.research.prompts import SEARCH_AGENT_PROMPT
|
|
14
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class SearchDispatchNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
19
|
+
provider: str
|
|
20
|
+
model: str
|
|
21
|
+
|
|
22
|
+
async def run(
|
|
23
|
+
self, ctx: GraphRunContext[ResearchState, ResearchDeps]
|
|
24
|
+
) -> BaseNode[ResearchState, ResearchDeps, ResearchReport]:
|
|
25
|
+
state = ctx.state
|
|
26
|
+
deps = ctx.deps
|
|
27
|
+
if not state.sub_questions:
|
|
28
|
+
from haiku.rag.research.nodes.evaluate import EvaluateNode
|
|
29
|
+
|
|
30
|
+
return EvaluateNode(self.provider, self.model)
|
|
31
|
+
|
|
32
|
+
# Take up to max_concurrency questions and answer them concurrently
|
|
33
|
+
take = max(1, state.max_concurrency)
|
|
34
|
+
batch: list[str] = []
|
|
35
|
+
while state.sub_questions and len(batch) < take:
|
|
36
|
+
batch.append(state.sub_questions.pop(0))
|
|
37
|
+
|
|
38
|
+
async def answer_one(sub_q: str) -> SearchAnswer | None:
|
|
39
|
+
log(
|
|
40
|
+
deps.console,
|
|
41
|
+
f"\n[bold cyan]🔍 Searching & Answering:[/bold cyan] {sub_q}",
|
|
42
|
+
)
|
|
43
|
+
agent = Agent(
|
|
44
|
+
model=get_model(self.provider, self.model),
|
|
45
|
+
output_type=ToolOutput(SearchAnswer, max_retries=3),
|
|
46
|
+
instructions=SEARCH_AGENT_PROMPT,
|
|
47
|
+
retries=3,
|
|
48
|
+
deps_type=ResearchDependencies,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@agent.tool
|
|
52
|
+
async def search_and_answer(
|
|
53
|
+
ctx2: RunContext[ResearchDependencies], query: str, limit: int = 5
|
|
54
|
+
) -> str:
|
|
55
|
+
search_results = await ctx2.deps.client.search(query, limit=limit)
|
|
56
|
+
expanded = await ctx2.deps.client.expand_context(search_results)
|
|
57
|
+
|
|
58
|
+
entries: list[dict[str, Any]] = [
|
|
59
|
+
{
|
|
60
|
+
"text": chunk.content,
|
|
61
|
+
"score": score,
|
|
62
|
+
"document_uri": (chunk.document_uri or ""),
|
|
63
|
+
}
|
|
64
|
+
for chunk, score in expanded
|
|
65
|
+
]
|
|
66
|
+
if not entries:
|
|
67
|
+
return f"No relevant information found in the knowledge base for: {query}"
|
|
68
|
+
|
|
69
|
+
return format_as_xml(entries, root_tag="snippets")
|
|
70
|
+
|
|
71
|
+
agent_deps = ResearchDependencies(
|
|
72
|
+
client=deps.client, context=state.context, console=deps.console
|
|
73
|
+
)
|
|
74
|
+
try:
|
|
75
|
+
result = await agent.run(sub_q, deps=agent_deps)
|
|
76
|
+
except Exception as e:
|
|
77
|
+
log(deps.console, f"[red]Search failed:[/red] {e}")
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
return result.output
|
|
81
|
+
|
|
82
|
+
answers = await asyncio.gather(*(answer_one(q) for q in batch))
|
|
83
|
+
for ans in answers:
|
|
84
|
+
if ans is None:
|
|
85
|
+
continue
|
|
86
|
+
state.context.add_qa_response(ans)
|
|
87
|
+
if deps.console:
|
|
88
|
+
preview = ans.answer[:150] + ("…" if len(ans.answer) > 150 else "")
|
|
89
|
+
log(deps.console, f" [green]✓[/green] {preview}")
|
|
90
|
+
|
|
91
|
+
return SearchDispatchNode(self.provider, self.model)
|