sunholo 0.143.1__py3-none-any.whl → 0.143.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sunholo/__init__.py +4 -1
- sunholo/a2a/__init__.py +27 -0
- sunholo/a2a/agent_card.py +345 -0
- sunholo/a2a/task_manager.py +480 -0
- sunholo/a2a/vac_a2a_agent.py +383 -0
- sunholo/agents/flask/vac_routes.py +256 -19
- sunholo/mcp/__init__.py +11 -2
- sunholo/mcp/mcp_manager.py +66 -28
- sunholo/mcp/vac_mcp_server.py +3 -9
- {sunholo-0.143.1.dist-info → sunholo-0.143.7.dist-info}/METADATA +4 -2
- {sunholo-0.143.1.dist-info → sunholo-0.143.7.dist-info}/RECORD +15 -11
- {sunholo-0.143.1.dist-info → sunholo-0.143.7.dist-info}/WHEEL +0 -0
- {sunholo-0.143.1.dist-info → sunholo-0.143.7.dist-info}/entry_points.txt +0 -0
- {sunholo-0.143.1.dist-info → sunholo-0.143.7.dist-info}/licenses/LICENSE.txt +0 -0
- {sunholo-0.143.1.dist-info → sunholo-0.143.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,480 @@
|
|
1
|
+
# Copyright [2024] [Holosun ApS]
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
A2A Task Management for VAC interactions.
|
17
|
+
Handles the lifecycle of A2A tasks and their state transitions.
|
18
|
+
"""
|
19
|
+
|
20
|
+
import asyncio
|
21
|
+
import uuid
|
22
|
+
from datetime import datetime, timezone, timedelta
|
23
|
+
from typing import Dict, List, Any, Optional, Callable, AsyncGenerator
|
24
|
+
from enum import Enum
|
25
|
+
import json
|
26
|
+
from ..custom_logging import log
|
27
|
+
|
28
|
+
|
29
|
+
class TaskState(Enum):
|
30
|
+
"""A2A Task states as defined in the protocol."""
|
31
|
+
SUBMITTED = "submitted"
|
32
|
+
WORKING = "working"
|
33
|
+
INPUT_REQUIRED = "input-required"
|
34
|
+
COMPLETED = "completed"
|
35
|
+
CANCELED = "canceled"
|
36
|
+
FAILED = "failed"
|
37
|
+
UNKNOWN = "unknown"
|
38
|
+
|
39
|
+
|
40
|
+
class A2ATask:
|
41
|
+
"""Represents an A2A task with its state and data."""
|
42
|
+
|
43
|
+
def __init__(self, task_id: str, skill_name: str, input_data: Dict[str, Any],
|
44
|
+
client_metadata: Optional[Dict[str, Any]] = None):
|
45
|
+
"""
|
46
|
+
Initialize an A2A task.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
task_id: Unique task identifier
|
50
|
+
skill_name: Name of the skill being invoked
|
51
|
+
input_data: Input parameters for the task
|
52
|
+
client_metadata: Optional metadata from the client
|
53
|
+
"""
|
54
|
+
self.task_id = task_id
|
55
|
+
self.skill_name = skill_name
|
56
|
+
self.input_data = input_data
|
57
|
+
self.client_metadata = client_metadata or {}
|
58
|
+
|
59
|
+
self.state = TaskState.SUBMITTED
|
60
|
+
self.created_at = datetime.now(timezone.utc)
|
61
|
+
self.updated_at = self.created_at
|
62
|
+
self.completed_at: Optional[datetime] = None
|
63
|
+
|
64
|
+
self.messages: List[Dict[str, Any]] = []
|
65
|
+
self.artifacts: List[Dict[str, Any]] = []
|
66
|
+
self.error: Optional[Dict[str, Any]] = None
|
67
|
+
self.progress: float = 0.0
|
68
|
+
|
69
|
+
# For streaming tasks
|
70
|
+
self.stream_queue: Optional[asyncio.Queue] = None
|
71
|
+
self.is_streaming = False
|
72
|
+
|
73
|
+
def to_dict(self) -> Dict[str, Any]:
|
74
|
+
"""Convert task to dictionary format for A2A responses."""
|
75
|
+
return {
|
76
|
+
"taskId": self.task_id,
|
77
|
+
"state": self.state.value,
|
78
|
+
"createdAt": self.created_at.isoformat(),
|
79
|
+
"updatedAt": self.updated_at.isoformat(),
|
80
|
+
"completedAt": self.completed_at.isoformat() if self.completed_at else None,
|
81
|
+
"messages": self.messages,
|
82
|
+
"artifacts": self.artifacts,
|
83
|
+
"error": self.error,
|
84
|
+
"progress": self.progress,
|
85
|
+
"metadata": {
|
86
|
+
"skillName": self.skill_name,
|
87
|
+
"isStreaming": self.is_streaming,
|
88
|
+
"clientMetadata": self.client_metadata
|
89
|
+
}
|
90
|
+
}
|
91
|
+
|
92
|
+
def add_message(self, role: str, content: str, message_type: str = "text"):
|
93
|
+
"""Add a message to the task."""
|
94
|
+
message = {
|
95
|
+
"role": role,
|
96
|
+
"parts": [{
|
97
|
+
"type": message_type,
|
98
|
+
"text": content
|
99
|
+
}],
|
100
|
+
"timestamp": datetime.now(timezone.utc).isoformat()
|
101
|
+
}
|
102
|
+
self.messages.append(message)
|
103
|
+
self.updated_at = datetime.now(timezone.utc)
|
104
|
+
|
105
|
+
def add_artifact(self, name: str, content: Any, artifact_type: str = "text"):
|
106
|
+
"""Add an artifact to the task."""
|
107
|
+
artifact = {
|
108
|
+
"name": name,
|
109
|
+
"type": artifact_type,
|
110
|
+
"content": content,
|
111
|
+
"createdAt": datetime.now(timezone.utc).isoformat()
|
112
|
+
}
|
113
|
+
self.artifacts.append(artifact)
|
114
|
+
self.updated_at = datetime.now(timezone.utc)
|
115
|
+
|
116
|
+
def update_state(self, new_state: TaskState, error_message: str = None):
|
117
|
+
"""Update the task state."""
|
118
|
+
self.state = new_state
|
119
|
+
self.updated_at = datetime.now(timezone.utc)
|
120
|
+
|
121
|
+
if new_state in [TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED]:
|
122
|
+
self.completed_at = self.updated_at
|
123
|
+
|
124
|
+
if new_state == TaskState.FAILED and error_message:
|
125
|
+
self.error = {
|
126
|
+
"message": error_message,
|
127
|
+
"timestamp": self.updated_at.isoformat()
|
128
|
+
}
|
129
|
+
|
130
|
+
def update_progress(self, progress: float):
|
131
|
+
"""Update task progress (0.0 to 1.0)."""
|
132
|
+
self.progress = max(0.0, min(1.0, progress))
|
133
|
+
self.updated_at = datetime.now(timezone.utc)
|
134
|
+
|
135
|
+
|
136
|
+
class A2ATaskManager:
|
137
|
+
"""Manages A2A tasks and their lifecycle."""
|
138
|
+
|
139
|
+
def __init__(self, stream_interpreter: Callable, vac_interpreter: Optional[Callable] = None):
|
140
|
+
"""
|
141
|
+
Initialize the task manager.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
stream_interpreter: Function for streaming VAC interactions
|
145
|
+
vac_interpreter: Function for static VAC interactions
|
146
|
+
"""
|
147
|
+
self.stream_interpreter = stream_interpreter
|
148
|
+
self.vac_interpreter = vac_interpreter
|
149
|
+
self.tasks: Dict[str, A2ATask] = {}
|
150
|
+
self.task_subscribers: Dict[str, List[asyncio.Queue]] = {}
|
151
|
+
|
152
|
+
async def create_task(self, skill_name: str, input_data: Dict[str, Any],
|
153
|
+
client_metadata: Optional[Dict[str, Any]] = None) -> A2ATask:
|
154
|
+
"""
|
155
|
+
Create a new A2A task.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
skill_name: Name of the skill to invoke
|
159
|
+
input_data: Input parameters for the task
|
160
|
+
client_metadata: Optional client metadata
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
Created A2ATask instance
|
164
|
+
"""
|
165
|
+
task_id = str(uuid.uuid4())
|
166
|
+
task = A2ATask(task_id, skill_name, input_data, client_metadata)
|
167
|
+
|
168
|
+
self.tasks[task_id] = task
|
169
|
+
self.task_subscribers[task_id] = []
|
170
|
+
|
171
|
+
log.info(f"Created A2A task {task_id} for skill {skill_name}")
|
172
|
+
|
173
|
+
# Start processing the task
|
174
|
+
asyncio.create_task(self._process_task(task))
|
175
|
+
|
176
|
+
return task
|
177
|
+
|
178
|
+
async def get_task(self, task_id: str) -> Optional[A2ATask]:
|
179
|
+
"""Get a task by ID."""
|
180
|
+
return self.tasks.get(task_id)
|
181
|
+
|
182
|
+
async def cancel_task(self, task_id: str) -> bool:
|
183
|
+
"""
|
184
|
+
Cancel a task.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
task_id: ID of the task to cancel
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
True if task was canceled, False if not found or already completed
|
191
|
+
"""
|
192
|
+
task = self.tasks.get(task_id)
|
193
|
+
if not task:
|
194
|
+
return False
|
195
|
+
|
196
|
+
if task.state in [TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED]:
|
197
|
+
return False
|
198
|
+
|
199
|
+
task.update_state(TaskState.CANCELED)
|
200
|
+
await self._notify_subscribers(task_id, task.to_dict())
|
201
|
+
|
202
|
+
log.info(f"Canceled A2A task {task_id}")
|
203
|
+
return True
|
204
|
+
|
205
|
+
async def subscribe_to_task(self, task_id: str):
|
206
|
+
"""
|
207
|
+
Subscribe to task updates via async generator.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
task_id: ID of the task to subscribe to
|
211
|
+
|
212
|
+
Yields:
|
213
|
+
Task update dictionaries
|
214
|
+
"""
|
215
|
+
if task_id not in self.tasks:
|
216
|
+
return # Early return for async generator
|
217
|
+
|
218
|
+
queue = asyncio.Queue()
|
219
|
+
self.task_subscribers[task_id].append(queue)
|
220
|
+
|
221
|
+
# Send current state immediately
|
222
|
+
current_task = self.tasks[task_id]
|
223
|
+
await queue.put(current_task.to_dict())
|
224
|
+
|
225
|
+
try:
|
226
|
+
while True:
|
227
|
+
update = await queue.get()
|
228
|
+
if update is None: # End signal
|
229
|
+
break
|
230
|
+
yield update
|
231
|
+
finally:
|
232
|
+
# Clean up subscription
|
233
|
+
if queue in self.task_subscribers.get(task_id, []):
|
234
|
+
self.task_subscribers[task_id].remove(queue)
|
235
|
+
|
236
|
+
async def _process_task(self, task: A2ATask):
|
237
|
+
"""
|
238
|
+
Process a task by invoking the appropriate VAC functionality.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
task: The task to process
|
242
|
+
"""
|
243
|
+
try:
|
244
|
+
task.update_state(TaskState.WORKING)
|
245
|
+
await self._notify_subscribers(task.task_id, task.to_dict())
|
246
|
+
|
247
|
+
# Parse skill name to extract VAC name and operation
|
248
|
+
vac_name, operation = self._parse_skill_name(task.skill_name)
|
249
|
+
|
250
|
+
if operation == "query":
|
251
|
+
await self._process_query_task(task, vac_name)
|
252
|
+
elif operation == "stream":
|
253
|
+
await self._process_stream_task(task, vac_name)
|
254
|
+
elif operation == "memory_search":
|
255
|
+
await self._process_memory_search_task(task, vac_name)
|
256
|
+
else:
|
257
|
+
raise ValueError(f"Unknown operation: {operation}")
|
258
|
+
|
259
|
+
except Exception as e:
|
260
|
+
log.error(f"Error processing task {task.task_id}: {e}")
|
261
|
+
task.update_state(TaskState.FAILED, str(e))
|
262
|
+
await self._notify_subscribers(task.task_id, task.to_dict())
|
263
|
+
|
264
|
+
async def _process_query_task(self, task: A2ATask, vac_name: str):
|
265
|
+
"""Process a static query task."""
|
266
|
+
if not self.vac_interpreter:
|
267
|
+
raise ValueError("VAC interpreter not available for query tasks")
|
268
|
+
|
269
|
+
query = task.input_data.get("query", "")
|
270
|
+
chat_history = task.input_data.get("chat_history", [])
|
271
|
+
context = task.input_data.get("context", {})
|
272
|
+
|
273
|
+
task.add_message("user", query)
|
274
|
+
task.update_progress(0.3)
|
275
|
+
await self._notify_subscribers(task.task_id, task.to_dict())
|
276
|
+
|
277
|
+
# Convert A2A chat history to VAC format
|
278
|
+
vac_chat_history = self._convert_chat_history(chat_history)
|
279
|
+
|
280
|
+
# Execute VAC query
|
281
|
+
if asyncio.iscoroutinefunction(self.vac_interpreter):
|
282
|
+
result = await self.vac_interpreter(
|
283
|
+
question=query,
|
284
|
+
vector_name=vac_name,
|
285
|
+
chat_history=vac_chat_history,
|
286
|
+
**context
|
287
|
+
)
|
288
|
+
else:
|
289
|
+
# Run sync function in executor
|
290
|
+
loop = asyncio.get_event_loop()
|
291
|
+
result = await loop.run_in_executor(
|
292
|
+
None,
|
293
|
+
lambda: self.vac_interpreter(
|
294
|
+
question=query,
|
295
|
+
vector_name=vac_name,
|
296
|
+
chat_history=vac_chat_history,
|
297
|
+
**context
|
298
|
+
)
|
299
|
+
)
|
300
|
+
|
301
|
+
# Process result
|
302
|
+
if isinstance(result, dict):
|
303
|
+
answer = result.get("answer", str(result))
|
304
|
+
source_docs = result.get("source_documents", [])
|
305
|
+
else:
|
306
|
+
answer = str(result)
|
307
|
+
source_docs = []
|
308
|
+
|
309
|
+
task.add_message("agent", answer)
|
310
|
+
task.add_artifact("response", {
|
311
|
+
"answer": answer,
|
312
|
+
"source_documents": source_docs,
|
313
|
+
"metadata": result if isinstance(result, dict) else {}
|
314
|
+
}, "json")
|
315
|
+
|
316
|
+
task.update_progress(1.0)
|
317
|
+
task.update_state(TaskState.COMPLETED)
|
318
|
+
await self._notify_subscribers(task.task_id, task.to_dict())
|
319
|
+
|
320
|
+
async def _process_stream_task(self, task: A2ATask, vac_name: str):
|
321
|
+
"""Process a streaming task."""
|
322
|
+
query = task.input_data.get("query", "")
|
323
|
+
chat_history = task.input_data.get("chat_history", [])
|
324
|
+
stream_settings = task.input_data.get("stream_settings", {})
|
325
|
+
|
326
|
+
task.add_message("user", query)
|
327
|
+
task.is_streaming = True
|
328
|
+
task.update_progress(0.1)
|
329
|
+
await self._notify_subscribers(task.task_id, task.to_dict())
|
330
|
+
|
331
|
+
# Convert chat history
|
332
|
+
vac_chat_history = self._convert_chat_history(chat_history)
|
333
|
+
|
334
|
+
try:
|
335
|
+
# Import streaming function
|
336
|
+
from ..streaming import start_streaming_chat_async
|
337
|
+
|
338
|
+
# Start streaming
|
339
|
+
full_response = ""
|
340
|
+
async for chunk in start_streaming_chat_async(
|
341
|
+
question=query,
|
342
|
+
vector_name=vac_name,
|
343
|
+
qna_func_async=self.stream_interpreter,
|
344
|
+
chat_history=vac_chat_history,
|
345
|
+
wait_time=stream_settings.get("wait_time", 7),
|
346
|
+
timeout=stream_settings.get("timeout", 120)
|
347
|
+
):
|
348
|
+
if isinstance(chunk, dict) and 'answer' in chunk:
|
349
|
+
full_response = chunk['answer']
|
350
|
+
task.update_progress(0.9)
|
351
|
+
elif isinstance(chunk, str):
|
352
|
+
full_response += chunk
|
353
|
+
task.update_progress(min(0.8, task.progress + 0.1))
|
354
|
+
|
355
|
+
# Send intermediate updates
|
356
|
+
await self._notify_subscribers(task.task_id, task.to_dict())
|
357
|
+
|
358
|
+
# Final response
|
359
|
+
task.add_message("agent", full_response)
|
360
|
+
task.add_artifact("streaming_response", {
|
361
|
+
"final_answer": full_response,
|
362
|
+
"stream_completed": True
|
363
|
+
}, "json")
|
364
|
+
|
365
|
+
task.update_progress(1.0)
|
366
|
+
task.update_state(TaskState.COMPLETED)
|
367
|
+
|
368
|
+
except Exception as e:
|
369
|
+
task.update_state(TaskState.FAILED, f"Streaming error: {str(e)}")
|
370
|
+
|
371
|
+
await self._notify_subscribers(task.task_id, task.to_dict())
|
372
|
+
|
373
|
+
async def _process_memory_search_task(self, task: A2ATask, vac_name: str):
|
374
|
+
"""Process a memory search task."""
|
375
|
+
# This is a placeholder for memory search functionality
|
376
|
+
# In a real implementation, this would interface with the VAC's vector store
|
377
|
+
|
378
|
+
search_query = task.input_data.get("search_query", "")
|
379
|
+
limit = task.input_data.get("limit", 10)
|
380
|
+
similarity_threshold = task.input_data.get("similarity_threshold", 0.7)
|
381
|
+
|
382
|
+
task.add_message("agent", f"Searching memory for: {search_query}")
|
383
|
+
task.update_progress(0.5)
|
384
|
+
await self._notify_subscribers(task.task_id, task.to_dict())
|
385
|
+
|
386
|
+
# TODO: Implement actual memory search
|
387
|
+
# For now, return a placeholder response
|
388
|
+
results = [{
|
389
|
+
"content": f"Memory search result for '{search_query}' (placeholder)",
|
390
|
+
"score": 0.8,
|
391
|
+
"metadata": {"vac_name": vac_name, "query": search_query}
|
392
|
+
}]
|
393
|
+
|
394
|
+
task.add_artifact("search_results", {
|
395
|
+
"results": results,
|
396
|
+
"total_results": len(results),
|
397
|
+
"query": search_query,
|
398
|
+
"limit": limit,
|
399
|
+
"similarity_threshold": similarity_threshold
|
400
|
+
}, "json")
|
401
|
+
|
402
|
+
task.update_progress(1.0)
|
403
|
+
task.update_state(TaskState.COMPLETED)
|
404
|
+
await self._notify_subscribers(task.task_id, task.to_dict())
|
405
|
+
|
406
|
+
def _parse_skill_name(self, skill_name: str) -> tuple[str, str]:
|
407
|
+
"""
|
408
|
+
Parse skill name to extract VAC name and operation.
|
409
|
+
|
410
|
+
Expected format: "vac_{operation}_{vac_name}"
|
411
|
+
|
412
|
+
Args:
|
413
|
+
skill_name: The skill name to parse
|
414
|
+
|
415
|
+
Returns:
|
416
|
+
Tuple of (vac_name, operation)
|
417
|
+
"""
|
418
|
+
parts = skill_name.split("_")
|
419
|
+
if len(parts) < 3 or parts[0] != "vac":
|
420
|
+
raise ValueError(f"Invalid skill name format: {skill_name}")
|
421
|
+
|
422
|
+
operation = parts[1] # query, stream, memory
|
423
|
+
vac_name = "_".join(parts[2:]) # Handle VAC names with underscores
|
424
|
+
|
425
|
+
return vac_name, operation
|
426
|
+
|
427
|
+
def _convert_chat_history(self, a2a_history: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
428
|
+
"""
|
429
|
+
Convert A2A chat history format to VAC format.
|
430
|
+
|
431
|
+
Args:
|
432
|
+
a2a_history: A2A format chat history
|
433
|
+
|
434
|
+
Returns:
|
435
|
+
VAC format chat history
|
436
|
+
"""
|
437
|
+
vac_history = []
|
438
|
+
|
439
|
+
for msg in a2a_history:
|
440
|
+
role = msg.get("role", "")
|
441
|
+
content = msg.get("content", "")
|
442
|
+
|
443
|
+
if role == "user":
|
444
|
+
vac_history.append({"human": content})
|
445
|
+
elif role == "assistant":
|
446
|
+
vac_history.append({"ai": content})
|
447
|
+
|
448
|
+
return vac_history
|
449
|
+
|
450
|
+
async def _notify_subscribers(self, task_id: str, task_data: Dict[str, Any]):
|
451
|
+
"""Notify all subscribers of a task update."""
|
452
|
+
if task_id in self.task_subscribers:
|
453
|
+
for queue in self.task_subscribers[task_id]:
|
454
|
+
try:
|
455
|
+
await queue.put(task_data)
|
456
|
+
except Exception as e:
|
457
|
+
log.warning(f"Failed to notify subscriber for task {task_id}: {e}")
|
458
|
+
|
459
|
+
def cleanup_completed_tasks(self, max_age_hours: int = 24):
|
460
|
+
"""
|
461
|
+
Clean up completed tasks older than specified age.
|
462
|
+
|
463
|
+
Args:
|
464
|
+
max_age_hours: Maximum age in hours for completed tasks
|
465
|
+
"""
|
466
|
+
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
|
467
|
+
|
468
|
+
tasks_to_remove = []
|
469
|
+
for task_id, task in self.tasks.items():
|
470
|
+
if (task.state in [TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED] and
|
471
|
+
task.completed_at and task.completed_at < cutoff_time):
|
472
|
+
tasks_to_remove.append(task_id)
|
473
|
+
|
474
|
+
for task_id in tasks_to_remove:
|
475
|
+
del self.tasks[task_id]
|
476
|
+
if task_id in self.task_subscribers:
|
477
|
+
del self.task_subscribers[task_id]
|
478
|
+
|
479
|
+
if tasks_to_remove:
|
480
|
+
log.info(f"Cleaned up {len(tasks_to_remove)} completed tasks")
|