judgeval 0.0.34__py3-none-any.whl → 0.0.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,337 +1,1999 @@
1
- from typing import Any, Dict, List, Optional, Sequence
1
+ from typing import Any, Dict, List, Optional, Sequence, Callable, TypedDict
2
2
  from uuid import UUID
3
3
  import time
4
4
  import uuid
5
- from contextvars import ContextVar
6
- from judgeval.common.tracer import TraceClient, TraceEntry, Tracer, SpanType
5
+ import traceback # For detailed error logging
6
+ import contextvars # <--- Import contextvars
7
+ from dataclasses import dataclass
8
+
9
+ from judgeval.common.tracer import TraceClient, TraceEntry, Tracer, SpanType, EvaluationConfig
10
+ from judgeval.data import Example # Import Example
11
+ from judgeval.scorers import AnswerRelevancyScorer, JudgevalScorer, APIJudgmentScorer # Import Scorer and base scorer types
7
12
 
8
13
  from langchain_core.language_models import BaseChatModel
9
14
  from langchain_huggingface import ChatHuggingFace
10
15
  from langchain_openai import ChatOpenAI
11
16
  from langchain_anthropic import ChatAnthropic
12
17
  from langchain_core.utils.function_calling import convert_to_openai_tool
13
- from langchain_core.callbacks import CallbackManager, BaseCallbackHandler
18
+ from langchain_core.callbacks import BaseCallbackHandler
19
+ from langchain_core.callbacks.base import AsyncCallbackHandler
14
20
  from langchain_core.agents import AgentAction, AgentFinish
15
21
  from langchain_core.outputs import LLMResult
16
- from langchain_core.tracers.context import register_configure_hook
17
22
  from langchain_core.messages.ai import AIMessage
18
23
  from langchain_core.messages.tool import ToolMessage
19
24
  from langchain_core.messages.base import BaseMessage
20
25
  from langchain_core.documents import Document
21
26
 
27
+ # --- Get context vars from tracer module ---
28
+ # Assuming tracer.py defines these and they are accessible
29
+ # If not, redefine them here or adjust import
30
+ from judgeval.common.tracer import current_span_var, current_trace_var # <-- Import current_trace_var
31
+
32
+ # --- Constants for Logging ---
33
+ HANDLER_LOG_PREFIX = "[JudgevalHandlerLog]"
34
+
35
+ # --- NEW __init__ ---
22
36
  class JudgevalCallbackHandler(BaseCallbackHandler):
37
+ """
38
+ LangChain Callback Handler using run_id/parent_run_id for hierarchy.
39
+ Manages its own internal TraceClient instance created upon first use.
40
+ Includes verbose logging and defensive checks.
41
+ """
42
+ # Make all properties ignored by LangChain's callback system
43
+ # to prevent unexpected serialization issues.
44
+ lc_serializable = False
45
+ lc_kwargs = {}
46
+
47
+ # --- NEW __init__ ---
23
48
  def __init__(self, tracer: Tracer):
49
+ # --- Enhanced Logging ---
50
+ # instance_id = id(self)
51
+ # # print(f"{HANDLER_LOG_PREFIX} *** Handler instance {instance_id} __init__ called. ***")
52
+ # --- End Enhanced Logging ---
24
53
  self.tracer = tracer
25
- self.previous_spans = [] # stack of previous spans
26
- self.created_trace = False
54
+ self._trace_client: Optional[TraceClient] = None
55
+ self._run_id_to_span_id: Dict[UUID, str] = {}
56
+ self._span_id_to_start_time: Dict[str, float] = {}
57
+ self._span_id_to_depth: Dict[str, int] = {}
58
+ self._run_id_to_context_token: Dict[UUID, contextvars.Token] = {}
59
+ self._root_run_id: Optional[UUID] = None
60
+ self._trace_saved: bool = False # Flag to prevent actions after trace is saved
61
+ self._run_id_to_start_inputs: Dict[UUID, Dict] = {} # <<< ADDED input storage
27
62
 
28
- # Attributes for users to access
29
- self.previous_node = None
30
- self.executed_node_tools = []
31
- self.executed_nodes = []
32
- self.executed_tools = []
63
+ # --- Token Count Accumulators ---
64
+ # self._current_prompt_tokens = 0
65
+ # self._current_completion_tokens = 0
66
+ # --- End Token Count Accumulators ---
33
67
 
34
- def start_span(self, name: str, span_type: SpanType = "span"):
35
- current_trace = self.tracer.get_current_trace()
36
- start_time = time.time()
68
+ self.executed_nodes: List[str] = []
69
+ self.executed_tools: List[str] = []
70
+ self.executed_node_tools: List[str] = []
71
+ # --- END NEW __init__ ---
37
72
 
