jaf-py 2.4.1__py3-none-any.whl → 2.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaf/core/state.py ADDED
@@ -0,0 +1,156 @@
1
+ """
2
+ State management functions for approval handling in HITL scenarios.
3
+
4
+ This module provides functions to manage approval state transitions
5
+ and integrate with approval storage systems.
6
+ """
7
+
8
+ from typing import Dict, Any, Optional
9
+ from dataclasses import replace
10
+
11
+ from .types import RunState, RunConfig, Interruption, ApprovalValue
12
+
13
+
14
+ async def approve(
15
+ state: RunState[Any],
16
+ interruption: Interruption,
17
+ additional_context: Optional[Dict[str, Any]] = None,
18
+ config: Optional[RunConfig[Any]] = None
19
+ ) -> RunState[Any]:
20
+ """
21
+ Approve a tool call interruption and update the run state.
22
+
23
+ Args:
24
+ state: Current run state
25
+ interruption: The interruption to approve
26
+ additional_context: Optional additional context for the approval
27
+ config: Optional run configuration for approval storage
28
+
29
+ Returns:
30
+ Updated run state with approval recorded
31
+ """
32
+ if interruption.type == 'tool_approval':
33
+ approval_value = ApprovalValue(
34
+ status='approved',
35
+ approved=True,
36
+ additional_context={
37
+ **(additional_context or {}),
38
+ 'status': 'approved'
39
+ }
40
+ )
41
+
42
+ # Store in approval storage if available
43
+ if config and config.approval_storage:
44
+ try:
45
+ print(f"[JAF:APPROVAL] Storing approval for tool_call_id {interruption.tool_call.id}: {approval_value}")
46
+ result = await config.approval_storage.store_approval(
47
+ state.run_id,
48
+ interruption.tool_call.id,
49
+ approval_value
50
+ )
51
+ if not result.success:
52
+ print(f"[JAF:APPROVAL] Failed to store approval: {result.error}")
53
+ # Continue with in-memory fallback
54
+ else:
55
+ print(f"[JAF:APPROVAL] Successfully stored approval in storage")
56
+ except Exception as e:
57
+ print(f"[JAF:APPROVAL] Approval storage error: {e}")
58
+ # Continue with in-memory fallback
59
+
60
+ # Update in-memory state
61
+ new_approvals = {**state.approvals}
62
+ new_approvals[interruption.tool_call.id] = approval_value
63
+
64
+ return replace(state, approvals=new_approvals)
65
+
66
+ return state
67
+
68
+
69
+ async def reject(
70
+ state: RunState[Any],
71
+ interruption: Interruption,
72
+ additional_context: Optional[Dict[str, Any]] = None,
73
+ config: Optional[RunConfig[Any]] = None
74
+ ) -> RunState[Any]:
75
+ """
76
+ Reject a tool call interruption and update the run state.
77
+
78
+ Args:
79
+ state: Current run state
80
+ interruption: The interruption to reject
81
+ additional_context: Optional additional context for the rejection
82
+ config: Optional run configuration for approval storage
83
+
84
+ Returns:
85
+ Updated run state with rejection recorded
86
+ """
87
+ if interruption.type == 'tool_approval':
88
+ approval_value = ApprovalValue(
89
+ status='rejected',
90
+ approved=False,
91
+ additional_context={
92
+ **(additional_context or {}),
93
+ 'status': 'rejected'
94
+ }
95
+ )
96
+
97
+ # Store in approval storage if available
98
+ if config and config.approval_storage:
99
+ try:
100
+ print(f"[JAF:APPROVAL] Storing approval for tool_call_id {interruption.tool_call.id}: {approval_value}")
101
+ result = await config.approval_storage.store_approval(
102
+ state.run_id,
103
+ interruption.tool_call.id,
104
+ approval_value
105
+ )
106
+ if not result.success:
107
+ print(f"[JAF:APPROVAL] Failed to store approval: {result.error}")
108
+ # Continue with in-memory fallback
109
+ else:
110
+ print(f"[JAF:APPROVAL] Successfully stored approval in storage")
111
+ except Exception as e:
112
+ print(f"[JAF:APPROVAL] Approval storage error: {e}")
113
+ # Continue with in-memory fallback
114
+
115
+ # Update in-memory state
116
+ new_approvals = {**state.approvals}
117
+ new_approvals[interruption.tool_call.id] = approval_value
118
+
119
+ return replace(state, approvals=new_approvals)
120
+
121
+ return state
122
+
123
+
124
+ async def load_approvals_into_state(
125
+ state: RunState[Any],
126
+ config: Optional[RunConfig[Any]] = None
127
+ ) -> RunState[Any]:
128
+ """
129
+ Load approvals from storage into the run state.
130
+
131
+ Args:
132
+ state: Current run state
133
+ config: Optional run configuration with approval storage
134
+
135
+ Returns:
136
+ Updated run state with loaded approvals
137
+ """
138
+ if not config or not config.approval_storage:
139
+ print(f"[JAF:APPROVAL] No approval storage configured, using existing approvals: {state.approvals}")
140
+ return state
141
+
142
+ try:
143
+ print(f"[JAF:APPROVAL] Loading approvals from storage for run_id: {state.run_id}")
144
+ result = await config.approval_storage.get_run_approvals(state.run_id)
145
+ if result.success and result.data:
146
+ print(f"[JAF:APPROVAL] Loaded {len(result.data)} approvals from storage: {result.data}")
147
+ return replace(state, approvals=result.data)
148
+ else:
149
+ if not result.success:
150
+ print(f"[JAF:APPROVAL] Failed to load approvals: {result.error}")
151
+ else:
152
+ print(f"[JAF:APPROVAL] No approvals found in storage for run_id: {state.run_id}")
153
+ return state
154
+ except Exception as e:
155
+ print(f"[JAF:APPROVAL] Approval loading error: {e}")
156
+ return state
jaf/core/tracing.py CHANGED
@@ -344,6 +344,9 @@ class LangfuseTraceCollector:
344
344
  )
