a2a-lite 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- a2a_lite/__init__.py +151 -0
- a2a_lite/agent.py +453 -0
- a2a_lite/auth.py +344 -0
- a2a_lite/cli.py +336 -0
- a2a_lite/decorators.py +32 -0
- a2a_lite/discovery.py +148 -0
- a2a_lite/executor.py +317 -0
- a2a_lite/human_loop.py +284 -0
- a2a_lite/middleware.py +193 -0
- a2a_lite/parts.py +218 -0
- a2a_lite/streaming.py +89 -0
- a2a_lite/tasks.py +221 -0
- a2a_lite/testing.py +268 -0
- a2a_lite/utils.py +117 -0
- a2a_lite/webhooks.py +232 -0
- a2a_lite-0.1.0.dist-info/METADATA +383 -0
- a2a_lite-0.1.0.dist-info/RECORD +19 -0
- a2a_lite-0.1.0.dist-info/WHEEL +4 -0
- a2a_lite-0.1.0.dist-info/entry_points.txt +2 -0
a2a_lite/streaming.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Streaming support for A2A Lite agents.
|
|
3
|
+
|
|
4
|
+
Enables generator-based streaming for LLM-style responses:
|
|
5
|
+
|
|
6
|
+
@agent.skill("chat", streaming=True)
|
|
7
|
+
async def chat(message: str):
|
|
8
|
+
async for chunk in llm.stream(message):
|
|
9
|
+
yield chunk
|
|
10
|
+
"""
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any, AsyncGenerator, Callable, Generator, Union
|
|
14
|
+
import asyncio
|
|
15
|
+
import inspect
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def is_generator_function(func: Callable) -> bool:
|
|
19
|
+
"""Check if a function is a generator (sync or async)."""
|
|
20
|
+
return (
|
|
21
|
+
inspect.isgeneratorfunction(func) or
|
|
22
|
+
inspect.isasyncgenfunction(func)
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
async def collect_generator(
|
|
27
|
+
gen: Union[Generator, AsyncGenerator]
|
|
28
|
+
) -> list[Any]:
|
|
29
|
+
"""Collect all items from a generator into a list."""
|
|
30
|
+
items = []
|
|
31
|
+
if inspect.isasyncgen(gen):
|
|
32
|
+
async for item in gen:
|
|
33
|
+
items.append(item)
|
|
34
|
+
else:
|
|
35
|
+
for item in gen:
|
|
36
|
+
items.append(item)
|
|
37
|
+
return items
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
async def stream_generator(
|
|
41
|
+
gen: Union[Generator, AsyncGenerator],
|
|
42
|
+
event_queue,
|
|
43
|
+
) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Stream generator output through the A2A event queue.
|
|
46
|
+
|
|
47
|
+
Each yielded item becomes a separate event in the stream.
|
|
48
|
+
"""
|
|
49
|
+
from a2a.utils import new_agent_text_message
|
|
50
|
+
|
|
51
|
+
if inspect.isasyncgen(gen):
|
|
52
|
+
async for chunk in gen:
|
|
53
|
+
text = str(chunk) if not isinstance(chunk, str) else chunk
|
|
54
|
+
await event_queue.enqueue_event(new_agent_text_message(text))
|
|
55
|
+
else:
|
|
56
|
+
for chunk in gen:
|
|
57
|
+
text = str(chunk) if not isinstance(chunk, str) else chunk
|
|
58
|
+
await event_queue.enqueue_event(new_agent_text_message(text))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class StreamingResponse:
|
|
62
|
+
"""
|
|
63
|
+
Helper class for building streaming responses.
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
@agent.skill("count")
|
|
67
|
+
async def count(n: int):
|
|
68
|
+
stream = StreamingResponse()
|
|
69
|
+
for i in range(n):
|
|
70
|
+
stream.write(f"Count: {i}")
|
|
71
|
+
return stream
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self):
|
|
75
|
+
self._chunks: list[str] = []
|
|
76
|
+
|
|
77
|
+
def write(self, chunk: str) -> None:
|
|
78
|
+
"""Add a chunk to the stream."""
|
|
79
|
+
self._chunks.append(chunk)
|
|
80
|
+
|
|
81
|
+
def __iter__(self):
|
|
82
|
+
return iter(self._chunks)
|
|
83
|
+
|
|
84
|
+
def __aiter__(self):
|
|
85
|
+
return self._async_iter()
|
|
86
|
+
|
|
87
|
+
async def _async_iter(self):
|
|
88
|
+
for chunk in self._chunks:
|
|
89
|
+
yield chunk
|
a2a_lite/tasks.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task lifecycle management (OPTIONAL).
|
|
3
|
+
|
|
4
|
+
By default, A2A Lite just returns results (simple mode).
|
|
5
|
+
Enable task tracking when you need:
|
|
6
|
+
- Progress updates
|
|
7
|
+
- Long-running tasks
|
|
8
|
+
- Task status visibility
|
|
9
|
+
|
|
10
|
+
Example (simple - no task tracking needed):
|
|
11
|
+
@agent.skill("greet")
|
|
12
|
+
async def greet(name: str) -> str:
|
|
13
|
+
return f"Hello, {name}!"
|
|
14
|
+
|
|
15
|
+
Example (with task tracking - opt-in):
|
|
16
|
+
agent = Agent(name="Bot", task_store="memory") # Enable tracking
|
|
17
|
+
|
|
18
|
+
@agent.skill("process")
|
|
19
|
+
async def process(data: str, task: TaskContext) -> str:
|
|
20
|
+
await task.update("working", "Starting...", progress=0.0)
|
|
21
|
+
|
|
22
|
+
for i in range(10):
|
|
23
|
+
await task.update("working", f"Step {i+1}/10", progress=i/10)
|
|
24
|
+
await do_work(i)
|
|
25
|
+
|
|
26
|
+
return "Done!"
|
|
27
|
+
"""
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
from dataclasses import dataclass, field
|
|
31
|
+
from datetime import datetime
|
|
32
|
+
from enum import Enum
|
|
33
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
34
|
+
from uuid import uuid4
|
|
35
|
+
import asyncio
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TaskState(str, Enum):
|
|
39
|
+
"""A2A Protocol task states."""
|
|
40
|
+
SUBMITTED = "submitted"
|
|
41
|
+
WORKING = "working"
|
|
42
|
+
INPUT_REQUIRED = "input-required"
|
|
43
|
+
COMPLETED = "completed"
|
|
44
|
+
FAILED = "failed"
|
|
45
|
+
CANCELED = "canceled"
|
|
46
|
+
AUTH_REQUIRED = "auth-required"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class TaskStatus:
|
|
51
|
+
"""Current status of a task."""
|
|
52
|
+
state: TaskState
|
|
53
|
+
message: Optional[str] = None
|
|
54
|
+
progress: Optional[float] = None # 0.0 to 1.0
|
|
55
|
+
timestamp: datetime = field(default_factory=datetime.utcnow)
|
|
56
|
+
|
|
57
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
58
|
+
return {
|
|
59
|
+
"state": self.state.value,
|
|
60
|
+
"message": self.message,
|
|
61
|
+
"progress": self.progress,
|
|
62
|
+
"timestamp": self.timestamp.isoformat(),
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class Task:
|
|
68
|
+
"""Represents an A2A task."""
|
|
69
|
+
id: str
|
|
70
|
+
skill: str
|
|
71
|
+
params: Dict[str, Any]
|
|
72
|
+
status: TaskStatus
|
|
73
|
+
result: Any = None
|
|
74
|
+
error: Optional[str] = None
|
|
75
|
+
artifacts: List[Any] = field(default_factory=list)
|
|
76
|
+
history: List[TaskStatus] = field(default_factory=list)
|
|
77
|
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
|
78
|
+
updated_at: datetime = field(default_factory=datetime.utcnow)
|
|
79
|
+
|
|
80
|
+
def update_status(
|
|
81
|
+
self,
|
|
82
|
+
state: TaskState,
|
|
83
|
+
message: Optional[str] = None,
|
|
84
|
+
progress: Optional[float] = None,
|
|
85
|
+
) -> None:
|
|
86
|
+
"""Update task status."""
|
|
87
|
+
self.history.append(self.status)
|
|
88
|
+
self.status = TaskStatus(state=state, message=message, progress=progress)
|
|
89
|
+
self.updated_at = datetime.utcnow()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class TaskContext:
|
|
93
|
+
"""
|
|
94
|
+
Context passed to skills when task tracking is enabled.
|
|
95
|
+
|
|
96
|
+
Provides methods to update task status and request user input.
|
|
97
|
+
|
|
98
|
+
Example:
|
|
99
|
+
@agent.skill("process")
|
|
100
|
+
async def process(data: str, task: TaskContext) -> str:
|
|
101
|
+
await task.update("working", "Processing...", progress=0.5)
|
|
102
|
+
result = await heavy_computation(data)
|
|
103
|
+
return result
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(self, task: Task, event_queue=None, input_handler=None):
|
|
107
|
+
self._task = task
|
|
108
|
+
self._event_queue = event_queue
|
|
109
|
+
self._input_handler = input_handler
|
|
110
|
+
self._status_callbacks: List[Callable] = []
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def task_id(self) -> str:
|
|
114
|
+
return self._task.id
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def state(self) -> TaskState:
|
|
118
|
+
return self._task.status.state
|
|
119
|
+
|
|
120
|
+
async def update(
|
|
121
|
+
self,
|
|
122
|
+
state: str = "working",
|
|
123
|
+
message: Optional[str] = None,
|
|
124
|
+
progress: Optional[float] = None,
|
|
125
|
+
) -> None:
|
|
126
|
+
"""
|
|
127
|
+
Update task status.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
state: Task state (working, completed, failed, etc.)
|
|
131
|
+
message: Human-readable status message
|
|
132
|
+
progress: Progress from 0.0 to 1.0
|
|
133
|
+
|
|
134
|
+
Example:
|
|
135
|
+
await task.update("working", "Processing item 5/10", progress=0.5)
|
|
136
|
+
"""
|
|
137
|
+
task_state = TaskState(state) if isinstance(state, str) else state
|
|
138
|
+
self._task.update_status(task_state, message, progress)
|
|
139
|
+
|
|
140
|
+
# Notify callbacks
|
|
141
|
+
for callback in self._status_callbacks:
|
|
142
|
+
try:
|
|
143
|
+
if asyncio.iscoroutinefunction(callback):
|
|
144
|
+
await callback(self._task.status)
|
|
145
|
+
else:
|
|
146
|
+
callback(self._task.status)
|
|
147
|
+
except Exception:
|
|
148
|
+
pass
|
|
149
|
+
|
|
150
|
+
# Send SSE event if streaming
|
|
151
|
+
if self._event_queue:
|
|
152
|
+
await self._send_status_event()
|
|
153
|
+
|
|
154
|
+
async def _send_status_event(self) -> None:
|
|
155
|
+
"""Send status update via SSE."""
|
|
156
|
+
if self._event_queue:
|
|
157
|
+
from a2a.utils import new_agent_text_message
|
|
158
|
+
import json
|
|
159
|
+
status_msg = json.dumps({
|
|
160
|
+
"_type": "status_update",
|
|
161
|
+
"task_id": self._task.id,
|
|
162
|
+
"status": self._task.status.to_dict(),
|
|
163
|
+
})
|
|
164
|
+
await self._event_queue.enqueue_event(new_agent_text_message(status_msg))
|
|
165
|
+
|
|
166
|
+
def on_status_change(self, callback: Callable) -> None:
|
|
167
|
+
"""Register callback for status changes."""
|
|
168
|
+
self._status_callbacks.append(callback)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class TaskStore:
|
|
172
|
+
"""
|
|
173
|
+
In-memory task store.
|
|
174
|
+
|
|
175
|
+
For production, extend this with Redis/DB backend.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
def __init__(self):
|
|
179
|
+
self._tasks: Dict[str, Task] = {}
|
|
180
|
+
|
|
181
|
+
def create(self, skill: str, params: Dict[str, Any]) -> Task:
|
|
182
|
+
"""Create a new task."""
|
|
183
|
+
task = Task(
|
|
184
|
+
id=uuid4().hex,
|
|
185
|
+
skill=skill,
|
|
186
|
+
params=params,
|
|
187
|
+
status=TaskStatus(state=TaskState.SUBMITTED),
|
|
188
|
+
)
|
|
189
|
+
self._tasks[task.id] = task
|
|
190
|
+
return task
|
|
191
|
+
|
|
192
|
+
def get(self, task_id: str) -> Optional[Task]:
|
|
193
|
+
"""Get task by ID."""
|
|
194
|
+
return self._tasks.get(task_id)
|
|
195
|
+
|
|
196
|
+
def update(self, task: Task) -> None:
|
|
197
|
+
"""Update task in store."""
|
|
198
|
+
self._tasks[task.id] = task
|
|
199
|
+
|
|
200
|
+
def list(
|
|
201
|
+
self,
|
|
202
|
+
state: Optional[TaskState] = None,
|
|
203
|
+
skill: Optional[str] = None,
|
|
204
|
+
limit: int = 100,
|
|
205
|
+
) -> List[Task]:
|
|
206
|
+
"""List tasks with optional filters."""
|
|
207
|
+
tasks = list(self._tasks.values())
|
|
208
|
+
|
|
209
|
+
if state:
|
|
210
|
+
tasks = [t for t in tasks if t.status.state == state]
|
|
211
|
+
if skill:
|
|
212
|
+
tasks = [t for t in tasks if t.skill == skill]
|
|
213
|
+
|
|
214
|
+
return sorted(tasks, key=lambda t: t.created_at, reverse=True)[:limit]
|
|
215
|
+
|
|
216
|
+
def delete(self, task_id: str) -> bool:
|
|
217
|
+
"""Delete a task."""
|
|
218
|
+
if task_id in self._tasks:
|
|
219
|
+
del self._tasks[task_id]
|
|
220
|
+
return True
|
|
221
|
+
return False
|
a2a_lite/testing.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Testing utilities for A2A Lite agents.
|
|
3
|
+
|
|
4
|
+
Makes testing agents as simple as:
|
|
5
|
+
|
|
6
|
+
from a2a_lite.testing import TestClient
|
|
7
|
+
|
|
8
|
+
def test_my_agent():
|
|
9
|
+
client = TestClient(agent)
|
|
10
|
+
result = client.call("greet", name="World")
|
|
11
|
+
assert result == "Hello, World!"
|
|
12
|
+
"""
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
from typing import Any, Dict, Optional
|
|
17
|
+
from uuid import uuid4
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TestClient:
|
|
21
|
+
"""
|
|
22
|
+
Simple test client for A2A Lite agents.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
agent = Agent(name="Test", description="Test")
|
|
26
|
+
|
|
27
|
+
@agent.skill("add")
|
|
28
|
+
async def add(a: int, b: int) -> int:
|
|
29
|
+
return a + b
|
|
30
|
+
|
|
31
|
+
# In your test
|
|
32
|
+
client = TestClient(agent)
|
|
33
|
+
assert client.call("add", a=2, b=3) == 5
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, agent):
|
|
37
|
+
"""
|
|
38
|
+
Create a test client for an agent.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
agent: The A2A Lite Agent instance to test
|
|
42
|
+
"""
|
|
43
|
+
self.agent = agent
|
|
44
|
+
self._app = None
|
|
45
|
+
self._client = None
|
|
46
|
+
|
|
47
|
+
def _get_client(self):
|
|
48
|
+
"""Lazily create the test client."""
|
|
49
|
+
if self._client is None:
|
|
50
|
+
from starlette.testclient import TestClient as StarletteTestClient
|
|
51
|
+
self._app = self.agent.get_app()
|
|
52
|
+
self._client = StarletteTestClient(self._app)
|
|
53
|
+
return self._client
|
|
54
|
+
|
|
55
|
+
def call(self, skill: str, **params) -> Any:
|
|
56
|
+
"""
|
|
57
|
+
Call a skill and return the result.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
skill: Name of the skill to call
|
|
61
|
+
**params: Parameters to pass to the skill
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The skill's return value (parsed from JSON if possible)
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
result = client.call("greet", name="World")
|
|
68
|
+
"""
|
|
69
|
+
client = self._get_client()
|
|
70
|
+
|
|
71
|
+
message = json.dumps({"skill": skill, "params": params})
|
|
72
|
+
request_body = {
|
|
73
|
+
"jsonrpc": "2.0",
|
|
74
|
+
"method": "message/send",
|
|
75
|
+
"id": uuid4().hex,
|
|
76
|
+
"params": {
|
|
77
|
+
"message": {
|
|
78
|
+
"role": "user",
|
|
79
|
+
"parts": [{"type": "text", "text": message}],
|
|
80
|
+
"messageId": uuid4().hex,
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
response = client.post("/", json=request_body)
|
|
86
|
+
response.raise_for_status()
|
|
87
|
+
data = response.json()
|
|
88
|
+
|
|
89
|
+
# Extract the actual result from A2A response
|
|
90
|
+
return self._extract_result(data)
|
|
91
|
+
|
|
92
|
+
def _extract_result(self, response: Dict) -> Any:
|
|
93
|
+
"""Extract the skill result from A2A response."""
|
|
94
|
+
if "error" in response:
|
|
95
|
+
raise TestClientError(response["error"])
|
|
96
|
+
|
|
97
|
+
result = response.get("result", {})
|
|
98
|
+
|
|
99
|
+
# Get text from message parts
|
|
100
|
+
parts = result.get("parts", [])
|
|
101
|
+
for part in parts:
|
|
102
|
+
if part.get("kind") == "text" or part.get("type") == "text":
|
|
103
|
+
text = part.get("text", "")
|
|
104
|
+
# Try to parse as JSON
|
|
105
|
+
try:
|
|
106
|
+
return json.loads(text)
|
|
107
|
+
except json.JSONDecodeError:
|
|
108
|
+
return text
|
|
109
|
+
|
|
110
|
+
return result
|
|
111
|
+
|
|
112
|
+
def get_agent_card(self) -> Dict[str, Any]:
|
|
113
|
+
"""
|
|
114
|
+
Fetch the agent card.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
The agent card as a dictionary
|
|
118
|
+
"""
|
|
119
|
+
client = self._get_client()
|
|
120
|
+
response = client.get("/.well-known/agent.json")
|
|
121
|
+
response.raise_for_status()
|
|
122
|
+
return response.json()
|
|
123
|
+
|
|
124
|
+
def list_skills(self) -> list[str]:
|
|
125
|
+
"""
|
|
126
|
+
Get list of available skill names.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
List of skill names
|
|
130
|
+
"""
|
|
131
|
+
card = self.get_agent_card()
|
|
132
|
+
return [s.get("name", s.get("id")) for s in card.get("skills", [])]
|
|
133
|
+
|
|
134
|
+
def stream(self, skill: str, **params) -> list[Any]:
|
|
135
|
+
"""
|
|
136
|
+
Call a streaming skill and collect all results.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
skill: Name of the skill to call
|
|
140
|
+
**params: Parameters to pass to the skill
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
List of all streamed values
|
|
144
|
+
|
|
145
|
+
Example:
|
|
146
|
+
results = client.stream("count", limit=3)
|
|
147
|
+
assert len(results) == 3
|
|
148
|
+
"""
|
|
149
|
+
import asyncio
|
|
150
|
+
|
|
151
|
+
# Access skills directly from agent
|
|
152
|
+
skill_def = self.agent._skills.get(skill)
|
|
153
|
+
|
|
154
|
+
if not skill_def:
|
|
155
|
+
raise TestClientError(f"Unknown skill: {skill}")
|
|
156
|
+
|
|
157
|
+
# Call handler directly and collect results
|
|
158
|
+
results = []
|
|
159
|
+
|
|
160
|
+
async def run_handler():
|
|
161
|
+
handler = skill_def.handler
|
|
162
|
+
gen = handler(**params)
|
|
163
|
+
|
|
164
|
+
# Handle both async and sync generators
|
|
165
|
+
if hasattr(gen, '__anext__'):
|
|
166
|
+
async for value in gen:
|
|
167
|
+
results.append(value)
|
|
168
|
+
elif hasattr(gen, '__next__'):
|
|
169
|
+
for value in gen:
|
|
170
|
+
results.append(value)
|
|
171
|
+
else:
|
|
172
|
+
# Not a generator, just a coroutine
|
|
173
|
+
result = await gen
|
|
174
|
+
results.append(result)
|
|
175
|
+
|
|
176
|
+
asyncio.get_event_loop().run_until_complete(run_handler())
|
|
177
|
+
return results
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class TestClientError(Exception):
|
|
181
|
+
"""Error from test client."""
|
|
182
|
+
pass
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# Async version for async tests
|
|
186
|
+
class AsyncTestClient:
|
|
187
|
+
"""
|
|
188
|
+
Async test client for A2A Lite agents.
|
|
189
|
+
|
|
190
|
+
Example:
|
|
191
|
+
async def test_my_agent():
|
|
192
|
+
client = AsyncTestClient(agent)
|
|
193
|
+
result = await client.call("greet", name="World")
|
|
194
|
+
assert result == "Hello, World!"
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
def __init__(self, agent):
|
|
198
|
+
self.agent = agent
|
|
199
|
+
self._app = None
|
|
200
|
+
self._client = None
|
|
201
|
+
|
|
202
|
+
async def _get_client(self):
|
|
203
|
+
"""Lazily create the async test client."""
|
|
204
|
+
if self._client is None:
|
|
205
|
+
import httpx
|
|
206
|
+
self._app = self.agent.get_app()
|
|
207
|
+
self._client = httpx.AsyncClient(
|
|
208
|
+
app=self._app,
|
|
209
|
+
base_url="http://testserver"
|
|
210
|
+
)
|
|
211
|
+
return self._client
|
|
212
|
+
|
|
213
|
+
async def call(self, skill: str, **params) -> Any:
|
|
214
|
+
"""
|
|
215
|
+
Call a skill and return the result.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
skill: Name of the skill to call
|
|
219
|
+
**params: Parameters to pass to the skill
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
The skill's return value
|
|
223
|
+
"""
|
|
224
|
+
client = await self._get_client()
|
|
225
|
+
|
|
226
|
+
message = json.dumps({"skill": skill, "params": params})
|
|
227
|
+
request_body = {
|
|
228
|
+
"jsonrpc": "2.0",
|
|
229
|
+
"method": "message/send",
|
|
230
|
+
"id": uuid4().hex,
|
|
231
|
+
"params": {
|
|
232
|
+
"message": {
|
|
233
|
+
"role": "user",
|
|
234
|
+
"parts": [{"type": "text", "text": message}],
|
|
235
|
+
"messageId": uuid4().hex,
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
response = await client.post("/", json=request_body)
|
|
241
|
+
response.raise_for_status()
|
|
242
|
+
data = response.json()
|
|
243
|
+
|
|
244
|
+
return self._extract_result(data)
|
|
245
|
+
|
|
246
|
+
def _extract_result(self, response: Dict) -> Any:
|
|
247
|
+
"""Extract the skill result from A2A response."""
|
|
248
|
+
if "error" in response:
|
|
249
|
+
raise TestClientError(response["error"])
|
|
250
|
+
|
|
251
|
+
result = response.get("result", {})
|
|
252
|
+
parts = result.get("parts", [])
|
|
253
|
+
|
|
254
|
+
for part in parts:
|
|
255
|
+
if part.get("kind") == "text" or part.get("type") == "text":
|
|
256
|
+
text = part.get("text", "")
|
|
257
|
+
try:
|
|
258
|
+
return json.loads(text)
|
|
259
|
+
except json.JSONDecodeError:
|
|
260
|
+
return text
|
|
261
|
+
|
|
262
|
+
return result
|
|
263
|
+
|
|
264
|
+
async def close(self):
|
|
265
|
+
"""Close the client."""
|
|
266
|
+
if self._client:
|
|
267
|
+
await self._client.aclose()
|
|
268
|
+
self._client = None
|
a2a_lite/utils.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helper functions for A2A Lite.
|
|
3
|
+
"""
|
|
4
|
+
from typing import Any, Dict, Type, get_origin, get_args, Union
|
|
5
|
+
import inspect
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def type_to_json_schema(python_type: Type) -> Dict[str, Any]:
|
|
9
|
+
"""
|
|
10
|
+
Convert Python type to JSON Schema.
|
|
11
|
+
|
|
12
|
+
Handles basic types, generics (List, Dict, Optional), and Pydantic models.
|
|
13
|
+
"""
|
|
14
|
+
# Handle None type
|
|
15
|
+
if python_type is type(None):
|
|
16
|
+
return {"type": "null"}
|
|
17
|
+
|
|
18
|
+
# Basic type mapping
|
|
19
|
+
type_map = {
|
|
20
|
+
str: {"type": "string"},
|
|
21
|
+
int: {"type": "integer"},
|
|
22
|
+
float: {"type": "number"},
|
|
23
|
+
bool: {"type": "boolean"},
|
|
24
|
+
list: {"type": "array"},
|
|
25
|
+
dict: {"type": "object"},
|
|
26
|
+
Any: {"type": "object"},
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
# Check basic types first
|
|
30
|
+
if python_type in type_map:
|
|
31
|
+
return type_map[python_type]
|
|
32
|
+
|
|
33
|
+
# Handle generic types
|
|
34
|
+
origin = get_origin(python_type)
|
|
35
|
+
args = get_args(python_type)
|
|
36
|
+
|
|
37
|
+
# Handle Optional (Union[X, None])
|
|
38
|
+
if origin is Union:
|
|
39
|
+
non_none_args = [a for a in args if a is not type(None)]
|
|
40
|
+
if len(non_none_args) == 1:
|
|
41
|
+
# This is Optional[X]
|
|
42
|
+
return type_to_json_schema(non_none_args[0])
|
|
43
|
+
# Union of multiple types
|
|
44
|
+
return {"oneOf": [type_to_json_schema(a) for a in args]}
|
|
45
|
+
|
|
46
|
+
# Handle List[X]
|
|
47
|
+
if origin is list and args:
|
|
48
|
+
return {
|
|
49
|
+
"type": "array",
|
|
50
|
+
"items": type_to_json_schema(args[0])
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
# Handle Dict[K, V]
|
|
54
|
+
if origin is dict and len(args) >= 2:
|
|
55
|
+
return {
|
|
56
|
+
"type": "object",
|
|
57
|
+
"additionalProperties": type_to_json_schema(args[1])
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
# Handle Pydantic models
|
|
61
|
+
if hasattr(python_type, 'model_json_schema'):
|
|
62
|
+
return python_type.model_json_schema()
|
|
63
|
+
|
|
64
|
+
# Handle dataclasses
|
|
65
|
+
if hasattr(python_type, '__dataclass_fields__'):
|
|
66
|
+
properties = {}
|
|
67
|
+
required = []
|
|
68
|
+
for field_name, field_info in python_type.__dataclass_fields__.items():
|
|
69
|
+
properties[field_name] = type_to_json_schema(field_info.type)
|
|
70
|
+
if field_info.default is inspect.Parameter.empty and field_info.default_factory is inspect.Parameter.empty:
|
|
71
|
+
required.append(field_name)
|
|
72
|
+
return {
|
|
73
|
+
"type": "object",
|
|
74
|
+
"properties": properties,
|
|
75
|
+
"required": required,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
# Fallback for unknown types
|
|
79
|
+
return {"type": "object"}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def extract_function_schemas(func) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
|
83
|
+
"""
|
|
84
|
+
Extract input and output JSON schemas from a function's type hints.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Tuple of (input_schema, output_schema)
|
|
88
|
+
"""
|
|
89
|
+
sig = inspect.signature(func)
|
|
90
|
+
hints = getattr(func, '__annotations__', {})
|
|
91
|
+
|
|
92
|
+
# Build input schema from parameters
|
|
93
|
+
properties = {}
|
|
94
|
+
required = []
|
|
95
|
+
|
|
96
|
+
for param_name, param in sig.parameters.items():
|
|
97
|
+
if param_name in ('self', 'cls'):
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
param_type = hints.get(param_name, Any)
|
|
101
|
+
properties[param_name] = type_to_json_schema(param_type)
|
|
102
|
+
|
|
103
|
+
# Parameter is required if it has no default value
|
|
104
|
+
if param.default is inspect.Parameter.empty:
|
|
105
|
+
required.append(param_name)
|
|
106
|
+
|
|
107
|
+
input_schema = {
|
|
108
|
+
"type": "object",
|
|
109
|
+
"properties": properties,
|
|
110
|
+
"required": required,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
# Build output schema from return type
|
|
114
|
+
return_type = hints.get('return', Any)
|
|
115
|
+
output_schema = type_to_json_schema(return_type)
|
|
116
|
+
|
|
117
|
+
return input_schema, output_schema
|