38
- # Generate a unique ID for *this specific span invocation*
39
- span_id = str(uuid.uuid4())
40
-
41
- parent_span_id = current_trace.get_current_span()
42
- token = current_trace.set_current_span(span_id) # Set *this* span's ID as the current one
43
-
44
- current_depth = 0
45
- if parent_span_id and parent_span_id in current_trace._span_depths:
46
- current_depth = current_trace._span_depths[parent_span_id] + 1
47
-
48
- current_trace._span_depths[span_id] = current_depth # Store depth by span_id
49
- # Record span entry
50
- current_trace.add_entry(TraceEntry(
51
- type="enter",
52
- span_id=span_id,
53
- trace_id=current_trace.trace_id,
54
- parent_span_id=parent_span_id,
55
- function=name,
56
- depth=current_depth,
57
- message=name,
58
- created_at=start_time,
59
- span_type=span_type
60
- ))
61
-
62
- self.previous_spans.append(token)
63
- self._start_time = start_time
64
-
65
- def end_span(self, span_type: SpanType = "span"):
66
- current_trace = self.tracer.get_current_trace()
67
- duration = time.time() - self._start_time
68
- span_id = current_trace.get_current_span()
69
- exit_depth = current_trace._span_depths.get(span_id, 0) # Get depth using this span's ID
70
-
71
- # Record span exit
72
- current_trace.add_entry(TraceEntry(
73
- type="exit",
74
- span_id=span_id,
75
- trace_id=current_trace.trace_id,
76
- depth=exit_depth,
77
- created_at=time.time(),
78
- duration=duration,
79
- span_type=span_type
80
- ))
81
- current_trace.reset_current_span(self.previous_spans.pop())
82
- if exit_depth == 0:
83
- # Save the trace if we are the root, this is when users dont use any @observe decorators
84
- trace_id, trace_data = current_trace.save(overwrite=True)
85
- self._trace_id = trace_id
86
- current_trace = None
87
-
88
- def on_retriever_start(
89
- self,
90
- serialized: Optional[dict[str, Any]],
91
- query: str,
92
- *,
93
- run_id: UUID,
94
- parent_run_id: Optional[UUID] = None,
95
- tags: Optional[list[str]] = None,
96
- metadata: Optional[dict[str, Any]] = None,
97
- **kwargs: Any,
98
- ) -> Any:
99
- name = "RETRIEVER_CALL"
100
- if serialized and "name" in serialized:
101
- name = f"RETRIEVER_{serialized['name'].upper()}"
102
- current_trace = self.tracer.get_current_trace()
103
- self.start_span(name, span_type="retriever")
104
- current_trace.record_input({
105
- 'query': query,
106
- 'tags': tags,
107
- 'metadata': metadata,
108
- 'kwargs': kwargs
109
- })
110
-
111
- def on_retriever_end(
112
- self,
113
- documents: Sequence[Document],
114
- *,
115
- run_id: UUID,
116
- parent_run_id: Optional[UUID] = None,
117
- **kwargs: Any
118
- ) -> Any:
119
- # Process the retrieved documents into a format suitable for logging
120
- current_trace = self.tracer.get_current_trace()
121
- doc_summary = []
122
- for i, doc in enumerate(documents):
123
- # Extract key information from each document
124
- doc_data = {
125
- "index": i,
126
- "page_content": doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content,
127
- "metadata": doc.metadata
128
- }
129
- doc_summary.append(doc_data)
130
-
131
- # Record the document data
132
- current_trace.record_output({
133
- "document_count": len(documents),
134
- "documents": doc_summary
135
- })
136
-
137
- # End the retriever span
138
- self.end_span(span_type="retriever")
139
-
140
- def on_chain_start(
141
- self,
142
- serialized: Dict[str, Any],
143
- inputs: Dict[str, Any],
144
- *,
145
- run_id: UUID,
146
- parent_run_id: Optional[UUID] = None,
147
- tags: Optional[List[str]] = None,
148
- metadata: Optional[Dict[str, Any]] = None,
149
- **kwargs: Any
150
- ) -> None:
151
- # If the user doesnt use any @observe decorators, the first action in LangGraph workflows seems tohave this attribute, so we intialize our trace client here
152
- current_trace = self.tracer.get_current_trace()
153
- if kwargs.get('name') == 'LangGraph':
154
- if not current_trace:
155
- self.created_trace = True
156
- trace_id = str(uuid.uuid4())
157
- project = self.tracer.project_name
158
- trace = TraceClient(self.tracer, trace_id, "Langgraph", project_name=project, overwrite=False, rules=self.tracer.rules, enable_monitoring=self.tracer.enable_monitoring, enable_evaluations=self.tracer.enable_evaluations)
159
- self.tracer.set_current_trace(trace)
160
- self.start_span("LangGraph", span_type="Main Function")
161
-
162
- node = metadata.get("langgraph_node")
163
- if node != None and node != self.previous_node:
164
- self.start_span(node, span_type="node")
165
- self.executed_node_tools.append(node)
166
- self.executed_nodes.append(node)
167
- current_trace.record_input({
168
- 'args': inputs,
169
- 'kwargs': kwargs
170
- })
171
- self.previous_node = node
172
-
173
- def on_chain_end(
73
+ # --- MODIFIED _ensure_trace_client ---
74
+ def _ensure_trace_client(self, run_id: UUID, parent_run_id: Optional[UUID], event_name: str) -> Optional[TraceClient]:
75
+ """
76
+ Ensures the internal trace client is initialized, creating it only once
77
+ per handler instance lifecycle (effectively per graph invocation).
78
+ Returns the client or None.
79
+ """
80
+ # handler_instance_id = id(self)
81
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
82
+
83
+ # If trace already saved, do nothing.
84
+ # if self._trace_saved:
85
+ # # print(f"{log_prefix} Trace already saved. Skipping client check for {event_name} ({run_id}).")
86
+ # return None
87
+
88
+ # If a client already exists, return it.
89
+ if self._trace_client:
90
+ # # print(f"{log_prefix} Reusing existing TraceClient (ID: {self._trace_client.trace_id}) for {event_name} ({run_id}).")
91
+ return self._trace_client
92
+
93
+ # If no client exists, initialize it NOW.
94
+ # # print(f"{log_prefix} No TraceClient exists. Initializing for first event: {event_name} ({run_id})...")
95
+ trace_id = str(uuid.uuid4())
96
+ project = self.tracer.project_name
97
+ try:
98
+ # Use event_name as the initial trace name, might be updated later by on_chain_start if root
99
+ client_instance = TraceClient(
100
+ self.tracer, trace_id, event_name, project_name=project,
101
+ overwrite=False, rules=self.tracer.rules,
102
+ enable_monitoring=self.tracer.enable_monitoring,
103
+ enable_evaluations=self.tracer.enable_evaluations
104
+ )
105
+ self._trace_client = client_instance
106
+ if self._trace_client:
107
+ self._root_run_id = run_id # Assign the first run_id encountered as the tentative root
108
+ self._trace_saved = False # Ensure flag is reset
109
+ # # print(f"{log_prefix} Initialized NEW TraceClient: ID={self._trace_client.trace_id}, InitialName='{event_name}', Root Run ID={self._root_run_id}")
110
+ # Set active client on Tracer (important for potential fallbacks)
111
+ self.tracer._active_trace_client = self._trace_client
112
+ return self._trace_client
113
+ else:
114
+ # # print(f"{log_prefix} FATAL: TraceClient creation failed unexpectedly for {event_name} ({run_id}).")
115
+ return None
116
+ except Exception as e:
117
+ # # print(f"{log_prefix} FATAL: Exception initializing TraceClient for {event_name} ({run_id}): {e}")
118
+ # # print(traceback.format_exc())
119
+ self._trace_client = None
120
+ self._root_run_id = None
121
+ return None
122
+ # --- END MODIFIED _ensure_trace_client ---
123
+
124
+ def _log(self, message: str):
125
+ """Helper for consistent logging format."""
126
+ pass
127
+
128
+ def _start_span_tracking(
174
129
  self,
175
- outputs: Dict[str, Any],
176
- *,
130
+ trace_client: TraceClient, # Expect a valid client
177
131
  run_id: UUID,
178
- parent_run_id: Optional[UUID] = None,
179
- tags: Optional[List[str]] = None,
180
- **kwargs: Any,
181
- ) -> Any:
182
- current_trace = self.tracer.get_current_trace()
183
- if tags is not None and any("graph:step" in tag for tag in tags):
184
- current_trace.record_output(outputs)
185
- self.end_span(span_type="node")
186
-
187
- if self.created_trace and (outputs == "__end__" or (not kwargs and not tags)):
188
- self.end_span(span_type="Main Function")
189
-
190
- def on_chain_error(
132
+ parent_run_id: Optional[UUID],
133
+ name: str,
134
+ span_type: SpanType = "span",
135
+ inputs: Optional[Dict[str, Any]] = None
136
+ ):
137
+ # self._log(f"_start_span_tracking called for: name='{name}', run_id={run_id}, parent_run_id={parent_run_id}, span_type={span_type}")
138
+
139
+ # --- Add explicit check for trace_client ---
140
+ if not trace_client:
141
+ # --- Enhanced Logging ---
142
+ # handler_instance_id = id(self)
143
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
144
+ # --- End Enhanced Logging ---
145
+ # self._log(f"{log_prefix} FATAL ERROR in _start_span_tracking: trace_client argument is None for name='{name}', run_id={run_id}. Aborting span start.")
146
+ return
147
+ # --- End check ---
148
+ # --- Enhanced Logging ---
149
+ # handler_instance_id = id(self)
150
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
151
+ # trace_client_instance_id = id(trace_client) if trace_client else 'None'
152
+ # # print(f"{log_prefix} _start_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
153
+ # --- End Enhanced Logging ---
154
+
155
+ start_time = time.time()
156
+ span_id = str(uuid.uuid4())
157
+ parent_span_id: Optional[str] = None
158
+ current_depth = 0
159
+
160
+ if parent_run_id and parent_run_id in self._run_id_to_span_id:
161
+ parent_span_id = self._run_id_to_span_id[parent_run_id]
162
+ if parent_span_id in self._span_id_to_depth:
163
+ parent_depth = self._span_id_to_depth[parent_span_id]
164
+ current_depth = parent_depth + 1
165
+ # self._log(f" Found parent span_id={parent_span_id} with depth={parent_depth}. New depth={current_depth}.")
166
+ else:
167
+ # self._log(f" WARNING: Parent span depth not found for parent_span_id: {parent_span_id}. Setting depth to 0.")
168
+ current_depth = 0
169
+ elif parent_run_id:
170
+ # self._log(f" WARNING: parent_run_id {parent_run_id} provided for '{name}' ({run_id}) but parent span not tracked. Treating as depth 0.")
171
+ pass
172
+ else:
173
+ # self._log(f" No parent_run_id provided. Treating '{name}' as depth 0.")
174
+ pass
175
+
176
+ self._run_id_to_span_id[run_id] = span_id
177
+ self._span_id_to_start_time[span_id] = start_time
178
+ self._span_id_to_depth[span_id] = current_depth
179
+ # self._log(f" Tracking new span: span_id={span_id}, depth={current_depth}")
180
+
181
+ try:
182
+ trace_client.add_entry(TraceEntry(
183
+ type="enter", span_id=span_id, trace_id=trace_client.trace_id,
184
+ parent_span_id=parent_span_id, function=name, depth=current_depth,
185
+ message=name, created_at=start_time, span_type=span_type
186
+ ))
187
+ # self._log(f" Added 'enter' entry for span_id={span_id}")
188
+ except Exception as e:
189
+ # self._log(f" ERROR adding 'enter' entry for span_id {span_id}: {e}")
190
+ # # print(traceback.format_exc())
191
+ pass
192
+
193
+ if inputs:
194
+ # Pass the already validated trace_client
195
+ self._record_input_data(trace_client, run_id, inputs)
196
+
197
+ # --- Set SPAN context variable ONLY for chain (node) spans (Sync version) ---
198
+ if span_type == "chain":
199
+ try:
200
+ token = current_span_var.set(span_id)
201
+ self._run_id_to_context_token[run_id] = token
202
+ # self._log(f" Set current_span_var to {span_id} for run_id {run_id} (type: chain) in Sync Handler")
203
+ except Exception as e:
204
+ # self._log(f" ERROR setting current_span_var for run_id {run_id} in Sync Handler: {e}")
205
+ pass
206
+ # --- END ---
207
+
208
+ try:
209
+ # TODO: Check if trace_client.add_entry needs await if TraceClient becomes async
210
+ trace_client.add_entry(TraceEntry(
211
+ type="enter", span_id=span_id, trace_id=trace_client.trace_id,
212
+ parent_span_id=parent_span_id, function=name, depth=current_depth,
213
+ message=name, created_at=start_time, span_type=span_type
214
+ ))
215
+ # self._log(f" Added 'enter' entry for span_id={span_id}")
216
+ except Exception as e:
217
+ # self._log(f" ERROR adding 'enter' entry for span_id {span_id}: {e}")
218
+ # # print(traceback.format_exc())
219
+ pass
220
+
221
+ if inputs:
222
+ # _record_input_data is also sync for now
223
+ self._record_input_data(trace_client, run_id, inputs)
224
+
225
+ # --- NEW _end_span_tracking ---
226
+ def _end_span_tracking(
191
227
  self,
192
- error: BaseException,
193
- *,
228
+ trace_client: TraceClient, # Expect a valid client
194
229
  run_id: UUID,
195
- parent_run_id: Optional[UUID] = None,
196
- **kwargs: Any,
197
- ) -> Any:
198
- current_trace = self.tracer.get_current_trace()
199
- current_trace.record_output(error)
200
- self.end_span(span_type="node")
201
-
202
- def on_tool_start(
203
- self,
204
- serialized: Optional[dict[str, Any]],
205
- input_str: str,
206
- run_id: Optional[UUID] = None,
207
- parent_run_id: Optional[UUID] = None,
208
- inputs: Optional[dict[str, Any]] = None,
209
- **kwargs: Any,
230
+ span_type: SpanType = "span",
231
+ outputs: Optional[Any] = None,
232
+ error: Optional[BaseException] = None
210
233
  ):
211
- name = serialized["name"]
212
- self.start_span(name, span_type="tool")
213
- current_trace = self.tracer.get_current_trace()
214
- if name:
215
- # Track tool execution
216
- current_trace.executed_tools.append(name)
217
- node_tool = f"{self.previous_node}:{name}" if self.previous_node else name
218
- current_trace.executed_node_tools.append(node_tool)
219
- current_trace.record_input({
220
- 'args': input_str,
221
- 'kwargs': kwargs
222
- })
234
+ # self._log(f"_end_span_tracking called for: run_id={run_id}, span_type={span_type}")
235
+
236
+ # --- Define instance_id early for logging/cleanup ---
237
+ instance_id = id(self)
238
+
239
+ if not trace_client:
240
+ # Use instance_id defined above
241
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
242
+ # self._log(f"{log_prefix} FATAL ERROR in _end_span_tracking: trace_client argument is None for run_id={run_id}. Aborting span end.")
243
+ return
244
+
245
+ # Use instance_id defined above
246
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
247
+ # trace_client_instance_id = id(trace_client) if trace_client else 'None'
248
+ # # print(f"{log_prefix} _end_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
249
+
250
+ if run_id not in self._run_id_to_span_id:
251
+ # self._log(f" WARNING: Attempting to end span for untracked run_id: {run_id}")
252
+ # Allow root run end to proceed for cleanup/save attempt even if span wasn't tracked
253
+ if run_id != self._root_run_id:
254
+ return
255
+ else:
256
+ # self._log(f" Allowing root run {run_id} end logic to proceed despite untracked span.")
257
+ span_id = None # Indicate span wasn't found for duration/metadata lookup
258
+ else:
259
+ span_id = self._run_id_to_span_id[run_id]
260
+
261
+ start_time = self._span_id_to_start_time.get(span_id) if span_id else None
262
+ depth = self._span_id_to_depth.get(span_id, 0) if span_id else 0 # Use 0 depth if span_id is None
263
+ duration = time.time() - start_time if start_time is not None else None
264
+ # self._log(f" Ending span for run_id={run_id} (span_id={span_id}). Start time={start_time}, Duration={duration}, Depth={depth}")
265
+
266
+ # Record output/error first
267
+ if error:
268
+ # self._log(f" Recording error for run_id={run_id} (span_id={span_id}): {error}")
269
+ self._record_output_data(trace_client, run_id, error)
270
+ elif outputs is not None:
271
+ # output_repr = repr(outputs)
272
+ # log_output = (output_repr[:100] + '...') if len(output_repr) > 103 else output_repr
273
+ # self._log(f" Recording output for run_id={run_id} (span_id={span_id}): {log_output}")
274
+ self._record_output_data(trace_client, run_id, outputs)
275
+
276
+ # Add exit entry (only if span was tracked)
277
+ if span_id:
278
+ entry_function_name = "unknown"
279
+ try:
280
+ if hasattr(trace_client, 'entries') and trace_client.entries:
281
+ entry_function_name = next((e.function for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), "unknown")
282
+ else:
283
+ # self._log(f" WARNING: Cannot determine function name for exit span_id {span_id}, trace_client.entries missing or empty.")
284
+ pass
285
+ except Exception as e:
286
+ # self._log(f" ERROR finding function name for exit entry span_id {span_id}: {e}")
287
+ # # print(traceback.format_exc())
288
+ pass
289
+
290
+ try:
291
+ trace_client.add_entry(TraceEntry(
292
+ type="exit", span_id=span_id, trace_id=trace_client.trace_id,
293
+ depth=depth, created_at=time.time(), duration=duration,
294
+ span_type=span_type, function=entry_function_name
295
+ ))
296
+ # self._log(f" Added 'exit' entry for span_id={span_id}, function='{entry_function_name}'")
297
+ except Exception as e:
298
+ # self._log(f" ERROR adding 'exit' entry for span_id {span_id}: {e}")
299
+ # # print(traceback.format_exc())
300
+ pass
301
+
302
+ # Clean up dictionaries for this specific span
303
+ if span_id in self._span_id_to_start_time: del self._span_id_to_start_time[span_id]
304
+ if span_id in self._span_id_to_depth: del self._span_id_to_depth[span_id]
305
+
306
+ # Pop context token (Sync version) but don't reset
307
+ token = self._run_id_to_context_token.pop(run_id, None)
308
+ if token:
309
+ # self._log(f" Popped token for run_id {run_id} (was {span_id}), not resetting context var.")
310
+ pass
311
+ else:
312
+ # self._log(f" Skipping exit entry and cleanup for run_id {run_id} as span_id was not found.")
313
+ pass
314
+
315
+ # Check if this is the root run ending
316
+ if run_id == self._root_run_id:
317
+ trace_saved_successfully = False # Track save success
318
+ try:
319
+ # --- Aggregate and Set Token Counts BEFORE Saving ---
320
+ # if self._trace_client and not self._trace_saved:
321
+ # total_tokens = self._current_prompt_tokens + self._current_completion_tokens
322
+ # aggregated_token_counts = {
323
+ # 'prompt_tokens': self._current_prompt_tokens,
324
+ # 'completion_tokens': self._current_completion_tokens,
325
+ # 'total_tokens': total_tokens
326
+ # }
327
+ # # Assuming TraceClient has an attribute to hold the trace data being built
328
+ # try:
329
+ # # Attempt to set the attribute directly
330
+ # # Check if the attribute exists and is meant for this purpose
331
+ # if hasattr(self._trace_client, 'token_counts'):
332
+ # self._trace_client.token_counts = aggregated_token_counts
333
+ # self._log(f"Set aggregated token_counts on TraceClient for trace {self._trace_client.trace_id}: {aggregated_token_counts}")
334
+ # else:
335
+ # # If the attribute doesn't exist, maybe update the trace data dict directly if possible?
336
+ # # This part is speculative without knowing TraceClient internals.
337
+ # # E.g., if trace_client has a `_trace_data` dict:
338
+ # # if hasattr(self._trace_client, '_trace_data') and isinstance(self._trace_client._trace_data, dict):
339
+ # # self._trace_client._trace_data['token_counts'] = aggregated_token_counts
340
+ # # self._log(f"Updated _trace_data['token_counts'] on TraceClient for trace {self._trace_client.trace_id}: {aggregated_token_counts}")
341
+ # # else:
342
+ # self._log(f"WARNING: Could not set 'token_counts' on TraceClient for trace {self._trace_client.trace_id}. Aggregated counts might be lost.")
343
+ # except Exception as set_tc_e:
344
+ # self._log(f"ERROR setting token_counts on TraceClient for trace {self._trace_client.trace_id}: {set_tc_e}")
345
+ # --- End Token Count Aggregation ---
346
+
347
+ # Reset root run id after attempt
348
+ self._root_run_id = None
349
+ # Reset input storage for this handler instance
350
+ self._run_id_to_start_inputs = {}
351
+ self._log(f"Reset root run ID and input storage for handler {instance_id}.")
352
+
353
+ self._log(f"Root run {run_id} finished. Attempting to save trace...")
354
+ if self._trace_client and not self._trace_saved: # Check if not already saved
355
+ try:
356
+ # TODO: Check if trace_client.save needs await if TraceClient becomes async
357
+ trace_id, _ = self._trace_client.save(overwrite=self._trace_client.overwrite) # Use client's overwrite setting
358
+ self._log(f"Trace {trace_id} successfully saved.")
359
+ self._trace_saved = True # Set flag only after successful save
360
+ trace_saved_successfully = True # Mark success
361
+ except Exception as e:
362
+ self._log(f"ERROR saving trace {self._trace_client.trace_id}: {e}")
363
+ # print(traceback.format_exc())
364
+ # REMOVED FINALLY BLOCK THAT RESET STATE HERE
365
+ elif self._trace_client and self._trace_saved:
366
+ self._log(f" Trace {self._trace_client.trace_id} already saved. Skipping save.")
367
+ else:
368
+ self._log(f" WARNING: Root run {run_id} ended, but trace client was None. Cannot save trace.")
369
+ finally:
370
+ # --- NEW: Consolidated Cleanup Logic ---
371
+ # This block executes regardless of save success/failure
372
+ self._log(f" Performing cleanup for root run {run_id} in handler {instance_id}.")
373
+ # Reset root run id
374
+ self._root_run_id = None
375
+ # Reset input storage for this handler instance
376
+ self._run_id_to_start_inputs = {}
377
+ # --- Reset Token Counters ---
378
+ # self._current_prompt_tokens = 0
379
+ # self._current_completion_tokens = 0
380
+ # self._log(" Reset token counters.")
381
+ # --- End Reset Token Counters ---
382
+ # Reset tracer's active client ONLY IF it was this handler's client
383
+ if self.tracer._active_trace_client == self._trace_client:
384
+ self.tracer._active_trace_client = None
385
+ self._log(" Reset active_trace_client on Tracer.")
386
+ # Completely remove trace_context_token cleanup as it's not used in sync handler
387
+ # Optionally: Reset the entire trace client instance for this handler?
388
+ # self._trace_client = None # Uncomment if handler should reset client completely after root run
389
+ self._log(f" Cleanup complete for root run {run_id}.")
390
+ # --- End Cleanup Logic ---
391
+
392
+ def _record_input_data(self,
393
+ trace_client: TraceClient,
394
+ run_id: UUID,
395
+ inputs: Dict[str, Any]):
396
+ # self._log(f"_record_input_data called for run_id={run_id}")
397
+ if run_id not in self._run_id_to_span_id:
398
+ # self._log(f" WARNING: Attempting to record input for untracked run_id: {run_id}")
399
+ return
400
+ if not trace_client:
401
+ # self._log(f" ERROR: TraceClient is None when trying to record input for run_id={run_id}")
402
+ return
403
+
404
+ span_id = self._run_id_to_span_id[run_id]
405
+ depth = self._span_id_to_depth.get(span_id, 0)
406
+ # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
407
+
408
+ function_name = "unknown"
409
+ span_type: SpanType = "span"
410
+ try:
411
+ # Find the corresponding 'enter' entry to get the function name and span type
412
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
413
+ if enter_entry:
414
+ function_name = enter_entry.function
415
+ span_type = enter_entry.span_type
416
+ # self._log(f" Found function='{function_name}', span_type='{span_type}' for input span_id={span_id}")
417
+ else:
418
+ # self._log(f" WARNING: Could not find 'enter' entry for input span_id={span_id}")
419
+ pass
420
+ except Exception as e:
421
+ # self._log(f" ERROR finding enter entry for input span_id {span_id}: {e}")
422
+ # # print(traceback.format_exc())
423
+ pass
424
+
425
+ try:
426
+ input_entry = TraceEntry(
427
+ type="input",
428
+ span_id=span_id,
429
+ trace_id=trace_client.trace_id,
430
+ parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
431
+ function=function_name,
432
+ depth=depth,
433
+ message=f"Input to {function_name}",
434
+ created_at=time.time(),
435
+ inputs=inputs,
436
+ span_type=span_type
437
+ )
438
+ trace_client.add_entry(input_entry)
439
+ # self._log(f" Added 'input' entry directly for span_id={span_id}")
440
+ except Exception as e:
441
+ # self._log(f" ERROR adding 'input' entry directly for span_id {span_id}: {e}")
442
+ # # print(traceback.format_exc())
443
+ pass
444
+
445
+ def _record_output_data(self,
446
+ trace_client: TraceClient,
447
+ run_id: UUID,
448
+ output: Any):
449
+ # self._log(f"_record_output_data called for run_id={run_id}")
450
+ if run_id not in self._run_id_to_span_id:
451
+ # self._log(f" WARNING: Attempting to record output for untracked run_id: {run_id}")
452
+ return
453
+ if not trace_client:
454
+ # self._log(f" ERROR: TraceClient is None when trying to record output for run_id={run_id}")
455
+ return
456
+
457
+ span_id = self._run_id_to_span_id[run_id]
458
+ depth = self._span_id_to_depth.get(span_id, 0)
459
+ # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
460
+
461
+ function_name = "unknown"
462
+ span_type: SpanType = "span"
463
+ try:
464
+ # Find the corresponding 'enter' entry to get the function name and span type
465
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
466
+ if enter_entry:
467
+ function_name = enter_entry.function
468
+ span_type = enter_entry.span_type
469
+ # self._log(f" Found function='{function_name}', span_type='{span_type}' for output span_id={span_id}")
470
+ else:
471
+ # self._log(f" WARNING: Could not find 'enter' entry for output span_id={span_id}")
472
+ pass
473
+ except Exception as e:
474
+ # self._log(f" ERROR finding enter entry for output span_id {span_id}: {e}")
475
+ # # print(traceback.format_exc())
476
+ pass
477
+
478
+ try:
479
+ output_entry = TraceEntry(
480
+ type="output",
481
+ span_id=span_id,
482
+ trace_id=trace_client.trace_id,
483
+ parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
484
+ function=function_name,
485
+ depth=depth,
486
+ message=f"Output from {function_name}",
487
+ created_at=time.time(),
488
+ output=output, # Langchain outputs are typically serializable directly
489
+ span_type=span_type
490
+ )
491
+ trace_client.add_entry(output_entry)
492
+ self._log(f" Added 'output' entry directly for span_id={span_id}")
493
+ except Exception as e:
494
+ self._log(f" ERROR adding 'output' entry directly for span_id {span_id}: {e}")
495
+ # print(traceback.format_exc())
496
+
497
+ def _record_error(self,
498
+ trace_client: TraceClient,
499
+ run_id: UUID,
500
+ error: Any):
501
+ # self._log(f"_record_error called for run_id={run_id}")
502
+ if run_id not in self._run_id_to_span_id:
503
+ # self._log(f" WARNING: Attempting to record error for untracked run_id: {run_id}")
504
+ return
505
+ if not trace_client:
506
+ # self._log(f" ERROR: TraceClient is None when trying to record error for run_id={run_id}")
507
+ return
508
+
509
+ span_id = self._run_id_to_span_id[run_id]
510
+ depth = self._span_id_to_depth.get(span_id, 0)
511
+ # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
512
+
513
+ function_name = "unknown"
514
+ span_type: SpanType = "span"
515
+ try:
516
+ # Find the corresponding 'enter' entry to get the function name and span type
517
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
518
+ if enter_entry:
519
+ function_name = enter_entry.function
520
+ span_type = enter_entry.span_type
521
+ # self._log(f" Found function='{function_name}', span_type='{span_type}' for error span_id={span_id}")
522
+ else:
523
+ # self._log(f" WARNING: Could not find 'enter' entry for error span_id={span_id}")
524
+ pass
525
+ except Exception as e:
526
+ # self._log(f" ERROR finding enter entry for error span_id {span_id}: {e}")
527
+ # # print(traceback.format_exc())
528
+ pass
529
+
530
+ try:
531
+ error_entry = TraceEntry(
532
+ type="error",
533
+ span_id=span_id,
534
+ trace_id=trace_client.trace_id,
535
+ parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
536
+ function=function_name,
537
+ depth=depth,
538
+ message=f"Error in {function_name}",
539
+ created_at=time.time(),
540
+ error=str(error), # Convert error to string for serialization
541
+ span_type=span_type
542
+ )
543
+ trace_client.add_entry(error_entry)
544
+ # self._log(f" Added 'error' entry directly for span_id={span_id}")
545
+ except Exception as e:
546
+ # self._log(f" ERROR adding 'error' entry directly for span_id {span_id}: {e}")
547
+ # # print(traceback.format_exc())
548
+ pass
549
+
550
+ # --- Callback Methods ---
551
+ # Each method now ensures the trace client exists before proceeding
552
+
553
+ def on_retriever_start(self, serialized: Dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
554
+ handler_instance_id = id(self)
555
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
556
+ serialized_name = serialized.get('name', 'Unknown') if serialized else "Unknown (Serialized=None)"
557
+ # print(f"{log_prefix} ENTERING on_retriever_start: name='{serialized_name}', run_id={run_id}. Parent: {parent_run_id}")
558
+
559
+ try:
560
+ name = f"RETRIEVER_{(serialized_name).upper()}"
561
+ # Pass parent_run_id to _ensure_trace_client
562
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
563
+ if not trace_client:
564
+ # print(f"{log_prefix} No trace client obtained in on_retriever_start for {run_id}.")
565
+ return
566
+
567
+ inputs = {'query': query, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
568
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="retriever", inputs=inputs)
569
+ except Exception as e:
570
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
571
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
572
+ # print(traceback.format_exc())
573
+
574
+ def on_retriever_end(self, documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
575
+ handler_instance_id = id(self)
576
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
577
+ # print(f"{log_prefix} ENTERING on_retriever_end: run_id={run_id}. Parent: {parent_run_id}")
578
+
579
+ try:
580
+ # Pass parent_run_id to _ensure_trace_client (though less critical on end events)
581
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd") # Corrected call
582
+ if not trace_client:
583
+ # print(f"{log_prefix} No trace client obtained in on_retriever_end for {run_id}.")
584
+ return
585
+
586
+ doc_summary = [{"index": i, "page_content": doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content, "metadata": doc.metadata} for i, doc in enumerate(documents)]
587
+ outputs = {"document_count": len(documents), "documents": doc_summary, "kwargs": kwargs}
588
+ self._end_span_tracking(trace_client, run_id, span_type="retriever", outputs=outputs)
589
+ except Exception as e:
590
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
591
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
592
+ # print(traceback.format_exc())
593
+
594
+ def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
595
+ handler_instance_id = id(self)
596
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
597
+ serialized_name = serialized.get('name') if serialized else "Unknown (Serialized=None)"
598
+ # print(f"{log_prefix} ENTERING on_chain_start: name='{serialized_name}', run_id={run_id}. Parent: {parent_run_id}")
599
+
600
+ # --- Determine Name and Span Type ---
601
+ span_type: SpanType = "chain"
602
+ name = serialized_name if serialized_name else "Unknown Chain" # Default name
603
+ node_name = metadata.get("langgraph_node") if metadata else None
604
+ is_langgraph_root_kwarg = kwargs.get('name') == 'LangGraph' # Check kwargs for explicit root name
605
+ # More robust root detection: Often the first chain event with parent_run_id=None *is* the root.
606
+ is_potential_root_event = parent_run_id is None
607
+
608
+ if node_name:
609
+ name = node_name # Use node name if available
610
+ self._log(f" LangGraph Node Start: '{name}', run_id={run_id}")
611
+ if name not in self.executed_nodes: self.executed_nodes.append(name)
612
+ elif is_langgraph_root_kwarg and is_potential_root_event:
613
+ name = "LangGraph" # Explicit root detected
614
+ self._log(f" LangGraph Root Start Detected (kwargs): run_id={run_id}")
615
+ # Add handling for other potential LangChain internal chains if needed, e.g., "RunnableSequence"
616
+
617
+ # --- Ensure Trace Client ---
618
+ try:
619
+ # Pass parent_run_id to _ensure_trace_client
620
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
621
+ if not trace_client:
622
+ # print(f"{log_prefix} No trace client obtained in on_chain_start for {run_id} ('{name}').")
623
+ return
624
+
625
+ # --- Update Trace Name if Root ---
626
+ # If this is the root event (parent_run_id is None) and the trace client was just created,
627
+ # ensure the trace name reflects the graph's name ('LangGraph' usually).
628
+ if is_potential_root_event and run_id == self._root_run_id and trace_client.name != name:
629
+ self._log(f" Updating trace name from '{trace_client.name}' to '{name}' for root run {run_id}")
630
+ trace_client.name = name # Update trace name to the determined root name
631
+
632
+ # --- Start Span Tracking ---
633
+ combined_inputs = {'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
634
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type=span_type, inputs=combined_inputs)
635
+ # --- Store inputs for potential evaluation later ---
636
+ self._run_id_to_start_inputs[run_id] = inputs # Store the raw inputs dict
637
+ self._log(f" Stored inputs for run_id {run_id}")
638
+ # --- End Store inputs ---
639
+
640
+ except Exception as e:
641
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
642
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
643
+ # print(traceback.format_exc())
644
+
645
+
646
+ def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any) -> Any:
647
+ handler_instance_id = id(self)
648
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
649
+ # print(f"{log_prefix} ENTERING on_chain_end: run_id={run_id}. Parent: {parent_run_id}")
650
+
651
+ # --- Define instance_id for logging ---
652
+ instance_id = handler_instance_id # Use the already obtained id
653
+
654
+ try:
655
+ # Pass parent_run_id
656
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainEnd") # Corrected call
657
+ if not trace_client:
658
+ # print(f"{log_prefix} No trace client obtained in on_chain_end for {run_id}.")
659
+ return
660
+
661
+ span_id = self._run_id_to_span_id.get(run_id)
662
+ span_type: SpanType = "chain" # Default
663
+ if span_id:
664
+ try:
665
+ if hasattr(trace_client, 'entries') and trace_client.entries:
666
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
667
+ if enter_entry: span_type = enter_entry.span_type
668
+ else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_end span_id={span_id}")
669
+ except Exception as e:
670
+ self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_end: {e}")
671
+ else:
672
+ self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_end.")
673
+ # If it's the root run ending, _end_span_tracking will handle cleanup/save
674
+ if run_id == self._root_run_id:
675
+ self._log(f" Run ID {run_id} matches root. Proceeding to _end_span_tracking for potential save.")
676
+ else:
677
+ return # Don't call end tracking if it's not the root and span wasn't tracked
678
+
679
+ # --- Store input in on_chain_start if needed for evaluation ---
680
+ # Retrieve stored inputs
681
+ start_inputs = self._run_id_to_start_inputs.get(run_id, {})
682
+ # TODO: Determine how to reliably extract the original user prompt from start_inputs
683
+ # For the demo, the 'generate_recommendations' node receives the full state, not the initial prompt.
684
+ # Using a placeholder for now.
685
+ user_prompt_for_eval = "Unknown Input" # Placeholder - Needs refinement based on graph structure
686
+ # user_prompt_for_eval = start_inputs.get("messages", [{}])[-1].get("content", "Unknown Input") # Example: If input has messages list
687
+
688
+ # --- Trigger evaluation ---
689
+ if "recommendations" in outputs and span_id: # Ensure span_id exists
690
+ self._log(f"[Async Handler {instance_id}] Chain end for run_id {run_id} (span_id={span_id}) identified as recommendation node. Attempting evaluation.")
691
+ recommendations_output = outputs.get("recommendations")
692
+
693
+ if recommendations_output:
694
+ eval_example = Example(
695
+ input=user_prompt_for_eval, # Requires modification to store/retrieve this
696
+ actual_output=recommendations_output
697
+ )
698
+ # TODO: Get model name dynamically if possible
699
+ model_name = "gpt-4" # Placeholder
700
+
701
+ self._log(f"[Async Handler {instance_id}] Submitting evaluation for span_id={span_id}")
702
+ try:
703
+ # Call evaluate on the trace client, passing the specific span_id
704
+ # The TraceClient.async_evaluate now accepts and prioritizes this span_id.
705
+ trace_client.async_evaluate(
706
+ scorers=[AnswerRelevancyScorer(threshold=0.5)], # Ensure this scorer is imported
707
+ example=eval_example,
708
+ model=model_name,
709
+ span_id=span_id # Pass the specific span_id for this node run
710
+ )
711
+ self._log(f"[Async Handler {instance_id}] Evaluation submitted successfully for span_id={span_id}.")
712
+ except Exception as eval_e:
713
+ self._log(f"[Async Handler {instance_id}] ERROR submitting evaluation for span_id={span_id}: {eval_e}")
714
+ # print(traceback.format_exc()) # Print traceback for evaluation errors
715
+ else:
716
+ self._log(f"[Async Handler {instance_id}] Skipping evaluation for run_id {run_id} (span_id={span_id}): Missing recommendations output.")
717
+ elif "recommendations" in outputs:
718
+ self._log(f"[Async Handler {instance_id}] Skipping evaluation for run_id {run_id}: Span ID not found.")
719
+
720
+ # --- Existing span ending logic ---
721
+ # Determine span_type for end_span_tracking (copied from sync handler)
722
+ end_span_type: SpanType = "chain" # Default
723
+ if span_id: # Check if span_id was actually found
724
+ try:
725
+ if hasattr(trace_client, 'entries') and trace_client.entries:
726
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
727
+ if enter_entry: end_span_type = enter_entry.span_type
728
+ else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_end span_id={span_id}")
729
+ except Exception as e:
730
+ self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_end: {e}")
731
+ else:
732
+ self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_end, using default span_type='chain'.")
733
+
734
+ # Prepare outputs for end tracking (moved down)
735
+ combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
736
+
737
+ # Call end_span_tracking with potentially determined span_type
738
+ self._end_span_tracking(trace_client, run_id, span_type=end_span_type, outputs=combined_outputs)
739
+
740
+ # --- Root node cleanup (Existing logic - slightly modified save call) ---
741
+ if run_id == self._root_run_id:
742
+ self._log(f"Root run {run_id} finished. Attempting to save trace...")
743
+ if trace_client and not self._trace_saved:
744
+ try:
745
+ # Save might need to be async if TraceClient methods become async
746
+ # Pass overwrite=True based on client's setting
747
+ trace_id_saved, _ = trace_client.save(overwrite=trace_client.overwrite)
748
+ self._trace_saved = True
749
+ self._log(f"Trace {trace_id_saved} successfully saved.")
750
+ # Reset tracer's active client *after* successful save
751
+ if self.tracer._active_trace_client == trace_client:
752
+ self.tracer._active_trace_client = None
753
+ self._log("Reset active_trace_client on Tracer.")
754
+ except Exception as e:
755
+ self._log(f"ERROR saving trace {trace_client.trace_id}: {e}")
756
+ # print(traceback.format_exc())
757
+ elif trace_client and self._trace_saved:
758
+ self._log(f"Trace {trace_client.trace_id} already saved. Skipping save for root run {run_id}.")
759
+ elif not trace_client:
760
+ self._log(f"Skipping trace save for root run {run_id}: No client available.")
761
+
762
+ # Reset root run id after attempt
763
+ self._root_run_id = None
764
+ # Reset input storage for this handler instance
765
+ self._run_id_to_start_inputs = {}
766
+ self._log(f"Reset root run ID and input storage for handler {instance_id}.")
767
+
768
+ # --- SYNC: Attempt Evaluation by checking output metadata ---
769
+ eval_config: Optional[EvaluationConfig] = None
770
+ node_name = "unknown_node" # Default node name
771
+ # Ensure trace_client exists before proceeding with eval logic that uses it
772
+ if trace_client:
773
+ if span_id: # Try to find the node name from the 'enter' entry
774
+ try:
775
+ if hasattr(trace_client, 'entries') and trace_client.entries:
776
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
777
+ if enter_entry: node_name = enter_entry.function
778
+ except Exception as e:
779
+ self._log(f" ERROR finding node name for span_id {span_id} in on_chain_end: {e}")
780
+
781
+ if span_id and "_judgeval_eval" in outputs: # Only attempt if span exists and key is present
782
+ raw_eval_config = outputs.get("_judgeval_eval")
783
+ if isinstance(raw_eval_config, EvaluationConfig):
784
+ eval_config = raw_eval_config
785
+ self._log(f"{log_prefix} Found valid EvaluationConfig in outputs for node='{node_name}'.")
786
+ elif isinstance(raw_eval_config, dict):
787
+ # Attempt to reconstruct from dict
788
+ try:
789
+ if "scorers" in raw_eval_config and "example" in raw_eval_config:
790
+ example_data = raw_eval_config["example"]
791
+ reconstructed_example = Example(**example_data) if isinstance(example_data, dict) else example_data
792
+
793
+ if isinstance(reconstructed_example, Example):
794
+ eval_config = EvaluationConfig(
795
+ scorers=raw_eval_config["scorers"],
796
+ example=reconstructed_example,
797
+ model=raw_eval_config.get("model"),
798
+ log_results=raw_eval_config.get("log_results", True)
799
+ )
800
+ self._log(f"{log_prefix} Reconstructed EvaluationConfig from dict in outputs for node='{node_name}'.")
801
+ else:
802
+ self._log(f"{log_prefix} Could not reconstruct Example from dict in _judgeval_eval for node='{node_name}'. Skipping evaluation.")
803
+ else:
804
+ self._log(f"{log_prefix} Dict in _judgeval_eval missing required keys ('scorers', 'example') for node='{node_name}'. Skipping evaluation.")
805
+ except Exception as recon_e:
806
+ self._log(f"{log_prefix} ERROR attempting to reconstruct EvaluationConfig from dict for node='{node_name}': {recon_e}")
807
+ # print(traceback.format_exc()) # Print traceback for reconstruction errors
808
+ else:
809
+ self._log(f"{log_prefix} Found '_judgeval_eval' key in outputs for node='{node_name}', but it wasn't an EvaluationConfig object or reconstructable dict. Skipping evaluation.")
810
+
811
+ # Check eval_config *and* span_id again (this should be indented correctly)
812
+ if eval_config and span_id:
813
+ self._log(f"{log_prefix} Submitting evaluation for span_id={span_id}")
814
+ try:
815
+ # Ensure example has trace_id set if not already present
816
+ if not hasattr(eval_config.example, 'trace_id') or not eval_config.example.trace_id:
817
+ # Use the correct variable name 'trace_client' here
818
+ eval_config.example.trace_id = trace_client.trace_id
819
+ self._log(f"{log_prefix} Set trace_id={trace_client.trace_id} on evaluation example.")
820
+
821
+ # Call async_evaluate on the TraceClient instance ('trace_client')
822
+ # Use the correct variable name 'trace_client' here
823
+ trace_client.async_evaluate( # <-- Fix: Use trace_client
824
+ scorers=eval_config.scorers,
825
+ example=eval_config.example,
826
+ model=eval_config.model,
827
+ log_results=eval_config.log_results,
828
+ span_id=span_id # Pass the specific span_id for this node run
829
+ )
830
+ self._log(f"{log_prefix} Evaluation submitted successfully for span_id={span_id}.")
831
+ except Exception as eval_e:
832
+ self._log(f"{log_prefix} ERROR submitting evaluation for span_id={span_id}: {eval_e}")
833
+ # print(traceback.format_exc()) # Print traceback for evaluation errors
834
+ elif "_judgeval_eval" in outputs and not span_id:
835
+ self._log(f"{log_prefix} WARNING: Found _judgeval_eval in outputs, but span_id for run_id {run_id} was not found. Cannot submit evaluation.")
836
+ # --- End SYNC Evaluation Logic ---
837
+
838
+ except Exception as e:
839
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
840
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
841
+ # print(traceback.format_exc())
842
+
843
+ def on_chain_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
844
+ handler_instance_id = id(self)
845
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
846
+ # print(f"{log_prefix} ENTERING on_chain_error: run_id={run_id}. Parent: {parent_run_id}")
847
+
848
+ try:
849
+ # Pass parent_run_id
850
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainError") # Corrected call
851
+ if not trace_client:
852
+ # print(f"{log_prefix} No trace client obtained in on_chain_error for {run_id}.")
853
+ return
854
+
855
+ span_id = self._run_id_to_span_id.get(run_id)
856
+ span_type: SpanType = "chain" # Default
857
+ if span_id:
858
+ try:
859
+ if hasattr(trace_client, 'entries') and trace_client.entries:
860
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
861
+ if enter_entry: span_type = enter_entry.span_type
862
+ else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_error span_id={span_id}")
863
+ except Exception as e:
864
+ self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_error: {e}")
865
+ else:
866
+ self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_error.")
867
+ # Let _end_span_tracking handle potential root run cleanup
868
+ if run_id != self._root_run_id:
869
+ return
870
+
871
+ self._end_span_tracking(trace_client, run_id, span_type=span_type, error=error)
872
+ except Exception as e:
873
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
874
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
875
+ # print(traceback.format_exc())
876
+
877
+ def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
878
+ handler_instance_id = id(self)
879
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
880
+ name = serialized.get("name", "Unnamed Tool") if serialized else "Unknown Tool (Serialized=None)"
881
+ # print(f"{log_prefix} ENTERING on_tool_start: name='{name}', run_id={run_id}. Parent: {parent_run_id}")
882
+
883
+ try:
884
+ # Pass parent_run_id
885
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
886
+ if not trace_client:
887
+ # print(f"{log_prefix} No trace client obtained in on_tool_start for {run_id}.")
888
+ return
889
+
890
+ combined_inputs = {'input_str': input_str, 'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
891
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="tool", inputs=combined_inputs)
892
+
893
+ # --- Track executed tools (remains the same) ---
894
+ if name not in self.executed_tools: self.executed_tools.append(name)
895
+ parent_node_name = None
896
+ if parent_run_id and parent_run_id in self._run_id_to_span_id:
897
+ parent_span_id = self._run_id_to_span_id[parent_run_id]
898
+ try:
899
+ if hasattr(trace_client, 'entries') and trace_client.entries:
900
+ parent_enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == parent_span_id and e.type == "enter" and e.span_type == "chain"), None)
901
+ if parent_enter_entry:
902
+ parent_node_name = parent_enter_entry.function
903
+ else: self._log(f" WARNING: trace_client.entries missing for tool start parent {parent_span_id}")
904
+ except Exception as e:
905
+ self._log(f" ERROR finding parent node name for tool start span_id {parent_span_id}: {e}")
906
+
907
+ node_tool = f"{parent_node_name}:{name}" if parent_node_name else name
908
+ if node_tool not in self.executed_node_tools: self.executed_node_tools.append(node_tool)
909
+ self._log(f" Tracked node_tool: {node_tool}")
910
+ # --- End Track executed tools ---
911
+ except Exception as e:
912
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
913
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
914
+ # print(traceback.format_exc())
915
+
223
916
 
