jaf-py 2.5.10__py3-none-any.whl → 2.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. jaf/__init__.py +154 -57
  2. jaf/a2a/__init__.py +42 -21
  3. jaf/a2a/agent.py +79 -126
  4. jaf/a2a/agent_card.py +87 -78
  5. jaf/a2a/client.py +30 -66
  6. jaf/a2a/examples/client_example.py +12 -12
  7. jaf/a2a/examples/integration_example.py +38 -47
  8. jaf/a2a/examples/server_example.py +56 -53
  9. jaf/a2a/memory/__init__.py +0 -4
  10. jaf/a2a/memory/cleanup.py +28 -21
  11. jaf/a2a/memory/factory.py +155 -133
  12. jaf/a2a/memory/providers/composite.py +21 -26
  13. jaf/a2a/memory/providers/in_memory.py +89 -83
  14. jaf/a2a/memory/providers/postgres.py +117 -115
  15. jaf/a2a/memory/providers/redis.py +128 -121
  16. jaf/a2a/memory/serialization.py +77 -87
  17. jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
  18. jaf/a2a/memory/tests/test_cleanup.py +211 -94
  19. jaf/a2a/memory/tests/test_serialization.py +73 -68
  20. jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
  21. jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
  22. jaf/a2a/memory/types.py +91 -53
  23. jaf/a2a/protocol.py +95 -125
  24. jaf/a2a/server.py +90 -118
  25. jaf/a2a/standalone_client.py +30 -43
  26. jaf/a2a/tests/__init__.py +16 -33
  27. jaf/a2a/tests/run_tests.py +17 -53
  28. jaf/a2a/tests/test_agent.py +40 -140
  29. jaf/a2a/tests/test_client.py +54 -117
  30. jaf/a2a/tests/test_integration.py +28 -82
  31. jaf/a2a/tests/test_protocol.py +54 -139
  32. jaf/a2a/tests/test_types.py +50 -136
  33. jaf/a2a/types.py +58 -34
  34. jaf/cli.py +21 -41
  35. jaf/core/__init__.py +7 -1
  36. jaf/core/agent_tool.py +93 -72
  37. jaf/core/analytics.py +257 -207
  38. jaf/core/checkpoint.py +223 -0
  39. jaf/core/composition.py +249 -235
  40. jaf/core/engine.py +817 -519
  41. jaf/core/errors.py +55 -42
  42. jaf/core/guardrails.py +276 -202
  43. jaf/core/handoff.py +47 -31
  44. jaf/core/parallel_agents.py +69 -75
  45. jaf/core/performance.py +75 -73
  46. jaf/core/proxy.py +43 -44
  47. jaf/core/proxy_helpers.py +24 -27
  48. jaf/core/regeneration.py +220 -129
  49. jaf/core/state.py +68 -66
  50. jaf/core/streaming.py +115 -108
  51. jaf/core/tool_results.py +111 -101
  52. jaf/core/tools.py +114 -116
  53. jaf/core/tracing.py +269 -210
  54. jaf/core/types.py +371 -151
  55. jaf/core/workflows.py +209 -168
  56. jaf/exceptions.py +46 -38
  57. jaf/memory/__init__.py +1 -6
  58. jaf/memory/approval_storage.py +54 -77
  59. jaf/memory/factory.py +4 -4
  60. jaf/memory/providers/in_memory.py +216 -180
  61. jaf/memory/providers/postgres.py +216 -146
  62. jaf/memory/providers/redis.py +173 -116
  63. jaf/memory/types.py +70 -51
  64. jaf/memory/utils.py +36 -34
  65. jaf/plugins/__init__.py +12 -12
  66. jaf/plugins/base.py +105 -96
  67. jaf/policies/__init__.py +0 -1
  68. jaf/policies/handoff.py +37 -46
  69. jaf/policies/validation.py +76 -52
  70. jaf/providers/__init__.py +6 -3
  71. jaf/providers/mcp.py +97 -51
  72. jaf/providers/model.py +360 -279
  73. jaf/server/__init__.py +1 -1
  74. jaf/server/main.py +7 -11
  75. jaf/server/server.py +514 -359
  76. jaf/server/types.py +208 -52
  77. jaf/utils/__init__.py +17 -18
  78. jaf/utils/attachments.py +111 -116
  79. jaf/utils/document_processor.py +175 -174
  80. jaf/visualization/__init__.py +1 -1
  81. jaf/visualization/example.py +111 -110
  82. jaf/visualization/functional_core.py +46 -71
  83. jaf/visualization/graphviz.py +154 -189
  84. jaf/visualization/imperative_shell.py +7 -16
  85. jaf/visualization/types.py +8 -4
  86. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
  87. jaf_py-2.5.11.dist-info/RECORD +97 -0
  88. jaf_py-2.5.10.dist-info/RECORD +0 -96
  89. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
  90. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
  91. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
  92. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
jaf/server/server.py CHANGED
@@ -19,6 +19,7 @@ from fastapi.responses import StreamingResponse
19
19
  from ..core.engine import run
20
20
  from ..core.streaming import run_streaming
21
21
  from ..core.regeneration import regenerate_conversation, get_regeneration_points
22
+ from ..core.checkpoint import checkpoint_conversation, get_checkpoint_history
22
23
  from ..core.types import (
23
24
  ApprovalValue,
24
25
  CompletedOutcome,
@@ -32,6 +33,7 @@ from ..core.types import (
32
33
  create_trace_id,
33
34
  create_message_id,
34
35
  RegenerationRequest,
36
+ CheckpointRequest,
35
37
  )
36
38
  from ..memory.types import MemoryConfig
37
39
  from .types import (
@@ -61,23 +63,31 @@ from .types import (
61
63
  RegenerationPointData,
62
64
  RegenerationHistoryData,
63
65
  RegenerationHistoryResponse,
66
+ CheckpointHttpRequest,
67
+ CheckpointData,
68
+ CheckpointResponse,
69
+ CheckpointPointData,
70
+ CheckpointHistoryData,
71
+ CheckpointHistoryResponse,
64
72
  ServerConfig,
65
73
  ToolCallInterruption,
66
74
  validate_regeneration_request,
67
75
  )
68
76
 
69
- Ctx = TypeVar('Ctx')
77
+ Ctx = TypeVar("Ctx")
78
+
70
79
 
71
80
  # Helper functions for HITL (moved outside like TypeScript)
72
81
  def stable_stringify(value) -> str:
73
82
  """Create deterministic JSON string for tool call signatures."""
74
83
  try:
75
84
  if isinstance(value, dict):
76
- return json.dumps(value, sort_keys=True, separators=(',', ':'))
77
- return json.dumps(value, separators=(',', ':'))
85
+ return json.dumps(value, sort_keys=True, separators=(",", ":"))
86
+ return json.dumps(value, separators=(",", ":"))
78
87
  except (TypeError, ValueError):
79
88
  return str(value)
80
89
 
90
+
81
91
  def try_parse_json(s: str):
82
92
  """Try to parse JSON, return original string if it fails."""
83
93
  try:
@@ -85,6 +95,7 @@ def try_parse_json(s: str):
85
95
  except (json.JSONDecodeError, TypeError):
86
96
  return s
87
97
 
98
+
88
99
  def compute_tool_call_signature(tool_call) -> str:
89
100
  """Compute deterministic signature for tool call matching."""
90
101
  try:
@@ -93,6 +104,7 @@ def compute_tool_call_signature(tool_call) -> str:
93
104
  except Exception:
94
105
  return f"{tool_call.function.name}:unknown"
95
106
 
107
+
96
108
  def _convert_http_message_to_core(http_msg: HttpMessage) -> Message:
97
109
  """Convert HTTP message format to core Message format."""
98
110
  # Convert content
@@ -102,31 +114,26 @@ def _convert_http_message_to_core(http_msg: HttpMessage) -> Message:
102
114
  # Convert list of content parts
103
115
  content_parts = []
104
116
  for i, part in enumerate(http_msg.content):
105
- if part.type == 'text':
106
- content_parts.append(MessageContentPart(
107
- type='text',
108
- text=part.text,
109
- image_url=None,
110
- file=None
111
- ))
112
- elif part.type == 'image_url':
113
- content_parts.append(MessageContentPart(
114
- type='image_url',
115
- text=None,
116
- image_url=part.image_url,
117
- file=None
118
- ))
119
- elif part.type == 'file':
120
- content_parts.append(MessageContentPart(
121
- type='file',
122
- text=None,
123
- image_url=None,
124
- file=part.file
125
- ))
117
+ if part.type == "text":
118
+ content_parts.append(
119
+ MessageContentPart(type="text", text=part.text, image_url=None, file=None)
120
+ )
121
+ elif part.type == "image_url":
122
+ content_parts.append(
123
+ MessageContentPart(
124
+ type="image_url", text=None, image_url=part.image_url, file=None
125
+ )
126
+ )
127
+ elif part.type == "file":
128
+ content_parts.append(
129
+ MessageContentPart(type="file", text=None, image_url=None, file=part.file)
130
+ )
126
131
  else:
127
132
  # Raise explicit error for unrecognized part types
128
- raise ValueError(f"Unrecognized message content part type: '{part.type}' at index {i}. "
129
- f"Supported types are: 'text', 'image_url', 'file'")
133
+ raise ValueError(
134
+ f"Unrecognized message content part type: '{part.type}' at index {i}. "
135
+ f"Supported types are: 'text', 'image_url', 'file'"
136
+ )
130
137
  content = content_parts
131
138
 
132
139
  # Convert attachments
@@ -140,7 +147,7 @@ def _convert_http_message_to_core(http_msg: HttpMessage) -> Message:
140
147
  url=att.url,
141
148
  data=att.data,
142
149
  format=att.format,
143
- use_litellm_format=att.use_litellm_format
150
+ use_litellm_format=att.use_litellm_format,
144
151
  )
145
152
  for att in http_msg.attachments
146
153
  ]
