jaf-py 2.5.10__py3-none-any.whl → 2.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. jaf/__init__.py +154 -57
  2. jaf/a2a/__init__.py +42 -21
  3. jaf/a2a/agent.py +79 -126
  4. jaf/a2a/agent_card.py +87 -78
  5. jaf/a2a/client.py +30 -66
  6. jaf/a2a/examples/client_example.py +12 -12
  7. jaf/a2a/examples/integration_example.py +38 -47
  8. jaf/a2a/examples/server_example.py +56 -53
  9. jaf/a2a/memory/__init__.py +0 -4
  10. jaf/a2a/memory/cleanup.py +28 -21
  11. jaf/a2a/memory/factory.py +155 -133
  12. jaf/a2a/memory/providers/composite.py +21 -26
  13. jaf/a2a/memory/providers/in_memory.py +89 -83
  14. jaf/a2a/memory/providers/postgres.py +117 -115
  15. jaf/a2a/memory/providers/redis.py +128 -121
  16. jaf/a2a/memory/serialization.py +77 -87
  17. jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
  18. jaf/a2a/memory/tests/test_cleanup.py +211 -94
  19. jaf/a2a/memory/tests/test_serialization.py +73 -68
  20. jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
  21. jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
  22. jaf/a2a/memory/types.py +91 -53
  23. jaf/a2a/protocol.py +95 -125
  24. jaf/a2a/server.py +90 -118
  25. jaf/a2a/standalone_client.py +30 -43
  26. jaf/a2a/tests/__init__.py +16 -33
  27. jaf/a2a/tests/run_tests.py +17 -53
  28. jaf/a2a/tests/test_agent.py +40 -140
  29. jaf/a2a/tests/test_client.py +54 -117
  30. jaf/a2a/tests/test_integration.py +28 -82
  31. jaf/a2a/tests/test_protocol.py +54 -139
  32. jaf/a2a/tests/test_types.py +50 -136
  33. jaf/a2a/types.py +58 -34
  34. jaf/cli.py +21 -41
  35. jaf/core/__init__.py +7 -1
  36. jaf/core/agent_tool.py +93 -72
  37. jaf/core/analytics.py +257 -207
  38. jaf/core/checkpoint.py +223 -0
  39. jaf/core/composition.py +249 -235
  40. jaf/core/engine.py +817 -519
  41. jaf/core/errors.py +55 -42
  42. jaf/core/guardrails.py +276 -202
  43. jaf/core/handoff.py +47 -31
  44. jaf/core/parallel_agents.py +69 -75
  45. jaf/core/performance.py +75 -73
  46. jaf/core/proxy.py +43 -44
  47. jaf/core/proxy_helpers.py +24 -27
  48. jaf/core/regeneration.py +220 -129
  49. jaf/core/state.py +68 -66
  50. jaf/core/streaming.py +115 -108
  51. jaf/core/tool_results.py +111 -101
  52. jaf/core/tools.py +114 -116
  53. jaf/core/tracing.py +269 -210
  54. jaf/core/types.py +371 -151
  55. jaf/core/workflows.py +209 -168
  56. jaf/exceptions.py +46 -38
  57. jaf/memory/__init__.py +1 -6
  58. jaf/memory/approval_storage.py +54 -77
  59. jaf/memory/factory.py +4 -4
  60. jaf/memory/providers/in_memory.py +216 -180
  61. jaf/memory/providers/postgres.py +216 -146
  62. jaf/memory/providers/redis.py +173 -116
  63. jaf/memory/types.py +70 -51
  64. jaf/memory/utils.py +36 -34
  65. jaf/plugins/__init__.py +12 -12
  66. jaf/plugins/base.py +105 -96
  67. jaf/policies/__init__.py +0 -1
  68. jaf/policies/handoff.py +37 -46
  69. jaf/policies/validation.py +76 -52
  70. jaf/providers/__init__.py +6 -3
  71. jaf/providers/mcp.py +97 -51
  72. jaf/providers/model.py +360 -279
  73. jaf/server/__init__.py +1 -1
  74. jaf/server/main.py +7 -11
  75. jaf/server/server.py +514 -359
  76. jaf/server/types.py +208 -52
  77. jaf/utils/__init__.py +17 -18
  78. jaf/utils/attachments.py +111 -116
  79. jaf/utils/document_processor.py +175 -174
  80. jaf/visualization/__init__.py +1 -1
  81. jaf/visualization/example.py +111 -110
  82. jaf/visualization/functional_core.py +46 -71
  83. jaf/visualization/graphviz.py +154 -189
  84. jaf/visualization/imperative_shell.py +7 -16
  85. jaf/visualization/types.py +8 -4
  86. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
  87. jaf_py-2.5.11.dist-info/RECORD +97 -0
  88. jaf_py-2.5.10.dist-info/RECORD +0 -96
  89. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
  90. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
  91. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
  92. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