224
917
  def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
225
- current_trace = self.tracer.get_current_trace()
226
- current_trace.record_output(output)
227
- self.end_span(span_type="tool")
918
+ handler_instance_id = id(self)
919
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
920
+ # print(f"{log_prefix} ENTERING on_tool_end: run_id={run_id}. Parent: {parent_run_id}")
228
921
 
229
- def on_tool_error(
230
- self,
231
- error: BaseException,
232
- *,
233
- run_id: UUID,
234
- parent_run_id: Optional[UUID] = None,
235
- **kwargs: Any,
236
- ) -> Any:
237
- current_trace = self.tracer.get_current_trace()
238
- current_trace.record_output(error)
239
- self.end_span(span_type="tool")
240
-
241
- def on_agent_action(
242
- self,
243
- action: AgentAction,
244
- *,
245
- run_id: UUID,
246
- parent_run_id: Optional[UUID] = None,
247
- **kwargs: Any,
248
- ) -> Any:
249
- pass
922
+ try:
923
+ # Pass parent_run_id
924
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolEnd") # Corrected call
925
+ if not trace_client:
926
+ # print(f"{log_prefix} No trace client obtained in on_tool_end for {run_id}.")
927
+ return
928
+ outputs = {"output": output, "kwargs": kwargs}
929
+ self._end_span_tracking(trace_client, run_id, span_type="tool", outputs=outputs)
930
+ except Exception as e:
931
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
932
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
933
+ # print(traceback.format_exc())
250
934
 