@@ -150,14 +157,15 @@ def _convert_http_message_to_core(http_msg: HttpMessage) -> Message:
150
157
  content=content,
151
158
  attachments=attachments,
152
159
  tool_call_id=http_msg.tool_call_id,
153
- tool_calls=http_msg.tool_calls
160
+ tool_calls=http_msg.tool_calls,
154
161
  )
155
162
 
163
+
156
164
  def _convert_core_message_to_http(core_msg: Message) -> HttpMessage:
157
165
  """Convert core Message format to HTTP message format."""
158
166
  from .types import HttpAttachment, HttpMessageContentPart
159
167
  from ..core.types import get_text_content
160
-
168
+
161
169
  # Convert content
162
170
  if isinstance(core_msg.content, str):
163
171
  content = core_msg.content
@@ -165,33 +173,28 @@ def _convert_core_message_to_http(core_msg: Message) -> HttpMessage:
165
173
  # Convert content parts to HTTP format
166
174
  http_parts = []
167
175
  for i, part in enumerate(core_msg.content):
168
- if part.type == 'text':
169
- http_parts.append(HttpMessageContentPart(
170
- type='text',
171
- text=part.text,
172
- image_url=None,
173
- file=None
174
- ))
175
- elif part.type == 'image_url':
176
- http_parts.append(HttpMessageContentPart(
177
- type='image_url',
178
- text=None,
179
- image_url=part.image_url,
180
- file=None
181
- ))
182
- elif part.type == 'file':
183
- http_parts.append(HttpMessageContentPart(
184
- type='file',
185
- text=None,
186
- image_url=None,
187
- file=part.file
188
- ))
176
+ if part.type == "text":
177
+ http_parts.append(
178
+ HttpMessageContentPart(type="text", text=part.text, image_url=None, file=None)
179
+ )
180
+ elif part.type == "image_url":
181
+ http_parts.append(
182
+ HttpMessageContentPart(
183
+ type="image_url", text=None, image_url=part.image_url, file=None
184
+ )
185
+ )
186
+ elif part.type == "file":
187
+ http_parts.append(
188
+ HttpMessageContentPart(type="file", text=None, image_url=None, file=part.file)
189
+ )
189
190
  else:
190
191
  # Raise explicit error for unrecognized part types
191
192
  message_info = f"role={core_msg.role}"
192
- raise ValueError(f"Unrecognized core message content part type: '{part.type}' at index {i}. "
193
- f"Message info: {message_info}. "
194
- f"Supported types are: 'text', 'image_url', 'file'")
193
+ raise ValueError(
194
+ f"Unrecognized core message content part type: '{part.type}' at index {i}. "
195
+ f"Message info: {message_info}. "
196
+ f"Supported types are: 'text', 'image_url', 'file'"
197
+ )
195
198
  content = http_parts
196
199
  else:
197
200
  content = get_text_content(core_msg.content)
@@ -207,7 +210,7 @@ def _convert_core_message_to_http(core_msg: Message) -> HttpMessage:
207
210
  url=att.url,
208
211
  data=att.data,
209
212
  format=att.format,
210
- use_litellm_format=att.use_litellm_format
213
+ use_litellm_format=att.use_litellm_format,
211
214
  )
212
215
  for att in core_msg.attachments
213
216
  ]
@@ -217,17 +220,18 @@ def _convert_core_message_to_http(core_msg: Message) -> HttpMessage:
217
220
  content=content,
218
221
  attachments=attachments,
219
222
  tool_call_id=core_msg.tool_call_id,
220
- tool_calls=core_msg.tool_calls
223
+ tool_calls=core_msg.tool_calls,
221
224
  )
222
225
 
226
+
223
227
  def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
224
228
  """Create and configure a JAF server instance."""
225
229
 
226
230
  start_time = time.time()
227
-
231
+
228
232
  # SSE subscribers for approval-related events (matching TypeScript)
229
233
  approval_subscribers = set()
230
-
234
+
231
235
  def sse_send(response, event: str, data: dict):
232
236
  """Send SSE event to client."""
233
237
  try:
@@ -235,39 +239,39 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
235
239
  response.write(f"data: {json.dumps(data)}\n\n")
236
240
  except Exception:
237
241
  pass # ignore connection errors
238
-
242
+
239
243
  def broadcast_approval_required(payload: dict):
240
244
  """Broadcast approval_required event to SSE clients."""
241
245
  for client in approval_subscribers.copy(): # copy to avoid modification during iteration
242
- filter_conv_id = client.get('filter_conversation_id')
243
- if filter_conv_id and filter_conv_id != payload.get('conversationId'):
246
+ filter_conv_id = client.get("filter_conversation_id")
247
+ if filter_conv_id and filter_conv_id != payload.get("conversationId"):
244
248
  continue
