haiku.rag 0.9.3__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/app.py +64 -18
- haiku/rag/cli.py +67 -30
- haiku/rag/client.py +63 -21
- haiku/rag/config.py +4 -0
- haiku/rag/mcp.py +18 -6
- haiku/rag/qa/agent.py +4 -2
- haiku/rag/qa/prompts.py +2 -2
- haiku/rag/reranking/mxbai.py +1 -1
- haiku/rag/research/__init__.py +10 -27
- haiku/rag/research/common.py +53 -0
- haiku/rag/research/dependencies.py +3 -25
- haiku/rag/research/graph.py +29 -0
- haiku/rag/research/models.py +70 -0
- haiku/rag/research/nodes/evaluate.py +80 -0
- haiku/rag/research/nodes/plan.py +63 -0
- haiku/rag/research/nodes/search.py +93 -0
- haiku/rag/research/nodes/synthesize.py +51 -0
- haiku/rag/research/prompts.py +98 -113
- haiku/rag/research/state.py +25 -0
- haiku/rag/store/engine.py +14 -0
- haiku/rag/store/models/chunk.py +1 -0
- haiku/rag/store/models/document.py +1 -0
- haiku/rag/store/repositories/chunk.py +4 -0
- haiku/rag/store/repositories/document.py +3 -0
- haiku/rag/store/upgrades/__init__.py +2 -0
- haiku/rag/store/upgrades/v0_10_1.py +64 -0
- haiku/rag/utils.py +8 -5
- {haiku_rag-0.9.3.dist-info → haiku_rag-0.10.1.dist-info}/METADATA +37 -1
- haiku_rag-0.10.1.dist-info/RECORD +54 -0
- haiku/rag/research/base.py +0 -130
- haiku/rag/research/evaluation_agent.py +0 -85
- haiku/rag/research/orchestrator.py +0 -170
- haiku/rag/research/presearch_agent.py +0 -39
- haiku/rag/research/search_agent.py +0 -69
- haiku/rag/research/synthesis_agent.py +0 -60
- haiku_rag-0.9.3.dist-info/RECORD +0 -51
- {haiku_rag-0.9.3.dist-info → haiku_rag-0.10.1.dist-info}/WHEEL +0 -0
- {haiku_rag-0.9.3.dist-info → haiku_rag-0.10.1.dist-info}/entry_points.txt +0 -0
- {haiku_rag-0.9.3.dist-info → haiku_rag-0.10.1.dist-info}/licenses/LICENSE +0 -0
haiku/rag/qa/agent.py
CHANGED
|
@@ -12,7 +12,9 @@ from haiku.rag.qa.prompts import QA_SYSTEM_PROMPT, QA_SYSTEM_PROMPT_WITH_CITATIO
|
|
|
12
12
|
class SearchResult(BaseModel):
|
|
13
13
|
content: str = Field(description="The document text content")
|
|
14
14
|
score: float = Field(description="Relevance score (higher is more relevant)")
|
|
15
|
-
document_uri: str = Field(
|
|
15
|
+
document_uri: str = Field(
|
|
16
|
+
description="Source title (if available) or URI/path of the document"
|
|
17
|
+
)
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class Dependencies(BaseModel):
|
|
@@ -59,7 +61,7 @@ class QuestionAnswerAgent:
|
|
|
59
61
|
SearchResult(
|
|
60
62
|
content=chunk.content,
|
|
61
63
|
score=score,
|
|
62
|
-
document_uri=chunk.document_uri or "",
|
|
64
|
+
document_uri=(chunk.document_title or chunk.document_uri or ""),
|
|
63
65
|
)
|
|
64
66
|
for chunk, score in expanded_results
|
|
65
67
|
]
|
haiku/rag/qa/prompts.py
CHANGED
|
@@ -44,9 +44,9 @@ Guidelines:
|
|
|
44
44
|
|
|
45
45
|
Citation Format:
|
|
46
46
|
After your answer, include a "Citations:" section that lists:
|
|
47
|
-
- The document URI from each search result used
|
|
47
|
+
- The document title (if available) or URI from each search result used
|
|
48
48
|
- A brief excerpt (first 50-100 characters) of the content that supported your answer
|
|
49
|
-
- Format: "Citations:\n- [
|
|
49
|
+
- Format: "Citations:\n- [document title or URI]: [content_excerpt]..."
|
|
50
50
|
|
|
51
51
|
Example response format:
|
|
52
52
|
[Your answer here]
|
haiku/rag/reranking/mxbai.py
CHANGED
haiku/rag/research/__init__.py
CHANGED
|
@@ -1,37 +1,20 @@
|
|
|
1
|
-
"""Multi-agent research workflow for advanced RAG queries."""
|
|
2
|
-
|
|
3
|
-
from haiku.rag.research.base import (
|
|
4
|
-
BaseResearchAgent,
|
|
5
|
-
ResearchOutput,
|
|
6
|
-
SearchAnswer,
|
|
7
|
-
SearchResult,
|
|
8
|
-
)
|
|
9
1
|
from haiku.rag.research.dependencies import ResearchContext, ResearchDependencies
|
|
10
|
-
from haiku.rag.research.
|
|
11
|
-
|
|
12
|
-
|
|
2
|
+
from haiku.rag.research.graph import (
|
|
3
|
+
PlanNode,
|
|
4
|
+
ResearchDeps,
|
|
5
|
+
ResearchState,
|
|
6
|
+
build_research_graph,
|
|
13
7
|
)
|
|
14
|
-
from haiku.rag.research.
|
|
15
|
-
from haiku.rag.research.presearch_agent import PresearchSurveyAgent
|
|
16
|
-
from haiku.rag.research.search_agent import SearchSpecialistAgent
|
|
17
|
-
from haiku.rag.research.synthesis_agent import ResearchReport, SynthesisAgent
|
|
8
|
+
from haiku.rag.research.models import EvaluationResult, ResearchReport, SearchAnswer
|
|
18
9
|
|
|
19
10
|
__all__ = [
|
|
20
|
-
# Base classes
|
|
21
|
-
"BaseResearchAgent",
|
|
22
11
|
"ResearchDependencies",
|
|
23
12
|
"ResearchContext",
|
|
24
|
-
"SearchResult",
|
|
25
|
-
"ResearchOutput",
|
|
26
|
-
# Specialized agents
|
|
27
13
|
"SearchAnswer",
|
|
28
|
-
"SearchSpecialistAgent",
|
|
29
|
-
"PresearchSurveyAgent",
|
|
30
|
-
"AnalysisEvaluationAgent",
|
|
31
14
|
"EvaluationResult",
|
|
32
|
-
"SynthesisAgent",
|
|
33
15
|
"ResearchReport",
|
|
34
|
-
|
|
35
|
-
"
|
|
36
|
-
"
|
|
16
|
+
"ResearchDeps",
|
|
17
|
+
"ResearchState",
|
|
18
|
+
"PlanNode",
|
|
19
|
+
"build_research_graph",
|
|
37
20
|
]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import format_as_xml
|
|
4
|
+
from pydantic_ai.models.openai import OpenAIChatModel
|
|
5
|
+
from pydantic_ai.providers.ollama import OllamaProvider
|
|
6
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
7
|
+
|
|
8
|
+
from haiku.rag.config import Config
|
|
9
|
+
from haiku.rag.research.dependencies import ResearchContext
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_model(provider: str, model: str) -> Any:
|
|
13
|
+
if provider == "ollama":
|
|
14
|
+
return OpenAIChatModel(
|
|
15
|
+
model_name=model,
|
|
16
|
+
provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
|
|
17
|
+
)
|
|
18
|
+
elif provider == "vllm":
|
|
19
|
+
return OpenAIChatModel(
|
|
20
|
+
model_name=model,
|
|
21
|
+
provider=OpenAIProvider(
|
|
22
|
+
base_url=f"{Config.VLLM_RESEARCH_BASE_URL or Config.VLLM_QA_BASE_URL}/v1",
|
|
23
|
+
api_key="none",
|
|
24
|
+
),
|
|
25
|
+
)
|
|
26
|
+
else:
|
|
27
|
+
return f"{provider}:{model}"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def log(console, msg: str) -> None:
|
|
31
|
+
if console:
|
|
32
|
+
console.print(msg)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def format_context_for_prompt(context: ResearchContext) -> str:
|
|
36
|
+
"""Format the research context as XML for inclusion in prompts."""
|
|
37
|
+
|
|
38
|
+
context_data = {
|
|
39
|
+
"original_question": context.original_question,
|
|
40
|
+
"unanswered_questions": context.sub_questions,
|
|
41
|
+
"qa_responses": [
|
|
42
|
+
{
|
|
43
|
+
"question": qa.query,
|
|
44
|
+
"answer": qa.answer,
|
|
45
|
+
"context_snippets": qa.context,
|
|
46
|
+
"sources": qa.sources, # pyright: ignore[reportAttributeAccessIssue]
|
|
47
|
+
}
|
|
48
|
+
for qa in context.qa_responses
|
|
49
|
+
],
|
|
50
|
+
"insights": context.insights,
|
|
51
|
+
"gaps": context.gaps,
|
|
52
|
+
}
|
|
53
|
+
return format_as_xml(context_data, root_tag="research_context")
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
from pydantic import BaseModel, Field
|
|
2
|
-
from pydantic_ai import format_as_xml
|
|
3
2
|
from rich.console import Console
|
|
4
3
|
|
|
5
4
|
from haiku.rag.client import HaikuRAG
|
|
6
|
-
from haiku.rag.research.
|
|
5
|
+
from haiku.rag.research.models import SearchAnswer
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class ResearchContext(BaseModel):
|
|
@@ -13,7 +12,7 @@ class ResearchContext(BaseModel):
|
|
|
13
12
|
sub_questions: list[str] = Field(
|
|
14
13
|
default_factory=list, description="Decomposed sub-questions"
|
|
15
14
|
)
|
|
16
|
-
qa_responses: list[
|
|
15
|
+
qa_responses: list[SearchAnswer] = Field(
|
|
17
16
|
default_factory=list, description="Structured QA pairs used during research"
|
|
18
17
|
)
|
|
19
18
|
insights: list[str] = Field(
|
|
@@ -23,7 +22,7 @@ class ResearchContext(BaseModel):
|
|
|
23
22
|
default_factory=list, description="Identified information gaps"
|
|
24
23
|
)
|
|
25
24
|
|
|
26
|
-
def add_qa_response(self, qa:
|
|
25
|
+
def add_qa_response(self, qa: SearchAnswer) -> None:
|
|
27
26
|
"""Add a structured QA response (minimal context already included)."""
|
|
28
27
|
self.qa_responses.append(qa)
|
|
29
28
|
|
|
@@ -46,24 +45,3 @@ class ResearchDependencies(BaseModel):
|
|
|
46
45
|
client: HaikuRAG = Field(description="RAG client for document operations")
|
|
47
46
|
context: ResearchContext = Field(description="Shared research context")
|
|
48
47
|
console: Console | None = None
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def _format_context_for_prompt(context: ResearchContext) -> str:
|
|
52
|
-
"""Format the research context as XML for inclusion in prompts."""
|
|
53
|
-
|
|
54
|
-
context_data = {
|
|
55
|
-
"original_question": context.original_question,
|
|
56
|
-
"unanswered_questions": context.sub_questions,
|
|
57
|
-
"qa_responses": [
|
|
58
|
-
{
|
|
59
|
-
"question": qa.query,
|
|
60
|
-
"answer": qa.answer,
|
|
61
|
-
"context_snippets": qa.context,
|
|
62
|
-
"sources": qa.sources,
|
|
63
|
-
}
|
|
64
|
-
for qa in context.qa_responses
|
|
65
|
-
],
|
|
66
|
-
"insights": context.insights,
|
|
67
|
-
"gaps": context.gaps,
|
|
68
|
-
}
|
|
69
|
-
return format_as_xml(context_data, root_tag="research_context")
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from pydantic_graph import Graph
|
|
2
|
+
|
|
3
|
+
from haiku.rag.research.models import ResearchReport
|
|
4
|
+
from haiku.rag.research.nodes.evaluate import EvaluateNode
|
|
5
|
+
from haiku.rag.research.nodes.plan import PlanNode
|
|
6
|
+
from haiku.rag.research.nodes.search import SearchDispatchNode
|
|
7
|
+
from haiku.rag.research.nodes.synthesize import SynthesizeNode
|
|
8
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"PlanNode",
|
|
12
|
+
"SearchDispatchNode",
|
|
13
|
+
"EvaluateNode",
|
|
14
|
+
"SynthesizeNode",
|
|
15
|
+
"ResearchState",
|
|
16
|
+
"ResearchDeps",
|
|
17
|
+
"build_research_graph",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def build_research_graph() -> Graph[ResearchState, ResearchDeps, ResearchReport]:
|
|
22
|
+
return Graph(
|
|
23
|
+
nodes=[
|
|
24
|
+
PlanNode,
|
|
25
|
+
SearchDispatchNode,
|
|
26
|
+
EvaluateNode,
|
|
27
|
+
SynthesizeNode,
|
|
28
|
+
]
|
|
29
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ResearchPlan(BaseModel):
|
|
5
|
+
main_question: str
|
|
6
|
+
sub_questions: list[str]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SearchAnswer(BaseModel):
|
|
10
|
+
"""Structured output for the SearchSpecialist agent."""
|
|
11
|
+
|
|
12
|
+
query: str = Field(description="The search query that was performed")
|
|
13
|
+
answer: str = Field(description="The answer generated based on the context")
|
|
14
|
+
context: list[str] = Field(
|
|
15
|
+
description=(
|
|
16
|
+
"Only the minimal set of relevant snippets (verbatim) that directly "
|
|
17
|
+
"support the answer"
|
|
18
|
+
)
|
|
19
|
+
)
|
|
20
|
+
sources: list[str] = Field(
|
|
21
|
+
description=(
|
|
22
|
+
"Document titles (if available) or URIs corresponding to the"
|
|
23
|
+
" snippets actually used in the answer (one per snippet; omit if none)"
|
|
24
|
+
),
|
|
25
|
+
default_factory=list,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class EvaluationResult(BaseModel):
|
|
30
|
+
"""Result of analysis and evaluation."""
|
|
31
|
+
|
|
32
|
+
key_insights: list[str] = Field(
|
|
33
|
+
description="Main insights extracted from the research so far"
|
|
34
|
+
)
|
|
35
|
+
new_questions: list[str] = Field(
|
|
36
|
+
description="New sub-questions to add to the research (max 3)",
|
|
37
|
+
max_length=3,
|
|
38
|
+
default=[],
|
|
39
|
+
)
|
|
40
|
+
confidence_score: float = Field(
|
|
41
|
+
description="Confidence level in the completeness of research (0-1)",
|
|
42
|
+
ge=0.0,
|
|
43
|
+
le=1.0,
|
|
44
|
+
)
|
|
45
|
+
is_sufficient: bool = Field(
|
|
46
|
+
description="Whether the research is sufficient to answer the original question"
|
|
47
|
+
)
|
|
48
|
+
reasoning: str = Field(
|
|
49
|
+
description="Explanation of why the research is or isn't complete"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ResearchReport(BaseModel):
|
|
54
|
+
"""Final research report structure."""
|
|
55
|
+
|
|
56
|
+
title: str = Field(description="Concise title for the research")
|
|
57
|
+
executive_summary: str = Field(description="Brief overview of key findings")
|
|
58
|
+
main_findings: list[str] = Field(
|
|
59
|
+
description="Primary research findings with supporting evidence"
|
|
60
|
+
)
|
|
61
|
+
conclusions: list[str] = Field(description="Evidence-based conclusions")
|
|
62
|
+
limitations: list[str] = Field(
|
|
63
|
+
description="Limitations of the current research", default=[]
|
|
64
|
+
)
|
|
65
|
+
recommendations: list[str] = Field(
|
|
66
|
+
description="Actionable recommendations based on findings", default=[]
|
|
67
|
+
)
|
|
68
|
+
sources_summary: str = Field(
|
|
69
|
+
description="Summary of sources used and their reliability"
|
|
70
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import Agent
|
|
4
|
+
from pydantic_graph import BaseNode, GraphRunContext
|
|
5
|
+
|
|
6
|
+
from haiku.rag.research.common import format_context_for_prompt, get_model, log
|
|
7
|
+
from haiku.rag.research.dependencies import (
|
|
8
|
+
ResearchDependencies,
|
|
9
|
+
)
|
|
10
|
+
from haiku.rag.research.models import EvaluationResult, ResearchReport
|
|
11
|
+
from haiku.rag.research.nodes.synthesize import SynthesizeNode
|
|
12
|
+
from haiku.rag.research.prompts import EVALUATION_AGENT_PROMPT
|
|
13
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class EvaluateNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
18
|
+
provider: str
|
|
19
|
+
model: str
|
|
20
|
+
|
|
21
|
+
async def run(
|
|
22
|
+
self, ctx: GraphRunContext[ResearchState, ResearchDeps]
|
|
23
|
+
) -> BaseNode[ResearchState, ResearchDeps, ResearchReport]:
|
|
24
|
+
state = ctx.state
|
|
25
|
+
deps = ctx.deps
|
|
26
|
+
|
|
27
|
+
log(
|
|
28
|
+
deps.console,
|
|
29
|
+
"\n[bold cyan]📊 Analyzing and evaluating research progress...[/bold cyan]",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
agent = Agent(
|
|
33
|
+
model=get_model(self.provider, self.model),
|
|
34
|
+
output_type=EvaluationResult,
|
|
35
|
+
instructions=EVALUATION_AGENT_PROMPT,
|
|
36
|
+
retries=3,
|
|
37
|
+
deps_type=ResearchDependencies,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
context_xml = format_context_for_prompt(state.context)
|
|
41
|
+
prompt = (
|
|
42
|
+
"Analyze gathered information and evaluate completeness for the original question.\n\n"
|
|
43
|
+
f"{context_xml}"
|
|
44
|
+
)
|
|
45
|
+
agent_deps = ResearchDependencies(
|
|
46
|
+
client=deps.client, context=state.context, console=deps.console
|
|
47
|
+
)
|
|
48
|
+
eval_result = await agent.run(prompt, deps=agent_deps)
|
|
49
|
+
output = eval_result.output
|
|
50
|
+
|
|
51
|
+
for insight in output.key_insights:
|
|
52
|
+
state.context.add_insight(insight)
|
|
53
|
+
for new_q in output.new_questions:
|
|
54
|
+
if new_q not in state.sub_questions:
|
|
55
|
+
state.sub_questions.append(new_q)
|
|
56
|
+
|
|
57
|
+
state.last_eval = output
|
|
58
|
+
state.iterations += 1
|
|
59
|
+
|
|
60
|
+
if output.key_insights:
|
|
61
|
+
log(deps.console, " [bold]Key insights:[/bold]")
|
|
62
|
+
for ins in output.key_insights:
|
|
63
|
+
log(deps.console, f" • {ins}")
|
|
64
|
+
log(
|
|
65
|
+
deps.console,
|
|
66
|
+
f" Confidence: [yellow]{output.confidence_score:.1%}[/yellow]",
|
|
67
|
+
)
|
|
68
|
+
status = "[green]Yes[/green]" if output.is_sufficient else "[red]No[/red]"
|
|
69
|
+
log(deps.console, f" Sufficient: {status}")
|
|
70
|
+
|
|
71
|
+
from haiku.rag.research.nodes.search import SearchDispatchNode
|
|
72
|
+
|
|
73
|
+
if (
|
|
74
|
+
output.is_sufficient
|
|
75
|
+
and output.confidence_score >= state.confidence_threshold
|
|
76
|
+
) or state.iterations >= state.max_iterations:
|
|
77
|
+
log(deps.console, "\n[bold green]✅ Stopping research.[/bold green]")
|
|
78
|
+
return SynthesizeNode(self.provider, self.model)
|
|
79
|
+
|
|
80
|
+
return SearchDispatchNode(self.provider, self.model)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import Agent, RunContext
|
|
4
|
+
from pydantic_graph import BaseNode, GraphRunContext
|
|
5
|
+
|
|
6
|
+
from haiku.rag.research.common import get_model, log
|
|
7
|
+
from haiku.rag.research.dependencies import ResearchDependencies
|
|
8
|
+
from haiku.rag.research.models import ResearchPlan, ResearchReport
|
|
9
|
+
from haiku.rag.research.nodes.search import SearchDispatchNode
|
|
10
|
+
from haiku.rag.research.prompts import PLAN_PROMPT
|
|
11
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class PlanNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
16
|
+
provider: str
|
|
17
|
+
model: str
|
|
18
|
+
|
|
19
|
+
async def run(
|
|
20
|
+
self, ctx: GraphRunContext[ResearchState, ResearchDeps]
|
|
21
|
+
) -> BaseNode[ResearchState, ResearchDeps, ResearchReport]:
|
|
22
|
+
state = ctx.state
|
|
23
|
+
deps = ctx.deps
|
|
24
|
+
|
|
25
|
+
log(deps.console, "\n[bold cyan]📋 Creating research plan...[/bold cyan]")
|
|
26
|
+
|
|
27
|
+
plan_agent = Agent(
|
|
28
|
+
model=get_model(self.provider, self.model),
|
|
29
|
+
output_type=ResearchPlan,
|
|
30
|
+
instructions=(
|
|
31
|
+
PLAN_PROMPT
|
|
32
|
+
+ "\n\nUse the gather_context tool once on the main question before planning."
|
|
33
|
+
),
|
|
34
|
+
retries=3,
|
|
35
|
+
deps_type=ResearchDependencies,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
@plan_agent.tool
|
|
39
|
+
async def gather_context(
|
|
40
|
+
ctx2: RunContext[ResearchDependencies], query: str, limit: int = 6
|
|
41
|
+
) -> str:
|
|
42
|
+
results = await ctx2.deps.client.search(query, limit=limit)
|
|
43
|
+
expanded = await ctx2.deps.client.expand_context(results)
|
|
44
|
+
return "\n\n".join(chunk.content for chunk, _ in expanded)
|
|
45
|
+
|
|
46
|
+
prompt = (
|
|
47
|
+
"Plan a focused research approach for the main question.\n\n"
|
|
48
|
+
f"Main question: {state.question}"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
agent_deps = ResearchDependencies(
|
|
52
|
+
client=deps.client, context=state.context, console=deps.console
|
|
53
|
+
)
|
|
54
|
+
plan_result = await plan_agent.run(prompt, deps=agent_deps)
|
|
55
|
+
state.sub_questions = list(plan_result.output.sub_questions)
|
|
56
|
+
|
|
57
|
+
log(deps.console, "\n[bold green]✅ Research Plan Created:[/bold green]")
|
|
58
|
+
log(deps.console, f" [bold]Main Question:[/bold] {state.question}")
|
|
59
|
+
log(deps.console, " [bold]Sub-questions:[/bold]")
|
|
60
|
+
for i, sq in enumerate(state.sub_questions, 1):
|
|
61
|
+
log(deps.console, f" {i}. {sq}")
|
|
62
|
+
|
|
63
|
+
return SearchDispatchNode(self.provider, self.model)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic_ai import Agent, RunContext
|
|
6
|
+
from pydantic_ai.format_prompt import format_as_xml
|
|
7
|
+
from pydantic_ai.output import ToolOutput
|
|
8
|
+
from pydantic_graph import BaseNode, GraphRunContext
|
|
9
|
+
|
|
10
|
+
from haiku.rag.research.common import get_model, log
|
|
11
|
+
from haiku.rag.research.dependencies import ResearchDependencies
|
|
12
|
+
from haiku.rag.research.models import ResearchReport, SearchAnswer
|
|
13
|
+
from haiku.rag.research.prompts import SEARCH_AGENT_PROMPT
|
|
14
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class SearchDispatchNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
19
|
+
provider: str
|
|
20
|
+
model: str
|
|
21
|
+
|
|
22
|
+
async def run(
|
|
23
|
+
self, ctx: GraphRunContext[ResearchState, ResearchDeps]
|
|
24
|
+
) -> BaseNode[ResearchState, ResearchDeps, ResearchReport]:
|
|
25
|
+
state = ctx.state
|
|
26
|
+
deps = ctx.deps
|
|
27
|
+
if not state.sub_questions:
|
|
28
|
+
from haiku.rag.research.nodes.evaluate import EvaluateNode
|
|
29
|
+
|
|
30
|
+
return EvaluateNode(self.provider, self.model)
|
|
31
|
+
|
|
32
|
+
# Take up to max_concurrency questions and answer them concurrently
|
|
33
|
+
take = max(1, state.max_concurrency)
|
|
34
|
+
batch: list[str] = []
|
|
35
|
+
while state.sub_questions and len(batch) < take:
|
|
36
|
+
batch.append(state.sub_questions.pop(0))
|
|
37
|
+
|
|
38
|
+
async def answer_one(sub_q: str) -> SearchAnswer | None:
|
|
39
|
+
log(
|
|
40
|
+
deps.console,
|
|
41
|
+
f"\n[bold cyan]🔍 Searching & Answering:[/bold cyan] {sub_q}",
|
|
42
|
+
)
|
|
43
|
+
agent = Agent(
|
|
44
|
+
model=get_model(self.provider, self.model),
|
|
45
|
+
output_type=ToolOutput(SearchAnswer, max_retries=3),
|
|
46
|
+
instructions=SEARCH_AGENT_PROMPT,
|
|
47
|
+
retries=3,
|
|
48
|
+
deps_type=ResearchDependencies,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@agent.tool
|
|
52
|
+
async def search_and_answer(
|
|
53
|
+
ctx2: RunContext[ResearchDependencies], query: str, limit: int = 5
|
|
54
|
+
) -> str:
|
|
55
|
+
search_results = await ctx2.deps.client.search(query, limit=limit)
|
|
56
|
+
expanded = await ctx2.deps.client.expand_context(search_results)
|
|
57
|
+
|
|
58
|
+
entries: list[dict[str, Any]] = [
|
|
59
|
+
{
|
|
60
|
+
"text": chunk.content,
|
|
61
|
+
"score": score,
|
|
62
|
+
"document_uri": (
|
|
63
|
+
chunk.document_title or chunk.document_uri or ""
|
|
64
|
+
),
|
|
65
|
+
}
|
|
66
|
+
for chunk, score in expanded
|
|
67
|
+
]
|
|
68
|
+
if not entries:
|
|
69
|
+
return f"No relevant information found in the knowledge base for: {query}"
|
|
70
|
+
|
|
71
|
+
return format_as_xml(entries, root_tag="snippets")
|
|
72
|
+
|
|
73
|
+
agent_deps = ResearchDependencies(
|
|
74
|
+
client=deps.client, context=state.context, console=deps.console
|
|
75
|
+
)
|
|
76
|
+
try:
|
|
77
|
+
result = await agent.run(sub_q, deps=agent_deps)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
log(deps.console, f"[red]Search failed:[/red] {e}")
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
return result.output
|
|
83
|
+
|
|
84
|
+
answers = await asyncio.gather(*(answer_one(q) for q in batch))
|
|
85
|
+
for ans in answers:
|
|
86
|
+
if ans is None:
|
|
87
|
+
continue
|
|
88
|
+
state.context.add_qa_response(ans)
|
|
89
|
+
if deps.console:
|
|
90
|
+
preview = ans.answer[:150] + ("…" if len(ans.answer) > 150 else "")
|
|
91
|
+
log(deps.console, f" [green]✓[/green] {preview}")
|
|
92
|
+
|
|
93
|
+
return SearchDispatchNode(self.provider, self.model)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import Agent
|
|
4
|
+
from pydantic_graph import BaseNode, End, GraphRunContext
|
|
5
|
+
|
|
6
|
+
from haiku.rag.research.common import format_context_for_prompt, get_model, log
|
|
7
|
+
from haiku.rag.research.dependencies import (
|
|
8
|
+
ResearchDependencies,
|
|
9
|
+
)
|
|
10
|
+
from haiku.rag.research.models import ResearchReport
|
|
11
|
+
from haiku.rag.research.prompts import SYNTHESIS_AGENT_PROMPT
|
|
12
|
+
from haiku.rag.research.state import ResearchDeps, ResearchState
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class SynthesizeNode(BaseNode[ResearchState, ResearchDeps, ResearchReport]):
|
|
17
|
+
provider: str
|
|
18
|
+
model: str
|
|
19
|
+
|
|
20
|
+
async def run(
|
|
21
|
+
self, ctx: GraphRunContext[ResearchState, ResearchDeps]
|
|
22
|
+
) -> End[ResearchReport]:
|
|
23
|
+
state = ctx.state
|
|
24
|
+
deps = ctx.deps
|
|
25
|
+
|
|
26
|
+
log(
|
|
27
|
+
deps.console,
|
|
28
|
+
"\n[bold cyan]📝 Generating final research report...[/bold cyan]",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
agent = Agent(
|
|
32
|
+
model=get_model(self.provider, self.model),
|
|
33
|
+
output_type=ResearchReport,
|
|
34
|
+
instructions=SYNTHESIS_AGENT_PROMPT,
|
|
35
|
+
retries=3,
|
|
36
|
+
deps_type=ResearchDependencies,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
context_xml = format_context_for_prompt(state.context)
|
|
40
|
+
prompt = (
|
|
41
|
+
"Generate a comprehensive research report based on all gathered information.\n\n"
|
|
42
|
+
f"{context_xml}\n\n"
|
|
43
|
+
"Create a detailed report that synthesizes all findings into a coherent response."
|
|
44
|
+
)
|
|
45
|
+
agent_deps = ResearchDependencies(
|
|
46
|
+
client=deps.client, context=state.context, console=deps.console
|
|
47
|
+
)
|
|
48
|
+
result = await agent.run(prompt, deps=agent_deps)
|
|
49
|
+
|
|
50
|
+
log(deps.console, "[bold green]✅ Research complete![/bold green]")
|
|
51
|
+
return End(result.output)
|