a2a-lite 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/PKG-INFO +3 -1
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/pyproject.toml +4 -1
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/__init__.py +1 -1
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/agent.py +34 -10
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/auth.py +17 -6
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/cli.py +1 -1
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/decorators.py +2 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/executor.py +45 -10
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/middleware.py +6 -1
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/tasks.py +5 -5
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/testing.py +10 -1
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/utils.py +5 -1
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/webhooks.py +2 -6
- a2a_lite-0.2.2/tests/test_auth.py +378 -0
- a2a_lite-0.2.0/tests/test_auth.py +0 -177
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/.claude/settings.local.json +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/.gitignore +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/README.md +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/01_hello_world.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/02_calculator.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/03_async_agent.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/04_multi_agent/finance_agent.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/04_multi_agent/reporter_agent.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/04_multi_agent/run_demo.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/05_with_llm.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/06_pydantic_models.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/07_middleware.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/08_streaming.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/09_testing.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/10_webhooks.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/11_human_in_the_loop.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/12_file_handling.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/13_task_tracking.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/examples/14_with_auth.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/discovery.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/human_loop.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/parts.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/src/a2a_lite/streaming.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/__init__.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_agent.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_decorators.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_discovery.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_human_loop.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_integration.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_middleware.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_parts.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_pydantic.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_tasks.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_testing.py +0 -0
- {a2a_lite-0.2.0 → a2a_lite-0.2.2}/tests/test_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: a2a-lite
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: Simplified wrapper for Google's A2A Protocol SDK
|
|
5
5
|
Author: A2A Lite Contributors
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -27,6 +27,8 @@ Provides-Extra: dev
|
|
|
27
27
|
Requires-Dist: httpx>=0.25; extra == 'dev'
|
|
28
28
|
Requires-Dist: pytest-asyncio>=0.21; extra == 'dev'
|
|
29
29
|
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
30
|
+
Provides-Extra: oauth
|
|
31
|
+
Requires-Dist: pyjwt[crypto]>=2.0; extra == 'oauth'
|
|
30
32
|
Description-Content-Type: text/markdown
|
|
31
33
|
|
|
32
34
|
# A2A Lite - Python
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "a2a-lite"
|
|
3
|
-
version = "0.2.
|
|
3
|
+
version = "0.2.2"
|
|
4
4
|
description = "Simplified wrapper for Google's A2A Protocol SDK"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
license = "Apache-2.0"
|
|
@@ -32,6 +32,9 @@ dependencies = [
|
|
|
32
32
|
]
|
|
33
33
|
|
|
34
34
|
[project.optional-dependencies]
|
|
35
|
+
oauth = [
|
|
36
|
+
"pyjwt[crypto]>=2.0",
|
|
37
|
+
]
|
|
35
38
|
dev = [
|
|
36
39
|
"pytest>=7.0",
|
|
37
40
|
"pytest-asyncio>=0.21",
|
|
@@ -170,11 +170,14 @@ class Agent:
|
|
|
170
170
|
import typing
|
|
171
171
|
from .tasks import TaskContext as _TaskContext
|
|
172
172
|
from .human_loop import InteractionContext as _InteractionContext
|
|
173
|
+
from .auth import AuthResult as _AuthResult
|
|
173
174
|
|
|
174
175
|
needs_task_context = False
|
|
175
176
|
needs_interaction = False
|
|
177
|
+
needs_auth = False
|
|
176
178
|
task_context_param: str | None = None
|
|
177
179
|
interaction_param: str | None = None
|
|
180
|
+
auth_param: str | None = None
|
|
178
181
|
|
|
179
182
|
try:
|
|
180
183
|
resolved_hints = typing.get_type_hints(func)
|
|
@@ -190,6 +193,14 @@ class Agent:
|
|
|
190
193
|
elif _is_or_subclass(hint, _InteractionContext):
|
|
191
194
|
needs_interaction = True
|
|
192
195
|
interaction_param = param_name
|
|
196
|
+
elif _is_or_subclass(hint, _AuthResult):
|
|
197
|
+
needs_auth = True
|
|
198
|
+
auth_param = param_name
|
|
199
|
+
|
|
200
|
+
# Also detect require_auth decorator
|
|
201
|
+
if getattr(func, '__requires_auth__', False) and not needs_auth:
|
|
202
|
+
needs_auth = True
|
|
203
|
+
auth_param = auth_param or "auth"
|
|
193
204
|
|
|
194
205
|
# Extract schemas
|
|
195
206
|
input_schema, output_schema = extract_function_schemas(func)
|
|
@@ -205,8 +216,10 @@ class Agent:
|
|
|
205
216
|
is_streaming=is_streaming,
|
|
206
217
|
needs_task_context=needs_task_context,
|
|
207
218
|
needs_interaction=needs_interaction,
|
|
219
|
+
needs_auth=needs_auth,
|
|
208
220
|
task_context_param=task_context_param,
|
|
209
221
|
interaction_param=interaction_param,
|
|
222
|
+
auth_param=auth_param,
|
|
210
223
|
)
|
|
211
224
|
|
|
212
225
|
self._skills[skill_name] = skill_def
|
|
@@ -332,6 +345,10 @@ class Agent:
|
|
|
332
345
|
task_store=self._task_store,
|
|
333
346
|
)
|
|
334
347
|
|
|
348
|
+
# The SDK's InMemoryTaskStore handles protocol-level task lifecycle
|
|
349
|
+
# (task creation, state transitions per the A2A spec). This is separate
|
|
350
|
+
# from self._task_store which provides application-level tracking
|
|
351
|
+
# (progress updates, custom status) exposed via TaskContext to skills.
|
|
335
352
|
request_handler = DefaultRequestHandler(
|
|
336
353
|
agent_executor=executor,
|
|
337
354
|
task_store=InMemoryTaskStore(),
|
|
@@ -379,11 +396,14 @@ class Agent:
|
|
|
379
396
|
))
|
|
380
397
|
|
|
381
398
|
# Run startup hooks
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
asyncio.
|
|
385
|
-
|
|
386
|
-
|
|
399
|
+
async def _run_startup():
|
|
400
|
+
for hook in self._on_startup:
|
|
401
|
+
if asyncio.iscoroutinefunction(hook):
|
|
402
|
+
await hook()
|
|
403
|
+
else:
|
|
404
|
+
hook()
|
|
405
|
+
if self._on_startup:
|
|
406
|
+
asyncio.run(_run_startup())
|
|
387
407
|
|
|
388
408
|
# Enable discovery if requested
|
|
389
409
|
if enable_discovery:
|
|
@@ -430,11 +450,14 @@ class Agent:
|
|
|
430
450
|
)
|
|
431
451
|
finally:
|
|
432
452
|
# Run shutdown hooks
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
asyncio.
|
|
436
|
-
|
|
437
|
-
|
|
453
|
+
async def _run_shutdown():
|
|
454
|
+
for hook in self._on_shutdown:
|
|
455
|
+
if asyncio.iscoroutinefunction(hook):
|
|
456
|
+
await hook()
|
|
457
|
+
else:
|
|
458
|
+
hook()
|
|
459
|
+
if self._on_shutdown:
|
|
460
|
+
asyncio.run(_run_shutdown())
|
|
438
461
|
|
|
439
462
|
# Unregister discovery
|
|
440
463
|
if self._discovery:
|
|
@@ -485,6 +508,7 @@ class Agent:
|
|
|
485
508
|
task_store=self._task_store,
|
|
486
509
|
)
|
|
487
510
|
|
|
511
|
+
# SDK task store for protocol-level lifecycle (separate from app-level self._task_store)
|
|
488
512
|
request_handler = DefaultRequestHandler(
|
|
489
513
|
agent_executor=executor,
|
|
490
514
|
task_store=InMemoryTaskStore(),
|
|
@@ -60,6 +60,17 @@ class AuthRequest:
|
|
|
60
60
|
method: str = "POST"
|
|
61
61
|
path: str = "/"
|
|
62
62
|
|
|
63
|
+
def get_header(self, name: str) -> Optional[str]:
|
|
64
|
+
"""Get a header value (case-insensitive)."""
|
|
65
|
+
# Try exact match first, then case-insensitive
|
|
66
|
+
if name in self.headers:
|
|
67
|
+
return self.headers[name]
|
|
68
|
+
lower = name.lower()
|
|
69
|
+
for k, v in self.headers.items():
|
|
70
|
+
if k.lower() == lower:
|
|
71
|
+
return v
|
|
72
|
+
return None
|
|
73
|
+
|
|
63
74
|
|
|
64
75
|
@dataclass
|
|
65
76
|
class AuthResult:
|
|
@@ -128,8 +139,8 @@ class APIKeyAuth(AuthProvider):
|
|
|
128
139
|
return hashlib.sha256(key.encode()).hexdigest()
|
|
129
140
|
|
|
130
141
|
async def authenticate(self, request: AuthRequest) -> AuthResult:
|
|
131
|
-
# Check header
|
|
132
|
-
key = request.
|
|
142
|
+
# Check header (case-insensitive)
|
|
143
|
+
key = request.get_header(self.header)
|
|
133
144
|
|
|
134
145
|
# Check query param
|
|
135
146
|
if not key and self.query_param:
|
|
@@ -179,7 +190,7 @@ class BearerAuth(AuthProvider):
|
|
|
179
190
|
self.header = header
|
|
180
191
|
|
|
181
192
|
async def authenticate(self, request: AuthRequest) -> AuthResult:
|
|
182
|
-
auth_header = request.
|
|
193
|
+
auth_header = request.get_header(self.header) or ""
|
|
183
194
|
|
|
184
195
|
if not auth_header.startswith("Bearer "):
|
|
185
196
|
return AuthResult.failure("Bearer token required")
|
|
@@ -211,7 +222,7 @@ class OAuth2Auth(AuthProvider):
|
|
|
211
222
|
audience="my-agent",
|
|
212
223
|
)
|
|
213
224
|
|
|
214
|
-
Requires: pip install
|
|
225
|
+
Requires: pip install a2a-lite[oauth]
|
|
215
226
|
"""
|
|
216
227
|
|
|
217
228
|
def __init__(
|
|
@@ -228,7 +239,7 @@ class OAuth2Auth(AuthProvider):
|
|
|
228
239
|
self._jwks_client = None
|
|
229
240
|
|
|
230
241
|
async def authenticate(self, request: AuthRequest) -> AuthResult:
|
|
231
|
-
auth_header = request.
|
|
242
|
+
auth_header = request.get_header("Authorization") or ""
|
|
232
243
|
|
|
233
244
|
if not auth_header.startswith("Bearer "):
|
|
234
245
|
return AuthResult.failure("Bearer token required")
|
|
@@ -266,7 +277,7 @@ class OAuth2Auth(AuthProvider):
|
|
|
266
277
|
|
|
267
278
|
except ImportError:
|
|
268
279
|
return AuthResult.failure(
|
|
269
|
-
"OAuth2 requires pyjwt: pip install
|
|
280
|
+
"OAuth2 requires pyjwt: pip install a2a-lite[oauth]"
|
|
270
281
|
)
|
|
271
282
|
except Exception as e:
|
|
272
283
|
return AuthResult.failure(f"Token validation failed: {str(e)}")
|
|
@@ -18,8 +18,10 @@ class SkillDefinition:
|
|
|
18
18
|
is_streaming: bool = False
|
|
19
19
|
needs_task_context: bool = False
|
|
20
20
|
needs_interaction: bool = False
|
|
21
|
+
needs_auth: bool = False
|
|
21
22
|
task_context_param: Optional[str] = None
|
|
22
23
|
interaction_param: Optional[str] = None
|
|
24
|
+
auth_param: Optional[str] = None
|
|
23
25
|
|
|
24
26
|
def to_dict(self) -> Dict[str, Any]:
|
|
25
27
|
"""Convert to dictionary for serialization."""
|
|
@@ -59,6 +59,23 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
59
59
|
from a2a.utils import new_agent_text_message
|
|
60
60
|
|
|
61
61
|
try:
|
|
62
|
+
# Authenticate the request (always run to produce auth_result for injection)
|
|
63
|
+
auth_result = None
|
|
64
|
+
if self.auth_provider:
|
|
65
|
+
from .auth import AuthRequest, NoAuth
|
|
66
|
+
headers = {}
|
|
67
|
+
if context.call_context and context.call_context.state:
|
|
68
|
+
headers = context.call_context.state.get('headers', {})
|
|
69
|
+
auth_request = AuthRequest(headers=headers)
|
|
70
|
+
auth_result = await self.auth_provider.authenticate(auth_request)
|
|
71
|
+
# Reject unauthenticated requests (unless NoAuth)
|
|
72
|
+
if not isinstance(self.auth_provider, NoAuth) and not auth_result.authenticated:
|
|
73
|
+
error_msg = json.dumps({
|
|
74
|
+
"error": auth_result.error or "Authentication failed",
|
|
75
|
+
})
|
|
76
|
+
await event_queue.enqueue_event(new_agent_text_message(error_msg))
|
|
77
|
+
return
|
|
78
|
+
|
|
62
79
|
# Extract message and parts
|
|
63
80
|
message, parts = self._extract_message_and_parts(context)
|
|
64
81
|
|
|
@@ -72,9 +89,10 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
72
89
|
message=message,
|
|
73
90
|
)
|
|
74
91
|
|
|
75
|
-
# Store parts in metadata for skill access
|
|
92
|
+
# Store parts and auth result in metadata for skill access
|
|
76
93
|
ctx.metadata["parts"] = parts
|
|
77
94
|
ctx.metadata["event_queue"] = event_queue
|
|
95
|
+
ctx.metadata["auth_result"] = auth_result
|
|
78
96
|
|
|
79
97
|
# Define final handler
|
|
80
98
|
async def final_handler(ctx: MiddlewareContext) -> Any:
|
|
@@ -120,7 +138,14 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
120
138
|
if skill_name is None:
|
|
121
139
|
if not self.skills:
|
|
122
140
|
return {"error": "No skills registered"}
|
|
123
|
-
|
|
141
|
+
# Only auto-select if there's exactly one skill
|
|
142
|
+
if len(self.skills) == 1:
|
|
143
|
+
skill_name = list(self.skills.keys())[0]
|
|
144
|
+
else:
|
|
145
|
+
return {
|
|
146
|
+
"error": "No skill specified. Use {\"skill\": \"name\", \"params\": {...}} format.",
|
|
147
|
+
"available_skills": list(self.skills.keys()),
|
|
148
|
+
}
|
|
124
149
|
|
|
125
150
|
if skill_name not in self.skills:
|
|
126
151
|
return {
|
|
@@ -150,6 +175,10 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
150
175
|
param_name = skill_def.interaction_param or "ctx"
|
|
151
176
|
params[param_name] = interaction_ctx
|
|
152
177
|
|
|
178
|
+
if skill_def.needs_auth:
|
|
179
|
+
param_name = skill_def.auth_param or "auth"
|
|
180
|
+
params[param_name] = metadata.get("auth_result")
|
|
181
|
+
|
|
153
182
|
# Call the handler
|
|
154
183
|
handler = skill_def.handler
|
|
155
184
|
|
|
@@ -167,11 +196,19 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
167
196
|
metadata: Dict[str, Any],
|
|
168
197
|
) -> Dict[str, Any]:
|
|
169
198
|
"""Convert parameters to Pydantic models and file parts if needed."""
|
|
199
|
+
import typing
|
|
170
200
|
handler = skill_def.handler
|
|
171
|
-
|
|
201
|
+
try:
|
|
202
|
+
hints = typing.get_type_hints(handler)
|
|
203
|
+
except Exception:
|
|
204
|
+
hints = getattr(handler, '__annotations__', {})
|
|
205
|
+
|
|
206
|
+
from .parts import FilePart, DataPart
|
|
172
207
|
|
|
173
208
|
converted = {}
|
|
174
209
|
for param_name, value in params.items():
|
|
210
|
+
if param_name == 'return':
|
|
211
|
+
continue
|
|
175
212
|
param_type = hints.get(param_name)
|
|
176
213
|
|
|
177
214
|
if param_type is None:
|
|
@@ -181,13 +218,12 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
181
218
|
# Skip special context types
|
|
182
219
|
from .tasks import TaskContext as _TaskContext
|
|
183
220
|
from .human_loop import InteractionContext as _InteractionContext
|
|
184
|
-
|
|
221
|
+
from .auth import AuthResult as _AuthResult
|
|
222
|
+
if _is_or_subclass(param_type, _TaskContext) or _is_or_subclass(param_type, _InteractionContext) or _is_or_subclass(param_type, _AuthResult):
|
|
185
223
|
continue
|
|
186
224
|
|
|
187
225
|
# Convert FilePart
|
|
188
|
-
|
|
189
|
-
if "FilePart" in type_name:
|
|
190
|
-
from .parts import FilePart
|
|
226
|
+
if _is_or_subclass(param_type, FilePart):
|
|
191
227
|
if isinstance(value, dict):
|
|
192
228
|
# Handle both A2A format and simple dict format
|
|
193
229
|
if "file" in value:
|
|
@@ -208,8 +244,7 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
208
244
|
continue
|
|
209
245
|
|
|
210
246
|
# Convert DataPart
|
|
211
|
-
if
|
|
212
|
-
from .parts import DataPart
|
|
247
|
+
if _is_or_subclass(param_type, DataPart):
|
|
213
248
|
if isinstance(value, dict):
|
|
214
249
|
# Handle both A2A format and simple dict format
|
|
215
250
|
if "type" in value and value.get("type") == "data":
|
|
@@ -317,7 +352,7 @@ class LiteAgentExecutor(AgentExecutor):
|
|
|
317
352
|
if asyncio.iscoroutinefunction(handler):
|
|
318
353
|
return await handler(*args, **kwargs)
|
|
319
354
|
else:
|
|
320
|
-
loop = asyncio.
|
|
355
|
+
loop = asyncio.get_running_loop()
|
|
321
356
|
return await loop.run_in_executor(
|
|
322
357
|
None,
|
|
323
358
|
lambda: handler(*args, **kwargs)
|
|
@@ -159,7 +159,12 @@ def retry_middleware(max_retries: int = 3, delay: float = 1.0):
|
|
|
159
159
|
|
|
160
160
|
def rate_limit_middleware(requests_per_minute: int = 60):
|
|
161
161
|
"""
|
|
162
|
-
Create a simple rate limiting middleware.
|
|
162
|
+
Create a simple in-process rate limiting middleware.
|
|
163
|
+
|
|
164
|
+
Note: This rate limiter is per-process. Under multi-worker uvicorn
|
|
165
|
+
(e.g., ``--workers 4``), each worker tracks limits independently.
|
|
166
|
+
For shared rate limiting across workers, use an external store
|
|
167
|
+
(Redis, etc.) and a custom middleware.
|
|
163
168
|
|
|
164
169
|
Example:
|
|
165
170
|
agent.add_middleware(rate_limit_middleware(requests_per_minute=100))
|
|
@@ -30,7 +30,7 @@ from __future__ import annotations
|
|
|
30
30
|
import asyncio
|
|
31
31
|
import logging
|
|
32
32
|
from dataclasses import dataclass, field
|
|
33
|
-
from datetime import datetime
|
|
33
|
+
from datetime import datetime, timezone
|
|
34
34
|
from enum import Enum
|
|
35
35
|
from typing import Any, Callable, Dict, List, Optional
|
|
36
36
|
from uuid import uuid4
|
|
@@ -55,7 +55,7 @@ class TaskStatus:
|
|
|
55
55
|
state: TaskState
|
|
56
56
|
message: Optional[str] = None
|
|
57
57
|
progress: Optional[float] = None # 0.0 to 1.0
|
|
58
|
-
timestamp: datetime = field(default_factory=datetime.
|
|
58
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
59
59
|
|
|
60
60
|
def to_dict(self) -> Dict[str, Any]:
|
|
61
61
|
return {
|
|
@@ -77,8 +77,8 @@ class Task:
|
|
|
77
77
|
error: Optional[str] = None
|
|
78
78
|
artifacts: List[Any] = field(default_factory=list)
|
|
79
79
|
history: List[TaskStatus] = field(default_factory=list)
|
|
80
|
-
created_at: datetime = field(default_factory=datetime.
|
|
81
|
-
updated_at: datetime = field(default_factory=datetime.
|
|
80
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
81
|
+
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
82
82
|
|
|
83
83
|
def update_status(
|
|
84
84
|
self,
|
|
@@ -89,7 +89,7 @@ class Task:
|
|
|
89
89
|
"""Update task status."""
|
|
90
90
|
self.history.append(self.status)
|
|
91
91
|
self.status = TaskStatus(state=state, message=message, progress=progress)
|
|
92
|
-
self.updated_at = datetime.
|
|
92
|
+
self.updated_at = datetime.now(timezone.utc)
|
|
93
93
|
|
|
94
94
|
|
|
95
95
|
class TaskContext:
|
|
@@ -214,7 +214,16 @@ class AgentTestClient:
|
|
|
214
214
|
result = await gen
|
|
215
215
|
results.append(result)
|
|
216
216
|
|
|
217
|
-
|
|
217
|
+
# Handle both sync and async calling contexts
|
|
218
|
+
try:
|
|
219
|
+
asyncio.get_running_loop()
|
|
220
|
+
# Already in an async context — run in a separate thread
|
|
221
|
+
import concurrent.futures
|
|
222
|
+
with concurrent.futures.ThreadPoolExecutor(1) as pool:
|
|
223
|
+
pool.submit(asyncio.run, run_handler()).result()
|
|
224
|
+
except RuntimeError:
|
|
225
|
+
# No running loop — safe to use asyncio.run()
|
|
226
|
+
asyncio.run(run_handler())
|
|
218
227
|
return results
|
|
219
228
|
|
|
220
229
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Helper functions for A2A Lite.
|
|
3
3
|
"""
|
|
4
|
+
import typing
|
|
4
5
|
from typing import Any, Dict, Type, get_origin, get_args, Union
|
|
5
6
|
import inspect
|
|
6
7
|
|
|
@@ -103,7 +104,10 @@ def extract_function_schemas(func) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
|
|
103
104
|
Tuple of (input_schema, output_schema)
|
|
104
105
|
"""
|
|
105
106
|
sig = inspect.signature(func)
|
|
106
|
-
|
|
107
|
+
try:
|
|
108
|
+
hints = typing.get_type_hints(func)
|
|
109
|
+
except Exception:
|
|
110
|
+
hints = getattr(func, '__annotations__', {})
|
|
107
111
|
|
|
108
112
|
# Build input schema from parameters
|
|
109
113
|
properties = {}
|
|
@@ -9,7 +9,7 @@ Simplifies sending notifications when tasks complete:
|
|
|
9
9
|
"""
|
|
10
10
|
from __future__ import annotations
|
|
11
11
|
|
|
12
|
-
from dataclasses import dataclass
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
13
|
from typing import Any, Callable, Dict, List, Optional
|
|
14
14
|
import asyncio
|
|
15
15
|
import json
|
|
@@ -19,15 +19,11 @@ import json
|
|
|
19
19
|
class WebhookConfig:
|
|
20
20
|
"""Configuration for a webhook endpoint."""
|
|
21
21
|
url: str
|
|
22
|
-
headers: Dict[str, str] =
|
|
22
|
+
headers: Dict[str, str] = field(default_factory=dict)
|
|
23
23
|
retry_count: int = 3
|
|
24
24
|
retry_delay: float = 1.0
|
|
25
25
|
timeout: float = 30.0
|
|
26
26
|
|
|
27
|
-
def __post_init__(self):
|
|
28
|
-
if self.headers is None:
|
|
29
|
-
self.headers = {}
|
|
30
|
-
|
|
31
27
|
|
|
32
28
|
class WebhookClient:
|
|
33
29
|
"""
|
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for authentication providers.
|
|
3
|
+
"""
|
|
4
|
+
import pytest
|
|
5
|
+
from a2a_lite.auth import (
|
|
6
|
+
AuthRequest,
|
|
7
|
+
AuthResult,
|
|
8
|
+
NoAuth,
|
|
9
|
+
APIKeyAuth,
|
|
10
|
+
BearerAuth,
|
|
11
|
+
CompositeAuth,
|
|
12
|
+
require_auth,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TestAuthResult:
|
|
17
|
+
def test_success(self):
|
|
18
|
+
result = AuthResult.success(user_id="user-123", scopes={"read", "write"})
|
|
19
|
+
|
|
20
|
+
assert result.authenticated is True
|
|
21
|
+
assert result.user_id == "user-123"
|
|
22
|
+
assert "read" in result.scopes
|
|
23
|
+
|
|
24
|
+
def test_failure(self):
|
|
25
|
+
result = AuthResult.failure("Invalid token")
|
|
26
|
+
|
|
27
|
+
assert result.authenticated is False
|
|
28
|
+
assert result.error == "Invalid token"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TestNoAuth:
|
|
32
|
+
@pytest.mark.asyncio
|
|
33
|
+
async def test_always_succeeds(self):
|
|
34
|
+
auth = NoAuth()
|
|
35
|
+
request = AuthRequest(headers={})
|
|
36
|
+
|
|
37
|
+
result = await auth.authenticate(request)
|
|
38
|
+
|
|
39
|
+
assert result.authenticated is True
|
|
40
|
+
assert result.user_id == "anonymous"
|
|
41
|
+
|
|
42
|
+
def test_empty_scheme(self):
|
|
43
|
+
auth = NoAuth()
|
|
44
|
+
assert auth.get_scheme() == {}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TestAPIKeyAuth:
|
|
48
|
+
@pytest.mark.asyncio
|
|
49
|
+
async def test_valid_key_in_header(self):
|
|
50
|
+
auth = APIKeyAuth(keys=["secret-key", "another-key"])
|
|
51
|
+
request = AuthRequest(headers={"X-API-Key": "secret-key"})
|
|
52
|
+
|
|
53
|
+
result = await auth.authenticate(request)
|
|
54
|
+
|
|
55
|
+
assert result.authenticated is True
|
|
56
|
+
assert result.user_id is not None
|
|
57
|
+
|
|
58
|
+
@pytest.mark.asyncio
|
|
59
|
+
async def test_invalid_key(self):
|
|
60
|
+
auth = APIKeyAuth(keys=["secret-key"])
|
|
61
|
+
request = AuthRequest(headers={"X-API-Key": "wrong-key"})
|
|
62
|
+
|
|
63
|
+
result = await auth.authenticate(request)
|
|
64
|
+
|
|
65
|
+
assert result.authenticated is False
|
|
66
|
+
assert "Invalid" in result.error
|
|
67
|
+
|
|
68
|
+
@pytest.mark.asyncio
|
|
69
|
+
async def test_missing_key(self):
|
|
70
|
+
auth = APIKeyAuth(keys=["secret-key"])
|
|
71
|
+
request = AuthRequest(headers={})
|
|
72
|
+
|
|
73
|
+
result = await auth.authenticate(request)
|
|
74
|
+
|
|
75
|
+
assert result.authenticated is False
|
|
76
|
+
assert "required" in result.error.lower()
|
|
77
|
+
|
|
78
|
+
@pytest.mark.asyncio
|
|
79
|
+
async def test_custom_header(self):
|
|
80
|
+
auth = APIKeyAuth(keys=["key123"], header="Authorization")
|
|
81
|
+
request = AuthRequest(headers={"Authorization": "key123"})
|
|
82
|
+
|
|
83
|
+
result = await auth.authenticate(request)
|
|
84
|
+
|
|
85
|
+
assert result.authenticated is True
|
|
86
|
+
|
|
87
|
+
@pytest.mark.asyncio
|
|
88
|
+
async def test_query_param(self):
|
|
89
|
+
auth = APIKeyAuth(keys=["key123"], query_param="api_key")
|
|
90
|
+
request = AuthRequest(headers={}, query_params={"api_key": "key123"})
|
|
91
|
+
|
|
92
|
+
result = await auth.authenticate(request)
|
|
93
|
+
|
|
94
|
+
assert result.authenticated is True
|
|
95
|
+
|
|
96
|
+
def test_scheme(self):
|
|
97
|
+
auth = APIKeyAuth(keys=["key"], header="X-API-Key")
|
|
98
|
+
scheme = auth.get_scheme()
|
|
99
|
+
|
|
100
|
+
assert scheme["type"] == "apiKey"
|
|
101
|
+
assert scheme["in"] == "header"
|
|
102
|
+
assert scheme["name"] == "X-API-Key"
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class TestBearerAuth:
|
|
106
|
+
@pytest.mark.asyncio
|
|
107
|
+
async def test_valid_token(self):
|
|
108
|
+
def validator(token):
|
|
109
|
+
if token == "valid-token":
|
|
110
|
+
return "user-123"
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
auth = BearerAuth(validator=validator)
|
|
114
|
+
request = AuthRequest(headers={"Authorization": "Bearer valid-token"})
|
|
115
|
+
|
|
116
|
+
result = await auth.authenticate(request)
|
|
117
|
+
|
|
118
|
+
assert result.authenticated is True
|
|
119
|
+
assert result.user_id == "user-123"
|
|
120
|
+
|
|
121
|
+
@pytest.mark.asyncio
|
|
122
|
+
async def test_invalid_token(self):
|
|
123
|
+
def validator(token):
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
auth = BearerAuth(validator=validator)
|
|
127
|
+
request = AuthRequest(headers={"Authorization": "Bearer invalid"})
|
|
128
|
+
|
|
129
|
+
result = await auth.authenticate(request)
|
|
130
|
+
|
|
131
|
+
assert result.authenticated is False
|
|
132
|
+
|
|
133
|
+
@pytest.mark.asyncio
|
|
134
|
+
async def test_missing_bearer_prefix(self):
|
|
135
|
+
auth = BearerAuth(validator=lambda t: "user")
|
|
136
|
+
request = AuthRequest(headers={"Authorization": "token-without-bearer"})
|
|
137
|
+
|
|
138
|
+
result = await auth.authenticate(request)
|
|
139
|
+
|
|
140
|
+
assert result.authenticated is False
|
|
141
|
+
|
|
142
|
+
def test_scheme(self):
|
|
143
|
+
auth = BearerAuth(validator=lambda t: None)
|
|
144
|
+
scheme = auth.get_scheme()
|
|
145
|
+
|
|
146
|
+
assert scheme["type"] == "http"
|
|
147
|
+
assert scheme["scheme"] == "bearer"
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class TestCompositeAuth:
|
|
151
|
+
@pytest.mark.asyncio
|
|
152
|
+
async def test_first_match_wins(self):
|
|
153
|
+
auth = CompositeAuth([
|
|
154
|
+
APIKeyAuth(keys=["api-key"]),
|
|
155
|
+
BearerAuth(validator=lambda t: "bearer-user" if t == "token" else None),
|
|
156
|
+
])
|
|
157
|
+
|
|
158
|
+
# API key should work
|
|
159
|
+
request1 = AuthRequest(headers={"X-API-Key": "api-key"})
|
|
160
|
+
result1 = await auth.authenticate(request1)
|
|
161
|
+
assert result1.authenticated is True
|
|
162
|
+
|
|
163
|
+
# Bearer should also work
|
|
164
|
+
request2 = AuthRequest(headers={"Authorization": "Bearer token"})
|
|
165
|
+
result2 = await auth.authenticate(request2)
|
|
166
|
+
assert result2.authenticated is True
|
|
167
|
+
|
|
168
|
+
@pytest.mark.asyncio
|
|
169
|
+
async def test_all_fail(self):
|
|
170
|
+
auth = CompositeAuth([
|
|
171
|
+
APIKeyAuth(keys=["key1"]),
|
|
172
|
+
APIKeyAuth(keys=["key2"]),
|
|
173
|
+
])
|
|
174
|
+
|
|
175
|
+
request = AuthRequest(headers={"X-API-Key": "wrong"})
|
|
176
|
+
result = await auth.authenticate(request)
|
|
177
|
+
|
|
178
|
+
assert result.authenticated is False
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class TestAuthIntegration:
|
|
182
|
+
"""Test that auth is actually enforced in the HTTP request pipeline."""
|
|
183
|
+
|
|
184
|
+
def _make_agent_with_auth(self):
|
|
185
|
+
from a2a_lite import Agent
|
|
186
|
+
agent = Agent(
|
|
187
|
+
name="SecureAgent",
|
|
188
|
+
description="Auth integration test",
|
|
189
|
+
auth=APIKeyAuth(keys=["valid-key"]),
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
@agent.skill("secret")
|
|
193
|
+
async def secret(data: str) -> str:
|
|
194
|
+
return f"secret: {data}"
|
|
195
|
+
|
|
196
|
+
return agent
|
|
197
|
+
|
|
198
|
+
def test_unauthenticated_request_rejected(self):
|
|
199
|
+
"""Requests without a valid API key should be rejected."""
|
|
200
|
+
from starlette.testclient import TestClient
|
|
201
|
+
import json
|
|
202
|
+
from uuid import uuid4
|
|
203
|
+
|
|
204
|
+
agent = self._make_agent_with_auth()
|
|
205
|
+
app = agent.get_app()
|
|
206
|
+
client = TestClient(app)
|
|
207
|
+
|
|
208
|
+
request_body = {
|
|
209
|
+
"jsonrpc": "2.0",
|
|
210
|
+
"method": "message/send",
|
|
211
|
+
"id": uuid4().hex,
|
|
212
|
+
"params": {
|
|
213
|
+
"message": {
|
|
214
|
+
"role": "user",
|
|
215
|
+
"parts": [{"type": "text", "text": json.dumps({"skill": "secret", "params": {"data": "hello"}})}],
|
|
216
|
+
"messageId": uuid4().hex,
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
response = client.post("/", json=request_body)
|
|
222
|
+
data = response.json()
|
|
223
|
+
|
|
224
|
+
# The response should contain an auth error, not the skill result
|
|
225
|
+
result_text = data.get("result", {}).get("parts", [{}])[0].get("text", "")
|
|
226
|
+
assert "error" in result_text.lower() or "auth" in result_text.lower() or "key" in result_text.lower()
|
|
227
|
+
assert "secret: hello" not in result_text
|
|
228
|
+
|
|
229
|
+
def test_authenticated_request_succeeds(self):
|
|
230
|
+
"""Requests with a valid API key should succeed."""
|
|
231
|
+
from starlette.testclient import TestClient
|
|
232
|
+
import json
|
|
233
|
+
from uuid import uuid4
|
|
234
|
+
|
|
235
|
+
agent = self._make_agent_with_auth()
|
|
236
|
+
app = agent.get_app()
|
|
237
|
+
client = TestClient(app)
|
|
238
|
+
|
|
239
|
+
request_body = {
|
|
240
|
+
"jsonrpc": "2.0",
|
|
241
|
+
"method": "message/send",
|
|
242
|
+
"id": uuid4().hex,
|
|
243
|
+
"params": {
|
|
244
|
+
"message": {
|
|
245
|
+
"role": "user",
|
|
246
|
+
"parts": [{"type": "text", "text": json.dumps({"skill": "secret", "params": {"data": "hello"}})}],
|
|
247
|
+
"messageId": uuid4().hex,
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
response = client.post(
|
|
253
|
+
"/",
|
|
254
|
+
json=request_body,
|
|
255
|
+
headers={"X-API-Key": "valid-key"},
|
|
256
|
+
)
|
|
257
|
+
data = response.json()
|
|
258
|
+
|
|
259
|
+
result_text = data.get("result", {}).get("parts", [{}])[0].get("text", "")
|
|
260
|
+
assert "secret: hello" in result_text
|
|
261
|
+
|
|
262
|
+
def test_wrong_key_rejected(self):
|
|
263
|
+
"""Requests with an invalid API key should be rejected."""
|
|
264
|
+
from starlette.testclient import TestClient
|
|
265
|
+
import json
|
|
266
|
+
from uuid import uuid4
|
|
267
|
+
|
|
268
|
+
agent = self._make_agent_with_auth()
|
|
269
|
+
app = agent.get_app()
|
|
270
|
+
client = TestClient(app)
|
|
271
|
+
|
|
272
|
+
request_body = {
|
|
273
|
+
"jsonrpc": "2.0",
|
|
274
|
+
"method": "message/send",
|
|
275
|
+
"id": uuid4().hex,
|
|
276
|
+
"params": {
|
|
277
|
+
"message": {
|
|
278
|
+
"role": "user",
|
|
279
|
+
"parts": [{"type": "text", "text": json.dumps({"skill": "secret", "params": {"data": "hello"}})}],
|
|
280
|
+
"messageId": uuid4().hex,
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
response = client.post(
|
|
286
|
+
"/",
|
|
287
|
+
json=request_body,
|
|
288
|
+
headers={"X-API-Key": "wrong-key"},
|
|
289
|
+
)
|
|
290
|
+
data = response.json()
|
|
291
|
+
|
|
292
|
+
result_text = data.get("result", {}).get("parts", [{}])[0].get("text", "")
|
|
293
|
+
assert "secret: hello" not in result_text
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class TestRequireAuth:
|
|
297
|
+
"""Test that require_auth decorator receives AuthResult from executor."""
|
|
298
|
+
|
|
299
|
+
def test_require_auth_receives_auth_result(self):
|
|
300
|
+
"""Skills decorated with require_auth should receive the AuthResult."""
|
|
301
|
+
from a2a_lite import Agent
|
|
302
|
+
from a2a_lite.testing import AgentTestClient
|
|
303
|
+
from starlette.testclient import TestClient
|
|
304
|
+
import json
|
|
305
|
+
from uuid import uuid4
|
|
306
|
+
|
|
307
|
+
agent = Agent(
|
|
308
|
+
name="AuthTest",
|
|
309
|
+
description="require_auth test",
|
|
310
|
+
auth=APIKeyAuth(keys=["my-key"]),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
@agent.skill("admin")
|
|
314
|
+
@require_auth(scopes=["admin"])
|
|
315
|
+
async def admin_action(data: str, auth: AuthResult) -> str:
|
|
316
|
+
return f"admin:{auth.user_id}:{data}"
|
|
317
|
+
|
|
318
|
+
app = agent.get_app()
|
|
319
|
+
client = TestClient(app)
|
|
320
|
+
|
|
321
|
+
# Without auth — should be rejected at the gate
|
|
322
|
+
request_body = {
|
|
323
|
+
"jsonrpc": "2.0",
|
|
324
|
+
"method": "message/send",
|
|
325
|
+
"id": uuid4().hex,
|
|
326
|
+
"params": {
|
|
327
|
+
"message": {
|
|
328
|
+
"role": "user",
|
|
329
|
+
"parts": [{"type": "text", "text": json.dumps({"skill": "admin", "params": {"data": "hello"}})}],
|
|
330
|
+
"messageId": uuid4().hex,
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
}
|
|
334
|
+
response = client.post("/", json=request_body)
|
|
335
|
+
result_text = response.json().get("result", {}).get("parts", [{}])[0].get("text", "")
|
|
336
|
+
assert "admin:" not in result_text
|
|
337
|
+
|
|
338
|
+
# With valid auth — require_auth checks scopes (no admin scope, so should fail)
|
|
339
|
+
response = client.post("/", json=request_body, headers={"X-API-Key": "my-key"})
|
|
340
|
+
result_text = response.json().get("result", {}).get("parts", [{}])[0].get("text", "")
|
|
341
|
+
assert "Insufficient permissions" in result_text or "error" in result_text.lower()
|
|
342
|
+
|
|
343
|
+
def test_auth_param_injected_without_decorator(self):
|
|
344
|
+
"""Skills with auth: AuthResult parameter should receive it directly."""
|
|
345
|
+
from a2a_lite import Agent
|
|
346
|
+
from starlette.testclient import TestClient
|
|
347
|
+
import json
|
|
348
|
+
from uuid import uuid4
|
|
349
|
+
|
|
350
|
+
agent = Agent(
|
|
351
|
+
name="AuthTest2",
|
|
352
|
+
description="auth param test",
|
|
353
|
+
auth=APIKeyAuth(keys=["my-key"]),
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
@agent.skill("whoami")
|
|
357
|
+
async def whoami(auth: AuthResult) -> str:
|
|
358
|
+
return f"user:{auth.user_id}"
|
|
359
|
+
|
|
360
|
+
app = agent.get_app()
|
|
361
|
+
client = TestClient(app)
|
|
362
|
+
|
|
363
|
+
request_body = {
|
|
364
|
+
"jsonrpc": "2.0",
|
|
365
|
+
"method": "message/send",
|
|
366
|
+
"id": uuid4().hex,
|
|
367
|
+
"params": {
|
|
368
|
+
"message": {
|
|
369
|
+
"role": "user",
|
|
370
|
+
"parts": [{"type": "text", "text": json.dumps({"skill": "whoami", "params": {}})}],
|
|
371
|
+
"messageId": uuid4().hex,
|
|
372
|
+
}
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
response = client.post("/", json=request_body, headers={"X-API-Key": "my-key"})
|
|
377
|
+
result_text = response.json().get("result", {}).get("parts", [{}])[0].get("text", "")
|
|
378
|
+
assert result_text.startswith("user:")
|
|
@@ -1,177 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Tests for authentication providers.
|
|
3
|
-
"""
|
|
4
|
-
import pytest
|
|
5
|
-
from a2a_lite.auth import (
|
|
6
|
-
AuthRequest,
|
|
7
|
-
AuthResult,
|
|
8
|
-
NoAuth,
|
|
9
|
-
APIKeyAuth,
|
|
10
|
-
BearerAuth,
|
|
11
|
-
CompositeAuth,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class TestAuthResult:
|
|
16
|
-
def test_success(self):
|
|
17
|
-
result = AuthResult.success(user_id="user-123", scopes={"read", "write"})
|
|
18
|
-
|
|
19
|
-
assert result.authenticated is True
|
|
20
|
-
assert result.user_id == "user-123"
|
|
21
|
-
assert "read" in result.scopes
|
|
22
|
-
|
|
23
|
-
def test_failure(self):
|
|
24
|
-
result = AuthResult.failure("Invalid token")
|
|
25
|
-
|
|
26
|
-
assert result.authenticated is False
|
|
27
|
-
assert result.error == "Invalid token"
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class TestNoAuth:
|
|
31
|
-
@pytest.mark.asyncio
|
|
32
|
-
async def test_always_succeeds(self):
|
|
33
|
-
auth = NoAuth()
|
|
34
|
-
request = AuthRequest(headers={})
|
|
35
|
-
|
|
36
|
-
result = await auth.authenticate(request)
|
|
37
|
-
|
|
38
|
-
assert result.authenticated is True
|
|
39
|
-
assert result.user_id == "anonymous"
|
|
40
|
-
|
|
41
|
-
def test_empty_scheme(self):
|
|
42
|
-
auth = NoAuth()
|
|
43
|
-
assert auth.get_scheme() == {}
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class TestAPIKeyAuth:
|
|
47
|
-
@pytest.mark.asyncio
|
|
48
|
-
async def test_valid_key_in_header(self):
|
|
49
|
-
auth = APIKeyAuth(keys=["secret-key", "another-key"])
|
|
50
|
-
request = AuthRequest(headers={"X-API-Key": "secret-key"})
|
|
51
|
-
|
|
52
|
-
result = await auth.authenticate(request)
|
|
53
|
-
|
|
54
|
-
assert result.authenticated is True
|
|
55
|
-
assert result.user_id is not None
|
|
56
|
-
|
|
57
|
-
@pytest.mark.asyncio
|
|
58
|
-
async def test_invalid_key(self):
|
|
59
|
-
auth = APIKeyAuth(keys=["secret-key"])
|
|
60
|
-
request = AuthRequest(headers={"X-API-Key": "wrong-key"})
|
|
61
|
-
|
|
62
|
-
result = await auth.authenticate(request)
|
|
63
|
-
|
|
64
|
-
assert result.authenticated is False
|
|
65
|
-
assert "Invalid" in result.error
|
|
66
|
-
|
|
67
|
-
@pytest.mark.asyncio
|
|
68
|
-
async def test_missing_key(self):
|
|
69
|
-
auth = APIKeyAuth(keys=["secret-key"])
|
|
70
|
-
request = AuthRequest(headers={})
|
|
71
|
-
|
|
72
|
-
result = await auth.authenticate(request)
|
|
73
|
-
|
|
74
|
-
assert result.authenticated is False
|
|
75
|
-
assert "required" in result.error.lower()
|
|
76
|
-
|
|
77
|
-
@pytest.mark.asyncio
|
|
78
|
-
async def test_custom_header(self):
|
|
79
|
-
auth = APIKeyAuth(keys=["key123"], header="Authorization")
|
|
80
|
-
request = AuthRequest(headers={"Authorization": "key123"})
|
|
81
|
-
|
|
82
|
-
result = await auth.authenticate(request)
|
|
83
|
-
|
|
84
|
-
assert result.authenticated is True
|
|
85
|
-
|
|
86
|
-
@pytest.mark.asyncio
|
|
87
|
-
async def test_query_param(self):
|
|
88
|
-
auth = APIKeyAuth(keys=["key123"], query_param="api_key")
|
|
89
|
-
request = AuthRequest(headers={}, query_params={"api_key": "key123"})
|
|
90
|
-
|
|
91
|
-
result = await auth.authenticate(request)
|
|
92
|
-
|
|
93
|
-
assert result.authenticated is True
|
|
94
|
-
|
|
95
|
-
def test_scheme(self):
|
|
96
|
-
auth = APIKeyAuth(keys=["key"], header="X-API-Key")
|
|
97
|
-
scheme = auth.get_scheme()
|
|
98
|
-
|
|
99
|
-
assert scheme["type"] == "apiKey"
|
|
100
|
-
assert scheme["in"] == "header"
|
|
101
|
-
assert scheme["name"] == "X-API-Key"
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
class TestBearerAuth:
|
|
105
|
-
@pytest.mark.asyncio
|
|
106
|
-
async def test_valid_token(self):
|
|
107
|
-
def validator(token):
|
|
108
|
-
if token == "valid-token":
|
|
109
|
-
return "user-123"
|
|
110
|
-
return None
|
|
111
|
-
|
|
112
|
-
auth = BearerAuth(validator=validator)
|
|
113
|
-
request = AuthRequest(headers={"Authorization": "Bearer valid-token"})
|
|
114
|
-
|
|
115
|
-
result = await auth.authenticate(request)
|
|
116
|
-
|
|
117
|
-
assert result.authenticated is True
|
|
118
|
-
assert result.user_id == "user-123"
|
|
119
|
-
|
|
120
|
-
@pytest.mark.asyncio
|
|
121
|
-
async def test_invalid_token(self):
|
|
122
|
-
def validator(token):
|
|
123
|
-
return None
|
|
124
|
-
|
|
125
|
-
auth = BearerAuth(validator=validator)
|
|
126
|
-
request = AuthRequest(headers={"Authorization": "Bearer invalid"})
|
|
127
|
-
|
|
128
|
-
result = await auth.authenticate(request)
|
|
129
|
-
|
|
130
|
-
assert result.authenticated is False
|
|
131
|
-
|
|
132
|
-
@pytest.mark.asyncio
|
|
133
|
-
async def test_missing_bearer_prefix(self):
|
|
134
|
-
auth = BearerAuth(validator=lambda t: "user")
|
|
135
|
-
request = AuthRequest(headers={"Authorization": "token-without-bearer"})
|
|
136
|
-
|
|
137
|
-
result = await auth.authenticate(request)
|
|
138
|
-
|
|
139
|
-
assert result.authenticated is False
|
|
140
|
-
|
|
141
|
-
def test_scheme(self):
|
|
142
|
-
auth = BearerAuth(validator=lambda t: None)
|
|
143
|
-
scheme = auth.get_scheme()
|
|
144
|
-
|
|
145
|
-
assert scheme["type"] == "http"
|
|
146
|
-
assert scheme["scheme"] == "bearer"
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
class TestCompositeAuth:
|
|
150
|
-
@pytest.mark.asyncio
|
|
151
|
-
async def test_first_match_wins(self):
|
|
152
|
-
auth = CompositeAuth([
|
|
153
|
-
APIKeyAuth(keys=["api-key"]),
|
|
154
|
-
BearerAuth(validator=lambda t: "bearer-user" if t == "token" else None),
|
|
155
|
-
])
|
|
156
|
-
|
|
157
|
-
# API key should work
|
|
158
|
-
request1 = AuthRequest(headers={"X-API-Key": "api-key"})
|
|
159
|
-
result1 = await auth.authenticate(request1)
|
|
160
|
-
assert result1.authenticated is True
|
|
161
|
-
|
|
162
|
-
# Bearer should also work
|
|
163
|
-
request2 = AuthRequest(headers={"Authorization": "Bearer token"})
|
|
164
|
-
result2 = await auth.authenticate(request2)
|
|
165
|
-
assert result2.authenticated is True
|
|
166
|
-
|
|
167
|
-
@pytest.mark.asyncio
|
|
168
|
-
async def test_all_fail(self):
|
|
169
|
-
auth = CompositeAuth([
|
|
170
|
-
APIKeyAuth(keys=["key1"]),
|
|
171
|
-
APIKeyAuth(keys=["key2"]),
|
|
172
|
-
])
|
|
173
|
-
|
|
174
|
-
request = AuthRequest(headers={"X-API-Key": "wrong"})
|
|
175
|
-
result = await auth.authenticate(request)
|
|
176
|
-
|
|
177
|
-
assert result.authenticated is False
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|