a2a-lite 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- a2a_lite/__init__.py +7 -6
- a2a_lite/agent.py +75 -13
- a2a_lite/auth.py +27 -8
- a2a_lite/cli.py +1 -1
- a2a_lite/decorators.py +3 -1
- a2a_lite/discovery.py +4 -0
- a2a_lite/executor.py +49 -14
- a2a_lite/middleware.py +6 -1
- a2a_lite/tasks.py +43 -34
- a2a_lite/testing.py +57 -15
- a2a_lite/utils.py +16 -0
- a2a_lite-0.2.1.dist-info/METADATA +526 -0
- a2a_lite-0.2.1.dist-info/RECORD +19 -0
- a2a_lite-0.1.0.dist-info/METADATA +0 -383
- a2a_lite-0.1.0.dist-info/RECORD +0 -19
- {a2a_lite-0.1.0.dist-info → a2a_lite-0.2.1.dist-info}/WHEEL +0 -0
- {a2a_lite-0.1.0.dist-info → a2a_lite-0.2.1.dist-info}/entry_points.txt +0 -0
a2a_lite/__init__.py
CHANGED
|
@@ -13,8 +13,8 @@ SIMPLE (8 lines):
|
|
|
13
13
|
agent.run()
|
|
14
14
|
|
|
15
15
|
TEST IT (3 lines):
|
|
16
|
-
from a2a_lite import
|
|
17
|
-
client =
|
|
16
|
+
from a2a_lite import AgentTestClient
|
|
17
|
+
client = AgentTestClient(agent)
|
|
18
18
|
assert client.call("greet", name="World") == "Hello, World!"
|
|
19
19
|
|
|
20
20
|
WITH PYDANTIC:
|
|
@@ -64,7 +64,7 @@ WITH FILES (opt-in):
|
|
|
64
64
|
from .agent import Agent
|
|
65
65
|
from .decorators import SkillDefinition
|
|
66
66
|
from .discovery import AgentDiscovery, DiscoveredAgent
|
|
67
|
-
from .testing import
|
|
67
|
+
from .testing import AgentTestClient, AsyncAgentTestClient, TestResult
|
|
68
68
|
|
|
69
69
|
# Middleware - always available
|
|
70
70
|
from .middleware import (
|
|
@@ -102,7 +102,7 @@ from .auth import (
|
|
|
102
102
|
require_auth,
|
|
103
103
|
)
|
|
104
104
|
|
|
105
|
-
__version__ = "0.1
|
|
105
|
+
__version__ = "0.2.1"
|
|
106
106
|
|
|
107
107
|
__all__ = [
|
|
108
108
|
# Core
|
|
@@ -111,8 +111,9 @@ __all__ = [
|
|
|
111
111
|
"AgentDiscovery",
|
|
112
112
|
"DiscoveredAgent",
|
|
113
113
|
# Testing
|
|
114
|
-
"
|
|
115
|
-
"
|
|
114
|
+
"AgentTestClient",
|
|
115
|
+
"AsyncAgentTestClient",
|
|
116
|
+
"TestResult",
|
|
116
117
|
# Middleware
|
|
117
118
|
"MiddlewareContext",
|
|
118
119
|
"MiddlewareChain",
|
a2a_lite/agent.py
CHANGED
|
@@ -7,9 +7,12 @@ from __future__ import annotations
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import inspect
|
|
10
|
+
import logging
|
|
10
11
|
from typing import Any, Callable, Optional, Dict, List, Type, Union, get_origin, get_args
|
|
11
12
|
from dataclasses import dataclass, field
|
|
12
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
13
16
|
import uvicorn
|
|
14
17
|
|
|
15
18
|
from a2a.server.apps import A2AStarletteApplication
|
|
@@ -23,7 +26,7 @@ from a2a.types import (
|
|
|
23
26
|
|
|
24
27
|
from .executor import LiteAgentExecutor
|
|
25
28
|
from .decorators import SkillDefinition
|
|
26
|
-
from .utils import type_to_json_schema, extract_function_schemas
|
|
29
|
+
from .utils import type_to_json_schema, extract_function_schemas, _is_or_subclass
|
|
27
30
|
from .middleware import MiddlewareChain, MiddlewareContext
|
|
28
31
|
from .streaming import is_generator_function
|
|
29
32
|
from .webhooks import NotificationManager, WebhookClient
|
|
@@ -85,6 +88,8 @@ class Agent:
|
|
|
85
88
|
# Optional enterprise features
|
|
86
89
|
auth: Optional[Any] = None # AuthProvider
|
|
87
90
|
task_store: Optional[Any] = None # TaskStore or "memory"
|
|
91
|
+
cors_origins: Optional[List[str]] = None
|
|
92
|
+
production: bool = False
|
|
88
93
|
|
|
89
94
|
def __post_init__(self):
|
|
90
95
|
# Internal state
|
|
@@ -161,16 +166,30 @@ class Agent:
|
|
|
161
166
|
if is_streaming:
|
|
162
167
|
self._has_streaming = True
|
|
163
168
|
|
|
164
|
-
# Detect special parameter types
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
needs_interaction =
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
169
|
+
# Detect special parameter types using proper type introspection
|
|
170
|
+
import typing
|
|
171
|
+
from .tasks import TaskContext as _TaskContext
|
|
172
|
+
from .human_loop import InteractionContext as _InteractionContext
|
|
173
|
+
|
|
174
|
+
needs_task_context = False
|
|
175
|
+
needs_interaction = False
|
|
176
|
+
task_context_param: str | None = None
|
|
177
|
+
interaction_param: str | None = None
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
resolved_hints = typing.get_type_hints(func)
|
|
181
|
+
except Exception:
|
|
182
|
+
resolved_hints = getattr(func, '__annotations__', {})
|
|
183
|
+
|
|
184
|
+
for param_name, hint in resolved_hints.items():
|
|
185
|
+
if param_name == 'return':
|
|
186
|
+
continue
|
|
187
|
+
if _is_or_subclass(hint, _TaskContext):
|
|
188
|
+
needs_task_context = True
|
|
189
|
+
task_context_param = param_name
|
|
190
|
+
elif _is_or_subclass(hint, _InteractionContext):
|
|
191
|
+
needs_interaction = True
|
|
192
|
+
interaction_param = param_name
|
|
174
193
|
|
|
175
194
|
# Extract schemas
|
|
176
195
|
input_schema, output_schema = extract_function_schemas(func)
|
|
@@ -186,6 +205,8 @@ class Agent:
|
|
|
186
205
|
is_streaming=is_streaming,
|
|
187
206
|
needs_task_context=needs_task_context,
|
|
188
207
|
needs_interaction=needs_interaction,
|
|
208
|
+
task_context_param=task_context_param,
|
|
209
|
+
interaction_param=interaction_param,
|
|
189
210
|
)
|
|
190
211
|
|
|
191
212
|
self._skills[skill_name] = skill_def
|
|
@@ -311,6 +332,10 @@ class Agent:
|
|
|
311
332
|
task_store=self._task_store,
|
|
312
333
|
)
|
|
313
334
|
|
|
335
|
+
# The SDK's InMemoryTaskStore handles protocol-level task lifecycle
|
|
336
|
+
# (task creation, state transitions per the A2A spec). This is separate
|
|
337
|
+
# from self._task_store which provides application-level tracking
|
|
338
|
+
# (progress updates, custom status) exposed via TaskContext to skills.
|
|
314
339
|
request_handler = DefaultRequestHandler(
|
|
315
340
|
agent_executor=executor,
|
|
316
341
|
task_store=InMemoryTaskStore(),
|
|
@@ -375,10 +400,34 @@ class Agent:
|
|
|
375
400
|
)
|
|
376
401
|
console.print(f"[dim]mDNS discovery enabled for {self.name}[/]")
|
|
377
402
|
|
|
403
|
+
# Production mode warning
|
|
404
|
+
if self.production:
|
|
405
|
+
url_str = self.url or f"http://{display_host}:{port}"
|
|
406
|
+
if not url_str.startswith("https://"):
|
|
407
|
+
logger.warning(
|
|
408
|
+
"Running in production mode over HTTP. "
|
|
409
|
+
"Consider using HTTPS for secure communication."
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
# Build the ASGI app
|
|
413
|
+
app = app_builder.build()
|
|
414
|
+
|
|
415
|
+
# Add CORS middleware if configured
|
|
416
|
+
if self.cors_origins is not None:
|
|
417
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
418
|
+
from starlette.middleware import Middleware as StarletteMiddleware
|
|
419
|
+
# Wrap the existing app with CORS
|
|
420
|
+
app.add_middleware(
|
|
421
|
+
CORSMiddleware,
|
|
422
|
+
allow_origins=self.cors_origins,
|
|
423
|
+
allow_methods=["*"],
|
|
424
|
+
allow_headers=["*"],
|
|
425
|
+
)
|
|
426
|
+
|
|
378
427
|
# Start server
|
|
379
428
|
try:
|
|
380
429
|
uvicorn.run(
|
|
381
|
-
|
|
430
|
+
app,
|
|
382
431
|
host=host,
|
|
383
432
|
port=port,
|
|
384
433
|
log_level=log_level,
|
|
@@ -440,6 +489,7 @@ class Agent:
|
|
|
440
489
|
task_store=self._task_store,
|
|
441
490
|
)
|
|
442
491
|
|
|
492
|
+
# SDK task store for protocol-level lifecycle (separate from app-level self._task_store)
|
|
443
493
|
request_handler = DefaultRequestHandler(
|
|
444
494
|
agent_executor=executor,
|
|
445
495
|
task_store=InMemoryTaskStore(),
|
|
@@ -450,4 +500,16 @@ class Agent:
|
|
|
450
500
|
http_handler=request_handler,
|
|
451
501
|
)
|
|
452
502
|
|
|
453
|
-
|
|
503
|
+
app = app_builder.build()
|
|
504
|
+
|
|
505
|
+
# Add CORS middleware if configured
|
|
506
|
+
if self.cors_origins is not None:
|
|
507
|
+
from starlette.middleware.cors import CORSMiddleware
|
|
508
|
+
app.add_middleware(
|
|
509
|
+
CORSMiddleware,
|
|
510
|
+
allow_origins=self.cors_origins,
|
|
511
|
+
allow_methods=["*"],
|
|
512
|
+
allow_headers=["*"],
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
return app
|
a2a_lite/auth.py
CHANGED
|
@@ -60,6 +60,17 @@ class AuthRequest:
|
|
|
60
60
|
method: str = "POST"
|
|
61
61
|
path: str = "/"
|
|
62
62
|
|
|
63
|
+
def get_header(self, name: str) -> Optional[str]:
|
|
64
|
+
"""Get a header value (case-insensitive)."""
|
|
65
|
+
# Try exact match first, then case-insensitive
|
|
66
|
+
if name in self.headers:
|
|
67
|
+
return self.headers[name]
|
|
68
|
+
lower = name.lower()
|
|
69
|
+
for k, v in self.headers.items():
|
|
70
|
+
if k.lower() == lower:
|
|
71
|
+
return v
|
|
72
|
+
return None
|
|
73
|
+
|
|
63
74
|
|
|
64
75
|
@dataclass
|
|
65
76
|
class AuthResult:
|
|
@@ -116,13 +127,20 @@ class APIKeyAuth(AuthProvider):
|
|
|
116
127
|
header: str = "X-API-Key",
|
|
117
128
|
query_param: Optional[str] = None,
|
|
118
129
|
):
|
|
119
|
-
|
|
130
|
+
# Store only hashes of keys for security
|
|
131
|
+
self._key_hashes = {
|
|
132
|
+
hashlib.sha256(k.encode()).hexdigest() for k in keys
|
|
133
|
+
}
|
|
120
134
|
self.header = header
|
|
121
135
|
self.query_param = query_param
|
|
122
136
|
|
|
137
|
+
def _hash_key(self, key: str) -> str:
|
|
138
|
+
"""Hash a key using SHA-256."""
|
|
139
|
+
return hashlib.sha256(key.encode()).hexdigest()
|
|
140
|
+
|
|
123
141
|
async def authenticate(self, request: AuthRequest) -> AuthResult:
|
|
124
|
-
# Check header
|
|
125
|
-
key = request.
|
|
142
|
+
# Check header (case-insensitive)
|
|
143
|
+
key = request.get_header(self.header)
|
|
126
144
|
|
|
127
145
|
# Check query param
|
|
128
146
|
if not key and self.query_param:
|
|
@@ -131,11 +149,12 @@ class APIKeyAuth(AuthProvider):
|
|
|
131
149
|
if not key:
|
|
132
150
|
return AuthResult.failure("API key required")
|
|
133
151
|
|
|
134
|
-
|
|
152
|
+
key_hash = self._hash_key(key)
|
|
153
|
+
if key_hash not in self._key_hashes:
|
|
135
154
|
return AuthResult.failure("Invalid API key")
|
|
136
155
|
|
|
137
|
-
# Use hash
|
|
138
|
-
user_id =
|
|
156
|
+
# Use hash prefix as user ID
|
|
157
|
+
user_id = key_hash[:16]
|
|
139
158
|
return AuthResult.success(user_id=user_id)
|
|
140
159
|
|
|
141
160
|
def get_scheme(self) -> Dict[str, Any]:
|
|
@@ -171,7 +190,7 @@ class BearerAuth(AuthProvider):
|
|
|
171
190
|
self.header = header
|
|
172
191
|
|
|
173
192
|
async def authenticate(self, request: AuthRequest) -> AuthResult:
|
|
174
|
-
auth_header = request.
|
|
193
|
+
auth_header = request.get_header(self.header) or ""
|
|
175
194
|
|
|
176
195
|
if not auth_header.startswith("Bearer "):
|
|
177
196
|
return AuthResult.failure("Bearer token required")
|
|
@@ -220,7 +239,7 @@ class OAuth2Auth(AuthProvider):
|
|
|
220
239
|
self._jwks_client = None
|
|
221
240
|
|
|
222
241
|
async def authenticate(self, request: AuthRequest) -> AuthResult:
|
|
223
|
-
auth_header = request.
|
|
242
|
+
auth_header = request.get_header("Authorization") or ""
|
|
224
243
|
|
|
225
244
|
if not auth_header.startswith("Bearer "):
|
|
226
245
|
return AuthResult.failure("Bearer token required")
|
a2a_lite/cli.py
CHANGED
a2a_lite/decorators.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Decorator definitions and skill metadata.
|
|
3
3
|
"""
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from typing import Any, Callable, Dict, List
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
@dataclass
|
|
@@ -18,6 +18,8 @@ class SkillDefinition:
|
|
|
18
18
|
is_streaming: bool = False
|
|
19
19
|
needs_task_context: bool = False
|
|
20
20
|
needs_interaction: bool = False
|
|
21
|
+
task_context_param: Optional[str] = None
|
|
22
|
+
interaction_param: Optional[str] = None
|
|
21
23
|
|
|
22
24
|
def to_dict(self) -> Dict[str, Any]:
|
|
23
25
|
"""Convert to dictionary for serialization."""
|
a2a_lite/discovery.py
CHANGED
|
@@ -4,10 +4,13 @@ mDNS-based local agent discovery using Zeroconf.
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
|
|
6
6
|
import asyncio
|
|
7
|
+
import logging
|
|
7
8
|
import socket
|
|
8
9
|
from typing import Dict, List, Optional
|
|
9
10
|
from dataclasses import dataclass
|
|
10
11
|
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
11
14
|
from zeroconf import ServiceBrowser, ServiceListener, Zeroconf, ServiceInfo
|
|
12
15
|
from zeroconf.asyncio import AsyncZeroconf, AsyncServiceBrowser
|
|
13
16
|
|
|
@@ -113,6 +116,7 @@ class AgentDiscovery:
|
|
|
113
116
|
s.close()
|
|
114
117
|
return ip
|
|
115
118
|
except Exception:
|
|
119
|
+
logger.debug("Could not detect local IP, falling back to 127.0.0.1")
|
|
116
120
|
return "127.0.0.1"
|
|
117
121
|
|
|
118
122
|
|
a2a_lite/executor.py
CHANGED
|
@@ -6,14 +6,18 @@ from __future__ import annotations
|
|
|
6
6
|
import asyncio
|
|
7
7
|
import inspect
|
|
8
8
|
import json
|
|
9
|
+
import logging
|
|
9
10
|
from typing import Any, Callable, Dict, List, Optional
|
|
10
11
|
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
11
14
|
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
|
12
15
|
from a2a.server.events import EventQueue
|
|
13
16
|
|
|
14
17
|
from .decorators import SkillDefinition
|
|
15
18
|
from .middleware import MiddlewareChain, MiddlewareContext
|
|
16
19
|
from .streaming import is_generator_function, stream_generator
|
|
20
|
+
from .utils import _is_or_subclass
|
|
17
21
|
|
|
18
22
|
|
|
19
23
|
class LiteAgentExecutor(AgentExecutor):
|
|
@@ -55,6 +59,22 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
55
59
|
from a2a.utils import new_agent_text_message
|
|
56
60
|
|
|
57
61
|
try:
|
|
62
|
+
# Authenticate the request
|
|
63
|
+
if self.auth_provider:
|
|
64
|
+
from .auth import AuthRequest, NoAuth
|
|
65
|
+
if not isinstance(self.auth_provider, NoAuth):
|
|
66
|
+
headers = {}
|
|
67
|
+
if context.call_context and context.call_context.state:
|
|
68
|
+
headers = context.call_context.state.get('headers', {})
|
|
69
|
+
auth_request = AuthRequest(headers=headers)
|
|
70
|
+
auth_result = await self.auth_provider.authenticate(auth_request)
|
|
71
|
+
if not auth_result.authenticated:
|
|
72
|
+
error_msg = json.dumps({
|
|
73
|
+
"error": auth_result.error or "Authentication failed",
|
|
74
|
+
})
|
|
75
|
+
await event_queue.enqueue_event(new_agent_text_message(error_msg))
|
|
76
|
+
return
|
|
77
|
+
|
|
58
78
|
# Extract message and parts
|
|
59
79
|
message, parts = self._extract_message_and_parts(context)
|
|
60
80
|
|
|
@@ -100,7 +120,7 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
100
120
|
else:
|
|
101
121
|
hook(skill_name, result, ctx)
|
|
102
122
|
except Exception:
|
|
103
|
-
|
|
123
|
+
logger.warning("Completion hook error for skill '%s'", skill_name, exc_info=True)
|
|
104
124
|
|
|
105
125
|
except Exception as e:
|
|
106
126
|
await self._handle_error(e, event_queue)
|
|
@@ -116,7 +136,14 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
116
136
|
if skill_name is None:
|
|
117
137
|
if not self.skills:
|
|
118
138
|
return {"error": "No skills registered"}
|
|
119
|
-
|
|
139
|
+
# Only auto-select if there's exactly one skill
|
|
140
|
+
if len(self.skills) == 1:
|
|
141
|
+
skill_name = list(self.skills.keys())[0]
|
|
142
|
+
else:
|
|
143
|
+
return {
|
|
144
|
+
"error": "No skill specified. Use {\"skill\": \"name\", \"params\": {...}} format.",
|
|
145
|
+
"available_skills": list(self.skills.keys()),
|
|
146
|
+
}
|
|
120
147
|
|
|
121
148
|
if skill_name not in self.skills:
|
|
122
149
|
return {
|
|
@@ -132,17 +159,19 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
132
159
|
# Inject special contexts if needed
|
|
133
160
|
if skill_def.needs_task_context and self.task_store:
|
|
134
161
|
from .tasks import TaskContext, Task, TaskStatus, TaskState
|
|
135
|
-
task = self.task_store.create(skill_name, params)
|
|
162
|
+
task = await self.task_store.create(skill_name, params)
|
|
136
163
|
# Only pass event_queue for streaming skills (status updates go via SSE)
|
|
137
164
|
eq = event_queue if skill_def.is_streaming else None
|
|
138
165
|
task_ctx = TaskContext(task, eq)
|
|
139
|
-
|
|
166
|
+
param_name = skill_def.task_context_param or "task"
|
|
167
|
+
params[param_name] = task_ctx
|
|
140
168
|
|
|
141
169
|
if skill_def.needs_interaction:
|
|
142
170
|
from .human_loop import InteractionContext
|
|
143
171
|
task_id = metadata.get("task_id", "unknown")
|
|
144
172
|
interaction_ctx = InteractionContext(task_id, event_queue)
|
|
145
|
-
|
|
173
|
+
param_name = skill_def.interaction_param or "ctx"
|
|
174
|
+
params[param_name] = interaction_ctx
|
|
146
175
|
|
|
147
176
|
# Call the handler
|
|
148
177
|
handler = skill_def.handler
|
|
@@ -161,26 +190,33 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
161
190
|
metadata: Dict[str, Any],
|
|
162
191
|
) -> Dict[str, Any]:
|
|
163
192
|
"""Convert parameters to Pydantic models and file parts if needed."""
|
|
193
|
+
import typing
|
|
164
194
|
handler = skill_def.handler
|
|
165
|
-
|
|
195
|
+
try:
|
|
196
|
+
hints = typing.get_type_hints(handler)
|
|
197
|
+
except Exception:
|
|
198
|
+
hints = getattr(handler, '__annotations__', {})
|
|
199
|
+
|
|
200
|
+
from .parts import FilePart, DataPart
|
|
166
201
|
|
|
167
202
|
converted = {}
|
|
168
203
|
for param_name, value in params.items():
|
|
204
|
+
if param_name == 'return':
|
|
205
|
+
continue
|
|
169
206
|
param_type = hints.get(param_name)
|
|
170
207
|
|
|
171
208
|
if param_type is None:
|
|
172
209
|
converted[param_name] = value
|
|
173
210
|
continue
|
|
174
211
|
|
|
175
|
-
type_name = str(param_type)
|
|
176
|
-
|
|
177
212
|
# Skip special context types
|
|
178
|
-
|
|
213
|
+
from .tasks import TaskContext as _TaskContext
|
|
214
|
+
from .human_loop import InteractionContext as _InteractionContext
|
|
215
|
+
if _is_or_subclass(param_type, _TaskContext) or _is_or_subclass(param_type, _InteractionContext):
|
|
179
216
|
continue
|
|
180
217
|
|
|
181
218
|
# Convert FilePart
|
|
182
|
-
if
|
|
183
|
-
from .parts import FilePart
|
|
219
|
+
if _is_or_subclass(param_type, FilePart):
|
|
184
220
|
if isinstance(value, dict):
|
|
185
221
|
# Handle both A2A format and simple dict format
|
|
186
222
|
if "file" in value:
|
|
@@ -201,8 +237,7 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
201
237
|
continue
|
|
202
238
|
|
|
203
239
|
# Convert DataPart
|
|
204
|
-
if
|
|
205
|
-
from .parts import DataPart
|
|
240
|
+
if _is_or_subclass(param_type, DataPart):
|
|
206
241
|
if isinstance(value, dict):
|
|
207
242
|
# Handle both A2A format and simple dict format
|
|
208
243
|
if "type" in value and value.get("type") == "data":
|
|
@@ -234,7 +269,7 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
234
269
|
if isinstance(data, dict) and 'skill' in data:
|
|
235
270
|
return data['skill'], data.get('params', {})
|
|
236
271
|
except json.JSONDecodeError:
|
|
237
|
-
|
|
272
|
+
logger.debug("Message is not JSON, treating as plain text")
|
|
238
273
|
|
|
239
274
|
return None, {"message": message}
|
|
240
275
|
|
a2a_lite/middleware.py
CHANGED
|
@@ -159,7 +159,12 @@ def retry_middleware(max_retries: int = 3, delay: float = 1.0):
|
|
|
159
159
|
|
|
160
160
|
def rate_limit_middleware(requests_per_minute: int = 60):
|
|
161
161
|
"""
|
|
162
|
-
Create a simple rate limiting middleware.
|
|
162
|
+
Create a simple in-process rate limiting middleware.
|
|
163
|
+
|
|
164
|
+
Note: This rate limiter is per-process. Under multi-worker uvicorn
|
|
165
|
+
(e.g., ``--workers 4``), each worker tracks limits independently.
|
|
166
|
+
For shared rate limiting across workers, use an external store
|
|
167
|
+
(Redis, etc.) and a custom middleware.
|
|
163
168
|
|
|
164
169
|
Example:
|
|
165
170
|
agent.add_middleware(rate_limit_middleware(requests_per_minute=100))
|
a2a_lite/tasks.py
CHANGED
|
@@ -27,12 +27,15 @@ Example (with task tracking - opt-in):
|
|
|
27
27
|
"""
|
|
28
28
|
from __future__ import annotations
|
|
29
29
|
|
|
30
|
+
import asyncio
|
|
31
|
+
import logging
|
|
30
32
|
from dataclasses import dataclass, field
|
|
31
|
-
from datetime import datetime
|
|
33
|
+
from datetime import datetime, timezone
|
|
32
34
|
from enum import Enum
|
|
33
35
|
from typing import Any, Callable, Dict, List, Optional
|
|
34
36
|
from uuid import uuid4
|
|
35
|
-
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
36
39
|
|
|
37
40
|
|
|
38
41
|
class TaskState(str, Enum):
|
|
@@ -52,7 +55,7 @@ class TaskStatus:
|
|
|
52
55
|
state: TaskState
|
|
53
56
|
message: Optional[str] = None
|
|
54
57
|
progress: Optional[float] = None # 0.0 to 1.0
|
|
55
|
-
timestamp: datetime = field(default_factory=datetime.
|
|
58
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
56
59
|
|
|
57
60
|
def to_dict(self) -> Dict[str, Any]:
|
|
58
61
|
return {
|
|
@@ -74,8 +77,8 @@ class Task:
|
|
|
74
77
|
error: Optional[str] = None
|
|
75
78
|
artifacts: List[Any] = field(default_factory=list)
|
|
76
79
|
history: List[TaskStatus] = field(default_factory=list)
|
|
77
|
-
created_at: datetime = field(default_factory=datetime.
|
|
78
|
-
updated_at: datetime = field(default_factory=datetime.
|
|
80
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
81
|
+
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
79
82
|
|
|
80
83
|
def update_status(
|
|
81
84
|
self,
|
|
@@ -86,7 +89,7 @@ class Task:
|
|
|
86
89
|
"""Update task status."""
|
|
87
90
|
self.history.append(self.status)
|
|
88
91
|
self.status = TaskStatus(state=state, message=message, progress=progress)
|
|
89
|
-
self.updated_at = datetime.
|
|
92
|
+
self.updated_at = datetime.now(timezone.utc)
|
|
90
93
|
|
|
91
94
|
|
|
92
95
|
class TaskContext:
|
|
@@ -145,7 +148,7 @@ class TaskContext:
|
|
|
145
148
|
else:
|
|
146
149
|
callback(self._task.status)
|
|
147
150
|
except Exception:
|
|
148
|
-
|
|
151
|
+
logger.warning("Status callback error for task '%s'", self._task.id, exc_info=True)
|
|
149
152
|
|
|
150
153
|
# Send SSE event if streaming
|
|
151
154
|
if self._event_queue:
|
|
@@ -170,52 +173,58 @@ class TaskContext:
|
|
|
170
173
|
|
|
171
174
|
class TaskStore:
|
|
172
175
|
"""
|
|
173
|
-
In-memory task store.
|
|
176
|
+
In-memory task store with async locking for thread safety.
|
|
174
177
|
|
|
175
178
|
For production, extend this with Redis/DB backend.
|
|
176
179
|
"""
|
|
177
180
|
|
|
178
181
|
def __init__(self):
|
|
179
182
|
self._tasks: Dict[str, Task] = {}
|
|
183
|
+
self._lock = asyncio.Lock()
|
|
180
184
|
|
|
181
|
-
def create(self, skill: str, params: Dict[str, Any]) -> Task:
|
|
185
|
+
async def create(self, skill: str, params: Dict[str, Any]) -> Task:
|
|
182
186
|
"""Create a new task."""
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
187
|
+
async with self._lock:
|
|
188
|
+
task = Task(
|
|
189
|
+
id=uuid4().hex,
|
|
190
|
+
skill=skill,
|
|
191
|
+
params=params,
|
|
192
|
+
status=TaskStatus(state=TaskState.SUBMITTED),
|
|
193
|
+
)
|
|
194
|
+
self._tasks[task.id] = task
|
|
195
|
+
return task
|
|
196
|
+
|
|
197
|
+
async def get(self, task_id: str) -> Optional[Task]:
|
|
193
198
|
"""Get task by ID."""
|
|
194
|
-
|
|
199
|
+
async with self._lock:
|
|
200
|
+
return self._tasks.get(task_id)
|
|
195
201
|
|
|
196
|
-
def update(self, task: Task) -> None:
|
|
202
|
+
async def update(self, task: Task) -> None:
|
|
197
203
|
"""Update task in store."""
|
|
198
|
-
self.
|
|
204
|
+
async with self._lock:
|
|
205
|
+
self._tasks[task.id] = task
|
|
199
206
|
|
|
200
|
-
def list(
|
|
207
|
+
async def list(
|
|
201
208
|
self,
|
|
202
209
|
state: Optional[TaskState] = None,
|
|
203
210
|
skill: Optional[str] = None,
|
|
204
211
|
limit: int = 100,
|
|
205
212
|
) -> List[Task]:
|
|
206
213
|
"""List tasks with optional filters."""
|
|
207
|
-
|
|
214
|
+
async with self._lock:
|
|
215
|
+
tasks = list(self._tasks.values())
|
|
208
216
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
217
|
+
if state:
|
|
218
|
+
tasks = [t for t in tasks if t.status.state == state]
|
|
219
|
+
if skill:
|
|
220
|
+
tasks = [t for t in tasks if t.skill == skill]
|
|
213
221
|
|
|
214
|
-
|
|
222
|
+
return sorted(tasks, key=lambda t: t.created_at, reverse=True)[:limit]
|
|
215
223
|
|
|
216
|
-
def delete(self, task_id: str) -> bool:
|
|
224
|
+
async def delete(self, task_id: str) -> bool:
|
|
217
225
|
"""Delete a task."""
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
226
|
+
async with self._lock:
|
|
227
|
+
if task_id in self._tasks:
|
|
228
|
+
del self._tasks[task_id]
|
|
229
|
+
return True
|
|
230
|
+
return False
|