jaf/a2a/protocol.py CHANGED
@@ -12,29 +12,23 @@ from .types import A2AAgent, A2AErrorCodes
12
12
  def validate_jsonrpc_request(request: Dict[str, Any]) -> bool:
13
13
  """Pure function to validate JSON-RPC request"""
14
14
  return (
15
- isinstance(request, dict) and
16
- request.get("jsonrpc") == "2.0" and
17
- "id" in request and
18
- isinstance(request.get("method"), str)
15
+ isinstance(request, dict)
16
+ and request.get("jsonrpc") == "2.0"
17
+ and "id" in request
18
+ and isinstance(request.get("method"), str)
19
19
  )
20
20
 
21
21
 
22
22
  def create_jsonrpc_success_response_dict(id: Union[str, int, None], result: Any) -> Dict[str, Any]:
23
23
  """Pure function to create JSON-RPC success response as dict"""
24
- return {
25
- "jsonrpc": "2.0",
26
- "id": id,
27
- "result": result
28
- }
24
+ return {"jsonrpc": "2.0", "id": id, "result": result}
29
25
 
30
26
 
31
- def create_jsonrpc_error_response_dict(id: Union[str, int, None], error: Dict[str, Any]) -> Dict[str, Any]:
27
+ def create_jsonrpc_error_response_dict(
28
+ id: Union[str, int, None], error: Dict[str, Any]
29
+ ) -> Dict[str, Any]:
32
30
  """Pure function to create JSON-RPC error response as dict"""
33
- return {
34
- "jsonrpc": "2.0",
35
- "id": id,
36
- "error": error
37
- }
31
+ return {"jsonrpc": "2.0", "id": id, "error": error}
38
32
 
39
33
 
40
34
  def map_error_to_a2a_error(error: Exception) -> Dict[str, Any]:
@@ -43,13 +37,10 @@ def map_error_to_a2a_error(error: Exception) -> Dict[str, Any]:
43
37
  return {
44
38
  "code": A2AErrorCodes.INTERNAL_ERROR.value,
45
39
  "message": str(error),
46
- "data": {"type": type(error).__name__}
40
+ "data": {"type": type(error).__name__},
47
41
  }
48
42
 
49
- return {
50
- "code": A2AErrorCodes.INTERNAL_ERROR.value,
51
- "message": "Unknown error occurred"
52
- }
43
+ return {"code": A2AErrorCodes.INTERNAL_ERROR.value, "message": "Unknown error occurred"}
53
44
 
54
45
 
55
46
  def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
@@ -61,8 +52,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
61
52
  "is_valid": False,
62
53
  "error": {
63
54
  "code": A2AErrorCodes.INVALID_REQUEST.value,
64
- "message": "Invalid JSON-RPC request"
65
- }
55
+ "message": "Invalid JSON-RPC request",
56
+ },
66
57
  }
67
58
 
68
59
  if request.get("method") not in ["message/send", "message/stream"]:
@@ -70,8 +61,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
70
61
  "is_valid": False,
71
62
  "error": {
72
63
  "code": A2AErrorCodes.METHOD_NOT_FOUND.value,
73
- "message": f"Method {request.get('method')} not supported"
74
- }
64
+ "message": f"Method {request.get('method')} not supported",
65
+ },
75
66
  }
76
67
 
77
68
  params = request.get("params", {})
@@ -81,8 +72,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
81
72
  "error": {
82
73
  "code": A2AErrorCodes.INVALID_PARAMS.value,
83
74
  "message": "Invalid params format - must be an object",
84
- "data": {"expected": "object", "received": type(params).__name__}
85
- }
75
+ "data": {"expected": "object", "received": type(params).__name__},
76
+ },
86
77
  }
87
78
 
88
79
  message = params.get("message")
@@ -92,8 +83,11 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
92
83
  "error": {
93
84
  "code": A2AErrorCodes.INVALID_PARAMS.value,
94
85
  "message": "Invalid message format - message must be an object",
95
- "data": {"expected": "object", "received": type(message).__name__ if message is not None else "null"}
96
- }
86
+ "data": {
87
+ "expected": "object",
88
+ "received": type(message).__name__ if message is not None else "null",
89
+ },
90
+ },
97
91
  }
98
92
 
99
93
  # Validate required message fields
@@ -105,8 +99,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
105
99
  "error": {
106
100
  "code": A2AErrorCodes.INVALID_PARAMS.value,
107
101
  "message": f"Missing required message fields: {', '.join(missing_fields)}",
108
- "data": {"missing_fields": missing_fields}
109
- }
102
+ "data": {"missing_fields": missing_fields},
103
+ },
110
104
  }
111
105
 
112
106
  # Validate message structure
@@ -116,8 +110,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
116
110
  "error": {
117
111
  "code": A2AErrorCodes.INVALID_PARAMS.value,
118
112
  "message": "Message kind must be 'message'",
119
- "data": {"expected": "message", "received": message.get("kind")}
120
- }
113
+ "data": {"expected": "message", "received": message.get("kind")},
114
+ },
121
115
  }
122
116
 
123
117
  if message.get("role") not in ["user", "agent"]:
@@ -126,8 +120,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
126
120
  "error": {
127
121
  "code": A2AErrorCodes.INVALID_PARAMS.value,
128
122
  "message": "Message role must be 'user' or 'agent'",
129
- "data": {"expected": ["user", "agent"], "received": message.get("role")}
130
- }
123
+ "data": {"expected": ["user", "agent"], "received": message.get("role")},
124
+ },
131
125
  }
132
126
 
133
127
  if not isinstance(message.get("parts"), list):
@@ -136,8 +130,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
136
130
  "error": {
137
131
  "code": A2AErrorCodes.INVALID_PARAMS.value,
138
132
  "message": "Message parts must be a list",
139
- "data": {"expected": "array", "received": type(message.get("parts")).__name__}
140
- }
133
+ "data": {"expected": "array", "received": type(message.get("parts")).__name__},
134
+ },
141
135
  }
142
136
 
143
137
  if len(message.get("parts", [])) == 0:
@@ -146,8 +140,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
146
140
  "error": {
147
141
  "code": A2AErrorCodes.INVALID_PARAMS.value,
148
142
  "message": "Message parts cannot be empty",
149
- "data": {"minimum_parts": 1}
150
- }
143
+ "data": {"minimum_parts": 1},
144
+ },
151
145
  }
152
146
 
153
147
  # Validate parts structure
@@ -158,8 +152,12 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
158
152
  "error": {
159
153
  "code": A2AErrorCodes.INVALID_PARAMS.value,
160
154
  "message": f"Message part {i} must be an object",
161
- "data": {"part_index": i, "expected": "object", "received": type(part).__name__}
162
- }
155
+ "data": {
156
+ "part_index": i,
157
+ "expected": "object",
158
+ "received": type(part).__name__,
159
+ },
160
+ },
163
161
  }