345
345
  self.active_spans: Dict[str, Any] = {}
346
346
  self.trace_spans: Dict[TraceId, Any] = {}
347
+ # Track tool calls and results for each trace
348
+ self.trace_tool_calls: Dict[TraceId, List[Dict[str, Any]]] = {}
349
+ self.trace_tool_results: Dict[TraceId, List[Dict[str, Any]]] = {}
347
350
 
348
351
  def collect(self, event: TraceEvent) -> None:
349
352
  """Collect a trace event and send it to Langfuse."""
@@ -359,9 +362,14 @@ class LangfuseTraceCollector:
359
362
  # Start a new trace for the entire run
360
363
  print(f"[LANGFUSE] Starting trace for run: {trace_id}")
361
364
 
365
+ # Initialize tracking for this trace
366
+ self.trace_tool_calls[trace_id] = []
367
+ self.trace_tool_results[trace_id] = []
368
+
362
369
  # Extract user query from the run_start data
363
370
  user_query = None
364
371
  user_id = None
372
+ conversation_history = []
365
373
 
366
374
  # Debug: Print the event data structure to understand what we're working with
367
375
  if event.data.get("context"):
@@ -401,24 +409,51 @@ class LangfuseTraceCollector:
401
409
  user_id = token_response.email
402
410
  print(f"[LANGFUSE DEBUG] Extracted user_id from attr: {user_id}")
403
411
 
