judgeval 0.0.35__py3-none-any.whl → 0.0.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,337 +1,2000 @@
1
- from typing import Any, Dict, List, Optional, Sequence
1
+ from typing import Any, Dict, List, Optional, Sequence, Callable, TypedDict
2
2
  from uuid import UUID
3
3
  import time
4
4
  import uuid
5
- from contextvars import ContextVar
6
- from judgeval.common.tracer import TraceClient, TraceEntry, Tracer, SpanType
5
+ import traceback # For detailed error logging
6
+ import contextvars # <--- Import contextvars
7
+ from dataclasses import dataclass
8
+
9
+ from judgeval.common.tracer import TraceClient, TraceEntry, Tracer, SpanType, EvaluationConfig
10
+ from judgeval.data import Example # Import Example
11
+ from judgeval.scorers import AnswerRelevancyScorer, JudgevalScorer, APIJudgmentScorer # Import Scorer and base scorer types
7
12
 
8
13
  from langchain_core.language_models import BaseChatModel
9
14
  from langchain_huggingface import ChatHuggingFace
10
15
  from langchain_openai import ChatOpenAI
11
16
  from langchain_anthropic import ChatAnthropic
12
17
  from langchain_core.utils.function_calling import convert_to_openai_tool
13
- from langchain_core.callbacks import CallbackManager, BaseCallbackHandler
18
+ from langchain_core.callbacks import BaseCallbackHandler
19
+ from langchain_core.callbacks.base import AsyncCallbackHandler
14
20
  from langchain_core.agents import AgentAction, AgentFinish
15
21
  from langchain_core.outputs import LLMResult
16
- from langchain_core.tracers.context import register_configure_hook
17
22
  from langchain_core.messages.ai import AIMessage
18
23
  from langchain_core.messages.tool import ToolMessage
19
24
  from langchain_core.messages.base import BaseMessage
20
25
  from langchain_core.documents import Document
21
26
 
27
+ # --- Get context vars from tracer module ---
28
+ # Assuming tracer.py defines these and they are accessible
29
+ # If not, redefine them here or adjust import
30
+ from judgeval.common.tracer import current_span_var, current_trace_var # <-- Import current_trace_var
31
+
32
+ # --- Constants for Logging ---
33
+ HANDLER_LOG_PREFIX = "[JudgevalHandlerLog]"
34
+
35
+ # --- NEW __init__ ---
22
36
  class JudgevalCallbackHandler(BaseCallbackHandler):
37
+ """
38
+ LangChain Callback Handler using run_id/parent_run_id for hierarchy.
39
+ Manages its own internal TraceClient instance created upon first use.
40
+ Includes verbose logging and defensive checks.
41
+ """
42
+ # Make all properties ignored by LangChain's callback system
43
+ # to prevent unexpected serialization issues.
44
+ lc_serializable = False
45
+ lc_kwargs = {}
46
+
47
+ # --- NEW __init__ ---
23
48
  def __init__(self, tracer: Tracer):
49
+ # --- Enhanced Logging ---
50
+ # instance_id = id(self)
51
+ # # print(f"{HANDLER_LOG_PREFIX} *** Handler instance {instance_id} __init__ called. ***")
52
+ # --- End Enhanced Logging ---
24
53
  self.tracer = tracer
25
- self.previous_spans = [] # stack of previous spans
26
- self.created_trace = False
54
+ self._trace_client: Optional[TraceClient] = None
55
+ self._run_id_to_span_id: Dict[UUID, str] = {}
56
+ self._span_id_to_start_time: Dict[str, float] = {}
57
+ self._span_id_to_depth: Dict[str, int] = {}
58
+ self._run_id_to_context_token: Dict[UUID, contextvars.Token] = {}
59
+ self._root_run_id: Optional[UUID] = None
60
+ self._trace_saved: bool = False # Flag to prevent actions after trace is saved
61
+ self._run_id_to_start_inputs: Dict[UUID, Dict] = {} # <<< ADDED input storage
27
62
 
28
- # Attributes for users to access
29
- self.previous_node = None
30
- self.executed_node_tools = []
31
- self.executed_nodes = []
32
- self.executed_tools = []
63
+ # --- Token Count Accumulators ---
64
+ # self._current_prompt_tokens = 0
65
+ # self._current_completion_tokens = 0
66
+ # --- End Token Count Accumulators ---
33
67
 
34
- def start_span(self, name: str, span_type: SpanType = "span"):
35
- current_trace = self.tracer.get_current_trace()
36
- start_time = time.time()
68
+ self.executed_nodes: List[str] = []
69
+ self.executed_tools: List[str] = []
70
+ self.executed_node_tools: List[str] = []
71
+ self.traces: List[Dict[str, Any]] = []
72
+ # --- END NEW __init__ ---
37
73
 