164
162
 
165
163
  if "kind" not in part:
@@ -168,8 +166,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
168
166
  "error": {
169
167
  "code": A2AErrorCodes.INVALID_PARAMS.value,
170
168
  "message": f"Message part {i} missing 'kind' field",
171
- "data": {"part_index": i, "missing_field": "kind"}
172
- }
169
+ "data": {"part_index": i, "missing_field": "kind"},
170
+ },
173
171
  }
174
172
 
175
173
  kind = part.get("kind")
@@ -179,8 +177,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
179
177
  "error": {
180
178
  "code": A2AErrorCodes.INVALID_PARAMS.value,
181
179
  "message": f"Text part {i} missing 'text' field",
182
- "data": {"part_index": i, "part_kind": "text", "missing_field": "text"}
183
- }
180
+ "data": {"part_index": i, "part_kind": "text", "missing_field": "text"},
181
+ },
184
182
  }
185
183
  elif kind == "data" and "data" not in part:
186
184
  return {
@@ -188,8 +186,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
188
186
  "error": {
189
187
  "code": A2AErrorCodes.INVALID_PARAMS.value,
190
188
  "message": f"Data part {i} missing 'data' field",
191
- "data": {"part_index": i, "part_kind": "data", "missing_field": "data"}
192
- }
189
+ "data": {"part_index": i, "part_kind": "data", "missing_field": "data"},
190
+ },
193
191
  }
194
192
  elif kind == "file" and "file" not in part:
195
193
  return {
@@ -197,8 +195,8 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
197
195
  "error": {
198
196
  "code": A2AErrorCodes.INVALID_PARAMS.value,
199
197
  "message": f"File part {i} missing 'file' field",
200
- "data": {"part_index": i, "part_kind": "file", "missing_field": "file"}
201
- }
198
+ "data": {"part_index": i, "part_kind": "file", "missing_field": "file"},
199
+ },
202
200
  }
203
201
  elif kind not in ["text", "data", "file"]:
204
202
  return {
@@ -206,14 +204,15 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
206
204
  "error": {
207
205
  "code": A2AErrorCodes.INVALID_PARAMS.value,
208
206
  "message": f"Unknown part kind '{kind}' in part {i}",
209
- "data": {"part_index": i, "supported_kinds": ["text", "data", "file"], "received": kind}
210
- }
207
+ "data": {
208
+ "part_index": i,
209
+ "supported_kinds": ["text", "data", "file"],
210
+ "received": kind,
211
+ },
212
+ },
211
213
  }
212
214
 
213
- return {
214
- "is_valid": True,
215
- "data": request
216
- }
215
+ return {"is_valid": True, "data": request}
217
216
 
218
217
  except Exception as e:
219
218
  return {
@@ -221,16 +220,16 @@ def validate_send_message_request(request: Dict[str, Any]) -> Dict[str, Any]:
221
220
  "error": {
222
221
  "code": A2AErrorCodes.INTERNAL_ERROR.value,
223
222
  "message": f"Request validation failed: {e!s}",
224
- "data": {"error_type": type(e).__name__}
225
- }
223
+ "data": {"error_type": type(e).__name__},
224
+ },
226
225
  }
227
226
 
228
227
 
229
228
  async def handle_message_send(
230
229
  request: Dict[str, Any],
231
- agent: 'A2AAgent', # Forward reference to avoid circular imports
230
+ agent: "A2AAgent", # Forward reference to avoid circular imports
232
231
  model_provider: Any,
233
- executor_func: Callable
232
+ executor_func: Callable,
234
233
  ) -> Dict[str, Any]:
235
234
  """Pure function to handle message/send method"""
236
235
  try:
@@ -240,7 +239,7 @@ async def handle_message_send(
240
239
  context = {
241
240
  "message": message,
242
241
  "session_id": message.get("contextId", f"session_{id(request)}"),
243
- "metadata": params.get("metadata")
242
+ "metadata": params.get("metadata"),
244
243
  }
245
244
 
246
245
  result = await executor_func(context, agent, model_provider)
@@ -248,29 +247,22 @@ async def handle_message_send(
248
247
  if result.get("error"):
249
248
  return create_jsonrpc_error_response_dict(
250
249
  request.get("id"),
251
- {
252
- "code": A2AErrorCodes.INTERNAL_ERROR.value,
253
- "message": result["error"]
254
- }
250
+ {"code": A2AErrorCodes.INTERNAL_ERROR.value, "message": result["error"]},
255
251
  )
256
252
 
257
253
  return create_jsonrpc_success_response_dict(
258
- request.get("id"),
259
- result.get("final_task", {"message": "No result available"})
254
+ request.get("id"), result.get("final_task", {"message": "No result available"})
260
255
  )
261
256
 
262
257
  except Exception as error:
263
- return create_jsonrpc_error_response_dict(
264
- request.get("id"),
265
- map_error_to_a2a_error(error)
266
- )
258
+ return create_jsonrpc_error_response_dict(request.get("id"), map_error_to_a2a_error(error))
267
259
 
268
260
 
269
261
  async def handle_message_stream(
270
262
  request: Dict[str, Any],
271
- agent: 'A2AAgent', # Forward reference
263
+ agent: "A2AAgent", # Forward reference
272
264
  model_provider: Any,
273
- executor_func: Callable
265
+ executor_func: Callable,
274
266
  ) -> AsyncGenerator[Dict[str, Any], None]:
275
267
  """Pure function to handle message/stream method"""
276
268
  try:
@@ -280,22 +272,18 @@ async def handle_message_stream(
280
272
  context = {
281
273
  "message": message,
282
274
  "session_id": message.get("contextId", f"session_{id(request)}"),
283
- "metadata": params.get("metadata")
275
+ "metadata": params.get("metadata"),
284
276
  }
285
277
 
286
278
  async for event in executor_func(context, agent, model_provider):
287
279
  yield create_jsonrpc_success_response_dict(request.get("id"), event)
288
280
 
289
281
  except Exception as error:
290
- yield create_jsonrpc_error_response_dict(
291
- request.get("id"),
292
- map_error_to_a2a_error(error)
293
- )
282
+ yield create_jsonrpc_error_response_dict(request.get("id"), map_error_to_a2a_error(error))
294
283
 
295
284
 
296
285
  async def handle_tasks_get(
297
- request: Dict[str, Any],
298
- task_storage: Dict[str, Dict[str, Any]]
286
+ request: Dict[str, Any], task_storage: Dict[str, Dict[str, Any]]
299
287
  ) -> Dict[str, Any]:
300
288
  """Pure function to handle tasks/get method"""
301
289
  try:
@@ -305,10 +293,7 @@ async def handle_tasks_get(
305
293
  if not task_id:
306
294
  return create_jsonrpc_error_response_dict(
307
295
  request.get("id"),
308
- {
309
- "code": A2AErrorCodes.INVALID_PARAMS.value,
310
- "message": "Task ID is required"
311
- }
296
+ {"code": A2AErrorCodes.INVALID_PARAMS.value, "message": "Task ID is required"},
312
297
  )
313
298
 
314
299
  task = task_storage.get(task_id)
@@ -318,8 +303,8 @@ async def handle_tasks_get(
318
303
  request.get("id"),
319
304
  {
320
305
  "code": A2AErrorCodes.TASK_NOT_FOUND.value,
321
- "message": f"Task with id {task_id} not found"
322
- }
306
+ "message": f"Task with id {task_id} not found",
307
+ },
323
308
  )
324
309
 
325
310
  # Apply history length limit if specified
@@ -331,15 +316,11 @@ async def handle_tasks_get(
331
316
  return create_jsonrpc_success_response_dict(request.get("id"), result_task)
332
317
 
333
318
  except Exception as error:
334
- return create_jsonrpc_error_response_dict(
335
- request.get("id"),
336
- map_error_to_a2a_error(error)
337
- )
319
+ return create_jsonrpc_error_response_dict(request.get("id"), map_error_to_a2a_error(error))
338
320
 
339
321
 
340
322
  async def handle_tasks_cancel(
341
- request: Dict[str, Any],
342
- task_storage: Dict[str, Dict[str, Any]]
323
+ request: Dict[str, Any], task_storage: Dict[str, Dict[str, Any]]
343
324
  ) -> Dict[str, Any]:
344
325
  """Pure function to handle tasks/cancel method"""
345
326
  try:
@@ -349,10 +330,7 @@ async def handle_tasks_cancel(
349
330
  if not task_id:
350
331
  return create_jsonrpc_error_response_dict(
351
332
  request.get("id"),
352
- {
353
- "code": A2AErrorCodes.INVALID_PARAMS.value,
354
- "message": "Task ID is required"
355
- }
333
+ {"code": A2AErrorCodes.INVALID_PARAMS.value, "message": "Task ID is required"},
356
334
  )
357
335
 
358
336
  task = task_storage.get(task_id)
@@ -362,8 +340,8 @@ async def handle_tasks_cancel(
362
340
  request.get("id"),
363
341
  {
364
342
  "code": A2AErrorCodes.TASK_NOT_FOUND.value,
365
- "message": f"Task with id {task_id} not found"
366
- }
343
+ "message": f"Task with id {task_id} not found",
344
+ },
367
345
  )
368
346
 
369
347
  # Check if task can be canceled
@@ -373,29 +351,25 @@ async def handle_tasks_cancel(
373
351
  request.get("id"),
374
352
  {
375
353
  "code": A2AErrorCodes.TASK_NOT_CANCELABLE.value,
376
- "message": f"Task {task_id} cannot be canceled in state {current_state}"
377
- }
354
+ "message": f"Task {task_id} cannot be canceled in state {current_state}",
355
+ },
378
356
  )
379
357
 
380
358
  # Create canceled task
381
359
  canceled_task = task.copy()
382
360
  canceled_task["status"] = {
383
361
  "state": "canceled",
384
- "timestamp": None # Would be set by the system
362
+ "timestamp": None, # Would be set by the system
385
363
  }
386
364
 
387
365
  return create_jsonrpc_success_response_dict(request.get("id"), canceled_task)
388
366
 
389
367
  except Exception as error:
390
- return create_jsonrpc_error_response_dict(
391
- request.get("id"),
392
- map_error_to_a2a_error(error)
393
- )
368
+ return create_jsonrpc_error_response_dict(request.get("id"), map_error_to_a2a_error(error))
394
369
 
395
370
 
396
371
  async def handle_get_authenticated_extended_card(
397
- request: Dict[str, Any],
398
- agent_card: Dict[str, Any]
372
+ request: Dict[str, Any], agent_card: Dict[str, Any]
399
373
  ) -> Dict[str, Any]:
400
374
  """Pure function to handle agent/getAuthenticatedExtendedCard method"""
401
375
  try:
@@ -404,20 +378,17 @@ async def handle_get_authenticated_extended_card(
404
378
  return create_jsonrpc_success_response_dict(request.get("id"), agent_card)
405
379
 
406
380
  except Exception as error:
407
- return create_jsonrpc_error_response_dict(
408
- request.get("id"),
409
- map_error_to_a2a_error(error)
410
- )
381
+ return create_jsonrpc_error_response_dict(request.get("id"), map_error_to_a2a_error(error))
411
382
 
412
383
 
413
384
  def route_a2a_request(
414
385
  request: Dict[str, Any],
415
- agent: 'A2AAgent', # Forward reference
386
+ agent: "A2AAgent", # Forward reference
416
387
  model_provider: Any,
417
388
  task_storage: Dict[str, Dict[str, Any]],
418
389
  agent_card: Dict[str, Any],
419
390
  executor_func: Callable,
420
- streaming_executor_func: Callable
391
+ streaming_executor_func: Callable,
421
392
  ) -> Union[Dict[str, Any], AsyncGenerator[Dict[str, Any], None]]:
422
393
  """Pure function to route A2A requests"""
423
394
 
@@ -427,8 +398,8 @@ def route_a2a_request(
427
398
  request.get("id"),
428
399
  {
429
400
  "code": A2AErrorCodes.INVALID_REQUEST.value,
430
- "message": "Invalid JSON-RPC request"
431
- }
401
+ "message": "Invalid JSON-RPC request",
402
+ },
432
403
  )
433
404
 
434
405
  method = request.get("method")
@@ -436,17 +407,16 @@ def route_a2a_request(
436
407
  if method == "message/send":
437
408
  validation = validate_send_message_request(request)
438
409
  if not validation["is_valid"]:
439
- return create_jsonrpc_error_response_dict(
440
- request.get("id"),
441
- validation["error"]
442
- )
410
+ return create_jsonrpc_error_response_dict(request.get("id"), validation["error"])
443
411
  return await handle_message_send(request, agent, model_provider, executor_func)
444
412
 
445
413
  elif method == "message/stream":
446
414
  validation = validate_send_message_request(request)
447
415
  if not validation["is_valid"]:
416
+
448
417
  async def error_generator():
449
418
  yield create_jsonrpc_error_response_dict(request.get("id"), validation["error"])
419
+
450
420
  return error_generator()
451
421
  # Return the async generator directly
452
422
  return handle_message_stream(request, agent, model_provider, streaming_executor_func)
@@ -465,8 +435,8 @@ def route_a2a_request(
465
435
  request.get("id"),
466
436
  {
467
437
  "code": A2AErrorCodes.METHOD_NOT_FOUND.value,
468
- "message": f"Method {method} not found"
469
- }
438
+ "message": f"Method {method} not found",
439
+ },
470
440
  )
471
441
 
472
442
  # Handle streaming vs non-streaming
@@ -479,11 +449,11 @@ def route_a2a_request(
479
449
 
480
450
 
481
451
  def create_protocol_handler_config(
482
- agents: Dict[str, 'A2AAgent'], # Forward reference
452
+ agents: Dict[str, "A2AAgent"], # Forward reference
483
453
  model_provider: Any,
484
454
  agent_card: Dict[str, Any],
485
455
  executor_func: Callable,
486
- streaming_executor_func: Callable
456
+ streaming_executor_func: Callable,
487
457
  ) -> Dict[str, Any]:
488
458
  """Pure function to create protocol handler configuration"""
489
459
 
@@ -502,8 +472,8 @@ def create_protocol_handler_config(
502
472
  request.get("id"),
503
473
  {
504
474
  "code": A2AErrorCodes.INVALID_PARAMS.value,
505
- "message": f"Agent {agent_name or 'default'} not found"
506
- }
475
+ "message": f"Agent {agent_name or 'default'} not found",
476
+ },
507
477
  )
508
478
 
509
479
  return route_a2a_request(
@@ -513,7 +483,7 @@ def create_protocol_handler_config(
513
483
  task_storage,
514
484
  agent_card,
515
485
  executor_func,
516
- streaming_executor_func
486
+ streaming_executor_func,
517
487
  )
518
488
 
519
489
  return {
@@ -521,5 +491,5 @@ def create_protocol_handler_config(
521
491
  "model_provider": model_provider,
522
492
  "agent_card": agent_card,
523
493
  "task_storage": task_storage,
524
- "handle_request": handle_request
494
+ "handle_request": handle_request,
525
495
  }