404
- # Fallback: try to extract from messages if context didn't work
405
- if not user_query and event.data.get("messages"):
406
- print(f"[LANGFUSE DEBUG] Trying fallback from messages")
407
- messages = event.data["messages"]
408
- print(f"[LANGFUSE DEBUG] Found {len(messages)} messages")
409
- # Find the last user message which should be the current query
410
- for i, msg in enumerate(reversed(messages)):
411
- print(f"[LANGFUSE DEBUG] Message {i}: {msg}")
412
- if isinstance(msg, dict) and msg.get("role") == "user":
413
- user_query = msg.get("content", "")
414
- print(f"[LANGFUSE DEBUG] Found user_query from messages: {user_query}")
415
- break
416
- elif hasattr(msg, 'role') and msg.role == 'user':
417
- user_query = msg.content
418
- print(f"[LANGFUSE DEBUG] Found user_query from message attr: {user_query}")
419
- break
412
+ # Extract conversation history and current user query from messages
413
+ messages = event.data.get("messages", [])
414
+ if messages:
415
+ print(f"[LANGFUSE DEBUG] Processing {len(messages)} messages")
416
+
417
+ # Find the last user message (current query) and extract conversation history (excluding current)
418
+ current_user_message_found = False
419
+ for i in range(len(messages) - 1, -1, -1):
420
+ msg = messages[i]
421
+
422
+ if isinstance(msg, dict):
423
+ role = msg.get("role")
424
+ content = msg.get("content", "")
425
+ elif hasattr(msg, 'role'):
426
+ role = msg.role
427
+ content = getattr(msg, 'content', "")
428
+ # Handle both string content and complex content structures
429
+ if not isinstance(content, str):
430
+ # Try to extract text from complex content
431
+ if hasattr(content, '__iter__') and not isinstance(content, str):
432
+ try:
433
+ # If it's a list, try to join text parts
434
+ content = " ".join(str(item) for item in content if item)
435
+ except:
436
+ content = str(content)
437
+ else:
438
+ content = str(content)
439
+ else:
440
+ continue
441
+
442
+ # If we haven't found the current user message yet and this is a user message
443
+ if not current_user_message_found and (role == "user" or role == 'user'):
444
+ user_query = content
445
+ current_user_message_found = True
446
+ print(f"[LANGFUSE DEBUG] Found current user query: {user_query}")
447
+ elif current_user_message_found:
448
+ # Add to conversation history (excluding the current user message)
449
+ conversation_history.insert(0, {
450
+ "role": role,
451
+ "content": content,
452
+ "timestamp": datetime.now().isoformat() if not hasattr(msg, 'timestamp') else getattr(msg, 'timestamp', datetime.now().isoformat())
453
+ })
420
454
 
421
455
  print(f"[LANGFUSE DEBUG] Final extracted - user_query: {user_query}, user_id: {user_id}")
456
+ print(f"[LANGFUSE DEBUG] Conversation history length: {len(conversation_history)}")
422
457
 
423
458
  # Create comprehensive input data for the trace
424
459
  trace_input = {
@@ -437,31 +472,58 @@ class LangfuseTraceCollector:
437
472
  session_id=event.data.get("session_id"),
438
473
  input=trace_input,
439
474
  metadata={
440
- "framework": "jaf",
441
- "event_type": "run_start",
475
+ "framework": "jaf",
476
+ "event_type": "run_start",
442
477
  "trace_id": str(trace_id),
443
478
  "user_query": user_query,
444
479
  "user_id": user_id or event.data.get("user_id"),
445
- "agent_name": event.data.get("agent_name", "analytics_agent_jaf")
480
+ "agent_name": event.data.get("agent_name", "analytics_agent_jaf"),
481
+ "conversation_history": conversation_history,
482
+ "tool_calls": [],
483
+ "tool_results": []
446
484
  }
447
485
  )
448
486
  self.trace_spans[trace_id] = trace
449
- # Store user_id and user_query for later use in generations
487
+ # Store user_id, user_query, and conversation_history for later use
450
488
  trace._user_id = user_id or event.data.get("user_id")
451
489
  trace._user_query = user_query
490
+ trace._conversation_history = conversation_history
452
491
  print(f"[LANGFUSE] Created trace with user query: {user_query[:100] if user_query else 'None'}...")
453
492
 
454
493
  elif event.type == "run_end":
455
494
  if trace_id in self.trace_spans:
456
495
  print(f"[LANGFUSE] Ending trace for run: {trace_id}")