245
-
249
+
246
250
  payload_with_timestamp = {
247
251
  **payload,
248
- 'timestamp': payload.get('timestamp', time.strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
252
+ "timestamp": payload.get("timestamp", time.strftime("%Y-%m-%dT%H:%M:%S.%fZ")),
249
253
  }
250
- sse_send(client['response'], 'approval_required', payload_with_timestamp)
251
-
254
+ sse_send(client["response"], "approval_required", payload_with_timestamp)
255
+
252
256
  def broadcast_approval_decision(payload: dict):
253
257
  """Broadcast approval_decision event to SSE clients."""
254
258
  for client in approval_subscribers.copy(): # copy to avoid modification during iteration
255
- filter_conv_id = client.get('filter_conversation_id')
256
- if filter_conv_id and filter_conv_id != payload.get('conversationId'):
259
+ filter_conv_id = client.get("filter_conversation_id")
260
+ if filter_conv_id and filter_conv_id != payload.get("conversationId"):
257
261
  continue
258
-
262
+
259
263
  payload_with_timestamp = {
260
264
  **payload,
261
- 'timestamp': payload.get('timestamp', time.strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
265
+ "timestamp": payload.get("timestamp", time.strftime("%Y-%m-%dT%H:%M:%S.%fZ")),
262
266
  }
263
- sse_send(client['response'], 'approval_decision', payload_with_timestamp)
267
+ sse_send(client["response"], "approval_decision", payload_with_timestamp)
264
268
 
265
269
  app = FastAPI(
266
270
  title="JAF Agent Framework Server",
267
271
  description="HTTP API for JAF agents with HITL support",
268
272
  version="2.0.0",
269
273
  docs_url="/docs",
270
- redoc_url="/redoc"
274
+ redoc_url="/redoc",
271
275
  )
272
276
 
273
277
  # Setup middleware
@@ -286,7 +290,7 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
286
290
  status="healthy",
287
291
  timestamp=time.strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
288
292
  version="2.0.0",
289
- uptime=int((time.time() - start_time) * 1000)
293
+ uptime=int((time.time() - start_time) * 1000),
290
294
  )
291
295
 
292
296
  @app.get("/agents", response_model=AgentListResponse)
@@ -296,7 +300,7 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
296
300
  AgentInfo(
297
301
  name=name,
298
302
  description=agent.instructions(None) if agent.instructions else "",
299
- tools=[tool.schema.name for tool in agent.tools or []]
303
+ tools=[tool.schema.name for tool in agent.tools or []],
300
304
  )
301
305
  for name, agent in config.agent_registry.items()
302
306
  ]
@@ -307,31 +311,30 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
307
311
  @app.post("/chat", response_model=ChatResponse)
308
312
  async def chat_completion(request: ChatRequest):
309
313
  request_start_time = time.time()
310
-
314
+
311
315
  try:
312
316
  # Validate request (matching TypeScript approach)
313
- validated_request = request # Already validated by FastAPI, but keeping TypeScript structure
314
-
317
+ validated_request = (
318
+ request # Already validated by FastAPI, but keeping TypeScript structure
319
+ )
320
+
315
321
  # Check if agent exists (matching TypeScript response pattern)
316
322
  if validated_request.agent_name not in config.agent_registry:
317
323
  return ChatResponse(
318
324
  success=False,
319
- error=f"Agent '{validated_request.agent_name}' not found. Available agents: {', '.join(config.agent_registry.keys())}"
325
+ error=f"Agent '{validated_request.agent_name}' not found. Available agents: {', '.join(config.agent_registry.keys())}",
320
326
  )
321
-
327
+
322
328
  # Convert HTTP messages to JAF messages (matching TypeScript)
323
329
  jaf_messages = [
324
- Message(
325
- role='user' if msg.role == 'system' else msg.role,
326
- content=msg.content
327
- )
330
+ Message(role="user" if msg.role == "system" else msg.role, content=msg.content)
328
331
  for msg in validated_request.messages
329
332
  ]
330
-
333
+
331
334
  # Create initial state (matching TypeScript)
332
335
  run_id = create_run_id(str(uuid.uuid4()))
333
336
  trace_id = create_trace_id(str(uuid.uuid4()))
334
-
337
+
335
338
  # Generate conversationId if not provided (matching TypeScript)
336
339
  conversation_id = validated_request.conversation_id or f"conv-{uuid.uuid4()}"
337
340
  except Exception as e:
@@ -341,8 +344,10 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
341
344
  initial_turn_count = 0
342
345
  if config.default_memory_provider and conversation_id:
343
346
  try:
344
- conversation_result = await config.default_memory_provider.get_conversation(conversation_id)
345
- if hasattr(conversation_result, 'data') and conversation_result.data:
347
+ conversation_result = await config.default_memory_provider.get_conversation(
348
+ conversation_id
349
+ )
350
+ if hasattr(conversation_result, "data") and conversation_result.data:
346
351
  conversation_data = conversation_result.data
347
352
  if conversation_data.metadata and "turn_count" in conversation_data.metadata:
348
353
  initial_turn_count = conversation_data.metadata["turn_count"]
@@ -353,25 +358,25 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
353
358
  # Handle approval message(s) if present (matching TypeScript approach)
354
359
  initial_approvals = {} # Will act like TypeScript's Map
355
360
  initial_state_messages = jaf_messages
356
-
361
+
357
362
  approvals_list = validated_request.approvals or []
358
-
363
+
359
364
  async def persist_approval(conv_id: str, appr: ApprovalMessage):
360
365
  """Persist approval to memory provider with metadata (matching TypeScript)."""
361
366
  if not config.default_memory_provider:
362
367
  return
363
-
368
+
364
369
  provider = config.default_memory_provider
365
370
  # Keyed by previous run/session id + toolCallId for uniqueness (matching TypeScript)
366
371
  approval_key = f"{appr.session_id}:{appr.tool_call_id}"
367
372
  base_entry = {
368
- 'approved': appr.approved,
369
- 'status': 'approved' if appr.approved else 'rejected',
370
- 'additionalContext': appr.additional_context,
371
- 'sessionId': appr.session_id,
372
- 'toolCallId': appr.tool_call_id,
373
+ "approved": appr.approved,
374
+ "status": "approved" if appr.approved else "rejected",
375
+ "additionalContext": appr.additional_context,
376
+ "sessionId": appr.session_id,
377
+ "toolCallId": appr.tool_call_id,
373
378
  }
374
-
379
+
375
380
  try:
376
381
  existing = await provider.get_conversation(conv_id)
377
382
  if existing.success and existing.data:
@@ -380,111 +385,143 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
380
385
  msgs = existing.data.messages
381
386
  for i in range(len(msgs) - 1, -1, -1):
382
387
  m = msgs[i]
383
- if m.role == 'assistant' and hasattr(m, 'tool_calls') and m.tool_calls:
384
- match = next((tc for tc in m.tool_calls if tc.id == appr.tool_call_id), None)
388
+ if m.role == "assistant" and hasattr(m, "tool_calls") and m.tool_calls:
389
+ match = next(
390
+ (tc for tc in m.tool_calls if tc.id == appr.tool_call_id), None
391
+ )
385
392
  if match:
386
- base_entry['toolName'] = match.function.name
387
- base_entry['signature'] = compute_tool_call_signature(match)
393
+ base_entry["toolName"] = match.function.name
394
+ base_entry["signature"] = compute_tool_call_signature(match)
388
395
  break
389
396
  except Exception:
390
397
  pass # best-effort
391
-
392
- existing_approvals = (existing.data.metadata.get('toolApprovals') if existing.data.metadata else {}) or {}
398
+
399
+ existing_approvals = (
400
+ existing.data.metadata.get("toolApprovals")
401
+ if existing.data.metadata
402
+ else {}
403
+ ) or {}
393
404
  prev = existing_approvals.get(approval_key)
394
-
405
+
395
406
  # Merge additionalContext shallowly and avoid regressions (exactly matching TypeScript)
396
407
  merged_additional = {
397
- **(prev.get('additionalContext') if prev else {}),
398
- **(base_entry.get('additionalContext') or {}),
408
+ **(prev.get("additionalContext") if prev else {}),
409
+ **(base_entry.get("additionalContext") or {}),
399
410
  }
400
-
411
+
401
412
  next_entry = {
402
413
  **(prev or {}),
403
414
  **base_entry,
404
- 'additionalContext': merged_additional,
415
+ "additionalContext": merged_additional,
405
416
  # Preserve earliest timestamp if no effective change; else update (exactly matching TypeScript)
406
- 'timestamp': (
407
- prev.get('timestamp') if prev and (
408
- prev.get('status') == base_entry['status'] and
409
- stable_stringify(prev.get('additionalContext')) == stable_stringify(merged_additional)
410
- ) else time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
411
- )
417
+ "timestamp": (
418
+ prev.get("timestamp")
419
+ if prev
420
+ and (
421
+ prev.get("status") == base_entry["status"]
422
+ and stable_stringify(prev.get("additionalContext"))
423
+ == stable_stringify(merged_additional)
424
+ )
425
+ else time.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
426
+ ),
412
427
  }
