a2a-lite 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
a2a_lite/streaming.py ADDED
@@ -0,0 +1,89 @@
1
+ """
2
+ Streaming support for A2A Lite agents.
3
+
4
+ Enables generator-based streaming for LLM-style responses:
5
+
6
+ @agent.skill("chat", streaming=True)
7
+ async def chat(message: str):
8
+ async for chunk in llm.stream(message):
9
+ yield chunk
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from typing import Any, AsyncGenerator, Callable, Generator, Union
14
+ import asyncio
15
+ import inspect
16
+
17
+
18
+ def is_generator_function(func: Callable) -> bool:
19
+ """Check if a function is a generator (sync or async)."""
20
+ return (
21
+ inspect.isgeneratorfunction(func) or
22
+ inspect.isasyncgenfunction(func)
23
+ )
24
+
25
+
26
+ async def collect_generator(
27
+ gen: Union[Generator, AsyncGenerator]
28
+ ) -> list[Any]:
29
+ """Collect all items from a generator into a list."""
30
+ items = []
31
+ if inspect.isasyncgen(gen):
32
+ async for item in gen:
33
+ items.append(item)
34
+ else:
35
+ for item in gen:
36
+ items.append(item)
37
+ return items
38
+
39
+
40
+ async def stream_generator(
41
+ gen: Union[Generator, AsyncGenerator],
42
+ event_queue,
43
+ ) -> None:
44
+ """
45
+ Stream generator output through the A2A event queue.
46
+
47
+ Each yielded item becomes a separate event in the stream.
48
+ """
49
+ from a2a.utils import new_agent_text_message
50
+
51
+ if inspect.isasyncgen(gen):
52
+ async for chunk in gen:
53
+ text = str(chunk) if not isinstance(chunk, str) else chunk
54
+ await event_queue.enqueue_event(new_agent_text_message(text))
55
+ else:
56
+ for chunk in gen:
57
+ text = str(chunk) if not isinstance(chunk, str) else chunk
58
+ await event_queue.enqueue_event(new_agent_text_message(text))
59
+
60
+
61
+ class StreamingResponse:
62
+ """
63
+ Helper class for building streaming responses.
64
+
65
+ Example:
66
+ @agent.skill("count")
67
+ async def count(n: int):
68
+ stream = StreamingResponse()
69
+ for i in range(n):
70
+ stream.write(f"Count: {i}")
71
+ return stream
72
+ """
73
+
74
+ def __init__(self):
75
+ self._chunks: list[str] = []
76
+
77
+ def write(self, chunk: str) -> None:
78
+ """Add a chunk to the stream."""
79
+ self._chunks.append(chunk)
80
+
81
+ def __iter__(self):
82
+ return iter(self._chunks)
83
+
84
+ def __aiter__(self):
85
+ return self._async_iter()
86
+
87
+ async def _async_iter(self):
88
+ for chunk in self._chunks:
89
+ yield chunk
a2a_lite/tasks.py ADDED
@@ -0,0 +1,221 @@
1
+ """
2
+ Task lifecycle management (OPTIONAL).
3
+
4
+ By default, A2A Lite just returns results (simple mode).
5
+ Enable task tracking when you need:
6
+ - Progress updates
7
+ - Long-running tasks
8
+ - Task status visibility
9
+
10
+ Example (simple - no task tracking needed):
11
+ @agent.skill("greet")
12
+ async def greet(name: str) -> str:
13
+ return f"Hello, {name}!"
14
+
15
+ Example (with task tracking - opt-in):
16
+ agent = Agent(name="Bot", task_store="memory") # Enable tracking
17
+
18
+ @agent.skill("process")
19
+ async def process(data: str, task: TaskContext) -> str:
20
+ await task.update("working", "Starting...", progress=0.0)
21
+
22
+ for i in range(10):
23
+ await task.update("working", f"Step {i+1}/10", progress=i/10)
24
+ await do_work(i)
25
+
26
+ return "Done!"
27
+ """
28
+ from __future__ import annotations
29
+
30
+ from dataclasses import dataclass, field
31
+ from datetime import datetime
32
+ from enum import Enum
33
+ from typing import Any, Callable, Dict, List, Optional
34
+ from uuid import uuid4
35
+ import asyncio
36
+
37
+
38
+ class TaskState(str, Enum):
39
+ """A2A Protocol task states."""
40
+ SUBMITTED = "submitted"
41
+ WORKING = "working"
42
+ INPUT_REQUIRED = "input-required"
43
+ COMPLETED = "completed"
44
+ FAILED = "failed"
45
+ CANCELED = "canceled"
46
+ AUTH_REQUIRED = "auth-required"
47
+
48
+
49
+ @dataclass
50
+ class TaskStatus:
51
+ """Current status of a task."""
52
+ state: TaskState
53
+ message: Optional[str] = None
54
+ progress: Optional[float] = None # 0.0 to 1.0
55
+ timestamp: datetime = field(default_factory=datetime.utcnow)
56
+
57
+ def to_dict(self) -> Dict[str, Any]:
58
+ return {
59
+ "state": self.state.value,
60
+ "message": self.message,
61
+ "progress": self.progress,
62
+ "timestamp": self.timestamp.isoformat(),
63
+ }
64
+
65
+
66
+ @dataclass
67
+ class Task:
68
+ """Represents an A2A task."""
69
+ id: str
70
+ skill: str
71
+ params: Dict[str, Any]
72
+ status: TaskStatus
73
+ result: Any = None
74
+ error: Optional[str] = None
75
+ artifacts: List[Any] = field(default_factory=list)
76
+ history: List[TaskStatus] = field(default_factory=list)
77
+ created_at: datetime = field(default_factory=datetime.utcnow)
78
+ updated_at: datetime = field(default_factory=datetime.utcnow)
79
+
80
+ def update_status(
81
+ self,
82
+ state: TaskState,
83
+ message: Optional[str] = None,
84
+ progress: Optional[float] = None,
85
+ ) -> None:
86
+ """Update task status."""
87
+ self.history.append(self.status)
88
+ self.status = TaskStatus(state=state, message=message, progress=progress)
89
+ self.updated_at = datetime.utcnow()
90
+
91
+
92
+ class TaskContext:
93
+ """
94
+ Context passed to skills when task tracking is enabled.
95
+
96
+ Provides methods to update task status and request user input.
97
+
98
+ Example:
99
+ @agent.skill("process")
100
+ async def process(data: str, task: TaskContext) -> str:
101
+ await task.update("working", "Processing...", progress=0.5)
102
+ result = await heavy_computation(data)
103
+ return result
104
+ """
105
+
106
+ def __init__(self, task: Task, event_queue=None, input_handler=None):
107
+ self._task = task
108
+ self._event_queue = event_queue
109
+ self._input_handler = input_handler
110
+ self._status_callbacks: List[Callable] = []
111
+
112
+ @property
113
+ def task_id(self) -> str:
114
+ return self._task.id
115
+
116
+ @property
117
+ def state(self) -> TaskState:
118
+ return self._task.status.state
119
+
120
+ async def update(
121
+ self,
122
+ state: str = "working",
123
+ message: Optional[str] = None,
124
+ progress: Optional[float] = None,
125
+ ) -> None:
126
+ """
127
+ Update task status.
128
+
129
+ Args:
130
+ state: Task state (working, completed, failed, etc.)
131
+ message: Human-readable status message
132
+ progress: Progress from 0.0 to 1.0
133
+
134
+ Example:
135
+ await task.update("working", "Processing item 5/10", progress=0.5)
136
+ """
137
+ task_state = TaskState(state) if isinstance(state, str) else state
138
+ self._task.update_status(task_state, message, progress)
139
+
140
+ # Notify callbacks
141
+ for callback in self._status_callbacks:
142
+ try:
143
+ if asyncio.iscoroutinefunction(callback):
144
+ await callback(self._task.status)
145
+ else:
146
+ callback(self._task.status)
147
+ except Exception:
148
+ pass
149
+
150
+ # Send SSE event if streaming
151
+ if self._event_queue:
152
+ await self._send_status_event()
153
+
154
+ async def _send_status_event(self) -> None:
155
+ """Send status update via SSE."""
156
+ if self._event_queue:
157
+ from a2a.utils import new_agent_text_message
158
+ import json
159
+ status_msg = json.dumps({
160
+ "_type": "status_update",
161
+ "task_id": self._task.id,
162
+ "status": self._task.status.to_dict(),
163
+ })
164
+ await self._event_queue.enqueue_event(new_agent_text_message(status_msg))
165
+
166
+ def on_status_change(self, callback: Callable) -> None:
167
+ """Register callback for status changes."""
168
+ self._status_callbacks.append(callback)
169
+
170
+
171
+ class TaskStore:
172
+ """
173
+ In-memory task store.
174
+
175
+ For production, extend this with Redis/DB backend.
176
+ """
177
+
178
+ def __init__(self):
179
+ self._tasks: Dict[str, Task] = {}
180
+
181
+ def create(self, skill: str, params: Dict[str, Any]) -> Task:
182
+ """Create a new task."""
183
+ task = Task(
184
+ id=uuid4().hex,
185
+ skill=skill,
186
+ params=params,
187
+ status=TaskStatus(state=TaskState.SUBMITTED),
188
+ )
189
+ self._tasks[task.id] = task
190
+ return task
191
+
192
+ def get(self, task_id: str) -> Optional[Task]:
193
+ """Get task by ID."""
194
+ return self._tasks.get(task_id)
195
+
196
+ def update(self, task: Task) -> None:
197
+ """Update task in store."""
198
+ self._tasks[task.id] = task
199
+
200
+ def list(
201
+ self,
202
+ state: Optional[TaskState] = None,
203
+ skill: Optional[str] = None,
204
+ limit: int = 100,
205
+ ) -> List[Task]:
206
+ """List tasks with optional filters."""
207
+ tasks = list(self._tasks.values())
208
+
209
+ if state:
210
+ tasks = [t for t in tasks if t.status.state == state]
211
+ if skill:
212
+ tasks = [t for t in tasks if t.skill == skill]
213
+
214
+ return sorted(tasks, key=lambda t: t.created_at, reverse=True)[:limit]
215
+
216
+ def delete(self, task_id: str) -> bool:
217
+ """Delete a task."""
218
+ if task_id in self._tasks:
219
+ del self._tasks[task_id]
220
+ return True
221
+ return False
a2a_lite/testing.py ADDED
@@ -0,0 +1,268 @@
1
+ """
2
+ Testing utilities for A2A Lite agents.
3
+
4
+ Makes testing agents as simple as:
5
+
6
+ from a2a_lite.testing import TestClient
7
+
8
+ def test_my_agent():
9
+ client = TestClient(agent)
10
+ result = client.call("greet", name="World")
11
+ assert result == "Hello, World!"
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ from typing import Any, Dict, Optional
17
+ from uuid import uuid4
18
+
19
+
20
+ class TestClient:
21
+ """
22
+ Simple test client for A2A Lite agents.
23
+
24
+ Example:
25
+ agent = Agent(name="Test", description="Test")
26
+
27
+ @agent.skill("add")
28
+ async def add(a: int, b: int) -> int:
29
+ return a + b
30
+
31
+ # In your test
32
+ client = TestClient(agent)
33
+ assert client.call("add", a=2, b=3) == 5
34
+ """
35
+
36
+ def __init__(self, agent):
37
+ """
38
+ Create a test client for an agent.
39
+
40
+ Args:
41
+ agent: The A2A Lite Agent instance to test
42
+ """
43
+ self.agent = agent
44
+ self._app = None
45
+ self._client = None
46
+
47
+ def _get_client(self):
48
+ """Lazily create the test client."""
49
+ if self._client is None:
50
+ from starlette.testclient import TestClient as StarletteTestClient
51
+ self._app = self.agent.get_app()
52
+ self._client = StarletteTestClient(self._app)
53
+ return self._client
54
+
55
+ def call(self, skill: str, **params) -> Any:
56
+ """
57
+ Call a skill and return the result.
58
+
59
+ Args:
60
+ skill: Name of the skill to call
61
+ **params: Parameters to pass to the skill
62
+
63
+ Returns:
64
+ The skill's return value (parsed from JSON if possible)
65
+
66
+ Example:
67
+ result = client.call("greet", name="World")
68
+ """
69
+ client = self._get_client()
70
+
71
+ message = json.dumps({"skill": skill, "params": params})
72
+ request_body = {
73
+ "jsonrpc": "2.0",
74
+ "method": "message/send",
75
+ "id": uuid4().hex,
76
+ "params": {
77
+ "message": {
78
+ "role": "user",
79
+ "parts": [{"type": "text", "text": message}],
80
+ "messageId": uuid4().hex,
81
+ }
82
+ }
83
+ }
84
+
85
+ response = client.post("/", json=request_body)
86
+ response.raise_for_status()
87
+ data = response.json()
88
+
89
+ # Extract the actual result from A2A response
90
+ return self._extract_result(data)
91
+
92
+ def _extract_result(self, response: Dict) -> Any:
93
+ """Extract the skill result from A2A response."""
94
+ if "error" in response:
95
+ raise TestClientError(response["error"])
96
+
97
+ result = response.get("result", {})
98
+
99
+ # Get text from message parts
100
+ parts = result.get("parts", [])
101
+ for part in parts:
102
+ if part.get("kind") == "text" or part.get("type") == "text":
103
+ text = part.get("text", "")
104
+ # Try to parse as JSON
105
+ try:
106
+ return json.loads(text)
107
+ except json.JSONDecodeError:
108
+ return text
109
+
110
+ return result
111
+
112
+ def get_agent_card(self) -> Dict[str, Any]:
113
+ """
114
+ Fetch the agent card.
115
+
116
+ Returns:
117
+ The agent card as a dictionary
118
+ """
119
+ client = self._get_client()
120
+ response = client.get("/.well-known/agent.json")
121
+ response.raise_for_status()
122
+ return response.json()
123
+
124
+ def list_skills(self) -> list[str]:
125
+ """
126
+ Get list of available skill names.
127
+
128
+ Returns:
129
+ List of skill names
130
+ """
131
+ card = self.get_agent_card()
132
+ return [s.get("name", s.get("id")) for s in card.get("skills", [])]
133
+
134
+ def stream(self, skill: str, **params) -> list[Any]:
135
+ """
136
+ Call a streaming skill and collect all results.
137
+
138
+ Args:
139
+ skill: Name of the skill to call
140
+ **params: Parameters to pass to the skill
141
+
142
+ Returns:
143
+ List of all streamed values
144
+
145
+ Example:
146
+ results = client.stream("count", limit=3)
147
+ assert len(results) == 3
148
+ """
149
+ import asyncio
150
+
151
+ # Access skills directly from agent
152
+ skill_def = self.agent._skills.get(skill)
153
+
154
+ if not skill_def:
155
+ raise TestClientError(f"Unknown skill: {skill}")
156
+
157
+ # Call handler directly and collect results
158
+ results = []
159
+
160
+ async def run_handler():
161
+ handler = skill_def.handler
162
+ gen = handler(**params)
163
+
164
+ # Handle both async and sync generators
165
+ if hasattr(gen, '__anext__'):
166
+ async for value in gen:
167
+ results.append(value)
168
+ elif hasattr(gen, '__next__'):
169
+ for value in gen:
170
+ results.append(value)
171
+ else:
172
+ # Not a generator, just a coroutine
173
+ result = await gen
174
+ results.append(result)
175
+
176
+ asyncio.get_event_loop().run_until_complete(run_handler())
177
+ return results
178
+
179
+
180
+ class TestClientError(Exception):
181
+ """Error from test client."""
182
+ pass
183
+
184
+
185
+ # Async version for async tests
186
+ class AsyncTestClient:
187
+ """
188
+ Async test client for A2A Lite agents.
189
+
190
+ Example:
191
+ async def test_my_agent():
192
+ client = AsyncTestClient(agent)
193
+ result = await client.call("greet", name="World")
194
+ assert result == "Hello, World!"
195
+ """
196
+
197
+ def __init__(self, agent):
198
+ self.agent = agent
199
+ self._app = None
200
+ self._client = None
201
+
202
+ async def _get_client(self):
203
+ """Lazily create the async test client."""
204
+ if self._client is None:
205
+ import httpx
206
+ self._app = self.agent.get_app()
207
+ self._client = httpx.AsyncClient(
208
+ app=self._app,
209
+ base_url="http://testserver"
210
+ )
211
+ return self._client
212
+
213
+ async def call(self, skill: str, **params) -> Any:
214
+ """
215
+ Call a skill and return the result.
216
+
217
+ Args:
218
+ skill: Name of the skill to call
219
+ **params: Parameters to pass to the skill
220
+
221
+ Returns:
222
+ The skill's return value
223
+ """
224
+ client = await self._get_client()
225
+
226
+ message = json.dumps({"skill": skill, "params": params})
227
+ request_body = {
228
+ "jsonrpc": "2.0",
229
+ "method": "message/send",
230
+ "id": uuid4().hex,
231
+ "params": {
232
+ "message": {
233
+ "role": "user",
234
+ "parts": [{"type": "text", "text": message}],
235
+ "messageId": uuid4().hex,
236
+ }
237
+ }
238
+ }
239
+
240
+ response = await client.post("/", json=request_body)
241
+ response.raise_for_status()
242
+ data = response.json()
243
+
244
+ return self._extract_result(data)
245
+
246
+ def _extract_result(self, response: Dict) -> Any:
247
+ """Extract the skill result from A2A response."""
248
+ if "error" in response:
249
+ raise TestClientError(response["error"])
250
+
251
+ result = response.get("result", {})
252
+ parts = result.get("parts", [])
253
+
254
+ for part in parts:
255
+ if part.get("kind") == "text" or part.get("type") == "text":
256
+ text = part.get("text", "")
257
+ try:
258
+ return json.loads(text)
259
+ except json.JSONDecodeError:
260
+ return text
261
+
262
+ return result
263
+
264
+ async def close(self):
265
+ """Close the client."""
266
+ if self._client:
267
+ await self._client.aclose()
268
+ self._client = None
a2a_lite/utils.py ADDED
@@ -0,0 +1,117 @@
1
+ """
2
+ Helper functions for A2A Lite.
3
+ """
4
+ from typing import Any, Dict, Type, get_origin, get_args, Union
5
+ import inspect
6
+
7
+
8
+ def type_to_json_schema(python_type: Type) -> Dict[str, Any]:
9
+ """
10
+ Convert Python type to JSON Schema.
11
+
12
+ Handles basic types, generics (List, Dict, Optional), and Pydantic models.
13
+ """
14
+ # Handle None type
15
+ if python_type is type(None):
16
+ return {"type": "null"}
17
+
18
+ # Basic type mapping
19
+ type_map = {
20
+ str: {"type": "string"},
21
+ int: {"type": "integer"},
22
+ float: {"type": "number"},
23
+ bool: {"type": "boolean"},
24
+ list: {"type": "array"},
25
+ dict: {"type": "object"},
26
+ Any: {"type": "object"},
27
+ }
28
+
29
+ # Check basic types first
30
+ if python_type in type_map:
31
+ return type_map[python_type]
32
+
33
+ # Handle generic types
34
+ origin = get_origin(python_type)
35
+ args = get_args(python_type)
36
+
37
+ # Handle Optional (Union[X, None])
38
+ if origin is Union:
39
+ non_none_args = [a for a in args if a is not type(None)]
40
+ if len(non_none_args) == 1:
41
+ # This is Optional[X]
42
+ return type_to_json_schema(non_none_args[0])
43
+ # Union of multiple types
44
+ return {"oneOf": [type_to_json_schema(a) for a in args]}
45
+
46
+ # Handle List[X]
47
+ if origin is list and args:
48
+ return {
49
+ "type": "array",
50
+ "items": type_to_json_schema(args[0])
51
+ }
52
+
53
+ # Handle Dict[K, V]
54
+ if origin is dict and len(args) >= 2:
55
+ return {
56
+ "type": "object",
57
+ "additionalProperties": type_to_json_schema(args[1])
58
+ }
59
+
60
+ # Handle Pydantic models
61
+ if hasattr(python_type, 'model_json_schema'):
62
+ return python_type.model_json_schema()
63
+
64
+ # Handle dataclasses
65
+ if hasattr(python_type, '__dataclass_fields__'):
66
+ properties = {}
67
+ required = []
68
+ for field_name, field_info in python_type.__dataclass_fields__.items():
69
+ properties[field_name] = type_to_json_schema(field_info.type)
70
+ if field_info.default is inspect.Parameter.empty and field_info.default_factory is inspect.Parameter.empty:
71
+ required.append(field_name)
72
+ return {
73
+ "type": "object",
74
+ "properties": properties,
75
+ "required": required,
76
+ }
77
+
78
+ # Fallback for unknown types
79
+ return {"type": "object"}
80
+
81
+
82
+ def extract_function_schemas(func) -> tuple[Dict[str, Any], Dict[str, Any]]:
83
+ """
84
+ Extract input and output JSON schemas from a function's type hints.
85
+
86
+ Returns:
87
+ Tuple of (input_schema, output_schema)
88
+ """
89
+ sig = inspect.signature(func)
90
+ hints = getattr(func, '__annotations__', {})
91
+
92
+ # Build input schema from parameters
93
+ properties = {}
94
+ required = []
95
+
96
+ for param_name, param in sig.parameters.items():
97
+ if param_name in ('self', 'cls'):
98
+ continue
99
+
100
+ param_type = hints.get(param_name, Any)
101
+ properties[param_name] = type_to_json_schema(param_type)
102
+
103
+ # Parameter is required if it has no default value
104
+ if param.default is inspect.Parameter.empty:
105
+ required.append(param_name)
106
+
107
+ input_schema = {
108
+ "type": "object",
109
+ "properties": properties,
110
+ "required": required,
111
+ }
112
+
113
+ # Build output schema from return type
114
+ return_type = hints.get('return', Any)
115
+ output_schema = type_to_json_schema(return_type)
116
+
117
+ return input_schema, output_schema