judgeval 0.0.36__py3-none-any.whl → 0.0.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,36 +1,24 @@
1
- from typing import Any, Dict, List, Optional, Sequence, Callable, TypedDict
1
+ from typing import Any, Dict, List, Optional, Sequence
2
2
  from uuid import UUID
3
3
  import time
4
4
  import uuid
5
- import traceback # For detailed error logging
6
5
  import contextvars # <--- Import contextvars
7
- from dataclasses import dataclass
8
6
 
9
- from judgeval.common.tracer import TraceClient, TraceEntry, Tracer, SpanType, EvaluationConfig
7
+ from judgeval.common.tracer import TraceClient, TraceSpan, Tracer, SpanType, EvaluationConfig
10
8
  from judgeval.data import Example # Import Example
11
- from judgeval.scorers import AnswerRelevancyScorer, JudgevalScorer, APIJudgmentScorer # Import Scorer and base scorer types
12
9
 
13
- from langchain_core.language_models import BaseChatModel
14
- from langchain_huggingface import ChatHuggingFace
15
- from langchain_openai import ChatOpenAI
16
- from langchain_anthropic import ChatAnthropic
17
- from langchain_core.utils.function_calling import convert_to_openai_tool
18
10
  from langchain_core.callbacks import BaseCallbackHandler
19
- from langchain_core.callbacks.base import AsyncCallbackHandler
20
11
  from langchain_core.agents import AgentAction, AgentFinish
21
12
  from langchain_core.outputs import LLMResult
22
- from langchain_core.messages.ai import AIMessage
23
- from langchain_core.messages.tool import ToolMessage
24
13
  from langchain_core.messages.base import BaseMessage
25
14
  from langchain_core.documents import Document
26
15
 
27
16
  # --- Get context vars from tracer module ---
28
17
  # Assuming tracer.py defines these and they are accessible
29
18
  # If not, redefine them here or adjust import
30
- from judgeval.common.tracer import current_span_var, current_trace_var # <-- Import current_trace_var
31
19
 
32
- # --- Constants for Logging ---
33
- HANDLER_LOG_PREFIX = "[JudgevalHandlerLog]"
20
+ # from judgeval.common.tracer import current_span_var
21
+ # TODO: Figure out how to handle context variables. Current solution is to keep track of current span id in Tracer class
34
22
 
35
23
  # --- NEW __init__ ---
36
24
  class JudgevalCallbackHandler(BaseCallbackHandler):
@@ -46,28 +34,19 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
46
34
 
47
35
  # --- NEW __init__ ---
48
36
  def __init__(self, tracer: Tracer):
49
- # --- Enhanced Logging ---
50
- # instance_id = id(self)
51
- # # print(f"{HANDLER_LOG_PREFIX} *** Handler instance {instance_id} __init__ called. ***")
52
- # --- End Enhanced Logging ---
37
+
53
38
  self.tracer = tracer
54
39
  self._trace_client: Optional[TraceClient] = None
55
40
  self._run_id_to_span_id: Dict[UUID, str] = {}
56
41
  self._span_id_to_start_time: Dict[str, float] = {}
57
42
  self._span_id_to_depth: Dict[str, int] = {}
58
- self._run_id_to_context_token: Dict[UUID, contextvars.Token] = {}
59
43
  self._root_run_id: Optional[UUID] = None
60
44
  self._trace_saved: bool = False # Flag to prevent actions after trace is saved
61
- self._run_id_to_start_inputs: Dict[UUID, Dict] = {} # <<< ADDED input storage
62
-
63
- # --- Token Count Accumulators ---
64
- # self._current_prompt_tokens = 0
65
- # self._current_completion_tokens = 0
66
- # --- End Token Count Accumulators ---
67
45
 
68
- self.executed_nodes: List[str] = []
46
+ self.executed_nodes: List[str] = [] # These last four members are only appended to and never accessed; can probably be removed but still might be useful for future reference?
69
47
  self.executed_tools: List[str] = []
70
48
  self.executed_node_tools: List[str] = []
49
+ self.traces: List[Dict[str, Any]] = []
71
50
  # --- END NEW __init__ ---
72
51
 
73
52
  # --- MODIFIED _ensure_trace_client ---
@@ -77,21 +56,12 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
77
56
  per handler instance lifecycle (effectively per graph invocation).
78
57
  Returns the client or None.
