haiku.rag 0.10.1__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/app.py +152 -28
- haiku/rag/cli.py +72 -2
- haiku/rag/migration.py +2 -2
- haiku/rag/research/__init__.py +8 -0
- haiku/rag/research/common.py +71 -6
- haiku/rag/research/dependencies.py +179 -11
- haiku/rag/research/graph.py +5 -3
- haiku/rag/research/models.py +134 -1
- haiku/rag/research/nodes/analysis.py +181 -0
- haiku/rag/research/nodes/plan.py +16 -9
- haiku/rag/research/nodes/search.py +14 -11
- haiku/rag/research/nodes/synthesize.py +7 -3
- haiku/rag/research/prompts.py +67 -28
- haiku/rag/research/state.py +11 -4
- haiku/rag/research/stream.py +177 -0
- haiku/rag/store/__init__.py +1 -1
- haiku/rag/store/models/__init__.py +1 -1
- haiku/rag/utils.py +34 -0
- {haiku_rag-0.10.1.dist-info → haiku_rag-0.11.0.dist-info}/METADATA +34 -14
- {haiku_rag-0.10.1.dist-info → haiku_rag-0.11.0.dist-info}/RECORD +23 -22
- haiku/rag/research/nodes/evaluate.py +0 -80
- {haiku_rag-0.10.1.dist-info → haiku_rag-0.11.0.dist-info}/WHEEL +0 -0
- {haiku_rag-0.10.1.dist-info → haiku_rag-0.11.0.dist-info}/entry_points.txt +0 -0
- {haiku_rag-0.10.1.dist-info → haiku_rag-0.11.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,8 +1,16 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
1
3
|
from pydantic import BaseModel, Field
|
|
2
4
|
from rich.console import Console
|
|
3
5
|
|
|
4
6
|
from haiku.rag.client import HaikuRAG
|
|
5
|
-
from haiku.rag.research.models import
|
|
7
|
+
from haiku.rag.research.models import (
|
|
8
|
+
GapRecord,
|
|
9
|
+
InsightAnalysis,
|
|
10
|
+
InsightRecord,
|
|
11
|
+
SearchAnswer,
|
|
12
|
+
)
|
|
13
|
+
from haiku.rag.research.stream import ResearchStream
|
|
6
14
|
|
|
7
15
|
|
|
8
16
|
class ResearchContext(BaseModel):
|
|
@@ -15,10 +23,10 @@ class ResearchContext(BaseModel):
|
|
|
15
23
|
qa_responses: list[SearchAnswer] = Field(
|
|
16
24
|
default_factory=list, description="Structured QA pairs used during research"
|
|
17
25
|
)
|
|
18
|
-
insights: list[
|
|
26
|
+
insights: list[InsightRecord] = Field(
|
|
19
27
|
default_factory=list, description="Key insights discovered"
|
|
20
28
|
)
|
|
21
|
-
gaps: list[
|
|
29
|
+
gaps: list[GapRecord] = Field(
|
|
22
30
|
default_factory=list, description="Identified information gaps"
|
|
23
31
|
)
|
|
24
32
|
|
|
@@ -26,15 +34,147 @@ class ResearchContext(BaseModel):
|
|
|
26
34
|
"""Add a structured QA response (minimal context already included)."""
|
|
27
35
|
self.qa_responses.append(qa)
|
|
28
36
|
|
|
29
|
-
def
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
|
|
37
|
+
def upsert_insights(self, records: Iterable[InsightRecord]) -> list[InsightRecord]:
|
|
38
|
+
"""Merge one or more insights into the shared context with deduplication."""
|
|
39
|
+
|
|
40
|
+
merged: list[InsightRecord] = []
|
|
41
|
+
for record in records:
|
|
42
|
+
candidate = InsightRecord.model_validate(record)
|
|
43
|
+
existing = next(
|
|
44
|
+
(ins for ins in self.insights if ins.id == candidate.id), None
|
|
45
|
+
)
|
|
46
|
+
if not existing:
|
|
47
|
+
existing = next(
|
|
48
|
+
(ins for ins in self.insights if ins.summary == candidate.summary),
|
|
49
|
+
None,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if existing:
|
|
53
|
+
existing.summary = candidate.summary
|
|
54
|
+
existing.status = candidate.status
|
|
55
|
+
if candidate.notes:
|
|
56
|
+
existing.notes = candidate.notes
|
|
57
|
+
existing.supporting_sources = _merge_unique(
|
|
58
|
+
existing.supporting_sources, candidate.supporting_sources
|
|
59
|
+
)
|
|
60
|
+
existing.originating_questions = _merge_unique(
|
|
61
|
+
existing.originating_questions, candidate.originating_questions
|
|
62
|
+
)
|
|
63
|
+
merged.append(existing)
|
|
64
|
+
else:
|
|
65
|
+
candidate = candidate.model_copy(deep=True)
|
|
66
|
+
if candidate.id is None: # pragma: no cover - defensive
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"InsightRecord.id must be populated after validation"
|
|
69
|
+
)
|
|
70
|
+
candidate_id: str = candidate.id
|
|
71
|
+
candidate.id = self._allocate_insight_id(candidate_id)
|
|
72
|
+
self.insights.append(candidate)
|
|
73
|
+
merged.append(candidate)
|
|
74
|
+
|
|
75
|
+
return merged
|
|
76
|
+
|
|
77
|
+
def upsert_gaps(self, records: Iterable[GapRecord]) -> list[GapRecord]:
|
|
78
|
+
"""Merge one or more gap records into the shared context with deduplication."""
|
|
79
|
+
|
|
80
|
+
merged: list[GapRecord] = []
|
|
81
|
+
for record in records:
|
|
82
|
+
candidate = GapRecord.model_validate(record)
|
|
83
|
+
existing = next((gap for gap in self.gaps if gap.id == candidate.id), None)
|
|
84
|
+
if not existing:
|
|
85
|
+
existing = next(
|
|
86
|
+
(
|
|
87
|
+
gap
|
|
88
|
+
for gap in self.gaps
|
|
89
|
+
if gap.description == candidate.description
|
|
90
|
+
),
|
|
91
|
+
None,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if existing:
|
|
95
|
+
existing.description = candidate.description
|
|
96
|
+
existing.severity = candidate.severity
|
|
97
|
+
existing.blocking = candidate.blocking
|
|
98
|
+
existing.resolved = candidate.resolved
|
|
99
|
+
if candidate.notes:
|
|
100
|
+
existing.notes = candidate.notes
|
|
101
|
+
existing.supporting_sources = _merge_unique(
|
|
102
|
+
existing.supporting_sources, candidate.supporting_sources
|
|
103
|
+
)
|
|
104
|
+
existing.resolved_by = _merge_unique(
|
|
105
|
+
existing.resolved_by, candidate.resolved_by
|
|
106
|
+
)
|
|
107
|
+
merged.append(existing)
|
|
108
|
+
else:
|
|
109
|
+
candidate = candidate.model_copy(deep=True)
|
|
110
|
+
if candidate.id is None: # pragma: no cover - defensive
|
|
111
|
+
raise ValueError("GapRecord.id must be populated after validation")
|
|
112
|
+
candidate_id: str = candidate.id
|
|
113
|
+
candidate.id = self._allocate_gap_id(candidate_id)
|
|
114
|
+
self.gaps.append(candidate)
|
|
115
|
+
merged.append(candidate)
|
|
116
|
+
|
|
117
|
+
return merged
|
|
118
|
+
|
|
119
|
+
def mark_gap_resolved(
|
|
120
|
+
self, identifier: str, resolved_by: Iterable[str] | None = None
|
|
121
|
+
) -> GapRecord | None:
|
|
122
|
+
"""Mark a gap as resolved by identifier (id or description)."""
|
|
123
|
+
|
|
124
|
+
gap = self._find_gap(identifier)
|
|
125
|
+
if gap is None:
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
gap.resolved = True
|
|
129
|
+
gap.blocking = False
|
|
130
|
+
if resolved_by:
|
|
131
|
+
gap.resolved_by = _merge_unique(gap.resolved_by, list(resolved_by))
|
|
132
|
+
return gap
|
|
33
133
|
|
|
34
|
-
def
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
|
|
134
|
+
def integrate_analysis(self, analysis: InsightAnalysis) -> None:
|
|
135
|
+
"""Apply an analysis result to the shared context."""
|
|
136
|
+
|
|
137
|
+
merged_insights: list[InsightRecord] = []
|
|
138
|
+
if analysis.highlights:
|
|
139
|
+
merged_insights = self.upsert_insights(analysis.highlights)
|
|
140
|
+
analysis.highlights = merged_insights
|
|
141
|
+
if analysis.gap_assessments:
|
|
142
|
+
merged_gaps = self.upsert_gaps(analysis.gap_assessments)
|
|
143
|
+
analysis.gap_assessments = merged_gaps
|
|
144
|
+
if analysis.resolved_gaps:
|
|
145
|
+
resolved_by_list = (
|
|
146
|
+
[ins.id for ins in merged_insights if ins.id is not None]
|
|
147
|
+
if merged_insights
|
|
148
|
+
else None
|
|
149
|
+
)
|
|
150
|
+
for resolved in analysis.resolved_gaps:
|
|
151
|
+
self.mark_gap_resolved(resolved, resolved_by=resolved_by_list)
|
|
152
|
+
for question in analysis.new_questions:
|
|
153
|
+
if question not in self.sub_questions:
|
|
154
|
+
self.sub_questions.append(question)
|
|
155
|
+
|
|
156
|
+
def _allocate_insight_id(self, candidate_id: str) -> str:
|
|
157
|
+
taken: set[str] = set()
|
|
158
|
+
for ins in self.insights:
|
|
159
|
+
if ins.id is not None:
|
|
160
|
+
taken.add(ins.id)
|
|
161
|
+
return _allocate_sequential_id(candidate_id, taken)
|
|
162
|
+
|
|
163
|
+
def _allocate_gap_id(self, candidate_id: str) -> str:
|
|
164
|
+
taken: set[str] = set()
|
|
165
|
+
for gap in self.gaps:
|
|
166
|
+
if gap.id is not None:
|
|
167
|
+
taken.add(gap.id)
|
|
168
|
+
return _allocate_sequential_id(candidate_id, taken)
|
|
169
|
+
|
|
170
|
+
def _find_gap(self, identifier: str) -> GapRecord | None:
|
|
171
|
+
normalized = identifier.lower().strip()
|
|
172
|
+
for gap in self.gaps:
|
|
173
|
+
if gap.id is not None and gap.id == normalized:
|
|
174
|
+
return gap
|
|
175
|
+
if gap.description.lower().strip() == normalized:
|
|
176
|
+
return gap
|
|
177
|
+
return None
|
|
38
178
|
|
|
39
179
|
|
|
40
180
|
class ResearchDependencies(BaseModel):
|
|
@@ -45,3 +185,31 @@ class ResearchDependencies(BaseModel):
|
|
|
45
185
|
client: HaikuRAG = Field(description="RAG client for document operations")
|
|
46
186
|
context: ResearchContext = Field(description="Shared research context")
|
|
47
187
|
console: Console | None = None
|
|
188
|
+
stream: ResearchStream | None = Field(
|
|
189
|
+
default=None, description="Optional research event stream"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _merge_unique(existing: list[str], incoming: Iterable[str]) -> list[str]:
|
|
194
|
+
"""Merge two iterables preserving order while removing duplicates."""
|
|
195
|
+
|
|
196
|
+
merged = list(existing)
|
|
197
|
+
seen = {item for item in existing if item}
|
|
198
|
+
for item in incoming:
|
|
199
|
+
if item and item not in seen:
|
|
200
|
+
merged.append(item)
|
|
201
|
+
seen.add(item)
|
|
202
|
+
return merged
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _allocate_sequential_id(candidate: str, taken: set[str]) -> str:
|
|
206
|
+
slug = candidate
|
|
207
|
+
if slug not in taken:
|
|
208
|
+
return slug
|
|
209
|
+
base = slug
|
|
210
|
+
counter = 2
|
|
211
|
+
while True:
|
|
212
|
+
slug = f"{base}-{counter}"
|
|
213
|
+
if slug not in taken:
|
|
214
|
+
return slug
|
|
215
|
+
counter += 1
|
haiku/rag/research/graph.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from pydantic_graph import Graph
|
|
2
2
|
|
|
3
3
|
from haiku.rag.research.models import ResearchReport
|
|
4
|
-
from haiku.rag.research.nodes.
|
|
4
|
+
from haiku.rag.research.nodes.analysis import AnalyzeInsightsNode, DecisionNode
|
|
5
5
|
from haiku.rag.research.nodes.plan import PlanNode
|
|
6
6
|
from haiku.rag.research.nodes.search import SearchDispatchNode
|
|
7
7
|
from haiku.rag.research.nodes.synthesize import SynthesizeNode
|
|
@@ -10,7 +10,8 @@ from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
|
10
10
|
__all__ = [
|
|
11
11
|
"PlanNode",
|
|
12
12
|
"SearchDispatchNode",
|
|
13
|
-
"
|
|
13
|
+
"AnalyzeInsightsNode",
|
|
14
|
+
"DecisionNode",
|
|
14
15
|
"SynthesizeNode",
|
|
15
16
|
"ResearchState",
|
|
16
17
|
"ResearchDeps",
|
|
@@ -23,7 +24,8 @@ def build_research_graph() -> Graph[ResearchState, ResearchDeps, ResearchReport]
|
|
|
23
24
|
nodes=[
|
|
24
25
|
PlanNode,
|
|
25
26
|
SearchDispatchNode,
|
|
26
|
-
|
|
27
|
+
AnalyzeInsightsNode,
|
|
28
|
+
DecisionNode,
|
|
27
29
|
SynthesizeNode,
|
|
28
30
|
]
|
|
29
31
|
)
|
haiku/rag/research/models.py
CHANGED
|
@@ -1,4 +1,134 @@
|
|
|
1
|
-
|
|
1
|
+
import re
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field, model_validator
|
|
5
|
+
|
|
6
|
+
_SLUG_RE = re.compile(r"[^a-z0-9]+")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _make_slug(text: str, prefix: str) -> str:
|
|
10
|
+
"""Generate a lowercase slug with the given prefix as fallback."""
|
|
11
|
+
|
|
12
|
+
base = _SLUG_RE.sub("-", text.lower()).strip("-")
|
|
13
|
+
if not base:
|
|
14
|
+
base = prefix
|
|
15
|
+
# Trim overly long slugs but keep enough entropy for readability
|
|
16
|
+
return base[:48]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class InsightStatus(str, Enum):
|
|
20
|
+
OPEN = "open"
|
|
21
|
+
VALIDATED = "validated"
|
|
22
|
+
TENTATIVE = "tentative"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GapSeverity(str, Enum):
|
|
26
|
+
LOW = "low"
|
|
27
|
+
MEDIUM = "medium"
|
|
28
|
+
HIGH = "high"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class InsightRecord(BaseModel):
|
|
32
|
+
"""Structured insight with provenance and lifecycle metadata."""
|
|
33
|
+
|
|
34
|
+
id: str | None = Field(
|
|
35
|
+
default=None,
|
|
36
|
+
description="Stable slug identifier for the insight (auto-generated if omitted)",
|
|
37
|
+
)
|
|
38
|
+
summary: str = Field(description="Concise description of the insight")
|
|
39
|
+
status: InsightStatus = Field(
|
|
40
|
+
default=InsightStatus.OPEN,
|
|
41
|
+
description="Lifecycle status for the insight",
|
|
42
|
+
)
|
|
43
|
+
supporting_sources: list[str] = Field(
|
|
44
|
+
default_factory=list,
|
|
45
|
+
description="Source identifiers backing the insight",
|
|
46
|
+
)
|
|
47
|
+
originating_questions: list[str] = Field(
|
|
48
|
+
default_factory=list,
|
|
49
|
+
description="Research sub-questions that produced this insight",
|
|
50
|
+
)
|
|
51
|
+
notes: str | None = Field(
|
|
52
|
+
default=None,
|
|
53
|
+
description="Optional elaboration or caveats for the insight",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@model_validator(mode="after")
|
|
57
|
+
def _set_defaults(self) -> "InsightRecord":
|
|
58
|
+
if not self.id:
|
|
59
|
+
self.id = _make_slug(self.summary, "insight")
|
|
60
|
+
self.id = self.id.lower()
|
|
61
|
+
self.supporting_sources = list(dict.fromkeys(self.supporting_sources))
|
|
62
|
+
self.originating_questions = list(dict.fromkeys(self.originating_questions))
|
|
63
|
+
return self
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class GapRecord(BaseModel):
|
|
67
|
+
"""Structured representation of an identified research gap."""
|
|
68
|
+
|
|
69
|
+
id: str | None = Field(
|
|
70
|
+
default=None,
|
|
71
|
+
description="Stable slug identifier for the gap (auto-generated if omitted)",
|
|
72
|
+
)
|
|
73
|
+
description: str = Field(description="Concrete statement of what is missing")
|
|
74
|
+
severity: GapSeverity = Field(
|
|
75
|
+
default=GapSeverity.MEDIUM,
|
|
76
|
+
description="Severity of the gap for answering the main question",
|
|
77
|
+
)
|
|
78
|
+
blocking: bool = Field(
|
|
79
|
+
default=True,
|
|
80
|
+
description="Whether this gap blocks a confident answer",
|
|
81
|
+
)
|
|
82
|
+
resolved: bool = Field(
|
|
83
|
+
default=False,
|
|
84
|
+
description="Flag indicating if the gap has been resolved",
|
|
85
|
+
)
|
|
86
|
+
resolved_by: list[str] = Field(
|
|
87
|
+
default_factory=list,
|
|
88
|
+
description="Insight IDs or notes explaining how the gap was closed",
|
|
89
|
+
)
|
|
90
|
+
supporting_sources: list[str] = Field(
|
|
91
|
+
default_factory=list,
|
|
92
|
+
description="Sources confirming the gap status (e.g., evidence of absence)",
|
|
93
|
+
)
|
|
94
|
+
notes: str | None = Field(
|
|
95
|
+
default=None,
|
|
96
|
+
description="Optional clarification about the gap or follow-up actions",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@model_validator(mode="after")
|
|
100
|
+
def _set_defaults(self) -> "GapRecord":
|
|
101
|
+
if not self.id:
|
|
102
|
+
self.id = _make_slug(self.description, "gap")
|
|
103
|
+
self.id = self.id.lower()
|
|
104
|
+
self.resolved_by = list(dict.fromkeys(self.resolved_by))
|
|
105
|
+
self.supporting_sources = list(dict.fromkeys(self.supporting_sources))
|
|
106
|
+
return self
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class InsightAnalysis(BaseModel):
|
|
110
|
+
"""Output of the insight aggregation agent."""
|
|
111
|
+
|
|
112
|
+
highlights: list[InsightRecord] = Field(
|
|
113
|
+
default_factory=list,
|
|
114
|
+
description="New or updated insights discovered this iteration",
|
|
115
|
+
)
|
|
116
|
+
gap_assessments: list[GapRecord] = Field(
|
|
117
|
+
default_factory=list,
|
|
118
|
+
description="New or updated gap records based on current evidence",
|
|
119
|
+
)
|
|
120
|
+
resolved_gaps: list[str] = Field(
|
|
121
|
+
default_factory=list,
|
|
122
|
+
description="Gap identifiers or descriptions considered resolved",
|
|
123
|
+
)
|
|
124
|
+
new_questions: list[str] = Field(
|
|
125
|
+
default_factory=list,
|
|
126
|
+
max_length=3,
|
|
127
|
+
description="Up to three follow-up sub-questions to pursue next",
|
|
128
|
+
)
|
|
129
|
+
commentary: str = Field(
|
|
130
|
+
description="Short narrative summary of the incremental findings",
|
|
131
|
+
)
|
|
2
132
|
|
|
3
133
|
|
|
4
134
|
class ResearchPlan(BaseModel):
|
|
@@ -37,6 +167,9 @@ class EvaluationResult(BaseModel):
|
|
|
37
167
|
max_length=3,
|
|
38
168
|
default=[],
|
|
39
169
|
)
|
|
170
|
+
gaps: list[str] = Field(
|
|
171
|
+
description="Concrete information gaps that remain", default_factory=list
|
|
172
|
+
)
|
|
40
173
|
confidence_score: float = Field(
|
|
41
174
|
description="Confidence level in the completeness of research (0-1)",
|
|
42
175
|
ge=0.0,
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import Agent
|
|
4
|
+
from pydantic_graph import BaseNode, GraphRunContext
|
|
5
|
+
|
|
6
|
+
from haiku.rag.research.common import (
|
|
7
|
+
format_analysis_for_prompt,
|
|
8
|
+
format_context_for_prompt,
|
|
9
|
+
get_model,
|
|
10
|
+
log,
|
|
11
|
+
)
|
|
12
|
+
from haiku.rag.research.dependencies import ResearchDependencies
|
|
13
|
+
from haiku.rag.research.models import EvaluationResult, InsightAnalysis, ResearchReport
|
|
14
|
+
from haiku.rag.research.nodes.synthesize import SynthesizeNode
|
|
15
|
+
from haiku.rag.research.prompts import DECISION_AGENT_PROMPT, INSIGHT_AGENT_PROMPT
|
|
16
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class AnalyzeInsightsNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
21
|
+
provider: str
|
|
22
|
+
model: str
|
|
23
|
+
|
|
24
|
+
async def run(
|
|
25
|
+
self, ctx: GraphRunContext[ResearchState, ResearchDeps]
|
|
26
|
+
) -> BaseNode[ResearchState, ResearchDeps, ResearchReport]:
|
|
27
|
+
state = ctx.state
|
|
28
|
+
deps = ctx.deps
|
|
29
|
+
|
|
30
|
+
log(
|
|
31
|
+
deps,
|
|
32
|
+
state,
|
|
33
|
+
"\n[bold cyan]🧭 Synthesizing new insights and gap status...[/bold cyan]",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
agent = Agent(
|
|
37
|
+
model=get_model(self.provider, self.model),
|
|
38
|
+
output_type=InsightAnalysis,
|
|
39
|
+
instructions=INSIGHT_AGENT_PROMPT,
|
|
40
|
+
retries=3,
|
|
41
|
+
deps_type=ResearchDependencies,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
context_xml = format_context_for_prompt(state.context)
|
|
45
|
+
prompt = (
|
|
46
|
+
"Review the latest research context and update the shared ledger of insights, gaps,"
|
|
47
|
+
" and follow-up questions.\n\n"
|
|
48
|
+
f"{context_xml}"
|
|
49
|
+
)
|
|
50
|
+
agent_deps = ResearchDependencies(
|
|
51
|
+
client=deps.client,
|
|
52
|
+
context=state.context,
|
|
53
|
+
console=deps.console,
|
|
54
|
+
stream=deps.stream,
|
|
55
|
+
)
|
|
56
|
+
result = await agent.run(prompt, deps=agent_deps)
|
|
57
|
+
analysis: InsightAnalysis = result.output
|
|
58
|
+
|
|
59
|
+
state.context.integrate_analysis(analysis)
|
|
60
|
+
state.last_analysis = analysis
|
|
61
|
+
|
|
62
|
+
if analysis.commentary:
|
|
63
|
+
log(deps, state, f" Summary: {analysis.commentary}")
|
|
64
|
+
if analysis.highlights:
|
|
65
|
+
log(deps, state, " [bold]Updated insights:[/bold]")
|
|
66
|
+
for insight in analysis.highlights:
|
|
67
|
+
label = insight.status.value
|
|
68
|
+
log(
|
|
69
|
+
deps,
|
|
70
|
+
state,
|
|
71
|
+
f" • ({label}) {insight.summary}",
|
|
72
|
+
)
|
|
73
|
+
if analysis.gap_assessments:
|
|
74
|
+
log(deps, state, " [bold yellow]Gap updates:[/bold yellow]")
|
|
75
|
+
for gap in analysis.gap_assessments:
|
|
76
|
+
status = "resolved" if gap.resolved else "open"
|
|
77
|
+
severity = gap.severity.value
|
|
78
|
+
log(
|
|
79
|
+
deps,
|
|
80
|
+
state,
|
|
81
|
+
f" • ({severity}/{status}) {gap.description}",
|
|
82
|
+
)
|
|
83
|
+
if analysis.resolved_gaps:
|
|
84
|
+
log(deps, state, " [green]Resolved gaps:[/green]")
|
|
85
|
+
for resolved in analysis.resolved_gaps:
|
|
86
|
+
log(deps, state, f" • {resolved}")
|
|
87
|
+
if analysis.new_questions:
|
|
88
|
+
log(deps, state, " [cyan]Proposed follow-ups:[/cyan]")
|
|
89
|
+
for question in analysis.new_questions:
|
|
90
|
+
log(deps, state, f" • {question}")
|
|
91
|
+
|
|
92
|
+
return DecisionNode(self.provider, self.model)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class DecisionNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
97
|
+
provider: str
|
|
98
|
+
model: str
|
|
99
|
+
|
|
100
|
+
async def run(
|
|
101
|
+
self, ctx: GraphRunContext[ResearchState, ResearchDeps]
|
|
102
|
+
) -> BaseNode[ResearchState, ResearchDeps, ResearchReport]:
|
|
103
|
+
state = ctx.state
|
|
104
|
+
deps = ctx.deps
|
|
105
|
+
|
|
106
|
+
log(
|
|
107
|
+
deps,
|
|
108
|
+
state,
|
|
109
|
+
"\n[bold cyan]📊 Evaluating research sufficiency...[/bold cyan]",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
agent = Agent(
|
|
113
|
+
model=get_model(self.provider, self.model),
|
|
114
|
+
output_type=EvaluationResult,
|
|
115
|
+
instructions=DECISION_AGENT_PROMPT,
|
|
116
|
+
retries=3,
|
|
117
|
+
deps_type=ResearchDependencies,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
context_xml = format_context_for_prompt(state.context)
|
|
121
|
+
analysis_xml = format_analysis_for_prompt(state.last_analysis)
|
|
122
|
+
prompt_parts = [
|
|
123
|
+
"Assess whether the research now answers the original question with adequate confidence.",
|
|
124
|
+
context_xml,
|
|
125
|
+
analysis_xml,
|
|
126
|
+
]
|
|
127
|
+
if state.last_eval is not None:
|
|
128
|
+
prev = state.last_eval
|
|
129
|
+
prompt_parts.append(
|
|
130
|
+
"<previous_evaluation>"
|
|
131
|
+
f"<confidence>{prev.confidence_score:.2f}</confidence>"
|
|
132
|
+
f"<is_sufficient>{str(prev.is_sufficient).lower()}</is_sufficient>"
|
|
133
|
+
f"<reasoning>{prev.reasoning}</reasoning>"
|
|
134
|
+
"</previous_evaluation>"
|
|
135
|
+
)
|
|
136
|
+
prompt = "\n\n".join(part for part in prompt_parts if part)
|
|
137
|
+
|
|
138
|
+
agent_deps = ResearchDependencies(
|
|
139
|
+
client=deps.client,
|
|
140
|
+
context=state.context,
|
|
141
|
+
console=deps.console,
|
|
142
|
+
stream=deps.stream,
|
|
143
|
+
)
|
|
144
|
+
decision_result = await agent.run(prompt, deps=agent_deps)
|
|
145
|
+
output = decision_result.output
|
|
146
|
+
|
|
147
|
+
state.last_eval = output
|
|
148
|
+
state.iterations += 1
|
|
149
|
+
|
|
150
|
+
for new_q in output.new_questions:
|
|
151
|
+
if new_q not in state.context.sub_questions:
|
|
152
|
+
state.context.sub_questions.append(new_q)
|
|
153
|
+
|
|
154
|
+
if output.key_insights:
|
|
155
|
+
log(deps, state, " [bold]Key insights:[/bold]")
|
|
156
|
+
for insight in output.key_insights:
|
|
157
|
+
log(deps, state, f" • {insight}")
|
|
158
|
+
|
|
159
|
+
if output.gaps:
|
|
160
|
+
log(deps, state, " [bold yellow]Remaining gaps:[/bold yellow]")
|
|
161
|
+
for gap in output.gaps:
|
|
162
|
+
log(deps, state, f" • {gap}")
|
|
163
|
+
|
|
164
|
+
log(
|
|
165
|
+
deps,
|
|
166
|
+
state,
|
|
167
|
+
f" Confidence: [yellow]{output.confidence_score:.1%}[/yellow]",
|
|
168
|
+
)
|
|
169
|
+
status = "[green]Yes[/green]" if output.is_sufficient else "[red]No[/red]"
|
|
170
|
+
log(deps, state, f" Sufficient: {status}")
|
|
171
|
+
|
|
172
|
+
from haiku.rag.research.nodes.search import SearchDispatchNode
|
|
173
|
+
|
|
174
|
+
if (
|
|
175
|
+
output.is_sufficient
|
|
176
|
+
and output.confidence_score >= state.confidence_threshold
|
|
177
|
+
) or state.iterations >= state.max_iterations:
|
|
178
|
+
log(deps, state, "\n[bold green]✅ Stopping research.[/bold green]")
|
|
179
|
+
return SynthesizeNode(self.provider, self.model)
|
|
180
|
+
|
|
181
|
+
return SearchDispatchNode(self.provider, self.model)
|
haiku/rag/research/nodes/plan.py
CHANGED
|
@@ -22,7 +22,7 @@ class PlanNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
|
22
22
|
state = ctx.state
|
|
23
23
|
deps = ctx.deps
|
|
24
24
|
|
|
25
|
-
log(deps
|
|
25
|
+
log(deps, state, "\n[bold cyan]📋 Creating research plan...[/bold cyan]")
|
|
26
26
|
|
|
27
27
|
plan_agent = Agent(
|
|
28
28
|
model=get_model(self.provider, self.model),
|
|
@@ -45,19 +45,26 @@ class PlanNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
|
45
45
|
|
|
46
46
|
prompt = (
|
|
47
47
|
"Plan a focused research approach for the main question.\n\n"
|
|
48
|
-
f"Main question: {state.
|
|
48
|
+
f"Main question: {state.context.original_question}"
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
agent_deps = ResearchDependencies(
|
|
52
|
-
client=deps.client,
|
|
52
|
+
client=deps.client,
|
|
53
|
+
context=state.context,
|
|
54
|
+
console=deps.console,
|
|
55
|
+
stream=deps.stream,
|
|
53
56
|
)
|
|
54
57
|
plan_result = await plan_agent.run(prompt, deps=agent_deps)
|
|
55
|
-
state.sub_questions = list(plan_result.output.sub_questions)
|
|
58
|
+
state.context.sub_questions = list(plan_result.output.sub_questions)
|
|
56
59
|
|
|
57
|
-
log(deps
|
|
58
|
-
log(
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
log(deps, state, "\n[bold green]✅ Research Plan Created:[/bold green]")
|
|
61
|
+
log(
|
|
62
|
+
deps,
|
|
63
|
+
state,
|
|
64
|
+
f" [bold]Main Question:[/bold] {state.context.original_question}",
|
|
65
|
+
)
|
|
66
|
+
log(deps, state, " [bold]Sub-questions:[/bold]")
|
|
67
|
+
for i, sq in enumerate(state.context.sub_questions, 1):
|
|
68
|
+
log(deps, state, f" {i}. {sq}")
|
|
62
69
|
|
|
63
70
|
return SearchDispatchNode(self.provider, self.model)
|
|
@@ -24,20 +24,21 @@ class SearchDispatchNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
|
24
24
|
) -> BaseNode[ResearchState, ResearchDeps, ResearchReport]:
|
|
25
25
|
state = ctx.state
|
|
26
26
|
deps = ctx.deps
|
|
27
|
-
if not state.sub_questions:
|
|
28
|
-
from haiku.rag.research.nodes.
|
|
27
|
+
if not state.context.sub_questions:
|
|
28
|
+
from haiku.rag.research.nodes.analysis import AnalyzeInsightsNode
|
|
29
29
|
|
|
30
|
-
return
|
|
30
|
+
return AnalyzeInsightsNode(self.provider, self.model)
|
|
31
31
|
|
|
32
32
|
# Take up to max_concurrency questions and answer them concurrently
|
|
33
33
|
take = max(1, state.max_concurrency)
|
|
34
34
|
batch: list[str] = []
|
|
35
|
-
while state.sub_questions and len(batch) < take:
|
|
36
|
-
batch.append(state.sub_questions.pop(0))
|
|
35
|
+
while state.context.sub_questions and len(batch) < take:
|
|
36
|
+
batch.append(state.context.sub_questions.pop(0))
|
|
37
37
|
|
|
38
38
|
async def answer_one(sub_q: str) -> SearchAnswer | None:
|
|
39
39
|
log(
|
|
40
|
-
deps
|
|
40
|
+
deps,
|
|
41
|
+
state,
|
|
41
42
|
f"\n[bold cyan]🔍 Searching & Answering:[/bold cyan] {sub_q}",
|
|
42
43
|
)
|
|
43
44
|
agent = Agent(
|
|
@@ -71,12 +72,15 @@ class SearchDispatchNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
|
71
72
|
return format_as_xml(entries, root_tag="snippets")
|
|
72
73
|
|
|
73
74
|
agent_deps = ResearchDependencies(
|
|
74
|
-
client=deps.client,
|
|
75
|
+
client=deps.client,
|
|
76
|
+
context=state.context,
|
|
77
|
+
console=deps.console,
|
|
78
|
+
stream=deps.stream,
|
|
75
79
|
)
|
|
76
80
|
try:
|
|
77
81
|
result = await agent.run(sub_q, deps=agent_deps)
|
|
78
82
|
except Exception as e:
|
|
79
|
-
log(deps
|
|
83
|
+
log(deps, state, f"[red]Search failed:[/red] {e}")
|
|
80
84
|
return None
|
|
81
85
|
|
|
82
86
|
return result.output
|
|
@@ -86,8 +90,7 @@ class SearchDispatchNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
|
86
90
|
if ans is None:
|
|
87
91
|
continue
|
|
88
92
|
state.context.add_qa_response(ans)
|
|
89
|
-
if
|
|
90
|
-
|
|
91
|
-
log(deps.console, f" [green]✓[/green] {preview}")
|
|
93
|
+
preview = ans.answer[:150] + ("…" if len(ans.answer) > 150 else "")
|
|
94
|
+
log(deps, state, f" [green]✓[/green] {preview}")
|
|
92
95
|
|
|
93
96
|
return SearchDispatchNode(self.provider, self.model)
|