413
-
428
+
414
429
  # Check if there's actually a change (exactly matching TypeScript)
415
430
  no_change = prev and (
416
- prev.get('status') == next_entry['status'] and
417
- stable_stringify(prev.get('additionalContext')) == stable_stringify(next_entry['additionalContext']) and
418
- (prev.get('toolName') or None) == (next_entry.get('toolName') or None) and
419
- (prev.get('signature') or None) == (next_entry.get('signature') or None)
431
+ prev.get("status") == next_entry["status"]
432
+ and stable_stringify(prev.get("additionalContext"))
433
+ == stable_stringify(next_entry["additionalContext"])
434
+ and (prev.get("toolName") or None) == (next_entry.get("toolName") or None)
435
+ and (prev.get("signature") or None) == (next_entry.get("signature") or None)
420
436
  )
421
-
437
+
422
438
  if not no_change:
423
439
  merged_approvals = {**existing_approvals, approval_key: next_entry}
424
- await provider.appendMessages(conv_id, [], {'toolApprovals': merged_approvals, 'traceId': trace_id})
425
-
440
+ await provider.appendMessages(
441
+ conv_id, [], {"toolApprovals": merged_approvals, "traceId": trace_id}
442
+ )
443
+
426
444
  elif existing.success and not existing.data:
427
445
  # Create conversation shell with just metadata if not present (exactly matching TypeScript)
428
- entry = {**base_entry, 'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}
429
- await provider.storeMessages(conv_id, [], {'toolApprovals': {approval_key: entry}, 'traceId': trace_id})
446
+ entry = {**base_entry, "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S.%fZ")}
447
+ await provider.storeMessages(
448
+ conv_id, [], {"toolApprovals": {approval_key: entry}, "traceId": trace_id}
449
+ )
430
450
  # If provider call failed, we intentionally do not throw; run will proceed
431
451
  except Exception:
432
452
  # Ignore persistence errors here to avoid breaking the request path (exactly matching TypeScript)
433
453
  pass
434
-
454
+
435
455
  # Broadcast decision to approvals SSE (exactly matching TypeScript)
436
456
  try:
437
- broadcast_approval_decision({
438
- 'conversationId': conv_id,
439
- 'sessionId': appr.session_id,
440
- 'toolCallId': appr.tool_call_id,
441
- 'status': 'approved' if appr.approved else 'rejected',
442
- 'additionalContext': appr.additional_context
443
- })
457
+ broadcast_approval_decision(
458
+ {
459
+ "conversationId": conv_id,
460
+ "sessionId": appr.session_id,
461
+ "toolCallId": appr.tool_call_id,
462
+ "status": "approved" if appr.approved else "rejected",
463
+ "additionalContext": appr.additional_context,
464
+ }
465
+ )
444
466
  except Exception:
445
467
  pass # ignore
446
-
468
+
447
469
  if len(approvals_list) > 0:
448
470
  for approval in approvals_list:
449
471
  if approval.session_id: # Matching TypeScript condition
450
472
  initial_approvals[approval.tool_call_id] = {
451
- 'status': 'approved' if approval.approved else 'rejected',
452
- 'approved': approval.approved,
453
- 'additionalContext': approval.additional_context
473
+ "status": "approved" if approval.approved else "rejected",
474
+ "approved": approval.approved,
475
+ "additionalContext": approval.additional_context,
454
476
  }
455
477
  await persist_approval(conversation_id, approval)
456
-
478
+
457
479
  # Seed approvals from persisted conversation metadata
458
480
  if config.default_memory_provider:
459
481
  try:
460
482
  conv_result = await config.default_memory_provider.get_conversation(conversation_id)
461
- if hasattr(conv_result, 'data') and conv_result.data:
483
+ if hasattr(conv_result, "data") and conv_result.data:
462
484
  conversation_data = conv_result.data
463
- tool_approvals = getattr(conversation_data.metadata, 'tool_approvals', {}) if conversation_data.metadata else {}
464
-
485
+ tool_approvals = (
486
+ getattr(conversation_data.metadata, "tool_approvals", {})
487
+ if conversation_data.metadata
488
+ else {}
489
+ )
490
+
465
491
  if tool_approvals:
466
492
  # Find latest assistant message with tool calls for matching
467
493
  assistant_msg = None
468
494
  for msg in reversed(conversation_data.messages):
469
- if hasattr(msg, 'role') and msg.role == 'assistant' and hasattr(msg, 'tool_calls') and msg.tool_calls:
495
+ if (
496
+ hasattr(msg, "role")
497
+ and msg.role == "assistant"
498
+ and hasattr(msg, "tool_calls")
499
+ and msg.tool_calls
500
+ ):
470
501
  assistant_msg = msg
471
502
  break
472
-
503
+
473
504
  if assistant_msg:
474
505
  candidate_ids = {tc.id for tc in assistant_msg.tool_calls}
475
- candidate_signatures = {tc.id: compute_tool_call_signature(tc) for tc in assistant_msg.tool_calls}
476
-
506
+ candidate_signatures = {
507
+ tc.id: compute_tool_call_signature(tc)
508
+ for tc in assistant_msg.tool_calls
509
+ }
510
+
477
511
  # Load persisted approvals that aren't already in initial_approvals
478
512
  for approval_entry in tool_approvals.values():
479
513
  if not isinstance(approval_entry, dict):
480
514
  continue
481
-
482
- persisted_tool_call_id = approval_entry.get('tool_call_id')
483
- persisted_signature = approval_entry.get('signature')
484
-
515
+
516
+ persisted_tool_call_id = approval_entry.get("tool_call_id")
517
+ persisted_signature = approval_entry.get("signature")
518
+
485
519
  # Try direct ID match first
486
520
  target_id = None
487
- if persisted_tool_call_id and persisted_tool_call_id in candidate_ids:
521
+ if (
522
+ persisted_tool_call_id
523
+ and persisted_tool_call_id in candidate_ids
524
+ ):
488
525
  target_id = persisted_tool_call_id
489
526
  elif persisted_signature:
490
527
  # Signature fallback
@@ -492,20 +529,20 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
492
529
  if sig == persisted_signature:
493
530
  target_id = tc_id
494
531
  break
495
-
532
+
496
533
  if target_id and target_id not in initial_approvals:
497
- status = approval_entry.get('status', 'pending')
498
- if approval_entry.get('approved') is True:
499
- status = 'approved'
500
- elif approval_entry.get('approved') is False:
501
- status = 'rejected'
502
-
534
+ status = approval_entry.get("status", "pending")
535
+ if approval_entry.get("approved") is True:
536
+ status = "approved"
537
+ elif approval_entry.get("approved") is False:
538
+ status = "rejected"
539
+
503
540
  initial_approvals[target_id] = ApprovalValue(
504
541
  status=status,
505
- approved=approval_entry.get('approved', False),
506
- additional_context=approval_entry.get('additional_context')
542
+ approved=approval_entry.get("approved", False),
543
+ additional_context=approval_entry.get("additional_context"),
507
544
  )
508
-
545
+
509
546
  except Exception as e:
510
547
  print(f"[JAF:SERVER] Warning: Failed to seed approvals from metadata: {e}")
511
548
 
@@ -516,7 +553,7 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
516
553
  current_agent_name=request.agent_name,
517
554
  context=request.context or {},
518
555
  turn_count=initial_turn_count, # Use loaded turn count instead of always 0
519
- approvals=initial_approvals
556
+ approvals=initial_approvals,
520
557
  )
521
558
 
522
559
  run_config_with_memory = config.run_config
@@ -524,15 +561,17 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
524
561
  # Handle memory configuration with request overrides (matching TypeScript)