38
- # Generate a unique ID for *this specific span invocation*
39
- span_id = str(uuid.uuid4())
40
-
41
- parent_span_id = current_trace.get_current_span()
42
- token = current_trace.set_current_span(span_id) # Set *this* span's ID as the current one
43
-
44
- current_depth = 0
45
- if parent_span_id and parent_span_id in current_trace._span_depths:
46
- current_depth = current_trace._span_depths[parent_span_id] + 1
47
-
48
- current_trace._span_depths[span_id] = current_depth # Store depth by span_id
49
- # Record span entry
50
- current_trace.add_entry(TraceEntry(
51
- type="enter",
52
- span_id=span_id,
53
- trace_id=current_trace.trace_id,
54
- parent_span_id=parent_span_id,
55
- function=name,
56
- depth=current_depth,
57
- message=name,
58
- created_at=start_time,
59
- span_type=span_type
60
- ))
61
-
62
- self.previous_spans.append(token)
63
- self._start_time = start_time
64
-
65
- def end_span(self, span_type: SpanType = "span"):
66
- current_trace = self.tracer.get_current_trace()
67
- duration = time.time() - self._start_time
68
- span_id = current_trace.get_current_span()
69
- exit_depth = current_trace._span_depths.get(span_id, 0) # Get depth using this span's ID
70
-
71
- # Record span exit
72
- current_trace.add_entry(TraceEntry(
73
- type="exit",
74
- span_id=span_id,
75
- trace_id=current_trace.trace_id,
76
- depth=exit_depth,
77
- created_at=time.time(),
78
- duration=duration,
79
- span_type=span_type
80
- ))
81
- current_trace.reset_current_span(self.previous_spans.pop())
82
- if exit_depth == 0:
83
- # Save the trace if we are the root, this is when users dont use any @observe decorators
84
- trace_id, trace_data = current_trace.save(overwrite=True)
85
- self._trace_id = trace_id
86
- current_trace = None
87
-
88
- def on_retriever_start(
89
- self,
90
- serialized: Optional[dict[str, Any]],
91
- query: str,
92
- *,
93
- run_id: UUID,
94
- parent_run_id: Optional[UUID] = None,
95
- tags: Optional[list[str]] = None,
96
- metadata: Optional[dict[str, Any]] = None,
97
- **kwargs: Any,
98
- ) -> Any:
99
- name = "RETRIEVER_CALL"
100
- if serialized and "name" in serialized:
101
- name = f"RETRIEVER_{serialized['name'].upper()}"
102
- current_trace = self.tracer.get_current_trace()
103
- self.start_span(name, span_type="retriever")
104
- current_trace.record_input({
105
- 'query': query,
106
- 'tags': tags,
107
- 'metadata': metadata,
108
- 'kwargs': kwargs
109
- })
110
-
111
- def on_retriever_end(
112
- self,
113
- documents: Sequence[Document],
114
- *,
115
- run_id: UUID,
116
- parent_run_id: Optional[UUID] = None,
117
- **kwargs: Any
118
- ) -> Any:
119
- # Process the retrieved documents into a format suitable for logging
120
- current_trace = self.tracer.get_current_trace()
121
- doc_summary = []
122
- for i, doc in enumerate(documents):
123
- # Extract key information from each document
124
- doc_data = {
125
- "index": i,
126
- "page_content": doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content,
127
- "metadata": doc.metadata
128
- }
129
- doc_summary.append(doc_data)
130
-
131
- # Record the document data
132
- current_trace.record_output({
133
- "document_count": len(documents),
134
- "documents": doc_summary
135
- })
136
-
137
- # End the retriever span
138
- self.end_span(span_type="retriever")
139
-
140
- def on_chain_start(
141
- self,
142
- serialized: Dict[str, Any],
143
- inputs: Dict[str, Any],
144
- *,
145
- run_id: UUID,
146
- parent_run_id: Optional[UUID] = None,
147
- tags: Optional[List[str]] = None,
148
- metadata: Optional[Dict[str, Any]] = None,
149
- **kwargs: Any
150
- ) -> None:
151
- # If the user doesnt use any @observe decorators, the first action in LangGraph workflows seems tohave this attribute, so we intialize our trace client here
152
- current_trace = self.tracer.get_current_trace()
153
- if kwargs.get('name') == 'LangGraph':
154
- if not current_trace:
155
- self.created_trace = True
156
- trace_id = str(uuid.uuid4())
157
- project = self.tracer.project_name
158
- trace = TraceClient(self.tracer, trace_id, "Langgraph", project_name=project, overwrite=False, rules=self.tracer.rules, enable_monitoring=self.tracer.enable_monitoring, enable_evaluations=self.tracer.enable_evaluations)
159
- self.tracer.set_current_trace(trace)
160
- self.start_span("LangGraph", span_type="Main Function")
161
-
162
- node = metadata.get("langgraph_node")
163
- if node != None and node != self.previous_node:
164
- self.start_span(node, span_type="node")
165
- self.executed_node_tools.append(node)
166
- self.executed_nodes.append(node)
167
- current_trace.record_input({
168
- 'args': inputs,
169
- 'kwargs': kwargs
170
- })
171
- self.previous_node = node
172
-
173
- def on_chain_end(
74
+ # --- MODIFIED _ensure_trace_client ---
75
+ def _ensure_trace_client(self, run_id: UUID, parent_run_id: Optional[UUID], event_name: str) -> Optional[TraceClient]:
76
+ """
77
+ Ensures the internal trace client is initialized, creating it only once
78
+ per handler instance lifecycle (effectively per graph invocation).
79
+ Returns the client or None.
80
+ """
81
+ # handler_instance_id = id(self)
82
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
83
+
84
+ # If trace already saved, do nothing.
85
+ # if self._trace_saved:
86
+ # # print(f"{log_prefix} Trace already saved. Skipping client check for {event_name} ({run_id}).")
87
+ # return None
88
+
89
+ # If a client already exists, return it.
90
+ if self._trace_client:
91
+ # # print(f"{log_prefix} Reusing existing TraceClient (ID: {self._trace_client.trace_id}) for {event_name} ({run_id}).")
92
+ return self._trace_client
93
+
94
+ # If no client exists, initialize it NOW.
95
+ # # print(f"{log_prefix} No TraceClient exists. Initializing for first event: {event_name} ({run_id})...")
96
+ trace_id = str(uuid.uuid4())
97
+ project = self.tracer.project_name
98
+ try:
99
+ # Use event_name as the initial trace name, might be updated later by on_chain_start if root
100
+ client_instance = TraceClient(
101
+ self.tracer, trace_id, event_name, project_name=project,
102
+ overwrite=False, rules=self.tracer.rules,
103
+ enable_monitoring=self.tracer.enable_monitoring,
104
+ enable_evaluations=self.tracer.enable_evaluations
105
+ )
106
+ self._trace_client = client_instance
107
+ if self._trace_client:
108
+ self._root_run_id = run_id # Assign the first run_id encountered as the tentative root
109
+ self._trace_saved = False # Ensure flag is reset
110
+ # # print(f"{log_prefix} Initialized NEW TraceClient: ID={self._trace_client.trace_id}, InitialName='{event_name}', Root Run ID={self._root_run_id}")
111
+ # Set active client on Tracer (important for potential fallbacks)
112
+ self.tracer._active_trace_client = self._trace_client
113
+ return self._trace_client
114
+ else:
115
+ # # print(f"{log_prefix} FATAL: TraceClient creation failed unexpectedly for {event_name} ({run_id}).")
116
+ return None
117
+ except Exception as e:
118
+ # # print(f"{log_prefix} FATAL: Exception initializing TraceClient for {event_name} ({run_id}): {e}")
119
+ # # print(traceback.format_exc())
120
+ self._trace_client = None
121
+ self._root_run_id = None
122
+ return None
123
+ # --- END MODIFIED _ensure_trace_client ---
124
+
125
+ def _log(self, message: str):
126
+ """Helper for consistent logging format."""
127
+ pass
128
+
129
+ def _start_span_tracking(
174
130
  self,
175
- outputs: Dict[str, Any],
176
- *,
131
+ trace_client: TraceClient, # Expect a valid client
177
132
  run_id: UUID,
178
- parent_run_id: Optional[UUID] = None,
179
- tags: Optional[List[str]] = None,
180
- **kwargs: Any,
181
- ) -> Any:
182
- current_trace = self.tracer.get_current_trace()
183
- if tags is not None and any("graph:step" in tag for tag in tags):
184
- current_trace.record_output(outputs)
185
- self.end_span(span_type="node")
186
-
187
- if self.created_trace and (outputs == "__end__" or (not kwargs and not tags)):
188
- self.end_span(span_type="Main Function")
189
-
190
- def on_chain_error(
133
+ parent_run_id: Optional[UUID],
134
+ name: str,
135
+ span_type: SpanType = "span",
136
+ inputs: Optional[Dict[str, Any]] = None
137
+ ):
138
+ # self._log(f"_start_span_tracking called for: name='{name}', run_id={run_id}, parent_run_id={parent_run_id}, span_type={span_type}")
139
+
140
+ # --- Add explicit check for trace_client ---
141
+ if not trace_client:
142
+ # --- Enhanced Logging ---
143
+ # handler_instance_id = id(self)
144
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
145
+ # --- End Enhanced Logging ---
146
+ # self._log(f"{log_prefix} FATAL ERROR in _start_span_tracking: trace_client argument is None for name='{name}', run_id={run_id}. Aborting span start.")
147
+ return
148
+ # --- End check ---
149
+ # --- Enhanced Logging ---
150
+ # handler_instance_id = id(self)
151
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
152
+ # trace_client_instance_id = id(trace_client) if trace_client else 'None'
153
+ # # print(f"{log_prefix} _start_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
154
+ # --- End Enhanced Logging ---
155
+
156
+ start_time = time.time()
157
+ span_id = str(uuid.uuid4())
158
+ parent_span_id: Optional[str] = None
159
+ current_depth = 0
160
+
161
+ if parent_run_id and parent_run_id in self._run_id_to_span_id:
162
+ parent_span_id = self._run_id_to_span_id[parent_run_id]
163
+ if parent_span_id in self._span_id_to_depth:
164
+ parent_depth = self._span_id_to_depth[parent_span_id]
165
+ current_depth = parent_depth + 1
166
+ # self._log(f" Found parent span_id={parent_span_id} with depth={parent_depth}. New depth={current_depth}.")
167
+ else:
168
+ # self._log(f" WARNING: Parent span depth not found for parent_span_id: {parent_span_id}. Setting depth to 0.")
169
+ current_depth = 0
170
+ elif parent_run_id:
171
+ # self._log(f" WARNING: parent_run_id {parent_run_id} provided for '{name}' ({run_id}) but parent span not tracked. Treating as depth 0.")
172
+ pass
173
+ else:
174
+ # self._log(f" No parent_run_id provided. Treating '{name}' as depth 0.")
175
+ pass
176
+
177
+ self._run_id_to_span_id[run_id] = span_id
178
+ self._span_id_to_start_time[span_id] = start_time
179
+ self._span_id_to_depth[span_id] = current_depth
180
+ # self._log(f" Tracking new span: span_id={span_id}, depth={current_depth}")
181
+
182
+ try:
183
+ trace_client.add_entry(TraceEntry(
184
+ type="enter", span_id=span_id, trace_id=trace_client.trace_id,
185
+ parent_span_id=parent_span_id, function=name, depth=current_depth,
186
+ message=name, created_at=start_time, span_type=span_type
187
+ ))
188
+ # self._log(f" Added 'enter' entry for span_id={span_id}")
189
+ except Exception as e:
190
+ # self._log(f" ERROR adding 'enter' entry for span_id {span_id}: {e}")
191
+ # # print(traceback.format_exc())
192
+ pass
193
+
194
+ if inputs:
195
+ # Pass the already validated trace_client
196
+ self._record_input_data(trace_client, run_id, inputs)
197
+
198
+ # --- Set SPAN context variable ONLY for chain (node) spans (Sync version) ---
199
+ if span_type == "chain":
200
+ try:
201
+ token = current_span_var.set(span_id)
202
+ self._run_id_to_context_token[run_id] = token
203
+ # self._log(f" Set current_span_var to {span_id} for run_id {run_id} (type: chain) in Sync Handler")
204
+ except Exception as e:
205
+ # self._log(f" ERROR setting current_span_var for run_id {run_id} in Sync Handler: {e}")
206
+ pass
207
+ # --- END ---
208
+
209
+ try:
210
+ # TODO: Check if trace_client.add_entry needs await if TraceClient becomes async
211
+ trace_client.add_entry(TraceEntry(
212
+ type="enter", span_id=span_id, trace_id=trace_client.trace_id,
213
+ parent_span_id=parent_span_id, function=name, depth=current_depth,
214
+ message=name, created_at=start_time, span_type=span_type
215
+ ))
216
+ # self._log(f" Added 'enter' entry for span_id={span_id}")
217
+ except Exception as e:
218
+ # self._log(f" ERROR adding 'enter' entry for span_id {span_id}: {e}")
219
+ # # print(traceback.format_exc())
220
+ pass
221
+
222
+ if inputs:
223
+ # _record_input_data is also sync for now
224
+ self._record_input_data(trace_client, run_id, inputs)
225
+
226
+ # --- NEW _end_span_tracking ---
227
+ def _end_span_tracking(
191
228
  self,
192
- error: BaseException,
193
- *,
229
+ trace_client: TraceClient, # Expect a valid client
194
230
  run_id: UUID,
195
- parent_run_id: Optional[UUID] = None,
196
- **kwargs: Any,
197
- ) -> Any:
198
- current_trace = self.tracer.get_current_trace()
199
- current_trace.record_output(error)
200
- self.end_span(span_type="node")
201
-
202
- def on_tool_start(
203
- self,
204
- serialized: Optional[dict[str, Any]],
205
- input_str: str,
206
- run_id: Optional[UUID] = None,
207
- parent_run_id: Optional[UUID] = None,
208
- inputs: Optional[dict[str, Any]] = None,
209
- **kwargs: Any,
231
+ span_type: SpanType = "span",
232
+ outputs: Optional[Any] = None,
233
+ error: Optional[BaseException] = None
210
234
  ):
211
- name = serialized["name"]
212
- self.start_span(name, span_type="tool")
213
- current_trace = self.tracer.get_current_trace()
214
- if name:
215
- # Track tool execution
216
- current_trace.executed_tools.append(name)
217
- node_tool = f"{self.previous_node}:{name}" if self.previous_node else name
218
- current_trace.executed_node_tools.append(node_tool)
219
- current_trace.record_input({
220
- 'args': input_str,
221
- 'kwargs': kwargs
222
- })
235
+ # self._log(f"_end_span_tracking called for: run_id={run_id}, span_type={span_type}")
236
+
237
+ # --- Define instance_id early for logging/cleanup ---
238
+ instance_id = id(self)
239
+
240
+ if not trace_client:
241
+ # Use instance_id defined above
242
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
243
+ # self._log(f"{log_prefix} FATAL ERROR in _end_span_tracking: trace_client argument is None for run_id={run_id}. Aborting span end.")
244
+ return
245
+
246
+ # Use instance_id defined above
247
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
248
+ # trace_client_instance_id = id(trace_client) if trace_client else 'None'
249
+ # # print(f"{log_prefix} _end_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
250
+
251
+ if run_id not in self._run_id_to_span_id:
252
+ # self._log(f" WARNING: Attempting to end span for untracked run_id: {run_id}")
253
+ # Allow root run end to proceed for cleanup/save attempt even if span wasn't tracked
254
+ if run_id != self._root_run_id:
255
+ return
256
+ else:
257
+ # self._log(f" Allowing root run {run_id} end logic to proceed despite untracked span.")
258
+ span_id = None # Indicate span wasn't found for duration/metadata lookup
259
+ else:
260
+ span_id = self._run_id_to_span_id[run_id]
261
+
262
+ start_time = self._span_id_to_start_time.get(span_id) if span_id else None
263
+ depth = self._span_id_to_depth.get(span_id, 0) if span_id else 0 # Use 0 depth if span_id is None
264
+ duration = time.time() - start_time if start_time is not None else None
265
+ # self._log(f" Ending span for run_id={run_id} (span_id={span_id}). Start time={start_time}, Duration={duration}, Depth={depth}")
266
+
267
+ # Record output/error first
268
+ if error:
269
+ # self._log(f" Recording error for run_id={run_id} (span_id={span_id}): {error}")
270
+ self._record_output_data(trace_client, run_id, error)
271
+ elif outputs is not None:
272
+ # output_repr = repr(outputs)
273
+ # log_output = (output_repr[:100] + '...') if len(output_repr) > 103 else output_repr
274
+ # self._log(f" Recording output for run_id={run_id} (span_id={span_id}): {log_output}")
275
+ self._record_output_data(trace_client, run_id, outputs)
276
+
277
+ # Add exit entry (only if span was tracked)
278
+ if span_id:
279
+ entry_function_name = "unknown"
280
+ try:
281
+ if hasattr(trace_client, 'entries') and trace_client.entries:
282
+ entry_function_name = next((e.function for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), "unknown")
283
+ else:
284
+ # self._log(f" WARNING: Cannot determine function name for exit span_id {span_id}, trace_client.entries missing or empty.")
285
+ pass
286
+ except Exception as e:
287
+ # self._log(f" ERROR finding function name for exit entry span_id {span_id}: {e}")
288
+ # # print(traceback.format_exc())
289
+ pass
290
+
291
+ try:
292
+ trace_client.add_entry(TraceEntry(
293
+ type="exit", span_id=span_id, trace_id=trace_client.trace_id,
294
+ depth=depth, created_at=time.time(), duration=duration,
295
+ span_type=span_type, function=entry_function_name
296
+ ))
297
+ # self._log(f" Added 'exit' entry for span_id={span_id}, function='{entry_function_name}'")
298
+ except Exception as e:
299
+ # self._log(f" ERROR adding 'exit' entry for span_id {span_id}: {e}")
300
+ # # print(traceback.format_exc())
301
+ pass
302
+
303
+ # Clean up dictionaries for this specific span
304
+ if span_id in self._span_id_to_start_time: del self._span_id_to_start_time[span_id]
305
+ if span_id in self._span_id_to_depth: del self._span_id_to_depth[span_id]
306
+
307
+ # Pop context token (Sync version) but don't reset
308
+ token = self._run_id_to_context_token.pop(run_id, None)
309
+ if token:
310
+ # self._log(f" Popped token for run_id {run_id} (was {span_id}), not resetting context var.")
311
+ pass
312
+ else:
313
+ # self._log(f" Skipping exit entry and cleanup for run_id {run_id} as span_id was not found.")
314
+ pass
315
+
316
+ # Check if this is the root run ending
317
+ if run_id == self._root_run_id:
318
+ trace_saved_successfully = False # Track save success
319
+ try:
320
+ # --- Aggregate and Set Token Counts BEFORE Saving ---
321
+ # if self._trace_client and not self._trace_saved:
322
+ # total_tokens = self._current_prompt_tokens + self._current_completion_tokens
323
+ # aggregated_token_counts = {
324
+ # 'prompt_tokens': self._current_prompt_tokens,
325
+ # 'completion_tokens': self._current_completion_tokens,
326
+ # 'total_tokens': total_tokens
327
+ # }
328
+ # # Assuming TraceClient has an attribute to hold the trace data being built
329
+ # try:
330
+ # # Attempt to set the attribute directly
331
+ # # Check if the attribute exists and is meant for this purpose
332
+ # if hasattr(self._trace_client, 'token_counts'):
333
+ # self._trace_client.token_counts = aggregated_token_counts
334
+ # self._log(f"Set aggregated token_counts on TraceClient for trace {self._trace_client.trace_id}: {aggregated_token_counts}")
335
+ # else:
336
+ # # If the attribute doesn't exist, maybe update the trace data dict directly if possible?
337
+ # # This part is speculative without knowing TraceClient internals.
338
+ # # E.g., if trace_client has a `_trace_data` dict:
339
+ # # if hasattr(self._trace_client, '_trace_data') and isinstance(self._trace_client._trace_data, dict):
340
+ # # self._trace_client._trace_data['token_counts'] = aggregated_token_counts
341
+ # # self._log(f"Updated _trace_data['token_counts'] on TraceClient for trace {self._trace_client.trace_id}: {aggregated_token_counts}")
342
+ # # else:
343
+ # self._log(f"WARNING: Could not set 'token_counts' on TraceClient for trace {self._trace_client.trace_id}. Aggregated counts might be lost.")
344
+ # except Exception as set_tc_e:
345
+ # self._log(f"ERROR setting token_counts on TraceClient for trace {self._trace_client.trace_id}: {set_tc_e}")
346
+ # --- End Token Count Aggregation ---
347
+
348
+ # Reset root run id after attempt
349
+ self._root_run_id = None
350
+ # Reset input storage for this handler instance
351
+ self._run_id_to_start_inputs = {}
352
+ self._log(f"Reset root run ID and input storage for handler {instance_id}.")
353
+
354
+ self._log(f"Root run {run_id} finished. Attempting to save trace...")
355
+ if self._trace_client and not self._trace_saved: # Check if not already saved
356
+ try:
357
+ # TODO: Check if trace_client.save needs await if TraceClient becomes async
358
+ trace_id, trace_data = self._trace_client.save(overwrite=self._trace_client.overwrite) # Use client's overwrite setting
359
+ self.traces.append(trace_data)
360
+ self._log(f"Trace {trace_id} successfully saved.")
361
+ self._trace_saved = True # Set flag only after successful save
362
+ trace_saved_successfully = True # Mark success
363
+ except Exception as e:
364
+ self._log(f"ERROR saving trace {self._trace_client.trace_id}: {e}")
365
+ # print(traceback.format_exc())
366
+ # REMOVED FINALLY BLOCK THAT RESET STATE HERE
367
+ elif self._trace_client and self._trace_saved:
368
+ self._log(f" Trace {self._trace_client.trace_id} already saved. Skipping save.")
369
+ else:
370
+ self._log(f" WARNING: Root run {run_id} ended, but trace client was None. Cannot save trace.")
371
+ finally:
372
+ # --- NEW: Consolidated Cleanup Logic ---
373
+ # This block executes regardless of save success/failure
374
+ self._log(f" Performing cleanup for root run {run_id} in handler {instance_id}.")
375
+ # Reset root run id
376
+ self._root_run_id = None
377
+ # Reset input storage for this handler instance
378
+ self._run_id_to_start_inputs = {}
379
+ # --- Reset Token Counters ---
380
+ # self._current_prompt_tokens = 0
381
+ # self._current_completion_tokens = 0
382
+ # self._log(" Reset token counters.")
383
+ # --- End Reset Token Counters ---
384
+ # Reset tracer's active client ONLY IF it was this handler's client
385
+ if self.tracer._active_trace_client == self._trace_client:
386
+ self.tracer._active_trace_client = None
387
+ self._log(" Reset active_trace_client on Tracer.")
388
+ # Completely remove trace_context_token cleanup as it's not used in sync handler
389
+ # Optionally: Reset the entire trace client instance for this handler?
390
+ # self._trace_client = None # Uncomment if handler should reset client completely after root run
391
+ self._log(f" Cleanup complete for root run {run_id}.")
392
+ # --- End Cleanup Logic ---
393
+
394
+ def _record_input_data(self,
395
+ trace_client: TraceClient,
396
+ run_id: UUID,
397
+ inputs: Dict[str, Any]):
398
+ # self._log(f"_record_input_data called for run_id={run_id}")
399
+ if run_id not in self._run_id_to_span_id:
400
+ # self._log(f" WARNING: Attempting to record input for untracked run_id: {run_id}")
401
+ return
402
+ if not trace_client:
403
+ # self._log(f" ERROR: TraceClient is None when trying to record input for run_id={run_id}")
404
+ return
405
+
406
+ span_id = self._run_id_to_span_id[run_id]
407
+ depth = self._span_id_to_depth.get(span_id, 0)
408
+ # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
409
+
410
+ function_name = "unknown"
411
+ span_type: SpanType = "span"
412
+ try:
413
+ # Find the corresponding 'enter' entry to get the function name and span type
414
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
415
+ if enter_entry:
416
+ function_name = enter_entry.function
417
+ span_type = enter_entry.span_type
418
+ # self._log(f" Found function='{function_name}', span_type='{span_type}' for input span_id={span_id}")
419
+ else:
420
+ # self._log(f" WARNING: Could not find 'enter' entry for input span_id={span_id}")
421
+ pass
422
+ except Exception as e:
423
+ # self._log(f" ERROR finding enter entry for input span_id {span_id}: {e}")
424
+ # # print(traceback.format_exc())
425
+ pass
426
+
427
+ try:
428
+ input_entry = TraceEntry(
429
+ type="input",
430
+ span_id=span_id,
431
+ trace_id=trace_client.trace_id,
432
+ parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
433
+ function=function_name,
434
+ depth=depth,
435
+ message=f"Input to {function_name}",
436
+ created_at=time.time(),
437
+ inputs=inputs,
438
+ span_type=span_type
439
+ )
440
+ trace_client.add_entry(input_entry)
441
+ # self._log(f" Added 'input' entry directly for span_id={span_id}")
442
+ except Exception as e:
443
+ # self._log(f" ERROR adding 'input' entry directly for span_id {span_id}: {e}")
444
+ # # print(traceback.format_exc())
445
+ pass
446
+
447
+ def _record_output_data(self,
448
+ trace_client: TraceClient,
449
+ run_id: UUID,
450
+ output: Any):
451
+ # self._log(f"_record_output_data called for run_id={run_id}")
452
+ if run_id not in self._run_id_to_span_id:
453
+ # self._log(f" WARNING: Attempting to record output for untracked run_id: {run_id}")
454
+ return
455
+ if not trace_client:
456
+ # self._log(f" ERROR: TraceClient is None when trying to record output for run_id={run_id}")
457
+ return
458
+
459
+ span_id = self._run_id_to_span_id[run_id]
460
+ depth = self._span_id_to_depth.get(span_id, 0)
461
+ # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
462
+
463
+ function_name = "unknown"
464
+ span_type: SpanType = "span"
465
+ try:
466
+ # Find the corresponding 'enter' entry to get the function name and span type
467
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
468
+ if enter_entry:
469
+ function_name = enter_entry.function
470
+ span_type = enter_entry.span_type
471
+ # self._log(f" Found function='{function_name}', span_type='{span_type}' for output span_id={span_id}")
472
+ else:
473
+ # self._log(f" WARNING: Could not find 'enter' entry for output span_id={span_id}")
474
+ pass
475
+ except Exception as e:
476
+ # self._log(f" ERROR finding enter entry for output span_id {span_id}: {e}")
477
+ # # print(traceback.format_exc())
478
+ pass
479
+
480
+ try:
481
+ output_entry = TraceEntry(
482
+ type="output",
483
+ span_id=span_id,
484
+ trace_id=trace_client.trace_id,
485
+ parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
486
+ function=function_name,
487
+ depth=depth,
488
+ message=f"Output from {function_name}",
489
+ created_at=time.time(),
490
+ output=output, # Langchain outputs are typically serializable directly
491
+ span_type=span_type
492
+ )
493
+ trace_client.add_entry(output_entry)
494
+ self._log(f" Added 'output' entry directly for span_id={span_id}")
495
+ except Exception as e:
496
+ self._log(f" ERROR adding 'output' entry directly for span_id {span_id}: {e}")
497
+ # print(traceback.format_exc())
498
+
499
+ def _record_error(self,
500
+ trace_client: TraceClient,
501
+ run_id: UUID,
502
+ error: Any):
503
+ # self._log(f"_record_error called for run_id={run_id}")
504
+ if run_id not in self._run_id_to_span_id:
505
+ # self._log(f" WARNING: Attempting to record error for untracked run_id: {run_id}")
506
+ return
507
+ if not trace_client:
508
+ # self._log(f" ERROR: TraceClient is None when trying to record error for run_id={run_id}")
509
+ return
510
+
511
+ span_id = self._run_id_to_span_id[run_id]
512
+ depth = self._span_id_to_depth.get(span_id, 0)
513
+ # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
514
+
515
+ function_name = "unknown"
516
+ span_type: SpanType = "span"
517
+ try:
518
+ # Find the corresponding 'enter' entry to get the function name and span type
519
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
520
+ if enter_entry:
521
+ function_name = enter_entry.function
522
+ span_type = enter_entry.span_type
523
+ # self._log(f" Found function='{function_name}', span_type='{span_type}' for error span_id={span_id}")
524
+ else:
525
+ # self._log(f" WARNING: Could not find 'enter' entry for error span_id={span_id}")
526
+ pass
527
+ except Exception as e:
528
+ # self._log(f" ERROR finding enter entry for error span_id {span_id}: {e}")
529
+ # # print(traceback.format_exc())
530
+ pass
531
+
532
+ try:
533
+ error_entry = TraceEntry(
534
+ type="error",
535
+ span_id=span_id,
536
+ trace_id=trace_client.trace_id,
537
+ parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
538
+ function=function_name,
539
+ depth=depth,
540
+ message=f"Error in {function_name}",
541
+ created_at=time.time(),
542
+ error=str(error), # Convert error to string for serialization
543
+ span_type=span_type
544
+ )
545
+ trace_client.add_entry(error_entry)
546
+ # self._log(f" Added 'error' entry directly for span_id={span_id}")
547
+ except Exception as e:
548
+ # self._log(f" ERROR adding 'error' entry directly for span_id {span_id}: {e}")
549
+ # # print(traceback.format_exc())
550
+ pass
551
+
552
+ # --- Callback Methods ---
553
+ # Each method now ensures the trace client exists before proceeding
554
+
555
+ def on_retriever_start(self, serialized: Dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
556
+ handler_instance_id = id(self)
557
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
558
+ serialized_name = serialized.get('name', 'Unknown') if serialized else "Unknown (Serialized=None)"
559
+ # print(f"{log_prefix} ENTERING on_retriever_start: name='{serialized_name}', run_id={run_id}. Parent: {parent_run_id}")
560
+
561
+ try:
562
+ name = f"RETRIEVER_{(serialized_name).upper()}"
563
+ # Pass parent_run_id to _ensure_trace_client
564
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
565
+ if not trace_client:
566
+ # print(f"{log_prefix} No trace client obtained in on_retriever_start for {run_id}.")
567
+ return
568
+
569
+ inputs = {'query': query, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
570
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="retriever", inputs=inputs)
571
+ except Exception as e:
572
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
573
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
574
+ # print(traceback.format_exc())
575
+
576
+ def on_retriever_end(self, documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
577
+ handler_instance_id = id(self)
578
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
579
+ # print(f"{log_prefix} ENTERING on_retriever_end: run_id={run_id}. Parent: {parent_run_id}")
580
+
581
+ try:
582
+ # Pass parent_run_id to _ensure_trace_client (though less critical on end events)
583
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd") # Corrected call
584
+ if not trace_client:
585
+ # print(f"{log_prefix} No trace client obtained in on_retriever_end for {run_id}.")
586
+ return
587
+
588
+ doc_summary = [{"index": i, "page_content": doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content, "metadata": doc.metadata} for i, doc in enumerate(documents)]
589
+ outputs = {"document_count": len(documents), "documents": doc_summary, "kwargs": kwargs}
590
+ self._end_span_tracking(trace_client, run_id, span_type="retriever", outputs=outputs)
591
+ except Exception as e:
592
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
593
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
594
+ # print(traceback.format_exc())
595
+
596
+ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
597
+ handler_instance_id = id(self)
598
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
599
+ serialized_name = serialized.get('name') if serialized else "Unknown (Serialized=None)"
600
+ # print(f"{log_prefix} ENTERING on_chain_start: name='{serialized_name}', run_id={run_id}. Parent: {parent_run_id}")
601
+
602
+ # --- Determine Name and Span Type ---
603
+ span_type: SpanType = "chain"
604
+ name = serialized_name if serialized_name else "Unknown Chain" # Default name
605
+ node_name = metadata.get("langgraph_node") if metadata else None
606
+ is_langgraph_root_kwarg = kwargs.get('name') == 'LangGraph' # Check kwargs for explicit root name
607
+ # More robust root detection: Often the first chain event with parent_run_id=None *is* the root.
608
+ is_potential_root_event = parent_run_id is None
609
+
610
+ if 'langsmith:hidden' in tags:
611
+ pass
612
+
613
+ if node_name:
614
+ name = node_name # Use node name if available
615
+ self._log(f" LangGraph Node Start: '{name}', run_id={run_id}")
616
+ if name not in self.executed_nodes: self.executed_nodes.append(name)
617
+ elif is_langgraph_root_kwarg and is_potential_root_event:
618
+ name = "LangGraph" # Explicit root detected
619
+ self._log(f" LangGraph Root Start Detected (kwargs): run_id={run_id}")
620
+ # Add handling for other potential LangChain internal chains if needed, e.g., "RunnableSequence"
621
+
622
+ # --- Ensure Trace Client ---
623
+ try:
624
+ # Pass parent_run_id to _ensure_trace_client
625
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
626
+ if not trace_client:
627
+ # print(f"{log_prefix} No trace client obtained in on_chain_start for {run_id} ('{name}').")
628
+ return
629
+
630
+ # --- Update Trace Name if Root ---
631
+ # If this is the root event (parent_run_id is None) and the trace client was just created,
632
+ # ensure the trace name reflects the graph's name ('LangGraph' usually).
633
+ if is_potential_root_event and run_id == self._root_run_id and trace_client.name != name:
634
+ self._log(f" Updating trace name from '{trace_client.name}' to '{name}' for root run {run_id}")
635
+ trace_client.name = name # Update trace name to the determined root name
636
+
637
+ # --- Start Span Tracking ---
638
+ combined_inputs = {'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
639
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type=span_type, inputs=inputs)
640
+ # --- Store inputs for potential evaluation later ---
641
+ self._run_id_to_start_inputs[run_id] = inputs # Store the raw inputs dict
642
+ self._log(f" Stored inputs for run_id {run_id}")
643
+ # --- End Store inputs ---
644
+
645
+ except Exception as e:
646
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
647
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
648
+ # print(traceback.format_exc())
649
+
650
+
651
+ def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any) -> Any:
652
+ handler_instance_id = id(self)
653
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
654
+ # print(f"{log_prefix} ENTERING on_chain_end: run_id={run_id}. Parent: {parent_run_id}")
655
+
656
+ # --- Define instance_id for logging ---
657
+ instance_id = handler_instance_id # Use the already obtained id
658
+
659
+ if 'langsmith:hidden' in tags:
660
+ pass
661
+
662
+ try:
663
+ # Pass parent_run_id
664
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainEnd") # Corrected call
665
+ if not trace_client:
666
+ # print(f"{log_prefix} No trace client obtained in on_chain_end for {run_id}.")
667
+ return
668
+
669
+ span_id = self._run_id_to_span_id.get(run_id)
670
+ span_type: SpanType = "chain" # Default
671
+ if span_id:
672
+ try:
673
+ if hasattr(trace_client, 'entries') and trace_client.entries:
674
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
675
+ if enter_entry: span_type = enter_entry.span_type
676
+ else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_end span_id={span_id}")
677
+ except Exception as e:
678
+ self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_end: {e}")
679
+ else:
680
+ self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_end.")
681
+ # If it's the root run ending, _end_span_tracking will handle cleanup/save
682
+ if run_id == self._root_run_id:
683
+ self._log(f" Run ID {run_id} matches root. Proceeding to _end_span_tracking for potential save.")
684
+ else:
685
+ return # Don't call end tracking if it's not the root and span wasn't tracked
686
+
687
+ # --- Store input in on_chain_start if needed for evaluation ---
688
+ # Retrieve stored inputs
689
+ start_inputs = self._run_id_to_start_inputs.get(run_id, {})
690
+ # TODO: Determine how to reliably extract the original user prompt from start_inputs
691
+ # For the demo, the 'generate_recommendations' node receives the full state, not the initial prompt.
692
+ # Using a placeholder for now.
693
+ user_prompt_for_eval = "Unknown Input" # Placeholder - Needs refinement based on graph structure
694
+ # user_prompt_for_eval = start_inputs.get("messages", [{}])[-1].get("content", "Unknown Input") # Example: If input has messages list
695
+
696
+ # --- Trigger evaluation ---
697
+ if "recommendations" in outputs and span_id: # Ensure span_id exists
698
+ self._log(f"[Async Handler {instance_id}] Chain end for run_id {run_id} (span_id={span_id}) identified as recommendation node. Attempting evaluation.")
699
+ recommendations_output = outputs.get("recommendations")
700
+
701
+ if recommendations_output:
702
+ eval_example = Example(
703
+ input=user_prompt_for_eval, # Requires modification to store/retrieve this
704
+ actual_output=recommendations_output
705
+ )
706
+ # TODO: Get model name dynamically if possible
707
+ model_name = "gpt-4" # Placeholder
708
+
709
+ self._log(f"[Async Handler {instance_id}] Submitting evaluation for span_id={span_id}")
710
+ try:
711
+ # Call evaluate on the trace client, passing the specific span_id
712
+ # The TraceClient.async_evaluate now accepts and prioritizes this span_id.
713
+ trace_client.async_evaluate(
714
+ scorers=[AnswerRelevancyScorer(threshold=0.5)], # Ensure this scorer is imported
715
+ example=eval_example,
716
+ model=model_name,
717
+ span_id=span_id # Pass the specific span_id for this node run
718
+ )
719
+ self._log(f"[Async Handler {instance_id}] Evaluation submitted successfully for span_id={span_id}.")
720
+ except Exception as eval_e:
721
+ self._log(f"[Async Handler {instance_id}] ERROR submitting evaluation for span_id={span_id}: {eval_e}")
722
+ # print(traceback.format_exc()) # Print traceback for evaluation errors
723
+ else:
724
+ self._log(f"[Async Handler {instance_id}] Skipping evaluation for run_id {run_id} (span_id={span_id}): Missing recommendations output.")
725
+ elif "recommendations" in outputs:
726
+ self._log(f"[Async Handler {instance_id}] Skipping evaluation for run_id {run_id}: Span ID not found.")
727
+
728
+ # --- Existing span ending logic ---
729
+ # Determine span_type for end_span_tracking (copied from sync handler)
730
+ end_span_type: SpanType = "chain" # Default
731
+ if span_id: # Check if span_id was actually found
732
+ try:
733
+ if hasattr(trace_client, 'entries') and trace_client.entries:
734
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
735
+ if enter_entry: end_span_type = enter_entry.span_type
736
+ else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_end span_id={span_id}")
737
+ except Exception as e:
738
+ self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_end: {e}")
739
+ else:
740
+ self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_end, using default span_type='chain'.")
741
+
742
+ # Prepare outputs for end tracking (moved down)
743
+ combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
744
+
745
+ # Call end_span_tracking with potentially determined span_type
746
+ self._end_span_tracking(trace_client, run_id, span_type=end_span_type, outputs=combined_outputs)
747
+
748
+ # --- Root node cleanup (Existing logic - slightly modified save call) ---
749
+ if run_id == self._root_run_id:
750
+ self._log(f"Root run {run_id} finished. Attempting to save trace...")
751
+ if trace_client and not self._trace_saved:
752
+ try:
753
+ # Save might need to be async if TraceClient methods become async
754
+ # Pass overwrite=True based on client's setting
755
+ trace_id_saved, trace_data = trace_client.save(overwrite=trace_client.overwrite)
756
+ self.traces.append(trace_data)
757
+ self._trace_saved = True
758
+ self._log(f"Trace {trace_id_saved} successfully saved.")
759
+ # Reset tracer's active client *after* successful save
760
+ if self.tracer._active_trace_client == trace_client:
761
+ self.tracer._active_trace_client = None
762
+ self._log("Reset active_trace_client on Tracer.")
763
+ except Exception as e:
764
+ self._log(f"ERROR saving trace {trace_client.trace_id}: {e}")
765
+ # print(traceback.format_exc())
766
+ elif trace_client and self._trace_saved:
767
+ self._log(f"Trace {trace_client.trace_id} already saved. Skipping save for root run {run_id}.")
768
+ elif not trace_client:
769
+ self._log(f"Skipping trace save for root run {run_id}: No client available.")
770
+
771
+ # Reset root run id after attempt
772
+ self._root_run_id = None
773
+ # Reset input storage for this handler instance
774
+ self._run_id_to_start_inputs = {}
775
+ self._log(f"Reset root run ID and input storage for handler {instance_id}.")
776
+
777
+ # --- SYNC: Attempt Evaluation by checking output metadata ---
778
+ eval_config: Optional[EvaluationConfig] = None
779
+ node_name = "unknown_node" # Default node name
780
+ # Ensure trace_client exists before proceeding with eval logic that uses it
781
+ if trace_client:
782
+ if span_id: # Try to find the node name from the 'enter' entry
783
+ try:
784
+ if hasattr(trace_client, 'entries') and trace_client.entries:
785
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
786
+ if enter_entry: node_name = enter_entry.function
787
+ except Exception as e:
788
+ self._log(f" ERROR finding node name for span_id {span_id} in on_chain_end: {e}")
789
+
790
+ if span_id and "_judgeval_eval" in outputs: # Only attempt if span exists and key is present
791
+ raw_eval_config = outputs.get("_judgeval_eval")
792
+ if isinstance(raw_eval_config, EvaluationConfig):
793
+ eval_config = raw_eval_config
794
+ self._log(f"{log_prefix} Found valid EvaluationConfig in outputs for node='{node_name}'.")
795
+ elif isinstance(raw_eval_config, dict):
796
+ # Attempt to reconstruct from dict
797
+ try:
798
+ if "scorers" in raw_eval_config and "example" in raw_eval_config:
799
+ example_data = raw_eval_config["example"]
800
+ reconstructed_example = Example(**example_data) if isinstance(example_data, dict) else example_data
801
+
802
+ if isinstance(reconstructed_example, Example):
803
+ eval_config = EvaluationConfig(
804
+ scorers=raw_eval_config["scorers"],
805
+ example=reconstructed_example,
806
+ model=raw_eval_config.get("model"),
807
+ log_results=raw_eval_config.get("log_results", True)
808
+ )
809
+ self._log(f"{log_prefix} Reconstructed EvaluationConfig from dict in outputs for node='{node_name}'.")
810
+ else:
811
+ self._log(f"{log_prefix} Could not reconstruct Example from dict in _judgeval_eval for node='{node_name}'. Skipping evaluation.")
812
+ else:
813
+ self._log(f"{log_prefix} Dict in _judgeval_eval missing required keys ('scorers', 'example') for node='{node_name}'. Skipping evaluation.")
814
+ except Exception as recon_e:
815
+ self._log(f"{log_prefix} ERROR attempting to reconstruct EvaluationConfig from dict for node='{node_name}': {recon_e}")
816
+ # print(traceback.format_exc()) # Print traceback for reconstruction errors
817
+ else:
818
+ self._log(f"{log_prefix} Found '_judgeval_eval' key in outputs for node='{node_name}', but it wasn't an EvaluationConfig object or reconstructable dict. Skipping evaluation.")
819
+
820
+ # Check eval_config *and* span_id again (this should be indented correctly)
821
+ if eval_config and span_id:
822
+ self._log(f"{log_prefix} Submitting evaluation for span_id={span_id}")
823
+ try:
824
+
825
+ # Call async_evaluate on the TraceClient instance ('trace_client')
826
+ # Use the correct variable name 'trace_client' here
827
+ trace_client.async_evaluate( # <-- Fix: Use trace_client
828
+ scorers=eval_config.scorers,
829
+ example=eval_config.example,
830
+ model=eval_config.model,
831
+ log_results=eval_config.log_results,
832
+ span_id=span_id # Pass the specific span_id for this node run
833
+ )
834
+ self._log(f"{log_prefix} Evaluation submitted successfully for span_id={span_id}.")
835
+ except Exception as eval_e:
836
+ self._log(f"{log_prefix} ERROR submitting evaluation for span_id={span_id}: {eval_e}")
837
+ # print(traceback.format_exc()) # Print traceback for evaluation errors
838
+ elif "_judgeval_eval" in outputs and not span_id:
839
+ self._log(f"{log_prefix} WARNING: Found _judgeval_eval in outputs, but span_id for run_id {run_id} was not found. Cannot submit evaluation.")
840
+ # --- End SYNC Evaluation Logic ---
841
+
842
+ except Exception as e:
843
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
844
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
845
+ # print(traceback.format_exc())
846
+
847
+ def on_chain_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
848
+ handler_instance_id = id(self)
849
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
850
+ # print(f"{log_prefix} ENTERING on_chain_error: run_id={run_id}. Parent: {parent_run_id}")
851
+
852
+ try:
853
+ # Pass parent_run_id
854
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainError") # Corrected call
855
+ if not trace_client:
856
+ # print(f"{log_prefix} No trace client obtained in on_chain_error for {run_id}.")
857
+ return
858
+
859
+ span_id = self._run_id_to_span_id.get(run_id)
860
+ span_type: SpanType = "chain" # Default
861
+ if span_id:
862
+ try:
863
+ if hasattr(trace_client, 'entries') and trace_client.entries:
864
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
865
+ if enter_entry: span_type = enter_entry.span_type
866
+ else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_error span_id={span_id}")
867
+ except Exception as e:
868
+ self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_error: {e}")
869
+ else:
870
+ self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_error.")
871
+ # Let _end_span_tracking handle potential root run cleanup
872
+ if run_id != self._root_run_id:
873
+ return
874
+
875
+ self._end_span_tracking(trace_client, run_id, span_type=span_type, error=error)
876
+ except Exception as e:
877
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
878
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
879
+ # print(traceback.format_exc())
880
+
881
+ def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
882
+ handler_instance_id = id(self)
883
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
884
+ name = serialized.get("name", "Unnamed Tool") if serialized else "Unknown Tool (Serialized=None)"
885
+ # print(f"{log_prefix} ENTERING on_tool_start: name='{name}', run_id={run_id}. Parent: {parent_run_id}")
886
+
887
+ try:
888
+ # Pass parent_run_id
889
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
890
+ if not trace_client:
891
+ # print(f"{log_prefix} No trace client obtained in on_tool_start for {run_id}.")
892
+ return
893
+
894
+ combined_inputs = {'input_str': input_str, 'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
895
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="tool", inputs=inputs)
896
+
897
+ # --- Track executed tools (remains the same) ---
898
+ if name not in self.executed_tools: self.executed_tools.append(name)
899
+ parent_node_name = None
900
+ if parent_run_id and parent_run_id in self._run_id_to_span_id:
901
+ parent_span_id = self._run_id_to_span_id[parent_run_id]
902
+ try:
903
+ if hasattr(trace_client, 'entries') and trace_client.entries:
904
+ parent_enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == parent_span_id and e.type == "enter" and e.span_type == "chain"), None)
905
+ if parent_enter_entry:
906
+ parent_node_name = parent_enter_entry.function
907
+ else: self._log(f" WARNING: trace_client.entries missing for tool start parent {parent_span_id}")
908
+ except Exception as e:
909
+ self._log(f" ERROR finding parent node name for tool start span_id {parent_span_id}: {e}")
910
+
911
+ node_tool = f"{parent_node_name}:{name}" if parent_node_name else name
912
+ if node_tool not in self.executed_node_tools: self.executed_node_tools.append(node_tool)
913
+ self._log(f" Tracked node_tool: {node_tool}")
914
+ # --- End Track executed tools ---
915
+ except Exception as e:
916
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
917
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
918
+ # print(traceback.format_exc())
919
+
223
920
 
