a2a-lite 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
a2a_lite/__init__.py CHANGED
@@ -13,8 +13,8 @@ SIMPLE (8 lines):
13
13
  agent.run()
14
14
 
15
15
  TEST IT (3 lines):
16
- from a2a_lite import TestClient
17
- client = TestClient(agent)
16
+ from a2a_lite import AgentTestClient
17
+ client = AgentTestClient(agent)
18
18
  assert client.call("greet", name="World") == "Hello, World!"
19
19
 
20
20
  WITH PYDANTIC:
@@ -64,7 +64,7 @@ WITH FILES (opt-in):
64
64
  from .agent import Agent
65
65
  from .decorators import SkillDefinition
66
66
  from .discovery import AgentDiscovery, DiscoveredAgent
67
- from .testing import TestClient, AsyncTestClient
67
+ from .testing import AgentTestClient, AsyncAgentTestClient, TestResult
68
68
 
69
69
  # Middleware - always available
70
70
  from .middleware import (
@@ -102,7 +102,7 @@ from .auth import (
102
102
  require_auth,
103
103
  )
104
104
 
105
- __version__ = "0.1.0"
105
+ __version__ = "0.2.0"
106
106
 
107
107
  __all__ = [
108
108
  # Core
@@ -111,8 +111,9 @@ __all__ = [
111
111
  "AgentDiscovery",
112
112
  "DiscoveredAgent",
113
113
  # Testing
114
- "TestClient",
115
- "AsyncTestClient",
114
+ "AgentTestClient",
115
+ "AsyncAgentTestClient",
116
+ "TestResult",
116
117
  # Middleware
117
118
  "MiddlewareContext",
118
119
  "MiddlewareChain",
a2a_lite/agent.py CHANGED
@@ -7,9 +7,12 @@ from __future__ import annotations
7
7
 
8
8
  import asyncio
9
9
  import inspect
10
+ import logging
10
11
  from typing import Any, Callable, Optional, Dict, List, Type, Union, get_origin, get_args
11
12
  from dataclasses import dataclass, field
12
13
 
14
+ logger = logging.getLogger(__name__)
15
+
13
16
  import uvicorn
14
17
 
15
18
  from a2a.server.apps import A2AStarletteApplication
@@ -23,7 +26,7 @@ from a2a.types import (
23
26
 
24
27
  from .executor import LiteAgentExecutor
25
28
  from .decorators import SkillDefinition
26
- from .utils import type_to_json_schema, extract_function_schemas
29
+ from .utils import type_to_json_schema, extract_function_schemas, _is_or_subclass
27
30
  from .middleware import MiddlewareChain, MiddlewareContext
28
31
  from .streaming import is_generator_function
29
32
  from .webhooks import NotificationManager, WebhookClient
@@ -85,6 +88,8 @@ class Agent:
85
88
  # Optional enterprise features
86
89
  auth: Optional[Any] = None # AuthProvider
87
90
  task_store: Optional[Any] = None # TaskStore or "memory"
91
+ cors_origins: Optional[List[str]] = None
92
+ production: bool = False
88
93
 
89
94
  def __post_init__(self):
90
95
  # Internal state
@@ -161,16 +166,30 @@ class Agent:
161
166
  if is_streaming:
162
167
  self._has_streaming = True
163
168
 
164
- # Detect special parameter types
165
- hints = getattr(func, '__annotations__', {})
166
- needs_task_context = any(
167
- str(h).endswith("TaskContext") or "TaskContext" in str(h)
168
- for h in hints.values()
169
- )
170
- needs_interaction = any(
171
- str(h).endswith("InteractionContext") or "InteractionContext" in str(h)
172
- for h in hints.values()
173
- )
169
+ # Detect special parameter types using proper type introspection
170
+ import typing
171
+ from .tasks import TaskContext as _TaskContext
172
+ from .human_loop import InteractionContext as _InteractionContext
173
+
174
+ needs_task_context = False
175
+ needs_interaction = False
176
+ task_context_param: str | None = None
177
+ interaction_param: str | None = None
178
+
179
+ try:
180
+ resolved_hints = typing.get_type_hints(func)
181
+ except Exception:
182
+ resolved_hints = getattr(func, '__annotations__', {})
183
+
184
+ for param_name, hint in resolved_hints.items():
185
+ if param_name == 'return':
186
+ continue
187
+ if _is_or_subclass(hint, _TaskContext):
188
+ needs_task_context = True
189
+ task_context_param = param_name
190
+ elif _is_or_subclass(hint, _InteractionContext):
191
+ needs_interaction = True
192
+ interaction_param = param_name
174
193
 
175
194
  # Extract schemas
176
195
  input_schema, output_schema = extract_function_schemas(func)
@@ -186,6 +205,8 @@ class Agent:
186
205
  is_streaming=is_streaming,
187
206
  needs_task_context=needs_task_context,
188
207
  needs_interaction=needs_interaction,
208
+ task_context_param=task_context_param,
209
+ interaction_param=interaction_param,
189
210
  )
190
211
 
191
212
  self._skills[skill_name] = skill_def
@@ -375,10 +396,34 @@ class Agent:
375
396
  )
376
397
  console.print(f"[dim]mDNS discovery enabled for {self.name}[/]")
377
398
 
399
+ # Production mode warning
400
+ if self.production:
401
+ url_str = self.url or f"http://{display_host}:{port}"
402
+ if not url_str.startswith("https://"):
403
+ logger.warning(
404
+ "Running in production mode over HTTP. "
405
+ "Consider using HTTPS for secure communication."
406
+ )
407
+
408
+ # Build the ASGI app
409
+ app = app_builder.build()
410
+
411
+ # Add CORS middleware if configured
412
+ if self.cors_origins is not None:
413
+ from starlette.middleware.cors import CORSMiddleware
414
+ from starlette.middleware import Middleware as StarletteMiddleware
415
+ # Wrap the existing app with CORS
416
+ app.add_middleware(
417
+ CORSMiddleware,
418
+ allow_origins=self.cors_origins,
419
+ allow_methods=["*"],
420
+ allow_headers=["*"],
421
+ )
422
+
378
423
  # Start server
379
424
  try:
380
425
  uvicorn.run(
381
- app_builder.build(),
426
+ app,
382
427
  host=host,
383
428
  port=port,
384
429
  log_level=log_level,
@@ -450,4 +495,16 @@ class Agent:
450
495
  http_handler=request_handler,
451
496
  )
452
497
 
453
- return app_builder.build()
498
+ app = app_builder.build()
499
+
500
+ # Add CORS middleware if configured
501
+ if self.cors_origins is not None:
502
+ from starlette.middleware.cors import CORSMiddleware
503
+ app.add_middleware(
504
+ CORSMiddleware,
505
+ allow_origins=self.cors_origins,
506
+ allow_methods=["*"],
507
+ allow_headers=["*"],
508
+ )
509
+
510
+ return app
a2a_lite/auth.py CHANGED
@@ -116,10 +116,17 @@ class APIKeyAuth(AuthProvider):
116
116
  header: str = "X-API-Key",
117
117
  query_param: Optional[str] = None,
118
118
  ):
119
- self.keys = set(keys)
119
+ # Store only hashes of keys for security
120
+ self._key_hashes = {
121
+ hashlib.sha256(k.encode()).hexdigest() for k in keys
122
+ }
120
123
  self.header = header
121
124
  self.query_param = query_param
122
125
 
126
+ def _hash_key(self, key: str) -> str:
127
+ """Hash a key using SHA-256."""
128
+ return hashlib.sha256(key.encode()).hexdigest()
129
+
123
130
  async def authenticate(self, request: AuthRequest) -> AuthResult:
124
131
  # Check header
125
132
  key = request.headers.get(self.header)
@@ -131,11 +138,12 @@ class APIKeyAuth(AuthProvider):
131
138
  if not key:
132
139
  return AuthResult.failure("API key required")
133
140
 
134
- if key not in self.keys:
141
+ key_hash = self._hash_key(key)
142
+ if key_hash not in self._key_hashes:
135
143
  return AuthResult.failure("Invalid API key")
136
144
 
137
- # Use hash of key as user ID
138
- user_id = hashlib.sha256(key.encode()).hexdigest()[:16]
145
+ # Use hash prefix as user ID
146
+ user_id = key_hash[:16]
139
147
  return AuthResult.success(user_id=user_id)
140
148
 
141
149
  def get_scheme(self) -> Dict[str, Any]:
a2a_lite/cli.py CHANGED
@@ -75,7 +75,7 @@ version = "0.1.0"
75
75
  description = "A2A Agent: {name}"
76
76
  requires-python = ">=3.10"
77
77
  dependencies = [
78
- "a2a-lite>=0.1.0",
78
+ "a2a-lite>=0.2.0",
79
79
  ]
80
80
  '''
81
81
  (project_path / "pyproject.toml").write_text(pyproject)
a2a_lite/decorators.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Decorator definitions and skill metadata.
3
3
  """
4
4
  from dataclasses import dataclass, field
5
- from typing import Any, Callable, Dict, List
5
+ from typing import Any, Callable, Dict, List, Optional
6
6
 
7
7
 
8
8
  @dataclass
@@ -18,6 +18,8 @@ class SkillDefinition:
18
18
  is_streaming: bool = False
19
19
  needs_task_context: bool = False
20
20
  needs_interaction: bool = False
21
+ task_context_param: Optional[str] = None
22
+ interaction_param: Optional[str] = None
21
23
 
22
24
  def to_dict(self) -> Dict[str, Any]:
23
25
  """Convert to dictionary for serialization."""
a2a_lite/discovery.py CHANGED
@@ -4,10 +4,13 @@ mDNS-based local agent discovery using Zeroconf.
4
4
  from __future__ import annotations
5
5
 
6
6
  import asyncio
7
+ import logging
7
8
  import socket
8
9
  from typing import Dict, List, Optional
9
10
  from dataclasses import dataclass
10
11
 
12
+ logger = logging.getLogger(__name__)
13
+
11
14
  from zeroconf import ServiceBrowser, ServiceListener, Zeroconf, ServiceInfo
12
15
  from zeroconf.asyncio import AsyncZeroconf, AsyncServiceBrowser
13
16
 
@@ -113,6 +116,7 @@ class AgentDiscovery:
113
116
  s.close()
114
117
  return ip
115
118
  except Exception:
119
+ logger.debug("Could not detect local IP, falling back to 127.0.0.1")
116
120
  return "127.0.0.1"
117
121
 
118
122
 
a2a_lite/executor.py CHANGED
@@ -6,14 +6,18 @@ from __future__ import annotations
6
6
  import asyncio
7
7
  import inspect
8
8
  import json
9
+ import logging
9
10
  from typing import Any, Callable, Dict, List, Optional
10
11
 
12
+ logger = logging.getLogger(__name__)
13
+
11
14
  from a2a.server.agent_execution import AgentExecutor, RequestContext
12
15
  from a2a.server.events import EventQueue
13
16
 
14
17
  from .decorators import SkillDefinition
15
18
  from .middleware import MiddlewareChain, MiddlewareContext
16
19
  from .streaming import is_generator_function, stream_generator
20
+ from .utils import _is_or_subclass
17
21
 
18
22
 
19
23
  class LiteAgentExecutor(AgentExecutor):
@@ -100,7 +104,7 @@ class LiteAgentExecutor(AgentExecutor):
100
104
  else:
101
105
  hook(skill_name, result, ctx)
102
106
  except Exception:
103
- pass
107
+ logger.warning("Completion hook error for skill '%s'", skill_name, exc_info=True)
104
108
 
105
109
  except Exception as e:
106
110
  await self._handle_error(e, event_queue)
@@ -132,17 +136,19 @@ class LiteAgentExecutor(AgentExecutor):
132
136
  # Inject special contexts if needed
133
137
  if skill_def.needs_task_context and self.task_store:
134
138
  from .tasks import TaskContext, Task, TaskStatus, TaskState
135
- task = self.task_store.create(skill_name, params)
139
+ task = await self.task_store.create(skill_name, params)
136
140
  # Only pass event_queue for streaming skills (status updates go via SSE)
137
141
  eq = event_queue if skill_def.is_streaming else None
138
142
  task_ctx = TaskContext(task, eq)
139
- params["task"] = task_ctx
143
+ param_name = skill_def.task_context_param or "task"
144
+ params[param_name] = task_ctx
140
145
 
141
146
  if skill_def.needs_interaction:
142
147
  from .human_loop import InteractionContext
143
148
  task_id = metadata.get("task_id", "unknown")
144
149
  interaction_ctx = InteractionContext(task_id, event_queue)
145
- params["ctx"] = interaction_ctx
150
+ param_name = skill_def.interaction_param or "ctx"
151
+ params[param_name] = interaction_ctx
146
152
 
147
153
  # Call the handler
148
154
  handler = skill_def.handler
@@ -172,13 +178,14 @@ class LiteAgentExecutor(AgentExecutor):
172
178
  converted[param_name] = value
173
179
  continue
174
180
 
175
- type_name = str(param_type)
176
-
177
181
  # Skip special context types
178
- if "TaskContext" in type_name or "InteractionContext" in type_name:
182
+ from .tasks import TaskContext as _TaskContext
183
+ from .human_loop import InteractionContext as _InteractionContext
184
+ if _is_or_subclass(param_type, _TaskContext) or _is_or_subclass(param_type, _InteractionContext):
179
185
  continue
180
186
 
181
187
  # Convert FilePart
188
+ type_name = str(param_type)
182
189
  if "FilePart" in type_name:
183
190
  from .parts import FilePart
184
191
  if isinstance(value, dict):
@@ -234,7 +241,7 @@ class LiteAgentExecutor(AgentExecutor):
234
241
  if isinstance(data, dict) and 'skill' in data:
235
242
  return data['skill'], data.get('params', {})
236
243
  except json.JSONDecodeError:
237
- pass
244
+ logger.debug("Message is not JSON, treating as plain text")
238
245
 
239
246
  return None, {"message": message}
240
247
 
a2a_lite/tasks.py CHANGED
@@ -27,12 +27,15 @@ Example (with task tracking - opt-in):
27
27
  """
28
28
  from __future__ import annotations
29
29
 
30
+ import asyncio
31
+ import logging
30
32
  from dataclasses import dataclass, field
31
33
  from datetime import datetime
32
34
  from enum import Enum
33
35
  from typing import Any, Callable, Dict, List, Optional
34
36
  from uuid import uuid4
35
- import asyncio
37
+
38
+ logger = logging.getLogger(__name__)
36
39
 
37
40
 
38
41
  class TaskState(str, Enum):
@@ -145,7 +148,7 @@ class TaskContext:
145
148
  else:
146
149
  callback(self._task.status)
147
150
  except Exception:
148
- pass
151
+ logger.warning("Status callback error for task '%s'", self._task.id, exc_info=True)
149
152
 
150
153
  # Send SSE event if streaming
151
154
  if self._event_queue:
@@ -170,52 +173,58 @@ class TaskContext:
170
173
 
171
174
  class TaskStore:
172
175
  """
173
- In-memory task store.
176
+ In-memory task store with async locking for thread safety.
174
177
 
175
178
  For production, extend this with Redis/DB backend.
176
179
  """
177
180
 
178
181
  def __init__(self):
179
182
  self._tasks: Dict[str, Task] = {}
183
+ self._lock = asyncio.Lock()
180
184
 
181
- def create(self, skill: str, params: Dict[str, Any]) -> Task:
185
+ async def create(self, skill: str, params: Dict[str, Any]) -> Task:
182
186
  """Create a new task."""
183
- task = Task(
184
- id=uuid4().hex,
185
- skill=skill,
186
- params=params,
187
- status=TaskStatus(state=TaskState.SUBMITTED),
188
- )
189
- self._tasks[task.id] = task
190
- return task
191
-
192
- def get(self, task_id: str) -> Optional[Task]:
187
+ async with self._lock:
188
+ task = Task(
189
+ id=uuid4().hex,
190
+ skill=skill,
191
+ params=params,
192
+ status=TaskStatus(state=TaskState.SUBMITTED),
193
+ )
194
+ self._tasks[task.id] = task
195
+ return task
196
+
197
+ async def get(self, task_id: str) -> Optional[Task]:
193
198
  """Get task by ID."""
194
- return self._tasks.get(task_id)
199
+ async with self._lock:
200
+ return self._tasks.get(task_id)
195
201
 
196
- def update(self, task: Task) -> None:
202
+ async def update(self, task: Task) -> None:
197
203
  """Update task in store."""
198
- self._tasks[task.id] = task
204
+ async with self._lock:
205
+ self._tasks[task.id] = task
199
206
 
200
- def list(
207
+ async def list(
201
208
  self,
202
209
  state: Optional[TaskState] = None,
203
210
  skill: Optional[str] = None,
204
211
  limit: int = 100,
205
212
  ) -> List[Task]:
206
213
  """List tasks with optional filters."""
207
- tasks = list(self._tasks.values())
214
+ async with self._lock:
215
+ tasks = list(self._tasks.values())
208
216
 
209
- if state:
210
- tasks = [t for t in tasks if t.status.state == state]
211
- if skill:
212
- tasks = [t for t in tasks if t.skill == skill]
217
+ if state:
218
+ tasks = [t for t in tasks if t.status.state == state]
219
+ if skill:
220
+ tasks = [t for t in tasks if t.skill == skill]
213
221
 
214
- return sorted(tasks, key=lambda t: t.created_at, reverse=True)[:limit]
222
+ return sorted(tasks, key=lambda t: t.created_at, reverse=True)[:limit]
215
223
 
216
- def delete(self, task_id: str) -> bool:
224
+ async def delete(self, task_id: str) -> bool:
217
225
  """Delete a task."""
218
- if task_id in self._tasks:
219
- del self._tasks[task_id]
220
- return True
221
- return False
226
+ async with self._lock:
227
+ if task_id in self._tasks:
228
+ del self._tasks[task_id]
229
+ return True
230
+ return False
a2a_lite/testing.py CHANGED
@@ -3,21 +3,61 @@ Testing utilities for A2A Lite agents.
3
3
 
4
4
  Makes testing agents as simple as:
5
5
 
6
- from a2a_lite.testing import TestClient
6
+ from a2a_lite.testing import AgentTestClient
7
7
 
8
8
  def test_my_agent():
9
- client = TestClient(agent)
9
+ client = AgentTestClient(agent)
10
10
  result = client.call("greet", name="World")
11
11
  assert result == "Hello, World!"
12
12
  """
13
13
  from __future__ import annotations
14
14
 
15
15
  import json
16
+ from dataclasses import dataclass
16
17
  from typing import Any, Dict, Optional
17
18
  from uuid import uuid4
18
19
 
19
20
 
20
- class TestClient:
21
+ @dataclass
22
+ class TestResult:
23
+ """
24
+ Structured result from a test client call.
25
+
26
+ Provides multiple ways to access the result:
27
+ - .data — parsed Python object (dict, list, int, str, etc.)
28
+ - .text — raw text string
29
+ - .json() — parse text as JSON (raises on invalid JSON)
30
+ - .raw_response — the full A2A response dict
31
+ """
32
+ _data: Any
33
+ _text: str
34
+ raw_response: Dict[str, Any]
35
+
36
+ @property
37
+ def data(self) -> Any:
38
+ """The parsed result value."""
39
+ return self._data
40
+
41
+ @property
42
+ def text(self) -> str:
43
+ """The raw text representation."""
44
+ return self._text
45
+
46
+ def json(self) -> Any:
47
+ """Parse the text as JSON."""
48
+ return json.loads(self._text)
49
+
50
+ def __eq__(self, other: Any) -> bool:
51
+ """Allow direct comparison with the data value for convenience."""
52
+ if isinstance(other, TestResult):
53
+ return self._data == other._data
54
+ return self._data == other
55
+
56
+ def __repr__(self) -> str:
57
+ return f"TestResult(data={self._data!r})"
58
+
59
+
60
+ class AgentTestClient:
21
61
  """
22
62
  Simple test client for A2A Lite agents.
23
63
 
@@ -29,7 +69,7 @@ class TestClient:
29
69
  return a + b
30
70
 
31
71
  # In your test
32
- client = TestClient(agent)
72
+ client = AgentTestClient(agent)
33
73
  assert client.call("add", a=2, b=3) == 5
34
74
  """
35
75
 
@@ -89,7 +129,7 @@ class TestClient:
89
129
  # Extract the actual result from A2A response
90
130
  return self._extract_result(data)
91
131
 
92
- def _extract_result(self, response: Dict) -> Any:
132
+ def _extract_result(self, response: Dict) -> TestResult:
93
133
  """Extract the skill result from A2A response."""
94
134
  if "error" in response:
95
135
  raise TestClientError(response["error"])
@@ -103,11 +143,12 @@ class TestClient:
103
143
  text = part.get("text", "")
104
144
  # Try to parse as JSON
105
145
  try:
106
- return json.loads(text)
146
+ data = json.loads(text)
107
147
  except json.JSONDecodeError:
108
- return text
148
+ data = text
149
+ return TestResult(_data=data, _text=text, raw_response=response)
109
150
 
110
- return result
151
+ return TestResult(_data=result, _text=json.dumps(result), raw_response=response)
111
152
 
112
153
  def get_agent_card(self) -> Dict[str, Any]:
113
154
  """
@@ -183,13 +224,13 @@ class TestClientError(Exception):
183
224
 
184
225
 
185
226
  # Async version for async tests
186
- class AsyncTestClient:
227
+ class AsyncAgentTestClient:
187
228
  """
188
229
  Async test client for A2A Lite agents.
189
230
 
190
231
  Example:
191
232
  async def test_my_agent():
192
- client = AsyncTestClient(agent)
233
+ client = AsyncAgentTestClient(agent)
193
234
  result = await client.call("greet", name="World")
194
235
  assert result == "Hello, World!"
195
236
  """
@@ -243,7 +284,7 @@ class AsyncTestClient:
243
284
 
244
285
  return self._extract_result(data)
245
286
 
246
- def _extract_result(self, response: Dict) -> Any:
287
+ def _extract_result(self, response: Dict) -> TestResult:
247
288
  """Extract the skill result from A2A response."""
248
289
  if "error" in response:
249
290
  raise TestClientError(response["error"])
@@ -255,11 +296,12 @@ class AsyncTestClient:
255
296
  if part.get("kind") == "text" or part.get("type") == "text":
256
297
  text = part.get("text", "")
257
298
  try:
258
- return json.loads(text)
299
+ data = json.loads(text)
259
300
  except json.JSONDecodeError:
260
- return text
301
+ data = text
302
+ return TestResult(_data=data, _text=text, raw_response=response)
261
303
 
262
- return result
304
+ return TestResult(_data=result, _text=json.dumps(result), raw_response=response)
263
305
 
264
306
  async def close(self):
265
307
  """Close the client."""
a2a_lite/utils.py CHANGED
@@ -5,6 +5,22 @@ from typing import Any, Dict, Type, get_origin, get_args, Union
5
5
  import inspect
6
6
 
7
7
 
8
+ def _is_or_subclass(hint: Any, target_class: Type) -> bool:
9
+ """
10
+ Check if a type hint is, or is a subclass of, the target class.
11
+
12
+ Works with raw classes and string annotations.
13
+ """
14
+ try:
15
+ if hint is target_class:
16
+ return True
17
+ if isinstance(hint, type) and issubclass(hint, target_class):
18
+ return True
19
+ except TypeError:
20
+ pass
21
+ return False
22
+
23
+
8
24
  def type_to_json_schema(python_type: Type) -> Dict[str, Any]:
9
25
  """
10
26
  Convert Python type to JSON Schema.