525
562
  memory_config = MemoryConfig(
526
563
  provider=config.default_memory_provider,
527
- auto_store=request.memory.get('auto_store', True) if request.memory else True,
528
- max_messages=request.memory.get('max_messages') if request.memory else None,
529
- compression_threshold=request.memory.get('compression_threshold') if request.memory else None,
530
- store_on_completion=request.store_on_completion if request.store_on_completion is not None else True
564
+ auto_store=request.memory.get("auto_store", True) if request.memory else True,
565
+ max_messages=request.memory.get("max_messages") if request.memory else None,
566
+ compression_threshold=request.memory.get("compression_threshold")
567
+ if request.memory
568
+ else None,
569
+ store_on_completion=request.store_on_completion
570
+ if request.store_on_completion is not None
571
+ else True,
531
572
  )
532
573
  run_config_with_memory = replace(
533
- run_config_with_memory,
534
- memory=memory_config,
535
- conversation_id=conversation_id
574
+ run_config_with_memory, memory=memory_config, conversation_id=conversation_id
536
575
  )
537
576
 
538
577
  if request.max_turns is not None:
@@ -540,44 +579,56 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
540
579
 
541
580
  # Handle streaming vs non-streaming (matching TypeScript)
542
581
  if request.stream:
582
+
543
583
  async def event_stream():
544
584
  try:
545
585
  # Send initial metadata
546
- yield f"""event: stream_start data: {json.dumps({
547
- 'runId': str(initial_state.run_id),
548
- 'traceId': str(initial_state.trace_id),
549
- 'conversationId': conversation_id,
550
- 'agent': request.agent_name
551
- })}"""
552
-
586
+ yield f"""event: stream_start data: {
587
+ json.dumps(
588
+ {
589
+ "runId": str(initial_state.run_id),
590
+ "traceId": str(initial_state.trace_id),
591
+ "conversationId": conversation_id,
592
+ "agent": request.agent_name,
593
+ }
594
+ )
595
+ }"""
596
+
553
597
  # Stream events from the engine
554
598
  async for event in run_streaming(initial_state, run_config_with_memory):
555
599
  yield f"event: {event.type}\ndata: {json.dumps(asdict(event))}\n\n"
556
-
600
+
557
601
  # Check for run end and handle approval broadcasts
558
- if event.type == 'complete' and hasattr(event, 'data'):
559
- outcome = getattr(event.data, 'outcome', None)
560
- if outcome and getattr(outcome, 'status', None) == 'interrupted':
561
- interruptions = getattr(outcome, 'interruptions', [])
602
+ if event.type == "complete" and hasattr(event, "data"):
603
+ outcome = getattr(event.data, "outcome", None)
604
+ if outcome and getattr(outcome, "status", None) == "interrupted":
605
+ interruptions = getattr(outcome, "interruptions", [])
562
606
  for intr in interruptions:
563
- if getattr(intr, 'type', None) == 'tool_approval':
564
- tool_call = getattr(intr, 'tool_call', None)
607
+ if getattr(intr, "type", None) == "tool_approval":
608
+ tool_call = getattr(intr, "tool_call", None)
565
609
  if tool_call:
566
- broadcast_approval_required({
567
- 'conversationId': conversation_id,
568
- 'sessionId': getattr(intr, 'session_id', None) or str(initial_state.run_id),
569
- 'toolCallId': tool_call.id,
570
- 'toolName': tool_call.function.name,
571
- 'args': try_parse_json(tool_call.function.arguments),
572
- 'signature': compute_tool_call_signature(tool_call)
573
- })
610
+ broadcast_approval_required(
611
+ {
612
+ "conversationId": conversation_id,
613
+ "sessionId": getattr(intr, "session_id", None)
614
+ or str(initial_state.run_id),
615
+ "toolCallId": tool_call.id,
616
+ "toolName": tool_call.function.name,
617
+ "args": try_parse_json(
618
+ tool_call.function.arguments
619
+ ),
620
+ "signature": compute_tool_call_signature(
621
+ tool_call
622
+ ),
623
+ }
624
+ )
574
625
  break
575
-
626
+
576
627
  except Exception as e:
577
628
  yield f"event: error\ndata: {json.dumps({'message': str(e)})}\n\n"
578
629
  finally:
579
630
  yield f"event: stream_end\ndata: {json.dumps({'ended': True})}\n\n"
580
-
631
+
581
632
  return StreamingResponse(
582
633
  event_stream(),
583
634
  media_type="text/event-stream",
@@ -586,8 +637,8 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
586
637
  "Connection": "keep-alive",
587
638
  "X-Accel-Buffering": "no",
588
639
  "Access-Control-Allow-Origin": "*",
589
- "Access-Control-Allow-Headers": "*"
590
- }
640
+ "Access-Control-Allow-Headers": "*",
641
+ },
591
642
  )
592
643
 
593
644
  # Non-streaming execution
@@ -597,53 +648,48 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
597
648
 
598
649
  # Create proper outcome object
599
650
  if isinstance(result.outcome, CompletedOutcome):
600
- outcome_data = BaseOutcomeData(
601
- status='completed',
602
- output=result.outcome.output
603
- )
651
+ outcome_data = BaseOutcomeData(status="completed", output=result.outcome.output)
604
652
  elif isinstance(result.outcome, ErrorOutcome):
605
653
  error_info = result.outcome.error
606
654
  outcome_data = BaseOutcomeData(
607
- status='error',
608
- error={
609
- 'type': error_info.__class__.__name__,
610
- 'message': str(error_info)
611
- }
655
+ status="error",
656
+ error={"type": error_info.__class__.__name__, "message": str(error_info)},
612
657
  )
613
658
  elif isinstance(result.outcome, InterruptedOutcome):
614
659
  # Convert interruptions to response format
615
660
  interruptions = []
616
661
  for interruption in result.outcome.interruptions:
617
- if hasattr(interruption, 'tool_call') and hasattr(interruption, 'type'):
662
+ if hasattr(interruption, "tool_call") and hasattr(interruption, "type"):
618
663
  tool_call_data = ToolCallInterruption(
619
664
  id=interruption.tool_call.id,
620
665
  function={
621
- 'name': interruption.tool_call.function.name,
622
- 'arguments': interruption.tool_call.function.arguments
623
- }
666
+ "name": interruption.tool_call.function.name,
667
+ "arguments": interruption.tool_call.function.arguments,
668
+ },
669
+ )
670
+ interruptions.append(
671
+ InterruptionData(
672
+ type="tool_approval",
673
+ tool_call=tool_call_data,
674
+ session_id=interruption.session_id or str(result.final_state.run_id),
675
+ )
624
676
  )
625
- interruptions.append(InterruptionData(
626
- type='tool_approval',
627
- tool_call=tool_call_data,
628
- session_id=interruption.session_id or str(result.final_state.run_id)
629
- ))
630
-
677
+
631
678
  # Broadcast approval request via SSE
632
- broadcast_approval_required({
633
- 'conversationId': conversation_id,
634
- 'sessionId': interruption.session_id or str(result.final_state.run_id),
635
- 'toolCallId': interruption.tool_call.id,
636
- 'toolName': interruption.tool_call.function.name,
637
- 'args': try_parse_json(interruption.tool_call.function.arguments),
638
- 'signature': compute_tool_call_signature(interruption.tool_call)
639
- })
640
-
641
- outcome_data = InterruptedOutcomeData(
642
- status='interrupted',
643
- interruptions=interruptions
644
- )
679
+ broadcast_approval_required(
680
+ {
681
+ "conversationId": conversation_id,
682
+ "sessionId": interruption.session_id or str(result.final_state.run_id),
683
+ "toolCallId": interruption.tool_call.id,
684
+ "toolName": interruption.tool_call.function.name,
685
+ "args": try_parse_json(interruption.tool_call.function.arguments),
686
+ "signature": compute_tool_call_signature(interruption.tool_call),
687
+ }
688
+ )
689
+
690
+ outcome_data = InterruptedOutcomeData(status="interrupted", interruptions=interruptions)
645
691
  else:
646
- outcome_data = BaseOutcomeData(status='error', error='Unknown outcome type')
692
+ outcome_data = BaseOutcomeData(status="error", error="Unknown outcome type")
647
693
 
