haiku.rag 0.11.4__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

@@ -0,0 +1,176 @@
1
+ import logging
2
+ from contextlib import asynccontextmanager
3
+ from pathlib import Path
4
+
5
+ import logfire
6
+ from pydantic_ai import Agent, RunContext
7
+
8
+ from haiku.rag.config import Config
9
+ from haiku.rag.graph.common import get_model
10
+
11
+ from .context import load_message_history, save_message_history
12
+ from .models import AgentDependencies, SearchResult
13
+ from .prompts import A2A_SYSTEM_PROMPT
14
+ from .skills import extract_question_from_task, get_agent_skills
15
+ from .storage import LRUMemoryStorage
16
+ from .worker import ConversationalWorker
17
+
18
+ try:
19
+ from fasta2a import FastA2A # type: ignore
20
+ from fasta2a.broker import InMemoryBroker # type: ignore
21
+ from fasta2a.storage import InMemoryStorage # type: ignore
22
+ except ImportError as e:
23
+ raise ImportError(
24
+ "A2A support requires the 'a2a' extra. "
25
+ "Install with: uv pip install 'haiku.rag[a2a]'"
26
+ ) from e
27
+
28
+ logfire.configure(send_to_logfire="if-token-present", service_name="a2a")
29
+ logfire.instrument_pydantic_ai()
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ __all__ = [
34
+ "create_a2a_app",
35
+ "load_message_history",
36
+ "save_message_history",
37
+ "extract_question_from_task",
38
+ "get_agent_skills",
39
+ "LRUMemoryStorage",
40
+ ]
41
+
42
+
43
+ def create_a2a_app(
44
+ db_path: Path,
45
+ security_schemes: dict | None = None,
46
+ security: list[dict[str, list[str]]] | None = None,
47
+ ):
48
+ """Create an A2A app for the conversational QA agent.
49
+
50
+ Args:
51
+ db_path: Path to the LanceDB database
52
+ security_schemes: Optional security scheme definitions for the AgentCard
53
+ security: Optional security requirements for the AgentCard
54
+
55
+ Returns:
56
+ A FastA2A ASGI application
57
+ """
58
+ base_storage = InMemoryStorage()
59
+ storage = LRUMemoryStorage(
60
+ storage=base_storage, max_contexts=Config.A2A_MAX_CONTEXTS
61
+ )
62
+ broker = InMemoryBroker()
63
+
64
+ # Create the agent with native search tool
65
+ model = get_model(Config.QA_PROVIDER, Config.QA_MODEL)
66
+ agent = Agent(
67
+ model=model,
68
+ deps_type=AgentDependencies,
69
+ system_prompt=A2A_SYSTEM_PROMPT,
70
+ retries=3,
71
+ )
72
+
73
+ @agent.tool
74
+ async def search_documents(
75
+ ctx: RunContext[AgentDependencies],
76
+ query: str,
77
+ limit: int = 3,
78
+ ) -> list[SearchResult]:
79
+ """Search the knowledge base for relevant documents.
80
+
81
+ Returns chunks of text with their relevance scores and document URIs.
82
+ Use get_full_document if you need to see the complete document content.
83
+ """
84
+ search_results = await ctx.deps.client.search(query, limit=limit)
85
+ expanded_results = await ctx.deps.client.expand_context(search_results)
86
+
87
+ return [
88
+ SearchResult(
89
+ content=chunk.content,
90
+ score=score,
91
+ document_title=chunk.document_title,
92
+ document_uri=(chunk.document_uri or ""),
93
+ )
94
+ for chunk, score in expanded_results
95
+ ]
96
+
97
+ @agent.tool
98
+ async def get_full_document(
99
+ ctx: RunContext[AgentDependencies],
100
+ document_uri: str,
101
+ ) -> str:
102
+ """Retrieve the complete content of a document by its URI.
103
+
104
+ Use this when you need more context than what's in a search result chunk.
105
+ The document_uri comes from search_documents results.
106
+ """
107
+ document = await ctx.deps.client.get_document_by_uri(document_uri)
108
+ if document is None:
109
+ return f"Document not found: {document_uri}"
110
+
111
+ return document.content
112
+
113
+ worker = ConversationalWorker(
114
+ storage=storage,
115
+ broker=broker,
116
+ db_path=db_path,
117
+ agent=agent, # type: ignore
118
+ )
119
+
120
+ # Create FastA2A app with custom worker lifecycle
121
+ @asynccontextmanager
122
+ async def lifespan(app):
123
+ logger.info(f"Started A2A server (max contexts: {Config.A2A_MAX_CONTEXTS})")
124
+ async with app.task_manager:
125
+ async with worker.run():
126
+ yield
127
+
128
+ app = FastA2A(
129
+ storage=storage,
130
+ broker=broker,
131
+ name="haiku-rag",
132
+ description="Conversational question answering agent powered by haiku.rag RAG system",
133
+ skills=get_agent_skills(),
134
+ lifespan=lifespan,
135
+ )
136
+
137
+ # Add security configuration if provided
138
+ if security_schemes or security:
139
+ # Monkey-patch the agent card endpoint to include security
140
+ async def _agent_card_endpoint_with_security(request):
141
+ from fasta2a.schema import AgentCapabilities, AgentCard, agent_card_ta
142
+ from starlette.responses import Response
143
+
144
+ if app._agent_card_json_schema is None:
145
+ agent_card = AgentCard(
146
+ name=app.name,
147
+ description=app.description
148
+ or "An AI agent exposed as an A2A agent.",
149
+ url=app.url,
150
+ version=app.version,
151
+ protocol_version="0.3.0",
152
+ skills=app.skills,
153
+ default_input_modes=app.default_input_modes,
154
+ default_output_modes=app.default_output_modes,
155
+ capabilities=AgentCapabilities(
156
+ streaming=False,
157
+ push_notifications=False,
158
+ state_transition_history=False,
159
+ ),
160
+ )
161
+ if app.provider is not None:
162
+ agent_card["provider"] = app.provider
163
+ if security_schemes:
164
+ agent_card["security_schemes"] = security_schemes
165
+ if security:
166
+ agent_card["security"] = security
167
+ app._agent_card_json_schema = agent_card_ta.dump_json(
168
+ agent_card, by_alias=True
169
+ )
170
+ return Response(
171
+ content=app._agent_card_json_schema, media_type="application/json"
172
+ )
173
+
174
+ app._agent_card_endpoint = _agent_card_endpoint_with_security
175
+
176
+ return app
@@ -0,0 +1,268 @@
1
+ import asyncio
2
+ import uuid
3
+ from typing import Any
4
+
5
+ import httpx
6
+ from rich.console import Console
7
+ from rich.markdown import Markdown
8
+ from rich.prompt import Prompt
9
+
10
+ try:
11
+ from fasta2a.client import A2AClient as FastA2AClient
12
+ from fasta2a.schema import Message, TextPart
13
+ except ImportError as e:
14
+ raise ImportError(
15
+ "A2A support requires the 'a2a' extra. "
16
+ "Install with: uv pip install 'haiku.rag[a2a]'"
17
+ ) from e
18
+
19
+
20
+ class A2AClient:
21
+ """Interactive A2A protocol client."""
22
+
23
+ def __init__(self, base_url: str = "http://localhost:8000"):
24
+ """Initialize A2A client.
25
+
26
+ Args:
27
+ base_url: Base URL of the A2A server
28
+ """
29
+ self.base_url = base_url.rstrip("/")
30
+ http_client = httpx.AsyncClient(timeout=60.0)
31
+ self._client = FastA2AClient(base_url=base_url, http_client=http_client)
32
+
33
+ async def close(self):
34
+ """Close the HTTP client."""
35
+ await self._client.http_client.aclose()
36
+
37
+ async def get_agent_card(self) -> dict[str, Any]:
38
+ """Fetch the agent card from the A2A server.
39
+
40
+ Returns:
41
+ Agent card dictionary with agent capabilities and metadata
42
+ """
43
+ response = await self._client.http_client.get(
44
+ f"{self.base_url}/.well-known/agent-card.json"
45
+ )
46
+ response.raise_for_status()
47
+ return response.json()
48
+
49
+ async def send_message(
50
+ self,
51
+ text: str,
52
+ context_id: str | None = None,
53
+ skill_id: str | None = None,
54
+ ) -> dict[str, Any]:
55
+ """Send a message to the A2A agent and wait for completion.
56
+
57
+ Args:
58
+ text: Message text to send
59
+ context_id: Optional conversation context ID (creates new if None)
60
+ skill_id: Optional skill ID to use (defaults to document-qa)
61
+
62
+ Returns:
63
+ Completed task with response messages and artifacts
64
+ """
65
+ if context_id is None:
66
+ context_id = str(uuid.uuid4())
67
+
68
+ message = Message(
69
+ kind="message",
70
+ role="user",
71
+ message_id=str(uuid.uuid4()),
72
+ parts=[TextPart(kind="text", text=text)],
73
+ )
74
+
75
+ metadata: dict[str, Any] = {"contextId": context_id}
76
+ if skill_id:
77
+ metadata["skillId"] = skill_id
78
+
79
+ response = await self._client.send_message(message, metadata=metadata)
80
+
81
+ if "error" in response:
82
+ return {"error": response["error"]}
83
+
84
+ result = response.get("result")
85
+ if not result:
86
+ return {"result": result}
87
+
88
+ # Result can be either Task or Message - check if it's a Task with an id
89
+ if result.get("kind") == "task":
90
+ task_id = result.get("id")
91
+ if task_id:
92
+ # Poll for task completion
93
+ return await self.wait_for_task(task_id)
94
+
95
+ # Return the message directly
96
+ return {"result": result}
97
+
98
+ async def wait_for_task(
99
+ self, task_id: str, max_wait: int = 120, poll_interval: float = 0.5
100
+ ) -> dict[str, Any]:
101
+ """Poll for task completion.
102
+
103
+ Args:
104
+ task_id: Task ID to poll for
105
+ max_wait: Maximum time to wait in seconds
106
+ poll_interval: Interval between polls in seconds
107
+
108
+ Returns:
109
+ Completed task result
110
+ """
111
+ import time
112
+
113
+ start_time = time.time()
114
+
115
+ while time.time() - start_time < max_wait:
116
+ task_response = await self._client.get_task(task_id)
117
+
118
+ if "error" in task_response:
119
+ return {"error": task_response["error"]}
120
+
121
+ task = task_response.get("result")
122
+ if not task:
123
+ raise Exception("No task in response")
124
+
125
+ state = task.get("status", {}).get("state")
126
+
127
+ if state == "completed":
128
+ return {"result": task}
129
+ elif state == "failed":
130
+ raise Exception(f"Task failed: {task}")
131
+
132
+ await asyncio.sleep(poll_interval)
133
+
134
+ raise TimeoutError(f"Task {task_id} did not complete within {max_wait}s")
135
+
136
+
137
+ def print_agent_card(card: dict[str, Any], console: Console):
138
+ """Pretty print the agent card using Rich."""
139
+ console.print()
140
+ console.print("[bold]Agent Card[/bold]")
141
+ console.rule()
142
+
143
+ console.print(f" [repr.attrib_name]name[/repr.attrib_name]: {card.get('name')}")
144
+ console.print(
145
+ f" [repr.attrib_name]description[/repr.attrib_name]: {card.get('description')}"
146
+ )
147
+ console.print(
148
+ f" [repr.attrib_name]version[/repr.attrib_name]: {card.get('version')}"
149
+ )
150
+ console.print(
151
+ f" [repr.attrib_name]protocol version[/repr.attrib_name]: {card.get('protocolVersion')}"
152
+ )
153
+
154
+ skills = card.get("skills", [])
155
+ console.print(f"\n[bold cyan]Skills ({len(skills)}):[/bold cyan]")
156
+ for skill in skills:
157
+ console.print(f" • {skill.get('id')}: {skill.get('name')}")
158
+ console.print(f" [dim]{skill.get('description')}[/dim]")
159
+ examples = skill.get("examples", [])
160
+ if examples:
161
+ console.print(f" [dim]Examples: {', '.join(examples[:2])}[/dim]")
162
+ console.print()
163
+
164
+
165
+ def print_response(response: dict[str, Any], console: Console):
166
+ """Pretty print the A2A response using Rich."""
167
+ if "error" in response:
168
+ console.print(f"[red]Error: {response['error']}[/red]")
169
+ return
170
+
171
+ result = response.get("result", {})
172
+
173
+ # Get messages from history and artifacts from completed task
174
+ history = result.get("history", [])
175
+ artifacts = result.get("artifacts", [])
176
+
177
+ # Print agent messages from history with markdown rendering
178
+ for msg in history:
179
+ if msg.get("role") == "agent":
180
+ for part in msg.get("parts", []):
181
+ if part.get("kind") == "text":
182
+ text = part.get("text", "")
183
+ # Render as markdown
184
+ console.print()
185
+ console.print("[bold green]Answer:[/bold green]")
186
+ console.print(Markdown(text))
187
+
188
+ # Print artifacts summary with details
189
+ if artifacts:
190
+ console.rule("[dim]Artifacts generated[/dim]")
191
+ summary_lines = []
192
+
193
+ for artifact in artifacts:
194
+ name = artifact.get("name", "")
195
+ parts = artifact.get("parts", [])
196
+
197
+ if name == "search_results" and parts:
198
+ data = parts[0].get("data", {})
199
+ query = data.get("query", "")
200
+ results = data.get("results", [])
201
+ summary_lines.append(f"🔍 search: '{query}' ({len(results)} results)")
202
+
203
+ elif name == "document" and parts:
204
+ part = parts[0]
205
+ if part.get("kind") == "text":
206
+ text = part.get("text", "")
207
+ length = len(text)
208
+ summary_lines.append(f"📄 document ({length} chars)")
209
+
210
+ elif name == "qa_result" and parts:
211
+ data = parts[0].get("data", {})
212
+ skill = data.get("skill", "unknown")
213
+ summary_lines.append(f"💬 {skill}")
214
+
215
+ if summary_lines:
216
+ console.print(f"[dim]{' • '.join(summary_lines)}[/dim]")
217
+
218
+ console.print()
219
+
220
+
221
+ async def run_interactive_client(url: str = "http://localhost:8000"):
222
+ """Run the interactive A2A client.
223
+
224
+ Args:
225
+ url: Base URL of the A2A server
226
+ """
227
+ console = Console()
228
+ client = A2AClient(url)
229
+
230
+ console.print("[bold]haiku.rag A2A interactive client[/bold]")
231
+ console.print()
232
+
233
+ # Fetch and display agent card
234
+ console.print("[dim]Fetching agent card...[/dim]")
235
+ try:
236
+ card = await client.get_agent_card()
237
+ print_agent_card(card, console)
238
+ except Exception as e:
239
+ console.print(f"[red]Error fetching agent card: {e}[/red]")
240
+ await client.close()
241
+ return
242
+
243
+ # Create a conversation context
244
+ context_id = str(uuid.uuid4())
245
+ console.print(f"[dim]context id: {context_id}[/dim]")
246
+ console.print("[dim]Type your questions (or 'quit' to exit)[/dim]\n")
247
+
248
+ try:
249
+ while True:
250
+ try:
251
+ question = Prompt.ask("[bold blue]Question[/bold blue]").strip()
252
+ if not question:
253
+ continue
254
+
255
+ if question.lower() in ("quit", "exit", "q"):
256
+ console.print("\n[dim]Goodbye![/dim]")
257
+ break
258
+
259
+ response = await client.send_message(question, context_id=context_id)
260
+ print_response(response, console)
261
+
262
+ except KeyboardInterrupt:
263
+ console.print("\n\n[dim]Exiting...[/dim]")
264
+ break
265
+ except Exception as e:
266
+ console.print(f"\n[red]Error: {e}[/red]\n")
267
+ finally:
268
+ await client.close()
@@ -0,0 +1,68 @@
1
+ import uuid
2
+
3
+ from pydantic import TypeAdapter
4
+ from pydantic_ai.messages import ModelMessage
5
+ from pydantic_core import to_jsonable_python
6
+
7
+ try:
8
+ from fasta2a.schema import DataPart, Message # type: ignore
9
+ except ImportError as e:
10
+ raise ImportError(
11
+ "A2A support requires the 'a2a' extra. "
12
+ "Install with: uv pip install 'haiku.rag[a2a]'"
13
+ ) from e
14
+
15
+ ModelMessagesTypeAdapter = TypeAdapter(list[ModelMessage])
16
+
17
+
18
+ def load_message_history(context: list[Message]) -> list[ModelMessage]:
19
+ """Load pydantic-ai message history from A2A context.
20
+
21
+ The context stores serialized pydantic-ai message history directly,
22
+ which we deserialize and return.
23
+
24
+ Args:
25
+ context: A2A context messages
26
+
27
+ Returns:
28
+ List of pydantic-ai ModelMessage objects
29
+ """
30
+ if not context:
31
+ return []
32
+
33
+ # Context should contain a single "state" message with full history
34
+ for msg in context:
35
+ parts = msg.get("parts", [])
36
+ for part in parts:
37
+ if part.get("kind") == "data":
38
+ metadata = part.get("metadata", {})
39
+ if metadata.get("type") == "conversation_state":
40
+ stored_history = part.get("data", {}).get("message_history", [])
41
+ if stored_history:
42
+ return ModelMessagesTypeAdapter.validate_python(stored_history)
43
+
44
+ return []
45
+
46
+
47
+ def save_message_history(message_history: list[ModelMessage]) -> Message:
48
+ """Save pydantic-ai message history to A2A context format.
49
+
50
+ Args:
51
+ message_history: Full pydantic-ai message history
52
+
53
+ Returns:
54
+ A2A Message containing the serialized state (stored as agent role)
55
+ """
56
+ serialized = to_jsonable_python(message_history)
57
+ return Message(
58
+ role="agent",
59
+ parts=[
60
+ DataPart(
61
+ kind="data",
62
+ data={"message_history": serialized},
63
+ metadata={"type": "conversation_state"},
64
+ )
65
+ ],
66
+ kind="message",
67
+ message_id=str(uuid.uuid4()),
68
+ )
@@ -0,0 +1,21 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+ from haiku.rag.client import HaikuRAG
4
+
5
+
6
+ class SearchResult(BaseModel):
7
+ """Search result with both title and URI for A2A agent."""
8
+
9
+ content: str = Field(description="The document text content")
10
+ score: float = Field(description="Relevance score (higher is more relevant)")
11
+ document_title: str | None = Field(
12
+ description="Human-readable document title", default=None
13
+ )
14
+ document_uri: str = Field(description="Document URI/path for get_full_document")
15
+
16
+
17
+ class AgentDependencies(BaseModel):
18
+ """Dependencies for the A2A conversational agent."""
19
+
20
+ model_config = {"arbitrary_types_allowed": True}
21
+ client: HaikuRAG
@@ -0,0 +1,59 @@
1
+ A2A_SYSTEM_PROMPT = """You are Haiku.rag, an AI assistant that helps users find information from a document knowledge base.
2
+
3
+ IMPORTANT: You are NOT any person mentioned in the documents. You retrieve and present information about them.
4
+
5
+ Tools available:
6
+ - search_documents: Query for relevant text chunks (returns SearchResult objects with content, score, document_title, document_uri)
7
+ - get_full_document: Get complete document content by document_uri
8
+
9
+ Your behavior depends on the operation:
10
+
11
+ ## For direct search requests:
12
+ When the user is explicitly searching (e.g., "search for X", "find documents about Y"):
13
+ - Use search_documents tool ONLY
14
+ - Format results as a numbered list using markdown formatting
15
+ - For each result show:
16
+ * First line: *Score in italic* | **source in bold** (title if available, otherwise URI)
17
+ * Second line: The FULL chunk content (do not summarize or truncate)
18
+ - Present results in order of relevance
19
+ - Be concise - just present the search results, do not synthesize or add commentary
20
+
21
+ Example format:
22
+ Found 3 relevant results:
23
+
24
+ 1. *Score: 0.95* | **Python Documentation** (/guides/python.md)
25
+ Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability with the use of significant indentation.
26
+
27
+ 2. *Score: 0.87* | **/guides/python-basics.md**
28
+ Python supports multiple programming paradigms, including structured, object-oriented and functional programming.
29
+
30
+ ## For question-answering:
31
+ When the user asks a question (e.g., "What is Python?", "How does X work?"):
32
+ - For complex questions, use search_documents MULTIPLE TIMES with DIFFERENT queries to gather comprehensive information
33
+ - Example: For "What are the benefits and drawbacks of Python?", search separately for:
34
+ * "Python benefits advantages"
35
+ * "Python drawbacks disadvantages limitations"
36
+ - Synthesize information from all searches into a comprehensive answer
37
+ - Include "Sources:" section at the end listing sources used
38
+
39
+ Sources Format:
40
+ List each source with its title/URI and the relevant chunk content (NOT the score).
41
+ Format: "- **[title or URI]**: [chunk content]"
42
+
43
+ Example:
44
+ [Your synthesized answer here]
45
+
46
+ Sources:
47
+ - **Python Documentation** (/guides/python.md): Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability.
48
+ - **/guides/python-basics.md**: Python supports multiple programming paradigms, including structured, object-oriented and functional programming.
49
+
50
+ Critical rules:
51
+ - ONLY answer based on information found via search_documents
52
+ - For comprehensive questions, perform MULTIPLE searches with different query angles
53
+ - NEVER fabricate or assume information
54
+ - If not found, say: "I cannot find information about this in the knowledge base."
55
+ - For follow-ups, understand context (pronouns like "he", "it") but always search for facts
56
+ - In Sources, include the actual chunk content from your search results, not summaries
57
+
58
+ Note: When using get_full_document, always use document_uri (not document_title).
59
+ """
@@ -0,0 +1,75 @@
1
+ try:
2
+ from fasta2a.schema import Message, Skill # type: ignore
3
+ except ImportError as e:
4
+ raise ImportError(
5
+ "A2A support requires the 'a2a' extra. "
6
+ "Install with: uv pip install 'haiku.rag[a2a]'"
7
+ ) from e
8
+
9
+
10
+ def get_agent_skills() -> list[Skill]:
11
+ """Define the skills exposed by the haiku.rag A2A agent.
12
+
13
+ Returns:
14
+ List of skills describing the agent's capabilities
15
+ """
16
+ return [
17
+ Skill(
18
+ id="document-qa",
19
+ name="Document Question Answering",
20
+ description="Answer questions based on a knowledge base of documents using semantic search and retrieval",
21
+ tags=["question-answering", "search", "knowledge-base", "rag"],
22
+ input_modes=["application/json"],
23
+ output_modes=["application/json"],
24
+ examples=[
25
+ "What does the documentation say about authentication?",
26
+ "Find information about Python best practices",
27
+ "Show me the full API documentation",
28
+ ],
29
+ ),
30
+ Skill(
31
+ id="document-search",
32
+ name="Document Search",
33
+ description="Search for relevant document chunks in the knowledge base using hybrid (semantic and BM25) search",
34
+ tags=["search", "retrieval", "semantic-search"],
35
+ input_modes=["application/json"],
36
+ output_modes=["application/json"],
37
+ examples=[
38
+ "Search for Python best practices",
39
+ "Find documents about authentication",
40
+ "Look for API documentation",
41
+ ],
42
+ ),
43
+ Skill(
44
+ id="document-retrieve",
45
+ name="Document Retrieval",
46
+ description="Retrieve the complete content of a specific document by its URI",
47
+ tags=["retrieval", "fetch", "document"],
48
+ input_modes=["application/json"],
49
+ output_modes=["application/json"],
50
+ examples=[
51
+ "Get the full content of document X",
52
+ "Retrieve document by URI",
53
+ "Show me the complete document",
54
+ ],
55
+ ),
56
+ ]
57
+
58
+
59
+ def extract_question_from_task(task_history: list[Message]) -> str | None:
60
+ """Extract the user's question from task history.
61
+
62
+ Args:
63
+ task_history: Task history messages
64
+
65
+ Returns:
66
+ The question text if found, None otherwise
67
+ """
68
+ for msg in task_history:
69
+ if msg.get("role") == "user":
70
+ for part in msg.get("parts", []):
71
+ if part.get("kind") == "text":
72
+ text = part.get("text", "").strip()
73
+ if text:
74
+ return text
75
+ return None