sunholo 0.143.1__py3-none-any.whl → 0.143.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,480 @@
1
+ # Copyright [2024] [Holosun ApS]
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ A2A Task Management for VAC interactions.
17
+ Handles the lifecycle of A2A tasks and their state transitions.
18
+ """
19
+
20
+ import asyncio
21
+ import uuid
22
+ from datetime import datetime, timezone, timedelta
23
+ from typing import Dict, List, Any, Optional, Callable, AsyncGenerator
24
+ from enum import Enum
25
+ import json
26
+ from ..custom_logging import log
27
+
28
+
29
+ class TaskState(Enum):
30
+ """A2A Task states as defined in the protocol."""
31
+ SUBMITTED = "submitted"
32
+ WORKING = "working"
33
+ INPUT_REQUIRED = "input-required"
34
+ COMPLETED = "completed"
35
+ CANCELED = "canceled"
36
+ FAILED = "failed"
37
+ UNKNOWN = "unknown"
38
+
39
+
40
+ class A2ATask:
41
+ """Represents an A2A task with its state and data."""
42
+
43
+ def __init__(self, task_id: str, skill_name: str, input_data: Dict[str, Any],
44
+ client_metadata: Optional[Dict[str, Any]] = None):
45
+ """
46
+ Initialize an A2A task.
47
+
48
+ Args:
49
+ task_id: Unique task identifier
50
+ skill_name: Name of the skill being invoked
51
+ input_data: Input parameters for the task
52
+ client_metadata: Optional metadata from the client
53
+ """
54
+ self.task_id = task_id
55
+ self.skill_name = skill_name
56
+ self.input_data = input_data
57
+ self.client_metadata = client_metadata or {}
58
+
59
+ self.state = TaskState.SUBMITTED
60
+ self.created_at = datetime.now(timezone.utc)
61
+ self.updated_at = self.created_at
62
+ self.completed_at: Optional[datetime] = None
63
+
64
+ self.messages: List[Dict[str, Any]] = []
65
+ self.artifacts: List[Dict[str, Any]] = []
66
+ self.error: Optional[Dict[str, Any]] = None
67
+ self.progress: float = 0.0
68
+
69
+ # For streaming tasks
70
+ self.stream_queue: Optional[asyncio.Queue] = None
71
+ self.is_streaming = False
72
+
73
+ def to_dict(self) -> Dict[str, Any]:
74
+ """Convert task to dictionary format for A2A responses."""
75
+ return {
76
+ "taskId": self.task_id,
77
+ "state": self.state.value,
78
+ "createdAt": self.created_at.isoformat(),
79
+ "updatedAt": self.updated_at.isoformat(),
80
+ "completedAt": self.completed_at.isoformat() if self.completed_at else None,
81
+ "messages": self.messages,
82
+ "artifacts": self.artifacts,
83
+ "error": self.error,
84
+ "progress": self.progress,
85
+ "metadata": {
86
+ "skillName": self.skill_name,
87
+ "isStreaming": self.is_streaming,
88
+ "clientMetadata": self.client_metadata
89
+ }
90
+ }
91
+
92
+ def add_message(self, role: str, content: str, message_type: str = "text"):
93
+ """Add a message to the task."""
94
+ message = {
95
+ "role": role,
96
+ "parts": [{
97
+ "type": message_type,
98
+ "text": content
99
+ }],
100
+ "timestamp": datetime.now(timezone.utc).isoformat()
101
+ }
102
+ self.messages.append(message)
103
+ self.updated_at = datetime.now(timezone.utc)
104
+
105
+ def add_artifact(self, name: str, content: Any, artifact_type: str = "text"):
106
+ """Add an artifact to the task."""
107
+ artifact = {
108
+ "name": name,
109
+ "type": artifact_type,
110
+ "content": content,
111
+ "createdAt": datetime.now(timezone.utc).isoformat()
112
+ }
113
+ self.artifacts.append(artifact)
114
+ self.updated_at = datetime.now(timezone.utc)
115
+
116
+ def update_state(self, new_state: TaskState, error_message: str = None):
117
+ """Update the task state."""
118
+ self.state = new_state
119
+ self.updated_at = datetime.now(timezone.utc)
120
+
121
+ if new_state in [TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED]:
122
+ self.completed_at = self.updated_at
123
+
124
+ if new_state == TaskState.FAILED and error_message:
125
+ self.error = {
126
+ "message": error_message,
127
+ "timestamp": self.updated_at.isoformat()
128
+ }
129
+
130
+ def update_progress(self, progress: float):
131
+ """Update task progress (0.0 to 1.0)."""
132
+ self.progress = max(0.0, min(1.0, progress))
133
+ self.updated_at = datetime.now(timezone.utc)
134
+
135
+
136
+ class A2ATaskManager:
137
+ """Manages A2A tasks and their lifecycle."""
138
+
139
+ def __init__(self, stream_interpreter: Callable, vac_interpreter: Optional[Callable] = None):
140
+ """
141
+ Initialize the task manager.
142
+
143
+ Args:
144
+ stream_interpreter: Function for streaming VAC interactions
145
+ vac_interpreter: Function for static VAC interactions
146
+ """
147
+ self.stream_interpreter = stream_interpreter
148
+ self.vac_interpreter = vac_interpreter
149
+ self.tasks: Dict[str, A2ATask] = {}
150
+ self.task_subscribers: Dict[str, List[asyncio.Queue]] = {}
151
+
152
+ async def create_task(self, skill_name: str, input_data: Dict[str, Any],
153
+ client_metadata: Optional[Dict[str, Any]] = None) -> A2ATask:
154
+ """
155
+ Create a new A2A task.
156
+
157
+ Args:
158
+ skill_name: Name of the skill to invoke
159
+ input_data: Input parameters for the task
160
+ client_metadata: Optional client metadata
161
+
162
+ Returns:
163
+ Created A2ATask instance
164
+ """
165
+ task_id = str(uuid.uuid4())
166
+ task = A2ATask(task_id, skill_name, input_data, client_metadata)
167
+
168
+ self.tasks[task_id] = task
169
+ self.task_subscribers[task_id] = []
170
+
171
+ log.info(f"Created A2A task {task_id} for skill {skill_name}")
172
+
173
+ # Start processing the task
174
+ asyncio.create_task(self._process_task(task))
175
+
176
+ return task
177
+
178
+ async def get_task(self, task_id: str) -> Optional[A2ATask]:
179
+ """Get a task by ID."""
180
+ return self.tasks.get(task_id)
181
+
182
+ async def cancel_task(self, task_id: str) -> bool:
183
+ """
184
+ Cancel a task.
185
+
186
+ Args:
187
+ task_id: ID of the task to cancel
188
+
189
+ Returns:
190
+ True if task was canceled, False if not found or already completed
191
+ """
192
+ task = self.tasks.get(task_id)
193
+ if not task:
194
+ return False
195
+
196
+ if task.state in [TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED]:
197
+ return False
198
+
199
+ task.update_state(TaskState.CANCELED)
200
+ await self._notify_subscribers(task_id, task.to_dict())
201
+
202
+ log.info(f"Canceled A2A task {task_id}")
203
+ return True
204
+
205
+ async def subscribe_to_task(self, task_id: str):
206
+ """
207
+ Subscribe to task updates via async generator.
208
+
209
+ Args:
210
+ task_id: ID of the task to subscribe to
211
+
212
+ Yields:
213
+ Task update dictionaries
214
+ """
215
+ if task_id not in self.tasks:
216
+ return # Early return for async generator
217
+
218
+ queue = asyncio.Queue()
219
+ self.task_subscribers[task_id].append(queue)
220
+
221
+ # Send current state immediately
222
+ current_task = self.tasks[task_id]
223
+ await queue.put(current_task.to_dict())
224
+
225
+ try:
226
+ while True:
227
+ update = await queue.get()
228
+ if update is None: # End signal
229
+ break
230
+ yield update
231
+ finally:
232
+ # Clean up subscription
233
+ if queue in self.task_subscribers.get(task_id, []):
234
+ self.task_subscribers[task_id].remove(queue)
235
+
236
+ async def _process_task(self, task: A2ATask):
237
+ """
238
+ Process a task by invoking the appropriate VAC functionality.
239
+
240
+ Args:
241
+ task: The task to process
242
+ """
243
+ try:
244
+ task.update_state(TaskState.WORKING)
245
+ await self._notify_subscribers(task.task_id, task.to_dict())
246
+
247
+ # Parse skill name to extract VAC name and operation
248
+ vac_name, operation = self._parse_skill_name(task.skill_name)
249
+
250
+ if operation == "query":
251
+ await self._process_query_task(task, vac_name)
252
+ elif operation == "stream":
253
+ await self._process_stream_task(task, vac_name)
254
+ elif operation == "memory_search":
255
+ await self._process_memory_search_task(task, vac_name)
256
+ else:
257
+ raise ValueError(f"Unknown operation: {operation}")
258
+
259
+ except Exception as e:
260
+ log.error(f"Error processing task {task.task_id}: {e}")
261
+ task.update_state(TaskState.FAILED, str(e))
262
+ await self._notify_subscribers(task.task_id, task.to_dict())
263
+
264
+ async def _process_query_task(self, task: A2ATask, vac_name: str):
265
+ """Process a static query task."""
266
+ if not self.vac_interpreter:
267
+ raise ValueError("VAC interpreter not available for query tasks")
268
+
269
+ query = task.input_data.get("query", "")
270
+ chat_history = task.input_data.get("chat_history", [])
271
+ context = task.input_data.get("context", {})
272
+
273
+ task.add_message("user", query)
274
+ task.update_progress(0.3)
275
+ await self._notify_subscribers(task.task_id, task.to_dict())
276
+
277
+ # Convert A2A chat history to VAC format
278
+ vac_chat_history = self._convert_chat_history(chat_history)
279
+
280
+ # Execute VAC query
281
+ if asyncio.iscoroutinefunction(self.vac_interpreter):
282
+ result = await self.vac_interpreter(
283
+ question=query,
284
+ vector_name=vac_name,
285
+ chat_history=vac_chat_history,
286
+ **context
287
+ )
288
+ else:
289
+ # Run sync function in executor
290
+ loop = asyncio.get_event_loop()
291
+ result = await loop.run_in_executor(
292
+ None,
293
+ lambda: self.vac_interpreter(
294
+ question=query,
295
+ vector_name=vac_name,
296
+ chat_history=vac_chat_history,
297
+ **context
298
+ )
299
+ )
300
+
301
+ # Process result
302
+ if isinstance(result, dict):
303
+ answer = result.get("answer", str(result))
304
+ source_docs = result.get("source_documents", [])
305
+ else:
306
+ answer = str(result)
307
+ source_docs = []
308
+
309
+ task.add_message("agent", answer)
310
+ task.add_artifact("response", {
311
+ "answer": answer,
312
+ "source_documents": source_docs,
313
+ "metadata": result if isinstance(result, dict) else {}
314
+ }, "json")
315
+
316
+ task.update_progress(1.0)
317
+ task.update_state(TaskState.COMPLETED)
318
+ await self._notify_subscribers(task.task_id, task.to_dict())
319
+
320
+ async def _process_stream_task(self, task: A2ATask, vac_name: str):
321
+ """Process a streaming task."""
322
+ query = task.input_data.get("query", "")
323
+ chat_history = task.input_data.get("chat_history", [])
324
+ stream_settings = task.input_data.get("stream_settings", {})
325
+
326
+ task.add_message("user", query)
327
+ task.is_streaming = True
328
+ task.update_progress(0.1)
329
+ await self._notify_subscribers(task.task_id, task.to_dict())
330
+
331
+ # Convert chat history
332
+ vac_chat_history = self._convert_chat_history(chat_history)
333
+
334
+ try:
335
+ # Import streaming function
336
+ from ..streaming import start_streaming_chat_async
337
+
338
+ # Start streaming
339
+ full_response = ""
340
+ async for chunk in start_streaming_chat_async(
341
+ question=query,
342
+ vector_name=vac_name,
343
+ qna_func_async=self.stream_interpreter,
344
+ chat_history=vac_chat_history,
345
+ wait_time=stream_settings.get("wait_time", 7),
346
+ timeout=stream_settings.get("timeout", 120)
347
+ ):
348
+ if isinstance(chunk, dict) and 'answer' in chunk:
349
+ full_response = chunk['answer']
350
+ task.update_progress(0.9)
351
+ elif isinstance(chunk, str):
352
+ full_response += chunk
353
+ task.update_progress(min(0.8, task.progress + 0.1))
354
+
355
+ # Send intermediate updates
356
+ await self._notify_subscribers(task.task_id, task.to_dict())
357
+
358
+ # Final response
359
+ task.add_message("agent", full_response)
360
+ task.add_artifact("streaming_response", {
361
+ "final_answer": full_response,
362
+ "stream_completed": True
363
+ }, "json")
364
+
365
+ task.update_progress(1.0)
366
+ task.update_state(TaskState.COMPLETED)
367
+
368
+ except Exception as e:
369
+ task.update_state(TaskState.FAILED, f"Streaming error: {str(e)}")
370
+
371
+ await self._notify_subscribers(task.task_id, task.to_dict())
372
+
373
+ async def _process_memory_search_task(self, task: A2ATask, vac_name: str):
374
+ """Process a memory search task."""
375
+ # This is a placeholder for memory search functionality
376
+ # In a real implementation, this would interface with the VAC's vector store
377
+
378
+ search_query = task.input_data.get("search_query", "")
379
+ limit = task.input_data.get("limit", 10)
380
+ similarity_threshold = task.input_data.get("similarity_threshold", 0.7)
381
+
382
+ task.add_message("agent", f"Searching memory for: {search_query}")
383
+ task.update_progress(0.5)
384
+ await self._notify_subscribers(task.task_id, task.to_dict())
385
+
386
+ # TODO: Implement actual memory search
387
+ # For now, return a placeholder response
388
+ results = [{
389
+ "content": f"Memory search result for '{search_query}' (placeholder)",
390
+ "score": 0.8,
391
+ "metadata": {"vac_name": vac_name, "query": search_query}
392
+ }]
393
+
394
+ task.add_artifact("search_results", {
395
+ "results": results,
396
+ "total_results": len(results),
397
+ "query": search_query,
398
+ "limit": limit,
399
+ "similarity_threshold": similarity_threshold
400
+ }, "json")
401
+
402
+ task.update_progress(1.0)
403
+ task.update_state(TaskState.COMPLETED)
404
+ await self._notify_subscribers(task.task_id, task.to_dict())
405
+
406
+ def _parse_skill_name(self, skill_name: str) -> tuple[str, str]:
407
+ """
408
+ Parse skill name to extract VAC name and operation.
409
+
410
+ Expected format: "vac_{operation}_{vac_name}"
411
+
412
+ Args:
413
+ skill_name: The skill name to parse
414
+
415
+ Returns:
416
+ Tuple of (vac_name, operation)
417
+ """
418
+ parts = skill_name.split("_")
419
+ if len(parts) < 3 or parts[0] != "vac":
420
+ raise ValueError(f"Invalid skill name format: {skill_name}")
421
+
422
+ operation = parts[1] # query, stream, memory
423
+ vac_name = "_".join(parts[2:]) # Handle VAC names with underscores
424
+
425
+ return vac_name, operation
426
+
427
+ def _convert_chat_history(self, a2a_history: List[Dict[str, Any]]) -> List[Dict[str, str]]:
428
+ """
429
+ Convert A2A chat history format to VAC format.
430
+
431
+ Args:
432
+ a2a_history: A2A format chat history
433
+
434
+ Returns:
435
+ VAC format chat history
436
+ """
437
+ vac_history = []
438
+
439
+ for msg in a2a_history:
440
+ role = msg.get("role", "")
441
+ content = msg.get("content", "")
442
+
443
+ if role == "user":
444
+ vac_history.append({"human": content})
445
+ elif role == "assistant":
446
+ vac_history.append({"ai": content})
447
+
448
+ return vac_history
449
+
450
+ async def _notify_subscribers(self, task_id: str, task_data: Dict[str, Any]):
451
+ """Notify all subscribers of a task update."""
452
+ if task_id in self.task_subscribers:
453
+ for queue in self.task_subscribers[task_id]:
454
+ try:
455
+ await queue.put(task_data)
456
+ except Exception as e:
457
+ log.warning(f"Failed to notify subscriber for task {task_id}: {e}")
458
+
459
+ def cleanup_completed_tasks(self, max_age_hours: int = 24):
460
+ """
461
+ Clean up completed tasks older than specified age.
462
+
463
+ Args:
464
+ max_age_hours: Maximum age in hours for completed tasks
465
+ """
466
+ cutoff_time = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
467
+
468
+ tasks_to_remove = []
469
+ for task_id, task in self.tasks.items():
470
+ if (task.state in [TaskState.COMPLETED, TaskState.FAILED, TaskState.CANCELED] and
471
+ task.completed_at and task.completed_at < cutoff_time):
472
+ tasks_to_remove.append(task_id)
473
+
474
+ for task_id in tasks_to_remove:
475
+ del self.tasks[task_id]
476
+ if task_id in self.task_subscribers:
477
+ del self.task_subscribers[task_id]
478
+
479
+ if tasks_to_remove:
480
+ log.info(f"Cleaned up {len(tasks_to_remove)} completed tasks")