haiku.rag 0.11.3__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/a2a/__init__.py +176 -0
- haiku/rag/a2a/client.py +271 -0
- haiku/rag/a2a/context.py +68 -0
- haiku/rag/a2a/models.py +21 -0
- haiku/rag/a2a/prompts.py +59 -0
- haiku/rag/a2a/skills.py +75 -0
- haiku/rag/a2a/storage.py +71 -0
- haiku/rag/a2a/worker.py +320 -0
- haiku/rag/app.py +75 -14
- haiku/rag/cli.py +79 -69
- haiku/rag/client.py +10 -4
- haiku/rag/config.py +9 -0
- haiku/rag/mcp.py +99 -0
- haiku/rag/migration.py +3 -3
- haiku/rag/qa/__init__.py +6 -1
- haiku/rag/qa/agent.py +6 -6
- haiku/rag/store/engine.py +33 -5
- haiku/rag/store/repositories/chunk.py +0 -28
- haiku/rag/store/repositories/document.py +7 -0
- {haiku_rag-0.11.3.dist-info → haiku_rag-0.12.0.dist-info}/METADATA +31 -10
- {haiku_rag-0.11.3.dist-info → haiku_rag-0.12.0.dist-info}/RECORD +24 -16
- {haiku_rag-0.11.3.dist-info → haiku_rag-0.12.0.dist-info}/WHEEL +0 -0
- {haiku_rag-0.11.3.dist-info → haiku_rag-0.12.0.dist-info}/entry_points.txt +0 -0
- {haiku_rag-0.11.3.dist-info → haiku_rag-0.12.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from contextlib import asynccontextmanager
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import logfire
|
|
6
|
+
from pydantic_ai import Agent, RunContext
|
|
7
|
+
|
|
8
|
+
from haiku.rag.config import Config
|
|
9
|
+
from haiku.rag.graph.common import get_model
|
|
10
|
+
|
|
11
|
+
from .context import load_message_history, save_message_history
|
|
12
|
+
from .models import AgentDependencies, SearchResult
|
|
13
|
+
from .prompts import A2A_SYSTEM_PROMPT
|
|
14
|
+
from .skills import extract_question_from_task, get_agent_skills
|
|
15
|
+
from .storage import LRUMemoryStorage
|
|
16
|
+
from .worker import ConversationalWorker
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from fasta2a import FastA2A # type: ignore
|
|
20
|
+
from fasta2a.broker import InMemoryBroker # type: ignore
|
|
21
|
+
from fasta2a.storage import InMemoryStorage # type: ignore
|
|
22
|
+
except ImportError as e:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"A2A support requires the 'a2a' extra. "
|
|
25
|
+
"Install with: uv pip install 'haiku.rag[a2a]'"
|
|
26
|
+
) from e
|
|
27
|
+
|
|
28
|
+
logfire.configure(send_to_logfire="if-token-present", service_name="a2a")
|
|
29
|
+
logfire.instrument_pydantic_ai()
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"create_a2a_app",
|
|
35
|
+
"load_message_history",
|
|
36
|
+
"save_message_history",
|
|
37
|
+
"extract_question_from_task",
|
|
38
|
+
"get_agent_skills",
|
|
39
|
+
"LRUMemoryStorage",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def create_a2a_app(
|
|
44
|
+
db_path: Path,
|
|
45
|
+
security_schemes: dict | None = None,
|
|
46
|
+
security: list[dict[str, list[str]]] | None = None,
|
|
47
|
+
):
|
|
48
|
+
"""Create an A2A app for the conversational QA agent.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
db_path: Path to the LanceDB database
|
|
52
|
+
security_schemes: Optional security scheme definitions for the AgentCard
|
|
53
|
+
security: Optional security requirements for the AgentCard
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
A FastA2A ASGI application
|
|
57
|
+
"""
|
|
58
|
+
base_storage = InMemoryStorage()
|
|
59
|
+
storage = LRUMemoryStorage(
|
|
60
|
+
storage=base_storage, max_contexts=Config.A2A_MAX_CONTEXTS
|
|
61
|
+
)
|
|
62
|
+
broker = InMemoryBroker()
|
|
63
|
+
|
|
64
|
+
# Create the agent with native search tool
|
|
65
|
+
model = get_model(Config.QA_PROVIDER, Config.QA_MODEL)
|
|
66
|
+
agent = Agent(
|
|
67
|
+
model=model,
|
|
68
|
+
deps_type=AgentDependencies,
|
|
69
|
+
system_prompt=A2A_SYSTEM_PROMPT,
|
|
70
|
+
retries=3,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@agent.tool
|
|
74
|
+
async def search_documents(
|
|
75
|
+
ctx: RunContext[AgentDependencies],
|
|
76
|
+
query: str,
|
|
77
|
+
limit: int = 3,
|
|
78
|
+
) -> list[SearchResult]:
|
|
79
|
+
"""Search the knowledge base for relevant documents.
|
|
80
|
+
|
|
81
|
+
Returns chunks of text with their relevance scores and document URIs.
|
|
82
|
+
Use get_full_document if you need to see the complete document content.
|
|
83
|
+
"""
|
|
84
|
+
search_results = await ctx.deps.client.search(query, limit=limit)
|
|
85
|
+
expanded_results = await ctx.deps.client.expand_context(search_results)
|
|
86
|
+
|
|
87
|
+
return [
|
|
88
|
+
SearchResult(
|
|
89
|
+
content=chunk.content,
|
|
90
|
+
score=score,
|
|
91
|
+
document_title=chunk.document_title,
|
|
92
|
+
document_uri=(chunk.document_uri or ""),
|
|
93
|
+
)
|
|
94
|
+
for chunk, score in expanded_results
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
@agent.tool
|
|
98
|
+
async def get_full_document(
|
|
99
|
+
ctx: RunContext[AgentDependencies],
|
|
100
|
+
document_uri: str,
|
|
101
|
+
) -> str:
|
|
102
|
+
"""Retrieve the complete content of a document by its URI.
|
|
103
|
+
|
|
104
|
+
Use this when you need more context than what's in a search result chunk.
|
|
105
|
+
The document_uri comes from search_documents results.
|
|
106
|
+
"""
|
|
107
|
+
document = await ctx.deps.client.get_document_by_uri(document_uri)
|
|
108
|
+
if document is None:
|
|
109
|
+
return f"Document not found: {document_uri}"
|
|
110
|
+
|
|
111
|
+
return document.content
|
|
112
|
+
|
|
113
|
+
worker = ConversationalWorker(
|
|
114
|
+
storage=storage,
|
|
115
|
+
broker=broker,
|
|
116
|
+
db_path=db_path,
|
|
117
|
+
agent=agent, # type: ignore
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Create FastA2A app with custom worker lifecycle
|
|
121
|
+
@asynccontextmanager
|
|
122
|
+
async def lifespan(app):
|
|
123
|
+
logger.info(f"Started A2A server (max contexts: {Config.A2A_MAX_CONTEXTS})")
|
|
124
|
+
async with app.task_manager:
|
|
125
|
+
async with worker.run():
|
|
126
|
+
yield
|
|
127
|
+
|
|
128
|
+
app = FastA2A(
|
|
129
|
+
storage=storage,
|
|
130
|
+
broker=broker,
|
|
131
|
+
name="haiku-rag",
|
|
132
|
+
description="Conversational question answering agent powered by haiku.rag RAG system",
|
|
133
|
+
skills=get_agent_skills(),
|
|
134
|
+
lifespan=lifespan,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Add security configuration if provided
|
|
138
|
+
if security_schemes or security:
|
|
139
|
+
# Monkey-patch the agent card endpoint to include security
|
|
140
|
+
async def _agent_card_endpoint_with_security(request):
|
|
141
|
+
from fasta2a.schema import AgentCapabilities, AgentCard, agent_card_ta
|
|
142
|
+
from starlette.responses import Response
|
|
143
|
+
|
|
144
|
+
if app._agent_card_json_schema is None:
|
|
145
|
+
agent_card = AgentCard(
|
|
146
|
+
name=app.name,
|
|
147
|
+
description=app.description
|
|
148
|
+
or "An AI agent exposed as an A2A agent.",
|
|
149
|
+
url=app.url,
|
|
150
|
+
version=app.version,
|
|
151
|
+
protocol_version="0.3.0",
|
|
152
|
+
skills=app.skills,
|
|
153
|
+
default_input_modes=app.default_input_modes,
|
|
154
|
+
default_output_modes=app.default_output_modes,
|
|
155
|
+
capabilities=AgentCapabilities(
|
|
156
|
+
streaming=False,
|
|
157
|
+
push_notifications=False,
|
|
158
|
+
state_transition_history=False,
|
|
159
|
+
),
|
|
160
|
+
)
|
|
161
|
+
if app.provider is not None:
|
|
162
|
+
agent_card["provider"] = app.provider
|
|
163
|
+
if security_schemes:
|
|
164
|
+
agent_card["security_schemes"] = security_schemes
|
|
165
|
+
if security:
|
|
166
|
+
agent_card["security"] = security
|
|
167
|
+
app._agent_card_json_schema = agent_card_ta.dump_json(
|
|
168
|
+
agent_card, by_alias=True
|
|
169
|
+
)
|
|
170
|
+
return Response(
|
|
171
|
+
content=app._agent_card_json_schema, media_type="application/json"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
app._agent_card_endpoint = _agent_card_endpoint_with_security
|
|
175
|
+
|
|
176
|
+
return app
|
haiku/rag/a2a/client.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import uuid
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.markdown import Markdown
|
|
8
|
+
from rich.prompt import Prompt
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class A2AClient:
|
|
12
|
+
"""Simple A2A protocol client."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, base_url: str = "http://localhost:8000"):
|
|
15
|
+
"""Initialize A2A client.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
base_url: Base URL of the A2A server
|
|
19
|
+
"""
|
|
20
|
+
self.base_url = base_url.rstrip("/")
|
|
21
|
+
self.client = httpx.AsyncClient(timeout=60.0)
|
|
22
|
+
|
|
23
|
+
async def close(self):
|
|
24
|
+
"""Close the HTTP client."""
|
|
25
|
+
await self.client.aclose()
|
|
26
|
+
|
|
27
|
+
async def get_agent_card(self) -> dict[str, Any]:
|
|
28
|
+
"""Fetch the agent card from the A2A server.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Agent card dictionary with agent capabilities and metadata
|
|
32
|
+
"""
|
|
33
|
+
response = await self.client.get(f"{self.base_url}/.well-known/agent-card.json")
|
|
34
|
+
response.raise_for_status()
|
|
35
|
+
return response.json()
|
|
36
|
+
|
|
37
|
+
async def send_message(
|
|
38
|
+
self,
|
|
39
|
+
text: str,
|
|
40
|
+
context_id: str | None = None,
|
|
41
|
+
skill_id: str | None = None,
|
|
42
|
+
) -> dict[str, Any]:
|
|
43
|
+
"""Send a message to the A2A agent and wait for completion.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
text: Message text to send
|
|
47
|
+
context_id: Optional conversation context ID (creates new if None)
|
|
48
|
+
skill_id: Optional skill ID to use (defaults to document-qa)
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Completed task with response messages and artifacts
|
|
52
|
+
"""
|
|
53
|
+
if context_id is None:
|
|
54
|
+
context_id = str(uuid.uuid4())
|
|
55
|
+
|
|
56
|
+
message_id = str(uuid.uuid4())
|
|
57
|
+
|
|
58
|
+
payload: dict[str, Any] = {
|
|
59
|
+
"jsonrpc": "2.0",
|
|
60
|
+
"method": "message/send",
|
|
61
|
+
"params": {
|
|
62
|
+
"contextId": context_id,
|
|
63
|
+
"message": {
|
|
64
|
+
"kind": "message",
|
|
65
|
+
"role": "user",
|
|
66
|
+
"messageId": message_id,
|
|
67
|
+
"parts": [{"kind": "text", "text": text}],
|
|
68
|
+
},
|
|
69
|
+
},
|
|
70
|
+
"id": 1,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
if skill_id:
|
|
74
|
+
payload["params"]["skillId"] = skill_id
|
|
75
|
+
|
|
76
|
+
response = await self.client.post(
|
|
77
|
+
self.base_url,
|
|
78
|
+
json=payload,
|
|
79
|
+
headers={"Content-Type": "application/json"},
|
|
80
|
+
)
|
|
81
|
+
response.raise_for_status()
|
|
82
|
+
initial_response = response.json()
|
|
83
|
+
|
|
84
|
+
# Extract task ID from response
|
|
85
|
+
result = initial_response.get("result", {})
|
|
86
|
+
task_id = result.get("id")
|
|
87
|
+
|
|
88
|
+
if not task_id:
|
|
89
|
+
return initial_response
|
|
90
|
+
|
|
91
|
+
# Poll for task completion
|
|
92
|
+
return await self.wait_for_task(task_id)
|
|
93
|
+
|
|
94
|
+
async def wait_for_task(
|
|
95
|
+
self, task_id: str, max_wait: int = 60, poll_interval: float = 0.5
|
|
96
|
+
) -> dict[str, Any]:
|
|
97
|
+
"""Poll for task completion.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
task_id: Task ID to poll for
|
|
101
|
+
max_wait: Maximum time to wait in seconds
|
|
102
|
+
poll_interval: Interval between polls in seconds
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Completed task result
|
|
106
|
+
"""
|
|
107
|
+
import time
|
|
108
|
+
|
|
109
|
+
start_time = time.time()
|
|
110
|
+
|
|
111
|
+
while time.time() - start_time < max_wait:
|
|
112
|
+
payload = {
|
|
113
|
+
"jsonrpc": "2.0",
|
|
114
|
+
"method": "tasks/get",
|
|
115
|
+
"params": {"id": task_id},
|
|
116
|
+
"id": 2,
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
response = await self.client.post(
|
|
120
|
+
self.base_url,
|
|
121
|
+
json=payload,
|
|
122
|
+
headers={"Content-Type": "application/json"},
|
|
123
|
+
)
|
|
124
|
+
response.raise_for_status()
|
|
125
|
+
task = response.json()
|
|
126
|
+
|
|
127
|
+
result = task.get("result", {})
|
|
128
|
+
status = result.get("status", {})
|
|
129
|
+
state = status.get("state")
|
|
130
|
+
|
|
131
|
+
if state == "completed":
|
|
132
|
+
return task
|
|
133
|
+
elif state == "failed":
|
|
134
|
+
raise Exception(f"Task failed: {task}")
|
|
135
|
+
|
|
136
|
+
await asyncio.sleep(poll_interval)
|
|
137
|
+
|
|
138
|
+
raise TimeoutError(f"Task {task_id} did not complete within {max_wait}s")
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def print_agent_card(card: dict[str, Any], console: Console):
|
|
142
|
+
"""Pretty print the agent card using Rich."""
|
|
143
|
+
console.print()
|
|
144
|
+
console.print("[bold]Agent Card[/bold]")
|
|
145
|
+
console.rule()
|
|
146
|
+
|
|
147
|
+
console.print(f" [repr.attrib_name]name[/repr.attrib_name]: {card.get('name')}")
|
|
148
|
+
console.print(
|
|
149
|
+
f" [repr.attrib_name]description[/repr.attrib_name]: {card.get('description')}"
|
|
150
|
+
)
|
|
151
|
+
console.print(
|
|
152
|
+
f" [repr.attrib_name]version[/repr.attrib_name]: {card.get('version')}"
|
|
153
|
+
)
|
|
154
|
+
console.print(
|
|
155
|
+
f" [repr.attrib_name]protocol version[/repr.attrib_name]: {card.get('protocolVersion')}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
skills = card.get("skills", [])
|
|
159
|
+
console.print(f"\n[bold cyan]Skills ({len(skills)}):[/bold cyan]")
|
|
160
|
+
for skill in skills:
|
|
161
|
+
console.print(f" • {skill.get('id')}: {skill.get('name')}")
|
|
162
|
+
console.print(f" [dim]{skill.get('description')}[/dim]")
|
|
163
|
+
examples = skill.get("examples", [])
|
|
164
|
+
if examples:
|
|
165
|
+
console.print(f" [dim]Examples: {', '.join(examples[:2])}[/dim]")
|
|
166
|
+
console.print()
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def print_response(response: dict[str, Any], console: Console):
|
|
170
|
+
"""Pretty print the A2A response using Rich."""
|
|
171
|
+
if "error" in response:
|
|
172
|
+
console.print(f"[red]Error: {response['error']}[/red]")
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
result = response.get("result", {})
|
|
176
|
+
|
|
177
|
+
# Get messages from history and artifacts from completed task
|
|
178
|
+
history = result.get("history", [])
|
|
179
|
+
artifacts = result.get("artifacts", [])
|
|
180
|
+
|
|
181
|
+
# Print agent messages from history with markdown rendering
|
|
182
|
+
for msg in history:
|
|
183
|
+
if msg.get("role") == "agent":
|
|
184
|
+
for part in msg.get("parts", []):
|
|
185
|
+
if part.get("kind") == "text":
|
|
186
|
+
text = part.get("text", "")
|
|
187
|
+
# Render as markdown
|
|
188
|
+
console.print()
|
|
189
|
+
console.print("[bold green]Answer:[/bold green]")
|
|
190
|
+
console.print(Markdown(text))
|
|
191
|
+
|
|
192
|
+
# Print artifacts summary with details
|
|
193
|
+
if artifacts:
|
|
194
|
+
summary_lines = []
|
|
195
|
+
|
|
196
|
+
for artifact in artifacts:
|
|
197
|
+
name = artifact.get("name", "")
|
|
198
|
+
parts = artifact.get("parts", [])
|
|
199
|
+
|
|
200
|
+
if name == "search_results" and parts:
|
|
201
|
+
data = parts[0].get("data", {})
|
|
202
|
+
query = data.get("query", "")
|
|
203
|
+
results = data.get("results", [])
|
|
204
|
+
summary_lines.append(f"🔍 search: '{query}' ({len(results)} results)")
|
|
205
|
+
|
|
206
|
+
elif name == "document" and parts:
|
|
207
|
+
part = parts[0]
|
|
208
|
+
if part.get("kind") == "text":
|
|
209
|
+
text = part.get("text", "")
|
|
210
|
+
length = len(text)
|
|
211
|
+
summary_lines.append(f"📄 document ({length} chars)")
|
|
212
|
+
|
|
213
|
+
elif name == "qa_result" and parts:
|
|
214
|
+
data = parts[0].get("data", {})
|
|
215
|
+
skill = data.get("skill", "unknown")
|
|
216
|
+
summary_lines.append(f"💬 {skill}")
|
|
217
|
+
|
|
218
|
+
if summary_lines:
|
|
219
|
+
console.print(f"[dim]{' • '.join(summary_lines)}[/dim]")
|
|
220
|
+
|
|
221
|
+
console.print()
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
async def run_interactive_client(url: str = "http://localhost:8000"):
|
|
225
|
+
"""Run the interactive A2A client.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
url: Base URL of the A2A server
|
|
229
|
+
"""
|
|
230
|
+
console = Console()
|
|
231
|
+
client = A2AClient(url)
|
|
232
|
+
|
|
233
|
+
console.print("[bold]haiku.rag A2A interactive client[/bold]")
|
|
234
|
+
console.print()
|
|
235
|
+
|
|
236
|
+
# Fetch and display agent card
|
|
237
|
+
console.print("[dim]Fetching agent card...[/dim]")
|
|
238
|
+
try:
|
|
239
|
+
card = await client.get_agent_card()
|
|
240
|
+
print_agent_card(card, console)
|
|
241
|
+
except Exception as e:
|
|
242
|
+
console.print(f"[red]Error fetching agent card: {e}[/red]")
|
|
243
|
+
await client.close()
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
# Create a conversation context
|
|
247
|
+
context_id = str(uuid.uuid4())
|
|
248
|
+
console.print(f"[dim]context id: {context_id}[/dim]")
|
|
249
|
+
console.print("[dim]Type your questions (or 'quit' to exit)[/dim]\n")
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
while True:
|
|
253
|
+
try:
|
|
254
|
+
question = Prompt.ask("[bold blue]Question[/bold blue]").strip()
|
|
255
|
+
if not question:
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
if question.lower() in ("quit", "exit", "q"):
|
|
259
|
+
console.print("\n[dim]Goodbye![/dim]")
|
|
260
|
+
break
|
|
261
|
+
|
|
262
|
+
response = await client.send_message(question, context_id=context_id)
|
|
263
|
+
print_response(response, console)
|
|
264
|
+
|
|
265
|
+
except KeyboardInterrupt:
|
|
266
|
+
console.print("\n\n[dim]Exiting...[/dim]")
|
|
267
|
+
break
|
|
268
|
+
except Exception as e:
|
|
269
|
+
console.print(f"\n[red]Error: {e}[/red]\n")
|
|
270
|
+
finally:
|
|
271
|
+
await client.close()
|
haiku/rag/a2a/context.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
|
|
3
|
+
from pydantic import TypeAdapter
|
|
4
|
+
from pydantic_ai.messages import ModelMessage
|
|
5
|
+
from pydantic_core import to_jsonable_python
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from fasta2a.schema import DataPart, Message # type: ignore
|
|
9
|
+
except ImportError as e:
|
|
10
|
+
raise ImportError(
|
|
11
|
+
"A2A support requires the 'a2a' extra. "
|
|
12
|
+
"Install with: uv pip install 'haiku.rag[a2a]'"
|
|
13
|
+
) from e
|
|
14
|
+
|
|
15
|
+
ModelMessagesTypeAdapter = TypeAdapter(list[ModelMessage])
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def load_message_history(context: list[Message]) -> list[ModelMessage]:
|
|
19
|
+
"""Load pydantic-ai message history from A2A context.
|
|
20
|
+
|
|
21
|
+
The context stores serialized pydantic-ai message history directly,
|
|
22
|
+
which we deserialize and return.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
context: A2A context messages
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
List of pydantic-ai ModelMessage objects
|
|
29
|
+
"""
|
|
30
|
+
if not context:
|
|
31
|
+
return []
|
|
32
|
+
|
|
33
|
+
# Context should contain a single "state" message with full history
|
|
34
|
+
for msg in context:
|
|
35
|
+
parts = msg.get("parts", [])
|
|
36
|
+
for part in parts:
|
|
37
|
+
if part.get("kind") == "data":
|
|
38
|
+
metadata = part.get("metadata", {})
|
|
39
|
+
if metadata.get("type") == "conversation_state":
|
|
40
|
+
stored_history = part.get("data", {}).get("message_history", [])
|
|
41
|
+
if stored_history:
|
|
42
|
+
return ModelMessagesTypeAdapter.validate_python(stored_history)
|
|
43
|
+
|
|
44
|
+
return []
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def save_message_history(message_history: list[ModelMessage]) -> Message:
|
|
48
|
+
"""Save pydantic-ai message history to A2A context format.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
message_history: Full pydantic-ai message history
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A2A Message containing the serialized state (stored as agent role)
|
|
55
|
+
"""
|
|
56
|
+
serialized = to_jsonable_python(message_history)
|
|
57
|
+
return Message(
|
|
58
|
+
role="agent",
|
|
59
|
+
parts=[
|
|
60
|
+
DataPart(
|
|
61
|
+
kind="data",
|
|
62
|
+
data={"message_history": serialized},
|
|
63
|
+
metadata={"type": "conversation_state"},
|
|
64
|
+
)
|
|
65
|
+
],
|
|
66
|
+
kind="message",
|
|
67
|
+
message_id=str(uuid.uuid4()),
|
|
68
|
+
)
|
haiku/rag/a2a/models.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
from haiku.rag.client import HaikuRAG
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SearchResult(BaseModel):
|
|
7
|
+
"""Search result with both title and URI for A2A agent."""
|
|
8
|
+
|
|
9
|
+
content: str = Field(description="The document text content")
|
|
10
|
+
score: float = Field(description="Relevance score (higher is more relevant)")
|
|
11
|
+
document_title: str | None = Field(
|
|
12
|
+
description="Human-readable document title", default=None
|
|
13
|
+
)
|
|
14
|
+
document_uri: str = Field(description="Document URI/path for get_full_document")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AgentDependencies(BaseModel):
|
|
18
|
+
"""Dependencies for the A2A conversational agent."""
|
|
19
|
+
|
|
20
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
21
|
+
client: HaikuRAG
|
haiku/rag/a2a/prompts.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
A2A_SYSTEM_PROMPT = """You are Haiku.rag, an AI assistant that helps users find information from a document knowledge base.
|
|
2
|
+
|
|
3
|
+
IMPORTANT: You are NOT any person mentioned in the documents. You retrieve and present information about them.
|
|
4
|
+
|
|
5
|
+
Tools available:
|
|
6
|
+
- search_documents: Query for relevant text chunks (returns SearchResult objects with content, score, document_title, document_uri)
|
|
7
|
+
- get_full_document: Get complete document content by document_uri
|
|
8
|
+
|
|
9
|
+
Your behavior depends on the operation:
|
|
10
|
+
|
|
11
|
+
## For direct search requests:
|
|
12
|
+
When the user is explicitly searching (e.g., "search for X", "find documents about Y"):
|
|
13
|
+
- Use search_documents tool ONLY
|
|
14
|
+
- Format results as a numbered list using markdown formatting
|
|
15
|
+
- For each result show:
|
|
16
|
+
* First line: *Score in italic* | **source in bold** (title if available, otherwise URI)
|
|
17
|
+
* Second line: The FULL chunk content (do not summarize or truncate)
|
|
18
|
+
- Present results in order of relevance
|
|
19
|
+
- Be concise - just present the search results, do not synthesize or add commentary
|
|
20
|
+
|
|
21
|
+
Example format:
|
|
22
|
+
Found 3 relevant results:
|
|
23
|
+
|
|
24
|
+
1. *Score: 0.95* | **Python Documentation** (/guides/python.md)
|
|
25
|
+
Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability with the use of significant indentation.
|
|
26
|
+
|
|
27
|
+
2. *Score: 0.87* | **/guides/python-basics.md**
|
|
28
|
+
Python supports multiple programming paradigms, including structured, object-oriented and functional programming.
|
|
29
|
+
|
|
30
|
+
## For question-answering:
|
|
31
|
+
When the user asks a question (e.g., "What is Python?", "How does X work?"):
|
|
32
|
+
- For complex questions, use search_documents MULTIPLE TIMES with DIFFERENT queries to gather comprehensive information
|
|
33
|
+
- Example: For "What are the benefits and drawbacks of Python?", search separately for:
|
|
34
|
+
* "Python benefits advantages"
|
|
35
|
+
* "Python drawbacks disadvantages limitations"
|
|
36
|
+
- Synthesize information from all searches into a comprehensive answer
|
|
37
|
+
- Include "Sources:" section at the end listing sources used
|
|
38
|
+
|
|
39
|
+
Sources Format:
|
|
40
|
+
List each source with its title/URI and the relevant chunk content (NOT the score).
|
|
41
|
+
Format: "- **[title or URI]**: [chunk content]"
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
[Your synthesized answer here]
|
|
45
|
+
|
|
46
|
+
Sources:
|
|
47
|
+
- **Python Documentation** (/guides/python.md): Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability.
|
|
48
|
+
- **/guides/python-basics.md**: Python supports multiple programming paradigms, including structured, object-oriented and functional programming.
|
|
49
|
+
|
|
50
|
+
Critical rules:
|
|
51
|
+
- ONLY answer based on information found via search_documents
|
|
52
|
+
- For comprehensive questions, perform MULTIPLE searches with different query angles
|
|
53
|
+
- NEVER fabricate or assume information
|
|
54
|
+
- If not found, say: "I cannot find information about this in the knowledge base."
|
|
55
|
+
- For follow-ups, understand context (pronouns like "he", "it") but always search for facts
|
|
56
|
+
- In Sources, include the actual chunk content from your search results, not summaries
|
|
57
|
+
|
|
58
|
+
Note: When using get_full_document, always use document_uri (not document_title).
|
|
59
|
+
"""
|
haiku/rag/a2a/skills.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
try:
|
|
2
|
+
from fasta2a.schema import Message, Skill # type: ignore
|
|
3
|
+
except ImportError as e:
|
|
4
|
+
raise ImportError(
|
|
5
|
+
"A2A support requires the 'a2a' extra. "
|
|
6
|
+
"Install with: uv pip install 'haiku.rag[a2a]'"
|
|
7
|
+
) from e
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_agent_skills() -> list[Skill]:
|
|
11
|
+
"""Define the skills exposed by the haiku.rag A2A agent.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
List of skills describing the agent's capabilities
|
|
15
|
+
"""
|
|
16
|
+
return [
|
|
17
|
+
Skill(
|
|
18
|
+
id="document-qa",
|
|
19
|
+
name="Document Question Answering",
|
|
20
|
+
description="Answer questions based on a knowledge base of documents using semantic search and retrieval",
|
|
21
|
+
tags=["question-answering", "search", "knowledge-base", "rag"],
|
|
22
|
+
input_modes=["application/json"],
|
|
23
|
+
output_modes=["application/json"],
|
|
24
|
+
examples=[
|
|
25
|
+
"What does the documentation say about authentication?",
|
|
26
|
+
"Find information about Python best practices",
|
|
27
|
+
"Show me the full API documentation",
|
|
28
|
+
],
|
|
29
|
+
),
|
|
30
|
+
Skill(
|
|
31
|
+
id="document-search",
|
|
32
|
+
name="Document Search",
|
|
33
|
+
description="Search for relevant document chunks in the knowledge base using hybrid (semantic and BM25) search",
|
|
34
|
+
tags=["search", "retrieval", "semantic-search"],
|
|
35
|
+
input_modes=["application/json"],
|
|
36
|
+
output_modes=["application/json"],
|
|
37
|
+
examples=[
|
|
38
|
+
"Search for Python best practices",
|
|
39
|
+
"Find documents about authentication",
|
|
40
|
+
"Look for API documentation",
|
|
41
|
+
],
|
|
42
|
+
),
|
|
43
|
+
Skill(
|
|
44
|
+
id="document-retrieve",
|
|
45
|
+
name="Document Retrieval",
|
|
46
|
+
description="Retrieve the complete content of a specific document by its URI",
|
|
47
|
+
tags=["retrieval", "fetch", "document"],
|
|
48
|
+
input_modes=["application/json"],
|
|
49
|
+
output_modes=["application/json"],
|
|
50
|
+
examples=[
|
|
51
|
+
"Get the full content of document X",
|
|
52
|
+
"Retrieve document by URI",
|
|
53
|
+
"Show me the complete document",
|
|
54
|
+
],
|
|
55
|
+
),
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def extract_question_from_task(task_history: list[Message]) -> str | None:
|
|
60
|
+
"""Extract the user's question from task history.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
task_history: Task history messages
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
The question text if found, None otherwise
|
|
67
|
+
"""
|
|
68
|
+
for msg in task_history:
|
|
69
|
+
if msg.get("role") == "user":
|
|
70
|
+
for part in msg.get("parts", []):
|
|
71
|
+
if part.get("kind") == "text":
|
|
72
|
+
text = part.get("text", "").strip()
|
|
73
|
+
if text:
|
|
74
|
+
return text
|
|
75
|
+
return None
|