224
921
  def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
225
- current_trace = self.tracer.get_current_trace()
226
- current_trace.record_output(output)
227
- self.end_span(span_type="tool")
922
+ handler_instance_id = id(self)
923
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
924
+ # print(f"{log_prefix} ENTERING on_tool_end: run_id={run_id}. Parent: {parent_run_id}")
228
925
 
229
- def on_tool_error(
230
- self,
231
- error: BaseException,
232
- *,
233
- run_id: UUID,
234
- parent_run_id: Optional[UUID] = None,
235
- **kwargs: Any,
236
- ) -> Any:
237
- current_trace = self.tracer.get_current_trace()
238
- current_trace.record_output(error)
239
- self.end_span(span_type="tool")
240
-
241
- def on_agent_action(
242
- self,
243
- action: AgentAction,
244
- *,
245
- run_id: UUID,
246
- parent_run_id: Optional[UUID] = None,
247
- **kwargs: Any,
248
- ) -> Any:
249
- pass
926
+ try:
927
+ # Pass parent_run_id
928
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolEnd") # Corrected call
929
+ if not trace_client:
930
+ # print(f"{log_prefix} No trace client obtained in on_tool_end for {run_id}.")
931
+ return
932
+ outputs = {"output": output, "kwargs": kwargs}
933
+ self._end_span_tracking(trace_client, run_id, span_type="tool", outputs=outputs)
934
+ except Exception as e:
935
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
936
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
937
+ # print(traceback.format_exc())
250
938
 
