a2a-lite 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- a2a_lite/__init__.py +7 -6
- a2a_lite/agent.py +70 -13
- a2a_lite/auth.py +12 -4
- a2a_lite/cli.py +1 -1
- a2a_lite/decorators.py +3 -1
- a2a_lite/discovery.py +4 -0
- a2a_lite/executor.py +15 -8
- a2a_lite/tasks.py +38 -29
- a2a_lite/testing.py +56 -14
- a2a_lite/utils.py +16 -0
- a2a_lite-0.2.0.dist-info/METADATA +526 -0
- a2a_lite-0.2.0.dist-info/RECORD +19 -0
- a2a_lite-0.1.0.dist-info/METADATA +0 -383
- a2a_lite-0.1.0.dist-info/RECORD +0 -19
- {a2a_lite-0.1.0.dist-info → a2a_lite-0.2.0.dist-info}/WHEEL +0 -0
- {a2a_lite-0.1.0.dist-info → a2a_lite-0.2.0.dist-info}/entry_points.txt +0 -0
a2a_lite/__init__.py
CHANGED
|
@@ -13,8 +13,8 @@ SIMPLE (8 lines):
|
|
|
13
13
|
agent.run()
|
|
14
14
|
|
|
15
15
|
TEST IT (3 lines):
|
|
16
|
-
from a2a_lite import
|
|
17
|
-
client =
|
|
16
|
+
from a2a_lite import AgentTestClient
|
|
17
|
+
client = AgentTestClient(agent)
|
|
18
18
|
assert client.call("greet", name="World") == "Hello, World!"
|
|
19
19
|
|
|
20
20
|
WITH PYDANTIC:
|
|
@@ -64,7 +64,7 @@ WITH FILES (opt-in):
|
|
|
64
64
|
from .agent import Agent
|
|
65
65
|
from .decorators import SkillDefinition
|
|
66
66
|
from .discovery import AgentDiscovery, DiscoveredAgent
|
|
67
|
-
from .testing import
|
|
67
|
+
from .testing import AgentTestClient, AsyncAgentTestClient, TestResult
|
|
68
68
|
|
|
69
69
|
# Middleware - always available
|
|
70
70
|
from .middleware import (
|
|
@@ -102,7 +102,7 @@ from .auth import (
|
|
|
102
102
|
require_auth,
|
|
103
103
|
)
|
|
104
104
|
|
|
105
|
-
__version__ = "0.
|
|
105
|
+
__version__ = "0.2.0"
|
|
106
106
|
|
|
107
107
|
__all__ = [
|
|
108
108
|
# Core
|
|
@@ -111,8 +111,9 @@ __all__ = [
|
|
|
111
111
|
"AgentDiscovery",
|
|
112
112
|
"DiscoveredAgent",
|
|
113
113
|
# Testing
|
|
114
|
-
"
|
|
115
|
-
"
|
|
114
|
+
"AgentTestClient",
|
|
115
|
+
"AsyncAgentTestClient",
|
|
116
|
+
"TestResult",
|
|
116
117
|
# Middleware
|
|
117
118
|
"MiddlewareContext",
|
|
118
119
|
"MiddlewareChain",
|
a2a_lite/agent.py
CHANGED
|
@@ -7,9 +7,12 @@ from __future__ import annotations
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import inspect
|
|
10
|
+
import logging
|
|
10
11
|
from typing import Any, Callable, Optional, Dict, List, Type, Union, get_origin, get_args
|
|
11
12
|
from dataclasses import dataclass, field
|
|
12
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
13
16
|
import uvicorn
|
|
14
17
|
|
|
15
18
|
from a2a.server.apps import A2AStarletteApplication
|
|
@@ -23,7 +26,7 @@ from a2a.types import (
|
|
|
23
26
|
|
|
24
27
|
from .executor import LiteAgentExecutor
|
|
25
28
|
from .decorators import SkillDefinition
|
|
26
|
-
from .utils import type_to_json_schema, extract_function_schemas
|
|
29
|
+
from .utils import type_to_json_schema, extract_function_schemas, _is_or_subclass
|
|
27
30
|
from .middleware import MiddlewareChain, MiddlewareContext
|
|
28
31
|
from .streaming import is_generator_function
|
|
29
32
|
from .webhooks import NotificationManager, WebhookClient
|
|
@@ -85,6 +88,8 @@ class Agent:
|
|
|
85
88
|
# Optional enterprise features
|
|
86
89
|
auth: Optional[Any] = None # AuthProvider
|
|
87
90
|
task_store: Optional[Any] = None # TaskStore or "memory"
|
|
91
|
+
cors_origins: Optional[List[str]] = None
|
|
92
|
+
production: bool = False
|
|
88
93
|
|
|
89
94
|
def __post_init__(self):
|
|
90
95
|
# Internal state
|
|
@@ -161,16 +166,30 @@ class Agent:
|
|
|
161
166
|
if is_streaming:
|
|
162
167
|
self._has_streaming = True
|
|
163
168
|
|
|
164
|
-
# Detect special parameter types
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
needs_interaction =
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
169
|
+
# Detect special parameter types using proper type introspection
|
|
170
|
+
import typing
|
|
171
|
+
from .tasks import TaskContext as _TaskContext
|
|
172
|
+
from .human_loop import InteractionContext as _InteractionContext
|
|
173
|
+
|
|
174
|
+
needs_task_context = False
|
|
175
|
+
needs_interaction = False
|
|
176
|
+
task_context_param: str | None = None
|
|
177
|
+
interaction_param: str | None = None
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
resolved_hints = typing.get_type_hints(func)
|
|
181
|
+
except Exception:
|
|
182
|
+
resolved_hints = getattr(func, '__annotations__', {})
|
|
183
|
+
|
|
184
|
+
for param_name, hint in resolved_hints.items():
|
|
185
|
+
if param_name == 'return':
|
|
186
|
+
continue
|
|
187
|
+
if _is_or_subclass(hint, _TaskContext):
|
|
188
|
+
needs_task_context = True
|
|
189
|
+
task_context_param = param_name
|
|
190
|
+
elif _is_or_subclass(hint, _InteractionContext):
|
|
191
|
+
needs_interaction = True
|
|
192
|
+
interaction_param = param_name
|
|
174
193
|
|
|
175
194
|
# Extract schemas
|
|
176
195
|
input_schema, output_schema = extract_function_schemas(func)
|
|
@@ -186,6 +205,8 @@ class Agent:
|
|
|
186
205
|
is_streaming=is_streaming,
|
|
187
206
|
needs_task_context=needs_task_context,
|
|
188
207
|
needs_interaction=needs_interaction,
|
|
208
|
+
task_context_param=task_context_param,
|
|
209
|
+
interaction_param=interaction_param,
|
|
189
210
|
)
|
|
190
211
|
|
|
191
212
|
self._skills[skill_name] = skill_def
|
|
@@ -375,10 +396,34 @@ class Agent:
|
|
|
375
396
|
)
|
|
376
397
|
console.print(f"[dim]mDNS discovery enabled for {self.name}[/]")
|
|
377
398
|
|
|
399
|
+
# Production mode warning
|
|
400
|
+
if self.production:
|
|
401
|
+
url_str = self.url or f"http://{display_host}:{port}"
|
|
402
|
+
if not url_str.startswith("https://"):
|
|
403
|
+
logger.warning(
|
|
404
|
+
"Running in production mode over HTTP. "
|
|
405
|
+
"Consider using HTTPS for secure communication."
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Build the ASGI app
|
|
409
|
+
app = app_builder.build()
|
|
410
|
+
|
|
411
|
+
# Add CORS middleware if configured
|
|
412
|
+
if self.cors_origins is not None:
|
|
413
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
414
|
+
from starlette.middleware import Middleware as StarletteMiddleware
|
|
415
|
+
# Wrap the existing app with CORS
|
|
416
|
+
app.add_middleware(
|
|
417
|
+
CORSMiddleware,
|
|
418
|
+
allow_origins=self.cors_origins,
|
|
419
|
+
allow_methods=["*"],
|
|
420
|
+
allow_headers=["*"],
|
|
421
|
+
)
|
|
422
|
+
|
|
378
423
|
# Start server
|
|
379
424
|
try:
|
|
380
425
|
uvicorn.run(
|
|
381
|
-
|
|
426
|
+
app,
|
|
382
427
|
host=host,
|
|
383
428
|
port=port,
|
|
384
429
|
log_level=log_level,
|
|
@@ -450,4 +495,16 @@ class Agent:
|
|
|
450
495
|
http_handler=request_handler,
|
|
451
496
|
)
|
|
452
497
|
|
|
453
|
-
|
|
498
|
+
app = app_builder.build()
|
|
499
|
+
|
|
500
|
+
# Add CORS middleware if configured
|
|
501
|
+
if self.cors_origins is not None:
|
|
502
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
503
|
+
app.add_middleware(
|
|
504
|
+
CORSMiddleware,
|
|
505
|
+
allow_origins=self.cors_origins,
|
|
506
|
+
allow_methods=["*"],
|
|
507
|
+
allow_headers=["*"],
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
return app
|
a2a_lite/auth.py
CHANGED
|
@@ -116,10 +116,17 @@ class APIKeyAuth(AuthProvider):
|
|
|
116
116
|
header: str = "X-API-Key",
|
|
117
117
|
query_param: Optional[str] = None,
|
|
118
118
|
):
|
|
119
|
-
|
|
119
|
+
# Store only hashes of keys for security
|
|
120
|
+
self._key_hashes = {
|
|
121
|
+
hashlib.sha256(k.encode()).hexdigest() for k in keys
|
|
122
|
+
}
|
|
120
123
|
self.header = header
|
|
121
124
|
self.query_param = query_param
|
|
122
125
|
|
|
126
|
+
def _hash_key(self, key: str) -> str:
|
|
127
|
+
"""Hash a key using SHA-256."""
|
|
128
|
+
return hashlib.sha256(key.encode()).hexdigest()
|
|
129
|
+
|
|
123
130
|
async def authenticate(self, request: AuthRequest) -> AuthResult:
|
|
124
131
|
# Check header
|
|
125
132
|
key = request.headers.get(self.header)
|
|
@@ -131,11 +138,12 @@ class APIKeyAuth(AuthProvider):
|
|
|
131
138
|
if not key:
|
|
132
139
|
return AuthResult.failure("API key required")
|
|
133
140
|
|
|
134
|
-
|
|
141
|
+
key_hash = self._hash_key(key)
|
|
142
|
+
if key_hash not in self._key_hashes:
|
|
135
143
|
return AuthResult.failure("Invalid API key")
|
|
136
144
|
|
|
137
|
-
# Use hash
|
|
138
|
-
user_id =
|
|
145
|
+
# Use hash prefix as user ID
|
|
146
|
+
user_id = key_hash[:16]
|
|
139
147
|
return AuthResult.success(user_id=user_id)
|
|
140
148
|
|
|
141
149
|
def get_scheme(self) -> Dict[str, Any]:
|
a2a_lite/cli.py
CHANGED
a2a_lite/decorators.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Decorator definitions and skill metadata.
|
|
3
3
|
"""
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from typing import Any, Callable, Dict, List
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
@dataclass
|
|
@@ -18,6 +18,8 @@ class SkillDefinition:
|
|
|
18
18
|
is_streaming: bool = False
|
|
19
19
|
needs_task_context: bool = False
|
|
20
20
|
needs_interaction: bool = False
|
|
21
|
+
task_context_param: Optional[str] = None
|
|
22
|
+
interaction_param: Optional[str] = None
|
|
21
23
|
|
|
22
24
|
def to_dict(self) -> Dict[str, Any]:
|
|
23
25
|
"""Convert to dictionary for serialization."""
|
a2a_lite/discovery.py
CHANGED
|
@@ -4,10 +4,13 @@ mDNS-based local agent discovery using Zeroconf.
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
import asyncio
|
|
7
|
+
import logging
|
|
7
8
|
import socket
|
|
8
9
|
from typing import Dict, List, Optional
|
|
9
10
|
from dataclasses import dataclass
|
|
10
11
|
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
11
14
|
from zeroconf import ServiceBrowser, ServiceListener, Zeroconf, ServiceInfo
|
|
12
15
|
from zeroconf.asyncio import AsyncZeroconf, AsyncServiceBrowser
|
|
13
16
|
|
|
@@ -113,6 +116,7 @@ class AgentDiscovery:
|
|
|
113
116
|
s.close()
|
|
114
117
|
return ip
|
|
115
118
|
except Exception:
|
|
119
|
+
logger.debug("Could not detect local IP, falling back to 127.0.0.1")
|
|
116
120
|
return "127.0.0.1"
|
|
117
121
|
|
|
118
122
|
|
a2a_lite/executor.py
CHANGED
|
@@ -6,14 +6,18 @@ from __future__ import annotations
|
|
|
6
6
|
import asyncio
|
|
7
7
|
import inspect
|
|
8
8
|
import json
|
|
9
|
+
import logging
|
|
9
10
|
from typing import Any, Callable, Dict, List, Optional
|
|
10
11
|
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
11
14
|
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
|
12
15
|
from a2a.server.events import EventQueue
|
|
13
16
|
|
|
14
17
|
from .decorators import SkillDefinition
|
|
15
18
|
from .middleware import MiddlewareChain, MiddlewareContext
|
|
16
19
|
from .streaming import is_generator_function, stream_generator
|
|
20
|
+
from .utils import _is_or_subclass
|
|
17
21
|
|
|
18
22
|
|
|
19
23
|
class LiteAgentExecutor(AgentExecutor):
|
|
@@ -100,7 +104,7 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
100
104
|
else:
|
|
101
105
|
hook(skill_name, result, ctx)
|
|
102
106
|
except Exception:
|
|
103
|
-
|
|
107
|
+
logger.warning("Completion hook error for skill '%s'", skill_name, exc_info=True)
|
|
104
108
|
|
|
105
109
|
except Exception as e:
|
|
106
110
|
await self._handle_error(e, event_queue)
|
|
@@ -132,17 +136,19 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
132
136
|
# Inject special contexts if needed
|
|
133
137
|
if skill_def.needs_task_context and self.task_store:
|
|
134
138
|
from .tasks import TaskContext, Task, TaskStatus, TaskState
|
|
135
|
-
task = self.task_store.create(skill_name, params)
|
|
139
|
+
task = await self.task_store.create(skill_name, params)
|
|
136
140
|
# Only pass event_queue for streaming skills (status updates go via SSE)
|
|
137
141
|
eq = event_queue if skill_def.is_streaming else None
|
|
138
142
|
task_ctx = TaskContext(task, eq)
|
|
139
|
-
|
|
143
|
+
param_name = skill_def.task_context_param or "task"
|
|
144
|
+
params[param_name] = task_ctx
|
|
140
145
|
|
|
141
146
|
if skill_def.needs_interaction:
|
|
142
147
|
from .human_loop import InteractionContext
|
|
143
148
|
task_id = metadata.get("task_id", "unknown")
|
|
144
149
|
interaction_ctx = InteractionContext(task_id, event_queue)
|
|
145
|
-
|
|
150
|
+
param_name = skill_def.interaction_param or "ctx"
|
|
151
|
+
params[param_name] = interaction_ctx
|
|
146
152
|
|
|
147
153
|
# Call the handler
|
|
148
154
|
handler = skill_def.handler
|
|
@@ -172,13 +178,14 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
172
178
|
converted[param_name] = value
|
|
173
179
|
continue
|
|
174
180
|
|
|
175
|
-
type_name = str(param_type)
|
|
176
|
-
|
|
177
181
|
# Skip special context types
|
|
178
|
-
|
|
182
|
+
from .tasks import TaskContext as _TaskContext
|
|
183
|
+
from .human_loop import InteractionContext as _InteractionContext
|
|
184
|
+
if _is_or_subclass(param_type, _TaskContext) or _is_or_subclass(param_type, _InteractionContext):
|
|
179
185
|
continue
|
|
180
186
|
|
|
181
187
|
# Convert FilePart
|
|
188
|
+
type_name = str(param_type)
|
|
182
189
|
if "FilePart" in type_name:
|
|
183
190
|
from .parts import FilePart
|
|
184
191
|
if isinstance(value, dict):
|
|
@@ -234,7 +241,7 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
234
241
|
if isinstance(data, dict) and 'skill' in data:
|
|
235
242
|
return data['skill'], data.get('params', {})
|
|
236
243
|
except json.JSONDecodeError:
|
|
237
|
-
|
|
244
|
+
logger.debug("Message is not JSON, treating as plain text")
|
|
238
245
|
|
|
239
246
|
return None, {"message": message}
|
|
240
247
|
|
a2a_lite/tasks.py
CHANGED
|
@@ -27,12 +27,15 @@ Example (with task tracking - opt-in):
|
|
|
27
27
|
"""
|
|
28
28
|
from __future__ import annotations
|
|
29
29
|
|
|
30
|
+
import asyncio
|
|
31
|
+
import logging
|
|
30
32
|
from dataclasses import dataclass, field
|
|
31
33
|
from datetime import datetime
|
|
32
34
|
from enum import Enum
|
|
33
35
|
from typing import Any, Callable, Dict, List, Optional
|
|
34
36
|
from uuid import uuid4
|
|
35
|
-
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
36
39
|
|
|
37
40
|
|
|
38
41
|
class TaskState(str, Enum):
|
|
@@ -145,7 +148,7 @@ class TaskContext:
|
|
|
145
148
|
else:
|
|
146
149
|
callback(self._task.status)
|
|
147
150
|
except Exception:
|
|
148
|
-
|
|
151
|
+
logger.warning("Status callback error for task '%s'", self._task.id, exc_info=True)
|
|
149
152
|
|
|
150
153
|
# Send SSE event if streaming
|
|
151
154
|
if self._event_queue:
|
|
@@ -170,52 +173,58 @@ class TaskContext:
|
|
|
170
173
|
|
|
171
174
|
class TaskStore:
|
|
172
175
|
"""
|
|
173
|
-
In-memory task store.
|
|
176
|
+
In-memory task store with async locking for thread safety.
|
|
174
177
|
|
|
175
178
|
For production, extend this with Redis/DB backend.
|
|
176
179
|
"""
|
|
177
180
|
|
|
178
181
|
def __init__(self):
|
|
179
182
|
self._tasks: Dict[str, Task] = {}
|
|
183
|
+
self._lock = asyncio.Lock()
|
|
180
184
|
|
|
181
|
-
def create(self, skill: str, params: Dict[str, Any]) -> Task:
|
|
185
|
+
async def create(self, skill: str, params: Dict[str, Any]) -> Task:
|
|
182
186
|
"""Create a new task."""
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
187
|
+
async with self._lock:
|
|
188
|
+
task = Task(
|
|
189
|
+
id=uuid4().hex,
|
|
190
|
+
skill=skill,
|
|
191
|
+
params=params,
|
|
192
|
+
status=TaskStatus(state=TaskState.SUBMITTED),
|
|
193
|
+
)
|
|
194
|
+
self._tasks[task.id] = task
|
|
195
|
+
return task
|
|
196
|
+
|
|
197
|
+
async def get(self, task_id: str) -> Optional[Task]:
|
|
193
198
|
"""Get task by ID."""
|
|
194
|
-
|
|
199
|
+
async with self._lock:
|
|
200
|
+
return self._tasks.get(task_id)
|
|
195
201
|
|
|
196
|
-
def update(self, task: Task) -> None:
|
|
202
|
+
async def update(self, task: Task) -> None:
|
|
197
203
|
"""Update task in store."""
|
|
198
|
-
self.
|
|
204
|
+
async with self._lock:
|
|
205
|
+
self._tasks[task.id] = task
|
|
199
206
|
|
|
200
|
-
def list(
|
|
207
|
+
async def list(
|
|
201
208
|
self,
|
|
202
209
|
state: Optional[TaskState] = None,
|
|
203
210
|
skill: Optional[str] = None,
|
|
204
211
|
limit: int = 100,
|
|
205
212
|
) -> List[Task]:
|
|
206
213
|
"""List tasks with optional filters."""
|
|
207
|
-
|
|
214
|
+
async with self._lock:
|
|
215
|
+
tasks = list(self._tasks.values())
|
|
208
216
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
217
|
+
if state:
|
|
218
|
+
tasks = [t for t in tasks if t.status.state == state]
|
|
219
|
+
if skill:
|
|
220
|
+
tasks = [t for t in tasks if t.skill == skill]
|
|
213
221
|
|
|
214
|
-
|
|
222
|
+
return sorted(tasks, key=lambda t: t.created_at, reverse=True)[:limit]
|
|
215
223
|
|
|
216
|
-
def delete(self, task_id: str) -> bool:
|
|
224
|
+
async def delete(self, task_id: str) -> bool:
|
|
217
225
|
"""Delete a task."""
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
226
|
+
async with self._lock:
|
|
227
|
+
if task_id in self._tasks:
|
|
228
|
+
del self._tasks[task_id]
|
|
229
|
+
return True
|
|
230
|
+
return False
|
a2a_lite/testing.py
CHANGED
|
@@ -3,21 +3,61 @@ Testing utilities for A2A Lite agents.
|
|
|
3
3
|
|
|
4
4
|
Makes testing agents as simple as:
|
|
5
5
|
|
|
6
|
-
from a2a_lite.testing import
|
|
6
|
+
from a2a_lite.testing import AgentTestClient
|
|
7
7
|
|
|
8
8
|
def test_my_agent():
|
|
9
|
-
client =
|
|
9
|
+
client = AgentTestClient(agent)
|
|
10
10
|
result = client.call("greet", name="World")
|
|
11
11
|
assert result == "Hello, World!"
|
|
12
12
|
"""
|
|
13
13
|
from __future__ import annotations
|
|
14
14
|
|
|
15
15
|
import json
|
|
16
|
+
from dataclasses import dataclass
|
|
16
17
|
from typing import Any, Dict, Optional
|
|
17
18
|
from uuid import uuid4
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
|
|
21
|
+
@dataclass
|
|
22
|
+
class TestResult:
|
|
23
|
+
"""
|
|
24
|
+
Structured result from a test client call.
|
|
25
|
+
|
|
26
|
+
Provides multiple ways to access the result:
|
|
27
|
+
- .data — parsed Python object (dict, list, int, str, etc.)
|
|
28
|
+
- .text — raw text string
|
|
29
|
+
- .json() — parse text as JSON (raises on invalid JSON)
|
|
30
|
+
- .raw_response — the full A2A response dict
|
|
31
|
+
"""
|
|
32
|
+
_data: Any
|
|
33
|
+
_text: str
|
|
34
|
+
raw_response: Dict[str, Any]
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def data(self) -> Any:
|
|
38
|
+
"""The parsed result value."""
|
|
39
|
+
return self._data
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def text(self) -> str:
|
|
43
|
+
"""The raw text representation."""
|
|
44
|
+
return self._text
|
|
45
|
+
|
|
46
|
+
def json(self) -> Any:
|
|
47
|
+
"""Parse the text as JSON."""
|
|
48
|
+
return json.loads(self._text)
|
|
49
|
+
|
|
50
|
+
def __eq__(self, other: Any) -> bool:
|
|
51
|
+
"""Allow direct comparison with the data value for convenience."""
|
|
52
|
+
if isinstance(other, TestResult):
|
|
53
|
+
return self._data == other._data
|
|
54
|
+
return self._data == other
|
|
55
|
+
|
|
56
|
+
def __repr__(self) -> str:
|
|
57
|
+
return f"TestResult(data={self._data!r})"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class AgentTestClient:
|
|
21
61
|
"""
|
|
22
62
|
Simple test client for A2A Lite agents.
|
|
23
63
|
|
|
@@ -29,7 +69,7 @@ class TestClient:
|
|
|
29
69
|
return a + b
|
|
30
70
|
|
|
31
71
|
# In your test
|
|
32
|
-
client =
|
|
72
|
+
client = AgentTestClient(agent)
|
|
33
73
|
assert client.call("add", a=2, b=3) == 5
|
|
34
74
|
"""
|
|
35
75
|
|
|
@@ -89,7 +129,7 @@ class TestClient:
|
|
|
89
129
|
# Extract the actual result from A2A response
|
|
90
130
|
return self._extract_result(data)
|
|
91
131
|
|
|
92
|
-
def _extract_result(self, response: Dict) ->
|
|
132
|
+
def _extract_result(self, response: Dict) -> TestResult:
|
|
93
133
|
"""Extract the skill result from A2A response."""
|
|
94
134
|
if "error" in response:
|
|
95
135
|
raise TestClientError(response["error"])
|
|
@@ -103,11 +143,12 @@ class TestClient:
|
|
|
103
143
|
text = part.get("text", "")
|
|
104
144
|
# Try to parse as JSON
|
|
105
145
|
try:
|
|
106
|
-
|
|
146
|
+
data = json.loads(text)
|
|
107
147
|
except json.JSONDecodeError:
|
|
108
|
-
|
|
148
|
+
data = text
|
|
149
|
+
return TestResult(_data=data, _text=text, raw_response=response)
|
|
109
150
|
|
|
110
|
-
return result
|
|
151
|
+
return TestResult(_data=result, _text=json.dumps(result), raw_response=response)
|
|
111
152
|
|
|
112
153
|
def get_agent_card(self) -> Dict[str, Any]:
|
|
113
154
|
"""
|
|
@@ -183,13 +224,13 @@ class TestClientError(Exception):
|
|
|
183
224
|
|
|
184
225
|
|
|
185
226
|
# Async version for async tests
|
|
186
|
-
class
|
|
227
|
+
class AsyncAgentTestClient:
|
|
187
228
|
"""
|
|
188
229
|
Async test client for A2A Lite agents.
|
|
189
230
|
|
|
190
231
|
Example:
|
|
191
232
|
async def test_my_agent():
|
|
192
|
-
client =
|
|
233
|
+
client = AsyncAgentTestClient(agent)
|
|
193
234
|
result = await client.call("greet", name="World")
|
|
194
235
|
assert result == "Hello, World!"
|
|
195
236
|
"""
|
|
@@ -243,7 +284,7 @@ class AsyncTestClient:
|
|
|
243
284
|
|
|
244
285
|
return self._extract_result(data)
|
|
245
286
|
|
|
246
|
-
def _extract_result(self, response: Dict) ->
|
|
287
|
+
def _extract_result(self, response: Dict) -> TestResult:
|
|
247
288
|
"""Extract the skill result from A2A response."""
|
|
248
289
|
if "error" in response:
|
|
249
290
|
raise TestClientError(response["error"])
|
|
@@ -255,11 +296,12 @@ class AsyncTestClient:
|
|
|
255
296
|
if part.get("kind") == "text" or part.get("type") == "text":
|
|
256
297
|
text = part.get("text", "")
|
|
257
298
|
try:
|
|
258
|
-
|
|
299
|
+
data = json.loads(text)
|
|
259
300
|
except json.JSONDecodeError:
|
|
260
|
-
|
|
301
|
+
data = text
|
|
302
|
+
return TestResult(_data=data, _text=text, raw_response=response)
|
|
261
303
|
|
|
262
|
-
return result
|
|
304
|
+
return TestResult(_data=result, _text=json.dumps(result), raw_response=response)
|
|
263
305
|
|
|
264
306
|
async def close(self):
|
|
265
307
|
"""Close the client."""
|
a2a_lite/utils.py
CHANGED
|
@@ -5,6 +5,22 @@ from typing import Any, Dict, Type, get_origin, get_args, Union
|
|
|
5
5
|
import inspect
|
|
6
6
|
|
|
7
7
|
|
|
8
|
+
def _is_or_subclass(hint: Any, target_class: Type) -> bool:
|
|
9
|
+
"""
|
|
10
|
+
Check if a type hint is, or is a subclass of, the target class.
|
|
11
|
+
|
|
12
|
+
Works with raw classes and string annotations.
|
|
13
|
+
"""
|
|
14
|
+
try:
|
|
15
|
+
if hint is target_class:
|
|
16
|
+
return True
|
|
17
|
+
if isinstance(hint, type) and issubclass(hint, target_class):
|
|
18
|
+
return True
|
|
19
|
+
except TypeError:
|
|
20
|
+
pass
|
|
21
|
+
return False
|
|
22
|
+
|
|
23
|
+
|
|
8
24
|
def type_to_json_schema(python_type: Type) -> Dict[str, Any]:
|
|
9
25
|
"""
|
|
10
26
|
Convert Python type to JSON Schema.
|