251
- def on_agent_finish(
252
- self,
253
- finish: AgentFinish,
254
- *,
255
- run_id: UUID,
256
- parent_run_id: Optional[UUID] = None,
257
- **kwargs: Any,
258
- ) -> Any:
259
-
935
+ def on_tool_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
936
+ handler_instance_id = id(self)
937
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
938
+ # print(f"{log_prefix} ENTERING on_tool_error: run_id={run_id}. Parent: {parent_run_id}")
939
+
940
+ try:
941
+ # Pass parent_run_id
942
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolError") # Corrected call
943
+ if not trace_client:
944
+ # print(f"{log_prefix} No trace client obtained in on_tool_error for {run_id}.")
945
+ return
946
+ self._end_span_tracking(trace_client, run_id, span_type="tool", error=error)
947
+ except Exception as e:
948
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
949
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
950
+ # print(traceback.format_exc())
951
+
952
+ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
953
+ handler_instance_id = id(self)
954
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
955
+ llm_name = name or serialized.get("name", "LLM Call")
956
+ # print(f"{log_prefix} ENTERING on_llm_start: name='{llm_name}', run_id={run_id}. Parent: {parent_run_id}")
957
+
958
+ try:
959
+ # Pass parent_run_id
960
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, llm_name) # Corrected call
961
+ if not trace_client:
962
+ # print(f"{log_prefix} No trace client obtained in on_llm_start for {run_id}.")
963
+ return
964
+ inputs = {'prompts': prompts, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
965
+ self._start_span_tracking(trace_client, run_id, parent_run_id, llm_name, span_type="llm", inputs=inputs)
966
+ except Exception as e:
967
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
968
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
969
+ # print(traceback.format_exc())
970
+
971
+ def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
972
+ handler_instance_id = id(self)
973
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
974
+ # print(f"{log_prefix} ENTERING on_llm_end: run_id={run_id}. Parent: {parent_run_id}")
975
+
976
+ # --- Debugging unchanged ---
977
+ # # print(f"{log_prefix} [DEBUG on_llm_end] Received response object for run_id={run_id}:")
978
+ # try:
979
+ # from rich import print as rprint
980
+ # r# print(response)
981
+ # except ImportError: # print(response)
982
+ # # print(f"{log_prefix} [DEBUG on_llm_end] response.llm_output type: {type(response.llm_output)}")
983
+ # # print(f"{log_prefix} [DEBUG on_llm_end] response.llm_output content:")
984
+ # try:
985
+ # from rich import print as rprint
986
+
987
+ # except ImportError: # print(response.llm_output)
988
+
989
+ try:
990
+ # Pass parent_run_id
991
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMEnd") # Corrected call
992
+ if not trace_client:
993
+ # print(f"{log_prefix} No trace client obtained in on_llm_end for {run_id}.")
994
+ return
995
+ outputs = {"response": response, "kwargs": kwargs}
996
+ # --- Token Usage Extraction and Accumulation ---
997
+ token_usage = None
998
+ prompt_tokens = None # Use standard name
999
+ completion_tokens = None # Use standard name
1000
+ total_tokens = None
1001
+ try:
1002
+ if response.llm_output and isinstance(response.llm_output, dict):
1003
+ # Check for OpenAI/standard 'token_usage' first
1004
+ if 'token_usage' in response.llm_output:
1005
+ token_usage = response.llm_output.get('token_usage')
1006
+ if token_usage and isinstance(token_usage, dict):
1007
+ self._log(f" Extracted OpenAI token usage for run_id={run_id}: {token_usage}")
1008
+ prompt_tokens = token_usage.get('prompt_tokens')
1009
+ completion_tokens = token_usage.get('completion_tokens')
1010
+ total_tokens = token_usage.get('total_tokens') # OpenAI provides total
1011
+ # Check for Anthropic 'usage'
1012
+ elif 'usage' in response.llm_output:
1013
+ token_usage = response.llm_output.get('usage')
1014
+ if token_usage and isinstance(token_usage, dict):
1015
+ self._log(f" Extracted Anthropic token usage for run_id={run_id}: {token_usage}")
1016
+ prompt_tokens = token_usage.get('input_tokens') # Anthropic uses input_tokens
1017
+ completion_tokens = token_usage.get('output_tokens') # Anthropic uses output_tokens
1018
+ # Calculate total if possible
1019
+ if prompt_tokens is not None and completion_tokens is not None:
1020
+ total_tokens = prompt_tokens + completion_tokens
1021
+ else:
1022
+ self._log(f" Could not calculate total_tokens from Anthropic usage: input={prompt_tokens}, output={completion_tokens}")
1023
+
1024
+ # --- Store individual usage in span output and Accumulate ---
1025
+ if prompt_tokens is not None or completion_tokens is not None:
1026
+ # Store individual usage for this span
1027
+ outputs['usage'] = {
1028
+ 'prompt_tokens': prompt_tokens,
1029
+ 'completion_tokens': completion_tokens,
1030
+ 'total_tokens': total_tokens
1031
+ }
1032
+ # Accumulate tokens for the entire trace
1033
+ # if isinstance(prompt_tokens, int):
1034
+ # self._current_prompt_tokens += prompt_tokens
1035
+ # if isinstance(completion_tokens, int):
1036
+ # self._current_completion_tokens += completion_tokens
1037
+ # self._log(f" Accumulated tokens for run_id={run_id}. Current totals: Prompt={self._current_prompt_tokens}, Completion={self._current_completion_tokens}")
1038
+ else:
1039
+ self._log(f" Could not extract token usage structure from llm_output for run_id={run_id}")
1040
+ else: self._log(f" llm_output not available/dict for run_id={run_id}")
1041
+ except Exception as e:
1042
+ self._log(f" ERROR extracting/accumulating token usage for run_id={run_id}: {e}")
1043
+ # --- End Token Usage ---
1044
+ self._end_span_tracking(trace_client, run_id, span_type="llm", outputs=outputs)
1045
+ except Exception as e:
1046
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1047
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1048
+ # print(traceback.format_exc())
1049
+
1050
+ def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1051
+ handler_instance_id = id(self)
1052
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1053
+ # print(f"{log_prefix} ENTERING on_llm_error: run_id={run_id}. Parent: {parent_run_id}")
1054
+
1055
+ try:
1056
+ # Pass parent_run_id
1057
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMError") # Corrected call
1058
+ if not trace_client:
1059
+ # print(f"{log_prefix} No trace client obtained in on_llm_error for {run_id}.")
1060
+ return
1061
+ self._end_span_tracking(trace_client, run_id, span_type="llm", error=error)
1062
+ except Exception as e:
1063
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1064
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1065
+ # print(traceback.format_exc())
1066
+
1067
+ def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
1068
+ # Reuse on_llm_start logic, adding message formatting if needed
1069
+ handler_instance_id = id(self)
1070
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1071
+ chat_model_name = name or serialized.get("name", "ChatModel Call")
1072
+ # Add OPENAI_API_CALL suffix if model is OpenAI and not present
1073
+ is_openai = any(key.startswith('openai') for key in serialized.get('secrets', {}).keys()) or 'openai' in chat_model_name.lower()
1074
+ is_anthropic = any(key.startswith('anthropic') for key in serialized.get('secrets', {}).keys()) or 'anthropic' in chat_model_name.lower() or 'claude' in chat_model_name.lower()
1075
+ is_together = any(key.startswith('together') for key in serialized.get('secrets', {}).keys()) or 'together' in chat_model_name.lower()
1076
+ # Add more checks for other providers like Google if needed
1077
+ is_google = any(key.startswith('google') for key in serialized.get('secrets', {}).keys()) or 'google' in chat_model_name.lower() or 'gemini' in chat_model_name.lower()
1078
+
1079
+ if is_openai and "OPENAI_API_CALL" not in chat_model_name:
1080
+ chat_model_name = f"{chat_model_name} OPENAI_API_CALL"
1081
+ elif is_anthropic and "ANTHROPIC_API_CALL" not in chat_model_name:
1082
+ chat_model_name = f"{chat_model_name} ANTHROPIC_API_CALL"
1083
+ elif is_together and "TOGETHER_API_CALL" not in chat_model_name:
1084
+ chat_model_name = f"{chat_model_name} TOGETHER_API_CALL"
1085
+
1086
+ elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
1087
+ chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
1088
+
1089
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1090
+ # print(f"{log_prefix} ENTERING on_chat_model_start: name='{chat_model_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1091
+ try:
1092
+ # The call below was missing parent_run_id
1093
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name) # Corrected call with parent_run_id
1094
+ if not trace_client: return
1095
+ inputs = {'messages': messages, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
1096
+ self._start_span_tracking(trace_client, run_id, parent_run_id, chat_model_name, span_type="llm", inputs=inputs) # Use 'llm' span_type for consistency
1097
+ except Exception as e:
1098
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1099
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chat_model_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1100
+ # print(traceback.format_exc())
1101
+
1102
+ # --- Agent Methods (Async versions - ensure parent_run_id passed if needed) ---
1103
+ def on_agent_action(self, action: AgentAction, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1104
+ handler_instance_id = id(self)
1105
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1106
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1107
+ # print(f"{log_prefix} ENTERING on_agent_action: tool={action.tool}, run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1108
+
1109
+ try:
1110
+ # Optional: Implement detailed tracing if needed
1111
+ pass
1112
+ except Exception as e:
1113
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1114
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_agent_action for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1115
+ # print(traceback.format_exc())
1116
+
1117
+ def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1118
+ handler_instance_id = id(self)
1119
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {handler_instance_id}]"
1120
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1121
+ # print(f"{log_prefix} ENTERING on_agent_finish: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1122
+
1123
+ try:
1124
+ # Optional: Implement detailed tracing if needed
1125
+ pass
1126
+ except Exception as e:
1127
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1128
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_agent_finish for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1129
+ # print(traceback.format_exc())
1130
+
1131
+ # --- Async Handler ---
1132
+
1133
+ # --- NEW Fully Functional Async Handler ---
1134
+ class AsyncJudgevalCallbackHandler(AsyncCallbackHandler):
1135
+ """
1136
+ Async LangChain Callback Handler using run_id/parent_run_id for hierarchy.
1137
+ Manages its own internal TraceClient instance created upon first use.
1138
+ Includes verbose logging and defensive checks.
1139
+ """
1140
+ lc_serializable = False
1141
+ lc_kwargs = {}
1142
+
1143
+ def __init__(self, tracer: Tracer):
1144
+ instance_id = id(self)
1145
+ # print(f"{HANDLER_LOG_PREFIX} *** Async Handler instance {instance_id} __init__ called. ***")
1146
+ self.tracer = tracer
1147
+ self._trace_client: Optional[TraceClient] = None
1148
+ self._run_id_to_span_id: Dict[UUID, str] = {}
1149
+ self._span_id_to_start_time: Dict[str, float] = {}
1150
+ self._span_id_to_depth: Dict[str, int] = {}
1151
+ self._run_id_to_context_token: Dict[UUID, contextvars.Token] = {} # Initialize missing attribute
1152
+ self._root_run_id: Optional[UUID] = None
1153
+ self._trace_context_token: Optional[contextvars.Token] = None # Restore missing attribute
1154
+ self._trace_saved: bool = False # <<< ADDED MISSING ATTRIBUTE
1155
+ self._run_id_to_start_inputs: Dict[UUID, Dict] = {} # <<< ADDED input storage
1156
+
1157
+ # --- Token Count Accumulators ---
1158
+ # self._current_prompt_tokens = 0
1159
+ # self._current_completion_tokens = 0
1160
+ # --- End Token Count Accumulators ---
1161
+
1162
+ self.executed_nodes: List[str] = []
1163
+ self.executed_tools: List[str] = []
1164
+ self.executed_node_tools: List[str] = []
1165
+
1166
+ # NOTE: _ensure_trace_client remains synchronous as it doesn't involve async I/O
1167
+ def _ensure_trace_client(self, run_id: UUID, event_name: str) -> Optional[TraceClient]:
1168
+ """Ensures the internal trace client is initialized. Returns client or None."""
1169
+ handler_instance_id = id(self)
1170
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1171
+ if self._trace_client is None:
1172
+ trace_id = str(uuid.uuid4())
1173
+ project = self.tracer.project_name
1174
+ try:
1175
+ client_instance = TraceClient(
1176
+ self.tracer, trace_id, event_name, project_name=project,
1177
+ overwrite=False, rules=self.tracer.rules,
1178
+ enable_monitoring=self.tracer.enable_monitoring,
1179
+ enable_evaluations=self.tracer.enable_evaluations
1180
+ )
1181
+ self._trace_client = client_instance
1182
+ if self._trace_client:
1183
+ self.tracer._active_trace_client = self._trace_client
1184
+ self._current_trace_id = self._trace_client.trace_id
1185
+ if self._root_run_id is None:
1186
+ self._root_run_id = run_id
1187
+ else:
1188
+ return None
1189
+ except Exception as e:
1190
+ self._trace_client = None
1191
+ return None
1192
+ return self._trace_client
1193
+
1194
+ def _log(self, message: str):
1195
+ """Helper for consistent logging format."""
260
1196
  pass
261
1197
 
262
- def on_llm_start(
1198
+ # NOTE: _start_span_tracking remains mostly synchronous, TraceClient.add_entry might become async later
1199
+ def _start_span_tracking(
263
1200
  self,
264
- serialized: Optional[dict[str, Any]],
265
- prompts: list[str],
266
- *,
1201
+ trace_client: TraceClient,
267
1202
  run_id: UUID,
268
- parent_run_id: Optional[UUID] = None,
269
- **kwargs: Any,
270
- ) -> Any:
271
- name = "LLM call"
272
- self.start_span(name, span_type="llm")
273
- current_trace = self.tracer.get_current_trace()
274
- current_trace.record_input({
275
- 'args': prompts,
276
- 'kwargs': kwargs
277
- })
278
-
279
- def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any):
280
- current_trace = self.tracer.get_current_trace()
281
- current_trace.record_output(response.generations[0][0].text)
282
- self.end_span(span_type="llm")
283
-
284
- def on_llm_error(
285
- self,
286
- error: BaseException,
287
- *,
288
- run_id: UUID,
289
- parent_run_id: Optional[UUID] = None,
290
- **kwargs: Any,
291
- ) -> Any:
292
- current_trace = self.tracer.get_current_trace()
293
- current_trace.record_output(error)
294
- self.end_span(span_type="llm")
295
-
296
- def on_chat_model_start(
1203
+ parent_run_id: Optional[UUID],
1204
+ name: str,
1205
+ span_type: SpanType = "span",
1206
+ inputs: Optional[Dict[str, Any]] = None
1207
+ ):
1208
+ self._log(f"_start_span_tracking called for: name='{name}', run_id={run_id}, parent_run_id={parent_run_id}, span_type={span_type}")
1209
+ if not trace_client:
1210
+ handler_instance_id = id(self)
1211
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1212
+ self._log(f"{log_prefix} FATAL ERROR in _start_span_tracking: trace_client argument is None for name='{name}', run_id={run_id}. Aborting span start.")
1213
+ return
1214
+
1215
+ # --- NEW: Set trace context variable if not already set for this trace ---
1216
+ if self._trace_context_token is None:
1217
+ try:
1218
+ self._trace_context_token = current_trace_var.set(trace_client)
1219
+ self._log(f" Set current_trace_var for trace_id {trace_client.trace_id}")
1220
+ except Exception as e:
1221
+ self._log(f" ERROR setting current_trace_var for trace_id {trace_client.trace_id}: {e}")
1222
+ # --- END NEW ---
1223
+
1224
+ handler_instance_id = id(self)
1225
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1226
+ trace_client_instance_id = id(trace_client) if trace_client else 'None'
1227
+ # print(f"{log_prefix} _start_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
1228
+
1229
+ start_time = time.time()
1230
+ span_id = str(uuid.uuid4())
1231
+ parent_span_id: Optional[str] = None
1232
+ current_depth = 0
1233
+
1234
+ if parent_run_id and parent_run_id in self._run_id_to_span_id:
1235
+ parent_span_id = self._run_id_to_span_id[parent_run_id]
1236
+ if parent_span_id in self._span_id_to_depth:
1237
+ current_depth = self._span_id_to_depth[parent_span_id] + 1
1238
+ else:
1239
+ self._log(f" WARNING: Parent span depth not found for parent_span_id: {parent_span_id}. Setting depth to 0.")
1240
+ elif parent_run_id:
1241
+ self._log(f" WARNING: parent_run_id {parent_run_id} provided for '{name}' ({run_id}) but parent span not tracked. Treating as depth 0.")
1242
+ else:
1243
+ self._log(f" No parent_run_id provided. Treating '{name}' as depth 0.")
1244
+
1245
+ self._run_id_to_span_id[run_id] = span_id
1246
+ self._span_id_to_start_time[span_id] = start_time
1247
+ self._span_id_to_depth[span_id] = current_depth
1248
+ self._log(f" Tracking new span: span_id={span_id}, depth={current_depth}")
1249
+
1250
+ # --- Set SPAN context variable ONLY for chain (node) spans ---
1251
+ if span_type == "chain":
1252
+ try:
1253
+ # Set current_span_var for the node's execution context
1254
+ token = current_span_var.set(span_id) # Store the token
1255
+ self._run_id_to_context_token[run_id] = token # Store token in the dictionary
1256
+ self._log(f" Set current_span_var to {span_id} for run_id {run_id} (type: chain)")
1257
+ except Exception as e:
1258
+ self._log(f" ERROR setting current_span_var for run_id {run_id}: {e}")
1259
+ # --- END Span Context Var Logic ---
1260
+
1261
+ try:
1262
+ # TODO: Check if trace_client.add_entry needs await if TraceClient becomes async
1263
+ trace_client.add_entry(TraceEntry(
1264
+ type="enter", span_id=span_id, trace_id=trace_client.trace_id,
1265
+ parent_span_id=parent_span_id, function=name, depth=current_depth,
1266
+ message=name, created_at=start_time, span_type=span_type
1267
+ ))
1268
+ self._log(f" Added 'enter' entry for span_id={span_id}")
1269
+ except Exception as e:
1270
+ self._log(f" ERROR adding 'enter' entry for span_id {span_id}: {e}")
1271
+ # print(traceback.format_exc())
1272
+
1273
+ if inputs:
1274
+ # _record_input_data is also sync for now
1275
+ self._record_input_data(trace_client, run_id, inputs)
1276
+
1277
+ # NOTE: _end_span_tracking remains mostly synchronous, TraceClient.save might become async later
1278
+ def _end_span_tracking(
297
1279
  self,
298
- serialized: Optional[dict[str, Any]],
299
- messages: list[list[BaseMessage]],
300
- *,
1280
+ trace_client: TraceClient,
301
1281
  run_id: UUID,
302
- parent_run_id: Optional[UUID] = None,
303
- **kwargs: Any,
304
- ) -> Any:
305
-
306
- if "openai" in serialized["id"]:
307
- name = f"OPENAI_API_CALL"
308
- elif "anthropic" in serialized["id"]:
309
- name = "ANTHROPIC_API_CALL"
310
- elif "together" in serialized["id"]:
311
- name = "TOGETHER_API_CALL"
1282
+ span_type: SpanType = "span",
1283
+ outputs: Optional[Any] = None,
1284
+ error: Optional[BaseException] = None
1285
+ ):
1286
+ # self._log(f"_end_span_tracking called for: run_id={run_id}, span_type={span_type}")
1287
+
1288
+ # --- Define instance_id early for logging/cleanup ---
1289
+ instance_id = id(self)
1290
+
1291
+ if not trace_client:
1292
+ # Use instance_id defined above
1293
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
1294
+ # self._log(f"{log_prefix} FATAL ERROR in _end_span_tracking: trace_client argument is None for run_id={run_id}. Aborting span end.")
1295
+ return
1296
+
1297
+ # Use instance_id defined above
1298
+ # log_prefix = f"{HANDLER_LOG_PREFIX} [Handler {instance_id}]"
1299
+ # trace_client_instance_id = id(trace_client) if trace_client else 'None'
1300
+ # # print(f"{log_prefix} _end_span_tracking: Using TraceClient ID: {trace_client_instance_id}")
1301
+
1302
+ if run_id not in self._run_id_to_span_id:
1303
+ # self._log(f" WARNING: Attempting to end span for untracked run_id: {run_id}")
1304
+ # Allow root run end to proceed for cleanup/save attempt even if span wasn't tracked
1305
+ if run_id != self._root_run_id:
1306
+ return
1307
+ else:
1308
+ # self._log(f" Allowing root run {run_id} end logic to proceed despite untracked span.")
1309
+ span_id = None # Indicate span wasn't found for duration/metadata lookup
1310
+ else:
1311
+ span_id = self._run_id_to_span_id[run_id]
1312
+
1313
+ start_time = self._span_id_to_start_time.get(span_id) if span_id else None
1314
+ depth = self._span_id_to_depth.get(span_id, 0) if span_id else 0 # Use 0 depth if span_id is None
1315
+ duration = time.time() - start_time if start_time is not None else None
1316
+ # self._log(f" Ending span for run_id={run_id} (span_id={span_id}). Start time={start_time}, Duration={duration}, Depth={depth}")
1317
+
1318
+ # Record output/error first
1319
+ if error:
1320
+ # self._log(f" Recording error for run_id={run_id} (span_id={span_id}): {error}")
1321
+ self._record_output_data(trace_client, run_id, error)
1322
+ elif outputs is not None:
1323
+ # output_repr = repr(outputs)
1324
+ # log_output = (output_repr[:100] + '...') if len(output_repr) > 103 else output_repr
1325
+ # self._log(f" Recording output for run_id={run_id} (span_id={span_id}): {log_output}")
1326
+ self._record_output_data(trace_client, run_id, outputs)
1327
+
1328
+ # Add exit entry (only if span was tracked)
1329
+ if span_id:
1330
+ entry_function_name = "unknown"
1331
+ try:
1332
+ if hasattr(trace_client, 'entries') and trace_client.entries:
1333
+ entry_function_name = next((e.function for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), "unknown")
1334
+ else:
1335
+ # self._log(f" WARNING: Cannot determine function name for exit span_id {span_id}, trace_client.entries missing or empty.")
1336
+ pass
1337
+ except Exception as e:
1338
+ # self._log(f" ERROR finding function name for exit entry span_id {span_id}: {e}")
1339
+ # # print(traceback.format_exc())
1340
+ pass
1341
+
1342
+ try:
1343
+ trace_client.add_entry(TraceEntry(
1344
+ type="exit", span_id=span_id, trace_id=trace_client.trace_id,
1345
+ depth=depth, created_at=time.time(), duration=duration,
1346
+ span_type=span_type, function=entry_function_name
1347
+ ))
1348
+ # self._log(f" Added 'exit' entry for span_id={span_id}, function='{entry_function_name}'")
1349
+ except Exception as e:
1350
+ # self._log(f" ERROR adding 'exit' entry for span_id {span_id}: {e}")
1351
+ # # print(traceback.format_exc())
1352
+ pass
1353
+
1354
+ # Clean up dictionaries for this specific span
1355
+ if span_id in self._span_id_to_start_time: del self._span_id_to_start_time[span_id]
1356
+ if span_id in self._span_id_to_depth: del self._span_id_to_depth[span_id]
1357
+
1358
+ # Pop context token (Sync version) but don't reset
1359
+ token = self._run_id_to_context_token.pop(run_id, None)
1360
+ if token:
1361
+ # self._log(f" Popped token for run_id {run_id} (was {span_id}), not resetting context var.")
1362
+ pass
312
1363
  else:
313
- name = "LLM call"
314
-
315
- self.start_span(name, span_type="llm")
316
- current_trace = self.tracer.get_current_trace()
317
- current_trace.record_input({
318
- 'args': str(messages),
319
- 'kwargs': kwargs
320
- })
321
-
322
- judgeval_callback_handler_var: ContextVar[Optional[JudgevalCallbackHandler]] = ContextVar(
323
- "judgeval_callback_handler", default=None
324
- )
325
-
326
- def set_global_handler(handler: JudgevalCallbackHandler):
327
- if not handler.tracer.enable_monitoring:
328
- return
329
- judgeval_callback_handler_var.set(handler)
330
-
331
- def clear_global_handler():
332
- judgeval_callback_handler_var.set(None)
333
-
334
- register_configure_hook(
335
- context_var=judgeval_callback_handler_var,
336
- inheritable=True,
337
- )
1364
+ # self._log(f" Skipping exit entry and cleanup for run_id {run_id} as span_id was not found.")
1365
+ pass
1366
+
1367
+ # Check if this is the root run ending
1368
+ if run_id == self._root_run_id:
1369
+ trace_saved_successfully = False # Track save success
1370
+ try:
1371
+ # Reset root run id after attempt
1372
+ self._root_run_id = None
1373
+ # Reset input storage for this handler instance
1374
+ self._run_id_to_start_inputs = {}
1375
+ self._log(f"Reset root run ID and input storage for handler {instance_id}.")
1376
+
1377
+ self._log(f"Root run {run_id} finished. Attempting to save trace...")
1378
+ if self._trace_client and not self._trace_saved: # Check if not already saved
1379
+ try:
1380
+ # TODO: Check if trace_client.save needs await if TraceClient becomes async
1381
+ trace_id, _ = self._trace_client.save(overwrite=self._trace_client.overwrite) # Use client's overwrite setting
1382
+ self._log(f"Trace {trace_id} successfully saved.")
1383
+ self._trace_saved = True # Set flag only after successful save
1384
+ trace_saved_successfully = True # Mark success
1385
+ except Exception as e:
1386
+ self._log(f"ERROR saving trace {self._trace_client.trace_id}: {e}")
1387
+ # print(traceback.format_exc())
1388
+ # REMOVED FINALLY BLOCK THAT RESET STATE HERE
1389
+ elif self._trace_client and self._trace_saved:
1390
+ self._log(f" Trace {self._trace_client.trace_id} already saved. Skipping save.")
1391
+ else:
1392
+ self._log(f" WARNING: Root run {run_id} ended, but trace client was None. Cannot save trace.")
1393
+ finally:
1394
+ # --- NEW: Consolidated Cleanup Logic ---
1395
+ # This block executes regardless of save success/failure
1396
+ self._log(f" Performing cleanup for root run {run_id} in handler {instance_id}.")
1397
+ # Reset root run id
1398
+ self._root_run_id = None
1399
+ # Reset input storage for this handler instance
1400
+ self._run_id_to_start_inputs = {}
1401
+ # Reset tracer's active client ONLY IF it was this handler's client
1402
+ if self.tracer._active_trace_client == self._trace_client:
1403
+ self.tracer._active_trace_client = None
1404
+ self._log(" Reset active_trace_client on Tracer.")
1405
+ # Completely remove trace_context_token cleanup as it's not used in sync handler
1406
+ # Optionally: Reset the entire trace client instance for this handler?
1407
+ # self._trace_client = None # Uncomment if handler should reset client completely after root run
1408
+ self._log(f" Cleanup complete for root run {run_id}.")
1409
+ # --- End Cleanup Logic ---
1410
+
1411
+ # NOTE: _record_input_data remains synchronous for now
1412
+ def _record_input_data(self,
1413
+ trace_client: TraceClient,
1414
+ run_id: UUID,
1415
+ inputs: Dict[str, Any]):
1416
+ # self._log(f"_record_input_data called for run_id={run_id}")
1417
+ if run_id not in self._run_id_to_span_id:
1418
+ # self._log(f" WARNING: Attempting to record input for untracked run_id: {run_id}")
1419
+ return
1420
+ if not trace_client:
1421
+ # self._log(f" ERROR: TraceClient is None when trying to record input for run_id={run_id}")
1422
+ return
1423
+
1424
+ span_id = self._run_id_to_span_id[run_id]
1425
+ depth = self._span_id_to_depth.get(span_id, 0)
1426
+ # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
1427
+
1428
+ function_name = "unknown"
1429
+ span_type: SpanType = "span"
1430
+ try:
1431
+ # Find the corresponding 'enter' entry to get the function name and span type
1432
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
1433
+ if enter_entry:
1434
+ function_name = enter_entry.function
1435
+ span_type = enter_entry.span_type
1436
+ # self._log(f" Found function='{function_name}', span_type='{span_type}' for input span_id={span_id}")
1437
+ else:
1438
+ # self._log(f" WARNING: Could not find 'enter' entry for input span_id={span_id}")
1439
+ pass
1440
+ except Exception as e:
1441
+ # self._log(f" ERROR finding enter entry for input span_id {span_id}: {e}")
1442
+ # # print(traceback.format_exc())
1443
+ pass
1444
+
1445
+ try:
1446
+ input_entry = TraceEntry(
1447
+ type="input",
1448
+ span_id=span_id,
1449
+ trace_id=trace_client.trace_id,
1450
+ parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
1451
+ function=function_name,
1452
+ depth=depth,
1453
+ message=f"Input to {function_name}",
1454
+ created_at=time.time(),
1455
+ inputs=inputs,
1456
+ span_type=span_type
1457
+ )
1458
+ trace_client.add_entry(input_entry)
1459
+ # self._log(f" Added 'input' entry directly for span_id={span_id}")
1460
+ except Exception as e:
1461
+ # self._log(f" ERROR adding 'input' entry directly for span_id {span_id}: {e}")
1462
+ # # print(traceback.format_exc())
1463
+ pass
1464
+
1465
+ # NOTE: _record_output_data remains synchronous for now
1466
+ def _record_output_data(self,
1467
+ trace_client: TraceClient,
1468
+ run_id: UUID,
1469
+ output: Any):
1470
+ self._log(f"_record_output_data called for run_id={run_id}")
1471
+ if run_id not in self._run_id_to_span_id:
1472
+ # self._log(f" WARNING: Attempting to record output for untracked run_id: {run_id}")
1473
+ return
1474
+ if not trace_client:
1475
+ # self._log(f" ERROR: TraceClient is None when trying to record output for run_id={run_id}")
1476
+ return
1477
+
1478
+ span_id = self._run_id_to_span_id[run_id]
1479
+ depth = self._span_id_to_depth.get(span_id, 0)
1480
+ # self._log(f" Found span_id={span_id}, depth={depth} for run_id={run_id}")
1481
+
1482
+ function_name = "unknown"
1483
+ span_type: SpanType = "span"
1484
+ try:
1485
+ # Find the corresponding 'enter' entry to get the function name and span type
1486
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
1487
+ if enter_entry:
1488
+ function_name = enter_entry.function
1489
+ span_type = enter_entry.span_type
1490
+ # self._log(f" Found function='{function_name}', span_type='{span_type}' for output span_id={span_id}")
1491
+ else:
1492
+ # self._log(f" WARNING: Could not find 'enter' entry for output span_id={span_id}")
1493
+ pass
1494
+ except Exception as e:
1495
+ # self._log(f" ERROR finding enter entry for output span_id {span_id}: {e}")
1496
+ # # print(traceback.format_exc())
1497
+ pass
1498
+
1499
+ try:
1500
+ output_entry = TraceEntry(
1501
+ type="output",
1502
+ span_id=span_id,
1503
+ trace_id=trace_client.trace_id,
1504
+ parent_span_id=next((e.parent_span_id for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None), # Get parent from enter entry
1505
+ function=function_name,
1506
+ depth=depth,
1507
+ message=f"Output from {function_name}",
1508
+ created_at=time.time(),
1509
+ output=output, # Langchain outputs are typically serializable directly
1510
+ span_type=span_type
1511
+ )
1512
+ trace_client.add_entry(output_entry)
1513
+ self._log(f" Added 'output' entry directly for span_id={span_id}")
1514
+ except Exception as e:
1515
+ self._log(f" ERROR adding 'output' entry directly for span_id {span_id}: {e}")
1516
+ # print(traceback.format_exc())
1517
+
1518
+ # --- Async Callback Methods ---
1519
+
1520
+ async def on_retriever_start(self, serialized: Dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
1521
+ handler_instance_id = id(self)
1522
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1523
+ serialized_name = serialized.get('name', 'Unknown') if serialized else "Unknown (Serialized=None)"
1524
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1525
+ # print(f"{log_prefix} ENTERING on_retriever_start: name='{serialized_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1526
+ try:
1527
+ name = f"RETRIEVER_{(serialized_name).upper()}"
1528
+ # Pass parent_run_id
1529
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
1530
+ if not trace_client: return
1531
+ inputs = {'query': query, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
1532
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="retriever", inputs=inputs)
1533
+ except Exception as e:
1534
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1535
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1536
+ # print(traceback.format_exc())
1537
+
1538
+ async def on_retriever_end(self, documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1539
+ handler_instance_id = id(self)
1540
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1541
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1542
+ # print(f"{log_prefix} ENTERING on_retriever_end: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1543
+ try:
1544
+ # Pass parent_run_id
1545
+ trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd") # Corrected call
1546
+ if not trace_client: return
1547
+ doc_summary = [{"index": i, "page_content": doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content, "metadata": doc.metadata} for i, doc in enumerate(documents)]
1548
+ outputs = {"document_count": len(documents), "documents": doc_summary, "kwargs": kwargs}
1549
+ self._end_span_tracking(trace_client, run_id, span_type="retriever", outputs=outputs)
1550
+ except Exception as e:
1551
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1552
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_retriever_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1553
+ # print(traceback.format_exc())
1554
+
1555
+ async def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
1556
+ handler_instance_id = id(self)
1557
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1558
+ # Handle potential None for serialized safely
1559
+ serialized_name = serialized.get('name') if serialized else None
1560
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1561
+ # Log the potentially generic or specific name found in serialized
1562
+ log_name = serialized_name if serialized_name else "Unknown (Serialized=None)"
1563
+ # print(f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}] ENTERING on_chain_start: serialized_name='{log_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1564
+
1565
+ try:
1566
+ # Determine the best name and span type
1567
+ name = "Unknown Chain" # Default
1568
+ span_type: SpanType = "chain"
1569
+ node_name = metadata.get("langgraph_node") if metadata else None
1570
+ is_langgraph_root = kwargs.get('name') == 'LangGraph' # Check kwargs
1571
+ is_potential_root_event = parent_run_id is None
1572
+
1573
+ # Define generic names to ignore if node_name is not present
1574
+ GENERIC_NAMES = ["RunnableSequence", "RunnableParallel", "RunnableLambda", "LangGraph", "__start__", "__end__"]
1575
+
1576
+ if node_name:
1577
+ name = node_name
1578
+ self._log(f" LangGraph Node Start Detected: '{name}', run_id={run_id}, parent_run_id={parent_run_id}")
1579
+ if name not in self.executed_nodes: self.executed_nodes.append(name)
1580
+ elif serialized_name and serialized_name not in GENERIC_NAMES:
1581
+ name = serialized_name
1582
+ self._log(f" LangGraph Functional Step (Router?): '{name}', run_id={run_id}, parent_run_id={parent_run_id}")
1583
+ # Correct root detection: Should primarily rely on parent_run_id being None for the *first* event.
1584
+ # kwargs name='LangGraph' might appear later.
1585
+ elif is_potential_root_event: # Check if it's the potential root event
1586
+ # Use the serialized name if available and not generic, otherwise default to 'LangGraph'
1587
+ if serialized_name and serialized_name not in GENERIC_NAMES:
1588
+ name = serialized_name
1589
+ else:
1590
+ name = "LangGraph"
1591
+ self._log(f" LangGraph Root Start Detected (parent_run_id=None): Name='{name}', run_id={run_id}")
1592
+ if self._root_run_id is None: # Only set root_run_id once
1593
+ self._log(f" Setting root run ID to {run_id}")
1594
+ self._root_run_id = run_id
1595
+ # Defer trace client name update until client is ensured
1596
+ elif serialized_name: # Fallback if node_name missing and serialized_name was generic or root wasn't detected
1597
+ name = serialized_name
1598
+ self._log(f" Fallback to serialized_name: '{name}', run_id={run_id}")
1599
+
1600
+ # Ensure trace client exists (using the determined name for initialization if needed)
1601
+ # Pass parent_run_id
1602
+ trace_client = self._ensure_trace_client(run_id, name) # FIXED: Removed parent_run_id
1603
+ if not trace_client:
1604
+ # print(f"{log_prefix} No trace client obtained in on_chain_start for {run_id} ('{name}').")
1605
+ return
1606
+
1607
+ # --- Update Trace Name if Root (Moved After Client Ensure) ---
1608
+ if is_potential_root_event and run_id == self._root_run_id and trace_client.name != name:
1609
+ self._log(f" Updating trace name from '{trace_client.name}' to '{name}' for root run {run_id}")
1610
+ trace_client.name = name
1611
+ # --- End Update Trace Name ---
1612
+
1613
+ # Start span tracking using the determined name and span_type
1614
+ combined_inputs = {'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
1615
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type=span_type, inputs=combined_inputs)
1616
+ # --- Store inputs for potential evaluation later ---
1617
+ self._run_id_to_start_inputs[run_id] = inputs # Store the raw inputs dict
1618
+ self._log(f" Stored inputs for run_id {run_id}")
1619
+ # --- End Store inputs ---
1620
+
1621
+ except Exception as e:
1622
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1623
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1624
+ # print(traceback.format_exc())
1625
+
1626
+ async def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any) -> Any:
1627
+ """
1628
+ Ends span tracking for a chain/node and attempts evaluation if applicable.
1629
+ """
1630
+ # --- Existing logging and client check ---
1631
+ instance_id = id(self)
1632
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {instance_id}]"
1633
+ self._log(f"{log_prefix} ENTERING on_chain_end: run_id={run_id}. Current TraceClient ID: {id(self._trace_client) if self._trace_client else 'None'}")
1634
+ client = self._ensure_trace_client(run_id, "on_chain_end") # Ensure client exists
1635
+ if not client:
1636
+ self._log(f"{log_prefix} No TraceClient found for on_chain_end ({run_id}). Aborting.")
1637
+ return # Early exit if no client
1638
+
1639
+ # --- Get span_id associated with this chain run ---
1640
+ span_id = self._run_id_to_span_id.get(run_id)
1641
+
1642
+ # --- Existing span ending logic ---
1643
+ # Determine span_type for end_span_tracking (copied from sync handler)
1644
+ end_span_type: SpanType = "chain" # Default
1645
+ if span_id: # Check if span_id was actually found
1646
+ try:
1647
+ if hasattr(client, 'entries') and client.entries:
1648
+ enter_entry = next((e for e in reversed(client.entries) if e.span_id == span_id and e.type == "enter"), None)
1649
+ if enter_entry: end_span_type = enter_entry.span_type
1650
+ else: self._log(f" WARNING: trace_client.entries empty/missing for on_chain_end span_id={span_id}")
1651
+ except Exception as e:
1652
+ self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_end: {e}")
1653
+ else:
1654
+ self._log(f" WARNING: No span_id found for run_id {run_id} in on_chain_end, using default span_type='chain'.")
1655
+
1656
+ # Prepare outputs for end tracking (moved down)
1657
+ combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
1658
+
1659
+ # Call end_span_tracking with potentially determined span_type
1660
+ self._end_span_tracking(client, run_id, span_type=end_span_type, outputs=combined_outputs)
1661
+
1662
+ # --- Root node cleanup REMOVED - Now handled in _end_span_tracking ---
1663
+
1664
+ # --- NEW: Attempt Evaluation by checking output metadata ---
1665
+ eval_config: Optional[EvaluationConfig] = None
1666
+ node_name = "unknown_node" # Default node name
1667
+ # Ensure client exists before proceeding with eval logic that uses it
1668
+ if client:
1669
+ if span_id: # Try to find the node name from the 'enter' entry
1670
+ try:
1671
+ if hasattr(client, 'entries') and client.entries:
1672
+ enter_entry = next((e for e in reversed(client.entries) if e.span_id == span_id and e.type == "enter"), None)
1673
+ if enter_entry: node_name = enter_entry.function
1674
+ except Exception as e:
1675
+ self._log(f" ERROR finding node name for span_id {span_id} in on_chain_end: {e}")
1676
+
1677
+ if span_id and "_judgeval_eval" in outputs: # Only attempt if span exists and key is present
1678
+ raw_eval_config = outputs.get("_judgeval_eval")
1679
+ if isinstance(raw_eval_config, EvaluationConfig):
1680
+ eval_config = raw_eval_config
1681
+ self._log(f"{log_prefix} Found valid EvaluationConfig in outputs for node='{node_name}'.")
1682
+ elif isinstance(raw_eval_config, dict):
1683
+ # Attempt to reconstruct from dict if needed (e.g., if state serialization occurred)
1684
+ try:
1685
+ # Basic check for required keys before attempting reconstruction
1686
+ if "scorers" in raw_eval_config and "example" in raw_eval_config:
1687
+ # Example might also be a dict, try reconstructing it
1688
+ example_data = raw_eval_config["example"]
1689
+ reconstructed_example = Example(**example_data) if isinstance(example_data, dict) else example_data
1690
+
1691
+ if isinstance(reconstructed_example, Example):
1692
+ eval_config = EvaluationConfig(
1693
+ scorers=raw_eval_config["scorers"], # Assumes scorers are serializable or passed correctly
1694
+ example=reconstructed_example,
1695
+ model=raw_eval_config.get("model"),
1696
+ log_results=raw_eval_config.get("log_results", True)
1697
+ )
1698
+ self._log(f"{log_prefix} Reconstructed EvaluationConfig from dict in outputs for node='{node_name}'.")
1699
+ else:
1700
+ self._log(f"{log_prefix} Could not reconstruct Example from dict in _judgeval_eval for node='{node_name}'. Skipping evaluation.")
1701
+ else:
1702
+ self._log(f"{log_prefix} Dict in _judgeval_eval missing required keys ('scorers', 'example') for node='{node_name}'. Skipping evaluation.")
1703
+ except Exception as recon_e:
1704
+ self._log(f"{log_prefix} ERROR attempting to reconstruct EvaluationConfig from dict for node='{node_name}': {recon_e}")
1705
+ # print(traceback.format_exc()) # Print traceback for reconstruction errors
1706
+ else:
1707
+ self._log(f"{log_prefix} Found '_judgeval_eval' key in outputs for node='{node_name}', but it wasn't an EvaluationConfig object or reconstructable dict. Skipping evaluation.")
1708
+
1709
+
1710
+ if eval_config and span_id: # Check eval_config *and* span_id again
1711
+ self._log(f"{log_prefix} Submitting evaluation for span_id={span_id}")
1712
+ try:
1713
+ # Ensure example has trace_id set if not already present
1714
+ if not hasattr(eval_config.example, 'trace_id') or not eval_config.example.trace_id:
1715
+ # Use the correct variable name 'client' here for the async handler
1716
+ eval_config.example.trace_id = client.trace_id
1717
+ self._log(f"{log_prefix} Set trace_id={client.trace_id} on evaluation example.")
1718
+
1719
+ # Call async_evaluate on the TraceClient instance ('client')
1720
+ # Use the correct variable name 'client' here for the async handler
1721
+ client.async_evaluate(
1722
+ scorers=eval_config.scorers,
1723
+ example=eval_config.example,
1724
+ model=eval_config.model,
1725
+ log_results=eval_config.log_results,
1726
+ span_id=span_id # Pass the specific span_id for this node run
1727
+ )
1728
+ self._log(f"{log_prefix} Evaluation submitted successfully for span_id={span_id}.")
1729
+ except Exception as eval_e:
1730
+ self._log(f"{log_prefix} ERROR submitting evaluation for span_id={span_id}: {eval_e}")
1731
+ # print(traceback.format_exc()) # Print traceback for evaluation errors
1732
+ elif "_judgeval_eval" in outputs and not span_id:
1733
+ self._log(f"{log_prefix} WARNING: Found _judgeval_eval in outputs, but span_id for run_id {run_id} was not found. Cannot submit evaluation.")
1734
+ # --- End NEW Evaluation Logic ---
1735
+
1736
+ async def on_chain_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1737
+ handler_instance_id = id(self)
1738
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1739
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1740
+ # print(f"{log_prefix} ENTERING on_chain_error: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1741
+ try:
1742
+ # Pass parent_run_id
1743
+ trace_client = self._ensure_trace_client(run_id, "ChainError") # FIXED: Removed parent_run_id
1744
+ if not trace_client: return
1745
+
1746
+ span_id = self._run_id_to_span_id.get(run_id)
1747
+ span_type: SpanType = "chain"
1748
+ if span_id:
1749
+ try:
1750
+ if hasattr(trace_client, 'entries') and trace_client.entries:
1751
+ enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == span_id and e.type == "enter"), None)
1752
+ if enter_entry: span_type = enter_entry.span_type
1753
+ else: self._log(f" WARNING: trace_client.entries not available for on_chain_error span_id={span_id}")
1754
+ except Exception as e:
1755
+ self._log(f" ERROR finding enter entry for span_id {span_id} in on_chain_error: {e}")
1756
+
1757
+ self._end_span_tracking(trace_client, run_id, span_type=span_type, error=error)
1758
+ except Exception as e:
1759
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1760
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chain_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1761
+ # print(traceback.format_exc())
1762
+
1763
+ async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
1764
+ handler_instance_id = id(self)
1765
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1766
+ # Handle potential None for serialized
1767
+ name = serialized.get("name", "Unnamed Tool") if serialized else "Unknown Tool (Serialized=None)"
1768
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1769
+ # print(f"{log_prefix} ENTERING on_tool_start: name='{name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1770
+ try:
1771
+ # Pass parent_run_id
1772
+ trace_client = self._ensure_trace_client(run_id, name) # FIXED: Removed parent_run_id
1773
+ if not trace_client: return
1774
+
1775
+ combined_inputs = {'input_str': input_str, 'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
1776
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="tool", inputs=combined_inputs)
1777
+
1778
+ # --- Track executed tools (logic remains the same) ---
1779
+ if name not in self.executed_tools: self.executed_tools.append(name)
1780
+ parent_node_name = None
1781
+ if parent_run_id and parent_run_id in self._run_id_to_span_id:
1782
+ parent_span_id = self._run_id_to_span_id[parent_run_id]
1783
+ try:
1784
+ if hasattr(trace_client, 'entries') and trace_client.entries:
1785
+ parent_enter_entry = next((e for e in reversed(trace_client.entries) if e.span_id == parent_span_id and e.type == "enter" and e.span_type == "chain"), None)
1786
+ if parent_enter_entry:
1787
+ parent_node_name = parent_enter_entry.function
1788
+ else:
1789
+ self._log(f" WARNING: trace_client.entries not available for parent node {parent_span_id}")
1790
+ except Exception as e:
1791
+ self._log(f" ERROR finding parent node name for tool start span_id {parent_span_id}: {e}")
1792
+
1793
+ node_tool = f"{parent_node_name}:{name}" if parent_node_name else name
1794
+ if node_tool not in self.executed_node_tools: self.executed_node_tools.append(node_tool)
1795
+ self._log(f" Tracked node_tool: {node_tool}")
1796
+ except Exception as e:
1797
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1798
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1799
+ # print(traceback.format_exc())
1800
+
1801
+ async def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1802
+ handler_instance_id = id(self)
1803
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1804
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1805
+ # print(f"{log_prefix} ENTERING on_tool_end: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1806
+ try:
1807
+ # Pass parent_run_id
1808
+ trace_client = self._ensure_trace_client(run_id, "ToolEnd") # FIXED: Removed parent_run_id
1809
+ if not trace_client: return
1810
+ outputs = {"output": output, "kwargs": kwargs}
1811
+ self._end_span_tracking(trace_client, run_id, span_type="tool", outputs=outputs)
1812
+ except Exception as e:
1813
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1814
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1815
+ # print(traceback.format_exc())
1816
+
1817
+ async def on_tool_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1818
+ handler_instance_id = id(self)
1819
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1820
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1821
+ # print(f"{log_prefix} ENTERING on_tool_error: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1822
+ try:
1823
+ # Pass parent_run_id
1824
+ trace_client = self._ensure_trace_client(run_id, "ToolError") # FIXED: Removed parent_run_id
1825
+ if not trace_client: return
1826
+ self._end_span_tracking(trace_client, run_id, span_type="tool", error=error)
1827
+ except Exception as e:
1828
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1829
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_tool_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1830
+ # print(traceback.format_exc())
1831
+
1832
+ async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
1833
+ handler_instance_id = id(self)
1834
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1835
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1836
+ llm_name = name or serialized.get("name", "LLM Call")
1837
+ # print(f"{log_prefix} ENTERING on_llm_start: name='{llm_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1838
+ try:
1839
+ # Pass parent_run_id
1840
+ trace_client = self._ensure_trace_client(run_id, llm_name) # FIXED: Removed parent_run_id
1841
+ if not trace_client: return
1842
+ inputs = {'prompts': prompts, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
1843
+ self._start_span_tracking(trace_client, run_id, parent_run_id, llm_name, span_type="llm", inputs=inputs)
1844
+ except Exception as e:
1845
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1846
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1847
+ # print(traceback.format_exc())
1848
+
1849
+ async def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1850
+ handler_instance_id = id(self)
1851
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1852
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1853
+
1854
+ try:
1855
+ trace_client = self._ensure_trace_client(run_id, "LLMEnd")
1856
+ if not trace_client:
1857
+ return
1858
+
1859
+ outputs = {"response": response, "kwargs": kwargs}
1860
+ # --- Token Usage Extraction and Accumulation ---
1861
+ token_usage = None
1862
+ prompt_tokens = None # Use standard name
1863
+ completion_tokens = None # Use standard name
1864
+ total_tokens = None
1865
+ try:
1866
+ if response.llm_output and isinstance(response.llm_output, dict):
1867
+ # Check for OpenAI/standard 'token_usage' first
1868
+ if 'token_usage' in response.llm_output:
1869
+ token_usage = response.llm_output.get('token_usage')
1870
+ if token_usage and isinstance(token_usage, dict):
1871
+ self._log(f" Extracted OpenAI token usage for run_id={run_id}: {token_usage}")
1872
+ prompt_tokens = token_usage.get('prompt_tokens')
1873
+ completion_tokens = token_usage.get('completion_tokens')
1874
+ total_tokens = token_usage.get('total_tokens')
1875
+ # Check for Anthropic 'usage'
1876
+ elif 'usage' in response.llm_output:
1877
+ token_usage = response.llm_output.get('usage')
1878
+ if token_usage and isinstance(token_usage, dict):
1879
+ self._log(f" Extracted Anthropic token usage for run_id={run_id}: {token_usage}")
1880
+ prompt_tokens = token_usage.get('input_tokens') # Anthropic uses input_tokens
1881
+ completion_tokens = token_usage.get('output_tokens') # Anthropic uses output_tokens
1882
+ # Calculate total if possible
1883
+ if prompt_tokens is not None and completion_tokens is not None:
1884
+ total_tokens = prompt_tokens + completion_tokens
1885
+ else:
1886
+ self._log(f" Could not calculate total_tokens from Anthropic usage: input={prompt_tokens}, output={completion_tokens}")
1887
+
1888
+ # Add to outputs if any tokens were found
1889
+ if prompt_tokens is not None or completion_tokens is not None or total_tokens is not None:
1890
+ outputs['usage'] = { # Add under 'usage' key
1891
+ 'prompt_tokens': prompt_tokens, # Use standard keys
1892
+ 'completion_tokens': completion_tokens,
1893
+ 'total_tokens': total_tokens
1894
+ }
1895
+ else:
1896
+ self._log(f" Could not extract token usage structure from llm_output for run_id={run_id}")
1897
+ else: self._log(f" llm_output not available/dict for run_id={run_id}")
1898
+ except Exception as e:
1899
+ self._log(f" ERROR extracting token usage for run_id={run_id}: {e}")
1900
+
1901
+ self._end_span_tracking(trace_client, run_id, span_type="llm", outputs=outputs)
1902
+ except Exception as e:
1903
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1904
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_end for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1905
+ # print(traceback.format_exc())
1906
+
1907
+ async def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1908
+ handler_instance_id = id(self)
1909
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1910
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1911
+ # print(f"{log_prefix} ENTERING on_llm_error: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1912
+ try:
1913
+ # Pass parent_run_id
1914
+ trace_client = self._ensure_trace_client(run_id, "LLMError") # FIXED: Removed parent_run_id
1915
+ if not trace_client: return
1916
+ self._end_span_tracking(trace_client, run_id, span_type="llm", error=error)
1917
+ except Exception as e:
1918
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1919
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_llm_error for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1920
+ # print(traceback.format_exc())
1921
+
1922
+ async def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
1923
+ # Reuse on_llm_start logic, adding message formatting if needed
1924
+ handler_instance_id = id(self)
1925
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1926
+ chat_model_name = name or serialized.get("name", "ChatModel Call")
1927
+ # Add OPENAI_API_CALL suffix if model is OpenAI and not present
1928
+ is_openai = any(key.startswith('openai') for key in serialized.get('secrets', {}).keys()) or 'openai' in chat_model_name.lower()
1929
+ is_anthropic = any(key.startswith('anthropic') for key in serialized.get('secrets', {}).keys()) or 'anthropic' in chat_model_name.lower() or 'claude' in chat_model_name.lower()
1930
+ is_together = any(key.startswith('together') for key in serialized.get('secrets', {}).keys()) or 'together' in chat_model_name.lower()
1931
+ # Add more checks for other providers like Google if needed
1932
+ is_google = any(key.startswith('google') for key in serialized.get('secrets', {}).keys()) or 'google' in chat_model_name.lower() or 'gemini' in chat_model_name.lower()
1933
+
1934
+ if is_openai and "OPENAI_API_CALL" not in chat_model_name:
1935
+ chat_model_name = f"{chat_model_name} OPENAI_API_CALL"
1936
+ elif is_anthropic and "ANTHROPIC_API_CALL" not in chat_model_name:
1937
+ chat_model_name = f"{chat_model_name} ANTHROPIC_API_CALL"
1938
+ elif is_together and "TOGETHER_API_CALL" not in chat_model_name:
1939
+ chat_model_name = f"{chat_model_name} TOGETHER_API_CALL"
1940
+ # Add elif for Google: check for 'google' or 'gemini'?
1941
+ # elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
1942
+ # chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
1943
+ elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
1944
+ chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
1945
+
1946
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1947
+ # print(f"{log_prefix} ENTERING on_chat_model_start: name='{chat_model_name}', run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1948
+ try:
1949
+ # trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name) # Corrected call << INCORRECT COMMENT
1950
+ trace_client = self._ensure_trace_client(run_id, chat_model_name) # FIXED: Removed parent_run_id
1951
+ if not trace_client: return
1952
+ inputs = {'messages': messages, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
1953
+ self._start_span_tracking(trace_client, run_id, parent_run_id, chat_model_name, span_type="llm", inputs=inputs) # Use 'llm' span_type for consistency
1954
+ except Exception as e:
1955
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1956
+ self._log(f"{log_prefix} UNCAUGHT EXCEPTION in on_chat_model_start for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}")
1957
+ # print(traceback.format_exc())
1958
+
1959
+ # --- Agent Methods (Async versions - ensure parent_run_id passed if needed) ---
1960
+ async def on_agent_action(self, action: AgentAction, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1961
+ handler_instance_id = id(self)
1962
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1963
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1964
+ # print(f"{log_prefix} ENTERING on_agent_action: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1965
+ try:
1966
+ # trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentAction") # Corrected call << INCORRECT COMMENT
1967
+ trace_client = self._ensure_trace_client(run_id, "AgentAction") # FIXED: Removed parent_run_id
1968
+ if not trace_client: return
1969
+ # inputs = {\"action\": action, \"kwargs\": kwargs}
1970
+ inputs = {"action": action, "kwargs": kwargs} # FIXED: Removed bad escapes
1971
+ # Agent actions often lead to tool calls, treat as a distinct step
1972
+ # self._start_span_tracking(trace_client, run_id, parent_run_id, name=\"AgentAction\", span_type=\"agent_action\", inputs=inputs)
1973
+ self._start_span_tracking(trace_client, run_id, parent_run_id, name="AgentAction", span_type="agent_action", inputs=inputs) # FIXED: Removed bad escapes
1974
+ except Exception as e:
1975
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1976
+ # self._log(f\"{log_prefix} UNCAUGHT EXCEPTION in on_agent_action for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}\")
1977
+ self._log(f'{log_prefix} UNCAUGHT EXCEPTION in on_agent_action for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}') # FIXED: Changed f-string quotes
1978
+ # print(traceback.format_exc())
1979
+
1980
+
1981
+ async def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
1982
+ handler_instance_id = id(self)
1983
+ log_prefix = f"{HANDLER_LOG_PREFIX} [Async Handler {handler_instance_id}]"
1984
+ tc_id_on_entry = id(self._trace_client) if self._trace_client else 'None'
1985
+ # print(f"{log_prefix} ENTERING on_agent_finish: run_id={run_id}. Current TraceClient ID: {tc_id_on_entry}")
1986
+ try:
1987
+ # trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentFinish") # Corrected call << INCORRECT COMMENT
1988
+ trace_client = self._ensure_trace_client(run_id, "AgentFinish") # FIXED: Removed parent_run_id
1989
+ if not trace_client: return
1990
+ # outputs = {\"finish\": finish, \"kwargs\": kwargs}
1991
+ outputs = {"finish": finish, "kwargs": kwargs} # FIXED: Removed bad escapes
1992
+ # Corresponds to the end of an AgentAction span? Or a chain span? Assuming agent_action here.
1993
+ # self._end_span_tracking(trace_client, run_id, span_type=\"agent_action\", outputs=outputs)
1994
+ self._end_span_tracking(trace_client, run_id, span_type="agent_action", outputs=outputs) # FIXED: Removed bad escapes
1995
+ except Exception as e:
1996
+ tc_id_on_error = id(self._trace_client) if self._trace_client else 'None'
1997
+ # self._log(f\"{log_prefix} UNCAUGHT EXCEPTION in on_agent_finish for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}\")
1998
+ self._log(f'{log_prefix} UNCAUGHT EXCEPTION in on_agent_finish for run_id={run_id} (TraceClient ID: {tc_id_on_error}): {e}') # FIXED: Changed f-string quotes
1999
+ # print(traceback.format_exc())