haiku.rag 0.11.4__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag might be problematic. Click here for more details.

@@ -0,0 +1,71 @@
1
+ import logging
2
+ from collections import OrderedDict
3
+
4
+ try:
5
+ from fasta2a.schema import Artifact, Message, TaskState # type: ignore
6
+ from fasta2a.storage import InMemoryStorage, Storage # type: ignore
7
+ except ImportError as e:
8
+ raise ImportError(
9
+ "A2A support requires the 'a2a' extra. "
10
+ "Install with: uv pip install 'haiku.rag[a2a]'"
11
+ ) from e
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class LRUMemoryStorage(Storage[list[Message]]): # type: ignore
17
+ """Storage wrapper with LRU eviction for contexts.
18
+
19
+ Enforces a maximum context limit using LRU (Least Recently Used) eviction.
20
+ """
21
+
22
+ def __init__(self, storage: InMemoryStorage, max_contexts: int):
23
+ self.storage = storage
24
+ self.max_contexts = max_contexts
25
+ # Track context access order (LRU cache)
26
+ self.context_order: OrderedDict[str, None] = OrderedDict()
27
+
28
+ async def load_context(self, context_id: str) -> list[Message] | None:
29
+ """Load context and update access order."""
30
+ result = await self.storage.load_context(context_id)
31
+ if result is not None:
32
+ # Move to end (most recently used)
33
+ self.context_order.pop(context_id, None)
34
+ self.context_order[context_id] = None
35
+ return result
36
+
37
+ async def update_context(self, context_id: str, context: list[Message]) -> None:
38
+ """Update context and enforce LRU limit."""
39
+ await self.storage.update_context(context_id, context)
40
+ # Move to end (most recently used)
41
+ self.context_order.pop(context_id, None)
42
+ self.context_order[context_id] = None
43
+
44
+ # Enforce max contexts limit (LRU eviction)
45
+ while len(self.context_order) > self.max_contexts:
46
+ # Remove oldest (first item in OrderedDict)
47
+ oldest_context_id = next(iter(self.context_order))
48
+ self.context_order.pop(oldest_context_id)
49
+ logger.debug(
50
+ f"Evicted context {oldest_context_id} (LRU, limit={self.max_contexts})"
51
+ )
52
+
53
+ async def load_task(self, task_id: str, history_length: int | None = None):
54
+ """Delegate to underlying storage."""
55
+ return await self.storage.load_task(task_id, history_length)
56
+
57
+ async def update_task(
58
+ self,
59
+ task_id: str,
60
+ state: TaskState,
61
+ new_artifacts: list[Artifact] | None = None,
62
+ new_messages: list[Message] | None = None,
63
+ ):
64
+ """Delegate to underlying storage."""
65
+ return await self.storage.update_task(
66
+ task_id, state, new_artifacts, new_messages
67
+ )
68
+
69
+ async def submit_task(self, context_id: str, message: Message):
70
+ """Delegate to underlying storage."""
71
+ return await self.storage.submit_task(context_id, message)
@@ -0,0 +1,320 @@
1
+ import json
2
+ import logging
3
+ import uuid
4
+ from pathlib import Path
5
+
6
+ from pydantic_ai import Agent
7
+
8
+ from haiku.rag.a2a.context import load_message_history, save_message_history
9
+ from haiku.rag.a2a.models import AgentDependencies
10
+ from haiku.rag.a2a.skills import extract_question_from_task
11
+ from haiku.rag.client import HaikuRAG
12
+
13
+ try:
14
+ from fasta2a import Worker # type: ignore
15
+ from fasta2a.schema import ( # type: ignore
16
+ Artifact,
17
+ Message,
18
+ TaskIdParams,
19
+ TaskSendParams,
20
+ TextPart,
21
+ )
22
+ except ImportError as e:
23
+ raise ImportError(
24
+ "A2A support requires the 'a2a' extra. "
25
+ "Install with: uv pip install 'haiku.rag[a2a]'"
26
+ ) from e
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class ConversationalWorker(Worker[list[Message]]):
32
+ """Worker that handles conversational QA tasks."""
33
+
34
+ def __init__(
35
+ self,
36
+ storage,
37
+ broker,
38
+ db_path: Path,
39
+ agent: "Agent[AgentDependencies, str]",
40
+ ):
41
+ super().__init__(storage=storage, broker=broker)
42
+ self.db_path = db_path
43
+ self.agent = agent
44
+
45
+ async def run_task(self, params: TaskSendParams) -> None:
46
+ task = await self.storage.load_task(params["id"])
47
+ if task is None:
48
+ raise ValueError(f"Task {params['id']} not found")
49
+
50
+ if task["status"]["state"] != "submitted":
51
+ raise ValueError(
52
+ f"Task {params['id']} already processed: {task['status']['state']}"
53
+ )
54
+
55
+ await self.storage.update_task(task["id"], state="working")
56
+
57
+ task_history = task.get("history", [])
58
+ question = extract_question_from_task(task_history)
59
+
60
+ if not question:
61
+ await self.storage.update_task(task["id"], state="failed")
62
+ return
63
+
64
+ try:
65
+ async with HaikuRAG(self.db_path) as client:
66
+ context = await self.storage.load_context(task["context_id"]) or []
67
+ message_history = load_message_history(context)
68
+
69
+ deps = AgentDependencies(client=client)
70
+
71
+ result = await self.agent.run(
72
+ question, deps=deps, message_history=message_history
73
+ )
74
+
75
+ # Detect which skill was used
76
+ skill_type = self._detect_skill(result)
77
+
78
+ # Build messages based on skill type
79
+ response_messages = self._build_response_messages(result, skill_type)
80
+
81
+ # Update context with complete conversation state
82
+ updated_history = message_history + result.new_messages()
83
+ state_message = save_message_history(updated_history)
84
+
85
+ await self.storage.update_context(task["context_id"], [state_message])
86
+
87
+ artifacts = self.build_artifacts(result, skill_type, question)
88
+
89
+ await self.storage.update_task(
90
+ task["id"],
91
+ state="completed",
92
+ new_messages=response_messages,
93
+ new_artifacts=artifacts,
94
+ )
95
+ except Exception as e:
96
+ logger.error(
97
+ "Task execution failed: task_id=%s, question=%s, error=%s",
98
+ task["id"],
99
+ question,
100
+ str(e),
101
+ exc_info=True,
102
+ )
103
+ await self.storage.update_task(task["id"], state="failed")
104
+ raise
105
+
106
+ async def cancel_task(self, params: TaskIdParams) -> None:
107
+ """Cancel a task - not implemented for this worker."""
108
+ pass
109
+
110
+ def build_message_history(self, history: list[Message]) -> list[Message]:
111
+ """Required by Worker interface but unused - history stored in context."""
112
+ return history
113
+
114
+ def _detect_skill(self, result) -> str:
115
+ """Detect which skill was used based on tool calls and response pattern.
116
+
117
+ Returns:
118
+ "search", "retrieve", or "qa"
119
+ """
120
+ from pydantic_ai.messages import ModelResponse, ToolCallPart
121
+
122
+ tool_calls = []
123
+ for msg in result.new_messages():
124
+ if isinstance(msg, ModelResponse):
125
+ for part in msg.parts:
126
+ if isinstance(part, ToolCallPart):
127
+ tool_calls.append(part.tool_name)
128
+
129
+ # Check if output looks like formatted search results
130
+ output_str = str(result.output).strip()
131
+ # Check for either format: "Found N relevant results" or "**Search results for"
132
+ is_search_format = (
133
+ output_str.startswith("Found ") and "relevant results" in output_str[:100]
134
+ ) or output_str.startswith("**Search results for")
135
+
136
+ skill_type = "qa"
137
+ # If output is in search format and only search tools were used, it's a search
138
+ if is_search_format and all(tc == "search_documents" for tc in tool_calls):
139
+ skill_type = "search"
140
+ elif "get_full_document" in tool_calls and len(tool_calls) == 1:
141
+ skill_type = "retrieve"
142
+
143
+ return skill_type
144
+
145
+ def _build_response_messages(self, result, skill_type: str) -> list[Message]:
146
+ """Build response messages based on skill type.
147
+
148
+ All skills return a single text message with LLM's response.
149
+ Structured data is provided via artifacts for search/retrieve.
150
+ """
151
+ if skill_type == "search":
152
+ # Return LLM's formatted response
153
+ return [
154
+ Message(
155
+ role="agent",
156
+ parts=[TextPart(kind="text", text=str(result.output))],
157
+ kind="message",
158
+ message_id=str(uuid.uuid4()),
159
+ )
160
+ ]
161
+ elif skill_type == "retrieve":
162
+ # Extract document content
163
+ from pydantic_ai.messages import ModelRequest, ToolReturnPart
164
+
165
+ document_content = ""
166
+ for msg in result.new_messages():
167
+ if isinstance(msg, ModelRequest):
168
+ for part in msg.parts:
169
+ if (
170
+ isinstance(part, ToolReturnPart)
171
+ and part.tool_name == "get_full_document"
172
+ ):
173
+ document_content = part.content
174
+ break
175
+
176
+ return [
177
+ Message(
178
+ role="agent",
179
+ parts=[TextPart(kind="text", text=document_content)],
180
+ kind="message",
181
+ message_id=str(uuid.uuid4()),
182
+ )
183
+ ]
184
+ else:
185
+ # Conversational Q&A - use agent's answer
186
+ return [
187
+ Message(
188
+ role="agent",
189
+ parts=[TextPart(kind="text", text=str(result.output))],
190
+ kind="message",
191
+ message_id=str(uuid.uuid4()),
192
+ )
193
+ ]
194
+
195
+ def build_artifacts(
196
+ self, result, skill_type: str | None = None, question: str | None = None
197
+ ) -> list[Artifact]:
198
+ """Build artifacts from agent result based on tool calls.
199
+
200
+ Creates artifacts for:
201
+ - Each tool call (search_documents, get_full_document)
202
+ - Q&A operations: additional artifact with question and answer (only if tools were used)
203
+ """
204
+ if skill_type is None:
205
+ skill_type = self._detect_skill(result)
206
+
207
+ artifacts = []
208
+
209
+ # Always create artifacts for all tool calls
210
+ tool_artifacts = self._build_all_tool_artifacts(result)
211
+ artifacts.extend(tool_artifacts)
212
+
213
+ # For Q&A, always add a Q&A artifact with question and answer
214
+ # This includes follow-up questions, clarifications, and conversational responses
215
+ if skill_type == "qa" and question:
216
+ from fasta2a.schema import DataPart
217
+
218
+ artifacts.append(
219
+ Artifact(
220
+ artifact_id=str(uuid.uuid4()),
221
+ name="qa_result",
222
+ parts=[
223
+ DataPart(
224
+ kind="data",
225
+ data={
226
+ "question": question,
227
+ "answer": str(result.output),
228
+ "skill": "document-qa",
229
+ },
230
+ metadata={"skill": "document-qa"},
231
+ )
232
+ ],
233
+ )
234
+ )
235
+
236
+ return artifacts
237
+
238
+ def _build_all_tool_artifacts(self, result) -> list[Artifact]:
239
+ """Build artifacts for all tool calls."""
240
+ from pydantic_ai.messages import (
241
+ ModelRequest,
242
+ ModelResponse,
243
+ ToolCallPart,
244
+ ToolReturnPart,
245
+ )
246
+
247
+ artifacts = []
248
+
249
+ # Track tool calls and their returns by call_id
250
+ tool_returns = {}
251
+ for msg in result.new_messages():
252
+ if isinstance(msg, ModelRequest):
253
+ for part in msg.parts:
254
+ if isinstance(part, ToolReturnPart):
255
+ result_count = (
256
+ len(part.content) if isinstance(part.content, list) else 1
257
+ )
258
+ logger.info(
259
+ "Tool return: tool_call_id=%s, tool_name=%s, result_count=%s",
260
+ part.tool_call_id,
261
+ part.tool_name,
262
+ result_count,
263
+ )
264
+ tool_returns[part.tool_call_id] = (part.tool_name, part.content)
265
+
266
+ # Create artifacts for each tool call
267
+ for msg in result.new_messages():
268
+ if isinstance(msg, ModelResponse):
269
+ for part in msg.parts:
270
+ if isinstance(part, ToolCallPart):
271
+ tool_name, content = tool_returns.get(
272
+ part.tool_call_id, (None, None)
273
+ )
274
+
275
+ if tool_name == "search_documents" and content:
276
+ from fasta2a.schema import DataPart
277
+
278
+ # Extract query from tool call arguments
279
+ query = ""
280
+ if isinstance(part.args, dict):
281
+ query = part.args.get("query", "")
282
+ elif isinstance(part.args, str):
283
+ # Args is a JSON string - parse it
284
+ try:
285
+ args_dict = json.loads(part.args)
286
+ query = args_dict.get("query", "")
287
+ except (json.JSONDecodeError, AttributeError):
288
+ query = ""
289
+ elif hasattr(part.args, "get") and callable(
290
+ getattr(part.args, "get", None)
291
+ ):
292
+ # ArgsDict or dict-like object
293
+ query = part.args.get("query", "") # type: ignore
294
+ elif hasattr(part.args, "query"):
295
+ # Object with query attribute
296
+ query = str(part.args.query) # type: ignore
297
+
298
+ artifacts.append(
299
+ Artifact(
300
+ artifact_id=str(uuid.uuid4()),
301
+ name="search_results",
302
+ parts=[
303
+ DataPart(
304
+ kind="data",
305
+ data={"results": content, "query": query},
306
+ metadata={"query": query},
307
+ )
308
+ ],
309
+ )
310
+ )
311
+ elif tool_name == "get_full_document" and content:
312
+ artifacts.append(
313
+ Artifact(
314
+ artifact_id=str(uuid.uuid4()),
315
+ name="document",
316
+ parts=[TextPart(kind="text", text=content)],
317
+ )
318
+ )
319
+
320
+ return artifacts
haiku/rag/app.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import json
3
+ import logging
3
4
  from importlib.metadata import version as pkg_version
