jaf-py 2.4.1__py3-none-any.whl → 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +15 -0
- jaf/core/agent_tool.py +6 -4
- jaf/core/analytics.py +4 -3
- jaf/core/engine.py +401 -37
- jaf/core/state.py +156 -0
- jaf/core/tracing.py +114 -23
- jaf/core/types.py +113 -3
- jaf/memory/approval_storage.py +306 -0
- jaf/memory/types.py +1 -0
- jaf/memory/utils.py +1 -1
- jaf/providers/model.py +277 -17
- jaf/server/__init__.py +2 -0
- jaf/server/server.py +665 -22
- jaf/server/types.py +149 -4
- jaf/utils/__init__.py +50 -0
- jaf/utils/attachments.py +401 -0
- jaf/utils/document_processor.py +561 -0
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/METADATA +10 -2
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/RECORD +23 -18
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/WHEEL +0 -0
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/top_level.txt +0 -0
jaf/server/server.py
CHANGED
|
@@ -7,17 +7,25 @@ via REST API endpoints with proper error handling and validation.
|
|
|
7
7
|
|
|
8
8
|
import time
|
|
9
9
|
import uuid
|
|
10
|
+
import asyncio
|
|
11
|
+
import json
|
|
10
12
|
from dataclasses import asdict, replace
|
|
11
|
-
from typing import TypeVar
|
|
13
|
+
from typing import TypeVar, Dict, Set
|
|
12
14
|
|
|
13
|
-
from fastapi import FastAPI, HTTPException
|
|
15
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
14
16
|
from fastapi.middleware.cors import CORSMiddleware
|
|
17
|
+
from fastapi.responses import StreamingResponse
|
|
15
18
|
|
|
16
19
|
from ..core.engine import run
|
|
20
|
+
from ..core.streaming import run_streaming
|
|
17
21
|
from ..core.types import (
|
|
22
|
+
ApprovalValue,
|
|
18
23
|
CompletedOutcome,
|
|
19
24
|
ErrorOutcome,
|
|
25
|
+
InterruptedOutcome,
|
|
20
26
|
Message,
|
|
27
|
+
MessageContentPart,
|
|
28
|
+
Attachment,
|
|
21
29
|
RunState,
|
|
22
30
|
create_run_id,
|
|
23
31
|
create_trace_id,
|
|
@@ -27,6 +35,8 @@ from .types import (
|
|
|
27
35
|
AgentInfo,
|
|
28
36
|
AgentListData,
|
|
29
37
|
AgentListResponse,
|
|
38
|
+
ApprovalMessage,
|
|
39
|
+
BaseOutcomeData,
|
|
30
40
|
ChatRequest,
|
|
31
41
|
ChatResponse,
|
|
32
42
|
CompletedChatData,
|
|
@@ -36,20 +46,215 @@ from .types import (
|
|
|
36
46
|
DeleteConversationResponse,
|
|
37
47
|
HealthResponse,
|
|
38
48
|
HttpMessage,
|
|
49
|
+
InterruptedOutcomeData,
|
|
50
|
+
InterruptionData,
|
|
39
51
|
MemoryHealthResponse,
|
|
52
|
+
PendingApprovalData,
|
|
53
|
+
PendingApprovalsData,
|
|
54
|
+
PendingApprovalsResponse,
|
|
40
55
|
ServerConfig,
|
|
56
|
+
ToolCallInterruption,
|
|
41
57
|
)
|
|
42
58
|
|
|
43
59
|
Ctx = TypeVar('Ctx')
|
|
44
60
|
|
|
61
|
+
# Helper functions for HITL (moved outside like TypeScript)
|
|
62
|
+
def stable_stringify(value) -> str:
|
|
63
|
+
"""Create deterministic JSON string for tool call signatures."""
|
|
64
|
+
try:
|
|
65
|
+
if isinstance(value, dict):
|
|
66
|
+
return json.dumps(value, sort_keys=True, separators=(',', ':'))
|
|
67
|
+
return json.dumps(value, separators=(',', ':'))
|
|
68
|
+
except (TypeError, ValueError):
|
|
69
|
+
return str(value)
|
|
70
|
+
|
|
71
|
+
def try_parse_json(s: str):
|
|
72
|
+
"""Try to parse JSON, return original string if it fails."""
|
|
73
|
+
try:
|
|
74
|
+
return json.loads(s)
|
|
75
|
+
except (json.JSONDecodeError, TypeError):
|
|
76
|
+
return s
|
|
77
|
+
|
|
78
|
+
def compute_tool_call_signature(tool_call) -> str:
|
|
79
|
+
"""Compute deterministic signature for tool call matching."""
|
|
80
|
+
try:
|
|
81
|
+
args = try_parse_json(tool_call.function.arguments)
|
|
82
|
+
return f"{tool_call.function.name}:{stable_stringify(args)}"
|
|
83
|
+
except Exception:
|
|
84
|
+
return f"{tool_call.function.name}:unknown"
|
|
85
|
+
|
|
86
|
+
def _convert_http_message_to_core(http_msg: HttpMessage) -> Message:
|
|
87
|
+
"""Convert HTTP message format to core Message format."""
|
|
88
|
+
# Convert content
|
|
89
|
+
if isinstance(http_msg.content, str):
|
|
90
|
+
content = http_msg.content
|
|
91
|
+
else:
|
|
92
|
+
# Convert list of content parts
|
|
93
|
+
content_parts = []
|
|
94
|
+
for i, part in enumerate(http_msg.content):
|
|
95
|
+
if part.type == 'text':
|
|
96
|
+
content_parts.append(MessageContentPart(
|
|
97
|
+
type='text',
|
|
98
|
+
text=part.text,
|
|
99
|
+
image_url=None,
|
|
100
|
+
file=None
|
|
101
|
+
))
|
|
102
|
+
elif part.type == 'image_url':
|
|
103
|
+
content_parts.append(MessageContentPart(
|
|
104
|
+
type='image_url',
|
|
105
|
+
text=None,
|
|
106
|
+
image_url=part.image_url,
|
|
107
|
+
file=None
|
|
108
|
+
))
|
|
109
|
+
elif part.type == 'file':
|
|
110
|
+
content_parts.append(MessageContentPart(
|
|
111
|
+
type='file',
|
|
112
|
+
text=None,
|
|
113
|
+
image_url=None,
|
|
114
|
+
file=part.file
|
|
115
|
+
))
|
|
116
|
+
else:
|
|
117
|
+
# Raise explicit error for unrecognized part types
|
|
118
|
+
raise ValueError(f"Unrecognized message content part type: '{part.type}' at index {i}. "
|
|
119
|
+
f"Supported types are: 'text', 'image_url', 'file'")
|
|
120
|
+
content = content_parts
|
|
121
|
+
|
|
122
|
+
# Convert attachments
|
|
123
|
+
attachments = None
|
|
124
|
+
if http_msg.attachments:
|
|
125
|
+
attachments = [
|
|
126
|
+
Attachment(
|
|
127
|
+
kind=att.kind,
|
|
128
|
+
mime_type=att.mime_type,
|
|
129
|
+
name=att.name,
|
|
130
|
+
url=att.url,
|
|
131
|
+
data=att.data,
|
|
132
|
+
format=att.format,
|
|
133
|
+
use_litellm_format=att.use_litellm_format
|
|
134
|
+
)
|
|
135
|
+
for att in http_msg.attachments
|
|
136
|
+
]
|
|
137
|
+
|
|
138
|
+
return Message(
|
|
139
|
+
role=http_msg.role,
|
|
140
|
+
content=content,
|
|
141
|
+
attachments=attachments,
|
|
142
|
+
tool_call_id=http_msg.tool_call_id,
|
|
143
|
+
tool_calls=http_msg.tool_calls
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def _convert_core_message_to_http(core_msg: Message) -> HttpMessage:
|
|
147
|
+
"""Convert core Message format to HTTP message format."""
|
|
148
|
+
from .types import HttpAttachment, HttpMessageContentPart
|
|
149
|
+
from ..core.types import get_text_content
|
|
150
|
+
|
|
151
|
+
# Convert content
|
|
152
|
+
if isinstance(core_msg.content, str):
|
|
153
|
+
content = core_msg.content
|
|
154
|
+
elif isinstance(core_msg.content, list):
|
|
155
|
+
# Convert content parts to HTTP format
|
|
156
|
+
http_parts = []
|
|
157
|
+
for i, part in enumerate(core_msg.content):
|
|
158
|
+
if part.type == 'text':
|
|
159
|
+
http_parts.append(HttpMessageContentPart(
|
|
160
|
+
type='text',
|
|
161
|
+
text=part.text,
|
|
162
|
+
image_url=None,
|
|
163
|
+
file=None
|
|
164
|
+
))
|
|
165
|
+
elif part.type == 'image_url':
|
|
166
|
+
http_parts.append(HttpMessageContentPart(
|
|
167
|
+
type='image_url',
|
|
168
|
+
text=None,
|
|
169
|
+
image_url=part.image_url,
|
|
170
|
+
file=None
|
|
171
|
+
))
|
|
172
|
+
elif part.type == 'file':
|
|
173
|
+
http_parts.append(HttpMessageContentPart(
|
|
174
|
+
type='file',
|
|
175
|
+
text=None,
|
|
176
|
+
image_url=None,
|
|
177
|
+
file=part.file
|
|
178
|
+
))
|
|
179
|
+
else:
|
|
180
|
+
# Raise explicit error for unrecognized part types
|
|
181
|
+
message_info = f"role={core_msg.role}"
|
|
182
|
+
raise ValueError(f"Unrecognized core message content part type: '{part.type}' at index {i}. "
|
|
183
|
+
f"Message info: {message_info}. "
|
|
184
|
+
f"Supported types are: 'text', 'image_url', 'file'")
|
|
185
|
+
content = http_parts
|
|
186
|
+
else:
|
|
187
|
+
content = get_text_content(core_msg.content)
|
|
188
|
+
|
|
189
|
+
# Convert attachments
|
|
190
|
+
attachments = None
|
|
191
|
+
if core_msg.attachments:
|
|
192
|
+
attachments = [
|
|
193
|
+
HttpAttachment(
|
|
194
|
+
kind=att.kind,
|
|
195
|
+
mime_type=att.mime_type,
|
|
196
|
+
name=att.name,
|
|
197
|
+
url=att.url,
|
|
198
|
+
data=att.data,
|
|
199
|
+
format=att.format,
|
|
200
|
+
use_litellm_format=att.use_litellm_format
|
|
201
|
+
)
|
|
202
|
+
for att in core_msg.attachments
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
return HttpMessage(
|
|
206
|
+
role=core_msg.role,
|
|
207
|
+
content=content,
|
|
208
|
+
attachments=attachments,
|
|
209
|
+
tool_call_id=core_msg.tool_call_id,
|
|
210
|
+
tool_calls=core_msg.tool_calls
|
|
211
|
+
)
|
|
212
|
+
|
|
45
213
|
def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
|
|
46
214
|
"""Create and configure a JAF server instance."""
|
|
47
215
|
|
|
48
216
|
start_time = time.time()
|
|
217
|
+
|
|
218
|
+
# SSE subscribers for approval-related events (matching TypeScript)
|
|
219
|
+
approval_subscribers = set()
|
|
220
|
+
|
|
221
|
+
def sse_send(response, event: str, data: dict):
|
|
222
|
+
"""Send SSE event to client."""
|
|
223
|
+
try:
|
|
224
|
+
response.write(f"event: {event}\n")
|
|
225
|
+
response.write(f"data: {json.dumps(data)}\n\n")
|
|
226
|
+
except Exception:
|
|
227
|
+
pass # ignore connection errors
|
|
228
|
+
|
|
229
|
+
def broadcast_approval_required(payload: dict):
|
|
230
|
+
"""Broadcast approval_required event to SSE clients."""
|
|
231
|
+
for client in approval_subscribers.copy(): # copy to avoid modification during iteration
|
|
232
|
+
filter_conv_id = client.get('filter_conversation_id')
|
|
233
|
+
if filter_conv_id and filter_conv_id != payload.get('conversationId'):
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
payload_with_timestamp = {
|
|
237
|
+
**payload,
|
|
238
|
+
'timestamp': payload.get('timestamp', time.strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
|
|
239
|
+
}
|
|
240
|
+
sse_send(client['response'], 'approval_required', payload_with_timestamp)
|
|
241
|
+
|
|
242
|
+
def broadcast_approval_decision(payload: dict):
|
|
243
|
+
"""Broadcast approval_decision event to SSE clients."""
|
|
244
|
+
for client in approval_subscribers.copy(): # copy to avoid modification during iteration
|
|
245
|
+
filter_conv_id = client.get('filter_conversation_id')
|
|
246
|
+
if filter_conv_id and filter_conv_id != payload.get('conversationId'):
|
|
247
|
+
continue
|
|
248
|
+
|
|
249
|
+
payload_with_timestamp = {
|
|
250
|
+
**payload,
|
|
251
|
+
'timestamp': payload.get('timestamp', time.strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
|
|
252
|
+
}
|
|
253
|
+
sse_send(client['response'], 'approval_decision', payload_with_timestamp)
|
|
49
254
|
|
|
50
255
|
app = FastAPI(
|
|
51
256
|
title="JAF Agent Framework Server",
|
|
52
|
-
description="HTTP API for JAF agents",
|
|
257
|
+
description="HTTP API for JAF agents with HITL support",
|
|
53
258
|
version="2.0.0",
|
|
54
259
|
docs_url="/docs",
|
|
55
260
|
redoc_url="/redoc"
|
|
@@ -91,12 +296,36 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
|
|
|
91
296
|
|
|
92
297
|
@app.post("/chat", response_model=ChatResponse)
|
|
93
298
|
async def chat_completion(request: ChatRequest):
|
|
94
|
-
request_start_time = time.time()
|
|
299
|
+
request_start_time = time.time()
|
|
95
300
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
301
|
+
try:
|
|
302
|
+
# Validate request (matching TypeScript approach)
|
|
303
|
+
validated_request = request # Already validated by FastAPI, but keeping TypeScript structure
|
|
304
|
+
|
|
305
|
+
# Check if agent exists (matching TypeScript response pattern)
|
|
306
|
+
if validated_request.agent_name not in config.agent_registry:
|
|
307
|
+
return ChatResponse(
|
|
308
|
+
success=False,
|
|
309
|
+
error=f"Agent '{validated_request.agent_name}' not found. Available agents: {', '.join(config.agent_registry.keys())}"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Convert HTTP messages to JAF messages (matching TypeScript)
|
|
313
|
+
jaf_messages = [
|
|
314
|
+
Message(
|
|
315
|
+
role='user' if msg.role == 'system' else msg.role,
|
|
316
|
+
content=msg.content
|
|
317
|
+
)
|
|
318
|
+
for msg in validated_request.messages
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
# Create initial state (matching TypeScript)
|
|
322
|
+
run_id = create_run_id(str(uuid.uuid4()))
|
|
323
|
+
trace_id = create_trace_id(str(uuid.uuid4()))
|
|
324
|
+
|
|
325
|
+
# Generate conversationId if not provided (matching TypeScript)
|
|
326
|
+
conversation_id = validated_request.conversation_id or f"conv-{uuid.uuid4()}"
|
|
327
|
+
except Exception as e:
|
|
328
|
+
return ChatResponse(success=False, error=f"Invalid request: {str(e)}")
|
|
100
329
|
|
|
101
330
|
# Load conversation history to get correct turn count
|
|
102
331
|
initial_turn_count = 0
|
|
@@ -111,45 +340,300 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
|
|
|
111
340
|
except Exception as e:
|
|
112
341
|
print(f"[JAF:SERVER] Warning: Failed to load conversation history: {e}")
|
|
113
342
|
|
|
343
|
+
# Handle approval message(s) if present (matching TypeScript approach)
|
|
344
|
+
initial_approvals = {} # Will act like TypeScript's Map
|
|
345
|
+
initial_state_messages = jaf_messages
|
|
346
|
+
|
|
347
|
+
approvals_list = validated_request.approvals or []
|
|
348
|
+
|
|
349
|
+
async def persist_approval(conv_id: str, appr: ApprovalMessage):
|
|
350
|
+
"""Persist approval to memory provider with metadata (matching TypeScript)."""
|
|
351
|
+
if not config.default_memory_provider:
|
|
352
|
+
return
|
|
353
|
+
|
|
354
|
+
provider = config.default_memory_provider
|
|
355
|
+
# Keyed by previous run/session id + toolCallId for uniqueness (matching TypeScript)
|
|
356
|
+
approval_key = f"{appr.session_id}:{appr.tool_call_id}"
|
|
357
|
+
base_entry = {
|
|
358
|
+
'approved': appr.approved,
|
|
359
|
+
'status': 'approved' if appr.approved else 'rejected',
|
|
360
|
+
'additionalContext': appr.additional_context,
|
|
361
|
+
'sessionId': appr.session_id,
|
|
362
|
+
'toolCallId': appr.tool_call_id,
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
try:
|
|
366
|
+
existing = await provider.get_conversation(conv_id)
|
|
367
|
+
if existing.success and existing.data:
|
|
368
|
+
# Try to enrich entry with tool name and signature for robust matching (exactly matching TypeScript)
|
|
369
|
+
try:
|
|
370
|
+
msgs = existing.data.messages
|
|
371
|
+
for i in range(len(msgs) - 1, -1, -1):
|
|
372
|
+
m = msgs[i]
|
|
373
|
+
if m.role == 'assistant' and hasattr(m, 'tool_calls') and m.tool_calls:
|
|
374
|
+
match = next((tc for tc in m.tool_calls if tc.id == appr.tool_call_id), None)
|
|
375
|
+
if match:
|
|
376
|
+
base_entry['toolName'] = match.function.name
|
|
377
|
+
base_entry['signature'] = compute_tool_call_signature(match)
|
|
378
|
+
break
|
|
379
|
+
except Exception:
|
|
380
|
+
pass # best-effort
|
|
381
|
+
|
|
382
|
+
existing_approvals = (existing.data.metadata.get('toolApprovals') if existing.data.metadata else {}) or {}
|
|
383
|
+
prev = existing_approvals.get(approval_key)
|
|
384
|
+
|
|
385
|
+
# Merge additionalContext shallowly and avoid regressions (exactly matching TypeScript)
|
|
386
|
+
merged_additional = {
|
|
387
|
+
**(prev.get('additionalContext') if prev else {}),
|
|
388
|
+
**(base_entry.get('additionalContext') or {}),
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
next_entry = {
|
|
392
|
+
**(prev or {}),
|
|
393
|
+
**base_entry,
|
|
394
|
+
'additionalContext': merged_additional,
|
|
395
|
+
# Preserve earliest timestamp if no effective change; else update (exactly matching TypeScript)
|
|
396
|
+
'timestamp': (
|
|
397
|
+
prev.get('timestamp') if prev and (
|
|
398
|
+
prev.get('status') == base_entry['status'] and
|
|
399
|
+
stable_stringify(prev.get('additionalContext')) == stable_stringify(merged_additional)
|
|
400
|
+
) else time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
|
401
|
+
)
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
# Check if there's actually a change (exactly matching TypeScript)
|
|
405
|
+
no_change = prev and (
|
|
406
|
+
prev.get('status') == next_entry['status'] and
|
|
407
|
+
stable_stringify(prev.get('additionalContext')) == stable_stringify(next_entry['additionalContext']) and
|
|
408
|
+
(prev.get('toolName') or None) == (next_entry.get('toolName') or None) and
|
|
409
|
+
(prev.get('signature') or None) == (next_entry.get('signature') or None)
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
if not no_change:
|
|
413
|
+
merged_approvals = {**existing_approvals, approval_key: next_entry}
|
|
414
|
+
await provider.appendMessages(conv_id, [], {'toolApprovals': merged_approvals, 'traceId': trace_id})
|
|
415
|
+
|
|
416
|
+
elif existing.success and not existing.data:
|
|
417
|
+
# Create conversation shell with just metadata if not present (exactly matching TypeScript)
|
|
418
|
+
entry = {**base_entry, 'timestamp': time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')}
|
|
419
|
+
await provider.storeMessages(conv_id, [], {'toolApprovals': {approval_key: entry}, 'traceId': trace_id})
|
|
420
|
+
# If provider call failed, we intentionally do not throw; run will proceed
|
|
421
|
+
except Exception:
|
|
422
|
+
# Ignore persistence errors here to avoid breaking the request path (exactly matching TypeScript)
|
|
423
|
+
pass
|
|
424
|
+
|
|
425
|
+
# Broadcast decision to approvals SSE (exactly matching TypeScript)
|
|
426
|
+
try:
|
|
427
|
+
broadcast_approval_decision({
|
|
428
|
+
'conversationId': conv_id,
|
|
429
|
+
'sessionId': appr.session_id,
|
|
430
|
+
'toolCallId': appr.tool_call_id,
|
|
431
|
+
'status': 'approved' if appr.approved else 'rejected',
|
|
432
|
+
'additionalContext': appr.additional_context
|
|
433
|
+
})
|
|
434
|
+
except Exception:
|
|
435
|
+
pass # ignore
|
|
436
|
+
|
|
437
|
+
if len(approvals_list) > 0:
|
|
438
|
+
for approval in approvals_list:
|
|
439
|
+
if approval.session_id: # Matching TypeScript condition
|
|
440
|
+
initial_approvals[approval.tool_call_id] = {
|
|
441
|
+
'status': 'approved' if approval.approved else 'rejected',
|
|
442
|
+
'approved': approval.approved,
|
|
443
|
+
'additionalContext': approval.additional_context
|
|
444
|
+
}
|
|
445
|
+
await persist_approval(conversation_id, approval)
|
|
446
|
+
|
|
447
|
+
# Seed approvals from persisted conversation metadata
|
|
448
|
+
if config.default_memory_provider:
|
|
449
|
+
try:
|
|
450
|
+
conv_result = await config.default_memory_provider.get_conversation(conversation_id)
|
|
451
|
+
if hasattr(conv_result, 'data') and conv_result.data:
|
|
452
|
+
conversation_data = conv_result.data
|
|
453
|
+
tool_approvals = getattr(conversation_data.metadata, 'tool_approvals', {}) if conversation_data.metadata else {}
|
|
454
|
+
|
|
455
|
+
if tool_approvals:
|
|
456
|
+
# Find latest assistant message with tool calls for matching
|
|
457
|
+
assistant_msg = None
|
|
458
|
+
for msg in reversed(conversation_data.messages):
|
|
459
|
+
if hasattr(msg, 'role') and msg.role == 'assistant' and hasattr(msg, 'tool_calls') and msg.tool_calls:
|
|
460
|
+
assistant_msg = msg
|
|
461
|
+
break
|
|
462
|
+
|
|
463
|
+
if assistant_msg:
|
|
464
|
+
candidate_ids = {tc.id for tc in assistant_msg.tool_calls}
|
|
465
|
+
candidate_signatures = {tc.id: compute_tool_call_signature(tc) for tc in assistant_msg.tool_calls}
|
|
466
|
+
|
|
467
|
+
# Load persisted approvals that aren't already in initial_approvals
|
|
468
|
+
for approval_entry in tool_approvals.values():
|
|
469
|
+
if not isinstance(approval_entry, dict):
|
|
470
|
+
continue
|
|
471
|
+
|
|
472
|
+
persisted_tool_call_id = approval_entry.get('tool_call_id')
|
|
473
|
+
persisted_signature = approval_entry.get('signature')
|
|
474
|
+
|
|
475
|
+
# Try direct ID match first
|
|
476
|
+
target_id = None
|
|
477
|
+
if persisted_tool_call_id and persisted_tool_call_id in candidate_ids:
|
|
478
|
+
target_id = persisted_tool_call_id
|
|
479
|
+
elif persisted_signature:
|
|
480
|
+
# Signature fallback
|
|
481
|
+
for tc_id, sig in candidate_signatures.items():
|
|
482
|
+
if sig == persisted_signature:
|
|
483
|
+
target_id = tc_id
|
|
484
|
+
break
|
|
485
|
+
|
|
486
|
+
if target_id and target_id not in initial_approvals:
|
|
487
|
+
status = approval_entry.get('status', 'pending')
|
|
488
|
+
if approval_entry.get('approved') is True:
|
|
489
|
+
status = 'approved'
|
|
490
|
+
elif approval_entry.get('approved') is False:
|
|
491
|
+
status = 'rejected'
|
|
492
|
+
|
|
493
|
+
initial_approvals[target_id] = ApprovalValue(
|
|
494
|
+
status=status,
|
|
495
|
+
approved=approval_entry.get('approved', False),
|
|
496
|
+
additional_context=approval_entry.get('additional_context')
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
except Exception as e:
|
|
500
|
+
print(f"[JAF:SERVER] Warning: Failed to seed approvals from metadata: {e}")
|
|
501
|
+
|
|
114
502
|
initial_state = RunState(
|
|
115
503
|
run_id=create_run_id(str(uuid.uuid4())),
|
|
116
504
|
trace_id=create_trace_id(str(uuid.uuid4())),
|
|
117
|
-
messages=[
|
|
505
|
+
messages=[_convert_http_message_to_core(msg) for msg in request.messages],
|
|
118
506
|
current_agent_name=request.agent_name,
|
|
119
507
|
context=request.context or {},
|
|
120
|
-
turn_count=initial_turn_count # Use loaded turn count instead of always 0
|
|
508
|
+
turn_count=initial_turn_count, # Use loaded turn count instead of always 0
|
|
509
|
+
approvals=initial_approvals
|
|
121
510
|
)
|
|
122
511
|
|
|
123
512
|
run_config_with_memory = config.run_config
|
|
124
513
|
if config.default_memory_provider:
|
|
514
|
+
# Handle memory configuration with request overrides (matching TypeScript)
|
|
515
|
+
memory_config = MemoryConfig(
|
|
516
|
+
provider=config.default_memory_provider,
|
|
517
|
+
auto_store=request.memory.get('auto_store', True) if request.memory else True,
|
|
518
|
+
max_messages=request.memory.get('max_messages') if request.memory else None,
|
|
519
|
+
compression_threshold=request.memory.get('compression_threshold') if request.memory else None,
|
|
520
|
+
store_on_completion=request.store_on_completion if request.store_on_completion is not None else True
|
|
521
|
+
)
|
|
125
522
|
run_config_with_memory = replace(
|
|
126
523
|
run_config_with_memory,
|
|
127
|
-
memory=
|
|
524
|
+
memory=memory_config,
|
|
128
525
|
conversation_id=conversation_id
|
|
129
526
|
)
|
|
130
527
|
|
|
131
528
|
if request.max_turns is not None:
|
|
132
529
|
run_config_with_memory = replace(run_config_with_memory, max_turns=request.max_turns)
|
|
133
530
|
|
|
531
|
+
# Handle streaming vs non-streaming (matching TypeScript)
|
|
532
|
+
if request.stream:
|
|
533
|
+
async def event_stream():
|
|
534
|
+
try:
|
|
535
|
+
# Send initial metadata
|
|
536
|
+
yield f"event: stream_start\ndata: {json.dumps({
|
|
537
|
+
'runId': str(initial_state.run_id),
|
|
538
|
+
'traceId': str(initial_state.trace_id),
|
|
539
|
+
'conversationId': conversation_id,
|
|
540
|
+
'agent': request.agent_name
|
|
541
|
+
})}\n\n"
|
|
542
|
+
|
|
543
|
+
# Stream events from the engine
|
|
544
|
+
async for event in run_streaming(initial_state, run_config_with_memory):
|
|
545
|
+
yield f"event: {event.type}\ndata: {json.dumps(asdict(event))}\n\n"
|
|
546
|
+
|
|
547
|
+
# Check for run end and handle approval broadcasts
|
|
548
|
+
if event.type == 'complete' and hasattr(event, 'data'):
|
|
549
|
+
outcome = getattr(event.data, 'outcome', None)
|
|
550
|
+
if outcome and getattr(outcome, 'status', None) == 'interrupted':
|
|
551
|
+
interruptions = getattr(outcome, 'interruptions', [])
|
|
552
|
+
for intr in interruptions:
|
|
553
|
+
if getattr(intr, 'type', None) == 'tool_approval':
|
|
554
|
+
tool_call = getattr(intr, 'tool_call', None)
|
|
555
|
+
if tool_call:
|
|
556
|
+
broadcast_approval_required({
|
|
557
|
+
'conversationId': conversation_id,
|
|
558
|
+
'sessionId': getattr(intr, 'session_id', None) or str(initial_state.run_id),
|
|
559
|
+
'toolCallId': tool_call.id,
|
|
560
|
+
'toolName': tool_call.function.name,
|
|
561
|
+
'args': try_parse_json(tool_call.function.arguments),
|
|
562
|
+
'signature': compute_tool_call_signature(tool_call)
|
|
563
|
+
})
|
|
564
|
+
break
|
|
565
|
+
|
|
566
|
+
except Exception as e:
|
|
567
|
+
yield f"event: error\ndata: {json.dumps({'message': str(e)})}\n\n"
|
|
568
|
+
finally:
|
|
569
|
+
yield f"event: stream_end\ndata: {json.dumps({'ended': True})}\n\n"
|
|
570
|
+
|
|
571
|
+
return StreamingResponse(
|
|
572
|
+
event_stream(),
|
|
573
|
+
media_type="text/event-stream",
|
|
574
|
+
headers={
|
|
575
|
+
"Cache-Control": "no-cache, no-transform",
|
|
576
|
+
"Connection": "keep-alive",
|
|
577
|
+
"X-Accel-Buffering": "no",
|
|
578
|
+
"Access-Control-Allow-Origin": "*",
|
|
579
|
+
"Access-Control-Allow-Headers": "*"
|
|
580
|
+
}
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
# Non-streaming execution
|
|
134
584
|
result = await run(initial_state, run_config_with_memory)
|
|
135
585
|
|
|
136
|
-
http_messages = [
|
|
586
|
+
http_messages = [_convert_core_message_to_http(msg) for msg in result.final_state.messages]
|
|
137
587
|
|
|
138
|
-
|
|
588
|
+
# Create proper outcome object
|
|
139
589
|
if isinstance(result.outcome, CompletedOutcome):
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
590
|
+
outcome_data = BaseOutcomeData(
|
|
591
|
+
status='completed',
|
|
592
|
+
output=result.outcome.output
|
|
593
|
+
)
|
|
144
594
|
elif isinstance(result.outcome, ErrorOutcome):
|
|
145
595
|
error_info = result.outcome.error
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
596
|
+
outcome_data = BaseOutcomeData(
|
|
597
|
+
status='error',
|
|
598
|
+
error={
|
|
149
599
|
'type': error_info.__class__.__name__,
|
|
150
600
|
'message': str(error_info)
|
|
151
601
|
}
|
|
152
|
-
|
|
602
|
+
)
|
|
603
|
+
elif isinstance(result.outcome, InterruptedOutcome):
|
|
604
|
+
# Convert interruptions to response format
|
|
605
|
+
interruptions = []
|
|
606
|
+
for interruption in result.outcome.interruptions:
|
|
607
|
+
if hasattr(interruption, 'tool_call') and hasattr(interruption, 'type'):
|
|
608
|
+
tool_call_data = ToolCallInterruption(
|
|
609
|
+
id=interruption.tool_call.id,
|
|
610
|
+
function={
|
|
611
|
+
'name': interruption.tool_call.function.name,
|
|
612
|
+
'arguments': interruption.tool_call.function.arguments
|
|
613
|
+
}
|
|
614
|
+
)
|
|
615
|
+
interruptions.append(InterruptionData(
|
|
616
|
+
type='tool_approval',
|
|
617
|
+
tool_call=tool_call_data,
|
|
618
|
+
session_id=interruption.session_id or str(result.final_state.run_id)
|
|
619
|
+
))
|
|
620
|
+
|
|
621
|
+
# Broadcast approval request via SSE
|
|
622
|
+
broadcast_approval_required({
|
|
623
|
+
'conversationId': conversation_id,
|
|
624
|
+
'sessionId': interruption.session_id or str(result.final_state.run_id),
|
|
625
|
+
'toolCallId': interruption.tool_call.id,
|
|
626
|
+
'toolName': interruption.tool_call.function.name,
|
|
627
|
+
'args': try_parse_json(interruption.tool_call.function.arguments),
|
|
628
|
+
'signature': compute_tool_call_signature(interruption.tool_call)
|
|
629
|
+
})
|
|
630
|
+
|
|
631
|
+
outcome_data = InterruptedOutcomeData(
|
|
632
|
+
status='interrupted',
|
|
633
|
+
interruptions=interruptions
|
|
634
|
+
)
|
|
635
|
+
else:
|
|
636
|
+
outcome_data = BaseOutcomeData(status='error', error='Unknown outcome type')
|
|
153
637
|
|
|
154
638
|
return ChatResponse(
|
|
155
639
|
success=True,
|
|
@@ -157,7 +641,7 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
|
|
|
157
641
|
run_id=str(result.final_state.run_id),
|
|
158
642
|
trace_id=str(result.final_state.trace_id),
|
|
159
643
|
messages=http_messages,
|
|
160
|
-
outcome=
|
|
644
|
+
outcome=outcome_data,
|
|
161
645
|
turn_count=result.final_state.turn_count,
|
|
162
646
|
execution_time_ms=int((time.time() - request_start_time) * 1000), # Use request start time
|
|
163
647
|
conversation_id=conversation_id
|
|
@@ -211,4 +695,163 @@ def create_jaf_server(config: ServerConfig[Ctx]) -> FastAPI:
|
|
|
211
695
|
|
|
212
696
|
return MemoryHealthResponse(success=True, data=result.data)
|
|
213
697
|
|
|
698
|
+
# Approval endpoints for HITL functionality
|
|
699
|
+
@app.get("/approvals/pending", response_model=PendingApprovalsResponse)
|
|
700
|
+
async def get_pending_approvals(conversation_id: str = None):
|
|
701
|
+
"""Get pending approvals for a conversation."""
|
|
702
|
+
try:
|
|
703
|
+
if not conversation_id:
|
|
704
|
+
raise HTTPException(status_code=400, detail="conversation_id is required")
|
|
705
|
+
|
|
706
|
+
if not config.default_memory_provider:
|
|
707
|
+
return PendingApprovalsResponse(
|
|
708
|
+
success=False,
|
|
709
|
+
error="Memory provider not configured"
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
# Get conversation to analyze pending approvals
|
|
713
|
+
conv_result = await config.default_memory_provider.get_conversation(conversation_id)
|
|
714
|
+
if hasattr(conv_result, 'error'):
|
|
715
|
+
return PendingApprovalsResponse(success=False, error=str(conv_result.error.message))
|
|
716
|
+
|
|
717
|
+
if not conv_result.data:
|
|
718
|
+
return PendingApprovalsResponse(
|
|
719
|
+
success=True,
|
|
720
|
+
data=PendingApprovalsData(pending=[])
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
conversation = conv_result.data
|
|
724
|
+
messages = conversation.messages
|
|
725
|
+
approvals_meta = getattr(conversation.metadata, 'tool_approvals', {}) if conversation.metadata else {}
|
|
726
|
+
|
|
727
|
+
# Find most recent assistant message with tool calls
|
|
728
|
+
assistant_msg = None
|
|
729
|
+
assistant_index = -1
|
|
730
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
731
|
+
msg = messages[i]
|
|
732
|
+
if hasattr(msg, 'role') and msg.role == 'assistant' and hasattr(msg, 'tool_calls') and msg.tool_calls:
|
|
733
|
+
assistant_msg = msg
|
|
734
|
+
assistant_index = i
|
|
735
|
+
break
|
|
736
|
+
|
|
737
|
+
if not assistant_msg:
|
|
738
|
+
return PendingApprovalsResponse(
|
|
739
|
+
success=True,
|
|
740
|
+
data=PendingApprovalsData(pending=[])
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
# Check which tool calls have already been executed
|
|
744
|
+
tool_ids = {tc.id for tc in assistant_msg.tool_calls}
|
|
745
|
+
executed = set()
|
|
746
|
+
for j in range(assistant_index + 1, len(messages)):
|
|
747
|
+
msg = messages[j]
|
|
748
|
+
if hasattr(msg, 'role') and msg.role == 'tool' and hasattr(msg, 'tool_call_id'):
|
|
749
|
+
if msg.tool_call_id in tool_ids:
|
|
750
|
+
executed.add(msg.tool_call_id)
|
|
751
|
+
|
|
752
|
+
# Build pending approvals list
|
|
753
|
+
pending_approvals = []
|
|
754
|
+
for tc in assistant_msg.tool_calls:
|
|
755
|
+
if tc.id in executed:
|
|
756
|
+
continue # Already executed
|
|
757
|
+
|
|
758
|
+
# Check approval status
|
|
759
|
+
approval_key = f"{conversation.conversation_id}:{tc.id}"
|
|
760
|
+
approval_entry = approvals_meta.get(approval_key)
|
|
761
|
+
|
|
762
|
+
status = 'pending'
|
|
763
|
+
if approval_entry:
|
|
764
|
+
status = approval_entry.get('status', 'pending')
|
|
765
|
+
if approval_entry.get('approved') is True:
|
|
766
|
+
status = 'approved'
|
|
767
|
+
elif approval_entry.get('approved') is False:
|
|
768
|
+
status = 'rejected'
|
|
769
|
+
|
|
770
|
+
if status == 'pending':
|
|
771
|
+
pending_approvals.append(PendingApprovalData(
|
|
772
|
+
conversation_id=conversation_id,
|
|
773
|
+
tool_call_id=tc.id,
|
|
774
|
+
tool_name=tc.function.name,
|
|
775
|
+
args=try_parse_json(tc.function.arguments),
|
|
776
|
+
signature=compute_tool_call_signature(tc),
|
|
777
|
+
status='pending',
|
|
778
|
+
session_id=getattr(conversation.metadata, 'run_id', None) if conversation.metadata else None
|
|
779
|
+
))
|
|
780
|
+
|
|
781
|
+
return PendingApprovalsResponse(
|
|
782
|
+
success=True,
|
|
783
|
+
data=PendingApprovalsData(pending=pending_approvals)
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
except Exception as e:
|
|
787
|
+
return PendingApprovalsResponse(success=False, error=str(e))
|
|
788
|
+
|
|
789
|
+
# Agent-specific chat endpoint (convenience - matching TypeScript)
|
|
790
|
+
@app.post("/agents/{agent_name}/chat", response_model=ChatResponse)
|
|
791
|
+
async def agent_chat_completion(agent_name: str, request_body: ChatRequest):
|
|
792
|
+
"""Agent-specific chat endpoint for convenience."""
|
|
793
|
+
# Create modified request with agent name
|
|
794
|
+
modified_request = ChatRequest(
|
|
795
|
+
messages=request_body.messages,
|
|
796
|
+
agent_name=agent_name,
|
|
797
|
+
context=request_body.context,
|
|
798
|
+
max_turns=request_body.max_turns,
|
|
799
|
+
stream=request_body.stream,
|
|
800
|
+
conversation_id=request_body.conversation_id,
|
|
801
|
+
memory=request_body.memory,
|
|
802
|
+
approvals=request_body.approvals
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
# Delegate to main chat endpoint logic
|
|
806
|
+
return await chat_completion(modified_request)
|
|
807
|
+
|
|
808
|
+
# Approvals SSE stream endpoint (matching TypeScript placement inside start function)
|
|
809
|
+
@app.get("/approvals/stream")
|
|
810
|
+
async def stream_approval_updates(request: Request, conversation_id: str = None):
|
|
811
|
+
"""Stream real-time approval updates via Server-Sent Events."""
|
|
812
|
+
async def event_stream():
|
|
813
|
+
# Simple client structure matching TypeScript
|
|
814
|
+
client = {
|
|
815
|
+
'response': request, # Store request for disconnection check
|
|
816
|
+
'filter_conversation_id': conversation_id
|
|
817
|
+
}
|
|
818
|
+
approval_subscribers.add(client)
|
|
819
|
+
|
|
820
|
+
try:
|
|
821
|
+
# Initial greeting (matching TypeScript)
|
|
822
|
+
yield f"event: stream_start\ndata: {json.dumps({'conversationId': conversation_id})}\n\n"
|
|
823
|
+
|
|
824
|
+
# Heartbeat like TypeScript (15 second interval)
|
|
825
|
+
last_heartbeat = time.time()
|
|
826
|
+
|
|
827
|
+
while True:
|
|
828
|
+
# Check client disconnection
|
|
829
|
+
if await request.is_disconnected():
|
|
830
|
+
break
|
|
831
|
+
|
|
832
|
+
# Send heartbeat every 15 seconds
|
|
833
|
+
current_time = time.time()
|
|
834
|
+
if current_time - last_heartbeat >= 15:
|
|
835
|
+
yield f"event: ping\ndata: {json.dumps({'ts': int(current_time * 1000)})}\n\n"
|
|
836
|
+
last_heartbeat = current_time
|
|
837
|
+
|
|
838
|
+
await asyncio.sleep(1)
|
|
839
|
+
|
|
840
|
+
except Exception as e:
|
|
841
|
+
yield f"event: error\ndata: {json.dumps({'message': str(e)})}\n\n"
|
|
842
|
+
finally:
|
|
843
|
+
approval_subscribers.discard(client)
|
|
844
|
+
|
|
845
|
+
return StreamingResponse(
|
|
846
|
+
event_stream(),
|
|
847
|
+
media_type="text/event-stream",
|
|
848
|
+
headers={
|
|
849
|
+
"Cache-Control": "no-cache, no-transform",
|
|
850
|
+
"Connection": "keep-alive",
|
|
851
|
+
"X-Accel-Buffering": "no",
|
|
852
|
+
"Access-Control-Allow-Origin": "*",
|
|
853
|
+
"Access-Control-Allow-Headers": "*"
|
|
854
|
+
}
|
|
855
|
+
)
|
|
856
|
+
|
|
214
857
|
return app
|