haiku.rag 0.11.4__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- haiku/rag/a2a/__init__.py +176 -0
- haiku/rag/a2a/client.py +268 -0
- haiku/rag/a2a/context.py +68 -0
- haiku/rag/a2a/models.py +21 -0
- haiku/rag/a2a/prompts.py +59 -0
- haiku/rag/a2a/skills.py +75 -0
- haiku/rag/a2a/storage.py +71 -0
- haiku/rag/a2a/worker.py +320 -0
- haiku/rag/app.py +87 -19
- haiku/rag/cli.py +81 -71
- haiku/rag/client.py +47 -4
- haiku/rag/config.py +4 -0
- haiku/rag/embeddings/base.py +8 -0
- haiku/rag/embeddings/ollama.py +8 -0
- haiku/rag/embeddings/openai.py +8 -0
- haiku/rag/embeddings/vllm.py +8 -0
- haiku/rag/embeddings/voyageai.py +8 -0
- haiku/rag/mcp.py +99 -0
- haiku/rag/qa/agent.py +0 -3
- {haiku_rag-0.11.4.dist-info → haiku_rag-0.12.1.dist-info}/METADATA +33 -10
- {haiku_rag-0.11.4.dist-info → haiku_rag-0.12.1.dist-info}/RECORD +24 -16
- {haiku_rag-0.11.4.dist-info → haiku_rag-0.12.1.dist-info}/WHEEL +0 -0
- {haiku_rag-0.11.4.dist-info → haiku_rag-0.12.1.dist-info}/entry_points.txt +0 -0
- {haiku_rag-0.11.4.dist-info → haiku_rag-0.12.1.dist-info}/licenses/LICENSE +0 -0
haiku/rag/a2a/storage.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from fasta2a.schema import Artifact, Message, TaskState # type: ignore
|
|
6
|
+
from fasta2a.storage import InMemoryStorage, Storage # type: ignore
|
|
7
|
+
except ImportError as e:
|
|
8
|
+
raise ImportError(
|
|
9
|
+
"A2A support requires the 'a2a' extra. "
|
|
10
|
+
"Install with: uv pip install 'haiku.rag[a2a]'"
|
|
11
|
+
) from e
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LRUMemoryStorage(Storage[list[Message]]): # type: ignore
|
|
17
|
+
"""Storage wrapper with LRU eviction for contexts.
|
|
18
|
+
|
|
19
|
+
Enforces a maximum context limit using LRU (Least Recently Used) eviction.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, storage: InMemoryStorage, max_contexts: int):
|
|
23
|
+
self.storage = storage
|
|
24
|
+
self.max_contexts = max_contexts
|
|
25
|
+
# Track context access order (LRU cache)
|
|
26
|
+
self.context_order: OrderedDict[str, None] = OrderedDict()
|
|
27
|
+
|
|
28
|
+
async def load_context(self, context_id: str) -> list[Message] | None:
|
|
29
|
+
"""Load context and update access order."""
|
|
30
|
+
result = await self.storage.load_context(context_id)
|
|
31
|
+
if result is not None:
|
|
32
|
+
# Move to end (most recently used)
|
|
33
|
+
self.context_order.pop(context_id, None)
|
|
34
|
+
self.context_order[context_id] = None
|
|
35
|
+
return result
|
|
36
|
+
|
|
37
|
+
async def update_context(self, context_id: str, context: list[Message]) -> None:
|
|
38
|
+
"""Update context and enforce LRU limit."""
|
|
39
|
+
await self.storage.update_context(context_id, context)
|
|
40
|
+
# Move to end (most recently used)
|
|
41
|
+
self.context_order.pop(context_id, None)
|
|
42
|
+
self.context_order[context_id] = None
|
|
43
|
+
|
|
44
|
+
# Enforce max contexts limit (LRU eviction)
|
|
45
|
+
while len(self.context_order) > self.max_contexts:
|
|
46
|
+
# Remove oldest (first item in OrderedDict)
|
|
47
|
+
oldest_context_id = next(iter(self.context_order))
|
|
48
|
+
self.context_order.pop(oldest_context_id)
|
|
49
|
+
logger.debug(
|
|
50
|
+
f"Evicted context {oldest_context_id} (LRU, limit={self.max_contexts})"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
async def load_task(self, task_id: str, history_length: int | None = None):
|
|
54
|
+
"""Delegate to underlying storage."""
|
|
55
|
+
return await self.storage.load_task(task_id, history_length)
|
|
56
|
+
|
|
57
|
+
async def update_task(
|
|
58
|
+
self,
|
|
59
|
+
task_id: str,
|
|
60
|
+
state: TaskState,
|
|
61
|
+
new_artifacts: list[Artifact] | None = None,
|
|
62
|
+
new_messages: list[Message] | None = None,
|
|
63
|
+
):
|
|
64
|
+
"""Delegate to underlying storage."""
|
|
65
|
+
return await self.storage.update_task(
|
|
66
|
+
task_id, state, new_artifacts, new_messages
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
async def submit_task(self, context_id: str, message: Message):
|
|
70
|
+
"""Delegate to underlying storage."""
|
|
71
|
+
return await self.storage.submit_task(context_id, message)
|
haiku/rag/a2a/worker.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import uuid
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from pydantic_ai import Agent
|
|
7
|
+
|
|
8
|
+
from haiku.rag.a2a.context import load_message_history, save_message_history
|
|
9
|
+
from haiku.rag.a2a.models import AgentDependencies
|
|
10
|
+
from haiku.rag.a2a.skills import extract_question_from_task
|
|
11
|
+
from haiku.rag.client import HaikuRAG
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from fasta2a import Worker # type: ignore
|
|
15
|
+
from fasta2a.schema import ( # type: ignore
|
|
16
|
+
Artifact,
|
|
17
|
+
Message,
|
|
18
|
+
TaskIdParams,
|
|
19
|
+
TaskSendParams,
|
|
20
|
+
TextPart,
|
|
21
|
+
)
|
|
22
|
+
except ImportError as e:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"A2A support requires the 'a2a' extra. "
|
|
25
|
+
"Install with: uv pip install 'haiku.rag[a2a]'"
|
|
26
|
+
) from e
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ConversationalWorker(Worker[list[Message]]):
|
|
32
|
+
"""Worker that handles conversational QA tasks."""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
storage,
|
|
37
|
+
broker,
|
|
38
|
+
db_path: Path,
|
|
39
|
+
agent: "Agent[AgentDependencies, str]",
|
|
40
|
+
):
|
|
41
|
+
super().__init__(storage=storage, broker=broker)
|
|
42
|
+
self.db_path = db_path
|
|
43
|
+
self.agent = agent
|
|
44
|
+
|
|
45
|
+
async def run_task(self, params: TaskSendParams) -> None:
|
|
46
|
+
task = await self.storage.load_task(params["id"])
|
|
47
|
+
if task is None:
|
|
48
|
+
raise ValueError(f"Task {params['id']} not found")
|
|
49
|
+
|
|
50
|
+
if task["status"]["state"] != "submitted":
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"Task {params['id']} already processed: {task['status']['state']}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
await self.storage.update_task(task["id"], state="working")
|
|
56
|
+
|
|
57
|
+
task_history = task.get("history", [])
|
|
58
|
+
question = extract_question_from_task(task_history)
|
|
59
|
+
|
|
60
|
+
if not question:
|
|
61
|
+
await self.storage.update_task(task["id"], state="failed")
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
async with HaikuRAG(self.db_path) as client:
|
|
66
|
+
context = await self.storage.load_context(task["context_id"]) or []
|
|
67
|
+
message_history = load_message_history(context)
|
|
68
|
+
|
|
69
|
+
deps = AgentDependencies(client=client)
|
|
70
|
+
|
|
71
|
+
result = await self.agent.run(
|
|
72
|
+
question, deps=deps, message_history=message_history
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Detect which skill was used
|
|
76
|
+
skill_type = self._detect_skill(result)
|
|
77
|
+
|
|
78
|
+
# Build messages based on skill type
|
|
79
|
+
response_messages = self._build_response_messages(result, skill_type)
|
|
80
|
+
|
|
81
|
+
# Update context with complete conversation state
|
|
82
|
+
updated_history = message_history + result.new_messages()
|
|
83
|
+
state_message = save_message_history(updated_history)
|
|
84
|
+
|
|
85
|
+
await self.storage.update_context(task["context_id"], [state_message])
|
|
86
|
+
|
|
87
|
+
artifacts = self.build_artifacts(result, skill_type, question)
|
|
88
|
+
|
|
89
|
+
await self.storage.update_task(
|
|
90
|
+
task["id"],
|
|
91
|
+
state="completed",
|
|
92
|
+
new_messages=response_messages,
|
|
93
|
+
new_artifacts=artifacts,
|
|
94
|
+
)
|
|
95
|
+
except Exception as e:
|
|
96
|
+
logger.error(
|
|
97
|
+
"Task execution failed: task_id=%s, question=%s, error=%s",
|
|
98
|
+
task["id"],
|
|
99
|
+
question,
|
|
100
|
+
str(e),
|
|
101
|
+
exc_info=True,
|
|
102
|
+
)
|
|
103
|
+
await self.storage.update_task(task["id"], state="failed")
|
|
104
|
+
raise
|
|
105
|
+
|
|
106
|
+
async def cancel_task(self, params: TaskIdParams) -> None:
|
|
107
|
+
"""Cancel a task - not implemented for this worker."""
|
|
108
|
+
pass
|
|
109
|
+
|
|
110
|
+
def build_message_history(self, history: list[Message]) -> list[Message]:
|
|
111
|
+
"""Required by Worker interface but unused - history stored in context."""
|
|
112
|
+
return history
|
|
113
|
+
|
|
114
|
+
def _detect_skill(self, result) -> str:
|
|
115
|
+
"""Detect which skill was used based on tool calls and response pattern.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
"search", "retrieve", or "qa"
|
|
119
|
+
"""
|
|
120
|
+
from pydantic_ai.messages import ModelResponse, ToolCallPart
|
|
121
|
+
|
|
122
|
+
tool_calls = []
|
|
123
|
+
for msg in result.new_messages():
|
|
124
|
+
if isinstance(msg, ModelResponse):
|
|
125
|
+
for part in msg.parts:
|
|
126
|
+
if isinstance(part, ToolCallPart):
|
|
127
|
+
tool_calls.append(part.tool_name)
|
|
128
|
+
|
|
129
|
+
# Check if output looks like formatted search results
|
|
130
|
+
output_str = str(result.output).strip()
|
|
131
|
+
# Check for either format: "Found N relevant results" or "**Search results for"
|
|
132
|
+
is_search_format = (
|
|
133
|
+
output_str.startswith("Found ") and "relevant results" in output_str[:100]
|
|
134
|
+
) or output_str.startswith("**Search results for")
|
|
135
|
+
|
|
136
|
+
skill_type = "qa"
|
|
137
|
+
# If output is in search format and only search tools were used, it's a search
|
|
138
|
+
if is_search_format and all(tc == "search_documents" for tc in tool_calls):
|
|
139
|
+
skill_type = "search"
|
|
140
|
+
elif "get_full_document" in tool_calls and len(tool_calls) == 1:
|
|
141
|
+
skill_type = "retrieve"
|
|
142
|
+
|
|
143
|
+
return skill_type
|
|
144
|
+
|
|
145
|
+
def _build_response_messages(self, result, skill_type: str) -> list[Message]:
|
|
146
|
+
"""Build response messages based on skill type.
|
|
147
|
+
|
|
148
|
+
All skills return a single text message with LLM's response.
|
|
149
|
+
Structured data is provided via artifacts for search/retrieve.
|
|
150
|
+
"""
|
|
151
|
+
if skill_type == "search":
|
|
152
|
+
# Return LLM's formatted response
|
|
153
|
+
return [
|
|
154
|
+
Message(
|
|
155
|
+
role="agent",
|
|
156
|
+
parts=[TextPart(kind="text", text=str(result.output))],
|
|
157
|
+
kind="message",
|
|
158
|
+
message_id=str(uuid.uuid4()),
|
|
159
|
+
)
|
|
160
|
+
]
|
|
161
|
+
elif skill_type == "retrieve":
|
|
162
|
+
# Extract document content
|
|
163
|
+
from pydantic_ai.messages import ModelRequest, ToolReturnPart
|
|
164
|
+
|
|
165
|
+
document_content = ""
|
|
166
|
+
for msg in result.new_messages():
|
|
167
|
+
if isinstance(msg, ModelRequest):
|
|
168
|
+
for part in msg.parts:
|
|
169
|
+
if (
|
|
170
|
+
isinstance(part, ToolReturnPart)
|
|
171
|
+
and part.tool_name == "get_full_document"
|
|
172
|
+
):
|
|
173
|
+
document_content = part.content
|
|
174
|
+
break
|
|
175
|
+
|
|
176
|
+
return [
|
|
177
|
+
Message(
|
|
178
|
+
role="agent",
|
|
179
|
+
parts=[TextPart(kind="text", text=document_content)],
|
|
180
|
+
kind="message",
|
|
181
|
+
message_id=str(uuid.uuid4()),
|
|
182
|
+
)
|
|
183
|
+
]
|
|
184
|
+
else:
|
|
185
|
+
# Conversational Q&A - use agent's answer
|
|
186
|
+
return [
|
|
187
|
+
Message(
|
|
188
|
+
role="agent",
|
|
189
|
+
parts=[TextPart(kind="text", text=str(result.output))],
|
|
190
|
+
kind="message",
|
|
191
|
+
message_id=str(uuid.uuid4()),
|
|
192
|
+
)
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
def build_artifacts(
|
|
196
|
+
self, result, skill_type: str | None = None, question: str | None = None
|
|
197
|
+
) -> list[Artifact]:
|
|
198
|
+
"""Build artifacts from agent result based on tool calls.
|
|
199
|
+
|
|
200
|
+
Creates artifacts for:
|
|
201
|
+
- Each tool call (search_documents, get_full_document)
|
|
202
|
+
- Q&A operations: additional artifact with question and answer (only if tools were used)
|
|
203
|
+
"""
|
|
204
|
+
if skill_type is None:
|
|
205
|
+
skill_type = self._detect_skill(result)
|
|
206
|
+
|
|
207
|
+
artifacts = []
|
|
208
|
+
|
|
209
|
+
# Always create artifacts for all tool calls
|
|
210
|
+
tool_artifacts = self._build_all_tool_artifacts(result)
|
|
211
|
+
artifacts.extend(tool_artifacts)
|
|
212
|
+
|
|
213
|
+
# For Q&A, always add a Q&A artifact with question and answer
|
|
214
|
+
# This includes follow-up questions, clarifications, and conversational responses
|
|
215
|
+
if skill_type == "qa" and question:
|
|
216
|
+
from fasta2a.schema import DataPart
|
|
217
|
+
|
|
218
|
+
artifacts.append(
|
|
219
|
+
Artifact(
|
|
220
|
+
artifact_id=str(uuid.uuid4()),
|
|
221
|
+
name="qa_result",
|
|
222
|
+
parts=[
|
|
223
|
+
DataPart(
|
|
224
|
+
kind="data",
|
|
225
|
+
data={
|
|
226
|
+
"question": question,
|
|
227
|
+
"answer": str(result.output),
|
|
228
|
+
"skill": "document-qa",
|
|
229
|
+
},
|
|
230
|
+
metadata={"skill": "document-qa"},
|
|
231
|
+
)
|
|
232
|
+
],
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
return artifacts
|
|
237
|
+
|
|
238
|
+
def _build_all_tool_artifacts(self, result) -> list[Artifact]:
|
|
239
|
+
"""Build artifacts for all tool calls."""
|
|
240
|
+
from pydantic_ai.messages import (
|
|
241
|
+
ModelRequest,
|
|
242
|
+
ModelResponse,
|
|
243
|
+
ToolCallPart,
|
|
244
|
+
ToolReturnPart,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
artifacts = []
|
|
248
|
+
|
|
249
|
+
# Track tool calls and their returns by call_id
|
|
250
|
+
tool_returns = {}
|
|
251
|
+
for msg in result.new_messages():
|
|
252
|
+
if isinstance(msg, ModelRequest):
|
|
253
|
+
for part in msg.parts:
|
|
254
|
+
if isinstance(part, ToolReturnPart):
|
|
255
|
+
result_count = (
|
|
256
|
+
len(part.content) if isinstance(part.content, list) else 1
|
|
257
|
+
)
|
|
258
|
+
logger.info(
|
|
259
|
+
"Tool return: tool_call_id=%s, tool_name=%s, result_count=%s",
|
|
260
|
+
part.tool_call_id,
|
|
261
|
+
part.tool_name,
|
|
262
|
+
result_count,
|
|
263
|
+
)
|
|
264
|
+
tool_returns[part.tool_call_id] = (part.tool_name, part.content)
|
|
265
|
+
|
|
266
|
+
# Create artifacts for each tool call
|
|
267
|
+
for msg in result.new_messages():
|
|
268
|
+
if isinstance(msg, ModelResponse):
|
|
269
|
+
for part in msg.parts:
|
|
270
|
+
if isinstance(part, ToolCallPart):
|
|
271
|
+
tool_name, content = tool_returns.get(
|
|
272
|
+
part.tool_call_id, (None, None)
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if tool_name == "search_documents" and content:
|
|
276
|
+
from fasta2a.schema import DataPart
|
|
277
|
+
|
|
278
|
+
# Extract query from tool call arguments
|
|
279
|
+
query = ""
|
|
280
|
+
if isinstance(part.args, dict):
|
|
281
|
+
query = part.args.get("query", "")
|
|
282
|
+
elif isinstance(part.args, str):
|
|
283
|
+
# Args is a JSON string - parse it
|
|
284
|
+
try:
|
|
285
|
+
args_dict = json.loads(part.args)
|
|
286
|
+
query = args_dict.get("query", "")
|
|
287
|
+
except (json.JSONDecodeError, AttributeError):
|
|
288
|
+
query = ""
|
|
289
|
+
elif hasattr(part.args, "get") and callable(
|
|
290
|
+
getattr(part.args, "get", None)
|
|
291
|
+
):
|
|
292
|
+
# ArgsDict or dict-like object
|
|
293
|
+
query = part.args.get("query", "") # type: ignore
|
|
294
|
+
elif hasattr(part.args, "query"):
|
|
295
|
+
# Object with query attribute
|
|
296
|
+
query = str(part.args.query) # type: ignore
|
|
297
|
+
|
|
298
|
+
artifacts.append(
|
|
299
|
+
Artifact(
|
|
300
|
+
artifact_id=str(uuid.uuid4()),
|
|
301
|
+
name="search_results",
|
|
302
|
+
parts=[
|
|
303
|
+
DataPart(
|
|
304
|
+
kind="data",
|
|
305
|
+
data={"results": content, "query": query},
|
|
306
|
+
metadata={"query": query},
|
|
307
|
+
)
|
|
308
|
+
],
|
|
309
|
+
)
|
|
310
|
+
)
|
|
311
|
+
elif tool_name == "get_full_document" and content:
|
|
312
|
+
artifacts.append(
|
|
313
|
+
Artifact(
|
|
314
|
+
artifact_id=str(uuid.uuid4()),
|
|
315
|
+
name="document",
|
|
316
|
+
parts=[TextPart(kind="text", text=content)],
|
|
317
|
+
)
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
return artifacts
|
haiku/rag/app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import json
|
|
3
|
+
import logging
|
|
3
4
|
from importlib.metadata import version as pkg_version
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
|
|
@@ -22,6 +23,8 @@ from haiku.rag.research.stream import stream_research_graph
|
|
|
22
23
|
from haiku.rag.store.models.chunk import Chunk
|
|
23
24
|
from haiku.rag.store.models.document import Document
|
|
24
25
|
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
25
28
|
|
|
26
29
|
class HaikuRAGApp:
|
|
27
30
|
def __init__(self, db_path: Path):
|
|
@@ -157,13 +160,20 @@ class HaikuRAGApp:
|
|
|
157
160
|
self, source: str, title: str | None = None, metadata: dict | None = None
|
|
158
161
|
):
|
|
159
162
|
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
160
|
-
|
|
163
|
+
result = await self.client.create_document_from_source(
|
|
161
164
|
source, title=title, metadata=metadata
|
|
162
165
|
)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
166
|
+
if isinstance(result, list):
|
|
167
|
+
for doc in result:
|
|
168
|
+
self._rich_print_document(doc, truncate=True)
|
|
169
|
+
self.console.print(
|
|
170
|
+
f"[bold green]{len(result)} documents added successfully.[/bold green]"
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
self._rich_print_document(result, truncate=True)
|
|
174
|
+
self.console.print(
|
|
175
|
+
f"[bold green]Document {result.id} added successfully.[/bold green]"
|
|
176
|
+
)
|
|
167
177
|
|
|
168
178
|
async def get_document(self, doc_id: str):
|
|
169
179
|
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
@@ -448,23 +458,81 @@ class HaikuRAGApp:
|
|
|
448
458
|
self.console.print(content)
|
|
449
459
|
self.console.rule()
|
|
450
460
|
|
|
451
|
-
async def serve(
|
|
452
|
-
|
|
461
|
+
async def serve(
|
|
462
|
+
self,
|
|
463
|
+
enable_monitor: bool = True,
|
|
464
|
+
enable_mcp: bool = True,
|
|
465
|
+
mcp_transport: str | None = None,
|
|
466
|
+
mcp_port: int = 8001,
|
|
467
|
+
enable_a2a: bool = False,
|
|
468
|
+
a2a_host: str = "127.0.0.1",
|
|
469
|
+
a2a_port: int = 8000,
|
|
470
|
+
):
|
|
471
|
+
"""Start the server with selected services."""
|
|
453
472
|
async with HaikuRAG(self.db_path) as client:
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
473
|
+
tasks = []
|
|
474
|
+
|
|
475
|
+
# Start file monitor if enabled
|
|
476
|
+
if enable_monitor:
|
|
477
|
+
monitor = FileWatcher(paths=Config.MONITOR_DIRECTORIES, client=client)
|
|
478
|
+
monitor_task = asyncio.create_task(monitor.observe())
|
|
479
|
+
tasks.append(monitor_task)
|
|
480
|
+
|
|
481
|
+
# Start MCP server if enabled
|
|
482
|
+
if enable_mcp:
|
|
483
|
+
server = create_mcp_server(self.db_path)
|
|
484
|
+
|
|
485
|
+
async def run_mcp():
|
|
486
|
+
if mcp_transport == "stdio":
|
|
487
|
+
await server.run_stdio_async()
|
|
488
|
+
else:
|
|
489
|
+
logger.info(f"Starting MCP server on port {mcp_port}")
|
|
490
|
+
await server.run_http_async(
|
|
491
|
+
transport="streamable-http", port=mcp_port
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
mcp_task = asyncio.create_task(run_mcp())
|
|
495
|
+
tasks.append(mcp_task)
|
|
496
|
+
|
|
497
|
+
# Start A2A server if enabled
|
|
498
|
+
if enable_a2a:
|
|
499
|
+
try:
|
|
500
|
+
from haiku.rag.a2a import create_a2a_app
|
|
501
|
+
except ImportError as e:
|
|
502
|
+
logger.error(f"Failed to import A2A: {e}")
|
|
503
|
+
return
|
|
504
|
+
|
|
505
|
+
import uvicorn
|
|
506
|
+
|
|
507
|
+
logger.info(f"Starting A2A server on {a2a_host}:{a2a_port}")
|
|
508
|
+
|
|
509
|
+
async def run_a2a():
|
|
510
|
+
app = create_a2a_app(db_path=self.db_path)
|
|
511
|
+
config = uvicorn.Config(
|
|
512
|
+
app,
|
|
513
|
+
host=a2a_host,
|
|
514
|
+
port=a2a_port,
|
|
515
|
+
log_level="warning",
|
|
516
|
+
access_log=False,
|
|
517
|
+
)
|
|
518
|
+
server = uvicorn.Server(config)
|
|
519
|
+
await server.serve()
|
|
520
|
+
|
|
521
|
+
a2a_task = asyncio.create_task(run_a2a())
|
|
522
|
+
tasks.append(a2a_task)
|
|
523
|
+
|
|
524
|
+
if not tasks:
|
|
525
|
+
logger.warning("No services enabled")
|
|
526
|
+
return
|
|
457
527
|
|
|
458
528
|
try:
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
else:
|
|
462
|
-
await server.run_http_async(transport="streamable-http")
|
|
529
|
+
# Wait for any task to complete (or KeyboardInterrupt)
|
|
530
|
+
await asyncio.gather(*tasks)
|
|
463
531
|
except KeyboardInterrupt:
|
|
464
532
|
pass
|
|
465
533
|
finally:
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
534
|
+
# Cancel all tasks
|
|
535
|
+
for task in tasks:
|
|
536
|
+
task.cancel()
|
|
537
|
+
# Wait for cancellation
|
|
538
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|