kalibr 1.0.25__py3-none-any.whl → 1.1.2a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kalibr/__init__.py +170 -3
- kalibr/__main__.py +3 -203
- kalibr/capsule_middleware.py +108 -0
- kalibr/cli/__init__.py +5 -0
- kalibr/cli/capsule_cmd.py +174 -0
- kalibr/cli/deploy_cmd.py +114 -0
- kalibr/cli/main.py +67 -0
- kalibr/cli/run.py +203 -0
- kalibr/cli/serve.py +59 -0
- kalibr/client.py +293 -0
- kalibr/collector.py +173 -0
- kalibr/context.py +132 -0
- kalibr/cost_adapter.py +222 -0
- kalibr/decorators.py +140 -0
- kalibr/instrumentation/__init__.py +13 -0
- kalibr/instrumentation/anthropic_instr.py +282 -0
- kalibr/instrumentation/base.py +108 -0
- kalibr/instrumentation/google_instr.py +281 -0
- kalibr/instrumentation/openai_instr.py +265 -0
- kalibr/instrumentation/registry.py +153 -0
- kalibr/kalibr.py +144 -230
- kalibr/kalibr_app.py +53 -314
- kalibr/middleware/__init__.py +5 -0
- kalibr/middleware/auto_tracer.py +356 -0
- kalibr/models.py +41 -0
- kalibr/redaction.py +44 -0
- kalibr/schemas.py +116 -0
- kalibr/simple_tracer.py +258 -0
- kalibr/tokens.py +52 -0
- kalibr/trace_capsule.py +296 -0
- kalibr/trace_models.py +201 -0
- kalibr/tracer.py +354 -0
- kalibr/types.py +25 -93
- kalibr/utils.py +198 -0
- kalibr-1.1.2a0.dist-info/METADATA +236 -0
- kalibr-1.1.2a0.dist-info/RECORD +48 -0
- kalibr-1.1.2a0.dist-info/entry_points.txt +2 -0
- kalibr-1.1.2a0.dist-info/licenses/LICENSE +21 -0
- kalibr-1.1.2a0.dist-info/top_level.txt +4 -0
- kalibr_crewai/__init__.py +65 -0
- kalibr_crewai/callbacks.py +539 -0
- kalibr_crewai/instrumentor.py +513 -0
- kalibr_langchain/__init__.py +47 -0
- kalibr_langchain/async_callback.py +850 -0
- kalibr_langchain/callback.py +1064 -0
- kalibr_openai_agents/__init__.py +43 -0
- kalibr_openai_agents/processor.py +554 -0
- kalibr/deployment.py +0 -41
- kalibr/packager.py +0 -43
- kalibr/runtime_router.py +0 -138
- kalibr/schema_generators.py +0 -159
- kalibr/validator.py +0 -70
- kalibr-1.0.25.data/data/examples/README.md +0 -173
- kalibr-1.0.25.data/data/examples/basic_kalibr_example.py +0 -66
- kalibr-1.0.25.data/data/examples/enhanced_kalibr_example.py +0 -347
- kalibr-1.0.25.dist-info/METADATA +0 -231
- kalibr-1.0.25.dist-info/RECORD +0 -19
- kalibr-1.0.25.dist-info/entry_points.txt +0 -2
- kalibr-1.0.25.dist-info/licenses/LICENSE +0 -11
- kalibr-1.0.25.dist-info/top_level.txt +0 -1
- {kalibr-1.0.25.dist-info → kalibr-1.1.2a0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,1064 @@
|
|
|
1
|
+
"""Kalibr Callback Handler for LangChain.
|
|
2
|
+
|
|
3
|
+
This module provides the main callback handler that integrates LangChain
|
|
4
|
+
with Kalibr's observability platform.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import atexit
|
|
8
|
+
import hashlib
|
|
9
|
+
import hmac
|
|
10
|
+
import os
|
|
11
|
+
import queue
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
import traceback
|
|
15
|
+
import uuid
|
|
16
|
+
from datetime import datetime, timezone
|
|
17
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
|
18
|
+
|
|
19
|
+
import httpx
|
|
20
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
21
|
+
from langchain_core.agents import AgentAction, AgentFinish
|
|
22
|
+
from langchain_core.documents import Document
|
|
23
|
+
from langchain_core.messages import BaseMessage
|
|
24
|
+
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
|
|
25
|
+
|
|
26
|
+
# Import Kalibr cost adapters
|
|
27
|
+
try:
|
|
28
|
+
from kalibr.cost_adapter import CostAdapterFactory
|
|
29
|
+
except ImportError:
|
|
30
|
+
CostAdapterFactory = None
|
|
31
|
+
|
|
32
|
+
# Import tiktoken for token counting
|
|
33
|
+
try:
|
|
34
|
+
import tiktoken
|
|
35
|
+
HAS_TIKTOKEN = True
|
|
36
|
+
except ImportError:
|
|
37
|
+
HAS_TIKTOKEN = False
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _count_tokens(text: str, model: str) -> int:
|
|
41
|
+
"""Count tokens for given text and model."""
|
|
42
|
+
if not text:
|
|
43
|
+
return 0
|
|
44
|
+
|
|
45
|
+
if HAS_TIKTOKEN and "gpt" in model.lower():
|
|
46
|
+
try:
|
|
47
|
+
encoding = tiktoken.encoding_for_model(model)
|
|
48
|
+
return len(encoding.encode(text))
|
|
49
|
+
except Exception:
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
# Fallback: approximate (1 token ~= 4 chars)
|
|
53
|
+
return len(text) // 4
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _get_provider_from_model(model: str) -> str:
|
|
57
|
+
"""Infer provider from model name."""
|
|
58
|
+
model_lower = model.lower()
|
|
59
|
+
|
|
60
|
+
if any(x in model_lower for x in ["gpt", "text-davinci", "text-embedding", "whisper", "dall-e"]):
|
|
61
|
+
return "openai"
|
|
62
|
+
elif any(x in model_lower for x in ["claude"]):
|
|
63
|
+
return "anthropic"
|
|
64
|
+
elif any(x in model_lower for x in ["gemini", "palm", "bison"]):
|
|
65
|
+
return "google"
|
|
66
|
+
elif any(x in model_lower for x in ["cohere", "command"]):
|
|
67
|
+
return "cohere"
|
|
68
|
+
else:
|
|
69
|
+
return "custom"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _serialize_for_metadata(obj: Any) -> Any:
|
|
73
|
+
"""Serialize objects for JSON metadata."""
|
|
74
|
+
if obj is None:
|
|
75
|
+
return None
|
|
76
|
+
elif isinstance(obj, (str, int, float, bool)):
|
|
77
|
+
return obj
|
|
78
|
+
elif isinstance(obj, (list, tuple)):
|
|
79
|
+
return [_serialize_for_metadata(item) for item in obj]
|
|
80
|
+
elif isinstance(obj, dict):
|
|
81
|
+
return {k: _serialize_for_metadata(v) for k, v in obj.items()}
|
|
82
|
+
elif hasattr(obj, "dict"):
|
|
83
|
+
return obj.dict()
|
|
84
|
+
elif hasattr(obj, "__dict__"):
|
|
85
|
+
return {k: _serialize_for_metadata(v) for k, v in obj.__dict__.items()
|
|
86
|
+
if not k.startswith("_")}
|
|
87
|
+
else:
|
|
88
|
+
return str(obj)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class SpanTracker:
|
|
92
|
+
"""Tracks active spans and their metadata."""
|
|
93
|
+
|
|
94
|
+
def __init__(self):
|
|
95
|
+
self.spans: Dict[str, Dict[str, Any]] = {}
|
|
96
|
+
self._lock = threading.Lock()
|
|
97
|
+
|
|
98
|
+
def start_span(
|
|
99
|
+
self,
|
|
100
|
+
run_id: str,
|
|
101
|
+
trace_id: str,
|
|
102
|
+
parent_run_id: Optional[str],
|
|
103
|
+
operation: str,
|
|
104
|
+
span_type: str,
|
|
105
|
+
**kwargs
|
|
106
|
+
) -> Dict[str, Any]:
|
|
107
|
+
"""Start a new span."""
|
|
108
|
+
span_id = str(uuid.uuid4())
|
|
109
|
+
|
|
110
|
+
with self._lock:
|
|
111
|
+
parent_span_id = None
|
|
112
|
+
if parent_run_id and parent_run_id in self.spans:
|
|
113
|
+
parent_span_id = self.spans[parent_run_id].get("span_id")
|
|
114
|
+
|
|
115
|
+
span = {
|
|
116
|
+
"span_id": span_id,
|
|
117
|
+
"trace_id": trace_id,
|
|
118
|
+
"parent_span_id": parent_span_id,
|
|
119
|
+
"operation": operation,
|
|
120
|
+
"span_type": span_type,
|
|
121
|
+
"ts_start": datetime.now(timezone.utc),
|
|
122
|
+
"status": "success",
|
|
123
|
+
**kwargs
|
|
124
|
+
}
|
|
125
|
+
self.spans[run_id] = span
|
|
126
|
+
return span
|
|
127
|
+
|
|
128
|
+
def end_span(self, run_id: str) -> Optional[Dict[str, Any]]:
|
|
129
|
+
"""End a span and return its data."""
|
|
130
|
+
with self._lock:
|
|
131
|
+
if run_id in self.spans:
|
|
132
|
+
span = self.spans.pop(run_id)
|
|
133
|
+
span["ts_end"] = datetime.now(timezone.utc)
|
|
134
|
+
span["duration_ms"] = int(
|
|
135
|
+
(span["ts_end"] - span["ts_start"]).total_seconds() * 1000
|
|
136
|
+
)
|
|
137
|
+
return span
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
def update_span(self, run_id: str, **kwargs):
|
|
141
|
+
"""Update span with additional data."""
|
|
142
|
+
with self._lock:
|
|
143
|
+
if run_id in self.spans:
|
|
144
|
+
self.spans[run_id].update(kwargs)
|
|
145
|
+
|
|
146
|
+
def get_span(self, run_id: str) -> Optional[Dict[str, Any]]:
|
|
147
|
+
"""Get span data."""
|
|
148
|
+
with self._lock:
|
|
149
|
+
return self.spans.get(run_id)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class KalibrCallbackHandler(BaseCallbackHandler):
|
|
153
|
+
"""LangChain callback handler for Kalibr observability.
|
|
154
|
+
|
|
155
|
+
This handler captures telemetry from LangChain components and sends
|
|
156
|
+
them to the Kalibr backend for analysis and visualization.
|
|
157
|
+
|
|
158
|
+
Supported callbacks:
|
|
159
|
+
- LLM start/end/error
|
|
160
|
+
- Chat model start/end
|
|
161
|
+
- Chain start/end/error
|
|
162
|
+
- Tool start/end/error
|
|
163
|
+
- Agent action/finish
|
|
164
|
+
- Retriever start/end
|
|
165
|
+
- Text generation (streaming)
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
api_key: Kalibr API key (or KALIBR_API_KEY env var)
|
|
169
|
+
endpoint: Backend endpoint URL (or KALIBR_ENDPOINT env var)
|
|
170
|
+
tenant_id: Tenant identifier (or KALIBR_TENANT_ID env var)
|
|
171
|
+
environment: Environment name (or KALIBR_ENVIRONMENT env var)
|
|
172
|
+
service: Service name (or KALIBR_SERVICE env var)
|
|
173
|
+
workflow_id: Workflow identifier for grouping traces
|
|
174
|
+
secret: HMAC secret for request signing
|
|
175
|
+
batch_size: Max events per batch (default: 100)
|
|
176
|
+
flush_interval: Flush interval in seconds (default: 2.0)
|
|
177
|
+
capture_input: Whether to capture input prompts (default: True)
|
|
178
|
+
capture_output: Whether to capture outputs (default: True)
|
|
179
|
+
max_content_length: Max length for captured content (default: 10000)
|
|
180
|
+
metadata: Additional metadata to include in all events
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(
|
|
184
|
+
self,
|
|
185
|
+
api_key: Optional[str] = None,
|
|
186
|
+
endpoint: Optional[str] = None,
|
|
187
|
+
tenant_id: Optional[str] = None,
|
|
188
|
+
environment: Optional[str] = None,
|
|
189
|
+
service: Optional[str] = None,
|
|
190
|
+
workflow_id: Optional[str] = None,
|
|
191
|
+
secret: Optional[str] = None,
|
|
192
|
+
batch_size: int = 100,
|
|
193
|
+
flush_interval: float = 2.0,
|
|
194
|
+
capture_input: bool = True,
|
|
195
|
+
capture_output: bool = True,
|
|
196
|
+
max_content_length: int = 10000,
|
|
197
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
198
|
+
):
|
|
199
|
+
super().__init__()
|
|
200
|
+
|
|
201
|
+
# Configuration
|
|
202
|
+
self.api_key = api_key or os.getenv("KALIBR_API_KEY", "")
|
|
203
|
+
self.endpoint = endpoint or os.getenv(
|
|
204
|
+
"KALIBR_ENDPOINT",
|
|
205
|
+
os.getenv("KALIBR_API_ENDPOINT", "http://localhost:8001/api/v1/traces")
|
|
206
|
+
)
|
|
207
|
+
self.tenant_id = tenant_id or os.getenv("KALIBR_TENANT_ID", "default")
|
|
208
|
+
self.environment = environment or os.getenv("KALIBR_ENVIRONMENT", "prod")
|
|
209
|
+
self.service = service or os.getenv("KALIBR_SERVICE", "langchain-app")
|
|
210
|
+
self.workflow_id = workflow_id or os.getenv("KALIBR_WORKFLOW_ID", "default-workflow")
|
|
211
|
+
self.secret = secret
|
|
212
|
+
|
|
213
|
+
# Content capture settings
|
|
214
|
+
self.capture_input = capture_input
|
|
215
|
+
self.capture_output = capture_output
|
|
216
|
+
self.max_content_length = max_content_length
|
|
217
|
+
self.default_metadata = metadata or {}
|
|
218
|
+
|
|
219
|
+
# Batching configuration
|
|
220
|
+
self.batch_size = batch_size
|
|
221
|
+
self.flush_interval = flush_interval
|
|
222
|
+
|
|
223
|
+
# Span tracking
|
|
224
|
+
self._span_tracker = SpanTracker()
|
|
225
|
+
|
|
226
|
+
# Root trace ID (created per top-level chain/llm call)
|
|
227
|
+
self._root_trace_id: Optional[str] = None
|
|
228
|
+
self._trace_lock = threading.Lock()
|
|
229
|
+
|
|
230
|
+
# Event queue for batching
|
|
231
|
+
self._event_queue: queue.Queue = queue.Queue(maxsize=5000)
|
|
232
|
+
|
|
233
|
+
# HTTP client
|
|
234
|
+
self._client = httpx.Client(timeout=10.0)
|
|
235
|
+
|
|
236
|
+
# Background flusher thread
|
|
237
|
+
self._shutdown = False
|
|
238
|
+
self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
|
|
239
|
+
self._flush_thread.start()
|
|
240
|
+
|
|
241
|
+
# Register cleanup
|
|
242
|
+
atexit.register(self.shutdown)
|
|
243
|
+
|
|
244
|
+
def _get_or_create_trace_id(self, parent_run_id: Optional[str]) -> str:
|
|
245
|
+
"""Get existing trace ID or create a new one for root spans."""
|
|
246
|
+
with self._trace_lock:
|
|
247
|
+
if parent_run_id is None:
|
|
248
|
+
# This is a root span, create new trace
|
|
249
|
+
self._root_trace_id = str(uuid.uuid4())
|
|
250
|
+
return self._root_trace_id or str(uuid.uuid4())
|
|
251
|
+
|
|
252
|
+
def _truncate(self, text: str) -> str:
|
|
253
|
+
"""Truncate text to max length."""
|
|
254
|
+
if len(text) > self.max_content_length:
|
|
255
|
+
return text[:self.max_content_length] + "...[truncated]"
|
|
256
|
+
return text
|
|
257
|
+
|
|
258
|
+
def _compute_cost(
|
|
259
|
+
self,
|
|
260
|
+
provider: str,
|
|
261
|
+
model: str,
|
|
262
|
+
input_tokens: int,
|
|
263
|
+
output_tokens: int
|
|
264
|
+
) -> float:
|
|
265
|
+
"""Compute cost using Kalibr cost adapters."""
|
|
266
|
+
if CostAdapterFactory is not None:
|
|
267
|
+
return CostAdapterFactory.compute_cost(
|
|
268
|
+
vendor=provider,
|
|
269
|
+
model_name=model,
|
|
270
|
+
tokens_in=input_tokens,
|
|
271
|
+
tokens_out=output_tokens
|
|
272
|
+
)
|
|
273
|
+
return 0.0
|
|
274
|
+
|
|
275
|
+
def _create_event(
|
|
276
|
+
self,
|
|
277
|
+
span: Dict[str, Any],
|
|
278
|
+
input_tokens: int = 0,
|
|
279
|
+
output_tokens: int = 0,
|
|
280
|
+
error_type: Optional[str] = None,
|
|
281
|
+
error_message: Optional[str] = None,
|
|
282
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
283
|
+
) -> Dict[str, Any]:
|
|
284
|
+
"""Create a standardized trace event from span data."""
|
|
285
|
+
provider = span.get("provider", "custom")
|
|
286
|
+
model = span.get("model", "unknown")
|
|
287
|
+
|
|
288
|
+
# Compute cost
|
|
289
|
+
cost_usd = self._compute_cost(provider, model, input_tokens, output_tokens)
|
|
290
|
+
|
|
291
|
+
# Build event
|
|
292
|
+
event = {
|
|
293
|
+
"schema_version": "1.0",
|
|
294
|
+
"trace_id": span["trace_id"],
|
|
295
|
+
"span_id": span["span_id"],
|
|
296
|
+
"parent_span_id": span.get("parent_span_id"),
|
|
297
|
+
"tenant_id": self.tenant_id,
|
|
298
|
+
"workflow_id": self.workflow_id,
|
|
299
|
+
"provider": provider,
|
|
300
|
+
"model_id": model,
|
|
301
|
+
"model_name": model,
|
|
302
|
+
"operation": span["operation"],
|
|
303
|
+
"endpoint": span.get("endpoint", span["operation"]),
|
|
304
|
+
"duration_ms": span.get("duration_ms", 0),
|
|
305
|
+
"latency_ms": span.get("duration_ms", 0),
|
|
306
|
+
"input_tokens": input_tokens,
|
|
307
|
+
"output_tokens": output_tokens,
|
|
308
|
+
"total_tokens": input_tokens + output_tokens,
|
|
309
|
+
"cost_usd": cost_usd,
|
|
310
|
+
"total_cost_usd": cost_usd,
|
|
311
|
+
"status": span.get("status", "success"),
|
|
312
|
+
"error_type": error_type,
|
|
313
|
+
"error_message": error_message,
|
|
314
|
+
"timestamp": span["ts_start"].isoformat(),
|
|
315
|
+
"ts_start": span["ts_start"].isoformat(),
|
|
316
|
+
"ts_end": span.get("ts_end", datetime.now(timezone.utc)).isoformat(),
|
|
317
|
+
"environment": self.environment,
|
|
318
|
+
"service": self.service,
|
|
319
|
+
"runtime_env": os.getenv("RUNTIME_ENV", "local"),
|
|
320
|
+
"sandbox_id": os.getenv("SANDBOX_ID", "local"),
|
|
321
|
+
"metadata": {
|
|
322
|
+
**self.default_metadata,
|
|
323
|
+
"span_type": span.get("span_type", "llm"),
|
|
324
|
+
"langchain": True,
|
|
325
|
+
**(metadata or {}),
|
|
326
|
+
},
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
return event
|
|
330
|
+
|
|
331
|
+
def _enqueue_event(self, event: Dict[str, Any]):
|
|
332
|
+
"""Add event to queue for batching."""
|
|
333
|
+
try:
|
|
334
|
+
self._event_queue.put_nowait(event)
|
|
335
|
+
except queue.Full:
|
|
336
|
+
# Drop oldest event and retry
|
|
337
|
+
try:
|
|
338
|
+
self._event_queue.get_nowait()
|
|
339
|
+
self._event_queue.put_nowait(event)
|
|
340
|
+
except:
|
|
341
|
+
pass
|
|
342
|
+
|
|
343
|
+
def _flush_loop(self):
|
|
344
|
+
"""Background thread to flush events periodically."""
|
|
345
|
+
batch = []
|
|
346
|
+
last_flush = time.time()
|
|
347
|
+
|
|
348
|
+
while not self._shutdown:
|
|
349
|
+
try:
|
|
350
|
+
try:
|
|
351
|
+
event = self._event_queue.get(timeout=0.1)
|
|
352
|
+
batch.append(event)
|
|
353
|
+
except queue.Empty:
|
|
354
|
+
pass
|
|
355
|
+
|
|
356
|
+
now = time.time()
|
|
357
|
+
should_flush = (
|
|
358
|
+
len(batch) >= self.batch_size or
|
|
359
|
+
(batch and now - last_flush >= self.flush_interval)
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
if should_flush:
|
|
363
|
+
self._send_batch(batch)
|
|
364
|
+
batch = []
|
|
365
|
+
last_flush = now
|
|
366
|
+
|
|
367
|
+
except Exception:
|
|
368
|
+
pass
|
|
369
|
+
|
|
370
|
+
# Final flush on shutdown
|
|
371
|
+
if batch:
|
|
372
|
+
self._send_batch(batch)
|
|
373
|
+
|
|
374
|
+
def _send_batch(self, batch: List[Dict[str, Any]]):
|
|
375
|
+
"""Send batch to Kalibr backend."""
|
|
376
|
+
if not batch:
|
|
377
|
+
return
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
payload = {"events": batch}
|
|
381
|
+
|
|
382
|
+
headers = {}
|
|
383
|
+
if self.api_key:
|
|
384
|
+
headers["X-API-Key"] = self.api_key
|
|
385
|
+
|
|
386
|
+
if self.secret:
|
|
387
|
+
import json
|
|
388
|
+
body = json.dumps(payload).encode("utf-8")
|
|
389
|
+
signature = hmac.new(
|
|
390
|
+
self.secret.encode(), body, hashlib.sha256
|
|
391
|
+
).hexdigest()
|
|
392
|
+
headers["X-Signature"] = signature
|
|
393
|
+
|
|
394
|
+
response = self._client.post(
|
|
395
|
+
self.endpoint,
|
|
396
|
+
json=payload,
|
|
397
|
+
headers=headers,
|
|
398
|
+
)
|
|
399
|
+
response.raise_for_status()
|
|
400
|
+
|
|
401
|
+
except Exception as e:
|
|
402
|
+
# Log error but don't raise
|
|
403
|
+
pass
|
|
404
|
+
|
|
405
|
+
def shutdown(self):
|
|
406
|
+
"""Shutdown handler and flush remaining events."""
|
|
407
|
+
if self._shutdown:
|
|
408
|
+
return
|
|
409
|
+
|
|
410
|
+
self._shutdown = True
|
|
411
|
+
|
|
412
|
+
if self._flush_thread.is_alive():
|
|
413
|
+
self._flush_thread.join(timeout=5.0)
|
|
414
|
+
|
|
415
|
+
self._client.close()
|
|
416
|
+
|
|
417
|
+
# =========================================================================
|
|
418
|
+
# LLM Callbacks
|
|
419
|
+
# =========================================================================
|
|
420
|
+
|
|
421
|
+
def on_llm_start(
|
|
422
|
+
self,
|
|
423
|
+
serialized: Dict[str, Any],
|
|
424
|
+
prompts: List[str],
|
|
425
|
+
*,
|
|
426
|
+
run_id: uuid.UUID,
|
|
427
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
428
|
+
tags: Optional[List[str]] = None,
|
|
429
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
430
|
+
**kwargs: Any,
|
|
431
|
+
) -> None:
|
|
432
|
+
"""Called when LLM starts generating."""
|
|
433
|
+
run_id_str = str(run_id)
|
|
434
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
435
|
+
trace_id = self._get_or_create_trace_id(parent_id_str)
|
|
436
|
+
|
|
437
|
+
# Extract model info
|
|
438
|
+
model = kwargs.get("invocation_params", {}).get("model_name", "unknown")
|
|
439
|
+
if model == "unknown":
|
|
440
|
+
model = serialized.get("kwargs", {}).get("model_name", "unknown")
|
|
441
|
+
|
|
442
|
+
provider = _get_provider_from_model(model)
|
|
443
|
+
|
|
444
|
+
# Calculate input tokens
|
|
445
|
+
prompt_text = "\n".join(prompts)
|
|
446
|
+
input_tokens = _count_tokens(prompt_text, model)
|
|
447
|
+
|
|
448
|
+
span_metadata = {
|
|
449
|
+
"tags": tags or [],
|
|
450
|
+
"model": model,
|
|
451
|
+
"provider": provider,
|
|
452
|
+
"input_tokens": input_tokens,
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
if self.capture_input:
|
|
456
|
+
span_metadata["input"] = self._truncate(prompt_text)
|
|
457
|
+
|
|
458
|
+
self._span_tracker.start_span(
|
|
459
|
+
run_id=run_id_str,
|
|
460
|
+
trace_id=trace_id,
|
|
461
|
+
parent_run_id=parent_id_str,
|
|
462
|
+
operation="llm_call",
|
|
463
|
+
span_type="llm",
|
|
464
|
+
model=model,
|
|
465
|
+
provider=provider,
|
|
466
|
+
endpoint=f"{provider}.{model}",
|
|
467
|
+
**span_metadata,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
def on_llm_end(
|
|
471
|
+
self,
|
|
472
|
+
response: LLMResult,
|
|
473
|
+
*,
|
|
474
|
+
run_id: uuid.UUID,
|
|
475
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
476
|
+
**kwargs: Any,
|
|
477
|
+
) -> None:
|
|
478
|
+
"""Called when LLM finishes generating."""
|
|
479
|
+
run_id_str = str(run_id)
|
|
480
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
481
|
+
|
|
482
|
+
if not span:
|
|
483
|
+
return
|
|
484
|
+
|
|
485
|
+
# Extract token usage from response
|
|
486
|
+
input_tokens = span.get("input_tokens", 0)
|
|
487
|
+
output_tokens = 0
|
|
488
|
+
output_text = ""
|
|
489
|
+
|
|
490
|
+
# Get token usage from LLM response
|
|
491
|
+
if response.llm_output:
|
|
492
|
+
token_usage = response.llm_output.get("token_usage", {})
|
|
493
|
+
if token_usage:
|
|
494
|
+
input_tokens = token_usage.get("prompt_tokens", input_tokens)
|
|
495
|
+
output_tokens = token_usage.get("completion_tokens", 0)
|
|
496
|
+
|
|
497
|
+
# Extract output text
|
|
498
|
+
if response.generations:
|
|
499
|
+
output_parts = []
|
|
500
|
+
for gen_list in response.generations:
|
|
501
|
+
for gen in gen_list:
|
|
502
|
+
if hasattr(gen, "text"):
|
|
503
|
+
output_parts.append(gen.text)
|
|
504
|
+
elif hasattr(gen, "message") and hasattr(gen.message, "content"):
|
|
505
|
+
output_parts.append(gen.message.content)
|
|
506
|
+
output_text = "\n".join(output_parts)
|
|
507
|
+
|
|
508
|
+
# Fallback token count from output
|
|
509
|
+
if output_tokens == 0:
|
|
510
|
+
output_tokens = _count_tokens(output_text, span.get("model", "unknown"))
|
|
511
|
+
|
|
512
|
+
# Build metadata
|
|
513
|
+
event_metadata = {
|
|
514
|
+
"tags": span.get("tags", []),
|
|
515
|
+
}
|
|
516
|
+
if self.capture_output and output_text:
|
|
517
|
+
event_metadata["output"] = self._truncate(output_text)
|
|
518
|
+
if self.capture_input and "input" in span:
|
|
519
|
+
event_metadata["input"] = span["input"]
|
|
520
|
+
|
|
521
|
+
event = self._create_event(
|
|
522
|
+
span=span,
|
|
523
|
+
input_tokens=input_tokens,
|
|
524
|
+
output_tokens=output_tokens,
|
|
525
|
+
metadata=event_metadata,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
self._enqueue_event(event)
|
|
529
|
+
|
|
530
|
+
def on_llm_error(
|
|
531
|
+
self,
|
|
532
|
+
error: BaseException,
|
|
533
|
+
*,
|
|
534
|
+
run_id: uuid.UUID,
|
|
535
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
536
|
+
**kwargs: Any,
|
|
537
|
+
) -> None:
|
|
538
|
+
"""Called when LLM errors."""
|
|
539
|
+
run_id_str = str(run_id)
|
|
540
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
541
|
+
|
|
542
|
+
if not span:
|
|
543
|
+
return
|
|
544
|
+
|
|
545
|
+
span["status"] = "error"
|
|
546
|
+
|
|
547
|
+
event = self._create_event(
|
|
548
|
+
span=span,
|
|
549
|
+
input_tokens=span.get("input_tokens", 0),
|
|
550
|
+
output_tokens=0,
|
|
551
|
+
error_type=type(error).__name__,
|
|
552
|
+
error_message=str(error)[:512],
|
|
553
|
+
metadata={
|
|
554
|
+
"tags": span.get("tags", []),
|
|
555
|
+
"stack_trace": "".join(traceback.format_exception(
|
|
556
|
+
type(error), error, error.__traceback__
|
|
557
|
+
))[:2000],
|
|
558
|
+
},
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
self._enqueue_event(event)
|
|
562
|
+
|
|
563
|
+
# =========================================================================
|
|
564
|
+
# Chat Model Callbacks
|
|
565
|
+
# =========================================================================
|
|
566
|
+
|
|
567
|
+
def on_chat_model_start(
|
|
568
|
+
self,
|
|
569
|
+
serialized: Dict[str, Any],
|
|
570
|
+
messages: List[List[BaseMessage]],
|
|
571
|
+
*,
|
|
572
|
+
run_id: uuid.UUID,
|
|
573
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
574
|
+
tags: Optional[List[str]] = None,
|
|
575
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
576
|
+
**kwargs: Any,
|
|
577
|
+
) -> None:
|
|
578
|
+
"""Called when chat model starts."""
|
|
579
|
+
run_id_str = str(run_id)
|
|
580
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
581
|
+
trace_id = self._get_or_create_trace_id(parent_id_str)
|
|
582
|
+
|
|
583
|
+
# Extract model info
|
|
584
|
+
model = kwargs.get("invocation_params", {}).get("model", "unknown")
|
|
585
|
+
if model == "unknown":
|
|
586
|
+
model = kwargs.get("invocation_params", {}).get("model_name", "unknown")
|
|
587
|
+
if model == "unknown":
|
|
588
|
+
model = serialized.get("kwargs", {}).get("model", "unknown")
|
|
589
|
+
if model == "unknown":
|
|
590
|
+
model = serialized.get("kwargs", {}).get("model_name", "unknown")
|
|
591
|
+
|
|
592
|
+
provider = _get_provider_from_model(model)
|
|
593
|
+
|
|
594
|
+
# Calculate input tokens from messages
|
|
595
|
+
message_text = ""
|
|
596
|
+
for msg_list in messages:
|
|
597
|
+
for msg in msg_list:
|
|
598
|
+
if hasattr(msg, "content"):
|
|
599
|
+
content = msg.content
|
|
600
|
+
if isinstance(content, str):
|
|
601
|
+
message_text += content + "\n"
|
|
602
|
+
elif isinstance(content, list):
|
|
603
|
+
for item in content:
|
|
604
|
+
if isinstance(item, str):
|
|
605
|
+
message_text += item + "\n"
|
|
606
|
+
elif isinstance(item, dict) and "text" in item:
|
|
607
|
+
message_text += item["text"] + "\n"
|
|
608
|
+
|
|
609
|
+
input_tokens = _count_tokens(message_text, model)
|
|
610
|
+
|
|
611
|
+
span_metadata = {
|
|
612
|
+
"tags": tags or [],
|
|
613
|
+
"model": model,
|
|
614
|
+
"provider": provider,
|
|
615
|
+
"input_tokens": input_tokens,
|
|
616
|
+
"message_count": sum(len(msg_list) for msg_list in messages),
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
if self.capture_input:
|
|
620
|
+
span_metadata["input"] = self._truncate(message_text)
|
|
621
|
+
|
|
622
|
+
self._span_tracker.start_span(
|
|
623
|
+
run_id=run_id_str,
|
|
624
|
+
trace_id=trace_id,
|
|
625
|
+
parent_run_id=parent_id_str,
|
|
626
|
+
operation="chat_completion",
|
|
627
|
+
span_type="chat",
|
|
628
|
+
model=model,
|
|
629
|
+
provider=provider,
|
|
630
|
+
endpoint=f"{provider}.chat.completions",
|
|
631
|
+
**span_metadata,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
# =========================================================================
|
|
635
|
+
# Chain Callbacks
|
|
636
|
+
# =========================================================================
|
|
637
|
+
|
|
638
|
+
def on_chain_start(
|
|
639
|
+
self,
|
|
640
|
+
serialized: Dict[str, Any],
|
|
641
|
+
inputs: Dict[str, Any],
|
|
642
|
+
*,
|
|
643
|
+
run_id: uuid.UUID,
|
|
644
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
645
|
+
tags: Optional[List[str]] = None,
|
|
646
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
647
|
+
**kwargs: Any,
|
|
648
|
+
) -> None:
|
|
649
|
+
"""Called when chain starts."""
|
|
650
|
+
run_id_str = str(run_id)
|
|
651
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
652
|
+
trace_id = self._get_or_create_trace_id(parent_id_str)
|
|
653
|
+
|
|
654
|
+
# Get chain name
|
|
655
|
+
chain_name = serialized.get("name", serialized.get("id", ["unknown"])[-1])
|
|
656
|
+
|
|
657
|
+
span_metadata = {
|
|
658
|
+
"tags": tags or [],
|
|
659
|
+
"chain_name": chain_name,
|
|
660
|
+
}
|
|
661
|
+
|
|
662
|
+
if self.capture_input:
|
|
663
|
+
input_str = str(_serialize_for_metadata(inputs))
|
|
664
|
+
span_metadata["input"] = self._truncate(input_str)
|
|
665
|
+
|
|
666
|
+
self._span_tracker.start_span(
|
|
667
|
+
run_id=run_id_str,
|
|
668
|
+
trace_id=trace_id,
|
|
669
|
+
parent_run_id=parent_id_str,
|
|
670
|
+
operation=f"chain:{chain_name}",
|
|
671
|
+
span_type="chain",
|
|
672
|
+
model="chain",
|
|
673
|
+
provider="langchain",
|
|
674
|
+
endpoint=chain_name,
|
|
675
|
+
**span_metadata,
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
def on_chain_end(
|
|
679
|
+
self,
|
|
680
|
+
outputs: Dict[str, Any],
|
|
681
|
+
*,
|
|
682
|
+
run_id: uuid.UUID,
|
|
683
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
684
|
+
**kwargs: Any,
|
|
685
|
+
) -> None:
|
|
686
|
+
"""Called when chain ends."""
|
|
687
|
+
run_id_str = str(run_id)
|
|
688
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
689
|
+
|
|
690
|
+
if not span:
|
|
691
|
+
return
|
|
692
|
+
|
|
693
|
+
event_metadata = {
|
|
694
|
+
"tags": span.get("tags", []),
|
|
695
|
+
"chain_name": span.get("chain_name", "unknown"),
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
if self.capture_output:
|
|
699
|
+
output_str = str(_serialize_for_metadata(outputs))
|
|
700
|
+
event_metadata["output"] = self._truncate(output_str)
|
|
701
|
+
if self.capture_input and "input" in span:
|
|
702
|
+
event_metadata["input"] = span["input"]
|
|
703
|
+
|
|
704
|
+
event = self._create_event(
|
|
705
|
+
span=span,
|
|
706
|
+
metadata=event_metadata,
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
self._enqueue_event(event)
|
|
710
|
+
|
|
711
|
+
def on_chain_error(
|
|
712
|
+
self,
|
|
713
|
+
error: BaseException,
|
|
714
|
+
*,
|
|
715
|
+
run_id: uuid.UUID,
|
|
716
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
717
|
+
**kwargs: Any,
|
|
718
|
+
) -> None:
|
|
719
|
+
"""Called when chain errors."""
|
|
720
|
+
run_id_str = str(run_id)
|
|
721
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
722
|
+
|
|
723
|
+
if not span:
|
|
724
|
+
return
|
|
725
|
+
|
|
726
|
+
span["status"] = "error"
|
|
727
|
+
|
|
728
|
+
event = self._create_event(
|
|
729
|
+
span=span,
|
|
730
|
+
error_type=type(error).__name__,
|
|
731
|
+
error_message=str(error)[:512],
|
|
732
|
+
metadata={
|
|
733
|
+
"tags": span.get("tags", []),
|
|
734
|
+
"chain_name": span.get("chain_name", "unknown"),
|
|
735
|
+
"stack_trace": "".join(traceback.format_exception(
|
|
736
|
+
type(error), error, error.__traceback__
|
|
737
|
+
))[:2000],
|
|
738
|
+
},
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
self._enqueue_event(event)
|
|
742
|
+
|
|
743
|
+
# =========================================================================
|
|
744
|
+
# Tool Callbacks
|
|
745
|
+
# =========================================================================
|
|
746
|
+
|
|
747
|
+
def on_tool_start(
|
|
748
|
+
self,
|
|
749
|
+
serialized: Dict[str, Any],
|
|
750
|
+
input_str: str,
|
|
751
|
+
*,
|
|
752
|
+
run_id: uuid.UUID,
|
|
753
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
754
|
+
tags: Optional[List[str]] = None,
|
|
755
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
756
|
+
**kwargs: Any,
|
|
757
|
+
) -> None:
|
|
758
|
+
"""Called when tool starts."""
|
|
759
|
+
run_id_str = str(run_id)
|
|
760
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
761
|
+
trace_id = self._get_or_create_trace_id(parent_id_str)
|
|
762
|
+
|
|
763
|
+
tool_name = serialized.get("name", "unknown_tool")
|
|
764
|
+
|
|
765
|
+
span_metadata = {
|
|
766
|
+
"tags": tags or [],
|
|
767
|
+
"tool_name": tool_name,
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
if self.capture_input:
|
|
771
|
+
span_metadata["input"] = self._truncate(input_str)
|
|
772
|
+
|
|
773
|
+
self._span_tracker.start_span(
|
|
774
|
+
run_id=run_id_str,
|
|
775
|
+
trace_id=trace_id,
|
|
776
|
+
parent_run_id=parent_id_str,
|
|
777
|
+
operation=f"tool:{tool_name}",
|
|
778
|
+
span_type="tool",
|
|
779
|
+
model="tool",
|
|
780
|
+
provider="langchain",
|
|
781
|
+
endpoint=tool_name,
|
|
782
|
+
**span_metadata,
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
def on_tool_end(
|
|
786
|
+
self,
|
|
787
|
+
output: Any,
|
|
788
|
+
*,
|
|
789
|
+
run_id: uuid.UUID,
|
|
790
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
791
|
+
**kwargs: Any,
|
|
792
|
+
) -> None:
|
|
793
|
+
"""Called when tool ends."""
|
|
794
|
+
run_id_str = str(run_id)
|
|
795
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
796
|
+
|
|
797
|
+
if not span:
|
|
798
|
+
return
|
|
799
|
+
|
|
800
|
+
event_metadata = {
|
|
801
|
+
"tags": span.get("tags", []),
|
|
802
|
+
"tool_name": span.get("tool_name", "unknown"),
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
if self.capture_output:
|
|
806
|
+
output_str = str(_serialize_for_metadata(output))
|
|
807
|
+
event_metadata["output"] = self._truncate(output_str)
|
|
808
|
+
if self.capture_input and "input" in span:
|
|
809
|
+
event_metadata["input"] = span["input"]
|
|
810
|
+
|
|
811
|
+
event = self._create_event(
|
|
812
|
+
span=span,
|
|
813
|
+
metadata=event_metadata,
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
self._enqueue_event(event)
|
|
817
|
+
|
|
818
|
+
def on_tool_error(
|
|
819
|
+
self,
|
|
820
|
+
error: BaseException,
|
|
821
|
+
*,
|
|
822
|
+
run_id: uuid.UUID,
|
|
823
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
824
|
+
**kwargs: Any,
|
|
825
|
+
) -> None:
|
|
826
|
+
"""Called when tool errors."""
|
|
827
|
+
run_id_str = str(run_id)
|
|
828
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
829
|
+
|
|
830
|
+
if not span:
|
|
831
|
+
return
|
|
832
|
+
|
|
833
|
+
span["status"] = "error"
|
|
834
|
+
|
|
835
|
+
event = self._create_event(
|
|
836
|
+
span=span,
|
|
837
|
+
error_type=type(error).__name__,
|
|
838
|
+
error_message=str(error)[:512],
|
|
839
|
+
metadata={
|
|
840
|
+
"tags": span.get("tags", []),
|
|
841
|
+
"tool_name": span.get("tool_name", "unknown"),
|
|
842
|
+
"stack_trace": "".join(traceback.format_exception(
|
|
843
|
+
type(error), error, error.__traceback__
|
|
844
|
+
))[:2000],
|
|
845
|
+
},
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
self._enqueue_event(event)
|
|
849
|
+
|
|
850
|
+
# =========================================================================
|
|
851
|
+
# Agent Callbacks
|
|
852
|
+
# =========================================================================
|
|
853
|
+
|
|
854
|
+
def on_agent_action(
|
|
855
|
+
self,
|
|
856
|
+
action: AgentAction,
|
|
857
|
+
*,
|
|
858
|
+
run_id: uuid.UUID,
|
|
859
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
860
|
+
**kwargs: Any,
|
|
861
|
+
) -> None:
|
|
862
|
+
"""Called when agent takes an action."""
|
|
863
|
+
# Agent actions are tracked as part of the chain
|
|
864
|
+
# We just update metadata on the parent span
|
|
865
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
866
|
+
if parent_id_str:
|
|
867
|
+
self._span_tracker.update_span(
|
|
868
|
+
parent_id_str,
|
|
869
|
+
last_action=action.tool,
|
|
870
|
+
last_action_input=self._truncate(str(action.tool_input))
|
|
871
|
+
if self.capture_input else None,
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
def on_agent_finish(
|
|
875
|
+
self,
|
|
876
|
+
finish: AgentFinish,
|
|
877
|
+
*,
|
|
878
|
+
run_id: uuid.UUID,
|
|
879
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
880
|
+
**kwargs: Any,
|
|
881
|
+
) -> None:
|
|
882
|
+
"""Called when agent finishes."""
|
|
883
|
+
# Agent finish is tracked as part of the chain
|
|
884
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
885
|
+
if parent_id_str:
|
|
886
|
+
self._span_tracker.update_span(
|
|
887
|
+
parent_id_str,
|
|
888
|
+
agent_finish=True,
|
|
889
|
+
return_values=self._truncate(str(finish.return_values))
|
|
890
|
+
if self.capture_output else None,
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
# =========================================================================
|
|
894
|
+
# Retriever Callbacks
|
|
895
|
+
# =========================================================================
|
|
896
|
+
|
|
897
|
+
def on_retriever_start(
|
|
898
|
+
self,
|
|
899
|
+
serialized: Dict[str, Any],
|
|
900
|
+
query: str,
|
|
901
|
+
*,
|
|
902
|
+
run_id: uuid.UUID,
|
|
903
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
904
|
+
tags: Optional[List[str]] = None,
|
|
905
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
906
|
+
**kwargs: Any,
|
|
907
|
+
) -> None:
|
|
908
|
+
"""Called when retriever starts."""
|
|
909
|
+
run_id_str = str(run_id)
|
|
910
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
911
|
+
trace_id = self._get_or_create_trace_id(parent_id_str)
|
|
912
|
+
|
|
913
|
+
retriever_name = serialized.get("name", serialized.get("id", ["unknown"])[-1])
|
|
914
|
+
|
|
915
|
+
span_metadata = {
|
|
916
|
+
"tags": tags or [],
|
|
917
|
+
"retriever_name": retriever_name,
|
|
918
|
+
}
|
|
919
|
+
|
|
920
|
+
if self.capture_input:
|
|
921
|
+
span_metadata["query"] = self._truncate(query)
|
|
922
|
+
|
|
923
|
+
self._span_tracker.start_span(
|
|
924
|
+
run_id=run_id_str,
|
|
925
|
+
trace_id=trace_id,
|
|
926
|
+
parent_run_id=parent_id_str,
|
|
927
|
+
operation=f"retriever:{retriever_name}",
|
|
928
|
+
span_type="retriever",
|
|
929
|
+
model="retriever",
|
|
930
|
+
provider="langchain",
|
|
931
|
+
endpoint=retriever_name,
|
|
932
|
+
**span_metadata,
|
|
933
|
+
)
|
|
934
|
+
|
|
935
|
+
def on_retriever_end(
|
|
936
|
+
self,
|
|
937
|
+
documents: Sequence[Document],
|
|
938
|
+
*,
|
|
939
|
+
run_id: uuid.UUID,
|
|
940
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
941
|
+
**kwargs: Any,
|
|
942
|
+
) -> None:
|
|
943
|
+
"""Called when retriever ends."""
|
|
944
|
+
run_id_str = str(run_id)
|
|
945
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
946
|
+
|
|
947
|
+
if not span:
|
|
948
|
+
return
|
|
949
|
+
|
|
950
|
+
event_metadata = {
|
|
951
|
+
"tags": span.get("tags", []),
|
|
952
|
+
"retriever_name": span.get("retriever_name", "unknown"),
|
|
953
|
+
"document_count": len(documents),
|
|
954
|
+
}
|
|
955
|
+
|
|
956
|
+
if self.capture_input and "query" in span:
|
|
957
|
+
event_metadata["query"] = span["query"]
|
|
958
|
+
|
|
959
|
+
if self.capture_output and documents:
|
|
960
|
+
# Capture document summaries
|
|
961
|
+
doc_summaries = []
|
|
962
|
+
for doc in documents[:5]: # Limit to first 5 docs
|
|
963
|
+
summary = {
|
|
964
|
+
"content_preview": self._truncate(doc.page_content[:200]),
|
|
965
|
+
"metadata": _serialize_for_metadata(doc.metadata),
|
|
966
|
+
}
|
|
967
|
+
doc_summaries.append(summary)
|
|
968
|
+
event_metadata["documents"] = doc_summaries
|
|
969
|
+
|
|
970
|
+
event = self._create_event(
|
|
971
|
+
span=span,
|
|
972
|
+
metadata=event_metadata,
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
self._enqueue_event(event)
|
|
976
|
+
|
|
977
|
+
def on_retriever_error(
|
|
978
|
+
self,
|
|
979
|
+
error: BaseException,
|
|
980
|
+
*,
|
|
981
|
+
run_id: uuid.UUID,
|
|
982
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
983
|
+
**kwargs: Any,
|
|
984
|
+
) -> None:
|
|
985
|
+
"""Called when retriever errors."""
|
|
986
|
+
run_id_str = str(run_id)
|
|
987
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
988
|
+
|
|
989
|
+
if not span:
|
|
990
|
+
return
|
|
991
|
+
|
|
992
|
+
span["status"] = "error"
|
|
993
|
+
|
|
994
|
+
event = self._create_event(
|
|
995
|
+
span=span,
|
|
996
|
+
error_type=type(error).__name__,
|
|
997
|
+
error_message=str(error)[:512],
|
|
998
|
+
metadata={
|
|
999
|
+
"tags": span.get("tags", []),
|
|
1000
|
+
"retriever_name": span.get("retriever_name", "unknown"),
|
|
1001
|
+
"stack_trace": "".join(traceback.format_exception(
|
|
1002
|
+
type(error), error, error.__traceback__
|
|
1003
|
+
))[:2000],
|
|
1004
|
+
},
|
|
1005
|
+
)
|
|
1006
|
+
|
|
1007
|
+
self._enqueue_event(event)
|
|
1008
|
+
|
|
1009
|
+
# =========================================================================
|
|
1010
|
+
# Text/Streaming Callbacks
|
|
1011
|
+
# =========================================================================
|
|
1012
|
+
|
|
1013
|
+
def on_llm_new_token(
|
|
1014
|
+
self,
|
|
1015
|
+
token: str,
|
|
1016
|
+
*,
|
|
1017
|
+
run_id: uuid.UUID,
|
|
1018
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
1019
|
+
**kwargs: Any,
|
|
1020
|
+
) -> None:
|
|
1021
|
+
"""Called on each new token during streaming."""
|
|
1022
|
+
# We track tokens but don't send events for each one
|
|
1023
|
+
run_id_str = str(run_id)
|
|
1024
|
+
span = self._span_tracker.get_span(run_id_str)
|
|
1025
|
+
|
|
1026
|
+
if span:
|
|
1027
|
+
current_tokens = span.get("streaming_tokens", "")
|
|
1028
|
+
self._span_tracker.update_span(
|
|
1029
|
+
run_id_str,
|
|
1030
|
+
streaming_tokens=current_tokens + token,
|
|
1031
|
+
)
|
|
1032
|
+
|
|
1033
|
+
def on_text(
|
|
1034
|
+
self,
|
|
1035
|
+
text: str,
|
|
1036
|
+
*,
|
|
1037
|
+
run_id: uuid.UUID,
|
|
1038
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
1039
|
+
**kwargs: Any,
|
|
1040
|
+
) -> None:
|
|
1041
|
+
"""Called when text is generated."""
|
|
1042
|
+
# Track text generation
|
|
1043
|
+
pass
|
|
1044
|
+
|
|
1045
|
+
# =========================================================================
|
|
1046
|
+
# Utility Methods
|
|
1047
|
+
# =========================================================================
|
|
1048
|
+
|
|
1049
|
+
def flush(self):
|
|
1050
|
+
"""Force flush all pending events."""
|
|
1051
|
+
events = []
|
|
1052
|
+
while True:
|
|
1053
|
+
try:
|
|
1054
|
+
event = self._event_queue.get_nowait()
|
|
1055
|
+
events.append(event)
|
|
1056
|
+
except queue.Empty:
|
|
1057
|
+
break
|
|
1058
|
+
|
|
1059
|
+
if events:
|
|
1060
|
+
self._send_batch(events)
|
|
1061
|
+
|
|
1062
|
+
def get_trace_id(self) -> Optional[str]:
|
|
1063
|
+
"""Get current trace ID."""
|
|
1064
|
+
return self._root_trace_id
|