251
- def on_agent_finish(
252
- self,
253
- finish: AgentFinish,
254
- *,
255
- run_id: UUID,
256
- parent_run_id: Optional[UUID] = None,
257
- **kwargs: Any,
258
- ) -> Any:
259
-
939
+ def on_tool_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
940
+ handler_instance_id = id(self)
941
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
942
+ # print(f"{log_prefix} ENTERING on_tool_error: run_id={run_id}. Parent: {parent_run_id}")
943
+
944
+ try:
945
+ # Pass parent_run_id
946
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolError") # Corrected call
947
+ if not trace_client:
948
+ # print(f"{log_prefix} No trace client obtained in on_tool_error for {run_id}.")
949
+ return
950
+ self._end_span_tracking(trace_client, run_id, span_type="tool", error=error)
951
+ except Exception as e:
952
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
953
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
954
+ # print(traceback.format_exc())
955
+
956
+ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
957
+ handler_instance_id = id(self)
958
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
959
+ llm_name = name or serialized.get("name", "LLM Call")
960
+ # print(f"{log_prefix} ENTERING on_llm_start: name='{llm_name}', run_id={run_id}. Parent: {parent_run_id}")
961
+
962
+ try:
963
+ # Pass parent_run_id
964
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, llm_name) # Corrected call
965
+ if not trace_client:
966
+ # print(f"{log_prefix} No trace client obtained in on_llm_start for {run_id}.")
967
+ return
968
+ inputs = {'prompts': prompts, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
969
+ self._start_span_tracking(trace_client, run_id, parent_run_id, llm_name, span_type="llm", inputs=prompts)
970
+ except Exception as e:
971
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
972
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
973
+ # print(traceback.format_exc())
974
+
975
+ def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
976
+ handler_instance_id = id(self)
977
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
978
+ # print(f"{log_prefix} ENTERING on_llm_end: run_id={run_id}. Parent: {parent_run_id}")
979
+
980
+ # --- Debugging unchanged ---
981
+ # # print(f"{log_prefix} [DEBUG on_llm_end] Received response object for run_id={run_id}:")
982
+ # try:
983
+ # from rich import print as rprint
984
+ # r# print(response)
985
+ # except ImportError: # print(response)
986
+ # # print(f"{log_prefix} [DEBUG on_llm_end] response.llm_output type: {type(response.llm_output)}")
987
+ # # print(f"{log_prefix} [DEBUG on_llm_end] response.llm_output content:")
988
+ # try:
989
+ # from rich import print as rprint
990
+
991
+ # except ImportError: # print(response.llm_output)
992
+
993
+ try:
994
+ # Pass parent_run_id
995
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMEnd") # Corrected call
996
+ if not trace_client:
997
+ # print(f"{log_prefix} No trace client obtained in on_llm_end for {run_id}.")
998
+ return
999
+ outputs = {"response": response, "kwargs": kwargs}
1000
+ # --- Token Usage Extraction and Accumulation ---
1001
+ token_usage = None
1002
+ prompt_tokens = None # Use standard name
1003
+ completion_tokens = None # Use standard name
1004
+ total_tokens = None
1005
+ try:
1006
+ if response.llm_output and isinstance(response.llm_output, dict):
1007
+ # Check for OpenAI/standard 'token_usage' first
1008
+ if 'token_usage' in response.llm_output:
1009
+ token_usage = response.llm_output.get('token_usage')
1010
+ if token_usage and isinstance(token_usage, dict):
1011
+ self._log(f" Extracted OpenAI token usage for run_id={run_id}: {token_usage}")
1012
+ prompt_tokens = token_usage.get('prompt_tokens')
1013
+ completion_tokens = token_usage.get('completion_tokens')
1014
+ total_tokens = token_usage.get('total_tokens') # OpenAI provides total
1015
+ # Check for Anthropic 'usage'
1016
+ elif 'usage' in response.llm_output:
1017
+ token_usage = response.llm_output.get('usage')
1018
+ if token_usage and isinstance(token_usage, dict):
1019
+ self._log(f" Extracted Anthropic token usage for run_id={run_id}: {token_usage}")
1020
+ prompt_tokens = token_usage.get('input_tokens') # Anthropic uses input_tokens
1021
+ completion_tokens = token_usage.get('output_tokens') # Anthropic uses output_tokens
1022
+ # Calculate total if possible
1023
+ if prompt_tokens is not None and completion_tokens is not None:
1024
+ total_tokens = prompt_tokens + completion_tokens
1025
+ else:
1026
+ self._log(f" Could not calculate total_tokens from Anthropic usage: input={prompt_tokens}, output={completion_tokens}")
1027
+
1028
+ # --- Store individual usage in span output and Accumulate ---
1029
+ if prompt_tokens is not None or completion_tokens is not None:
1030
+ # Store individual usage for this span
1031
+ outputs['usage'] = {
1032
+ 'prompt_tokens': prompt_tokens,
1033
+ 'completion_tokens': completion_tokens,
1034
+ 'total_tokens': total_tokens
1035
+ }
1036
+ # Accumulate tokens for the entire trace
1037
+ # if isinstance(prompt_tokens, int):
1038
+ # self._current_prompt_tokens += prompt_tokens
1039
+ # if isinstance(completion_tokens, int):
1040
+ # self._current_completion_tokens += completion_tokens
1041
+ # self._log(f" Accumulated tokens for run_id={run_id}. Current totals: Prompt={self._current_prompt_tokens}, Completion={self._current_completion_tokens}")
1042
+ else:
1043
+ self._log(f" Could not extract token usage structure from llm_output for run_id={run_id}")
1044
+ else: self._log(f" llm_output not available/dict for run_id={run_id}")
1045
+ except Exception as e:
1046
+ self._log(f" ERROR extracting/accumulating token usage for run_id={run_id}: {e}")
1047
+ # --- End Token Usage ---
1048
+ self._end_span_tracking(trace_client, run_id, span_type="llm", outputs=outputs)
1049
+ except Exception as e:
1050
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1051
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1052
+ # print(traceback.format_exc())
1053
+
1054
+ def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1055
+ handler_instance_id = id(self)
1056
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1057
+ # print(f"{log_prefix} ENTERING on_llm_error: run_id={run_id}. Parent: {parent_run_id}")
1058
+
1059
+ try:
1060
+ # Pass parent_run_id
1061
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMError") # Corrected call
1062
+ if not trace_client:
1063
+ # print(f"{log_prefix} No trace client obtained in on_llm_error for {run_id}.")
1064
+ return
1065
+ self._end_span_tracking(trace_client, run_id, span_type="llm", error=error)
1066
+ except Exception as e:
1067
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1068
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1069
+ # print(traceback.format_exc())
1070
+
1071
+ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
1072
+ # Reuse on_llm_start logic, adding message formatting if needed
1073
+ handler_instance_id = id(self)
1074
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1075
+ chat_model_name = name or serialized.get("name", "ChatModel Call")
1076
+ # Add OPENAI_API_CALL suffix if model is OpenAI and not present
1077
+ is_openai = any(key.startswith('openai') for key in serialized.get('secrets', {}).keys()) or 'openai' in chat_model_name.lower()
1078
+ is_anthropic = any(key.startswith('anthropic') for key in serialized.get('secrets', {}).keys()) or 'anthropic' in chat_model_name.lower() or 'claude' in chat_model_name.lower()
1079
+ is_together = any(key.startswith('together') for key in serialized.get('secrets', {}).keys()) or 'together' in chat_model_name.lower()
1080
+ # Add more checks for other providers like Google if needed
1081
+ is_google = any(key.startswith('google') for key in serialized.get('secrets', {}).keys()) or 'google' in chat_model_name.lower() or 'gemini' in chat_model_name.lower()
1082
+
1083
+ if is_openai and "OPENAI_API_CALL" not in chat_model_name:
1084
+ chat_model_name = f"{chat_model_name} OPENAI_API_CALL"
1085
+ elif is_anthropic and "ANTHROPIC_API_CALL" not in chat_model_name:
1086
+ chat_model_name = f"{chat_model_name} ANTHROPIC_API_CALL"
1087
+ elif is_together and "TOGETHER_API_CALL" not in chat_model_name:
1088
+ chat_model_name = f"{chat_model_name} TOGETHER_API_CALL"
1089
+
1090
+ elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
1091
+ chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
1092
+
1093
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1094
+ # print(f"{log_prefix} ENTERING on_chat_model_start: name='{chat_model_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1095
+ try:
1096
+ # The call below was missing parent_run_id
1097
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name) # Corrected call with parent_run_id
1098
+ if not trace_client: return
1099
+ inputs = {'messages': messages, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
1100
+ self._start_span_tracking(trace_client, run_id, parent_run_id, chat_model_name, span_type="llm", inputs=messages) # Use 'llm' span_type for consistency
1101
+ except Exception as e:
1102
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1103
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chat_model_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1104
+ # print(traceback.format_exc())
1105
+
1106
+ # --- Agent Methods (Async versions - ensure parent_run_id passed if needed) ---
1107
+ def on_agent_action(self, action: AgentAction, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1108
+ handler_instance_id = id(self)
1109
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1110
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1111
+ # print(f"{log_prefix} ENTERING on_agent_action: tool={action.tool}, run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1112
+
1113
+ try:
1114
+ # Optional: Implement detailed tracing if needed
1115
+ pass
1116
+ except Exception as e:
1117
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1118
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_agent_action for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1119
+ # print(traceback.format_exc())
1120
+
1121
+ def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1122
+ handler_instance_id = id(self)
1123
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1124
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1125
+ # print(f"{log_prefix} ENTERING on_agent_finish: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1126
+
1127
+ try:
1128
+ # Optional: Implement detailed tracing if needed
1129
+ pass
1130
+ except Exception as e:
1131
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1132
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_agent_finish for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1133
+ # print(traceback.format_exc())
1134
+
1135
+ # --- Async Handler ---
1136
+
1137
+ # --- NEW Fully Functional Async Handler ---
1138
+ class AsyncJudgevalCallbackHandler(AsyncCallbackHandler):
1139
+ """
1140
+ Async LangChain Callback Handler using run_id/parent_run_id for hierarchy.
1141
+ Manages its own internal TraceClient instance created upon first use.
1142
+ Includes verbose logging and defensive checks.
1143
+ """
1144
+ lc_serializable = False
1145
+ lc_kwargs = {}
1146
+
1147
+ def __init__(self, tracer: Tracer):
1148
+ instance_id = id(self)
1149
+ # print(f"{HANDLER_LOG_PREFIX} *** Async Handler instance {instance_id} __init__ called. ***")
1150
+ self.tracer = tracer
1151
+ self._trace_client: Optional[TraceClient] = None
1152
+ self._run_id_to_span_id: Dict[UUID, str] = {}
1153
+ self._span_id_to_start_time: Dict[str, float] = {}
1154
+ self._span_id_to_depth: Dict[str, int] = {}
1155
+ self._run_id_to_context_token: Dict[UUID, contextvars.Token] = {} # Initialize missing attribute
1156
+ self._root_run_id: Optional[UUID] = None
1157
+ self._trace_context_token: Optional[contextvars.Token] = None # Restore missing attribute
1158
+ self._trace_saved: bool = False # <<< ADDED MISSING ATTRIBUTE
1159
+ self._run_id_to_start_inputs: Dict[UUID, Dict] = {} # <<< ADDED input storage
1160
+
1161
+ # --- Token Count Accumulators ---
1162
+ # self._current_prompt_tokens = 0
1163
+ # self._current_completion_tokens = 0
1164
+ # --- End Token Count Accumulators ---
1165
+
1166
+ self.executed_nodes: List[str] = []
1167
+ self.executed_tools: List[str] = []
1168
+ self.executed_node_tools: List[str] = []
1169
+ self.traces: List[Dict[str, Any]] = []
1170
+
1171
+ # NOTE: _ensure_trace_client remains synchronous as it doesn't involve async I/O
1172
+ def _ensure_trace_client(self, run_id: UUID, event_name: str) -> Optional[TraceClient]:
1173
+ """Ensures the internal trace client is initialized. Returns client or None."""
1174
+ handler_instance_id = id(self)
1175
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1176
+ if self._trace_client is None:
1177
+ trace_id = str(uuid.uuid4())
1178
+ project = self.tracer.project_name
1179
+ try:
1180
+ client_instance = TraceClient(
1181
+ self.tracer, trace_id, event_name, project_name=project,
1182
+ overwrite=False, rules=self.tracer.rules,
1183
+ enable_monitoring=self.tracer.enable_monitoring,
1184
+ enable_evaluations=self.tracer.enable_evaluations
1185
+ )
1186
+ self._trace_client = client_instance
1187
+ if self._trace_client:
1188
+ self.tracer._active_trace_client = self._trace_client
1189
+ self._current_trace_id = self._trace_client.trace_id
1190
+ if self._root_run_id is None:
1191
+ self._root_run_id = run_id
1192
+ else:
1193
+ return None
1194
+ except Exception as e:
1195
+ self._trace_client = None
1196
+ return None
1197
+ return self._trace_client
1198
+
1199
+ def _log(self, message: str):
1200
+ """Helper for consistent logging format."""
260
1201
  pass
261
1202
 
262
- def on_llm_start(
263
- self,
264
- serialized: Optional[dict[str, Any]],
265
- prompts: list[str],
266
- *,
267
- run_id: UUID,
268
- parent_run_id: Optional[UUID] = None,
269
- **kwargs: Any,
270
- ) -> Any:
271
- name = "LLM call"
272
- self.start_span(name, span_type="llm")
273
- current_trace = self.tracer.get_current_trace()
274
- current_trace.record_input({
275
- 'args': prompts,
276
- 'kwargs': kwargs
277
- })
278
-
279
- def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any):
280
- current_trace = self.tracer.get_current_trace()
281
- current_trace.record_output(response.generations[0][0].text)
282
- self.end_span(span_type="llm")
283
-
284
- def on_llm_error(
1203
+ # NOTE: _start_span_tracking remains mostly synchronous, TraceClient.add_entry might become async later
1204
+ def _start_span_tracking(
285
1205
  self,
286
- error: BaseException,
287
- *,
1206
+ trace_client: TraceClient,
288
1207
  run_id: UUID,
289
- parent_run_id: Optional[UUID] = None,
290
- **kwargs: Any,
291
- ) -> Any:
292
- current_trace = self.tracer.get_current_trace()
293
- current_trace.record_output(error)
294
- self.end_span(span_type="llm")
295
-
296
- def on_chat_model_start(
1208
+ parent_run_id: Optional[UUID],
1209
+ name: str,
1210
+ span_type: SpanType = "span",
1211
+ inputs: Optional[Dict[str, Any]] = None
1212
+ ):
1213
+ self._log(f"_start_span_tracking called for: name='{name}', run_id={run_id}, parent_run_id={parent_run_id}, span_type={span_type}")
1214
+ if not trace_client:
1215
+ handler_instance_id = id(self)
1216
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1217
+ self._log(f"{log_prefix} FATAL ERROR in _start_span_tracking: trace_client argument is None for name='{name}', run_id={run_id}. Aborting span start.")
1218
+ return
1219
+
1220
+ # --- NEW: Set trace context variable if not already set for this trace ---
1221
+ if self._trace_context_token is None:
1222
+ try:
1223
+ self._trace_context_token = current_trace_var.set(trace_client)
1224
+ self._log(f" Set current_trace_var for trace_id {trace_client.trace_id}")
1225
+ except Exception as e:
1226
+ self._log(f" ERROR setting current_trace_var for trace_id {trace_client.trace_id}: {e}")
1227
+ # --- END NEW ---
1228
+
1229
+ handler_instance_id = id(self)
1230
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1231
+ trace_client_instance_id = id(trace_client) if trace_client else 'None'
1232
+ # print(f"{log_prefix} _start_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
1233
+
1234
+ start_time = time.time()
1235
+ span_id = str(uuid.uuid4())
1236
+ parent_span_id: Optional[str] = None
1237
+ current_depth = 0
1238
+
1239
+ if parent_run_id and parent_run_id in self._run_id_to_span_id:
1240
+ parent_span_id = self._run_id_to_span_id[parent_run_id]
1241
+ if parent_span_id in self._span_id_to_depth:
1242
+ current_depth = self._span_id_to_depth[parent_span_id] + 1
1243
+ else:
1244
+ self._log(f" WARNING: Parent span depth not found for parent_span_id: {parent_span_id}. Setting depth to 0.")
1245
+ elif parent_run_id:
1246
+ self._log(f" WARNING: parent_run_id {parent_run_id} provided for '{name}' ({run_id}) but parent span not tracked. Treating as depth 0.")
1247
+ else:
1248
+ self._log(f" No parent_run_id provided. Treating '{name}' as depth 0.")
1249
+
1250
+ self._run_id_to_span_id[run_id] = span_id
1251
+ self._span_id_to_start_time[span_id] = start_time
1252
+ self._span_id_to_depth[span_id] = current_depth
1253
+ self._log(f" Tracking new span: span_id={span_id}, depth={current_depth}")
1254
+
1255
+ # --- Set SPAN context variable ONLY for chain (node) spans ---
1256
+ if span_type == "chain":
1257
+ try:
1258
+ # Set current_span_var for the node's execution context
1259
+ token = current_span_var.set(span_id) # Store the token
1260
+ self._run_id_to_context_token[run_id] = token # Store token in the dictionary
1261
+ self._log(f" Set current_span_var to {span_id} for run_id {run_id} (type: chain)")
1262
+ except Exception as e:
1263
+ self._log(f" ERROR setting current_span_var for run_id {run_id}: {e}")
1264
+ # --- END Span Context Var Logic ---
1265
+
1266
+ try:
1267
+ # TODO: Check if trace_client.add_entry needs await if TraceClient becomes async
1268
+ trace_client.add_entry(TraceEntry(
1269
+ type="enter", span_id=span_id, trace_id=trace_client.trace_id,
1270
+ parent_span_id=parent_span_id, function=name, depth=current_depth,
1271
+ message=name, created_at=start_time, span_type=span_type
1272
+ ))
1273
+ self._log(f" Added 'enter' entry for span_id={span_id}")
1274
+ except Exception as e:
1275
+ self._log(f" ERROR adding 'enter' entry for span_id {span_id}: {e}")
1276
+ # print(traceback.format_exc())
1277
+
1278
+ if inputs:
1279
+ # _record_input_data is also sync for now
1280
+ self._record_input_data(trace_client, run_id, inputs)
1281
+
1282
+ # NOTE: _end_span_tracking remains mostly synchronous, TraceClient.save might become async later
1283
+ def _end_span_tracking(
297
1284
  self,
298
- serialized: Optional[dict[str, Any]],
299
- messages: list[list[BaseMessage]],
300
- *,
1285
+ trace_client: TraceClient,
301
1286
  run_id: UUID,
302
- parent_run_id: Optional[UUID] = None,
303
- **kwargs: Any,
304
- ) -> Any:
305
-
306
- if "openai" in serialized["id"]:
307
- name = f"OPENAI_API_CALL"
308
- elif "anthropic" in serialized["id"]:
309
- name = "ANTHROPIC_API_CALL"
310
- elif "together" in serialized["id"]:
311
- name = "TOGETHER_API_CALL"
1287
+ span_type: SpanType = "span",
1288
+ outputs: Optional[Any] = None,
1289
+ error: Optional[BaseException] = None
1290
+ ):
1291
+ # self._log(f"_end_span_tracking called for: run_id={run_id}, span_type={span_type}")
1292
+
1293
+ # --- Define instance_id early for logging/cleanup ---
1294
+ instance_id = id(self)
1295
+
1296
+ if not trace_client:
1297
+ # Use instance_id defined above
1298
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
1299
+ # self._log(f"{log_prefix} FATAL ERROR in _end_span_tracking: trace_client argument is None for run_id={run_id}. Aborting span end.")
1300
+ return
1301
+
1302
+ # Use instance_id defined above
1303
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
1304
+ # trace_client_instance_id = id(trace_client) if trace_client else 'None'
1305
+ # # print(f"{log_prefix} _end_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
1306
+
1307
+ if run_id not in self._run_id_to_span_id:
1308
+ # self._log(f" WARNING: Attempting to end span for untracked run_id: {run_id}")
1309
+ # Allow root run end to proceed for cleanup/save attempt even if span wasn't tracked
1310
+ if run_id != self._root_run_id:
1311
+ return
1312
+ else:
1313
+ # self._log(f" Allowing root run {run_id} end logic to proceed despite untracked span.")
1314
+ span_id = None # Indicate span wasn't found for duration/metadata lookup
1315
+ else:
1316
+ span_id = self._run_id_to_span_id[run_id]
1317
+
1318
+ start_time = self._span_id_to_start_time.get(span_id) if span_id else None
1319
+ depth = self._span_id_to_depth.get(span_id, 0) if span_id else 0 # Use 0 depth if span_id is None
1320
+ duration = time.time() - start_time if start_time is not None else None
1321
+ # self._log(f" Ending span for run_id={run_id} (span_id={span_id}). Start time={start_time}, Duration={duration}, Depth={depth}")
1322
+
1323
+ # Record output/error first
1324
+ if error:
1325
+ # self._log(f" Recording error for run_id={run_id} (span_id={span_id}): {error}")
1326
+ self._record_output_data(trace_client, run_id, error)
1327
+ elif outputs is not None:
1328
+ # output_repr = repr(outputs)
1329
+ # log_output = (output_repr[:100] + '...') if len(output_repr) > 103 else output_repr
1330
+ # self._log(f" Recording output for run_id={run_id} (span_id={span_id}): {log_output}")
1331
+ self._record_output_data(trace_client, run_id, outputs)
1332
+
1333
+ # Add exit entry (only if span was tracked)
1334
+ if span_id:
1335
+ entry_function_name = "unknown"
1336
+ try:
1337
+ if hasattr(trace_client, 'entries') and trace_client.entries:
1338
+ entry_function_name = next((e.function for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), "unknown")
1339
+ else:
1340
+ # self._log(f" WARNING: Cannot determine function name for exit span_id {span_id}, trace_client.entries missing or empty.")
1341
+ pass
1342
+ except Exception as e:
1343
+ # self._log(f" ERROR finding function name for exit entry span_id {span_id}: {e}")
1344
+ # # print(traceback.format_exc())
1345
+ pass
1346
+
1347
+ try:
1348
+ trace_client.add_entry(TraceEntry(
1349
+ type="exit", span_id=span_id, trace_id=trace_client.trace_id,
1350
+ depth=depth, created_at=time.time(), duration=duration,
1351
+ span_type=span_type, function=entry_function_name
1352
+ ))
1353
+ # self._log(f" Added 'exit' entry for span_id={span_id}, function='{entry_function_name}'")
1354
+ except Exception as e:
1355
+ # self._log(f" ERROR adding 'exit' entry for span_id {span_id}: {e}")
1356
+ # # print(traceback.format_exc())
1357
+ pass
1358
+
1359
+ # Clean up dictionaries for this specific span
1360
+ if span_id in self._span_id_to_start_time: del self._span_id_to_start_time[span_id]
1361
+ if span_id in self._span_id_to_depth: del self._span_id_to_depth[span_id]
1362
+
1363
+ # Pop context token (Sync version) but don't reset
1364
+ token = self._run_id_to_context_token.pop(run_id, None)
1365
+ if token:
1366
+ # self._log(f" Popped token for run_id {run_id} (was {span_id}), not resetting context var.")
1367
+ pass
1368
+ else:
1369
+ # self._log(f" Skipping exit entry and cleanup for run_id {run_id} as span_id was not found.")
1370
+ pass
1371
+
1372
+ # Check if this is the root run ending
1373
+ if run_id == self._root_run_id:
1374
+ trace_saved_successfully = False # Track save success
1375
+ try:
1376
+ # Reset root run id after attempt
1377
+ self._root_run_id = None
1378
+ # Reset input storage for this handler instance
1379
+ self._run_id_to_start_inputs = {}
1380
+ self._log(f"Reset root run ID and input storage for handler {instance_id}.")
1381
+
1382
+ self._log(f"Root run {run_id} finished. Attempting to save trace...")
1383
+ if self._trace_client and not self._trace_saved: # Check if not already saved
1384
+ try:
1385
+ # TODO: Check if trace_client.save needs await if TraceClient becomes async
1386
+ trace_id, trace_data = self._trace_client.save(overwrite=self._trace_client.overwrite) # Use client's overwrite setting
1387
+ self.traces.append(trace_data)
1388
+ self._log(f"Trace {trace_id} successfully saved.")
1389
+ self._trace_saved = True # Set flag only after successful save
1390
+ trace_saved_successfully = True # Mark success
1391
+ except Exception as e:
1392
+ self._log(f"ERROR saving trace {self._trace_client.trace_id}: {e}")
1393
+ # print(traceback.format_exc())
1394
+ # REMOVED FINALLY BLOCK THAT RESET STATE HERE
1395
+ elif self._trace_client and self._trace_saved:
1396
+ self._log(f" Trace {self._trace_client.trace_id} already saved. Skipping save.")
1397
+ else:
1398
+ self._log(f" WARNING: Root run {run_id} ended, but trace client was None. Cannot save trace.")
1399
+ finally:
1400
+ # --- NEW: Consolidated Cleanup Logic ---
1401
+ # This block executes regardless of save success/failure
1402
+ self._log(f" Performing cleanup for root run {run_id} in handler {instance_id}.")
1403
+ # Reset root run id
1404
+ self._root_run_id = None
1405
+ # Reset input storage for this handler instance
1406
+ self._run_id_to_start_inputs = {}
1407
+ # Reset tracer's active client ONLY IF it was this handler's client
1408
+ if self.tracer._active_trace_client == self._trace_client:
1409
+ self.tracer._active_trace_client = None
1410
+ self._log(" Reset active_trace_client on Tracer.")
1411
+ # Completely remove trace_context_token cleanup as it's not used in sync handler
1412
+ # Optionally: Reset the entire trace client instance for this handler?
1413
+ # self._trace_client = None # Uncomment if handler should reset client completely after root run
1414
+ self._log(f" Cleanup complete for root run {run_id}.")
1415
+ # --- End Cleanup Logic ---
1416
+
1417
+ # NOTE: _record_input_data remains synchronous for now
1418
+ def _record_input_data(self,
1419
+ trace_client: TraceClient,
1420
+ run_id: UUID,
1421
+ inputs: Dict[str, Any]):
1422
+ # self._log(f"_record_input_data called for run_id={run_id}")
1423
+ if run_id not in self._run_id_to_span_id:
1424
+ # self._log(f" WARNING: Attempting to record input for untracked run_id: {run_id}")
1425
+ return
1426
+ if not trace_client:
1427
+ # self._log(f" ERROR: TraceClient is None when trying to record input for run_id={run_id}")
1428
+ return
1429
+
1430
+ span_id = self._run_id_to_span_id[run_id]
1431
+ depth = self._span_id_to_depth.get(span_id, 0)
1432
+ # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
1433
+
1434
+ function_name = "unknown"
1435
+ span_type: SpanType = "span"
1436
+ try:
1437
+ # Find the corresponding 'enter' entry to get the function name and span type
1438
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
1439
+ if enter_entry:
1440
+ function_name = enter_entry.function
1441
+ span_type = enter_entry.span_type
1442
+ # self._log(f" Found function='{function_name}', span_type='{span_type}' for input span_id={span_id}")
1443
+ else:
1444
+ # self._log(f" WARNING: Could not find 'enter' entry for input span_id={span_id}")
1445
+ pass
1446
+ except Exception as e:
1447
+ # self._log(f" ERROR finding enter entry for input span_id {span_id}: {e}")
1448
+ # # print(traceback.format_exc())
1449
+ pass
1450
+
1451
+ try:
1452
+ input_entry = TraceEntry(
1453
+ type="input",
1454
+ span_id=span_id,
1455
+ trace_id=trace_client.trace_id,
1456
+ parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
1457
+ function=function_name,
1458
+ depth=depth,
1459
+ message=f"Input to {function_name}",
1460
+ created_at=time.time(),
1461
+ inputs=inputs,
1462
+ span_type=span_type
1463
+ )
1464
+ trace_client.add_entry(input_entry)
1465
+ # self._log(f" Added 'input' entry directly for span_id={span_id}")
1466
+ except Exception as e:
1467
+ # self._log(f" ERROR adding 'input' entry directly for span_id {span_id}: {e}")
1468
+ # # print(traceback.format_exc())
1469
+ pass
1470
+
1471
+ # NOTE: _record_output_data remains synchronous for now
1472
+ def _record_output_data(self,
1473
+ trace_client: TraceClient,
1474
+ run_id: UUID,
1475
+ output: Any):
1476
+ self._log(f"_record_output_data called for run_id={run_id}")
1477
+ if run_id not in self._run_id_to_span_id:
1478
+ # self._log(f" WARNING: Attempting to record output for untracked run_id: {run_id}")
1479
+ return
1480
+ if not trace_client:
1481
+ # self._log(f" ERROR: TraceClient is None when trying to record output for run_id={run_id}")
1482
+ return
1483
+
1484
+ span_id = self._run_id_to_span_id[run_id]
1485
+ depth = self._span_id_to_depth.get(span_id, 0)
1486
+ # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
1487
+
1488
+ function_name = "unknown"
1489
+ span_type: SpanType = "span"
1490
+ try:
1491
+ # Find the corresponding 'enter' entry to get the function name and span type
1492
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
1493
+ if enter_entry:
1494
+ function_name = enter_entry.function
1495
+ span_type = enter_entry.span_type
1496
+ # self._log(f" Found function='{function_name}', span_type='{span_type}' for output span_id={span_id}")
1497
+ else:
1498
+ # self._log(f" WARNING: Could not find 'enter' entry for output span_id={span_id}")
1499
+ pass
1500
+ except Exception as e:
1501
+ # self._log(f" ERROR finding enter entry for output span_id {span_id}: {e}")
1502
+ # # print(traceback.format_exc())
1503
+ pass
1504
+
1505
+ try:
1506
+ output_entry = TraceEntry(
1507
+ type="output",
1508
+ span_id=span_id,
1509
+ trace_id=trace_client.trace_id,
1510
+ parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
1511
+ function=function_name,
1512
+ depth=depth,
1513
+ message=f"Output from {function_name}",
1514
+ created_at=time.time(),
1515
+ output=output, # Langchain outputs are typically serializable directly
1516
+ span_type=span_type
1517
+ )
1518
+ trace_client.add_entry(output_entry)
1519
+ self._log(f" Added 'output' entry directly for span_id={span_id}")
1520
+ except Exception as e:
1521
+ self._log(f" ERROR adding 'output' entry directly for span_id {span_id}: {e}")
1522
+ # print(traceback.format_exc())
1523
+
1524
+ # --- Async Callback Methods ---
1525
+
1526
+ async def on_retriever_start(self, serialized: Dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
1527
+ handler_instance_id = id(self)
1528
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1529
+ serialized_name = serialized.get('name', 'Unknown') if serialized else "Unknown (Serialized=None)"
1530
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1531
+ # print(f"{log_prefix} ENTERING on_retriever_start: name='{serialized_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1532
+ try:
1533
+ name = f"RETRIEVER_{(serialized_name).upper()}"
1534
+ # Pass parent_run_id
1535
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
1536
+ if not trace_client: return
1537
+ inputs = {'query': query, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
1538
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="retriever", inputs=inputs)
1539
+ except Exception as e:
1540
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1541
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1542
+ # print(traceback.format_exc())
1543
+
1544
+ async def on_retriever_end(self, documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1545
+ handler_instance_id = id(self)
1546
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1547
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1548
+ # print(f"{log_prefix} ENTERING on_retriever_end: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1549
+ try:
1550
+ # Pass parent_run_id
1551
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd") # Corrected call
1552
+ if not trace_client: return
1553
+ doc_summary = [{"index": i, "page_content": doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content, "metadata": doc.metadata} for i, doc in enumerate(documents)]
1554
+ outputs = {"document_count": len(documents), "documents": doc_summary, "kwargs": kwargs}
1555
+ self._end_span_tracking(trace_client, run_id, span_type="retriever", outputs=outputs)
1556
+ except Exception as e:
1557
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1558
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1559
+ # print(traceback.format_exc())
1560
+
1561
+ async def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
1562
+ handler_instance_id = id(self)
1563
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1564
+ # Handle potential None for serialized safely
1565
+ serialized_name = serialized.get('name') if serialized else None
1566
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1567
+ # Log the potentially generic or specific name found in serialized
1568
+ log_name = serialized_name if serialized_name else "Unknown (Serialized=None)"
1569
+ # print(f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}] ENTERING on_chain_start: serialized_name='{log_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1570
+
1571
+ try:
1572
+ # Determine the best name and span type
1573
+ name = "Unknown Chain" # Default
1574
+ span_type: SpanType = "chain"
1575
+ node_name = metadata.get("langgraph_node") if metadata else None
1576
+ is_langgraph_root = kwargs.get('name') == 'LangGraph' # Check kwargs
1577
+ is_potential_root_event = parent_run_id is None
1578
+
1579
+ # Define generic names to ignore if node_name is not present
1580
+ GENERIC_NAMES = ["RunnableSequence", "RunnableParallel", "RunnableLambda", "LangGraph", "__start__", "__end__"]
1581
+
1582
+ if node_name:
1583
+ name = node_name
1584
+ self._log(f" LangGraph Node Start Detected: '{name}', run_id={run_id}, parent_run_id={parent_run_id}")
1585
+ if name not in self.executed_nodes: self.executed_nodes.append(name)
1586
+ elif serialized_name and serialized_name not in GENERIC_NAMES:
1587
+ name = serialized_name
1588
+ self._log(f" LangGraph Functional Step (Router?): '{name}', run_id={run_id}, parent_run_id={parent_run_id}")
1589
+ # Correct root detection: Should primarily rely on parent_run_id being None for the *first* event.
1590
+ # kwargs name='LangGraph' might appear later.
1591
+ elif is_potential_root_event: # Check if it's the potential root event
1592
+ # Use the serialized name if available and not generic, otherwise default to 'LangGraph'
1593
+ if serialized_name and serialized_name not in GENERIC_NAMES:
1594
+ name = serialized_name
1595
+ else:
1596
+ name = "LangGraph"
1597
+ self._log(f" LangGraph Root Start Detected (parent_run_id=None): Name='{name}', run_id={run_id}")
1598
+ if self._root_run_id is None: # Only set root_run_id once
1599
+ self._log(f" Setting root run ID to {run_id}")
1600
+ self._root_run_id = run_id
1601
+ # Defer trace client name update until client is ensured
1602
+ elif serialized_name: # Fallback if node_name missing and serialized_name was generic or root wasn't detected
1603
+ name = serialized_name
1604
+ self._log(f" Fallback to serialized_name: '{name}', run_id={run_id}")
1605
+
1606
+ # Ensure trace client exists (using the determined name for initialization if needed)
1607
+ # Pass parent_run_id
1608
+ trace_client = self._ensure_trace_client(run_id, name) # FIXED: Removed parent_run_id
1609
+ if not trace_client:
1610
+ # print(f"{log_prefix} No trace client obtained in on_chain_start for {run_id} ('{name}').")
1611
+ return
1612
+
1613
+ # --- Update Trace Name if Root (Moved After Client Ensure) ---
1614
+ if is_potential_root_event and run_id == self._root_run_id and trace_client.name != name:
1615
+ self._log(f" Updating trace name from '{trace_client.name}' to '{name}' for root run {run_id}")
1616
+ trace_client.name = name
1617
+ # --- End Update Trace Name ---
1618
+
1619
+ # Start span tracking using the determined name and span_type
1620
+ combined_inputs = {'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
1621
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type=span_type, inputs=combined_inputs)
1622
+ # --- Store inputs for potential evaluation later ---
1623
+ self._run_id_to_start_inputs[run_id] = inputs # Store the raw inputs dict
1624
+ self._log(f" Stored inputs for run_id {run_id}")
1625
+ # --- End Store inputs ---
1626
+
1627
+ except Exception as e:
1628
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1629
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1630
+ # print(traceback.format_exc())
1631
+
1632
+ async def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any) -> Any:
1633
+ """
1634
+ Ends span tracking for a chain/node and attempts evaluation if applicable.
1635
+ """
1636
+ # --- Existing logging and client check ---
1637
+ instance_id = id(self)
1638
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {instance_id}]"
1639
+ self._log(f"{log_prefix} ENTERING on_chain_end: run_id={run_id}. Current TraceClient ID: {id(self._trace_client) if self._trace_client else 'None'}")
1640
+ client = self._ensure_trace_client(run_id, "on_chain_end") # Ensure client exists
1641
+ if not client:
1642
+ self._log(f"{log_prefix} No TraceClient found for on_chain_end ({run_id}). Aborting.")
1643
+ return # Early exit if no client
1644
+
1645
+ # --- Get span_id associated with this chain run ---
1646
+ span_id = self._run_id_to_span_id.get(run_id)
1647
+
1648
+ # --- Existing span ending logic ---
1649
+ # Determine span_type for end_span_tracking (copied from sync handler)
1650
+ end_span_type: SpanType = "chain" # Default
1651
+ if span_id: # Check if span_id was actually found
1652
+ try:
1653
+ if hasattr(client, 'entries') and client.entries:
1654
+ enter_entry = next((e for e in reversed(client.entries) if e.span_id == span_id and e.type == "enter"), None)
1655
+ if enter_entry: end_span_type = enter_entry.span_type
1656
+ else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_end span_id={span_id}")
1657
+ except Exception as e:
1658
+ self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_end: {e}")
312
1659
  else:
313
- name = "LLM call"
314
-
315
- self.start_span(name, span_type="llm")
316
- current_trace = self.tracer.get_current_trace()
317
- current_trace.record_input({
318
- 'args': str(messages),
319
- 'kwargs': kwargs
320
- })
321
-
322
- judgeval_callback_handler_var: ContextVar[Optional[JudgevalCallbackHandler]] = ContextVar(
323
- "judgeval_callback_handler", default=None
324
- )
325
-
326
- def set_global_handler(handler: JudgevalCallbackHandler):
327
- if not handler.tracer.enable_monitoring:
328
- return
329
- judgeval_callback_handler_var.set(handler)
330
-
331
- def clear_global_handler():
332
- judgeval_callback_handler_var.set(None)
333
-
334
- register_configure_hook(
335
- context_var=judgeval_callback_handler_var,
336
- inheritable=True,
337
- )
1660
+ self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_end, using default span_type='chain'.")
1661
+
1662
+ # Prepare outputs for end tracking (moved down)
1663
+ combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
1664
+
1665
+ # Call end_span_tracking with potentially determined span_type
1666
+ self._end_span_tracking(client, run_id, span_type=end_span_type, outputs=combined_outputs)
1667
+
1668
+ # --- Root node cleanup REMOVED - Now handled in _end_span_tracking ---
1669
+
1670
+ # --- NEW: Attempt Evaluation by checking output metadata ---
1671
+ eval_config: Optional[EvaluationConfig] = None
1672
+ node_name = "unknown_node" # Default node name
1673
+ # Ensure client exists before proceeding with eval logic that uses it
1674
+ if client:
1675
+ if span_id: # Try to find the node name from the 'enter' entry
1676
+ try:
1677
+ if hasattr(client, 'entries') and client.entries:
1678
+ enter_entry = next((e for e in reversed(client.entries) if e.span_id == span_id and e.type == "enter"), None)
1679
+ if enter_entry: node_name = enter_entry.function
1680
+ except Exception as e:
1681
+ self._log(f" ERROR finding node name for span_id {span_id} in on_chain_end: {e}")
1682
+
1683
+ if span_id and "_judgeval_eval" in outputs: # Only attempt if span exists and key is present
1684
+ raw_eval_config = outputs.get("_judgeval_eval")
1685
+ if isinstance(raw_eval_config, EvaluationConfig):
1686
+ eval_config = raw_eval_config
1687
+ self._log(f"{log_prefix} Found valid EvaluationConfig in outputs for node='{node_name}'.")
1688
+ elif isinstance(raw_eval_config, dict):
1689
+ # Attempt to reconstruct from dict if needed (e.g., if state serialization occurred)
1690
+ try:
1691
+ # Basic check for required keys before attempting reconstruction
1692
+ if "scorers" in raw_eval_config and "example" in raw_eval_config:
1693
+ # Example might also be a dict, try reconstructing it
1694
+ example_data = raw_eval_config["example"]
1695
+ reconstructed_example = Example(**example_data) if isinstance(example_data, dict) else example_data
1696
+
1697
+ if isinstance(reconstructed_example, Example):
1698
+ eval_config = EvaluationConfig(
1699
+ scorers=raw_eval_config["scorers"], # Assumes scorers are serializable or passed correctly
1700
+ example=reconstructed_example,
1701
+ model=raw_eval_config.get("model"),
1702
+ log_results=raw_eval_config.get("log_results", True)
1703
+ )
1704
+ self._log(f"{log_prefix} Reconstructed EvaluationConfig from dict in outputs for node='{node_name}'.")
1705
+ else:
1706
+ self._log(f"{log_prefix} Could not reconstruct Example from dict in _judgeval_eval for node='{node_name}'. Skipping evaluation.")
1707
+ else:
1708
+ self._log(f"{log_prefix} Dict in _judgeval_eval missing required keys ('scorers', 'example') for node='{node_name}'. Skipping evaluation.")
1709
+ except Exception as recon_e:
1710
+ self._log(f"{log_prefix} ERROR attempting to reconstruct EvaluationConfig from dict for node='{node_name}': {recon_e}")
1711
+ # print(traceback.format_exc()) # Print traceback for reconstruction errors
1712
+ else:
1713
+ self._log(f"{log_prefix} Found '_judgeval_eval' key in outputs for node='{node_name}', but it wasn't an EvaluationConfig object or reconstructable dict. Skipping evaluation.")
1714
+
1715
+
1716
+ if eval_config and span_id: # Check eval_config *and* span_id again
1717
+ self._log(f"{log_prefix} Submitting evaluation for span_id={span_id}")
1718
+ try:
1719
+
1720
+ # Call async_evaluate on the TraceClient instance ('client')
1721
+ # Use the correct variable name 'client' here for the async handler
1722
+ client.async_evaluate(
1723
+ scorers=eval_config.scorers,
1724
+ example=eval_config.example,
1725
+ model=eval_config.model,
1726
+ log_results=eval_config.log_results,
1727
+ span_id=span_id # Pass the specific span_id for this node run
1728
+ )
1729
+ self._log(f"{log_prefix} Evaluation submitted successfully for span_id={span_id}.")
1730
+ except Exception as eval_e:
1731
+ self._log(f"{log_prefix} ERROR submitting evaluation for span_id={span_id}: {eval_e}")
1732
+ # print(traceback.format_exc()) # Print traceback for evaluation errors
1733
+ elif "_judgeval_eval" in outputs and not span_id:
1734
+ self._log(f"{log_prefix} WARNING: Found _judgeval_eval in outputs, but span_id for run_id {run_id} was not found. Cannot submit evaluation.")
1735
+ # --- End NEW Evaluation Logic ---
1736
+
1737
+ async def on_chain_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1738
+ handler_instance_id = id(self)
1739
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1740
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1741
+ # print(f"{log_prefix} ENTERING on_chain_error: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1742
+ try:
1743
+ # Pass parent_run_id
1744
+ trace_client = self._ensure_trace_client(run_id, "ChainError") # FIXED: Removed parent_run_id
1745
+ if not trace_client: return
1746
+
1747
+ span_id = self._run_id_to_span_id.get(run_id)
1748
+ span_type: SpanType = "chain"
1749
+ if span_id:
1750
+ try:
1751
+ if hasattr(trace_client, 'entries') and trace_client.entries:
1752
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
1753
+ if enter_entry: span_type = enter_entry.span_type
1754
+ else: self._log(f" WARNING: trace_client.entries not available for on_chain_error span_id={span_id}")
1755
+ except Exception as e:
1756
+ self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_error: {e}")
1757
+
1758
+ self._end_span_tracking(trace_client, run_id, span_type=span_type, error=error)
1759
+ except Exception as e:
1760
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1761
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1762
+ # print(traceback.format_exc())
1763
+
1764
+ async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
1765
+ handler_instance_id = id(self)
1766
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1767
+ # Handle potential None for serialized
1768
+ name = serialized.get("name", "Unnamed Tool") if serialized else "Unknown Tool (Serialized=None)"
1769
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1770
+ # print(f"{log_prefix} ENTERING on_tool_start: name='{name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1771
+ try:
1772
+ # Pass parent_run_id
1773
+ trace_client = self._ensure_trace_client(run_id, name) # FIXED: Removed parent_run_id
1774
+ if not trace_client: return
1775
+
1776
+ combined_inputs = {'input_str': input_str, 'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
1777
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="tool", inputs=combined_inputs)
1778
+
1779
+ # --- Track executed tools (logic remains the same) ---
1780
+ if name not in self.executed_tools: self.executed_tools.append(name)
1781
+ parent_node_name = None
1782
+ if parent_run_id and parent_run_id in self._run_id_to_span_id:
1783
+ parent_span_id = self._run_id_to_span_id[parent_run_id]
1784
+ try:
1785
+ if hasattr(trace_client, 'entries') and trace_client.entries:
1786
+ parent_enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == parent_span_id and e.type == "enter" and e.span_type == "chain"), None)
1787
+ if parent_enter_entry:
1788
+ parent_node_name = parent_enter_entry.function
1789
+ else:
1790
+ self._log(f" WARNING: trace_client.entries not available for parent node {parent_span_id}")
1791
+ except Exception as e:
1792
+ self._log(f" ERROR finding parent node name for tool start span_id {parent_span_id}: {e}")
1793
+
1794
+ node_tool = f"{parent_node_name}:{name}" if parent_node_name else name
1795
+ if node_tool not in self.executed_node_tools: self.executed_node_tools.append(node_tool)
1796
+ self._log(f" Tracked node_tool: {node_tool}")
1797
+ except Exception as e:
1798
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1799
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1800
+ # print(traceback.format_exc())
1801
+
1802
+ async def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1803
+ handler_instance_id = id(self)
1804
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1805
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1806
+ # print(f"{log_prefix} ENTERING on_tool_end: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1807
+ try:
1808
+ # Pass parent_run_id
1809
+ trace_client = self._ensure_trace_client(run_id, "ToolEnd") # FIXED: Removed parent_run_id
1810
+ if not trace_client: return
1811
+ outputs = {"output": output, "kwargs": kwargs}
1812
+ self._end_span_tracking(trace_client, run_id, span_type="tool", outputs=outputs)
1813
+ except Exception as e:
1814
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1815
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1816
+ # print(traceback.format_exc())
1817
+
1818
+ async def on_tool_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1819
+ handler_instance_id = id(self)
1820
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1821
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1822
+ # print(f"{log_prefix} ENTERING on_tool_error: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1823
+ try:
1824
+ # Pass parent_run_id
1825
+ trace_client = self._ensure_trace_client(run_id, "ToolError") # FIXED: Removed parent_run_id
1826
+ if not trace_client: return
1827
+ self._end_span_tracking(trace_client, run_id, span_type="tool", error=error)
1828
+ except Exception as e:
1829
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1830
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1831
+ # print(traceback.format_exc())
1832
+
1833
+ async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
1834
+ handler_instance_id = id(self)
1835
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1836
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1837
+ llm_name = name or serialized.get("name", "LLM Call")
1838
+ # print(f"{log_prefix} ENTERING on_llm_start: name='{llm_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1839
+ try:
1840
+ # Pass parent_run_id
1841
+ trace_client = self._ensure_trace_client(run_id, llm_name) # FIXED: Removed parent_run_id
1842
+ if not trace_client: return
1843
+ inputs = {'prompts': prompts, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
1844
+ self._start_span_tracking(trace_client, run_id, parent_run_id, llm_name, span_type="llm", inputs=inputs)
1845
+ except Exception as e:
1846
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1847
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1848
+ # print(traceback.format_exc())
1849
+
1850
+ async def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1851
+ handler_instance_id = id(self)
1852
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1853
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1854
+
1855
+ try:
1856
+ trace_client = self._ensure_trace_client(run_id, "LLMEnd")
1857
+ if not trace_client:
1858
+ return
1859
+
1860
+ outputs = {"response": response, "kwargs": kwargs}
1861
+ # --- Token Usage Extraction and Accumulation ---
1862
+ token_usage = None
1863
+ prompt_tokens = None # Use standard name
1864
+ completion_tokens = None # Use standard name
1865
+ total_tokens = None
1866
+ try:
1867
+ if response.llm_output and isinstance(response.llm_output, dict):
1868
+ # Check for OpenAI/standard 'token_usage' first
1869
+ if 'token_usage' in response.llm_output:
1870
+ token_usage = response.llm_output.get('token_usage')
1871
+ if token_usage and isinstance(token_usage, dict):
1872
+ self._log(f" Extracted OpenAI token usage for run_id={run_id}: {token_usage}")
1873
+ prompt_tokens = token_usage.get('prompt_tokens')
1874
+ completion_tokens = token_usage.get('completion_tokens')
1875
+ total_tokens = token_usage.get('total_tokens')
1876
+ # Check for Anthropic 'usage'
1877
+ elif 'usage' in response.llm_output:
1878
+ token_usage = response.llm_output.get('usage')
1879
+ if token_usage and isinstance(token_usage, dict):
1880
+ self._log(f" Extracted Anthropic token usage for run_id={run_id}: {token_usage}")
1881
+ prompt_tokens = token_usage.get('input_tokens') # Anthropic uses input_tokens
1882
+ completion_tokens = token_usage.get('output_tokens') # Anthropic uses output_tokens
1883
+ # Calculate total if possible
1884
+ if prompt_tokens is not None and completion_tokens is not None:
1885
+ total_tokens = prompt_tokens + completion_tokens
1886
+ else:
1887
+ self._log(f" Could not calculate total_tokens from Anthropic usage: input={prompt_tokens}, output={completion_tokens}")
1888
+
1889
+ # Add to outputs if any tokens were found
1890
+ if prompt_tokens is not None or completion_tokens is not None or total_tokens is not None:
1891
+ outputs['usage'] = { # Add under 'usage' key
1892
+ 'prompt_tokens': prompt_tokens, # Use standard keys
1893
+ 'completion_tokens': completion_tokens,
1894
+ 'total_tokens': total_tokens
1895
+ }
1896
+ else:
1897
+ self._log(f" Could not extract token usage structure from llm_output for run_id={run_id}")
1898
+ else: self._log(f" llm_output not available/dict for run_id={run_id}")
1899
+ except Exception as e:
1900
+ self._log(f" ERROR extracting token usage for run_id={run_id}: {e}")
1901
+
1902
+ self._end_span_tracking(trace_client, run_id, span_type="llm", outputs=outputs)
1903
+ except Exception as e:
1904
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1905
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1906
+ # print(traceback.format_exc())
1907
+
1908
+ async def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1909
+ handler_instance_id = id(self)
1910
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1911
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1912
+ # print(f"{log_prefix} ENTERING on_llm_error: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1913
+ try:
1914
+ # Pass parent_run_id
1915
+ trace_client = self._ensure_trace_client(run_id, "LLMError") # FIXED: Removed parent_run_id
1916
+ if not trace_client: return
1917
+ self._end_span_tracking(trace_client, run_id, span_type="llm", error=error)
1918
+ except Exception as e:
1919
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1920
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1921
+ # print(traceback.format_exc())
1922
+
1923
+ async def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
1924
+ # Reuse on_llm_start logic, adding message formatting if needed
1925
+ handler_instance_id = id(self)
1926
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1927
+ chat_model_name = name or serialized.get("name", "ChatModel Call")
1928
+ # Add OPENAI_API_CALL suffix if model is OpenAI and not present
1929
+ is_openai = any(key.startswith('openai') for key in serialized.get('secrets', {}).keys()) or 'openai' in chat_model_name.lower()
1930
+ is_anthropic = any(key.startswith('anthropic') for key in serialized.get('secrets', {}).keys()) or 'anthropic' in chat_model_name.lower() or 'claude' in chat_model_name.lower()
1931
+ is_together = any(key.startswith('together') for key in serialized.get('secrets', {}).keys()) or 'together' in chat_model_name.lower()
1932
+ # Add more checks for other providers like Google if needed
1933
+ is_google = any(key.startswith('google') for key in serialized.get('secrets', {}).keys()) or 'google' in chat_model_name.lower() or 'gemini' in chat_model_name.lower()
1934
+
1935
+ if is_openai and "OPENAI_API_CALL" not in chat_model_name:
1936
+ chat_model_name = f"{chat_model_name} OPENAI_API_CALL"
1937
+ elif is_anthropic and "ANTHROPIC_API_CALL" not in chat_model_name:
1938
+ chat_model_name = f"{chat_model_name} ANTHROPIC_API_CALL"
1939
+ elif is_together and "TOGETHER_API_CALL" not in chat_model_name:
1940
+ chat_model_name = f"{chat_model_name} TOGETHER_API_CALL"
1941
+ # Add elif for Google: check for 'google' or 'gemini'?
1942
+ # elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
1943
+ # chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
1944
+ elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
1945
+ chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
1946
+
1947
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1948
+ # print(f"{log_prefix} ENTERING on_chat_model_start: name='{chat_model_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1949
+ try:
1950
+ # trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name) # Corrected call << INCORRECT COMMENT
1951
+ trace_client = self._ensure_trace_client(run_id, chat_model_name) # FIXED: Removed parent_run_id
1952
+ if not trace_client: return
1953
+ inputs = {'messages': messages, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
1954
+ self._start_span_tracking(trace_client, run_id, parent_run_id, chat_model_name, span_type="llm", inputs=inputs) # Use 'llm' span_type for consistency
1955
+ except Exception as e:
1956
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1957
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chat_model_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1958
+ # print(traceback.format_exc())
1959
+
1960
+ # --- Agent Methods (Async versions - ensure parent_run_id passed if needed) ---
1961
+ async def on_agent_action(self, action: AgentAction, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1962
+ handler_instance_id = id(self)
1963
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1964
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1965
+ # print(f"{log_prefix} ENTERING on_agent_action: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1966
+ try:
1967
+ # trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentAction") # Corrected call << INCORRECT COMMENT
1968
+ trace_client = self._ensure_trace_client(run_id, "AgentAction") # FIXED: Removed parent_run_id
1969
+ if not trace_client: return
1970
+ # inputs = {\"action\": action, \"kwargs\": kwargs}
1971
+ inputs = {"action": action, "kwargs": kwargs} # FIXED: Removed bad escapes
1972
+ # Agent actions often lead to tool calls, treat as a distinct step
1973
+ # self._start_span_tracking(trace_client, run_id, parent_run_id, name=\"AgentAction\", span_type=\"agent_action\", inputs=inputs)
1974
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name="AgentAction", span_type="agent_action", inputs=inputs) # FIXED: Removed bad escapes
1975
+ except Exception as e:
1976
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1977
+ # self._log(f\"{log_prefix} UNCAUGHT EXCEPTION in on_agent_action for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}\")
1978
+ self._log(f'{log_prefix} UNCAUGHT EXCEPTION in on_agent_action for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}') # FIXED: Changed f-string quotes
1979
+ # print(traceback.format_exc())
1980
+
1981
+
1982
+ async def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1983
+ handler_instance_id = id(self)
1984
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1985
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1986
+ # print(f"{log_prefix} ENTERING on_agent_finish: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1987
+ try:
1988
+ # trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentFinish") # Corrected call << INCORRECT COMMENT
1989
+ trace_client = self._ensure_trace_client(run_id, "AgentFinish") # FIXED: Removed parent_run_id
1990
+ if not trace_client: return
1991
+ # outputs = {\"finish\": finish, \"kwargs\": kwargs}
1992
+ outputs = {"finish": finish, "kwargs": kwargs} # FIXED: Removed bad escapes
1993
+ # Corresponds to the end of an AgentAction span? Or a chain span? Assuming agent_action here.
1994
+ # self._end_span_tracking(trace_client, run_id, span_type=\"agent_action\", outputs=outputs)
1995
+ self._end_span_tracking(trace_client, run_id, span_type="agent_action", outputs=outputs) # FIXED: Removed bad escapes
1996
+ except Exception as e:
1997
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1998
+ # self._log(f\"{log_prefix} UNCAUGHT EXCEPTION in on_agent_finish for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}\")
1999
+ self._log(f'{log_prefix} UNCAUGHT EXCEPTION in on_agent_finish for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}') # FIXED: Changed f-string quotes
2000
+ # print(traceback.format_exc())