457
- # End the trace
458
- self.trace_spans[trace_id].update(output=event.data)
496
+
497
+ # Update the trace metadata with final tool calls and results
498
+ final_metadata = {
499
+ "framework": "jaf",
500
+ "event_type": "run_end",
501
+ "trace_id": str(trace_id),
502
+ "user_query": getattr(self.trace_spans[trace_id], '_user_query', None),
503
+ "user_id": getattr(self.trace_spans[trace_id], '_user_id', None),
504
+ "agent_name": event.data.get("agent_name", "analytics_agent_jaf"),
505
+ "conversation_history": getattr(self.trace_spans[trace_id], '_conversation_history', []),
506
+ "tool_calls": self.trace_tool_calls.get(trace_id, []),
507
+ "tool_results": self.trace_tool_results.get(trace_id, [])
508
+ }
509
+
510
+ # End the trace with updated metadata
511
+ self.trace_spans[trace_id].update(
512
+ output=event.data,
513
+ metadata=final_metadata
514
+ )
515
+
459
516
  # Flush to ensure data is sent
460
517
  print(f"[LANGFUSE] Flushing data to Langfuse...")
461
518
  self.langfuse.flush()
462
519
  print(f"[LANGFUSE] Flush completed")
520
+
463
521
  # Clean up
464
522
  del self.trace_spans[trace_id]
523
+ if trace_id in self.trace_tool_calls:
524
+ del self.trace_tool_calls[trace_id]
525
+ if trace_id in self.trace_tool_results:
526
+ del self.trace_tool_results[trace_id]
465
527
  else:
466
528
  print(f"[LANGFUSE] No trace found for run_end: {trace_id}")
467
529
 
@@ -549,6 +611,20 @@ class LangfuseTraceCollector:
549
611
 
550
612
  print(f"[LANGFUSE] Starting span for tool call: {tool_name}")
551
613
 
614
+ # Track this tool call for the trace
615
+ tool_call_data = {
616
+ "tool_name": tool_name,
617
+ "arguments": tool_args,
618
+ "call_id": event.data.get("call_id"),
619
+ "timestamp": datetime.now().isoformat()
620
+ }
621
+
622
+ # Ensure trace_id exists in tracking
623
+ if trace_id not in self.trace_tool_calls:
624
+ self.trace_tool_calls[trace_id] = []
625
+
626
+ self.trace_tool_calls[trace_id].append(tool_call_data)
627
+
552
628
  # Create comprehensive input data for the tool call