4
5
  from pathlib import Path
5
6
 
@@ -22,6 +23,8 @@ from haiku.rag.research.stream import stream_research_graph
22
23
  from haiku.rag.store.models.chunk import Chunk
23
24
  from haiku.rag.store.models.document import Document
24
25
 
26
+ logger = logging.getLogger(__name__)
27
+
25
28
 
26
29
  class HaikuRAGApp:
27
30
  def __init__(self, db_path: Path):
@@ -157,13 +160,20 @@ class HaikuRAGApp:
157
160
  self, source: str, title: str | None = None, metadata: dict | None = None
158
161
  ):
159
162
  async with HaikuRAG(db_path=self.db_path) as self.client:
160
- doc = await self.client.create_document_from_source(
163
+ result = await self.client.create_document_from_source(
161
164
  source, title=title, metadata=metadata
162
165
  )
163
- self._rich_print_document(doc, truncate=True)
164
- self.console.print(
165
- f"[bold green]Document {doc.id} added successfully.[/bold green]"
166
- )
166
+ if isinstance(result, list):
167
+ for doc in result:
168
+ self._rich_print_document(doc, truncate=True)
169
+ self.console.print(
170
+ f"[bold green]{len(result)} documents added successfully.[/bold green]"
171
+ )
172
+ else:
173
+ self._rich_print_document(result, truncate=True)
174
+ self.console.print(
175
+ f"[bold green]Document {result.id} added successfully.[/bold green]"
176
+ )
167
177
 
168
178
  async def get_document(self, doc_id: str):
169
179
  async with HaikuRAG(db_path=self.db_path) as self.client:
@@ -448,23 +458,81 @@ class HaikuRAGApp:
448
458
  self.console.print(content)
449
459
  self.console.rule()
450
460
 
451
- async def serve(self, transport: str | None = None):
452
- """Start the MCP server."""
461
+ async def serve(
462
+ self,
463
+ enable_monitor: bool = True,
464
+ enable_mcp: bool = True,
465
+ mcp_transport: str | None = None,
466
+ mcp_port: int = 8001,
467
+ enable_a2a: bool = False,
468
+ a2a_host: str = "127.0.0.1",
469
+ a2a_port: int = 8000,
470
+ ):
471
+ """Start the server with selected services."""
453
472
  async with HaikuRAG(self.db_path) as client:
454
- monitor = FileWatcher(paths=Config.MONITOR_DIRECTORIES, client=client)
455
- monitor_task = asyncio.create_task(monitor.observe())
456
- server = create_mcp_server(self.db_path)
473
+ tasks = []
474
+
475
+ # Start file monitor if enabled
476
+ if enable_monitor:
477
+ monitor = FileWatcher(paths=Config.MONITOR_DIRECTORIES, client=client)
478
+ monitor_task = asyncio.create_task(monitor.observe())
479
+ tasks.append(monitor_task)
480
+
481
+ # Start MCP server if enabled
482
+ if enable_mcp:
483
+ server = create_mcp_server(self.db_path)
484
+
485
+ async def run_mcp():
486
+ if mcp_transport == "stdio":
487
+ await server.run_stdio_async()
488
+ else:
489
+ logger.info(f"Starting MCP server on port {mcp_port}")
490
+ await server.run_http_async(
491
+ transport="streamable-http", port=mcp_port
492
+ )
493
+
494
+ mcp_task = asyncio.create_task(run_mcp())
495
+ tasks.append(mcp_task)
496
+
497
+ # Start A2A server if enabled
498
+ if enable_a2a:
499
+ try:
500
+ from haiku.rag.a2a import create_a2a_app
501
+ except ImportError as e:
502
+ logger.error(f"Failed to import A2A: {e}")
503
+ return
504
+
505
+ import uvicorn
506
+
507
+ logger.info(f"Starting A2A server on {a2a_host}:{a2a_port}")
508
+
509
+ async def run_a2a():
510
+ app = create_a2a_app(db_path=self.db_path)
511
+ config = uvicorn.Config(
512
+ app,
513
+ host=a2a_host,
514
+ port=a2a_port,
515
+ log_level="warning",
516
+ access_log=False,
517
+ )
518
+ server = uvicorn.Server(config)
519
+ await server.serve()
520
+
521
+ a2a_task = asyncio.create_task(run_a2a())
522
+ tasks.append(a2a_task)
523
+
524
+ if not tasks:
525
+ logger.warning("No services enabled")
526
+ return
457
527
 
458
528
  try:
459
- if transport == "stdio":
460
- await server.run_stdio_async()
461
- else:
462
- await server.run_http_async(transport="streamable-http")
529
+ # Wait for any task to complete (or KeyboardInterrupt)
530
+ await asyncio.gather(*tasks)
463
531
  except KeyboardInterrupt:
464
532
  pass
465
533
  finally:
466
- monitor_task.cancel()
467
- try:
468
- await monitor_task
469
- except asyncio.CancelledError:
470
- pass
534
+ # Cancel all tasks
535
+ for task in tasks:
536
+ task.cancel()
537
+ # Wait for cancellation
538
+ await asyncio.gather(*tasks, return_exceptions=True)