jaf-py 2.3.1__py3-none-any.whl → 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +15 -0
- jaf/core/agent_tool.py +6 -4
- jaf/core/analytics.py +4 -3
- jaf/core/engine.py +512 -39
- jaf/core/state.py +156 -0
- jaf/core/tools.py +4 -6
- jaf/core/tracing.py +114 -23
- jaf/core/types.py +157 -4
- jaf/memory/approval_storage.py +306 -0
- jaf/memory/types.py +1 -0
- jaf/memory/utils.py +1 -1
- jaf/providers/model.py +436 -13
- jaf/server/__init__.py +2 -0
- jaf/server/server.py +665 -22
- jaf/server/types.py +149 -4
- jaf/utils/__init__.py +50 -0
- jaf/utils/attachments.py +401 -0
- jaf/utils/document_processor.py +561 -0
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/METADATA +128 -120
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/RECORD +24 -19
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/WHEEL +0 -0
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/top_level.txt +0 -0
jaf/core/state.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""
|
|
2
|
+
State management functions for approval handling in HITL scenarios.
|
|
3
|
+
|
|
4
|
+
This module provides functions to manage approval state transitions
|
|
5
|
+
and integrate with approval storage systems.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Dict, Any, Optional
|
|
9
|
+
from dataclasses import replace
|
|
10
|
+
|
|
11
|
+
from .types import RunState, RunConfig, Interruption, ApprovalValue
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def approve(
|
|
15
|
+
state: RunState[Any],
|
|
16
|
+
interruption: Interruption,
|
|
17
|
+
additional_context: Optional[Dict[str, Any]] = None,
|
|
18
|
+
config: Optional[RunConfig[Any]] = None
|
|
19
|
+
) -> RunState[Any]:
|
|
20
|
+
"""
|
|
21
|
+
Approve a tool call interruption and update the run state.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
state: Current run state
|
|
25
|
+
interruption: The interruption to approve
|
|
26
|
+
additional_context: Optional additional context for the approval
|
|
27
|
+
config: Optional run configuration for approval storage
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Updated run state with approval recorded
|
|
31
|
+
"""
|
|
32
|
+
if interruption.type == 'tool_approval':
|
|
33
|
+
approval_value = ApprovalValue(
|
|
34
|
+
status='approved',
|
|
35
|
+
approved=True,
|
|
36
|
+
additional_context={
|
|
37
|
+
**(additional_context or {}),
|
|
38
|
+
'status': 'approved'
|
|
39
|
+
}
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# Store in approval storage if available
|
|
43
|
+
if config and config.approval_storage:
|
|
44
|
+
try:
|
|
45
|
+
print(f"[JAF:APPROVAL] Storing approval for tool_call_id {interruption.tool_call.id}: {approval_value}")
|
|
46
|
+
result = await config.approval_storage.store_approval(
|
|
47
|
+
state.run_id,
|
|
48
|
+
interruption.tool_call.id,
|
|
49
|
+
approval_value
|
|
50
|
+
)
|
|
51
|
+
if not result.success:
|
|
52
|
+
print(f"[JAF:APPROVAL] Failed to store approval: {result.error}")
|
|
53
|
+
# Continue with in-memory fallback
|
|
54
|
+
else:
|
|
55
|
+
print(f"[JAF:APPROVAL] Successfully stored approval in storage")
|
|
56
|
+
except Exception as e:
|
|
57
|
+
print(f"[JAF:APPROVAL] Approval storage error: {e}")
|
|
58
|
+
# Continue with in-memory fallback
|
|
59
|
+
|
|
60
|
+
# Update in-memory state
|
|
61
|
+
new_approvals = {**state.approvals}
|
|
62
|
+
new_approvals[interruption.tool_call.id] = approval_value
|
|
63
|
+
|
|
64
|
+
return replace(state, approvals=new_approvals)
|
|
65
|
+
|
|
66
|
+
return state
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
async def reject(
|
|
70
|
+
state: RunState[Any],
|
|
71
|
+
interruption: Interruption,
|
|
72
|
+
additional_context: Optional[Dict[str, Any]] = None,
|
|
73
|
+
config: Optional[RunConfig[Any]] = None
|
|
74
|
+
) -> RunState[Any]:
|
|
75
|
+
"""
|
|
76
|
+
Reject a tool call interruption and update the run state.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
state: Current run state
|
|
80
|
+
interruption: The interruption to reject
|
|
81
|
+
additional_context: Optional additional context for the rejection
|
|
82
|
+
config: Optional run configuration for approval storage
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Updated run state with rejection recorded
|
|
86
|
+
"""
|
|
87
|
+
if interruption.type == 'tool_approval':
|
|
88
|
+
approval_value = ApprovalValue(
|
|
89
|
+
status='rejected',
|
|
90
|
+
approved=False,
|
|
91
|
+
additional_context={
|
|
92
|
+
**(additional_context or {}),
|
|
93
|
+
'status': 'rejected'
|
|
94
|
+
}
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Store in approval storage if available
|
|
98
|
+
if config and config.approval_storage:
|
|
99
|
+
try:
|
|
100
|
+
print(f"[JAF:APPROVAL] Storing approval for tool_call_id {interruption.tool_call.id}: {approval_value}")
|
|
101
|
+
result = await config.approval_storage.store_approval(
|
|
102
|
+
state.run_id,
|
|
103
|
+
interruption.tool_call.id,
|
|
104
|
+
approval_value
|
|
105
|
+
)
|
|
106
|
+
if not result.success:
|
|
107
|
+
print(f"[JAF:APPROVAL] Failed to store approval: {result.error}")
|
|
108
|
+
# Continue with in-memory fallback
|
|
109
|
+
else:
|
|
110
|
+
print(f"[JAF:APPROVAL] Successfully stored approval in storage")
|
|
111
|
+
except Exception as e:
|
|
112
|
+
print(f"[JAF:APPROVAL] Approval storage error: {e}")
|
|
113
|
+
# Continue with in-memory fallback
|
|
114
|
+
|
|
115
|
+
# Update in-memory state
|
|
116
|
+
new_approvals = {**state.approvals}
|
|
117
|
+
new_approvals[interruption.tool_call.id] = approval_value
|
|
118
|
+
|
|
119
|
+
return replace(state, approvals=new_approvals)
|
|
120
|
+
|
|
121
|
+
return state
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
async def load_approvals_into_state(
|
|
125
|
+
state: RunState[Any],
|
|
126
|
+
config: Optional[RunConfig[Any]] = None
|
|
127
|
+
) -> RunState[Any]:
|
|
128
|
+
"""
|
|
129
|
+
Load approvals from storage into the run state.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
state: Current run state
|
|
133
|
+
config: Optional run configuration with approval storage
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Updated run state with loaded approvals
|
|
137
|
+
"""
|
|
138
|
+
if not config or not config.approval_storage:
|
|
139
|
+
print(f"[JAF:APPROVAL] No approval storage configured, using existing approvals: {state.approvals}")
|
|
140
|
+
return state
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
print(f"[JAF:APPROVAL] Loading approvals from storage for run_id: {state.run_id}")
|
|
144
|
+
result = await config.approval_storage.get_run_approvals(state.run_id)
|
|
145
|
+
if result.success and result.data:
|
|
146
|
+
print(f"[JAF:APPROVAL] Loaded {len(result.data)} approvals from storage: {result.data}")
|
|
147
|
+
return replace(state, approvals=result.data)
|
|
148
|
+
else:
|
|
149
|
+
if not result.success:
|
|
150
|
+
print(f"[JAF:APPROVAL] Failed to load approvals: {result.error}")
|
|
151
|
+
else:
|
|
152
|
+
print(f"[JAF:APPROVAL] No approvals found in storage for run_id: {state.run_id}")
|
|
153
|
+
return state
|
|
154
|
+
except Exception as e:
|
|
155
|
+
print(f"[JAF:APPROVAL] Approval loading error: {e}")
|
|
156
|
+
return state
|
jaf/core/tools.py
CHANGED
|
@@ -89,14 +89,12 @@ def create_function_tool(config: FunctionToolConfig) -> Tool:
|
|
|
89
89
|
# Validate schema generation (cached for performance)
|
|
90
90
|
if not hasattr(parameters, '_schema_validated'):
|
|
91
91
|
try:
|
|
92
|
+
# Generate schema once to validate the model is well-formed.
|
|
93
|
+
# Allow empty object schemas (no parameters) for tools that take no args.
|
|
92
94
|
if hasattr(parameters, 'model_json_schema'):
|
|
93
|
-
|
|
94
|
-
if not test_schema.get('properties'):
|
|
95
|
-
raise ValueError(f"Tool '{tool_name}' has no properties in schema. Check your Pydantic model fields.")
|
|
95
|
+
_ = parameters.model_json_schema()
|
|
96
96
|
elif hasattr(parameters, 'schema'):
|
|
97
|
-
|
|
98
|
-
if not test_schema.get('properties'):
|
|
99
|
-
raise ValueError(f"Tool '{tool_name}' has no properties in schema. Check your Pydantic model fields.")
|
|
97
|
+
_ = parameters.schema()
|
|
100
98
|
parameters._schema_validated = True
|
|
101
99
|
except Exception as e:
|
|
102
100
|
logger.error(f"Tool {tool_name} schema generation failed: {e}")
|
jaf/core/tracing.py
CHANGED
|
@@ -344,6 +344,9 @@ class LangfuseTraceCollector:
|
|
|
344
344
|
)
|
|
345
345
|
self.active_spans: Dict[str, Any] = {}
|
|
346
346
|
self.trace_spans: Dict[TraceId, Any] = {}
|
|
347
|
+
# Track tool calls and results for each trace
|
|
348
|
+
self.trace_tool_calls: Dict[TraceId, List[Dict[str, Any]]] = {}
|
|
349
|
+
self.trace_tool_results: Dict[TraceId, List[Dict[str, Any]]] = {}
|
|
347
350
|
|
|
348
351
|
def collect(self, event: TraceEvent) -> None:
|
|
349
352
|
"""Collect a trace event and send it to Langfuse."""
|
|
@@ -359,9 +362,14 @@ class LangfuseTraceCollector:
|
|
|
359
362
|
# Start a new trace for the entire run
|
|
360
363
|
print(f"[LANGFUSE] Starting trace for run: {trace_id}")
|
|
361
364
|
|
|
365
|
+
# Initialize tracking for this trace
|
|
366
|
+
self.trace_tool_calls[trace_id] = []
|
|
367
|
+
self.trace_tool_results[trace_id] = []
|
|
368
|
+
|
|
362
369
|
# Extract user query from the run_start data
|
|
363
370
|
user_query = None
|
|
364
371
|
user_id = None
|
|
372
|
+
conversation_history = []
|
|
365
373
|
|
|
366
374
|
# Debug: Print the event data structure to understand what we're working with
|
|
367
375
|
if event.data.get("context"):
|
|
@@ -401,24 +409,51 @@ class LangfuseTraceCollector:
|
|
|
401
409
|
user_id = token_response.email
|
|
402
410
|
print(f"[LANGFUSE DEBUG] Extracted user_id from attr: {user_id}")
|
|
403
411
|
|
|
404
|
-
#
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
# Find the last user message
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
412
|
+
# Extract conversation history and current user query from messages
|
|
413
|
+
messages = event.data.get("messages", [])
|
|
414
|
+
if messages:
|
|
415
|
+
print(f"[LANGFUSE DEBUG] Processing {len(messages)} messages")
|
|
416
|
+
|
|
417
|
+
# Find the last user message (current query) and extract conversation history (excluding current)
|
|
418
|
+
current_user_message_found = False
|
|
419
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
420
|
+
msg = messages[i]
|
|
421
|
+
|
|
422
|
+
if isinstance(msg, dict):
|
|
423
|
+
role = msg.get("role")
|
|
424
|
+
content = msg.get("content", "")
|
|
425
|
+
elif hasattr(msg, 'role'):
|
|
426
|
+
role = msg.role
|
|
427
|
+
content = getattr(msg, 'content', "")
|
|
428
|
+
# Handle both string content and complex content structures
|
|
429
|
+
if not isinstance(content, str):
|
|
430
|
+
# Try to extract text from complex content
|
|
431
|
+
if hasattr(content, '__iter__') and not isinstance(content, str):
|
|
432
|
+
try:
|
|
433
|
+
# If it's a list, try to join text parts
|
|
434
|
+
content = " ".join(str(item) for item in content if item)
|
|
435
|
+
except:
|
|
436
|
+
content = str(content)
|
|
437
|
+
else:
|
|
438
|
+
content = str(content)
|
|
439
|
+
else:
|
|
440
|
+
continue
|
|
441
|
+
|
|
442
|
+
# If we haven't found the current user message yet and this is a user message
|
|
443
|
+
if not current_user_message_found and (role == "user" or role == 'user'):
|
|
444
|
+
user_query = content
|
|
445
|
+
current_user_message_found = True
|
|
446
|
+
print(f"[LANGFUSE DEBUG] Found current user query: {user_query}")
|
|
447
|
+
elif current_user_message_found:
|
|
448
|
+
# Add to conversation history (excluding the current user message)
|
|
449
|
+
conversation_history.insert(0, {
|
|
450
|
+
"role": role,
|
|
451
|
+
"content": content,
|
|
452
|
+
"timestamp": datetime.now().isoformat() if not hasattr(msg, 'timestamp') else getattr(msg, 'timestamp', datetime.now().isoformat())
|
|
453
|
+
})
|
|
420
454
|
|
|
421
455
|
print(f"[LANGFUSE DEBUG] Final extracted - user_query: {user_query}, user_id: {user_id}")
|
|
456
|
+
print(f"[LANGFUSE DEBUG] Conversation history length: {len(conversation_history)}")
|
|
422
457
|
|
|
423
458
|
# Create comprehensive input data for the trace
|
|
424
459
|
trace_input = {
|
|
@@ -437,31 +472,58 @@ class LangfuseTraceCollector:
|
|
|
437
472
|
session_id=event.data.get("session_id"),
|
|
438
473
|
input=trace_input,
|
|
439
474
|
metadata={
|
|
440
|
-
"framework": "jaf",
|
|
441
|
-
"event_type": "run_start",
|
|
475
|
+
"framework": "jaf",
|
|
476
|
+
"event_type": "run_start",
|
|
442
477
|
"trace_id": str(trace_id),
|
|
443
478
|
"user_query": user_query,
|
|
444
479
|
"user_id": user_id or event.data.get("user_id"),
|
|
445
|
-
"agent_name": event.data.get("agent_name", "analytics_agent_jaf")
|
|
480
|
+
"agent_name": event.data.get("agent_name", "analytics_agent_jaf"),
|
|
481
|
+
"conversation_history": conversation_history,
|
|
482
|
+
"tool_calls": [],
|
|
483
|
+
"tool_results": []
|
|
446
484
|
}
|
|
447
485
|
)
|
|
448
486
|
self.trace_spans[trace_id] = trace
|
|
449
|
-
# Store user_id and
|
|
487
|
+
# Store user_id, user_query, and conversation_history for later use
|
|
450
488
|
trace._user_id = user_id or event.data.get("user_id")
|
|
451
489
|
trace._user_query = user_query
|
|
490
|
+
trace._conversation_history = conversation_history
|
|
452
491
|
print(f"[LANGFUSE] Created trace with user query: {user_query[:100] if user_query else 'None'}...")
|
|
453
492
|
|
|
454
493
|
elif event.type == "run_end":
|
|
455
494
|
if trace_id in self.trace_spans:
|
|
456
495
|
print(f"[LANGFUSE] Ending trace for run: {trace_id}")
|
|
457
|
-
|
|
458
|
-
|
|
496
|
+
|
|
497
|
+
# Update the trace metadata with final tool calls and results
|
|
498
|
+
final_metadata = {
|
|
499
|
+
"framework": "jaf",
|
|
500
|
+
"event_type": "run_end",
|
|
501
|
+
"trace_id": str(trace_id),
|
|
502
|
+
"user_query": getattr(self.trace_spans[trace_id], '_user_query', None),
|
|
503
|
+
"user_id": getattr(self.trace_spans[trace_id], '_user_id', None),
|
|
504
|
+
"agent_name": event.data.get("agent_name", "analytics_agent_jaf"),
|
|
505
|
+
"conversation_history": getattr(self.trace_spans[trace_id], '_conversation_history', []),
|
|
506
|
+
"tool_calls": self.trace_tool_calls.get(trace_id, []),
|
|
507
|
+
"tool_results": self.trace_tool_results.get(trace_id, [])
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
# End the trace with updated metadata
|
|
511
|
+
self.trace_spans[trace_id].update(
|
|
512
|
+
output=event.data,
|
|
513
|
+
metadata=final_metadata
|
|
514
|
+
)
|
|
515
|
+
|
|
459
516
|
# Flush to ensure data is sent
|
|
460
517
|
print(f"[LANGFUSE] Flushing data to Langfuse...")
|
|
461
518
|
self.langfuse.flush()
|
|
462
519
|
print(f"[LANGFUSE] Flush completed")
|
|
520
|
+
|
|
463
521
|
# Clean up
|
|
464
522
|
del self.trace_spans[trace_id]
|
|
523
|
+
if trace_id in self.trace_tool_calls:
|
|
524
|
+
del self.trace_tool_calls[trace_id]
|
|
525
|
+
if trace_id in self.trace_tool_results:
|
|
526
|
+
del self.trace_tool_results[trace_id]
|
|
465
527
|
else:
|
|
466
528
|
print(f"[LANGFUSE] No trace found for run_end: {trace_id}")
|
|
467
529
|
|
|
@@ -549,6 +611,20 @@ class LangfuseTraceCollector:
|
|
|
549
611
|
|
|
550
612
|
print(f"[LANGFUSE] Starting span for tool call: {tool_name}")
|
|
551
613
|
|
|
614
|
+
# Track this tool call for the trace
|
|
615
|
+
tool_call_data = {
|
|
616
|
+
"tool_name": tool_name,
|
|
617
|
+
"arguments": tool_args,
|
|
618
|
+
"call_id": event.data.get("call_id"),
|
|
619
|
+
"timestamp": datetime.now().isoformat()
|
|
620
|
+
}
|
|
621
|
+
|
|
622
|
+
# Ensure trace_id exists in tracking
|
|
623
|
+
if trace_id not in self.trace_tool_calls:
|
|
624
|
+
self.trace_tool_calls[trace_id] = []
|
|
625
|
+
|
|
626
|
+
self.trace_tool_calls[trace_id].append(tool_call_data)
|
|
627
|
+
|
|
552
628
|
# Create comprehensive input data for the tool call
|
|
553
629
|
tool_input = {
|
|
554
630
|
"tool_name": tool_name,
|
|
@@ -579,13 +655,28 @@ class LangfuseTraceCollector:
|
|
|
579
655
|
|
|
580
656
|
print(f"[LANGFUSE] Ending span for tool call: {tool_name}")
|
|
581
657
|
|
|
658
|
+
# Track this tool result for the trace
|
|
659
|
+
tool_result_data = {
|
|
660
|
+
"tool_name": tool_name,
|
|
661
|
+
"result": tool_result,
|
|
662
|
+
"call_id": event.data.get("call_id"),
|
|
663
|
+
"timestamp": datetime.now().isoformat(),
|
|
664
|
+
"status": event.data.get("status", "completed"),
|
|
665
|
+
"tool_result": event.data.get("tool_result")
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
if trace_id not in self.trace_tool_results:
|
|
669
|
+
self.trace_tool_results[trace_id] = []
|
|
670
|
+
|
|
671
|
+
self.trace_tool_results[trace_id].append(tool_result_data)
|
|
672
|
+
|
|
582
673
|
# Create comprehensive output data for the tool call
|
|
583
674
|
tool_output = {
|
|
584
675
|
"tool_name": tool_name,
|
|
585
676
|
"result": tool_result,
|
|
586
677
|
"call_id": event.data.get("call_id"),
|
|
587
678
|
"timestamp": datetime.now().isoformat(),
|
|
588
|
-
"status": "completed"
|
|
679
|
+
"status": event.data.get("status", "completed")
|
|
589
680
|
}
|
|
590
681
|
|
|
591
682
|
# End the span with detailed output
|
jaf/core/types.py
CHANGED
|
@@ -5,12 +5,13 @@ This module defines all the fundamental data structures and types used throughou
|
|
|
5
5
|
the framework, maintaining immutability and type safety.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from collections.abc import Awaitable
|
|
8
|
+
from collections.abc import Awaitable, AsyncIterator
|
|
9
9
|
|
|
10
10
|
# ReadOnly is only available in Python 3.13+, so we'll use a simpler approach
|
|
11
11
|
from dataclasses import dataclass, field
|
|
12
12
|
from typing import (
|
|
13
13
|
Any,
|
|
14
|
+
Awaitable,
|
|
14
15
|
Callable,
|
|
15
16
|
Dict,
|
|
16
17
|
Generic,
|
|
@@ -28,6 +29,8 @@ from enum import Enum
|
|
|
28
29
|
|
|
29
30
|
if TYPE_CHECKING:
|
|
30
31
|
from .tool_results import ToolResult
|
|
32
|
+
from ..memory.approval_storage import ApprovalStorage
|
|
33
|
+
from ..memory.types import MemoryConfig
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
# Comprehensive enums for type safety and improved developer experience
|
|
@@ -143,13 +146,88 @@ class ToolCallFunction:
|
|
|
143
146
|
name: str
|
|
144
147
|
arguments: str
|
|
145
148
|
|
|
149
|
+
@dataclass(frozen=True)
|
|
150
|
+
class Attachment:
|
|
151
|
+
"""Represents an attachment with various content types."""
|
|
152
|
+
kind: Literal['image', 'document', 'file']
|
|
153
|
+
mime_type: Optional[str] = None # e.g. image/png, application/pdf
|
|
154
|
+
name: Optional[str] = None # Optional filename
|
|
155
|
+
url: Optional[str] = None # Remote URL or data URL
|
|
156
|
+
data: Optional[str] = None # Base64 without data: prefix
|
|
157
|
+
format: Optional[str] = None # Optional short format like 'pdf', 'txt'
|
|
158
|
+
use_litellm_format: Optional[bool] = None # Use LiteLLM native file format
|
|
159
|
+
|
|
160
|
+
def __post_init__(self):
|
|
161
|
+
"""Validate that at least one of url or data is provided."""
|
|
162
|
+
if self.url is None and self.data is None:
|
|
163
|
+
raise ValueError("At least one of 'url' or 'data' must be provided for an Attachment.")
|
|
164
|
+
|
|
165
|
+
@dataclass(frozen=True)
|
|
166
|
+
class MessageContentPart:
|
|
167
|
+
"""Part of multi-part message content."""
|
|
168
|
+
type: Literal['text', 'image_url', 'file']
|
|
169
|
+
text: Optional[str] = None
|
|
170
|
+
image_url: Optional[Dict[str, Any]] = None # Contains url and optional detail
|
|
171
|
+
file: Optional[Dict[str, Any]] = None # Contains file_id and optional format
|
|
172
|
+
|
|
146
173
|
@dataclass(frozen=True)
|
|
147
174
|
class Message:
|
|
148
|
-
"""
|
|
175
|
+
"""
|
|
176
|
+
A message in the conversation.
|
|
177
|
+
|
|
178
|
+
BACKWARDS COMPATIBILITY:
|
|
179
|
+
- Messages created with string content remain fully backwards compatible
|
|
180
|
+
- Direct access to .content returns the original string when created with string
|
|
181
|
+
- Use .text_content property for guaranteed string access in all cases
|
|
182
|
+
- Use get_text_content() function to extract text from any content type
|
|
183
|
+
|
|
184
|
+
Examples:
|
|
185
|
+
# Original usage - still works exactly the same
|
|
186
|
+
msg = Message(role='user', content='Hello')
|
|
187
|
+
text = msg.content # Returns 'Hello' as string
|
|
188
|
+
|
|
189
|
+
# Guaranteed string access (recommended for new code)
|
|
190
|
+
text = msg.text_content # Always returns string
|
|
191
|
+
|
|
192
|
+
# Universal text extraction
|
|
193
|
+
text = get_text_content(msg.content) # Works with any content type
|
|
194
|
+
"""
|
|
149
195
|
role: ContentRole
|
|
150
|
-
content: str
|
|
196
|
+
content: Union[str, List[MessageContentPart]]
|
|
197
|
+
attachments: Optional[List[Attachment]] = None
|
|
151
198
|
tool_call_id: Optional[str] = None
|
|
152
199
|
tool_calls: Optional[List[ToolCall]] = None
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def text_content(self) -> str:
|
|
203
|
+
"""Get text content as string for backwards compatibility."""
|
|
204
|
+
return get_text_content(self.content)
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def create(
|
|
208
|
+
cls,
|
|
209
|
+
role: ContentRole,
|
|
210
|
+
content: str,
|
|
211
|
+
attachments: Optional[List[Attachment]] = None,
|
|
212
|
+
tool_call_id: Optional[str] = None,
|
|
213
|
+
tool_calls: Optional[List[ToolCall]] = None
|
|
214
|
+
) -> 'Message':
|
|
215
|
+
"""Create a message with string content and optional attachments."""
|
|
216
|
+
return cls(
|
|
217
|
+
role=role,
|
|
218
|
+
content=content,
|
|
219
|
+
attachments=attachments,
|
|
220
|
+
tool_call_id=tool_call_id,
|
|
221
|
+
tool_calls=tool_calls
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def get_text_content(content: Union[str, List[MessageContentPart]]) -> str:
|
|
225
|
+
"""Extract text content from message content."""
|
|
226
|
+
if isinstance(content, str):
|
|
227
|
+
return content
|
|
228
|
+
|
|
229
|
+
text_parts = [part.text for part in content if part.type == 'text' and part.text]
|
|
230
|
+
return ' '.join(text_parts)
|
|
153
231
|
|
|
154
232
|
@dataclass(frozen=True)
|
|
155
233
|
class ModelConfig:
|
|
@@ -179,6 +257,11 @@ class Tool(Protocol[Args, Ctx]):
|
|
|
179
257
|
"""Execute the tool with given arguments and context."""
|
|
180
258
|
...
|
|
181
259
|
|
|
260
|
+
@property
|
|
261
|
+
def needs_approval(self) -> Union[bool, Callable[[Ctx, Args], Union[bool, Awaitable[bool]]]]:
|
|
262
|
+
"""Whether this tool requires approval before execution."""
|
|
263
|
+
return False
|
|
264
|
+
|
|
182
265
|
|
|
183
266
|
# Function tool configuration for improved DX
|
|
184
267
|
class FunctionToolConfig(TypedDict):
|
|
@@ -248,6 +331,13 @@ class Agent(Generic[Ctx, Out]):
|
|
|
248
331
|
# Guardrail type
|
|
249
332
|
Guardrail = Callable[[Any], Union[ValidationResult, Awaitable[ValidationResult]]]
|
|
250
333
|
|
|
334
|
+
@dataclass(frozen=True)
|
|
335
|
+
class ApprovalValue:
|
|
336
|
+
"""Represents an approval decision with context."""
|
|
337
|
+
status: str # 'pending', 'approved', 'rejected'
|
|
338
|
+
approved: bool
|
|
339
|
+
additional_context: Optional[Dict[str, Any]] = None
|
|
340
|
+
|
|
251
341
|
@dataclass(frozen=True)
|
|
252
342
|
class RunState(Generic[Ctx]):
|
|
253
343
|
"""Immutable state of a run."""
|
|
@@ -257,6 +347,7 @@ class RunState(Generic[Ctx]):
|
|
|
257
347
|
current_agent_name: str
|
|
258
348
|
context: Ctx
|
|
259
349
|
turn_count: int
|
|
350
|
+
approvals: Dict[str, ApprovalValue] = field(default_factory=dict)
|
|
260
351
|
|
|
261
352
|
# Error types using dataclasses for immutability
|
|
262
353
|
@dataclass(frozen=True)
|
|
@@ -335,6 +426,18 @@ class NetworkError:
|
|
|
335
426
|
is_retryable: bool = True
|
|
336
427
|
endpoint: Optional[str] = None
|
|
337
428
|
|
|
429
|
+
# Interruption types for HITL
|
|
430
|
+
@dataclass(frozen=True)
|
|
431
|
+
class ToolApprovalInterruption(Generic[Ctx]):
|
|
432
|
+
"""Interruption for tool approval."""
|
|
433
|
+
type: Literal['tool_approval'] = 'tool_approval'
|
|
434
|
+
tool_call: ToolCall = field(default_factory=lambda: ToolCall("", "function", ToolCallFunction("", "")))
|
|
435
|
+
agent: 'Agent[Ctx, Any]' = None
|
|
436
|
+
session_id: Optional[str] = None
|
|
437
|
+
|
|
438
|
+
# Union type for all interruptions
|
|
439
|
+
Interruption = Union[ToolApprovalInterruption[Any]]
|
|
440
|
+
|
|
338
441
|
# Union type for all possible errors
|
|
339
442
|
JAFError = Union[
|
|
340
443
|
MaxTurnsExceeded,
|
|
@@ -363,8 +466,14 @@ class ErrorOutcome:
|
|
|
363
466
|
status: Literal['error'] = 'error'
|
|
364
467
|
error: JAFError = field(default=None)
|
|
365
468
|
|
|
469
|
+
@dataclass(frozen=True)
|
|
470
|
+
class InterruptedOutcome:
|
|
471
|
+
"""Interrupted outcome for HITL."""
|
|
472
|
+
status: Literal['interrupted'] = 'interrupted'
|
|
473
|
+
interruptions: List[Interruption] = field(default_factory=list)
|
|
474
|
+
|
|
366
475
|
# Union type for outcomes
|
|
367
|
-
RunOutcome = Union[CompletedOutcome[Out], ErrorOutcome]
|
|
476
|
+
RunOutcome = Union[CompletedOutcome[Out], ErrorOutcome, InterruptedOutcome]
|
|
368
477
|
|
|
369
478
|
@dataclass(frozen=True)
|
|
370
479
|
class RunResult(Generic[Out]):
|
|
@@ -416,6 +525,15 @@ class LLMCallEndEvent:
|
|
|
416
525
|
data: LLMCallEndEventData = field(default_factory=lambda: LLMCallEndEventData(None, TraceId(""), RunId("")))
|
|
417
526
|
|
|
418
527
|
@dataclass(frozen=True)
|
|
528
|
+
class AssistantMessageEventData:
|
|
529
|
+
"""Data for assistant message events (partial or complete)."""
|
|
530
|
+
message: Message
|
|
531
|
+
|
|
532
|
+
@dataclass(frozen=True)
|
|
533
|
+
class AssistantMessageEvent:
|
|
534
|
+
type: Literal['assistant_message'] = 'assistant_message'
|
|
535
|
+
data: AssistantMessageEventData = field(default_factory=lambda: AssistantMessageEventData(Message(role=ContentRole.ASSISTANT, content="")))
|
|
536
|
+
@dataclass(frozen=True)
|
|
419
537
|
class ToolCallStartEventData:
|
|
420
538
|
"""Data for tool call start events."""
|
|
421
539
|
tool_name: str
|
|
@@ -515,6 +633,7 @@ TraceEvent = Union[
|
|
|
515
633
|
OutputParseEvent,
|
|
516
634
|
LLMCallStartEvent,
|
|
517
635
|
LLMCallEndEvent,
|
|
636
|
+
AssistantMessageEvent,
|
|
518
637
|
ToolCallStartEvent,
|
|
519
638
|
ToolCallEndEvent,
|
|
520
639
|
HandoffEvent,
|
|
@@ -532,6 +651,30 @@ class ModelCompletionResponse:
|
|
|
532
651
|
"""Response structure from model completion."""
|
|
533
652
|
message: Optional[ModelCompletionMessage] = None
|
|
534
653
|
|
|
654
|
+
# Streaming chunk structures for provider-level streaming support
|
|
655
|
+
@dataclass(frozen=True)
|
|
656
|
+
class ToolCallFunctionDelta:
|
|
657
|
+
"""Function fields that may stream as deltas."""
|
|
658
|
+
name: Optional[str] = None
|
|
659
|
+
arguments_delta: Optional[str] = None
|
|
660
|
+
|
|
661
|
+
@dataclass(frozen=True)
|
|
662
|
+
class ToolCallDelta:
|
|
663
|
+
"""Represents a partial tool call delta in a streamed response."""
|
|
664
|
+
index: int
|
|
665
|
+
id: Optional[str] = None
|
|
666
|
+
type: Literal['function'] = 'function'
|
|
667
|
+
function: Optional[ToolCallFunctionDelta] = None
|
|
668
|
+
|
|
669
|
+
@dataclass(frozen=True)
|
|
670
|
+
class CompletionStreamChunk:
|
|
671
|
+
"""A streamed chunk from the model provider."""
|
|
672
|
+
delta: Optional[str] = None
|
|
673
|
+
tool_call_delta: Optional[ToolCallDelta] = None
|
|
674
|
+
is_done: Optional[bool] = False
|
|
675
|
+
finish_reason: Optional[str] = None
|
|
676
|
+
raw: Optional[Any] = None
|
|
677
|
+
|
|
535
678
|
@runtime_checkable
|
|
536
679
|
class ModelProvider(Protocol[Ctx]):
|
|
537
680
|
"""Protocol for model providers."""
|
|
@@ -545,6 +688,15 @@ class ModelProvider(Protocol[Ctx]):
|
|
|
545
688
|
"""Get completion from the model."""
|
|
546
689
|
...
|
|
547
690
|
|
|
691
|
+
async def get_completion_stream(
|
|
692
|
+
self,
|
|
693
|
+
state: RunState[Ctx],
|
|
694
|
+
agent: Agent[Ctx, Any],
|
|
695
|
+
config: 'RunConfig[Ctx]'
|
|
696
|
+
) -> AsyncIterator[CompletionStreamChunk]:
|
|
697
|
+
"""Optional streaming API: yields incremental deltas while generating."""
|
|
698
|
+
...
|
|
699
|
+
|
|
548
700
|
@dataclass(frozen=True)
|
|
549
701
|
class RunConfig(Generic[Ctx]):
|
|
550
702
|
"""Configuration for running agents."""
|
|
@@ -558,3 +710,4 @@ class RunConfig(Generic[Ctx]):
|
|
|
558
710
|
memory: Optional['MemoryConfig'] = None
|
|
559
711
|
conversation_id: Optional[str] = None
|
|
560
712
|
default_tool_timeout: Optional[float] = 30.0 # Default timeout for tool execution in seconds
|
|
713
|
+
approval_storage: Optional['ApprovalStorage'] = None # Storage for approval decisions
|