a2a-lite 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
a2a_lite/discovery.py ADDED
@@ -0,0 +1,148 @@
1
+ """
2
+ mDNS-based local agent discovery using Zeroconf.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import asyncio
7
+ import socket
8
+ from typing import Dict, List, Optional
9
+ from dataclasses import dataclass
10
+
11
+ from zeroconf import ServiceBrowser, ServiceListener, Zeroconf, ServiceInfo
12
+ from zeroconf.asyncio import AsyncZeroconf, AsyncServiceBrowser
13
+
14
+
15
+ SERVICE_TYPE = "_a2a-agent._tcp.local."
16
+
17
+
18
+ @dataclass
19
+ class DiscoveredAgent:
20
+ """Information about a discovered A2A agent."""
21
+ name: str
22
+ host: str
23
+ port: int
24
+ url: str
25
+ properties: Dict[str, str]
26
+
27
+
28
+ class AgentDiscovery:
29
+ """
30
+ Discover A2A agents on the local network using mDNS.
31
+ """
32
+
33
+ def __init__(self):
34
+ self._discovered: Dict[str, DiscoveredAgent] = {}
35
+ self._zeroconf: Optional[Zeroconf] = None
36
+ self._service_info: Optional[ServiceInfo] = None
37
+
38
+ async def discover(self, timeout: float = 5.0) -> List[DiscoveredAgent]:
39
+ """
40
+ Discover agents on the local network.
41
+
42
+ Args:
43
+ timeout: How long to wait for discoveries
44
+
45
+ Returns:
46
+ List of discovered agents
47
+ """
48
+ self._discovered.clear()
49
+
50
+ async with AsyncZeroconf() as azc:
51
+ listener = _DiscoveryListener(self._discovered)
52
+ browser = AsyncServiceBrowser(
53
+ azc.zeroconf, SERVICE_TYPE, listener
54
+ )
55
+
56
+ await asyncio.sleep(timeout)
57
+ await browser.async_cancel()
58
+
59
+ return list(self._discovered.values())
60
+
61
+ def register(
62
+ self,
63
+ name: str,
64
+ port: int,
65
+ properties: Optional[Dict[str, str]] = None,
66
+ ) -> ServiceInfo:
67
+ """
68
+ Register this agent for discovery.
69
+
70
+ Args:
71
+ name: Agent name
72
+ port: Port the agent is running on
73
+ properties: Additional properties to advertise
74
+
75
+ Returns:
76
+ ServiceInfo for unregistration
77
+ """
78
+ hostname = socket.gethostname()
79
+ local_ip = self._get_local_ip()
80
+
81
+ # Sanitize name for mDNS (remove spaces and special chars)
82
+ safe_name = "".join(c if c.isalnum() or c == "-" else "-" for c in name)
83
+
84
+ info = ServiceInfo(
85
+ SERVICE_TYPE,
86
+ f"{safe_name}.{SERVICE_TYPE}",
87
+ addresses=[socket.inet_aton(local_ip)],
88
+ port=port,
89
+ properties=properties or {},
90
+ server=f"{hostname}.local.",
91
+ )
92
+
93
+ self._zeroconf = Zeroconf()
94
+ self._zeroconf.register_service(info)
95
+ self._service_info = info
96
+
97
+ return info
98
+
99
+ def unregister(self) -> None:
100
+ """Unregister this agent from discovery."""
101
+ if self._zeroconf and self._service_info:
102
+ self._zeroconf.unregister_service(self._service_info)
103
+ self._zeroconf.close()
104
+ self._zeroconf = None
105
+ self._service_info = None
106
+
107
+ def _get_local_ip(self) -> str:
108
+ """Get local IP address."""
109
+ try:
110
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
111
+ s.connect(("8.8.8.8", 80))
112
+ ip = s.getsockname()[0]
113
+ s.close()
114
+ return ip
115
+ except Exception:
116
+ return "127.0.0.1"
117
+
118
+
119
+ class _DiscoveryListener(ServiceListener):
120
+ """Internal listener for mDNS service discovery."""
121
+
122
+ def __init__(self, discovered: Dict[str, DiscoveredAgent]):
123
+ self.discovered = discovered
124
+
125
+ def add_service(self, zc: Zeroconf, type_: str, name: str) -> None:
126
+ info = zc.get_service_info(type_, name)
127
+ if info:
128
+ agent_name = name.replace(f".{SERVICE_TYPE}", "")
129
+ host = socket.inet_ntoa(info.addresses[0]) if info.addresses else "localhost"
130
+ port = info.port
131
+
132
+ self.discovered[name] = DiscoveredAgent(
133
+ name=agent_name,
134
+ host=host,
135
+ port=port,
136
+ url=f"http://{host}:{port}",
137
+ properties={
138
+ k.decode() if isinstance(k, bytes) else k:
139
+ v.decode() if isinstance(v, bytes) else str(v)
140
+ for k, v in info.properties.items()
141
+ },
142
+ )
143
+
144
+ def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None:
145
+ self.discovered.pop(name, None)
146
+
147
+ def update_service(self, zc: Zeroconf, type_: str, name: str) -> None:
148
+ self.add_service(zc, type_, name)
a2a_lite/executor.py ADDED
@@ -0,0 +1,317 @@
1
+ """
2
+ Wrapper around A2A's AgentExecutor that dispatches to registered skill handlers.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ import asyncio
7
+ import inspect
8
+ import json
9
+ from typing import Any, Callable, Dict, List, Optional
10
+
11
+ from a2a.server.agent_execution import AgentExecutor, RequestContext
12
+ from a2a.server.events import EventQueue
13
+
14
+ from .decorators import SkillDefinition
15
+ from .middleware import MiddlewareChain, MiddlewareContext
16
+ from .streaming import is_generator_function, stream_generator
17
+
18
+
19
+ class LiteAgentExecutor(AgentExecutor):
20
+ """
21
+ Simplified AgentExecutor with optional enterprise features.
22
+
23
+ Features (all optional):
24
+ - Middleware chain
25
+ - Streaming support
26
+ - Pydantic model conversion
27
+ - Task context injection
28
+ - Interaction context injection
29
+ - Authentication
30
+ - File part handling
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ skills: Dict[str, SkillDefinition],
36
+ error_handler: Optional[Callable] = None,
37
+ middleware: Optional[MiddlewareChain] = None,
38
+ on_complete: Optional[List[Callable]] = None,
39
+ auth_provider: Optional[Any] = None,
40
+ task_store: Optional[Any] = None,
41
+ ):
42
+ self.skills = skills
43
+ self.error_handler = error_handler
44
+ self.middleware = middleware or MiddlewareChain()
45
+ self.on_complete = on_complete or []
46
+ self.auth_provider = auth_provider
47
+ self.task_store = task_store
48
+
49
+ async def execute(
50
+ self,
51
+ context: RequestContext,
52
+ event_queue: EventQueue,
53
+ ) -> None:
54
+ """Execute a skill based on the incoming request."""
55
+ from a2a.utils import new_agent_text_message
56
+
57
+ try:
58
+ # Extract message and parts
59
+ message, parts = self._extract_message_and_parts(context)
60
+
61
+ # Parse skill call
62
+ skill_name, params = self._parse_message(message)
63
+
64
+ # Build middleware context
65
+ ctx = MiddlewareContext(
66
+ skill=skill_name,
67
+ params=params,
68
+ message=message,
69
+ )
70
+
71
+ # Store parts in metadata for skill access
72
+ ctx.metadata["parts"] = parts
73
+ ctx.metadata["event_queue"] = event_queue
74
+
75
+ # Define final handler
76
+ async def final_handler(ctx: MiddlewareContext) -> Any:
77
+ return await self._execute_skill(
78
+ ctx.skill,
79
+ ctx.params,
80
+ event_queue,
81
+ ctx.metadata,
82
+ )
83
+
84
+ # Execute through middleware chain
85
+ result = await self.middleware.execute(ctx, final_handler)
86
+
87
+ # If result is not None and not already streamed, send it
88
+ if result is not None:
89
+ if isinstance(result, (dict, list)):
90
+ response_text = json.dumps(result, indent=2, default=str)
91
+ else:
92
+ response_text = str(result)
93
+ await event_queue.enqueue_event(new_agent_text_message(response_text))
94
+
95
+ # Call completion hooks
96
+ for hook in self.on_complete:
97
+ try:
98
+ if asyncio.iscoroutinefunction(hook):
99
+ await hook(skill_name, result, ctx)
100
+ else:
101
+ hook(skill_name, result, ctx)
102
+ except Exception:
103
+ pass
104
+
105
+ except Exception as e:
106
+ await self._handle_error(e, event_queue)
107
+
108
+ async def _execute_skill(
109
+ self,
110
+ skill_name: Optional[str],
111
+ params: Dict[str, Any],
112
+ event_queue: EventQueue,
113
+ metadata: Dict[str, Any],
114
+ ) -> Any:
115
+ """Execute a skill with the given parameters."""
116
+ if skill_name is None:
117
+ if not self.skills:
118
+ return {"error": "No skills registered"}
119
+ skill_name = list(self.skills.keys())[0]
120
+
121
+ if skill_name not in self.skills:
122
+ return {
123
+ "error": f"Unknown skill: {skill_name}",
124
+ "available_skills": list(self.skills.keys()),
125
+ }
126
+
127
+ skill_def = self.skills[skill_name]
128
+
129
+ # Convert Pydantic models and file parts in params
130
+ params = self._convert_params(skill_def, params, metadata)
131
+
132
+ # Inject special contexts if needed
133
+ if skill_def.needs_task_context and self.task_store:
134
+ from .tasks import TaskContext, Task, TaskStatus, TaskState
135
+ task = self.task_store.create(skill_name, params)
136
+ # Only pass event_queue for streaming skills (status updates go via SSE)
137
+ eq = event_queue if skill_def.is_streaming else None
138
+ task_ctx = TaskContext(task, eq)
139
+ params["task"] = task_ctx
140
+
141
+ if skill_def.needs_interaction:
142
+ from .human_loop import InteractionContext
143
+ task_id = metadata.get("task_id", "unknown")
144
+ interaction_ctx = InteractionContext(task_id, event_queue)
145
+ params["ctx"] = interaction_ctx
146
+
147
+ # Call the handler
148
+ handler = skill_def.handler
149
+
150
+ if skill_def.is_streaming or is_generator_function(handler):
151
+ gen = handler(**params)
152
+ await stream_generator(gen, event_queue)
153
+ return None
154
+ else:
155
+ return await self._call_handler(handler, **params)
156
+
157
+ def _convert_params(
158
+ self,
159
+ skill_def: SkillDefinition,
160
+ params: Dict[str, Any],
161
+ metadata: Dict[str, Any],
162
+ ) -> Dict[str, Any]:
163
+ """Convert parameters to Pydantic models and file parts if needed."""
164
+ handler = skill_def.handler
165
+ hints = getattr(handler, '__annotations__', {})
166
+
167
+ converted = {}
168
+ for param_name, value in params.items():
169
+ param_type = hints.get(param_name)
170
+
171
+ if param_type is None:
172
+ converted[param_name] = value
173
+ continue
174
+
175
+ type_name = str(param_type)
176
+
177
+ # Skip special context types
178
+ if "TaskContext" in type_name or "InteractionContext" in type_name:
179
+ continue
180
+
181
+ # Convert FilePart
182
+ if "FilePart" in type_name:
183
+ from .parts import FilePart
184
+ if isinstance(value, dict):
185
+ # Handle both A2A format and simple dict format
186
+ if "file" in value:
187
+ converted[param_name] = FilePart.from_a2a(value)
188
+ else:
189
+ # Simple format: {name, data, mime_type}
190
+ data = value.get("data")
191
+ if isinstance(data, str):
192
+ data = data.encode("utf-8")
193
+ converted[param_name] = FilePart(
194
+ name=value.get("name", "unknown"),
195
+ mime_type=value.get("mime_type", "application/octet-stream"),
196
+ data=data,
197
+ uri=value.get("uri"),
198
+ )
199
+ else:
200
+ converted[param_name] = value
201
+ continue
202
+
203
+ # Convert DataPart
204
+ if "DataPart" in type_name:
205
+ from .parts import DataPart
206
+ if isinstance(value, dict):
207
+ # Handle both A2A format and simple dict format
208
+ if "type" in value and value.get("type") == "data":
209
+ converted[param_name] = DataPart.from_a2a(value)
210
+ else:
211
+ # Simple format: pass the dict directly as data
212
+ converted[param_name] = DataPart(data=value)
213
+ else:
214
+ converted[param_name] = value
215
+ continue
216
+
217
+ # Convert Pydantic models
218
+ if hasattr(param_type, 'model_validate'):
219
+ if isinstance(value, dict):
220
+ converted[param_name] = param_type.model_validate(value)
221
+ else:
222
+ converted[param_name] = value
223
+ continue
224
+
225
+ # Default: keep as-is
226
+ converted[param_name] = value
227
+
228
+ return converted
229
+
230
+ def _parse_message(self, message: str) -> tuple[Optional[str], Dict[str, Any]]:
231
+ """Parse message to extract skill name and params."""
232
+ try:
233
+ data = json.loads(message)
234
+ if isinstance(data, dict) and 'skill' in data:
235
+ return data['skill'], data.get('params', {})
236
+ except json.JSONDecodeError:
237
+ pass
238
+
239
+ return None, {"message": message}
240
+
241
+ def _extract_message_and_parts(self, context: RequestContext) -> tuple[str, List[Any]]:
242
+ """Extract message text and any file/data parts."""
243
+ text = ""
244
+ parts = []
245
+
246
+ if hasattr(context, 'message') and context.message:
247
+ message = context.message
248
+ if hasattr(message, 'parts'):
249
+ raw_parts = message.parts
250
+ else:
251
+ raw_parts = message.get('parts', [])
252
+
253
+ for part in raw_parts:
254
+ if hasattr(part, 'root'):
255
+ part = part.root
256
+
257
+ # Get part type
258
+ if hasattr(part, 'text'):
259
+ text = part.text
260
+ elif isinstance(part, dict):
261
+ part_type = part.get('type') or part.get('kind')
262
+ if part_type == 'text':
263
+ text = part.get('text', '')
264
+ elif part_type in ('file', 'data'):
265
+ parts.append(part)
266
+
267
+ return text, parts
268
+
269
+ async def _handle_error(self, e: Exception, event_queue: EventQueue) -> None:
270
+ """Handle execution errors."""
271
+ from a2a.utils import new_agent_text_message
272
+
273
+ if self.error_handler:
274
+ try:
275
+ result = await self._call_handler(self.error_handler, e)
276
+ await event_queue.enqueue_event(
277
+ new_agent_text_message(json.dumps(result, default=str))
278
+ )
279
+ return
280
+ except Exception as handler_error:
281
+ await event_queue.enqueue_event(
282
+ new_agent_text_message(json.dumps({
283
+ "error": str(e),
284
+ "handler_error": str(handler_error),
285
+ "type": type(e).__name__,
286
+ }))
287
+ )
288
+ return
289
+
290
+ await event_queue.enqueue_event(
291
+ new_agent_text_message(json.dumps({
292
+ "error": str(e),
293
+ "type": type(e).__name__,
294
+ }))
295
+ )
296
+
297
+ async def cancel(
298
+ self,
299
+ context: RequestContext,
300
+ event_queue: EventQueue,
301
+ ) -> None:
302
+ """Handle cancellation requests."""
303
+ from a2a.utils import new_agent_text_message
304
+ await event_queue.enqueue_event(
305
+ new_agent_text_message(json.dumps({"status": "cancelled"}))
306
+ )
307
+
308
+ async def _call_handler(self, handler: Callable, *args, **kwargs) -> Any:
309
+ """Call a handler, handling both sync and async functions."""
310
+ if asyncio.iscoroutinefunction(handler):
311
+ return await handler(*args, **kwargs)
312
+ else:
313
+ loop = asyncio.get_event_loop()
314
+ return await loop.run_in_executor(
315
+ None,
316
+ lambda: handler(*args, **kwargs)
317
+ )