79
58
  """
80
- # handler_instance_id = id(self)
81
- # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
82
-
83
- # If trace already saved, do nothing.
84
- # if self._trace_saved:
85
- # # print(f"{log_prefix} Trace already saved. Skipping client check for {event_name} ({run_id}).")
86
- # return None
87
59
 
88
60
  # If a client already exists, return it.
89
61
  if self._trace_client:
90
- # # print(f"{log_prefix} Reusing existing TraceClient (ID: {self._trace_client.trace_id}) for {event_name} ({run_id}).")
91
62
  return self._trace_client
92
63
 
93
64
  # If no client exists, initialize it NOW.
94
- # # print(f"{log_prefix} No TraceClient exists. Initializing for first event: {event_name} ({run_id})...")
95
65
  trace_id = str(uuid.uuid4())
96
66
  project = self.tracer.project_name
97
67
  try:
@@ -106,51 +76,27 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
106
76
  if self._trace_client:
107
77
  self._root_run_id = run_id # Assign the first run_id encountered as the tentative root
108
78
  self._trace_saved = False # Ensure flag is reset
109
- # # print(f"{log_prefix} Initialized NEW TraceClient: ID={self._trace_client.trace_id}, InitialName='{event_name}', Root Run ID={self._root_run_id}")
110
79
  # Set active client on Tracer (important for potential fallbacks)
111
80
  self.tracer._active_trace_client = self._trace_client
112
81
  return self._trace_client
113
82
  else:
114
- # # print(f"{log_prefix} FATAL: TraceClient creation failed unexpectedly for {event_name} ({run_id}).")
115
83
  return None
116
84
  except Exception as e:
117
- # # print(f"{log_prefix} FATAL: Exception initializing TraceClient for {event_name} ({run_id}): {e}")
118
- # # print(traceback.format_exc())
85
+
119
86
  self._trace_client = None
120
87
  self._root_run_id = None
121
88
  return None
122
- # --- END MODIFIED _ensure_trace_client ---
123
-
124
- def _log(self, message: str):
125
- """Helper for consistent logging format."""
126
- pass
127
89
 
128
90
  def _start_span_tracking(
129
91
  self,
130
- trace_client: TraceClient, # Expect a valid client
92
+ trace_client: TraceClient,
131
93
  run_id: UUID,
132
94
  parent_run_id: Optional[UUID],
133
95
  name: str,
134
96
  span_type: SpanType = "span",
135
97
  inputs: Optional[Dict[str, Any]] = None
136
- ):
137
- # self._log(f"_start_span_tracking called for: name='{name}', run_id={run_id}, parent_run_id={parent_run_id}, span_type={span_type}")
138
-
139
- # --- Add explicit check for trace_client ---
140
- if not trace_client:
141
- # --- Enhanced Logging ---
142
- # handler_instance_id = id(self)
143
- # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
144
- # --- End Enhanced Logging ---
145
- # self._log(f"{log_prefix} FATAL ERROR in _start_span_tracking: trace_client argument is None for name='{name}', run_id={run_id}. Aborting span start.")
146
- return
147
- # --- End check ---
148
- # --- Enhanced Logging ---
149
- # handler_instance_id = id(self)
150
- # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
151
- # trace_client_instance_id = id(trace_client) if trace_client else 'None'
152
- # # print(f"{log_prefix} _start_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
153
- # --- End Enhanced Logging ---
98
+ ) -> None:
99
+ """Start tracking a span, ensuring trace client exists"""
154
100
 
155
101
  start_time = time.time()
156
102
  span_id = str(uuid.uuid4())
@@ -160,442 +106,102 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
160
106
  if parent_run_id and parent_run_id in self._run_id_to_span_id:
161
107
  parent_span_id = self._run_id_to_span_id[parent_run_id]
162
108
  if parent_span_id in self._span_id_to_depth:
163
- parent_depth = self._span_id_to_depth[parent_span_id]
164
- current_depth = parent_depth + 1
165
- # self._log(f" Found parent span_id={parent_span_id} with depth={parent_depth}. New depth={current_depth}.")
166
- else:
167
- # self._log(f" WARNING: Parent span depth not found for parent_span_id: {parent_span_id}. Setting depth to 0.")
168
- current_depth = 0
169
- elif parent_run_id:
170
- # self._log(f" WARNING: parent_run_id {parent_run_id} provided for '{name}' ({run_id}) but parent span not tracked. Treating as depth 0.")
171
- pass
172
- else:
173
- # self._log(f" No parent_run_id provided. Treating '{name}' as depth 0.")
174
- pass
109
+ current_depth = self._span_id_to_depth[parent_span_id] + 1
175
110
 
176
111
  self._run_id_to_span_id[run_id] = span_id
177
112
  self._span_id_to_start_time[span_id] = start_time
178
113
  self._span_id_to_depth[span_id] = current_depth
179
- # self._log(f" Tracking new span: span_id={span_id}, depth={current_depth}")
180
114
 
181
- try:
182
- trace_client.add_entry(TraceEntry(
183
- type="enter", span_id=span_id, trace_id=trace_client.trace_id,
184
- parent_span_id=parent_span_id, function=name, depth=current_depth,
185
- message=name, created_at=start_time, span_type=span_type
186
- ))
187
- # self._log(f" Added 'enter' entry for span_id={span_id}")
188
- except Exception as e:
189
- # self._log(f" ERROR adding 'enter' entry for span_id {span_id}: {e}")
190
- # # print(traceback.format_exc())
191
- pass
192
-
193
- if inputs:
194
- # Pass the already validated trace_client
195
- self._record_input_data(trace_client, run_id, inputs)
196
115
 
197
116
  # --- Set SPAN context variable ONLY for chain (node) spans (Sync version) ---
198
117
  if span_type == "chain":
199
- try:
200
- token = current_span_var.set(span_id)
201
- self._run_id_to_context_token[run_id] = token
202
- # self._log(f" Set current_span_var to {span_id} for run_id {run_id} (type: chain) in Sync Handler")
203
- except Exception as e:
204
- # self._log(f" ERROR setting current_span_var for run_id {run_id} in Sync Handler: {e}")
205
- pass
206
- # --- END ---
118
+ self.tracer.set_current_span(span_id)
207
119
 
208
- try:
209
- # TODO: Check if trace_client.add_entry needs await if TraceClient becomes async
210
- trace_client.add_entry(TraceEntry(
211
- type="enter", span_id=span_id, trace_id=trace_client.trace_id,
212
- parent_span_id=parent_span_id, function=name, depth=current_depth,
213
- message=name, created_at=start_time, span_type=span_type
214
- ))
215
- # self._log(f" Added 'enter' entry for span_id={span_id}")
216
- except Exception as e:
217
- # self._log(f" ERROR adding 'enter' entry for span_id {span_id}: {e}")
218
- # # print(traceback.format_exc())
219
- pass
120
+ new_trace = TraceSpan(
121
+ span_id=span_id,
122
+ trace_id=trace_client.trace_id,
123
+ parent_span_id=parent_span_id,
124
+ function=name,
125
+ depth=current_depth,
126
+ created_at=start_time,
127
+ span_type=span_type
128
+ )
220
129
 
221
- if inputs:
222
- # _record_input_data is also sync for now
223
- self._record_input_data(trace_client, run_id, inputs)
130
+ new_trace.inputs = inputs
131
+
132
+ trace_client.add_span(new_trace)
224
133
 
225
- # --- NEW _end_span_tracking ---
226
134
  def _end_span_tracking(
227
135
  self,
228
- trace_client: TraceClient, # Expect a valid client
136
+ trace_client: TraceClient,
229
137
  run_id: UUID,
230
- span_type: SpanType = "span",
231
138
  outputs: Optional[Any] = None,
232
139
  error: Optional[BaseException] = None
233
- ):
234
- # self._log(f"_end_span_tracking called for: run_id={run_id}, span_type={span_type}")
140
+ ) -> None:
141
+ """End tracking a span, ensuring trace client exists"""
235
142
 
236
- # --- Define instance_id early for logging/cleanup ---
237
- instance_id = id(self)
238
-
239
- if not trace_client:
240
- # Use instance_id defined above
241
- # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
242
- # self._log(f"{log_prefix} FATAL ERROR in _end_span_tracking: trace_client argument is None for run_id={run_id}. Aborting span end.")
243
- return
244
-
245
- # Use instance_id defined above
246
- # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
247
- # trace_client_instance_id = id(trace_client) if trace_client else 'None'
248
- # # print(f"{log_prefix} _end_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
249
-
250
- if run_id not in self._run_id_to_span_id:
251
- # self._log(f" WARNING: Attempting to end span for untracked run_id: {run_id}")
252
- # Allow root run end to proceed for cleanup/save attempt even if span wasn't tracked
253
- if run_id != self._root_run_id:
254
- return
255
- else:
256
- # self._log(f" Allowing root run {run_id} end logic to proceed despite untracked span.")
257
- span_id = None # Indicate span wasn't found for duration/metadata lookup
258
- else:
259
- span_id = self._run_id_to_span_id[run_id]
143
+ # Get span ID and check if it exists
144
+ span_id = self._run_id_to_span_id.get(run_id)
260
145
 
261
146
  start_time = self._span_id_to_start_time.get(span_id) if span_id else None
262
- depth = self._span_id_to_depth.get(span_id, 0) if span_id else 0 # Use 0 depth if span_id is None
263
147
  duration = time.time() - start_time if start_time is not None else None
264
- # self._log(f" Ending span for run_id={run_id} (span_id={span_id}). Start time={start_time}, Duration={duration}, Depth={depth}")
265
-
266
- # Record output/error first
267
- if error:
268
- # self._log(f" Recording error for run_id={run_id} (span_id={span_id}): {error}")
269
- self._record_output_data(trace_client, run_id, error)
270
- elif outputs is not None:
271
- # output_repr = repr(outputs)
272
- # log_output = (output_repr[:100] + '...') if len(output_repr) > 103 else output_repr
273
- # self._log(f" Recording output for run_id={run_id} (span_id={span_id}): {log_output}")
274
- self._record_output_data(trace_client, run_id, outputs)
275
148
 
276
149
  # Add exit entry (only if span was tracked)
277
150
  if span_id:
278
- entry_function_name = "unknown"
279
- try:
280
- if hasattr(trace_client, 'entries') and trace_client.entries:
281
- entry_function_name = next((e.function for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), "unknown")
282
- else:
283
- # self._log(f" WARNING: Cannot determine function name for exit span_id {span_id}, trace_client.entries missing or empty.")
284
- pass
285
- except Exception as e:
286
- # self._log(f" ERROR finding function name for exit entry span_id {span_id}: {e}")
287
- # # print(traceback.format_exc())
288
- pass
289
-
290
- try:
291
- trace_client.add_entry(TraceEntry(
292
- type="exit", span_id=span_id, trace_id=trace_client.trace_id,
293
- depth=depth, created_at=time.time(), duration=duration,
294
- span_type=span_type, function=entry_function_name
295
- ))
296
- # self._log(f" Added 'exit' entry for span_id={span_id}, function='{entry_function_name}'")
297
- except Exception as e:
298
- # self._log(f" ERROR adding 'exit' entry for span_id {span_id}: {e}")
299
- # # print(traceback.format_exc())
300
- pass
151
+ trace_span = trace_client.span_id_to_span.get(span_id)
152
+ if trace_span:
153
+ trace_span.duration = duration
154
+ trace_span.output = error if error else outputs
301
155
 
302
156
  # Clean up dictionaries for this specific span
303
157
  if span_id in self._span_id_to_start_time: del self._span_id_to_start_time[span_id]
304
158
  if span_id in self._span_id_to_depth: del self._span_id_to_depth[span_id]
305
159
 
306
- # Pop context token (Sync version) but don't reset
307
- token = self._run_id_to_context_token.pop(run_id, None)
308
- if token:
309
- # self._log(f" Popped token for run_id {run_id} (was {span_id}), not resetting context var.")
310
- pass
311
- else:
312
- # self._log(f" Skipping exit entry and cleanup for run_id {run_id} as span_id was not found.")
313
- pass
314
-
315
160
  # Check if this is the root run ending
316
161
  if run_id == self._root_run_id:
317
- trace_saved_successfully = False # Track save success
318
162
  try:
319
- # --- Aggregate and Set Token Counts BEFORE Saving ---
320
- # if self._trace_client and not self._trace_saved:
321
- # total_tokens = self._current_prompt_tokens + self._current_completion_tokens
322
- # aggregated_token_counts = {
323
- # 'prompt_tokens': self._current_prompt_tokens,
324
- # 'completion_tokens': self._current_completion_tokens,
325
- # 'total_tokens': total_tokens
326
- # }
327
- # # Assuming TraceClient has an attribute to hold the trace data being built
328
- # try:
329
- # # Attempt to set the attribute directly
330
- # # Check if the attribute exists and is meant for this purpose
331
- # if hasattr(self._trace_client, 'token_counts'):
332
- # self._trace_client.token_counts = aggregated_token_counts
333
- # self._log(f"Set aggregated token_counts on TraceClient for trace {self._trace_client.trace_id}: {aggregated_token_counts}")
334
- # else:
335
- # # If the attribute doesn't exist, maybe update the trace data dict directly if possible?
336
- # # This part is speculative without knowing TraceClient internals.
337
- # # E.g., if trace_client has a `_trace_data` dict:
338
- # # if hasattr(self._trace_client, '_trace_data') and isinstance(self._trace_client._trace_data, dict):
339
- # # self._trace_client._trace_data['token_counts'] = aggregated_token_counts
340
- # # self._log(f"Updated _trace_data['token_counts'] on TraceClient for trace {self._trace_client.trace_id}: {aggregated_token_counts}")
341
- # # else:
342
- # self._log(f"WARNING: Could not set 'token_counts' on TraceClient for trace {self._trace_client.trace_id}. Aggregated counts might be lost.")
343
- # except Exception as set_tc_e:
344
- # self._log(f"ERROR setting token_counts on TraceClient for trace {self._trace_client.trace_id}: {set_tc_e}")
345
- # --- End Token Count Aggregation ---
346
-
347
163
  # Reset root run id after attempt
348
164
  self._root_run_id = None
349
165
  # Reset input storage for this handler instance
350
- self._run_id_to_start_inputs = {}
351
- self._log(f"Reset root run ID and input storage for handler {instance_id}.")
352
166
 
353
- self._log(f"Root run {run_id} finished. Attempting to save trace...")
354
167
  if self._trace_client and not self._trace_saved: # Check if not already saved
355
- try:
356
- # TODO: Check if trace_client.save needs await if TraceClient becomes async
357
- trace_id, _ = self._trace_client.save(overwrite=self._trace_client.overwrite) # Use client's overwrite setting
358
- self._log(f"Trace {trace_id} successfully saved.")
359
- self._trace_saved = True # Set flag only after successful save
360
- trace_saved_successfully = True # Mark success
361
- except Exception as e:
362
- self._log(f"ERROR saving trace {self._trace_client.trace_id}: {e}")
363
- # print(traceback.format_exc())
364
- # REMOVED FINALLY BLOCK THAT RESET STATE HERE
365
- elif self._trace_client and self._trace_saved:
366
- self._log(f" Trace {self._trace_client.trace_id} already saved. Skipping save.")
367
- else:
368
- self._log(f" WARNING: Root run {run_id} ended, but trace client was None. Cannot save trace.")
168
+ # TODO: Check if trace_client.save needs await if TraceClient becomes async
169
+ trace_id, trace_data = self._trace_client.save(overwrite=self._trace_client.overwrite) # Use client's overwrite setting
170
+ self.traces.append(trace_data) # Leaving this in for now but can probably be removed
171
+ self._trace_saved = True # Set flag only after successful save
369
172
  finally:
370
173
  # --- NEW: Consolidated Cleanup Logic ---
371
174
  # This block executes regardless of save success/failure
372
- self._log(f" Performing cleanup for root run {run_id} in handler {instance_id}.")
373
175
  # Reset root run id
374
176
  self._root_run_id = None
375
177
  # Reset input storage for this handler instance
376
- self._run_id_to_start_inputs = {}
377
- # --- Reset Token Counters ---
378
- # self._current_prompt_tokens = 0
379
- # self._current_completion_tokens = 0
380
- # self._log(" Reset token counters.")
381
- # --- End Reset Token Counters ---
382
- # Reset tracer's active client ONLY IF it was this handler's client
383
178
  if self.tracer._active_trace_client == self._trace_client:
384
179
  self.tracer._active_trace_client = None
385
- self._log(" Reset active_trace_client on Tracer.")
386
- # Completely remove trace_context_token cleanup as it's not used in sync handler
387
- # Optionally: Reset the entire trace client instance for this handler?
388
- # self._trace_client = None # Uncomment if handler should reset client completely after root run
389
- self._log(f" Cleanup complete for root run {run_id}.")
390
180
  # --- End Cleanup Logic ---
391
181
 
392
- def _record_input_data(self,
393
- trace_client: TraceClient,
394
- run_id: UUID,
395
- inputs: Dict[str, Any]):
396
- # self._log(f"_record_input_data called for run_id={run_id}")
397
- if run_id not in self._run_id_to_span_id:
398
- # self._log(f" WARNING: Attempting to record input for untracked run_id: {run_id}")
399
- return
400
- if not trace_client:
401
- # self._log(f" ERROR: TraceClient is None when trying to record input for run_id={run_id}")
402
- return
403
-
404
- span_id = self._run_id_to_span_id[run_id]
405
- depth = self._span_id_to_depth.get(span_id, 0)
406
- # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
407
-
408
- function_name = "unknown"
409
- span_type: SpanType = "span"
410
- try:
411
- # Find the corresponding 'enter' entry to get the function name and span type
412
- enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
413
- if enter_entry:
414
- function_name = enter_entry.function
415
- span_type = enter_entry.span_type
416
- # self._log(f" Found function='{function_name}', span_type='{span_type}' for input span_id={span_id}")
417
- else:
418
- # self._log(f" WARNING: Could not find 'enter' entry for input span_id={span_id}")
419
- pass
420
- except Exception as e:
421
- # self._log(f" ERROR finding enter entry for input span_id {span_id}: {e}")
422
- # # print(traceback.format_exc())
423
- pass
424
-
425
- try:
426
- input_entry = TraceEntry(
427
- type="input",
428
- span_id=span_id,
429
- trace_id=trace_client.trace_id,
430
- parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
431
- function=function_name,
432
- depth=depth,
433
- message=f"Input to {function_name}",
434
- created_at=time.time(),
435
- inputs=inputs,
436
- span_type=span_type
437
- )
438
- trace_client.add_entry(input_entry)
439
- # self._log(f" Added 'input' entry directly for span_id={span_id}")
440
- except Exception as e:
441
- # self._log(f" ERROR adding 'input' entry directly for span_id {span_id}: {e}")
442
- # # print(traceback.format_exc())
443
- pass
444
-
445
- def _record_output_data(self,
446
- trace_client: TraceClient,
447
- run_id: UUID,
448
- output: Any):
449
- # self._log(f"_record_output_data called for run_id={run_id}")
450
- if run_id not in self._run_id_to_span_id:
451
- # self._log(f" WARNING: Attempting to record output for untracked run_id: {run_id}")
452
- return
453
- if not trace_client:
454
- # self._log(f" ERROR: TraceClient is None when trying to record output for run_id={run_id}")
455
- return
456
-
457
- span_id = self._run_id_to_span_id[run_id]
458
- depth = self._span_id_to_depth.get(span_id, 0)
459
- # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
460
-
461
- function_name = "unknown"
462
- span_type: SpanType = "span"
463
- try:
464
- # Find the corresponding 'enter' entry to get the function name and span type
465
- enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
466
- if enter_entry:
467
- function_name = enter_entry.function
468
- span_type = enter_entry.span_type
469
- # self._log(f" Found function='{function_name}', span_type='{span_type}' for output span_id={span_id}")
470
- else:
471
- # self._log(f" WARNING: Could not find 'enter' entry for output span_id={span_id}")
472
- pass
473
- except Exception as e:
474
- # self._log(f" ERROR finding enter entry for output span_id {span_id}: {e}")
475
- # # print(traceback.format_exc())
476
- pass
477
-
478
- try:
479
- output_entry = TraceEntry(
480
- type="output",
481
- span_id=span_id,
482
- trace_id=trace_client.trace_id,
483
- parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
484
- function=function_name,
485
- depth=depth,
486
- message=f"Output from {function_name}",
487
- created_at=time.time(),
488
- output=output, # Langchain outputs are typically serializable directly
489
- span_type=span_type
490
- )
491
- trace_client.add_entry(output_entry)
492
- self._log(f" Added 'output' entry directly for span_id={span_id}")
493
- except Exception as e:
494
- self._log(f" ERROR adding 'output' entry directly for span_id {span_id}: {e}")
495
- # print(traceback.format_exc())
496
-
497
- def _record_error(self,
498
- trace_client: TraceClient,
499
- run_id: UUID,
500
- error: Any):
501
- # self._log(f"_record_error called for run_id={run_id}")
502
- if run_id not in self._run_id_to_span_id:
503
- # self._log(f" WARNING: Attempting to record error for untracked run_id: {run_id}")
504
- return
505
- if not trace_client:
506
- # self._log(f" ERROR: TraceClient is None when trying to record error for run_id={run_id}")
507
- return
508
-
509
- span_id = self._run_id_to_span_id[run_id]
510
- depth = self._span_id_to_depth.get(span_id, 0)
511
- # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
512
-
513
- function_name = "unknown"
514
- span_type: SpanType = "span"
515
- try:
516
- # Find the corresponding 'enter' entry to get the function name and span type
517
- enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
518
- if enter_entry:
519
- function_name = enter_entry.function
520
- span_type = enter_entry.span_type
521
- # self._log(f" Found function='{function_name}', span_type='{span_type}' for error span_id={span_id}")
522
- else:
523
- # self._log(f" WARNING: Could not find 'enter' entry for error span_id={span_id}")
524
- pass
525
- except Exception as e:
526
- # self._log(f" ERROR finding enter entry for error span_id {span_id}: {e}")
527
- # # print(traceback.format_exc())
528
- pass
529
-
530
- try:
531
- error_entry = TraceEntry(
532
- type="error",
533
- span_id=span_id,
534
- trace_id=trace_client.trace_id,
535
- parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
536
- function=function_name,
537
- depth=depth,
538
- message=f"Error in {function_name}",
539
- created_at=time.time(),
540
- error=str(error), # Convert error to string for serialization
541
- span_type=span_type
542
- )
543
- trace_client.add_entry(error_entry)
544
- # self._log(f" Added 'error' entry directly for span_id={span_id}")
545
- except Exception as e:
546
- # self._log(f" ERROR adding 'error' entry directly for span_id {span_id}: {e}")
547
- # # print(traceback.format_exc())
548
- pass
549
-
550
182
  # --- Callback Methods ---
551
183
  # Each method now ensures the trace client exists before proceeding
552
184
 
553
185
  def on_retriever_start(self, serialized: Dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
554
- handler_instance_id = id(self)
555
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
556
186
  serialized_name = serialized.get('name', 'Unknown') if serialized else "Unknown (Serialized=None)"
557
- # print(f"{log_prefix} ENTERING on_retriever_start: name='{serialized_name}', run_id={run_id}. Parent: {parent_run_id}")
558
187
 
559
- try:
560
- name = f"RETRIEVER_{(serialized_name).upper()}"
561
- # Pass parent_run_id to _ensure_trace_client
562
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
563
- if not trace_client:
564
- # print(f"{log_prefix} No trace client obtained in on_retriever_start for {run_id}.")
565
- return
188
+ name = f"RETRIEVER_{(serialized_name).upper()}"
189
+ # Pass parent_run_id
190
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
191
+ if not trace_client: return
566
192
 
567
- inputs = {'query': query, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
568
- self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="retriever", inputs=inputs)
569
- except Exception as e:
570
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
571
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
572
- # print(traceback.format_exc())
193
+ inputs = {'query': query, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
194
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="retriever", inputs=inputs)
573
195
 
574
196
  def on_retriever_end(self, documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
575
- handler_instance_id = id(self)
576
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
577
- # print(f"{log_prefix} ENTERING on_retriever_end: run_id={run_id}. Parent: {parent_run_id}")
578
-
579
- try:
580
- # Pass parent_run_id to _ensure_trace_client (though less critical on end events)
581
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd") # Corrected call
582
- if not trace_client:
583
- # print(f"{log_prefix} No trace client obtained in on_retriever_end for {run_id}.")
584
- return
585
-
586
- doc_summary = [{"index": i, "page_content": doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content, "metadata": doc.metadata} for i, doc in enumerate(documents)]
587
- outputs = {"document_count": len(documents), "documents": doc_summary, "kwargs": kwargs}
588
- self._end_span_tracking(trace_client, run_id, span_type="retriever", outputs=outputs)
589
- except Exception as e:
590
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
591
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
592
- # print(traceback.format_exc())
197
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd") # Corrected call
198
+ if not trace_client: return
199
+ doc_summary = [{"index": i, "page_content": doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content, "metadata": doc.metadata} for i, doc in enumerate(documents)]
200
+ outputs = {"document_count": len(documents), "documents": doc_summary, "kwargs": kwargs}
201
+ self._end_span_tracking(trace_client, run_id, outputs=outputs)
593
202
 
594
203
  def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
595
- handler_instance_id = id(self)
596
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
597
204
  serialized_name = serialized.get('name') if serialized else "Unknown (Serialized=None)"
598
- # print(f"{log_prefix} ENTERING on_chain_start: name='{serialized_name}', run_id={run_id}. Parent: {parent_run_id}")
599
205
 
600
206
  # --- Determine Name and Span Type ---
601
207
  span_type: SpanType = "chain"
@@ -607,467 +213,170 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
607
213
 
608
214
  if node_name:
609
215
  name = node_name # Use node name if available
610
- self._log(f" LangGraph Node Start: '{name}', run_id={run_id}")
611
- if name not in self.executed_nodes: self.executed_nodes.append(name)
216
+ if name not in self.executed_nodes: self.executed_nodes.append(name) # Leaving this in for now but can probably be removed
612
217
  elif is_langgraph_root_kwarg and is_potential_root_event:
613
218
  name = "LangGraph" # Explicit root detected
614
- self._log(f" LangGraph Root Start Detected (kwargs): run_id={run_id}")
615
219
  # Add handling for other potential LangChain internal chains if needed, e.g., "RunnableSequence"
616
220
 
617
221
  # --- Ensure Trace Client ---
618
- try:
619
222
  # Pass parent_run_id to _ensure_trace_client
620
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
621
- if not trace_client:
622
- # print(f"{log_prefix} No trace client obtained in on_chain_start for {run_id} ('{name}').")
623
- return
624
-
625
- # --- Update Trace Name if Root ---
626
- # If this is the root event (parent_run_id is None) and the trace client was just created,
627
- # ensure the trace name reflects the graph's name ('LangGraph' usually).
628
- if is_potential_root_event and run_id == self._root_run_id and trace_client.name != name:
629
- self._log(f" Updating trace name from '{trace_client.name}' to '{name}' for root run {run_id}")
630
- trace_client.name = name # Update trace name to the determined root name
631
-
632
- # --- Start Span Tracking ---
633
- combined_inputs = {'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
634
- self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type=span_type, inputs=combined_inputs)
635
- # --- Store inputs for potential evaluation later ---
636
- self._run_id_to_start_inputs[run_id] = inputs # Store the raw inputs dict
637
- self._log(f" Stored inputs for run_id {run_id}")
638
- # --- End Store inputs ---
223
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
224
+ if not trace_client: return
639
225
 
640
- except Exception as e:
641
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
642
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
643
- # print(traceback.format_exc())
226
+ # --- Update Trace Name if Root ---
227
+ # If this is the root event (parent_run_id is None) and the trace client was just created,
228
+ # ensure the trace name reflects the graph's name ('LangGraph' usually).
229
+ if is_potential_root_event and run_id == self._root_run_id and trace_client.name != name:
230
+ trace_client.name = name # Update trace name to the determined root name
231
+
232
+ # --- Start Span Tracking ---
233
+ combined_inputs = {'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
234
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type=span_type, inputs=combined_inputs)
644
235
 
645
236
 
646
237
  def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any) -> Any:
647
- handler_instance_id = id(self)
648
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
649
- # print(f"{log_prefix} ENTERING on_chain_end: run_id={run_id}. Parent: {parent_run_id}")
650
238
 
651
- # --- Define instance_id for logging ---
652
- instance_id = handler_instance_id # Use the already obtained id
239
+ # Pass parent_run_id
240
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainEnd") # Corrected call
241
+ if not trace_client: return
653
242
 
654
- try:
655
- # Pass parent_run_id
656
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainEnd") # Corrected call
657
- if not trace_client:
658
- # print(f"{log_prefix} No trace client obtained in on_chain_end for {run_id}.")
659
- return
660
-
661
- span_id = self._run_id_to_span_id.get(run_id)
662
- span_type: SpanType = "chain" # Default
663
- if span_id:
664
- try:
665
- if hasattr(trace_client, 'entries') and trace_client.entries:
666
- enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
667
- if enter_entry: span_type = enter_entry.span_type
668
- else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_end span_id={span_id}")
669
- except Exception as e:
670
- self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_end: {e}")
671
- else:
672
- self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_end.")
673
- # If it's the root run ending, _end_span_tracking will handle cleanup/save
674
- if run_id == self._root_run_id:
675
- self._log(f" Run ID {run_id} matches root. Proceeding to _end_span_tracking for potential save.")
676
- else:
677
- return # Don't call end tracking if it's not the root and span wasn't tracked
678
-
679
- # --- Store input in on_chain_start if needed for evaluation ---
680
- # Retrieve stored inputs
681
- start_inputs = self._run_id_to_start_inputs.get(run_id, {})
682
- # TODO: Determine how to reliably extract the original user prompt from start_inputs
683
- # For the demo, the 'generate_recommendations' node receives the full state, not the initial prompt.
684
- # Using a placeholder for now.
685
- user_prompt_for_eval = "Unknown Input" # Placeholder - Needs refinement based on graph structure
686
- # user_prompt_for_eval = start_inputs.get("messages", [{}])[-1].get("content", "Unknown Input") # Example: If input has messages list
687
-
688
- # --- Trigger evaluation ---
689
- if "recommendations" in outputs and span_id: # Ensure span_id exists
690
- self._log(f"[Async Handler {instance_id}] Chain end for run_id {run_id} (span_id={span_id}) identified as recommendation node. Attempting evaluation.")
691
- recommendations_output = outputs.get("recommendations")
692
-
693
- if recommendations_output:
694
- eval_example = Example(
695
- input=user_prompt_for_eval, # Requires modification to store/retrieve this
696
- actual_output=recommendations_output
697
- )
698
- # TODO: Get model name dynamically if possible
699
- model_name = "gpt-4" # Placeholder
700
-
701
- self._log(f"[Async Handler {instance_id}] Submitting evaluation for span_id={span_id}")
702
- try:
703
- # Call evaluate on the trace client, passing the specific span_id
704
- # The TraceClient.async_evaluate now accepts and prioritizes this span_id.
705
- trace_client.async_evaluate(
706
- scorers=[AnswerRelevancyScorer(threshold=0.5)], # Ensure this scorer is imported
707
- example=eval_example,
708
- model=model_name,
709
- span_id=span_id # Pass the specific span_id for this node run
710
- )
711
- self._log(f"[Async Handler {instance_id}] Evaluation submitted successfully for span_id={span_id}.")
712
- except Exception as eval_e:
713
- self._log(f"[Async Handler {instance_id}] ERROR submitting evaluation for span_id={span_id}: {eval_e}")
714
- # print(traceback.format_exc()) # Print traceback for evaluation errors
715
- else:
716
- self._log(f"[Async Handler {instance_id}] Skipping evaluation for run_id {run_id} (span_id={span_id}): Missing recommendations output.")
717
- elif "recommendations" in outputs:
718
- self._log(f"[Async Handler {instance_id}] Skipping evaluation for run_id {run_id}: Span ID not found.")
719
-
720
- # --- Existing span ending logic ---
721
- # Determine span_type for end_span_tracking (copied from sync handler)
722
- end_span_type: SpanType = "chain" # Default
723
- if span_id: # Check if span_id was actually found
724
- try:
725
- if hasattr(trace_client, 'entries') and trace_client.entries:
726
- enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
727
- if enter_entry: end_span_type = enter_entry.span_type
728
- else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_end span_id={span_id}")
729
- except Exception as e:
730
- self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_end: {e}")
731
- else:
732
- self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_end, using default span_type='chain'.")
733
-
734
- # Prepare outputs for end tracking (moved down)
735
- combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
736
-
737
- # Call end_span_tracking with potentially determined span_type
738
- self._end_span_tracking(trace_client, run_id, span_type=end_span_type, outputs=combined_outputs)
739
-
740
- # --- Root node cleanup (Existing logic - slightly modified save call) ---
741
- if run_id == self._root_run_id:
742
- self._log(f"Root run {run_id} finished. Attempting to save trace...")
743
- if trace_client and not self._trace_saved:
744
- try:
745
- # Save might need to be async if TraceClient methods become async
746
- # Pass overwrite=True based on client's setting
747
- trace_id_saved, _ = trace_client.save(overwrite=trace_client.overwrite)
748
- self._trace_saved = True
749
- self._log(f"Trace {trace_id_saved} successfully saved.")
750
- # Reset tracer's active client *after* successful save
751
- if self.tracer._active_trace_client == trace_client:
752
- self.tracer._active_trace_client = None
753
- self._log("Reset active_trace_client on Tracer.")
754
- except Exception as e:
755
- self._log(f"ERROR saving trace {trace_client.trace_id}: {e}")
756
- # print(traceback.format_exc())
757
- elif trace_client and self._trace_saved:
758
- self._log(f"Trace {trace_client.trace_id} already saved. Skipping save for root run {run_id}.")
759
- elif not trace_client:
760
- self._log(f"Skipping trace save for root run {run_id}: No client available.")
761
-
762
- # Reset root run id after attempt
763
- self._root_run_id = None
764
- # Reset input storage for this handler instance
765
- self._run_id_to_start_inputs = {}
766
- self._log(f"Reset root run ID and input storage for handler {instance_id}.")
767
-
768
- # --- SYNC: Attempt Evaluation by checking output metadata ---
769
- eval_config: Optional[EvaluationConfig] = None
770
- node_name = "unknown_node" # Default node name
771
- # Ensure trace_client exists before proceeding with eval logic that uses it
772
- if trace_client:
773
- if span_id: # Try to find the node name from the 'enter' entry
774
- try:
775
- if hasattr(trace_client, 'entries') and trace_client.entries:
776
- enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
777
- if enter_entry: node_name = enter_entry.function
778
- except Exception as e:
779
- self._log(f" ERROR finding node name for span_id {span_id} in on_chain_end: {e}")
780
-
781
- if span_id and "_judgeval_eval" in outputs: # Only attempt if span exists and key is present
782
- raw_eval_config = outputs.get("_judgeval_eval")
783
- if isinstance(raw_eval_config, EvaluationConfig):
784
- eval_config = raw_eval_config
785
- self._log(f"{log_prefix} Found valid EvaluationConfig in outputs for node='{node_name}'.")
786
- elif isinstance(raw_eval_config, dict):
787
- # Attempt to reconstruct from dict
788
- try:
789
- if "scorers" in raw_eval_config and "example" in raw_eval_config:
790
- example_data = raw_eval_config["example"]
791
- reconstructed_example = Example(**example_data) if isinstance(example_data, dict) else example_data
792
-
793
- if isinstance(reconstructed_example, Example):
794
- eval_config = EvaluationConfig(
795
- scorers=raw_eval_config["scorers"],
796
- example=reconstructed_example,
797
- model=raw_eval_config.get("model"),
798
- log_results=raw_eval_config.get("log_results", True)
799
- )
800
- self._log(f"{log_prefix} Reconstructed EvaluationConfig from dict in outputs for node='{node_name}'.")
801
- else:
802
- self._log(f"{log_prefix} Could not reconstruct Example from dict in _judgeval_eval for node='{node_name}'. Skipping evaluation.")
803
- else:
804
- self._log(f"{log_prefix} Dict in _judgeval_eval missing required keys ('scorers', 'example') for node='{node_name}'. Skipping evaluation.")
805
- except Exception as recon_e:
806
- self._log(f"{log_prefix} ERROR attempting to reconstruct EvaluationConfig from dict for node='{node_name}': {recon_e}")
807
- # print(traceback.format_exc()) # Print traceback for reconstruction errors
808
- else:
809
- self._log(f"{log_prefix} Found '_judgeval_eval' key in outputs for node='{node_name}', but it wasn't an EvaluationConfig object or reconstructable dict. Skipping evaluation.")
810
-
811
- # Check eval_config *and* span_id again (this should be indented correctly)
812
- if eval_config and span_id:
813
- self._log(f"{log_prefix} Submitting evaluation for span_id={span_id}")
814
- try:
815
- # Ensure example has trace_id set if not already present
816
- if not hasattr(eval_config.example, 'trace_id') or not eval_config.example.trace_id:
817
- # Use the correct variable name 'trace_client' here
818
- eval_config.example.trace_id = trace_client.trace_id
819
- self._log(f"{log_prefix} Set trace_id={trace_client.trace_id} on evaluation example.")
820
-
821
- # Call async_evaluate on the TraceClient instance ('trace_client')
822
- # Use the correct variable name 'trace_client' here
823
- trace_client.async_evaluate( # <-- Fix: Use trace_client
824
- scorers=eval_config.scorers,
825
- example=eval_config.example,
826
- model=eval_config.model,
827
- log_results=eval_config.log_results,
828
- span_id=span_id # Pass the specific span_id for this node run
829
- )
830
- self._log(f"{log_prefix} Evaluation submitted successfully for span_id={span_id}.")
831
- except Exception as eval_e:
832
- self._log(f"{log_prefix} ERROR submitting evaluation for span_id={span_id}: {eval_e}")
833
- # print(traceback.format_exc()) # Print traceback for evaluation errors
834
- elif "_judgeval_eval" in outputs and not span_id:
835
- self._log(f"{log_prefix} WARNING: Found _judgeval_eval in outputs, but span_id for run_id {run_id} was not found. Cannot submit evaluation.")
836
- # --- End SYNC Evaluation Logic ---
243
+ span_id = self._run_id_to_span_id.get(run_id)
244
+ # If it's the root run ending, _end_span_tracking will handle cleanup/save
245
+ if not span_id and run_id != self._root_run_id:
246
+ return # Don't call end tracking if it's not the root and span wasn't tracked
837
247
 
838
- except Exception as e:
839
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
840
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
841
- # print(traceback.format_exc())
248
+ # Prepare outputs for end tracking (moved down)
249
+ combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
250
+
251
+ # Call end_span_tracking with potentially determined span_type
252
+ self._end_span_tracking(trace_client, run_id, outputs=combined_outputs)
253
+
254
+ # --- Root node cleanup (Existing logic - slightly modified save call) ---
255
+ if run_id == self._root_run_id:
256
+ if trace_client and not self._trace_saved:
257
+ # Save might need to be async if TraceClient methods become async
258
+ # Pass overwrite=True based on client's setting
259
+ trace_id_saved, trace_data = trace_client.save(overwrite=trace_client.overwrite)
260
+ self.traces.append(trace_data) # Leaving this in for now but can probably be removed
261
+ self._trace_saved = True
262
+ # Reset tracer's active client *after* successful save
263
+ if self.tracer._active_trace_client == trace_client:
264
+ self.tracer._active_trace_client = None
265
+
266
+ # Reset root run id after attempt
267
+ self._root_run_id = None
268
+ # Reset input storage for this handler instance
842
269
 
843
270
  def on_chain_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
844
- handler_instance_id = id(self)
845
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
846
- # print(f"{log_prefix} ENTERING on_chain_error: run_id={run_id}. Parent: {parent_run_id}")
271
+ # Pass parent_run_id
272
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainError") # Corrected call
273
+ if not trace_client:
274
+ return
847
275
 
848
- try:
849
- # Pass parent_run_id
850
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainError") # Corrected call
851
- if not trace_client:
852
- # print(f"{log_prefix} No trace client obtained in on_chain_error for {run_id}.")
853
- return
854
-
855
- span_id = self._run_id_to_span_id.get(run_id)
856
- span_type: SpanType = "chain" # Default
857
- if span_id:
858
- try:
859
- if hasattr(trace_client, 'entries') and trace_client.entries:
860
- enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
861
- if enter_entry: span_type = enter_entry.span_type
862
- else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_error span_id={span_id}")
863
- except Exception as e:
864
- self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_error: {e}")
865
- else:
866
- self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_error.")
867
- # Let _end_span_tracking handle potential root run cleanup
868
- if run_id != self._root_run_id:
869
- return
276
+ span_id = self._run_id_to_span_id.get(run_id)
277
+
278
+ # Let _end_span_tracking handle potential root run cleanup
279
+ if not span_id and run_id != self._root_run_id:
280
+ return
870
281
 
871
- self._end_span_tracking(trace_client, run_id, span_type=span_type, error=error)
872
- except Exception as e:
873
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
874
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
875
- # print(traceback.format_exc())
282
+ self._end_span_tracking(trace_client, run_id, error=error)
876
283
 
877
284
  def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
878
- handler_instance_id = id(self)
879
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
880
285
  name = serialized.get("name", "Unnamed Tool") if serialized else "Unknown Tool (Serialized=None)"
881
- # print(f"{log_prefix} ENTERING on_tool_start: name='{name}', run_id={run_id}. Parent: {parent_run_id}")
882
286
 
883
- try:
884
- # Pass parent_run_id
885
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
886
- if not trace_client:
887
- # print(f"{log_prefix} No trace client obtained in on_tool_start for {run_id}.")
888
- return
889
-
890
- combined_inputs = {'input_str': input_str, 'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
891
- self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="tool", inputs=combined_inputs)
892
-
893
- # --- Track executed tools (remains the same) ---
894
- if name not in self.executed_tools: self.executed_tools.append(name)
895
- parent_node_name = None
896
- if parent_run_id and parent_run_id in self._run_id_to_span_id:
897
- parent_span_id = self._run_id_to_span_id[parent_run_id]
898
- try:
899
- if hasattr(trace_client, 'entries') and trace_client.entries:
900
- parent_enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == parent_span_id and e.type == "enter" and e.span_type == "chain"), None)
901
- if parent_enter_entry:
902
- parent_node_name = parent_enter_entry.function
903
- else: self._log(f" WARNING: trace_client.entries missing for tool start parent {parent_span_id}")
904
- except Exception as e:
905
- self._log(f" ERROR finding parent node name for tool start span_id {parent_span_id}: {e}")
906
-
907
- node_tool = f"{parent_node_name}:{name}" if parent_node_name else name
908
- if node_tool not in self.executed_node_tools: self.executed_node_tools.append(node_tool)
909
- self._log(f" Tracked node_tool: {node_tool}")
910
- # --- End Track executed tools ---
911
- except Exception as e:
912
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
913
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
914
- # print(traceback.format_exc())
287
+ # Pass parent_run_id
288
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
289
+ if not trace_client: return
915
290
 
291
+ combined_inputs = {'input_str': input_str, 'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
292
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="tool", inputs=combined_inputs)
916
293
 
917
- def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
918
- handler_instance_id = id(self)
919
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
920
- # print(f"{log_prefix} ENTERING on_tool_end: run_id={run_id}. Parent: {parent_run_id}")
294
+ # --- Track executed tools (remains the same) ---
295
+ if name not in self.executed_tools: self.executed_tools.append(name) # Leaving this in for now but can probably be removed
296
+ parent_node_name = None
297
+ if parent_run_id and parent_run_id in self._run_id_to_span_id:
298
+ parent_span_id = self._run_id_to_span_id[parent_run_id]
299
+ parent_node_name = trace_client.span_id_to_span[parent_span_id].function
921
300
 
922
- try:
923
- # Pass parent_run_id
924
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolEnd") # Corrected call
925
- if not trace_client:
926
- # print(f"{log_prefix} No trace client obtained in on_tool_end for {run_id}.")
927
- return
928
- outputs = {"output": output, "kwargs": kwargs}
929
- self._end_span_tracking(trace_client, run_id, span_type="tool", outputs=outputs)
930
- except Exception as e:
931
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
932
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
933
- # print(traceback.format_exc())
301
+ node_tool = f"{parent_node_name}:{name}" if parent_node_name else name
302
+ if node_tool not in self.executed_node_tools: self.executed_node_tools.append(node_tool) # Leaving this in for now but can probably be removed
303
+ # --- End Track executed tools ---
304
+
305
+
306
+ def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
307
+ # Pass parent_run_id
308
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolEnd") # Corrected call
309
+ if not trace_client: return
310
+ outputs = {"output": output, "kwargs": kwargs}
311
+ self._end_span_tracking(trace_client, run_id, outputs=outputs)
934
312
 
935
313
  def on_tool_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
936
- handler_instance_id = id(self)
937
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
938
- # print(f"{log_prefix} ENTERING on_tool_error: run_id={run_id}. Parent: {parent_run_id}")
939
314
 
940
- try:
941
- # Pass parent_run_id
942
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolError") # Corrected call
943
- if not trace_client:
944
- # print(f"{log_prefix} No trace client obtained in on_tool_error for {run_id}.")
945
- return
946
- self._end_span_tracking(trace_client, run_id, span_type="tool", error=error)
947
- except Exception as e:
948
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
949
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
950
- # print(traceback.format_exc())
315
+ # Pass parent_run_id
316
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolError") # Corrected call
317
+ if not trace_client: return
318
+ self._end_span_tracking(trace_client, run_id, error=error)
951
319
 
952
320
  def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
953
- handler_instance_id = id(self)
954
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
321
+
955
322
  llm_name = name or serialized.get("name", "LLM Call")
956
- # print(f"{log_prefix} ENTERING on_llm_start: name='{llm_name}', run_id={run_id}. Parent: {parent_run_id}")
957
323
 
958
- try:
959
- # Pass parent_run_id
960
- trace_client = self._ensure_trace_client(run_id, parent_run_id, llm_name) # Corrected call
961
- if not trace_client:
962
- # print(f"{log_prefix} No trace client obtained in on_llm_start for {run_id}.")
963
- return
964
- inputs = {'prompts': prompts, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
965
- self._start_span_tracking(trace_client, run_id, parent_run_id, llm_name, span_type="llm", inputs=inputs)
966
- except Exception as e:
967
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
968
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
969
- # print(traceback.format_exc())
324
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, llm_name) # Corrected call
325
+ if not trace_client: return
326
+ inputs = {'prompts': prompts, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
327
+ self._start_span_tracking(trace_client, run_id, parent_run_id, llm_name, span_type="llm", inputs=inputs)
970
328
 
971
329
  def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
972
- handler_instance_id = id(self)
973
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
974
- # print(f"{log_prefix} ENTERING on_llm_end: run_id={run_id}. Parent: {parent_run_id}")
975
-
976
- # --- Debugging unchanged ---
977
- # # print(f"{log_prefix} [DEBUG on_llm_end] Received response object for run_id={run_id}:")
978
- # try:
979
- # from rich import print as rprint
980
- # r# print(response)
981
- # except ImportError: # print(response)
982
- # # print(f"{log_prefix} [DEBUG on_llm_end] response.llm_output type: {type(response.llm_output)}")
983
- # # print(f"{log_prefix} [DEBUG on_llm_end] response.llm_output content:")
984
- # try:
985
- # from rich import print as rprint
986
-
987
- # except ImportError: # print(response.llm_output)
988
330
 
989
- try:
990
- # Pass parent_run_id
991
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMEnd") # Corrected call
992
- if not trace_client:
993
- # print(f"{log_prefix} No trace client obtained in on_llm_end for {run_id}.")
994
- return
995
- outputs = {"response": response, "kwargs": kwargs}
996
- # --- Token Usage Extraction and Accumulation ---
997
- token_usage = None
998
- prompt_tokens = None # Use standard name
999
- completion_tokens = None # Use standard name
1000
- total_tokens = None
1001
- try:
1002
- if response.llm_output and isinstance(response.llm_output, dict):
1003
- # Check for OpenAI/standard 'token_usage' first
1004
- if 'token_usage' in response.llm_output:
1005
- token_usage = response.llm_output.get('token_usage')
1006
- if token_usage and isinstance(token_usage, dict):
1007
- self._log(f" Extracted OpenAI token usage for run_id={run_id}: {token_usage}")
1008
- prompt_tokens = token_usage.get('prompt_tokens')
1009
- completion_tokens = token_usage.get('completion_tokens')
1010
- total_tokens = token_usage.get('total_tokens') # OpenAI provides total
1011
- # Check for Anthropic 'usage'
1012
- elif 'usage' in response.llm_output:
1013
- token_usage = response.llm_output.get('usage')
1014
- if token_usage and isinstance(token_usage, dict):
1015
- self._log(f" Extracted Anthropic token usage for run_id={run_id}: {token_usage}")
1016
- prompt_tokens = token_usage.get('input_tokens') # Anthropic uses input_tokens
1017
- completion_tokens = token_usage.get('output_tokens') # Anthropic uses output_tokens
1018
- # Calculate total if possible
1019
- if prompt_tokens is not None and completion_tokens is not None:
1020
- total_tokens = prompt_tokens + completion_tokens
1021
- else:
1022
- self._log(f" Could not calculate total_tokens from Anthropic usage: input={prompt_tokens}, output={completion_tokens}")
1023
-
1024
- # --- Store individual usage in span output and Accumulate ---
1025
- if prompt_tokens is not None or completion_tokens is not None:
1026
- # Store individual usage for this span
1027
- outputs['usage'] = {
1028
- 'prompt_tokens': prompt_tokens,
1029
- 'completion_tokens': completion_tokens,
1030
- 'total_tokens': total_tokens
1031
- }
1032
- # Accumulate tokens for the entire trace
1033
- # if isinstance(prompt_tokens, int):
1034
- # self._current_prompt_tokens += prompt_tokens
1035
- # if isinstance(completion_tokens, int):
1036
- # self._current_completion_tokens += completion_tokens
1037
- # self._log(f" Accumulated tokens for run_id={run_id}. Current totals: Prompt={self._current_prompt_tokens}, Completion={self._current_completion_tokens}")
1038
- else:
1039
- self._log(f" Could not extract token usage structure from llm_output for run_id={run_id}")
1040
- else: self._log(f" llm_output not available/dict for run_id={run_id}")
1041
- except Exception as e:
1042
- self._log(f" ERROR extracting/accumulating token usage for run_id={run_id}: {e}")
1043
- # --- End Token Usage ---
1044
- self._end_span_tracking(trace_client, run_id, span_type="llm", outputs=outputs)
1045
- except Exception as e:
1046
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1047
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1048
- # print(traceback.format_exc())
331
+ # Pass parent_run_id
332
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMEnd") # Corrected call
333
+ if not trace_client:
334
+ return
335
+ outputs = {"response": response, "kwargs": kwargs}
336
+ # --- Token Usage Extraction and Accumulation ---
337
+ token_usage = None
338
+ prompt_tokens = None # Use standard name
339
+ completion_tokens = None # Use standard name
340
+ total_tokens = None
341
+ if response.llm_output and isinstance(response.llm_output, dict):
342
+ # Check for OpenAI/standard 'token_usage' first
343
+ if 'token_usage' in response.llm_output:
344
+ token_usage = response.llm_output.get('token_usage')
345
+ if token_usage and isinstance(token_usage, dict):
346
+ prompt_tokens = token_usage.get('prompt_tokens')
347
+ completion_tokens = token_usage.get('completion_tokens')
348
+ total_tokens = token_usage.get('total_tokens') # OpenAI provides total
349
+ # Check for Anthropic 'usage'
350
+ elif 'usage' in response.llm_output:
351
+ token_usage = response.llm_output.get('usage')
352
+ if token_usage and isinstance(token_usage, dict):
353
+ prompt_tokens = token_usage.get('input_tokens') # Anthropic uses input_tokens
354
+ completion_tokens = token_usage.get('output_tokens') # Anthropic uses output_tokens
355
+ # Calculate total if possible
356
+ if prompt_tokens is not None and completion_tokens is not None:
357
+ total_tokens = prompt_tokens + completion_tokens
358
+
359
+ # --- Store individual usage in span output and Accumulate ---
360
+ if prompt_tokens is not None or completion_tokens is not None:
361
+ # Store individual usage for this span
362
+ outputs['usage'] = {
363
+ 'prompt_tokens': prompt_tokens,
364
+ 'completion_tokens': completion_tokens,
365
+ 'total_tokens': total_tokens
366
+ }
367
+
368
+ self._end_span_tracking(trace_client, run_id, outputs=outputs)
369
+ # --- End Token Usage ---
1049
370
 
1050
371
  def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1051
- handler_instance_id = id(self)
1052
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1053
- # print(f"{log_prefix} ENTERING on_llm_error: run_id={run_id}. Parent: {parent_run_id}")
1054
372
 
1055
- try:
1056
- # Pass parent_run_id
1057
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMError") # Corrected call
1058
- if not trace_client:
1059
- # print(f"{log_prefix} No trace client obtained in on_llm_error for {run_id}.")
1060
- return
1061
- self._end_span_tracking(trace_client, run_id, span_type="llm", error=error)
1062
- except Exception as e:
1063
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1064
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1065
- # print(traceback.format_exc())
373
+ # Pass parent_run_id
374
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMError") # Corrected call
375
+ if not trace_client: return
376
+ self._end_span_tracking(trace_client, run_id, error=error)
1066
377
 
1067
378
  def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
1068
379
  # Reuse on_llm_start logic, adding message formatting if needed
1069
- handler_instance_id = id(self)
1070
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1071
380
  chat_model_name = name or serialized.get("name", "ChatModel Call")
1072
381
  # Add OPENAI_API_CALL suffix if model is OpenAI and not present
1073
382
  is_openai = any(key.startswith('openai') for key in serialized.get('secrets', {}).keys()) or 'openai' in chat_model_name.lower()
@@ -1086,914 +395,25 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
1086
395
  elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
1087
396
  chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
1088
397
 
1089
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1090
- # print(f"{log_prefix} ENTERING on_chat_model_start: name='{chat_model_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1091
- try:
1092
- # The call below was missing parent_run_id
1093
- trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name) # Corrected call with parent_run_id
1094
- if not trace_client: return
1095
- inputs = {'messages': messages, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
1096
- self._start_span_tracking(trace_client, run_id, parent_run_id, chat_model_name, span_type="llm", inputs=inputs) # Use 'llm' span_type for consistency
1097
- except Exception as e:
1098
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1099
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chat_model_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1100
- # print(traceback.format_exc())
398
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name) # Corrected call with parent_run_id
399
+ if not trace_client: return
400
+ inputs = {'messages': messages, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
401
+ self._start_span_tracking(trace_client, run_id, parent_run_id, chat_model_name, span_type="llm", inputs=inputs) # Use 'llm' span_type for consistency
1101
402
 
1102
- # --- Agent Methods (Async versions - ensure parent_run_id passed if needed) ---
1103
403
  def on_agent_action(self, action: AgentAction, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1104
- handler_instance_id = id(self)
1105
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1106
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1107
- # print(f"{log_prefix} ENTERING on_agent_action: tool={action.tool}, run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
404
+ action_tool = action.tool
405
+ name = f"AGENT_ACTION_{(action_tool).upper()}"
406
+ # Pass parent_run_id
407
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
408
+ if not trace_client: return
1108
409
 
1109
- try:
1110
- # Optional: Implement detailed tracing if needed
1111
- pass
1112
- except Exception as e:
1113
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1114
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_agent_action for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1115
- # print(traceback.format_exc())
410
+ inputs = {'tool_input': action.tool_input, 'log': action.log, 'messages': action.messages, 'kwargs': kwargs}
411
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="agent", inputs=inputs)
1116
412
 
1117
413
  def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1118
- handler_instance_id = id(self)
1119
- log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1120
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1121
- # print(f"{log_prefix} ENTERING on_agent_finish: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1122
-
1123
- try:
1124
- # Optional: Implement detailed tracing if needed
1125
- pass
1126
- except Exception as e:
1127
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1128
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_agent_finish for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1129
- # print(traceback.format_exc())
1130
-
1131
- # --- Async Handler ---
1132
-
1133
- # --- NEW Fully Functional Async Handler ---
1134
- class AsyncJudgevalCallbackHandler(AsyncCallbackHandler):
1135
- """
1136
- Async LangChain Callback Handler using run_id/parent_run_id for hierarchy.
1137
- Manages its own internal TraceClient instance created upon first use.
1138
- Includes verbose logging and defensive checks.
1139
- """
1140
- lc_serializable = False
1141
- lc_kwargs = {}
1142
-
1143
- def __init__(self, tracer: Tracer):
1144
- instance_id = id(self)
1145
- # print(f"{HANDLER_LOG_PREFIX} *** Async Handler instance {instance_id} __init__ called. ***")
1146
- self.tracer = tracer
1147
- self._trace_client: Optional[TraceClient] = None
1148
- self._run_id_to_span_id: Dict[UUID, str] = {}
1149
- self._span_id_to_start_time: Dict[str, float] = {}
1150
- self._span_id_to_depth: Dict[str, int] = {}
1151
- self._run_id_to_context_token: Dict[UUID, contextvars.Token] = {} # Initialize missing attribute
1152
- self._root_run_id: Optional[UUID] = None
1153
- self._trace_context_token: Optional[contextvars.Token] = None # Restore missing attribute
1154
- self._trace_saved: bool = False # <<< ADDED MISSING ATTRIBUTE
1155
- self._run_id_to_start_inputs: Dict[UUID, Dict] = {} # <<< ADDED input storage
1156
-
1157
- # --- Token Count Accumulators ---
1158
- # self._current_prompt_tokens = 0
1159
- # self._current_completion_tokens = 0
1160
- # --- End Token Count Accumulators ---
1161
-
1162
- self.executed_nodes: List[str] = []
1163
- self.executed_tools: List[str] = []
1164
- self.executed_node_tools: List[str] = []
1165
-
1166
- # NOTE: _ensure_trace_client remains synchronous as it doesn't involve async I/O
1167
- def _ensure_trace_client(self, run_id: UUID, event_name: str) -> Optional[TraceClient]:
1168
- """Ensures the internal trace client is initialized. Returns client or None."""
1169
- handler_instance_id = id(self)
1170
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1171
- if self._trace_client is None:
1172
- trace_id = str(uuid.uuid4())
1173
- project = self.tracer.project_name
1174
- try:
1175
- client_instance = TraceClient(
1176
- self.tracer, trace_id, event_name, project_name=project,
1177
- overwrite=False, rules=self.tracer.rules,
1178
- enable_monitoring=self.tracer.enable_monitoring,
1179
- enable_evaluations=self.tracer.enable_evaluations
1180
- )
1181
- self._trace_client = client_instance
1182
- if self._trace_client:
1183
- self.tracer._active_trace_client = self._trace_client
1184
- self._current_trace_id = self._trace_client.trace_id
1185
- if self._root_run_id is None:
1186
- self._root_run_id = run_id
1187
- else:
1188
- return None
1189
- except Exception as e:
1190
- self._trace_client = None
1191
- return None
1192
- return self._trace_client
1193
-
1194
- def _log(self, message: str):
1195
- """Helper for consistent logging format."""
1196
- pass
1197
-
1198
- # NOTE: _start_span_tracking remains mostly synchronous, TraceClient.add_entry might become async later
1199
- def _start_span_tracking(
1200
- self,
1201
- trace_client: TraceClient,
1202
- run_id: UUID,
1203
- parent_run_id: Optional[UUID],
1204
- name: str,
1205
- span_type: SpanType = "span",
1206
- inputs: Optional[Dict[str, Any]] = None
1207
- ):
1208
- self._log(f"_start_span_tracking called for: name='{name}', run_id={run_id}, parent_run_id={parent_run_id}, span_type={span_type}")
1209
- if not trace_client:
1210
- handler_instance_id = id(self)
1211
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1212
- self._log(f"{log_prefix} FATAL ERROR in _start_span_tracking: trace_client argument is None for name='{name}', run_id={run_id}. Aborting span start.")
1213
- return
1214
-
1215
- # --- NEW: Set trace context variable if not already set for this trace ---
1216
- if self._trace_context_token is None:
1217
- try:
1218
- self._trace_context_token = current_trace_var.set(trace_client)
1219
- self._log(f" Set current_trace_var for trace_id {trace_client.trace_id}")
1220
- except Exception as e:
1221
- self._log(f" ERROR setting current_trace_var for trace_id {trace_client.trace_id}: {e}")
1222
- # --- END NEW ---
1223
-
1224
- handler_instance_id = id(self)
1225
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1226
- trace_client_instance_id = id(trace_client) if trace_client else 'None'
1227
- # print(f"{log_prefix} _start_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
1228
-
1229
- start_time = time.time()
1230
- span_id = str(uuid.uuid4())
1231
- parent_span_id: Optional[str] = None
1232
- current_depth = 0
1233
-
1234
- if parent_run_id and parent_run_id in self._run_id_to_span_id:
1235
- parent_span_id = self._run_id_to_span_id[parent_run_id]
1236
- if parent_span_id in self._span_id_to_depth:
1237
- current_depth = self._span_id_to_depth[parent_span_id] + 1
1238
- else:
1239
- self._log(f" WARNING: Parent span depth not found for parent_span_id: {parent_span_id}. Setting depth to 0.")
1240
- elif parent_run_id:
1241
- self._log(f" WARNING: parent_run_id {parent_run_id} provided for '{name}' ({run_id}) but parent span not tracked. Treating as depth 0.")
1242
- else:
1243
- self._log(f" No parent_run_id provided. Treating '{name}' as depth 0.")
1244
-
1245
- self._run_id_to_span_id[run_id] = span_id
1246
- self._span_id_to_start_time[span_id] = start_time
1247
- self._span_id_to_depth[span_id] = current_depth
1248
- self._log(f" Tracking new span: span_id={span_id}, depth={current_depth}")
1249
-
1250
- # --- Set SPAN context variable ONLY for chain (node) spans ---
1251
- if span_type == "chain":
1252
- try:
1253
- # Set current_span_var for the node's execution context
1254
- token = current_span_var.set(span_id) # Store the token
1255
- self._run_id_to_context_token[run_id] = token # Store token in the dictionary
1256
- self._log(f" Set current_span_var to {span_id} for run_id {run_id} (type: chain)")
1257
- except Exception as e:
1258
- self._log(f" ERROR setting current_span_var for run_id {run_id}: {e}")
1259
- # --- END Span Context Var Logic ---
1260
-
1261
- try:
1262
- # TODO: Check if trace_client.add_entry needs await if TraceClient becomes async
1263
- trace_client.add_entry(TraceEntry(
1264
- type="enter", span_id=span_id, trace_id=trace_client.trace_id,
1265
- parent_span_id=parent_span_id, function=name, depth=current_depth,
1266
- message=name, created_at=start_time, span_type=span_type
1267
- ))
1268
- self._log(f" Added 'enter' entry for span_id={span_id}")
1269
- except Exception as e:
1270
- self._log(f" ERROR adding 'enter' entry for span_id {span_id}: {e}")
1271
- # print(traceback.format_exc())
1272
-
1273
- if inputs:
1274
- # _record_input_data is also sync for now
1275
- self._record_input_data(trace_client, run_id, inputs)
1276
-
1277
- # NOTE: _end_span_tracking remains mostly synchronous, TraceClient.save might become async later
1278
- def _end_span_tracking(
1279
- self,
1280
- trace_client: TraceClient,
1281
- run_id: UUID,
1282
- span_type: SpanType = "span",
1283
- outputs: Optional[Any] = None,
1284
- error: Optional[BaseException] = None
1285
- ):
1286
- # self._log(f"_end_span_tracking called for: run_id={run_id}, span_type={span_type}")
1287
-
1288
- # --- Define instance_id early for logging/cleanup ---
1289
- instance_id = id(self)
1290
-
1291
- if not trace_client:
1292
- # Use instance_id defined above
1293
- # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
1294
- # self._log(f"{log_prefix} FATAL ERROR in _end_span_tracking: trace_client argument is None for run_id={run_id}. Aborting span end.")
1295
- return
1296
-
1297
- # Use instance_id defined above
1298
- # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
1299
- # trace_client_instance_id = id(trace_client) if trace_client else 'None'
1300
- # # print(f"{log_prefix} _end_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
1301
-
1302
- if run_id not in self._run_id_to_span_id:
1303
- # self._log(f" WARNING: Attempting to end span for untracked run_id: {run_id}")
1304
- # Allow root run end to proceed for cleanup/save attempt even if span wasn't tracked
1305
- if run_id != self._root_run_id:
1306
- return
1307
- else:
1308
- # self._log(f" Allowing root run {run_id} end logic to proceed despite untracked span.")
1309
- span_id = None # Indicate span wasn't found for duration/metadata lookup
1310
- else:
1311
- span_id = self._run_id_to_span_id[run_id]
1312
-
1313
- start_time = self._span_id_to_start_time.get(span_id) if span_id else None
1314
- depth = self._span_id_to_depth.get(span_id, 0) if span_id else 0 # Use 0 depth if span_id is None
1315
- duration = time.time() - start_time if start_time is not None else None
1316
- # self._log(f" Ending span for run_id={run_id} (span_id={span_id}). Start time={start_time}, Duration={duration}, Depth={depth}")
1317
-
1318
- # Record output/error first
1319
- if error:
1320
- # self._log(f" Recording error for run_id={run_id} (span_id={span_id}): {error}")
1321
- self._record_output_data(trace_client, run_id, error)
1322
- elif outputs is not None:
1323
- # output_repr = repr(outputs)
1324
- # log_output = (output_repr[:100] + '...') if len(output_repr) > 103 else output_repr
1325
- # self._log(f" Recording output for run_id={run_id} (span_id={span_id}): {log_output}")
1326
- self._record_output_data(trace_client, run_id, outputs)
1327
-
1328
- # Add exit entry (only if span was tracked)
1329
- if span_id:
1330
- entry_function_name = "unknown"
1331
- try:
1332
- if hasattr(trace_client, 'entries') and trace_client.entries:
1333
- entry_function_name = next((e.function for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), "unknown")
1334
- else:
1335
- # self._log(f" WARNING: Cannot determine function name for exit span_id {span_id}, trace_client.entries missing or empty.")
1336
- pass
1337
- except Exception as e:
1338
- # self._log(f" ERROR finding function name for exit entry span_id {span_id}: {e}")
1339
- # # print(traceback.format_exc())
1340
- pass
414
+ # Pass parent_run_id
415
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentFinish") # Corrected call
416
+ if not trace_client: return
1341
417
 
1342
- try:
1343
- trace_client.add_entry(TraceEntry(
1344
- type="exit", span_id=span_id, trace_id=trace_client.trace_id,
1345
- depth=depth, created_at=time.time(), duration=duration,
1346
- span_type=span_type, function=entry_function_name
1347
- ))
1348
- # self._log(f" Added 'exit' entry for span_id={span_id}, function='{entry_function_name}'")
1349
- except Exception as e:
1350
- # self._log(f" ERROR adding 'exit' entry for span_id {span_id}: {e}")
1351
- # # print(traceback.format_exc())
1352
- pass
1353
-
1354
- # Clean up dictionaries for this specific span
1355
- if span_id in self._span_id_to_start_time: del self._span_id_to_start_time[span_id]
1356
- if span_id in self._span_id_to_depth: del self._span_id_to_depth[span_id]
1357
-
1358
- # Pop context token (Sync version) but don't reset
1359
- token = self._run_id_to_context_token.pop(run_id, None)
1360
- if token:
1361
- # self._log(f" Popped token for run_id {run_id} (was {span_id}), not resetting context var.")
1362
- pass
1363
- else:
1364
- # self._log(f" Skipping exit entry and cleanup for run_id {run_id} as span_id was not found.")
1365
- pass
1366
-
1367
- # Check if this is the root run ending
1368
- if run_id == self._root_run_id:
1369
- trace_saved_successfully = False # Track save success
1370
- try:
1371
- # Reset root run id after attempt
1372
- self._root_run_id = None
1373
- # Reset input storage for this handler instance
1374
- self._run_id_to_start_inputs = {}
1375
- self._log(f"Reset root run ID and input storage for handler {instance_id}.")
1376
-
1377
- self._log(f"Root run {run_id} finished. Attempting to save trace...")
1378
- if self._trace_client and not self._trace_saved: # Check if not already saved
1379
- try:
1380
- # TODO: Check if trace_client.save needs await if TraceClient becomes async
1381
- trace_id, _ = self._trace_client.save(overwrite=self._trace_client.overwrite) # Use client's overwrite setting
1382
- self._log(f"Trace {trace_id} successfully saved.")
1383
- self._trace_saved = True # Set flag only after successful save
1384
- trace_saved_successfully = True # Mark success
1385
- except Exception as e:
1386
- self._log(f"ERROR saving trace {self._trace_client.trace_id}: {e}")
1387
- # print(traceback.format_exc())
1388
- # REMOVED FINALLY BLOCK THAT RESET STATE HERE
1389
- elif self._trace_client and self._trace_saved:
1390
- self._log(f" Trace {self._trace_client.trace_id} already saved. Skipping save.")
1391
- else:
1392
- self._log(f" WARNING: Root run {run_id} ended, but trace client was None. Cannot save trace.")
1393
- finally:
1394
- # --- NEW: Consolidated Cleanup Logic ---
1395
- # This block executes regardless of save success/failure
1396
- self._log(f" Performing cleanup for root run {run_id} in handler {instance_id}.")
1397
- # Reset root run id
1398
- self._root_run_id = None
1399
- # Reset input storage for this handler instance
1400
- self._run_id_to_start_inputs = {}
1401
- # Reset tracer's active client ONLY IF it was this handler's client
1402
- if self.tracer._active_trace_client == self._trace_client:
1403
- self.tracer._active_trace_client = None
1404
- self._log(" Reset active_trace_client on Tracer.")
1405
- # Completely remove trace_context_token cleanup as it's not used in sync handler
1406
- # Optionally: Reset the entire trace client instance for this handler?
1407
- # self._trace_client = None # Uncomment if handler should reset client completely after root run
1408
- self._log(f" Cleanup complete for root run {run_id}.")
1409
- # --- End Cleanup Logic ---
1410
-
1411
- # NOTE: _record_input_data remains synchronous for now
1412
- def _record_input_data(self,
1413
- trace_client: TraceClient,
1414
- run_id: UUID,
1415
- inputs: Dict[str, Any]):
1416
- # self._log(f"_record_input_data called for run_id={run_id}")
1417
- if run_id not in self._run_id_to_span_id:
1418
- # self._log(f" WARNING: Attempting to record input for untracked run_id: {run_id}")
1419
- return
1420
- if not trace_client:
1421
- # self._log(f" ERROR: TraceClient is None when trying to record input for run_id={run_id}")
1422
- return
1423
-
1424
- span_id = self._run_id_to_span_id[run_id]
1425
- depth = self._span_id_to_depth.get(span_id, 0)
1426
- # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
1427
-
1428
- function_name = "unknown"
1429
- span_type: SpanType = "span"
1430
- try:
1431
- # Find the corresponding 'enter' entry to get the function name and span type
1432
- enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
1433
- if enter_entry:
1434
- function_name = enter_entry.function
1435
- span_type = enter_entry.span_type
1436
- # self._log(f" Found function='{function_name}', span_type='{span_type}' for input span_id={span_id}")
1437
- else:
1438
- # self._log(f" WARNING: Could not find 'enter' entry for input span_id={span_id}")
1439
- pass
1440
- except Exception as e:
1441
- # self._log(f" ERROR finding enter entry for input span_id {span_id}: {e}")
1442
- # # print(traceback.format_exc())
1443
- pass
1444
-
1445
- try:
1446
- input_entry = TraceEntry(
1447
- type="input",
1448
- span_id=span_id,
1449
- trace_id=trace_client.trace_id,
1450
- parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
1451
- function=function_name,
1452
- depth=depth,
1453
- message=f"Input to {function_name}",
1454
- created_at=time.time(),
1455
- inputs=inputs,
1456
- span_type=span_type
1457
- )
1458
- trace_client.add_entry(input_entry)
1459
- # self._log(f" Added 'input' entry directly for span_id={span_id}")
1460
- except Exception as e:
1461
- # self._log(f" ERROR adding 'input' entry directly for span_id {span_id}: {e}")
1462
- # # print(traceback.format_exc())
1463
- pass
1464
-
1465
- # NOTE: _record_output_data remains synchronous for now
1466
- def _record_output_data(self,
1467
- trace_client: TraceClient,
1468
- run_id: UUID,
1469
- output: Any):
1470
- self._log(f"_record_output_data called for run_id={run_id}")
1471
- if run_id not in self._run_id_to_span_id:
1472
- # self._log(f" WARNING: Attempting to record output for untracked run_id: {run_id}")
1473
- return
1474
- if not trace_client:
1475
- # self._log(f" ERROR: TraceClient is None when trying to record output for run_id={run_id}")
1476
- return
1477
-
1478
- span_id = self._run_id_to_span_id[run_id]
1479
- depth = self._span_id_to_depth.get(span_id, 0)
1480
- # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
1481
-
1482
- function_name = "unknown"
1483
- span_type: SpanType = "span"
1484
- try:
1485
- # Find the corresponding 'enter' entry to get the function name and span type
1486
- enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
1487
- if enter_entry:
1488
- function_name = enter_entry.function
1489
- span_type = enter_entry.span_type
1490
- # self._log(f" Found function='{function_name}', span_type='{span_type}' for output span_id={span_id}")
1491
- else:
1492
- # self._log(f" WARNING: Could not find 'enter' entry for output span_id={span_id}")
1493
- pass
1494
- except Exception as e:
1495
- # self._log(f" ERROR finding enter entry for output span_id {span_id}: {e}")
1496
- # # print(traceback.format_exc())
1497
- pass
1498
-
1499
- try:
1500
- output_entry = TraceEntry(
1501
- type="output",
1502
- span_id=span_id,
1503
- trace_id=trace_client.trace_id,
1504
- parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
1505
- function=function_name,
1506
- depth=depth,
1507
- message=f"Output from {function_name}",
1508
- created_at=time.time(),
1509
- output=output, # Langchain outputs are typically serializable directly
1510
- span_type=span_type
1511
- )
1512
- trace_client.add_entry(output_entry)
1513
- self._log(f" Added 'output' entry directly for span_id={span_id}")
1514
- except Exception as e:
1515
- self._log(f" ERROR adding 'output' entry directly for span_id {span_id}: {e}")
1516
- # print(traceback.format_exc())
1517
-
1518
- # --- Async Callback Methods ---
1519
-
1520
- async def on_retriever_start(self, serialized: Dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
1521
- handler_instance_id = id(self)
1522
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1523
- serialized_name = serialized.get('name', 'Unknown') if serialized else "Unknown (Serialized=None)"
1524
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1525
- # print(f"{log_prefix} ENTERING on_retriever_start: name='{serialized_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1526
- try:
1527
- name = f"RETRIEVER_{(serialized_name).upper()}"
1528
- # Pass parent_run_id
1529
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
1530
- if not trace_client: return
1531
- inputs = {'query': query, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
1532
- self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="retriever", inputs=inputs)
1533
- except Exception as e:
1534
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1535
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1536
- # print(traceback.format_exc())
1537
-
1538
- async def on_retriever_end(self, documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1539
- handler_instance_id = id(self)
1540
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1541
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1542
- # print(f"{log_prefix} ENTERING on_retriever_end: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1543
- try:
1544
- # Pass parent_run_id
1545
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd") # Corrected call
1546
- if not trace_client: return
1547
- doc_summary = [{"index": i, "page_content": doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content, "metadata": doc.metadata} for i, doc in enumerate(documents)]
1548
- outputs = {"document_count": len(documents), "documents": doc_summary, "kwargs": kwargs}
1549
- self._end_span_tracking(trace_client, run_id, span_type="retriever", outputs=outputs)
1550
- except Exception as e:
1551
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1552
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1553
- # print(traceback.format_exc())
1554
-
1555
- async def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
1556
- handler_instance_id = id(self)
1557
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1558
- # Handle potential None for serialized safely
1559
- serialized_name = serialized.get('name') if serialized else None
1560
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1561
- # Log the potentially generic or specific name found in serialized
1562
- log_name = serialized_name if serialized_name else "Unknown (Serialized=None)"
1563
- # print(f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}] ENTERING on_chain_start: serialized_name='{log_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1564
-
1565
- try:
1566
- # Determine the best name and span type
1567
- name = "Unknown Chain" # Default
1568
- span_type: SpanType = "chain"
1569
- node_name = metadata.get("langgraph_node") if metadata else None
1570
- is_langgraph_root = kwargs.get('name') == 'LangGraph' # Check kwargs
1571
- is_potential_root_event = parent_run_id is None
1572
-
1573
- # Define generic names to ignore if node_name is not present
1574
- GENERIC_NAMES = ["RunnableSequence", "RunnableParallel", "RunnableLambda", "LangGraph", "__start__", "__end__"]
1575
-
1576
- if node_name:
1577
- name = node_name
1578
- self._log(f" LangGraph Node Start Detected: '{name}', run_id={run_id}, parent_run_id={parent_run_id}")
1579
- if name not in self.executed_nodes: self.executed_nodes.append(name)
1580
- elif serialized_name and serialized_name not in GENERIC_NAMES:
1581
- name = serialized_name
1582
- self._log(f" LangGraph Functional Step (Router?): '{name}', run_id={run_id}, parent_run_id={parent_run_id}")
1583
- # Correct root detection: Should primarily rely on parent_run_id being None for the *first* event.
1584
- # kwargs name='LangGraph' might appear later.
1585
- elif is_potential_root_event: # Check if it's the potential root event
1586
- # Use the serialized name if available and not generic, otherwise default to 'LangGraph'
1587
- if serialized_name and serialized_name not in GENERIC_NAMES:
1588
- name = serialized_name
1589
- else:
1590
- name = "LangGraph"
1591
- self._log(f" LangGraph Root Start Detected (parent_run_id=None): Name='{name}', run_id={run_id}")
1592
- if self._root_run_id is None: # Only set root_run_id once
1593
- self._log(f" Setting root run ID to {run_id}")
1594
- self._root_run_id = run_id
1595
- # Defer trace client name update until client is ensured
1596
- elif serialized_name: # Fallback if node_name missing and serialized_name was generic or root wasn't detected
1597
- name = serialized_name
1598
- self._log(f" Fallback to serialized_name: '{name}', run_id={run_id}")
1599
-
1600
- # Ensure trace client exists (using the determined name for initialization if needed)
1601
- # Pass parent_run_id
1602
- trace_client = self._ensure_trace_client(run_id, name) # FIXED: Removed parent_run_id
1603
- if not trace_client:
1604
- # print(f"{log_prefix} No trace client obtained in on_chain_start for {run_id} ('{name}').")
1605
- return
1606
-
1607
- # --- Update Trace Name if Root (Moved After Client Ensure) ---
1608
- if is_potential_root_event and run_id == self._root_run_id and trace_client.name != name:
1609
- self._log(f" Updating trace name from '{trace_client.name}' to '{name}' for root run {run_id}")
1610
- trace_client.name = name
1611
- # --- End Update Trace Name ---
1612
-
1613
- # Start span tracking using the determined name and span_type
1614
- combined_inputs = {'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
1615
- self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type=span_type, inputs=combined_inputs)
1616
- # --- Store inputs for potential evaluation later ---
1617
- self._run_id_to_start_inputs[run_id] = inputs # Store the raw inputs dict
1618
- self._log(f" Stored inputs for run_id {run_id}")
1619
- # --- End Store inputs ---
1620
-
1621
- except Exception as e:
1622
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1623
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1624
- # print(traceback.format_exc())
1625
-
1626
- async def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any) -> Any:
1627
- """
1628
- Ends span tracking for a chain/node and attempts evaluation if applicable.
1629
- """
1630
- # --- Existing logging and client check ---
1631
- instance_id = id(self)
1632
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {instance_id}]"
1633
- self._log(f"{log_prefix} ENTERING on_chain_end: run_id={run_id}. Current TraceClient ID: {id(self._trace_client) if self._trace_client else 'None'}")
1634
- client = self._ensure_trace_client(run_id, "on_chain_end") # Ensure client exists
1635
- if not client:
1636
- self._log(f"{log_prefix} No TraceClient found for on_chain_end ({run_id}). Aborting.")
1637
- return # Early exit if no client
1638
-
1639
- # --- Get span_id associated with this chain run ---
1640
- span_id = self._run_id_to_span_id.get(run_id)
1641
-
1642
- # --- Existing span ending logic ---
1643
- # Determine span_type for end_span_tracking (copied from sync handler)
1644
- end_span_type: SpanType = "chain" # Default
1645
- if span_id: # Check if span_id was actually found
1646
- try:
1647
- if hasattr(client, 'entries') and client.entries:
1648
- enter_entry = next((e for e in reversed(client.entries) if e.span_id == span_id and e.type == "enter"), None)
1649
- if enter_entry: end_span_type = enter_entry.span_type
1650
- else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_end span_id={span_id}")
1651
- except Exception as e:
1652
- self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_end: {e}")
1653
- else:
1654
- self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_end, using default span_type='chain'.")
1655
-
1656
- # Prepare outputs for end tracking (moved down)
1657
- combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
1658
-
1659
- # Call end_span_tracking with potentially determined span_type
1660
- self._end_span_tracking(client, run_id, span_type=end_span_type, outputs=combined_outputs)
1661
-
1662
- # --- Root node cleanup REMOVED - Now handled in _end_span_tracking ---
1663
-
1664
- # --- NEW: Attempt Evaluation by checking output metadata ---
1665
- eval_config: Optional[EvaluationConfig] = None
1666
- node_name = "unknown_node" # Default node name
1667
- # Ensure client exists before proceeding with eval logic that uses it
1668
- if client:
1669
- if span_id: # Try to find the node name from the 'enter' entry
1670
- try:
1671
- if hasattr(client, 'entries') and client.entries:
1672
- enter_entry = next((e for e in reversed(client.entries) if e.span_id == span_id and e.type == "enter"), None)
1673
- if enter_entry: node_name = enter_entry.function
1674
- except Exception as e:
1675
- self._log(f" ERROR finding node name for span_id {span_id} in on_chain_end: {e}")
1676
-
1677
- if span_id and "_judgeval_eval" in outputs: # Only attempt if span exists and key is present
1678
- raw_eval_config = outputs.get("_judgeval_eval")
1679
- if isinstance(raw_eval_config, EvaluationConfig):
1680
- eval_config = raw_eval_config
1681
- self._log(f"{log_prefix} Found valid EvaluationConfig in outputs for node='{node_name}'.")
1682
- elif isinstance(raw_eval_config, dict):
1683
- # Attempt to reconstruct from dict if needed (e.g., if state serialization occurred)
1684
- try:
1685
- # Basic check for required keys before attempting reconstruction
1686
- if "scorers" in raw_eval_config and "example" in raw_eval_config:
1687
- # Example might also be a dict, try reconstructing it
1688
- example_data = raw_eval_config["example"]
1689
- reconstructed_example = Example(**example_data) if isinstance(example_data, dict) else example_data
1690
-
1691
- if isinstance(reconstructed_example, Example):
1692
- eval_config = EvaluationConfig(
1693
- scorers=raw_eval_config["scorers"], # Assumes scorers are serializable or passed correctly
1694
- example=reconstructed_example,
1695
- model=raw_eval_config.get("model"),
1696
- log_results=raw_eval_config.get("log_results", True)
1697
- )
1698
- self._log(f"{log_prefix} Reconstructed EvaluationConfig from dict in outputs for node='{node_name}'.")
1699
- else:
1700
- self._log(f"{log_prefix} Could not reconstruct Example from dict in _judgeval_eval for node='{node_name}'. Skipping evaluation.")
1701
- else:
1702
- self._log(f"{log_prefix} Dict in _judgeval_eval missing required keys ('scorers', 'example') for node='{node_name}'. Skipping evaluation.")
1703
- except Exception as recon_e:
1704
- self._log(f"{log_prefix} ERROR attempting to reconstruct EvaluationConfig from dict for node='{node_name}': {recon_e}")
1705
- # print(traceback.format_exc()) # Print traceback for reconstruction errors
1706
- else:
1707
- self._log(f"{log_prefix} Found '_judgeval_eval' key in outputs for node='{node_name}', but it wasn't an EvaluationConfig object or reconstructable dict. Skipping evaluation.")
1708
-
1709
-
1710
- if eval_config and span_id: # Check eval_config *and* span_id again
1711
- self._log(f"{log_prefix} Submitting evaluation for span_id={span_id}")
1712
- try:
1713
- # Ensure example has trace_id set if not already present
1714
- if not hasattr(eval_config.example, 'trace_id') or not eval_config.example.trace_id:
1715
- # Use the correct variable name 'client' here for the async handler
1716
- eval_config.example.trace_id = client.trace_id
1717
- self._log(f"{log_prefix} Set trace_id={client.trace_id} on evaluation example.")
1718
-
1719
- # Call async_evaluate on the TraceClient instance ('client')
1720
- # Use the correct variable name 'client' here for the async handler
1721
- client.async_evaluate(
1722
- scorers=eval_config.scorers,
1723
- example=eval_config.example,
1724
- model=eval_config.model,
1725
- log_results=eval_config.log_results,
1726
- span_id=span_id # Pass the specific span_id for this node run
1727
- )
1728
- self._log(f"{log_prefix} Evaluation submitted successfully for span_id={span_id}.")
1729
- except Exception as eval_e:
1730
- self._log(f"{log_prefix} ERROR submitting evaluation for span_id={span_id}: {eval_e}")
1731
- # print(traceback.format_exc()) # Print traceback for evaluation errors
1732
- elif "_judgeval_eval" in outputs and not span_id:
1733
- self._log(f"{log_prefix} WARNING: Found _judgeval_eval in outputs, but span_id for run_id {run_id} was not found. Cannot submit evaluation.")
1734
- # --- End NEW Evaluation Logic ---
1735
-
1736
- async def on_chain_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1737
- handler_instance_id = id(self)
1738
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1739
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1740
- # print(f"{log_prefix} ENTERING on_chain_error: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1741
- try:
1742
- # Pass parent_run_id
1743
- trace_client = self._ensure_trace_client(run_id, "ChainError") # FIXED: Removed parent_run_id
1744
- if not trace_client: return
1745
-
1746
- span_id = self._run_id_to_span_id.get(run_id)
1747
- span_type: SpanType = "chain"
1748
- if span_id:
1749
- try:
1750
- if hasattr(trace_client, 'entries') and trace_client.entries:
1751
- enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
1752
- if enter_entry: span_type = enter_entry.span_type
1753
- else: self._log(f" WARNING: trace_client.entries not available for on_chain_error span_id={span_id}")
1754
- except Exception as e:
1755
- self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_error: {e}")
1756
-
1757
- self._end_span_tracking(trace_client, run_id, span_type=span_type, error=error)
1758
- except Exception as e:
1759
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1760
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1761
- # print(traceback.format_exc())
1762
-
1763
- async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
1764
- handler_instance_id = id(self)
1765
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1766
- # Handle potential None for serialized
1767
- name = serialized.get("name", "Unnamed Tool") if serialized else "Unknown Tool (Serialized=None)"
1768
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1769
- # print(f"{log_prefix} ENTERING on_tool_start: name='{name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1770
- try:
1771
- # Pass parent_run_id
1772
- trace_client = self._ensure_trace_client(run_id, name) # FIXED: Removed parent_run_id
1773
- if not trace_client: return
1774
-
1775
- combined_inputs = {'input_str': input_str, 'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
1776
- self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="tool", inputs=combined_inputs)
1777
-
1778
- # --- Track executed tools (logic remains the same) ---
1779
- if name not in self.executed_tools: self.executed_tools.append(name)
1780
- parent_node_name = None
1781
- if parent_run_id and parent_run_id in self._run_id_to_span_id:
1782
- parent_span_id = self._run_id_to_span_id[parent_run_id]
1783
- try:
1784
- if hasattr(trace_client, 'entries') and trace_client.entries:
1785
- parent_enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == parent_span_id and e.type == "enter" and e.span_type == "chain"), None)
1786
- if parent_enter_entry:
1787
- parent_node_name = parent_enter_entry.function
1788
- else:
1789
- self._log(f" WARNING: trace_client.entries not available for parent node {parent_span_id}")
1790
- except Exception as e:
1791
- self._log(f" ERROR finding parent node name for tool start span_id {parent_span_id}: {e}")
1792
-
1793
- node_tool = f"{parent_node_name}:{name}" if parent_node_name else name
1794
- if node_tool not in self.executed_node_tools: self.executed_node_tools.append(node_tool)
1795
- self._log(f" Tracked node_tool: {node_tool}")
1796
- except Exception as e:
1797
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1798
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1799
- # print(traceback.format_exc())
1800
-
1801
- async def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1802
- handler_instance_id = id(self)
1803
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1804
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1805
- # print(f"{log_prefix} ENTERING on_tool_end: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1806
- try:
1807
- # Pass parent_run_id
1808
- trace_client = self._ensure_trace_client(run_id, "ToolEnd") # FIXED: Removed parent_run_id
1809
- if not trace_client: return
1810
- outputs = {"output": output, "kwargs": kwargs}
1811
- self._end_span_tracking(trace_client, run_id, span_type="tool", outputs=outputs)
1812
- except Exception as e:
1813
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1814
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1815
- # print(traceback.format_exc())
1816
-
1817
- async def on_tool_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1818
- handler_instance_id = id(self)
1819
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1820
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1821
- # print(f"{log_prefix} ENTERING on_tool_error: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1822
- try:
1823
- # Pass parent_run_id
1824
- trace_client = self._ensure_trace_client(run_id, "ToolError") # FIXED: Removed parent_run_id
1825
- if not trace_client: return
1826
- self._end_span_tracking(trace_client, run_id, span_type="tool", error=error)
1827
- except Exception as e:
1828
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1829
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1830
- # print(traceback.format_exc())
1831
-
1832
- async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
1833
- handler_instance_id = id(self)
1834
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1835
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1836
- llm_name = name or serialized.get("name", "LLM Call")
1837
- # print(f"{log_prefix} ENTERING on_llm_start: name='{llm_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1838
- try:
1839
- # Pass parent_run_id
1840
- trace_client = self._ensure_trace_client(run_id, llm_name) # FIXED: Removed parent_run_id
1841
- if not trace_client: return
1842
- inputs = {'prompts': prompts, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
1843
- self._start_span_tracking(trace_client, run_id, parent_run_id, llm_name, span_type="llm", inputs=inputs)
1844
- except Exception as e:
1845
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1846
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1847
- # print(traceback.format_exc())
1848
-
1849
- async def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1850
- handler_instance_id = id(self)
1851
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1852
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1853
-
1854
- try:
1855
- trace_client = self._ensure_trace_client(run_id, "LLMEnd")
1856
- if not trace_client:
1857
- return
1858
-
1859
- outputs = {"response": response, "kwargs": kwargs}
1860
- # --- Token Usage Extraction and Accumulation ---
1861
- token_usage = None
1862
- prompt_tokens = None # Use standard name
1863
- completion_tokens = None # Use standard name
1864
- total_tokens = None
1865
- try:
1866
- if response.llm_output and isinstance(response.llm_output, dict):
1867
- # Check for OpenAI/standard 'token_usage' first
1868
- if 'token_usage' in response.llm_output:
1869
- token_usage = response.llm_output.get('token_usage')
1870
- if token_usage and isinstance(token_usage, dict):
1871
- self._log(f" Extracted OpenAI token usage for run_id={run_id}: {token_usage}")
1872
- prompt_tokens = token_usage.get('prompt_tokens')
1873
- completion_tokens = token_usage.get('completion_tokens')
1874
- total_tokens = token_usage.get('total_tokens')
1875
- # Check for Anthropic 'usage'
1876
- elif 'usage' in response.llm_output:
1877
- token_usage = response.llm_output.get('usage')
1878
- if token_usage and isinstance(token_usage, dict):
1879
- self._log(f" Extracted Anthropic token usage for run_id={run_id}: {token_usage}")
1880
- prompt_tokens = token_usage.get('input_tokens') # Anthropic uses input_tokens
1881
- completion_tokens = token_usage.get('output_tokens') # Anthropic uses output_tokens
1882
- # Calculate total if possible
1883
- if prompt_tokens is not None and completion_tokens is not None:
1884
- total_tokens = prompt_tokens + completion_tokens
1885
- else:
1886
- self._log(f" Could not calculate total_tokens from Anthropic usage: input={prompt_tokens}, output={completion_tokens}")
1887
-
1888
- # Add to outputs if any tokens were found
1889
- if prompt_tokens is not None or completion_tokens is not None or total_tokens is not None:
1890
- outputs['usage'] = { # Add under 'usage' key
1891
- 'prompt_tokens': prompt_tokens, # Use standard keys
1892
- 'completion_tokens': completion_tokens,
1893
- 'total_tokens': total_tokens
1894
- }
1895
- else:
1896
- self._log(f" Could not extract token usage structure from llm_output for run_id={run_id}")
1897
- else: self._log(f" llm_output not available/dict for run_id={run_id}")
1898
- except Exception as e:
1899
- self._log(f" ERROR extracting token usage for run_id={run_id}: {e}")
1900
-
1901
- self._end_span_tracking(trace_client, run_id, span_type="llm", outputs=outputs)
1902
- except Exception as e:
1903
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1904
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1905
- # print(traceback.format_exc())
1906
-
1907
- async def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1908
- handler_instance_id = id(self)
1909
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1910
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1911
- # print(f"{log_prefix} ENTERING on_llm_error: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1912
- try:
1913
- # Pass parent_run_id
1914
- trace_client = self._ensure_trace_client(run_id, "LLMError") # FIXED: Removed parent_run_id
1915
- if not trace_client: return
1916
- self._end_span_tracking(trace_client, run_id, span_type="llm", error=error)
1917
- except Exception as e:
1918
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1919
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1920
- # print(traceback.format_exc())
1921
-
1922
- async def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
1923
- # Reuse on_llm_start logic, adding message formatting if needed
1924
- handler_instance_id = id(self)
1925
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1926
- chat_model_name = name or serialized.get("name", "ChatModel Call")
1927
- # Add OPENAI_API_CALL suffix if model is OpenAI and not present
1928
- is_openai = any(key.startswith('openai') for key in serialized.get('secrets', {}).keys()) or 'openai' in chat_model_name.lower()
1929
- is_anthropic = any(key.startswith('anthropic') for key in serialized.get('secrets', {}).keys()) or 'anthropic' in chat_model_name.lower() or 'claude' in chat_model_name.lower()
1930
- is_together = any(key.startswith('together') for key in serialized.get('secrets', {}).keys()) or 'together' in chat_model_name.lower()
1931
- # Add more checks for other providers like Google if needed
1932
- is_google = any(key.startswith('google') for key in serialized.get('secrets', {}).keys()) or 'google' in chat_model_name.lower() or 'gemini' in chat_model_name.lower()
1933
-
1934
- if is_openai and "OPENAI_API_CALL" not in chat_model_name:
1935
- chat_model_name = f"{chat_model_name} OPENAI_API_CALL"
1936
- elif is_anthropic and "ANTHROPIC_API_CALL" not in chat_model_name:
1937
- chat_model_name = f"{chat_model_name} ANTHROPIC_API_CALL"
1938
- elif is_together and "TOGETHER_API_CALL" not in chat_model_name:
1939
- chat_model_name = f"{chat_model_name} TOGETHER_API_CALL"
1940
- # Add elif for Google: check for 'google' or 'gemini'?
1941
- # elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
1942
- # chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
1943
- elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
1944
- chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
1945
-
1946
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1947
- # print(f"{log_prefix} ENTERING on_chat_model_start: name='{chat_model_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1948
- try:
1949
- # trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name) # Corrected call << INCORRECT COMMENT
1950
- trace_client = self._ensure_trace_client(run_id, chat_model_name) # FIXED: Removed parent_run_id
1951
- if not trace_client: return
1952
- inputs = {'messages': messages, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
1953
- self._start_span_tracking(trace_client, run_id, parent_run_id, chat_model_name, span_type="llm", inputs=inputs) # Use 'llm' span_type for consistency
1954
- except Exception as e:
1955
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1956
- self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chat_model_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1957
- # print(traceback.format_exc())
1958
-
1959
- # --- Agent Methods (Async versions - ensure parent_run_id passed if needed) ---
1960
- async def on_agent_action(self, action: AgentAction, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1961
- handler_instance_id = id(self)
1962
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1963
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1964
- # print(f"{log_prefix} ENTERING on_agent_action: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1965
- try:
1966
- # trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentAction") # Corrected call << INCORRECT COMMENT
1967
- trace_client = self._ensure_trace_client(run_id, "AgentAction") # FIXED: Removed parent_run_id
1968
- if not trace_client: return
1969
- # inputs = {\"action\": action, \"kwargs\": kwargs}
1970
- inputs = {"action": action, "kwargs": kwargs} # FIXED: Removed bad escapes
1971
- # Agent actions often lead to tool calls, treat as a distinct step
1972
- # self._start_span_tracking(trace_client, run_id, parent_run_id, name=\"AgentAction\", span_type=\"agent_action\", inputs=inputs)
1973
- self._start_span_tracking(trace_client, run_id, parent_run_id, name="AgentAction", span_type="agent_action", inputs=inputs) # FIXED: Removed bad escapes
1974
- except Exception as e:
1975
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1976
- # self._log(f\"{log_prefix} UNCAUGHT EXCEPTION in on_agent_action for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}\")
1977
- self._log(f'{log_prefix} UNCAUGHT EXCEPTION in on_agent_action for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}') # FIXED: Changed f-string quotes
1978
- # print(traceback.format_exc())
1979
-
1980
-
1981
- async def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1982
- handler_instance_id = id(self)
1983
- log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1984
- tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1985
- # print(f"{log_prefix} ENTERING on_agent_finish: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1986
- try:
1987
- # trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentFinish") # Corrected call << INCORRECT COMMENT
1988
- trace_client = self._ensure_trace_client(run_id, "AgentFinish") # FIXED: Removed parent_run_id
1989
- if not trace_client: return
1990
- # outputs = {\"finish\": finish, \"kwargs\": kwargs}
1991
- outputs = {"finish": finish, "kwargs": kwargs} # FIXED: Removed bad escapes
1992
- # Corresponds to the end of an AgentAction span? Or a chain span? Assuming agent_action here.
1993
- # self._end_span_tracking(trace_client, run_id, span_type=\"agent_action\", outputs=outputs)
1994
- self._end_span_tracking(trace_client, run_id, span_type="agent_action", outputs=outputs) # FIXED: Removed bad escapes
1995
- except Exception as e:
1996
- tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1997
- # self._log(f\"{log_prefix} UNCAUGHT EXCEPTION in on_agent_finish for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}\")
1998
- self._log(f'{log_prefix} UNCAUGHT EXCEPTION in on_agent_finish for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}') # FIXED: Changed f-string quotes
1999
- # print(traceback.format_exc())
418
+ outputs = {'return_values': finish.return_values, 'log': finish.log, 'messages': finish.messages, 'kwargs': kwargs}
419
+ self._end_span_tracking(trace_client, run_id, outputs=outputs)