553
629
  tool_input = {
554
630
  "tool_name": tool_name,
@@ -579,13 +655,28 @@ class LangfuseTraceCollector:
579
655
 
580
656
  print(f"[LANGFUSE] Ending span for tool call: {tool_name}")
581
657
 
658
+ # Track this tool result for the trace
659
+ tool_result_data = {
660
+ "tool_name": tool_name,
661
+ "result": tool_result,
662
+ "call_id": event.data.get("call_id"),
663
+ "timestamp": datetime.now().isoformat(),
664
+ "status": event.data.get("status", "completed"),
665
+ "tool_result": event.data.get("tool_result")
666
+ }
667
+
668
+ if trace_id not in self.trace_tool_results:
669
+ self.trace_tool_results[trace_id] = []
670
+
671
+ self.trace_tool_results[trace_id].append(tool_result_data)
672
+
582
673
  # Create comprehensive output data for the tool call
583
674
  tool_output = {
584
675
  "tool_name": tool_name,
585
676
  "result": tool_result,
586
677
  "call_id": event.data.get("call_id"),
587
678
  "timestamp": datetime.now().isoformat(),
588
- "status": "completed"
679
+ "status": event.data.get("status", "completed")
589
680
  }
590
681
 
591
682
  # End the span with detailed output
jaf/core/types.py CHANGED
@@ -11,6 +11,7 @@ from collections.abc import Awaitable, AsyncIterator
11
11
  from dataclasses import dataclass, field
12
12
  from typing import (
13
13
  Any,
14
+ Awaitable,
14
15
  Callable,
15
16
  Dict,
16
17
  Generic,
@@ -28,6 +29,8 @@ from enum import Enum
28
29
 
29
30
  if TYPE_CHECKING:
30
31
  from .tool_results import ToolResult
32
+ from ..memory.approval_storage import ApprovalStorage
33
+ from ..memory.types import MemoryConfig
31
34
 
32
35
 
33
36
  # Comprehensive enums for type safety and improved developer experience
@@ -143,13 +146,88 @@ class ToolCallFunction:
143
146
  name: str
144
147
  arguments: str
145
148
 
149
+ @dataclass(frozen=True)
150
+ class Attachment:
151
+ """Represents an attachment with various content types."""
152
+ kind: Literal['image', 'document', 'file']
153
+ mime_type: Optional[str] = None # e.g. image/png, application/pdf
154
+ name: Optional[str] = None # Optional filename
155
+ url: Optional[str] = None # Remote URL or data URL
156
+ data: Optional[str] = None # Base64 without data: prefix
157
+ format: Optional[str] = None # Optional short format like 'pdf', 'txt'
158
+ use_litellm_format: Optional[bool] = None # Use LiteLLM native file format
159
+
160
+ def __post_init__(self):
161
+ """Validate that at least one of url or data is provided."""
162
+ if self.url is None and self.data is None:
163
+ raise ValueError("At least one of 'url' or 'data' must be provided for an Attachment.")
164
+
165
+ @dataclass(frozen=True)
166
+ class MessageContentPart:
167
+ """Part of multi-part message content."""
168
+ type: Literal['text', 'image_url', 'file']
169
+ text: Optional[str] = None
170
+ image_url: Optional[Dict[str, Any]] = None # Contains url and optional detail
171
+ file: Optional[Dict[str, Any]] = None # Contains file_id and optional format
172
+
146
173
  @dataclass(frozen=True)
147
174
  class Message:
148
- """A message in the conversation."""
175
+ """
176
+ A message in the conversation.
177
+
178
+ BACKWARDS COMPATIBILITY:
179
+ - Messages created with string content remain fully backwards compatible
180
+ - Direct access to .content returns the original string when created with string
181
+ - Use .text_content property for guaranteed string access in all cases
182
+ - Use get_text_content() function to extract text from any content type
183
+
184
+ Examples:
185
+ # Original usage - still works exactly the same
186
+ msg = Message(role='user', content='Hello')
187
+ text = msg.content # Returns 'Hello' as string
188
+
189
+ # Guaranteed string access (recommended for new code)
190
+ text = msg.text_content # Always returns string
191
+
192
+ # Universal text extraction
193
+ text = get_text_content(msg.content) # Works with any content type
194
+ """
149
195
  role: ContentRole
150
- content: str
196
+ content: Union[str, List[MessageContentPart]]
197
+ attachments: Optional[List[Attachment]] = None
151
198
  tool_call_id: Optional[str] = None
152
199
  tool_calls: Optional[List[ToolCall]] = None
200
+
201
+ @property
202
+ def text_content(self) -> str:
203
+ """Get text content as string for backwards compatibility."""
204
+ return get_text_content(self.content)
205
+
206
+ @classmethod
207
+ def create(
208
+ cls,
209
+ role: ContentRole,
210
+ content: str,
211
+ attachments: Optional[List[Attachment]] = None,
212
+ tool_call_id: Optional[str] = None,
213
+ tool_calls: Optional[List[ToolCall]] = None
214
+ ) -> 'Message':
215
+ """Create a message with string content and optional attachments."""
216
+ return cls(
217
+ role=role,
218
+ content=content,
219
+ attachments=attachments,
220
+ tool_call_id=tool_call_id,
221
+ tool_calls=tool_calls
222
+ )
223
+
224
+ def get_text_content(content: Union[str, List[MessageContentPart]]) -> str:
225
+ """Extract text content from message content."""
226
+ if isinstance(content, str):
227
+ return content
228
+
229
+ text_parts = [part.text for part in content if part.type == 'text' and part.text]
230
+ return ' '.join(text_parts)
153
231
 
154
232
  @dataclass(frozen=True)
155
233
  class ModelConfig:
@@ -179,6 +257,11 @@ class Tool(Protocol[Args, Ctx]):
179
257
  """Execute the tool with given arguments and context."""
180
258
  ...
181
259
 
260
+ @property
261
+ def needs_approval(self) -> Union[bool, Callable[[Ctx, Args], Union[bool, Awaitable[bool]]]]:
262
+ """Whether this tool requires approval before execution."""
263
+ return False
264
+
182
265
 
183
266
  # Function tool configuration for improved DX
184
267
  class FunctionToolConfig(TypedDict):
@@ -248,6 +331,13 @@ class Agent(Generic[Ctx, Out]):
248
331
  # Guardrail type
249
332
  Guardrail = Callable[[Any], Union[ValidationResult, Awaitable[ValidationResult]]]
250
333
 
334
+ @dataclass(frozen=True)
335
+ class ApprovalValue:
336
+ """Represents an approval decision with context."""
337
+ status: str # 'pending', 'approved', 'rejected'
338
+ approved: bool
339
+ additional_context: Optional[Dict[str, Any]] = None
340
+
251
341
  @dataclass(frozen=True)
252
342
  class RunState(Generic[Ctx]):
253
343
  """Immutable state of a run."""
@@ -257,6 +347,7 @@ class RunState(Generic[Ctx]):
257
347
  current_agent_name: str
258
348
  context: Ctx
259
349
  turn_count: int
350
+ approvals: Dict[str, ApprovalValue] = field(default_factory=dict)
260
351
 
261
352
  # Error types using dataclasses for immutability
262
353
  @dataclass(frozen=True)
@@ -335,6 +426,18 @@ class NetworkError:
335
426
  is_retryable: bool = True
336
427
  endpoint: Optional[str] = None
337
428
 
429
+ # Interruption types for HITL
430
+ @dataclass(frozen=True)
431
+ class ToolApprovalInterruption(Generic[Ctx]):
432
+ """Interruption for tool approval."""
433
+ type: Literal['tool_approval'] = 'tool_approval'
434
+ tool_call: ToolCall = field(default_factory=lambda: ToolCall("", "function", ToolCallFunction("", "")))
435
+ agent: 'Agent[Ctx, Any]' = None
436
+ session_id: Optional[str] = None
437
+
438
+ # Union type for all interruptions
439
+ Interruption = Union[ToolApprovalInterruption[Any]]
440
+
338
441
  # Union type for all possible errors
339
442
  JAFError = Union[
340
443
  MaxTurnsExceeded,
@@ -363,8 +466,14 @@ class ErrorOutcome:
363
466
  status: Literal['error'] = 'error'
364
467
  error: JAFError = field(default=None)
365
468
 
469
+ @dataclass(frozen=True)
470
+ class InterruptedOutcome:
471
+ """Interrupted outcome for HITL."""
472
+ status: Literal['interrupted'] = 'interrupted'
473
+ interruptions: List[Interruption] = field(default_factory=list)
474
+
366
475
  # Union type for outcomes
367
- RunOutcome = Union[CompletedOutcome[Out], ErrorOutcome]
476
+ RunOutcome = Union[CompletedOutcome[Out], ErrorOutcome, InterruptedOutcome]
368
477
 
369
478
  @dataclass(frozen=True)
370
479
  class RunResult(Generic[Out]):
@@ -601,3 +710,4 @@ class RunConfig(Generic[Ctx]):
601
710
  memory: Optional['MemoryConfig'] = None
602
711
  conversation_id: Optional[str] = None
603
712
  default_tool_timeout: Optional[float] = 30.0 # Default timeout for tool execution in seconds
713
+ approval_storage: Optional['ApprovalStorage'] = None # Storage for approval decisions