648
694
  return ChatResponse(
649
695
  success=True,
@@ -653,19 +699,22 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
653
699
  messages=http_messages,
654
700
  outcome=outcome_data,
655
701
  turn_count=result.final_state.turn_count,
656
- execution_time_ms=int((time.time() - request_start_time) * 1000), # Use request start time
657
- conversation_id=conversation_id
658
- )
702
+ execution_time_ms=int(
703
+ (time.time() - request_start_time) * 1000
704
+ ), # Use request start time
705
+ conversation_id=conversation_id,
706
+ ),
659
707
  )
660
708
 
661
709
  # Memory endpoints
662
710
  if config.default_memory_provider:
711
+
663
712
  @app.get("/conversations/{conversation_id}", response_model=ConversationResponse)
664
713
  async def get_conversation(conversation_id: str):
665
714
  result = await config.default_memory_provider.get_conversation(conversation_id)
666
715
 
667
716
  # Handle Result type properly
668
- if hasattr(result, 'error'): # Failure case
717
+ if hasattr(result, "error"): # Failure case
669
718
  raise HTTPException(status_code=500, detail=str(result.error.message))
670
719
 
671
720
  conversation = result.data
@@ -677,7 +726,7 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
677
726
  conversation_id=conversation.conversation_id,
678
727
  user_id=conversation.user_id,
679
728
  messages=[asdict(msg) for msg in conversation.messages],
680
- metadata=conversation.metadata
729
+ metadata=conversation.metadata,
681
730
  )
682
731
 
683
732
  return ConversationResponse(success=True, data=conversation_data)
@@ -687,12 +736,12 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
687
736
  result = await config.default_memory_provider.delete_conversation(conversation_id)
688
737
 
689
738
  # Handle Result type properly
690
- if hasattr(result, 'error'): # Failure case
739
+ if hasattr(result, "error"): # Failure case
691
740
  raise HTTPException(status_code=500, detail=str(result.error.message))
692
741
 
693
742
  return DeleteConversationResponse(
694
743
  success=True,
695
- data=DeleteConversationData(conversation_id=conversation_id, deleted=result.data)
744
+ data=DeleteConversationData(conversation_id=conversation_id, deleted=result.data),
696
745
  )
697
746
 
698
747
  @app.get("/memory/health", response_model=MemoryHealthResponse)
@@ -700,7 +749,7 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
700
749
  result = await config.default_memory_provider.health_check()
701
750
 
702
751
  # Handle Result type properly
703
- if hasattr(result, 'error'): # Failure case
752
+ if hasattr(result, "error"): # Failure case
704
753
  raise HTTPException(status_code=500, detail=str(result.error.message))
705
754
 
706
755
  return MemoryHealthResponse(success=True, data=result.data)
@@ -712,87 +761,92 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
712
761
  try:
713
762
  if not conversation_id:
714
763
  raise HTTPException(status_code=400, detail="conversation_id is required")
715
-
764
+
716
765
  if not config.default_memory_provider:
717
766
  return PendingApprovalsResponse(
718
- success=False,
719
- error="Memory provider not configured"
767
+ success=False, error="Memory provider not configured"
720
768
  )
721
-
769
+
722
770
  # Get conversation to analyze pending approvals
723
771
  conv_result = await config.default_memory_provider.get_conversation(conversation_id)
724
- if hasattr(conv_result, 'error'):
772
+ if hasattr(conv_result, "error"):
725
773
  return PendingApprovalsResponse(success=False, error=str(conv_result.error.message))
726
-
774
+
727
775
  if not conv_result.data:
728
- return PendingApprovalsResponse(
729
- success=True,
730
- data=PendingApprovalsData(pending=[])
731
- )
732
-
776
+ return PendingApprovalsResponse(success=True, data=PendingApprovalsData(pending=[]))
777
+
733
778
  conversation = conv_result.data
734
779
  messages = conversation.messages
735
- approvals_meta = getattr(conversation.metadata, 'tool_approvals', {}) if conversation.metadata else {}
736
-
780
+ approvals_meta = (
781
+ getattr(conversation.metadata, "tool_approvals", {})
782
+ if conversation.metadata
783
+ else {}
784
+ )
785
+
737
786
  # Find most recent assistant message with tool calls
738
787
  assistant_msg = None
739
788
  assistant_index = -1
740
789
  for i in range(len(messages) - 1, -1, -1):
741
790
  msg = messages[i]
742
- if hasattr(msg, 'role') and msg.role == 'assistant' and hasattr(msg, 'tool_calls') and msg.tool_calls:
791
+ if (
792
+ hasattr(msg, "role")
793
+ and msg.role == "assistant"
794
+ and hasattr(msg, "tool_calls")
795
+ and msg.tool_calls
796
+ ):
743
797
  assistant_msg = msg
744
798
  assistant_index = i
745
799
  break
746
-
800
+
747
801
  if not assistant_msg:
748
- return PendingApprovalsResponse(
749
- success=True,
750
- data=PendingApprovalsData(pending=[])
751
- )
752
-
802
+ return PendingApprovalsResponse(success=True, data=PendingApprovalsData(pending=[]))
803
+
753
804
  # Check which tool calls have already been executed
754
805
  tool_ids = {tc.id for tc in assistant_msg.tool_calls}
755
806
  executed = set()
756
807
  for j in range(assistant_index + 1, len(messages)):
757
808
  msg = messages[j]
758
- if hasattr(msg, 'role') and msg.role == 'tool' and hasattr(msg, 'tool_call_id'):
809
+ if hasattr(msg, "role") and msg.role == "tool" and hasattr(msg, "tool_call_id"):
759
810
  if msg.tool_call_id in tool_ids:
760
811
  executed.add(msg.tool_call_id)
761
-
812
+
762
813
  # Build pending approvals list
763
814
  pending_approvals = []
764
815
  for tc in assistant_msg.tool_calls:
765
816
  if tc.id in executed:
766
817
  continue # Already executed
767
-
818
+
768
819
  # Check approval status
769
820
  approval_key = f"{conversation.conversation_id}:{tc.id}"
770
821
  approval_entry = approvals_meta.get(approval_key)
771
-
772
- status = 'pending'
822
+
823
+ status = "pending"
773
824
  if approval_entry:
774
- status = approval_entry.get('status', 'pending')
775
- if approval_entry.get('approved') is True:
776
- status = 'approved'
777
- elif approval_entry.get('approved') is False:
778
- status = 'rejected'
779
-
780
- if status == 'pending':
781
- pending_approvals.append(PendingApprovalData(
782
- conversation_id=conversation_id,
783
- tool_call_id=tc.id,
784
- tool_name=tc.function.name,
785
- args=try_parse_json(tc.function.arguments),
786
- signature=compute_tool_call_signature(tc),
787
- status='pending',
788
- session_id=getattr(conversation.metadata, 'run_id', None) if conversation.metadata else None
789
- ))
790
-
825
+ status = approval_entry.get("status", "pending")
826
+ if approval_entry.get("approved") is True:
827
+ status = "approved"
828
+ elif approval_entry.get("approved") is False:
829
+ status = "rejected"
830
+
831
+ if status == "pending":
832
+ pending_approvals.append(
833
+ PendingApprovalData(
834
+ conversation_id=conversation_id,
835
+ tool_call_id=tc.id,
836
+ tool_name=tc.function.name,
837
+ args=try_parse_json(tc.function.arguments),
838
+ signature=compute_tool_call_signature(tc),
839
+ status="pending",
840
+ session_id=getattr(conversation.metadata, "run_id", None)
841
+ if conversation.metadata
842
+ else None,
843
+ )
844
+ )
845
+
791
846
  return PendingApprovalsResponse(
792
- success=True,
793
- data=PendingApprovalsData(pending=pending_approvals)
847
+ success=True, data=PendingApprovalsData(pending=pending_approvals)
794
848
  )
