jaf-py 2.5.2__py3-none-any.whl → 2.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaf/core/tracing.py CHANGED
@@ -443,10 +443,18 @@ class LangfuseTraceCollector:
443
443
  public_key=public_key,
444
444
  secret_key=secret_key,
445
445
  host=host,
446
- release="jaf-py-v2.5.2",
446
+ release="jaf-py-v2.5.4",
447
447
  httpx_client=client
448
448
  )
449
449
  self._httpx_client = client
450
+
451
+ # Detect Langfuse version (v2 has trace() method, v3 does not)
452
+ self._is_langfuse_v3 = not hasattr(self.langfuse, 'trace')
453
+ if self._is_langfuse_v3:
454
+ print("[LANGFUSE] Detected Langfuse v3.x - using OpenTelemetry-based API")
455
+ else:
456
+ print("[LANGFUSE] Detected Langfuse v2.x - using legacy API")
457
+
450
458
  self.active_spans: Dict[str, Any] = {}
451
459
  self.trace_spans: Dict[TraceId, Any] = {}
452
460
  # Track tool calls and results for each trace
@@ -465,6 +473,113 @@ class LangfuseTraceCollector:
465
473
  except Exception as e:
466
474
  print(f"[LANGFUSE] Warning: Failed to close httpx client: {e}")
467
475
 
476
+ def _get_event_data(self, event: TraceEvent, key: str, default: Any = None) -> Any:
477
+ """Extract data from event, handling both dict and dataclass."""
478
+ if not hasattr(event, 'data'):
479
+ return default
480
+
481
+ # Handle dict
482
+ if isinstance(event.data, dict):
483
+ return event.data.get(key, default)
484
+
485
+ # Handle dataclass/object with attributes
486
+ return getattr(event.data, key, default)
487
+
488
+ def _create_trace(self, trace_id: TraceId, **kwargs) -> Any:
489
+ """Create a trace using the appropriate API for the Langfuse version."""
490
+ if self._is_langfuse_v3:
491
+ # Langfuse v3: Use start_span() to create a root span (creates trace implicitly)
492
+ # Extract parameters for v3 API
493
+ name = kwargs.get('name', 'trace')
494
+ input_data = kwargs.get('input')
495
+ metadata = kwargs.get('metadata', {})
496
+ user_id = kwargs.get('user_id')
497
+ session_id = kwargs.get('session_id')
498
+ tags = kwargs.get('tags', [])
499
+
500
+ # Add user_id, session_id, and tags to metadata for v3
501
+ if user_id:
502
+ metadata['user_id'] = user_id
503
+ if session_id:
504
+ metadata['session_id'] = session_id
505
+ if tags:
506
+ metadata['tags'] = tags
507
+
508
+ # Create root span
509
+ trace = self.langfuse.start_span(
510
+ name=name,
511
+ input=input_data,
512
+ metadata=metadata
513
+ )
514
+
515
+ # Update trace properties using update_trace()
516
+ update_params = {}
517
+ if user_id:
518
+ update_params['user_id'] = user_id
519
+ if session_id:
520
+ update_params['session_id'] = session_id
521
+ if tags:
522
+ update_params['tags'] = tags
523
+
524
+ if update_params:
525
+ trace.update_trace(**update_params)
526
+
527
+ return trace
528
+ else:
529
+ # Langfuse v2: Use trace() method
530
+ return self.langfuse.trace(**kwargs)
531
+
532
+ def _create_generation(self, parent_span: Any, **kwargs) -> Any:
533
+ """Create a generation using the appropriate API for the Langfuse version."""
534
+ if self._is_langfuse_v3:
535
+ # Langfuse v3: Use start_generation() method
536
+ return parent_span.start_generation(**kwargs)
537
+ else:
538
+ # Langfuse v2: Use generation() method
539
+ return parent_span.generation(**kwargs)
540
+
541
+ def _create_span(self, parent_span: Any, **kwargs) -> Any:
542
+ """Create a span using the appropriate API for the Langfuse version."""
543
+ if self._is_langfuse_v3:
544
+ # Langfuse v3: Use start_span() method
545
+ return parent_span.start_span(**kwargs)
546
+ else:
547
+ # Langfuse v2: Use span() method
548
+ return parent_span.span(**kwargs)
549
+
550
+ def _create_event(self, parent_span: Any, **kwargs) -> Any:
551
+ """Create an event using the appropriate API for the Langfuse version."""
552
+ if self._is_langfuse_v3:
553
+ # Langfuse v3: Use create_event() method
554
+ return parent_span.create_event(**kwargs)
555
+ else:
556
+ # Langfuse v2: Use event() method
557
+ return parent_span.event(**kwargs)
558
+
559
+ def _end_span(self, span: Any, **kwargs) -> None:
560
+ """End a span/generation using the appropriate API for the Langfuse version."""
561
+ if self._is_langfuse_v3:
562
+ # Langfuse v3: Call update() first with output/metadata, then end()
563
+ update_params = {}
564
+ end_params = {}
565
+
566
+ # Separate parameters for update() vs end()
567
+ for key, value in kwargs.items():
568
+ if key in ['output', 'metadata', 'model', 'usage']:
569
+ update_params[key] = value
570
+ elif key == 'end_time':
571
+ end_params[key] = value
572
+
573
+ # Update first if there are parameters
574
+ if update_params:
575
+ span.update(**update_params)
576
+
577
+ # Then end
578
+ span.end(**end_params)
579
+ else:
580
+ # Langfuse v2: Call end() directly with all parameters
581
+ span.end(**kwargs)
582
+
468
583
  def collect(self, event: TraceEvent) -> None:
469
584
  """Collect a trace event and send it to Langfuse."""
470
585
  try:
@@ -489,15 +604,15 @@ class LangfuseTraceCollector:
489
604
  conversation_history = []
490
605
 
491
606
  # Debug: Print the event data structure to understand what we're working with
492
- if event.data.get("context"):
493
- context = event.data["context"]
607
+ if self._get_event_data(event, "context"):
608
+ context = self._get_event_data(event, "context")
494
609
  print(f"[LANGFUSE DEBUG] Context type: {type(context)}")
495
610
  print(f"[LANGFUSE DEBUG] Context attributes: {dir(context) if hasattr(context, '__dict__') else 'Not an object'}")
496
611
  if hasattr(context, '__dict__'):
497
612
  print(f"[LANGFUSE DEBUG] Context dict: {context.__dict__}")
498
613
 
499
614
  # Try to extract from context first
500
- context = event.data.get("context")
615
+ context = self._get_event_data(event, "context")
501
616
  if context:
502
617
  # Try direct attribute access
503
618
  if hasattr(context, 'query'):
@@ -527,7 +642,7 @@ class LangfuseTraceCollector:
527
642
  print(f"[LANGFUSE DEBUG] Extracted user_id from attr: {user_id}")
528
643
 
529
644
  # Extract conversation history and current user query from messages
530
- messages = event.data.get("messages", [])
645
+ messages = self._get_event_data(event, "messages", [])
531
646
  if messages:
532
647
  print(f"[LANGFUSE DEBUG] Processing {len(messages)} messages")
533
648
 
@@ -619,20 +734,22 @@ class LangfuseTraceCollector:
619
734
  trace_input = {
620
735
  "user_query": user_query,
621
736
  "run_id": str(trace_id),
622
- "agent_name": event.data.get("agent_name", "analytics_agent_jaf"),
737
+ "agent_name": self._get_event_data(event, "agent_name", "analytics_agent_jaf"),
623
738
  "session_info": {
624
- "session_id": event.data.get("session_id"),
625
- "user_id": user_id or event.data.get("user_id")
739
+ "session_id": self._get_event_data(event, "session_id"),
740
+ "user_id": user_id or self._get_event_data(event, "user_id")
626
741
  }
627
742
  }
628
743
 
629
744
  # Extract agent_name for tagging
630
- agent_name = event.data.get("agent_name") or "analytics_agent_jaf"
745
+ agent_name = self._get_event_data(event, "agent_name") or "analytics_agent_jaf"
631
746
 