795
-
849
+
796
850
  except Exception as e:
797
851
  return PendingApprovalsResponse(success=False, error=str(e))
798
852
 
@@ -809,9 +863,9 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
809
863
  stream=request_body.stream,
810
864
  conversation_id=request_body.conversation_id,
811
865
  memory=request_body.memory,
812
- approvals=request_body.approvals
866
+ approvals=request_body.approvals,
813
867
  )
814
-
868
+
815
869
  # Delegate to main chat endpoint logic
816
870
  return await chat_completion(modified_request)
817
871
 
@@ -819,39 +873,40 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
819
873
  @app.get("/approvals/stream")
820
874
  async def stream_approval_updates(request: Request, conversation_id: str = None):
821
875
  """Stream real-time approval updates via Server-Sent Events."""
876
+
822
877
  async def event_stream():
823
878
  # Simple client structure matching TypeScript
824
879
  client = {
825
- 'response': request, # Store request for disconnection check
826
- 'filter_conversation_id': conversation_id
880
+ "response": request, # Store request for disconnection check
881
+ "filter_conversation_id": conversation_id,
827
882
  }
828
883
  approval_subscribers.add(client)
829
-
884
+
830
885
  try:
831
886
  # Initial greeting (matching TypeScript)
832
887
  yield f"event: stream_start\ndata: {json.dumps({'conversationId': conversation_id})}\n\n"
833
-
888
+
834
889
  # Heartbeat like TypeScript (15 second interval)
835
890
  last_heartbeat = time.time()
836
-
891
+
837
892
  while True:
838
893
  # Check client disconnection
839
894
  if await request.is_disconnected():
840
895
  break
841
-
896
+
842
897
  # Send heartbeat every 15 seconds
843
898
  current_time = time.time()
844
899
  if current_time - last_heartbeat >= 15:
845
900
  yield f"event: ping\ndata: {json.dumps({'ts': int(current_time * 1000)})}\n\n"
846
901
  last_heartbeat = current_time
847
-
902
+
848
903
  await asyncio.sleep(1)
849
-
904
+
850
905
  except Exception as e:
851
906
  yield f"event: error\ndata: {json.dumps({'message': str(e)})}\n\n"
852
907
  finally:
853
908
  approval_subscribers.discard(client)
854
-
909
+
855
910
  return StreamingResponse(
856
911
  event_stream(),
857
912
  media_type="text/event-stream",
@@ -860,111 +915,119 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
860
915
  "Connection": "keep-alive",
861
916
  "X-Accel-Buffering": "no",
862
917
  "Access-Control-Allow-Origin": "*",
863
- "Access-Control-Allow-Headers": "*"
864
- }
918
+ "Access-Control-Allow-Headers": "*",
919
+ },
865
920
  )
866
921
 
867
922
  # Regeneration endpoints
868
923
  if config.default_memory_provider:
869
- @app.post("/conversations/{conversation_id}/regenerate", response_model=RegenerationResponse)
870
- async def regenerate_conversation_endpoint(conversation_id: str, request: RegenerationHttpRequest):
924
+
925
+ @app.post(
926
+ "/conversations/{conversation_id}/regenerate", response_model=RegenerationResponse
927
+ )
928
+ async def regenerate_conversation_endpoint(
929
+ conversation_id: str, request: RegenerationHttpRequest
930
+ ):
871
931
  """Regenerate conversation from a specific message."""
872
932
  request_start_time = time.time()
873
-
933
+
874
934
  try:
875
935
  # Validate agent exists
876
936
  if request.agent_name not in config.agent_registry:
877
937
  return RegenerationResponse(
878
938
  success=False,
879
- error=f"Agent '{request.agent_name}' not found. Available agents: {', '.join(config.agent_registry.keys())}"
939
+ error=f"Agent '{request.agent_name}' not found. Available agents: {', '.join(config.agent_registry.keys())}",
880
940
  )
881
-
941
+
882
942
  # Create regeneration request
883
943
  regen_request = RegenerationRequest(
884
944
  conversation_id=conversation_id,
885
945
  message_id=create_message_id(request.message_id),
886
- context=request.context
946
+ context=request.context,
887
947
  )
888
-
948
+
889
949
  # Create run config with memory
890
950
  memory_config = MemoryConfig(
891
951
  provider=config.default_memory_provider,
892
952
  auto_store=True,
893
- store_on_completion=True
953
+ store_on_completion=True,
894
954
  )
895
-
955
+
896
956
  run_config_with_memory = replace(
897
957
  config.run_config,
898
958
  memory=memory_config,
899
959
  conversation_id=conversation_id,
900
- max_turns=request.max_turns or 10
960
+ max_turns=request.max_turns or 10,
901
961
  )
902
-
962
+
903
963
  # Execute regeneration
904
964
  result = await regenerate_conversation(
905
- regen_request,
906
- run_config_with_memory,
907
- request.context or {},
908
- request.agent_name
965
+ regen_request, run_config_with_memory, request.context or {}, request.agent_name
909
966
  )
910
-
967
+
911
968
  # Convert result to HTTP format
912
- http_messages = [_convert_core_message_to_http(msg) for msg in result.final_state.messages]
913
-
969
+ http_messages = [
970
+ _convert_core_message_to_http(msg) for msg in result.final_state.messages
971
+ ]
972
+
914
973
  # Create outcome data
915
974
  if isinstance(result.outcome, CompletedOutcome):
916
- outcome_data = BaseOutcomeData(
917
- status='completed',
918
- output=result.outcome.output
919
- )
975
+ outcome_data = BaseOutcomeData(status="completed", output=result.outcome.output)
920
976
  elif isinstance(result.outcome, ErrorOutcome):
921
977
  error_info = result.outcome.error
922
978
  outcome_data = BaseOutcomeData(
923
- status='error',
924
- error={
925
- 'type': error_info.__class__.__name__,
926
- 'message': str(error_info)
927
- }
979
+ status="error",
980
+ error={"type": error_info.__class__.__name__, "message": str(error_info)},
928
981
  )
929
982
  elif isinstance(result.outcome, InterruptedOutcome):
930
983
  interruptions = []
931
984
  for interruption in result.outcome.interruptions:
932
- if hasattr(interruption, 'tool_call') and hasattr(interruption, 'type'):
985
+ if hasattr(interruption, "tool_call") and hasattr(interruption, "type"):
933
986
  tool_call_data = ToolCallInterruption(
934
987
  id=interruption.tool_call.id,
935
988
  function={
936
- 'name': interruption.tool_call.function.name,
937
- 'arguments': interruption.tool_call.function.arguments
938
- }
989
+ "name": interruption.tool_call.function.name,
990
+ "arguments": interruption.tool_call.function.arguments,
991
+ },
992
+ )
993
+ interruptions.append(
994
+ InterruptionData(
995
+ type="tool_approval",
996
+ tool_call=tool_call_data,
997
+ session_id=interruption.session_id
998
+ or str(result.final_state.run_id),
999
+ )
939
1000
  )
940
- interruptions.append(InterruptionData(
941
- type='tool_approval',
942
- tool_call=tool_call_data,
943
- session_id=interruption.session_id or str(result.final_state.run_id)
944
- ))
945
-
1001
+
946
1002
  outcome_data = InterruptedOutcomeData(
947
- status='interrupted',
948
- interruptions=interruptions
1003
+ status="interrupted", interruptions=interruptions
949
1004
  )
950
1005
  else:
951
- outcome_data = BaseOutcomeData(status='error', error='Unknown outcome type')
952
-
1006
+ outcome_data = BaseOutcomeData(status="error", error="Unknown outcome type")
1007
+
953
1008
  # Get regeneration metadata from conversation
954
- conversation_result = await config.default_memory_provider.get_conversation(conversation_id)
1009
+ conversation_result = await config.default_memory_provider.get_conversation(
1010
+ conversation_id
1011
+ )
955
1012
  regeneration_id = f"regen_{int(time.time() * 1000)}_{request.message_id}"