632
- trace = self.langfuse.trace(
747
+ # Use compatibility layer to create trace (works with both v2 and v3)
748
+ trace = self._create_trace(
749
+ trace_id=trace_id,
633
750
  name=agent_name,
634
- user_id=user_id or event.data.get("user_id"),
635
- session_id=event.data.get("session_id"),
751
+ user_id=user_id or self._get_event_data(event, "user_id"),
752
+ session_id=self._get_event_data(event, "session_id"),
636
753
  input=trace_input,
637
754
  tags=[agent_name], # Add agent_name as a tag for dashboard filtering
638
755
  metadata={
@@ -640,17 +757,17 @@ class LangfuseTraceCollector:
640
757
  "event_type": "run_start",
641
758
  "trace_id": str(trace_id),
642
759
  "user_query": user_query,
643
- "user_id": user_id or event.data.get("user_id"),
760
+ "user_id": user_id or self._get_event_data(event, "user_id"),
644
761
  "agent_name": agent_name,
645
762
  "conversation_history": conversation_history,
646
763
  "tool_calls": [],
647
764
  "tool_results": [],
648
- "user_info": event.data.get("context").user_info if event.data.get("context") and hasattr(event.data.get("context"), 'user_info') else None
765
+ "user_info": self._get_event_data(event, "context").user_info if self._get_event_data(event, "context") and hasattr(self._get_event_data(event, "context"), 'user_info') else None
649
766
  }
650
767
  )
651
768
  self.trace_spans[trace_id] = trace
652
769
  # Store user_id, user_query, and conversation_history for later use
653
- trace._user_id = user_id or event.data.get("user_id")
770
+ trace._user_id = user_id or self._get_event_data(event, "user_id")
654
771
  trace._user_query = user_query
655
772
  trace._conversation_history = conversation_history
656
773
  print(f"[LANGFUSE] Created trace with user query: {user_query[:100] if user_query else 'None'}...")
@@ -667,7 +784,7 @@ class LangfuseTraceCollector:
667
784
  "trace_id": str(trace_id),
668
785
  "user_query": getattr(self.trace_spans[trace_id], '_user_query', None),
669
786
  "user_id": getattr(self.trace_spans[trace_id], '_user_id', None),
670
- "agent_name": event.data.get("agent_name", "analytics_agent_jaf"),
787
+ "agent_name": self._get_event_data(event, "agent_name", "analytics_agent_jaf"),
671
788
  "conversation_history": conversation_history,
672
789
  "tool_calls": self.trace_tool_calls.get(trace_id, []),
673
790
  "tool_results": self.trace_tool_results.get(trace_id, [])
@@ -695,7 +812,7 @@ class LangfuseTraceCollector:
695
812
 
696
813
  elif event.type == "llm_call_start":
697
814
  # Start a generation for LLM calls
698
- model = event.data.get("model", "unknown")
815
+ model = self._get_event_data(event, "model", "unknown")
699
816
  print(f"[LANGFUSE] Starting generation for LLM call with model: {model}")
700
817
 
701
818
  # Get stored user information from the trace
@@ -703,11 +820,13 @@ class LangfuseTraceCollector:
703
820
  user_id = getattr(trace, '_user_id', None)
704
821
  user_query = getattr(trace, '_user_query', None)
705
822
 
706
- generation = trace.generation(
823
+ # Use compatibility layer to create generation (works with both v2 and v3)
824
+ generation = self._create_generation(
825
+ parent_span=trace,
707
826
  name=f"llm-call-{model}",
708
- input=event.data.get("messages"),
827
+ input=self._get_event_data(event, "messages"),
709
828
  metadata={
710
- "agent_name": event.data.get("agent_name"),
829
+ "agent_name": self._get_event_data(event, "agent_name"),
711
830
  "model": model,
712
831
  "user_id": user_id,
713
832
  "user_query": user_query
@@ -723,10 +842,10 @@ class LangfuseTraceCollector:
723
842
  print(f"[LANGFUSE] Ending generation for LLM call")
724
843
  # End the generation
725
844
  generation = self.active_spans[span_id]
726
- choice = event.data.get("choice", {})
727
-
845
+ choice = self._get_event_data(event, "choice", {})
846
+
728
847
  # Extract usage from the event data
729
- usage = event.data.get("usage", {})
848
+ usage = self._get_event_data(event, "usage", {})
730
849
 
731
850
  # Extract model information from choice data or event data
732
851
  model = choice.get("model", "unknown")
@@ -752,8 +871,10 @@ class LangfuseTraceCollector:
752
871
  print(f"[LANGFUSE] Usage data for automatic cost calculation: {langfuse_usage}")
753
872
 
754
873
  # Include model information in the generation end - Langfuse will calculate costs automatically
755
- generation.end(
756
- output=choice,
874
+ # Use compatibility wrapper for ending spans/generations
875
+ self._end_span(
876
+ span=generation,
877
+ output=choice,
757
878
  usage=langfuse_usage,
758
879
  model=model, # Pass model directly for automatic cost calculation
759
880
  metadata={
@@ -772,9 +893,9 @@ class LangfuseTraceCollector:
772
893
 
773
894
  elif event.type == "tool_call_start":
774
895
  # Start a span for tool calls with detailed input information
775
- tool_name = event.data.get('tool_name', 'unknown')
776
- tool_args = event.data.get("args", {})
777
- call_id = event.data.get("call_id")
896
+ tool_name = self._get_event_data(event, 'tool_name', 'unknown')
897
+ tool_args = self._get_event_data(event, "args", {})
898
+ call_id = self._get_event_data(event, "call_id")
778
899
  if not call_id:
779
900
  call_id = f"{tool_name}-{uuid.uuid4().hex[:8]}"
780
901
  try:
@@ -807,7 +928,9 @@ class LangfuseTraceCollector:
807
928
  "timestamp": datetime.now().isoformat()
808
929
  }
809
930
 
810
- span = self.trace_spans[trace_id].span(
931
+ # Use compatibility layer to create span (works with both v2 and v3)
932
+ span = self._create_span(
933
+ parent_span=self.trace_spans[trace_id],
811
934
  name=f"tool-{tool_name}",
812
935
  input=tool_input,
813
936
  metadata={
@@ -824,9 +947,9 @@ class LangfuseTraceCollector:
824
947
  elif event.type == "tool_call_end":
825
948
  span_id = self._get_span_id(event)
826
949
  if span_id in self.active_spans:
827
- tool_name = event.data.get('tool_name', 'unknown')
828
- tool_result = event.data.get("result")
829
- call_id = event.data.get("call_id")
950
+ tool_name = self._get_event_data(event, 'tool_name', 'unknown')
951
+ tool_result = self._get_event_data(event, "result")
952
+ call_id = self._get_event_data(event, "call_id")
830
953
 
831
954
  print(f"[LANGFUSE] Ending span for tool call: {tool_name} ({call_id})")
832
955
 
@@ -836,9 +959,9 @@ class LangfuseTraceCollector:
836
959
  "result": tool_result,
837
960
  "call_id": call_id,
838
961
  "timestamp": datetime.now().isoformat(),
839
- "execution_status": event.data.get("execution_status", "completed"),
840
- "status": event.data.get("execution_status", "completed"), # DEPRECATED: backward compatibility
841
- "tool_result": event.data.get("tool_result")
962
+ "execution_status": self._get_event_data(event, "execution_status", "completed"),
963
+ "status": self._get_event_data(event, "execution_status", "completed"), # DEPRECATED: backward compatibility
964
+ "tool_result": self._get_event_data(event, "tool_result")
842
965
  }
843
966
 
844
967
  if trace_id not in self.trace_tool_results:
@@ -852,13 +975,15 @@ class LangfuseTraceCollector:
852
975
  "result": tool_result,
853
976
  "call_id": call_id,
854
977
  "timestamp": datetime.now().isoformat(),
855
- "execution_status": event.data.get("execution_status", "completed"),
856
- "status": event.data.get("execution_status", "completed") # DEPRECATED: backward compatibility
978
+ "execution_status": self._get_event_data(event, "execution_status", "completed"),
979
+ "status": self._get_event_data(event, "execution_status", "completed") # DEPRECATED: backward compatibility
857
980
  }
858
981
 
859
982
  # End the span with detailed output
983
+ # Use compatibility wrapper for ending spans/generations
860
984
  span = self.active_spans[span_id]
861
- span.end(
985
+ self._end_span(
986
+ span=span,
862
987
  output=tool_output,
863
988
  metadata={
864
989
  "tool_name": tool_name,
@@ -878,17 +1003,21 @@ class LangfuseTraceCollector:
878
1003
  elif event.type == "handoff":
879
1004
  # Create an event for handoffs
880
1005
  print(f"[LANGFUSE] Creating event for handoff")
881
- self.trace_spans[trace_id].event(
1006
+ # Use compatibility layer to create event (works with both v2 and v3)
1007
+ self._create_event(
1008
+ parent_span=self.trace_spans[trace_id],
882
1009
  name="agent-handoff",
883
- input={"from": event.data.get("from"), "to": event.data.get("to")},
1010
+ input={"from": self._get_event_data(event, "from"), "to": self._get_event_data(event, "to")},
884
1011
  metadata=event.data
885
1012
  )
886
1013
  print(f"[LANGFUSE] Handoff event created")
887
-
1014
+
888
1015
  else:
889
1016
  # Create a generic event for other event types
890
1017
  print(f"[LANGFUSE] Creating generic event for: {event.type}")
891
- self.trace_spans[trace_id].event(
1018
+ # Use compatibility layer to create event (works with both v2 and v3)
1019
+ self._create_event(
1020
+ parent_span=self.trace_spans[trace_id],
892
1021
  name=event.type,
893
1022
  input=event.data,
894
1023
  metadata={"framework": "jaf", "event_type": event.type}
@@ -902,20 +1031,28 @@ class LangfuseTraceCollector:
902
1031
  traceback.print_exc()
903
1032
 
904
1033
  def _get_trace_id(self, event: TraceEvent) -> Optional[TraceId]:
905
- """Extract trace ID from event data."""
906
- if hasattr(event, 'data') and isinstance(event.data, dict):
907
- # Try snake_case first (Python convention)
908
- if 'trace_id' in event.data:
909
- return event.data['trace_id']
910
- elif 'run_id' in event.data:
911
- return TraceId(event.data['run_id'])
912
- # Fallback to camelCase (for compatibility)
913
- elif 'traceId' in event.data:
914
- return event.data['traceId']
915
- elif 'runId' in event.data:
916
- return TraceId(event.data['runId'])
917
-
918
- # Debug: print what's actually in the event data
1034
+ """Extract trace ID from event data, handling both dict and dataclass."""
1035
+ if not hasattr(event, 'data'):
1036
+ return None
1037
+
1038
+ # Try snake_case first (Python convention)
1039
+ trace_id = self._get_event_data(event, 'trace_id')
1040
+ if trace_id:
1041
+ return trace_id
1042
+
1043
+ run_id = self._get_event_data(event, 'run_id')
1044
+ if run_id:
1045
+ return TraceId(run_id)
1046
+
1047
+ # Fallback to camelCase (for compatibility)
1048
+ trace_id = self._get_event_data(event, 'traceId')
1049
+ if trace_id:
1050
+ return trace_id
1051
+
1052
+ run_id = self._get_event_data(event, 'runId')
1053
+ if run_id:
1054
+ return TraceId(run_id)
1055
+
919
1056
  return None
920
1057
 
921
1058
  def _get_span_id(self, event: TraceEvent) -> str:
@@ -924,15 +1061,15 @@ class LangfuseTraceCollector:
924
1061
 
925
1062
  # Use consistent identifiers that don't depend on timestamp
926
1063
  if event.type.startswith('tool_call'):
927
- call_id = event.data.get('call_id') or event.data.get('tool_call_id')
1064
+ call_id = self._get_event_data(event, 'call_id') or self._get_event_data(event, 'tool_call_id')
928
1065
  if call_id:
929
1066
  return f"tool-{trace_id}-{call_id}"
930
- tool_name = event.data.get('tool_name') or event.data.get('toolName', 'unknown')
1067
+ tool_name = self._get_event_data(event, 'tool_name') or self._get_event_data(event, 'toolName', 'unknown')
931
1068
  return f"tool-{tool_name}-{trace_id}"
932
1069
  elif event.type.startswith('llm_call'):
933
1070
  # For LLM calls, use a simpler consistent ID that matches between start and end
934
1071
  # Get run_id for more consistent matching
935
- run_id = event.data.get('run_id') or event.data.get('runId', trace_id)
1072
+ run_id = self._get_event_data(event, 'run_id') or self._get_event_data(event, 'runId', trace_id)
936
1073
  return f"llm-{run_id}"
937
1074
  else:
938
1075
  return f"{event.type}-{trace_id}"
jaf/core/types.py CHANGED
@@ -94,6 +94,11 @@ class RunId(str):
94
94
  def __new__(cls, value: str) -> 'RunId':
95
95
  return str.__new__(cls, value)
96
96
 
97
+ class MessageId(str):
98
+ """Branded string type for message IDs."""
99
+ def __new__(cls, value: str) -> 'MessageId':
100
+ return str.__new__(cls, value)
101
+
97
102
  def create_trace_id(id_str: str) -> TraceId:
98
103
  """Create a TraceId from a string."""
99
104
  return TraceId(id_str)
@@ -102,6 +107,36 @@ def create_run_id(id_str: str) -> RunId:
102
107
  """Create a RunId from a string."""
103
108
  return RunId(id_str)
104
109
 
110
+ def create_message_id(id_str: Union[str, MessageId]) -> MessageId:
111
+ """
112
+ Create a MessageId from a string or return existing MessageId.
113
+
114
+ Args:
115
+ id_str: Either a string to convert to MessageId or an existing MessageId
116
+
117
+ Returns:
118
+ MessageId: A validated MessageId instance
119
+
120
+ Raises:
121
+ ValueError: If the input is invalid or empty
122
+ """
123
+ # Handle None input
124
+ if id_str is None:
125
+ raise ValueError("Message ID cannot be None")
126
+
127
+ # If already a MessageId, return as-is
128
+ if isinstance(id_str, MessageId):
129
+ return id_str
130
+
131
+ # Convert string to MessageId with validation
132
+ if isinstance(id_str, str):
133
+ if not id_str.strip():
134
+ raise ValueError("Message ID cannot be empty or whitespace")
135
+ return MessageId(id_str.strip())
136
+
137
+ # Handle any other type
138
+ raise ValueError(f"Message ID must be a string or MessageId, got {type(id_str)}")
139
+
105
140
  def generate_run_id() -> RunId:
106
141
  """Generate a new unique run ID."""
107
142
  import time
@@ -114,6 +149,12 @@ def generate_trace_id() -> TraceId:
114
149
  import uuid
115
150
  return TraceId(f"trace_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
116
151
 
152
+ def generate_message_id() -> MessageId:
153
+ """Generate a new unique message ID."""
154
+ import time
155
+ import uuid
156
+ return MessageId(f"msg_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
157
+
117
158
  # Type variables for generic contexts and outputs
118
159
  Ctx = TypeVar('Ctx')
119
160
  Out = TypeVar('Out')
@@ -180,12 +221,16 @@ class Message:
180
221
  - Direct access to .content returns the original string when created with string
181
222
  - Use .text_content property for guaranteed string access in all cases
182
223
  - Use get_text_content() function to extract text from any content type
224
+ - message_id is optional for backward compatibility
183
225
 
184
226
  Examples:
185
227
  # Original usage - still works exactly the same
186
228
  msg = Message(role='user', content='Hello')
187
229
  text = msg.content # Returns 'Hello' as string
188
230
 
231
+ # New usage with message ID
232
+ msg = Message(role='user', content='Hello', message_id='msg_123')
233
+
189
234
  # Guaranteed string access (recommended for new code)
190
235
  text = msg.text_content # Always returns string
191
236
 
@@ -197,6 +242,27 @@ class Message:
197
242
  attachments: Optional[List[Attachment]] = None
198
243
  tool_call_id: Optional[str] = None
199
244
  tool_calls: Optional[List[ToolCall]] = None
245
+ message_id: Optional[MessageId] = None # Optional for backward compatibility
246
+
247
+ def __post_init__(self):
248
+ """
249
+ Auto-generate message ID if not provided.
250
+
251
+ This implementation uses object.__setattr__ to bypass frozen dataclass restrictions,
252
+ which is a recommended pattern for one-time initialization of computed fields in
253
+ frozen dataclasses. This ensures:
254
+
255
+ 1. Backward compatibility - existing code with message_id=None continues to work
256
+ 2. Immutability - the dataclass remains frozen after initialization
257
+ 3. Guaranteed unique IDs - every message gets a unique identifier
258
+ 4. Clean API - users don't need to manually generate IDs in most cases
259
+
260
+ This pattern is preferred over using field(default_factory=...) because it
261
+ maintains the Optional[MessageId] type hint for backward compatibility while
262
+ ensuring the field is never actually None after object creation.
263
+ """
264
+ if self.message_id is None:
265
+ object.__setattr__(self, 'message_id', generate_message_id())
200
266
 
201
267
  @property
202
268
  def text_content(self) -> str:
@@ -210,7 +276,8 @@ class Message:
210
276
  content: str,
211
277
  attachments: Optional[List[Attachment]] = None,
212
278
  tool_call_id: Optional[str] = None,
213
- tool_calls: Optional[List[ToolCall]] = None
279
+ tool_calls: Optional[List[ToolCall]] = None,
280
+ message_id: Optional[MessageId] = None
214
281
  ) -> 'Message':
215
282
  """Create a message with string content and optional attachments."""
216
283
  return cls(
@@ -218,7 +285,8 @@ class Message:
218
285
  content=content,
219
286
  attachments=attachments,
220
287
  tool_call_id=tool_call_id,
221
- tool_calls=tool_calls
288
+ tool_calls=tool_calls,
289
+ message_id=message_id
222
290
  )
223
291
 
224
292
  def get_text_content(content: Union[str, List[MessageContentPart]]) -> str:
@@ -824,3 +892,42 @@ class RunConfig(Generic[Ctx]):
824
892
  default_fast_model: Optional[str] = None # Default model for fast operations like guardrails
825
893
  default_tool_timeout: Optional[float] = 300.0 # Default timeout for tool execution in seconds
826
894
  approval_storage: Optional['ApprovalStorage'] = None # Storage for approval decisions
895
+
896
+ # Regeneration types for conversation management
897
+ @dataclass(frozen=True)
898
+ class RegenerationRequest:
899
+ """Request to regenerate a conversation from a specific message."""
900
+ conversation_id: str
901
+ message_id: MessageId # ID of the message to regenerate from
902
+ context: Optional[Dict[str, Any]] = None # Optional context override
903
+
904
+ @dataclass(frozen=True)
905
+ class RegenerationContext:
906
+ """Context information for a regeneration operation."""
907
+ original_message_count: int
908
+ truncated_at_index: int
909
+ regenerated_message_id: MessageId
910
+ regeneration_id: str # Unique ID for this regeneration operation
911
+ timestamp: int # Unix timestamp in milliseconds
912
+
913
+ # Message utility functions
914
+ def find_message_index(messages: List[Message], message_id: MessageId) -> Optional[int]:
915
+ """Find the index of a message by its ID."""
916
+ for i, msg in enumerate(messages):
917
+ if msg.message_id == message_id:
918
+ return i
919
+ return None
920
+
921
+ def truncate_messages_after(messages: List[Message], message_id: MessageId) -> List[Message]:
922
+ """Truncate messages after (and including) the specified message ID."""
923
+ index = find_message_index(messages, message_id)
924
+ if index is None:
925
+ return messages # Message not found, return unchanged
926
+ return messages[:index]
927
+
928
+ def get_message_by_id(messages: List[Message], message_id: MessageId) -> Optional[Message]:
929
+ """Get a message by its ID."""
930
+ for msg in messages:
931
+ if msg.message_id == message_id:
932
+ return msg
933
+ return None