956
1013
  original_message_count = 0
957
1014
  truncated_at_index = 0
958
-
959
- if hasattr(conversation_result, 'data') and conversation_result.data:
1015
+
1016
+ if hasattr(conversation_result, "data") and conversation_result.data:
960
1017
  conversation_data = conversation_result.data
961
- regeneration_points = conversation_data.metadata.get('regeneration_points', []) if conversation_data.metadata else []
1018
+ regeneration_points = (
1019
+ conversation_data.metadata.get("regeneration_points", [])
1020
+ if conversation_data.metadata
1021
+ else []
1022
+ )
962
1023
  if regeneration_points:
963
1024
  latest_regen = regeneration_points[-1]
964
- original_message_count = latest_regen.get('original_message_count', len(conversation_data.messages))
965
- truncated_at_index = latest_regen.get('truncated_at_index', 0)
966
- regeneration_id = latest_regen.get('regeneration_id', regeneration_id)
967
-
1025
+ original_message_count = latest_regen.get(
1026
+ "original_message_count", len(conversation_data.messages)
1027
+ )
1028
+ truncated_at_index = latest_regen.get("truncated_at_index", 0)
1029
+ regeneration_id = latest_regen.get("regeneration_id", regeneration_id)
1030
+
968
1031
  return RegenerationResponse(
969
1032
  success=True,
970
1033
  data=RegenerationData(
@@ -976,45 +1039,137 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
976
1039
  messages=http_messages,
977
1040
  outcome=outcome_data,
978
1041
  turn_count=result.final_state.turn_count,
979
- execution_time_ms=int((time.time() - request_start_time) * 1000)
980
- )
1042
+ execution_time_ms=int((time.time() - request_start_time) * 1000),
1043
+ ),
981
1044
  )
982
-
1045
+
983
1046
  except Exception as e:
984
1047
  return RegenerationResponse(success=False, error=str(e))
985
1048
 
986
- @app.get("/conversations/{conversation_id}/regeneration-history", response_model=RegenerationHistoryResponse)
1049
+ @app.get(
1050
+ "/conversations/{conversation_id}/regeneration-history",
1051
+ response_model=RegenerationHistoryResponse,
1052
+ )
987
1053
  async def get_regeneration_history(conversation_id: str):
988
1054
  """Get regeneration history for a conversation."""
989
1055
  try:
990
- regeneration_points = await get_regeneration_points(conversation_id, config.run_config)
991
-
1056
+ regeneration_points = await get_regeneration_points(
1057
+ conversation_id, config.run_config
1058
+ )
1059
+
992
1060
  if regeneration_points is None:
993
1061
  return RegenerationHistoryResponse(
994
- success=False,
995
- error="Failed to get regeneration history"
1062
+ success=False, error="Failed to get regeneration history"
996
1063
  )
997
-
1064
+
998
1065
  # Convert to response format
999
1066
  regeneration_data = []
1000
1067
  for point in regeneration_points:
1001
- regeneration_data.append(RegenerationPointData(
1002
- regeneration_id=point.get('regeneration_id', ''),
1003
- message_id=point.get('message_id', ''),
1004
- timestamp=point.get('timestamp', 0),
1005
- original_message_count=point.get('original_message_count', 0),
1006
- truncated_at_index=point.get('truncated_at_index', 0)
1007
- ))
1008
-
1068
+ regeneration_data.append(
1069
+ RegenerationPointData(
1070
+ regeneration_id=point.get("regeneration_id", ""),
1071
+ message_id=point.get("message_id", ""),
1072
+ timestamp=point.get("timestamp", 0),
1073
+ original_message_count=point.get("original_message_count", 0),
1074
+ truncated_at_index=point.get("truncated_at_index", 0),
1075
+ )
1076
+ )
1077
+
1009
1078
  return RegenerationHistoryResponse(
1010
1079
  success=True,
1011
1080
  data=RegenerationHistoryData(
1012
- conversation_id=conversation_id,
1013
- regeneration_points=regeneration_data
1014
- )
1081
+ conversation_id=conversation_id, regeneration_points=regeneration_data
1082
+ ),
1015
1083
  )
1016
-
1084
+
1017
1085
  except Exception as e:
1018
1086
  return RegenerationHistoryResponse(success=False, error=str(e))
1019
1087
 
1088
+ # Checkpoint endpoints
1089
+ @app.post("/conversations/{conversation_id}/checkpoint", response_model=CheckpointResponse)
1090
+ async def checkpoint_conversation_endpoint(
1091
+ conversation_id: str, request: CheckpointHttpRequest
1092
+ ):
1093
+ """Checkpoint conversation after a specific message."""
1094
+ request_start_time = time.time()
1095
+
1096
+ try:
1097
+ # Create checkpoint request
1098
+ chk_request = CheckpointRequest(
1099
+ conversation_id=conversation_id,
1100
+ message_id=create_message_id(request.message_id),
1101
+ context=request.context,
1102
+ )
1103
+
1104
+ # Create run config with memory
1105
+ memory_config = MemoryConfig(
1106
+ provider=config.default_memory_provider,
1107
+ auto_store=True,
1108
+ store_on_completion=True,
1109
+ )
1110
+
1111
+ run_config_with_memory = replace(
1112
+ config.run_config, memory=memory_config, conversation_id=conversation_id
1113
+ )
1114
+
1115
+ # Execute checkpoint
1116
+ result = await checkpoint_conversation(chk_request, run_config_with_memory)
1117
+
1118
+ # Convert result to HTTP format
1119
+ http_messages = [_convert_core_message_to_http(msg) for msg in result.messages]
1120
+
1121
+ return CheckpointResponse(
1122
+ success=True,
1123
+ data=CheckpointData(
1124
+ checkpoint_id=result.checkpoint_id,
1125
+ conversation_id=result.conversation_id,
1126
+ original_message_count=result.original_message_count,
1127
+ checkpointed_at_index=result.checkpointed_at_index,
1128
+ checkpointed_message_id=str(result.checkpointed_message_id),
1129
+ messages=http_messages,
1130
+ execution_time_ms=result.execution_time_ms,
1131
+ ),
1132
+ )
1133
+
1134
+ except Exception as e:
1135
+ return CheckpointResponse(success=False, error=str(e))
1136
+
1137
+ @app.get(
1138
+ "/conversations/{conversation_id}/checkpoint-history",
1139
+ response_model=CheckpointHistoryResponse,
1140
+ )
1141
+ async def get_checkpoint_history_endpoint(conversation_id: str):
1142
+ """Get checkpoint history for a conversation."""
1143
+ try:
1144
+ checkpoint_points = await get_checkpoint_history(conversation_id, config.run_config)
1145
+
1146
+ if checkpoint_points is None:
1147
+ return CheckpointHistoryResponse(
1148
+ success=False, error="Failed to get checkpoint history"
1149
+ )
1150
+
1151
+ # Convert to response format
1152
+ checkpoint_data = []
1153
+ for point in checkpoint_points:
1154
+ checkpoint_data.append(
1155
+ CheckpointPointData(
1156
+ checkpoint_id=point.get("checkpoint_id", ""),
1157
+ checkpoint_point=point.get("checkpoint_point", ""),
1158
+ timestamp=point.get("timestamp", 0),
1159
+ original_message_count=point.get("original_message_count", 0),
1160
+ checkpointed_at_index=point.get("checkpointed_at_index", 0),
1161
+ checkpointed_messages=point.get("checkpointed_messages", 0),
1162
+ )
1163
+ )
1164
+
1165
+ return CheckpointHistoryResponse(
1166
+ success=True,
1167
+ data=CheckpointHistoryData(
1168
+ conversation_id=conversation_id, checkpoint_points=checkpoint_data
1169
+ ),
1170
+ )
1171
+
1172
+ except Exception as e:
1173
+ return CheckpointHistoryResponse(success=False, error=str(e))
1174
+
1020
1175
  return app