kalibr 1.0.28__py3-none-any.whl → 1.1.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. kalibr/__init__.py +129 -4
  2. kalibr/__main__.py +3 -203
  3. kalibr/capsule_middleware.py +108 -0
  4. kalibr/cli/__init__.py +5 -0
  5. kalibr/cli/capsule_cmd.py +174 -0
  6. kalibr/cli/deploy_cmd.py +114 -0
  7. kalibr/cli/main.py +67 -0
  8. kalibr/cli/run.py +203 -0
  9. kalibr/cli/serve.py +59 -0
  10. kalibr/client.py +293 -0
  11. kalibr/collector.py +173 -0
  12. kalibr/context.py +132 -0
  13. kalibr/cost_adapter.py +222 -0
  14. kalibr/decorators.py +140 -0
  15. kalibr/instrumentation/__init__.py +13 -0
  16. kalibr/instrumentation/anthropic_instr.py +282 -0
  17. kalibr/instrumentation/base.py +108 -0
  18. kalibr/instrumentation/google_instr.py +281 -0
  19. kalibr/instrumentation/openai_instr.py +265 -0
  20. kalibr/instrumentation/registry.py +153 -0
  21. kalibr/kalibr.py +144 -230
  22. kalibr/kalibr_app.py +53 -314
  23. kalibr/middleware/__init__.py +5 -0
  24. kalibr/middleware/auto_tracer.py +356 -0
  25. kalibr/models.py +41 -0
  26. kalibr/redaction.py +44 -0
  27. kalibr/schemas.py +116 -0
  28. kalibr/simple_tracer.py +258 -0
  29. kalibr/tokens.py +52 -0
  30. kalibr/trace_capsule.py +296 -0
  31. kalibr/trace_models.py +201 -0
  32. kalibr/tracer.py +354 -0
  33. kalibr/types.py +25 -93
  34. kalibr/utils.py +198 -0
  35. kalibr-1.1.3a0.dist-info/METADATA +236 -0
  36. kalibr-1.1.3a0.dist-info/RECORD +48 -0
  37. kalibr-1.1.3a0.dist-info/entry_points.txt +2 -0
  38. kalibr-1.1.3a0.dist-info/licenses/LICENSE +21 -0
  39. kalibr-1.1.3a0.dist-info/top_level.txt +4 -0
  40. kalibr_crewai/__init__.py +65 -0
  41. kalibr_crewai/callbacks.py +539 -0
  42. kalibr_crewai/instrumentor.py +513 -0
  43. kalibr_langchain/__init__.py +47 -0
  44. kalibr_langchain/async_callback.py +850 -0
  45. kalibr_langchain/callback.py +1064 -0
  46. kalibr_openai_agents/__init__.py +43 -0
  47. kalibr_openai_agents/processor.py +554 -0
  48. kalibr/deployment.py +0 -41
  49. kalibr/packager.py +0 -43
  50. kalibr/runtime_router.py +0 -138
  51. kalibr/schema_generators.py +0 -159
  52. kalibr/validator.py +0 -70
  53. kalibr-1.0.28.data/data/examples/README.md +0 -173
  54. kalibr-1.0.28.data/data/examples/basic_kalibr_example.py +0 -66
  55. kalibr-1.0.28.data/data/examples/enhanced_kalibr_example.py +0 -347
  56. kalibr-1.0.28.dist-info/METADATA +0 -175
  57. kalibr-1.0.28.dist-info/RECORD +0 -19
  58. kalibr-1.0.28.dist-info/entry_points.txt +0 -2
  59. kalibr-1.0.28.dist-info/licenses/LICENSE +0 -11
  60. kalibr-1.0.28.dist-info/top_level.txt +0 -1
  61. {kalibr-1.0.28.dist-info → kalibr-1.1.3a0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1064 @@
1
+ """Kalibr Callback Handler for LangChain.
2
+
3
+ This module provides the main callback handler that integrates LangChain
4
+ with Kalibr's observability platform.
5
+ """
6
+
7
+ import atexit
8
+ import hashlib
9
+ import hmac
10
+ import os
11
+ import queue
12
+ import threading
13
+ import time
14
+ import traceback
15
+ import uuid
16
+ from datetime import datetime, timezone
17
+ from typing import Any, Dict, List, Optional, Sequence, Union
18
+
19
+ import httpx
20
+ from langchain_core.callbacks import BaseCallbackHandler
21
+ from langchain_core.agents import AgentAction, AgentFinish
22
+ from langchain_core.documents import Document
23
+ from langchain_core.messages import BaseMessage
24
+ from langchain_core.outputs import ChatGeneration, Generation, LLMResult
25
+
26
+ # Import Kalibr cost adapters
27
+ try:
28
+ from kalibr.cost_adapter import CostAdapterFactory
29
+ except ImportError:
30
+ CostAdapterFactory = None
31
+
32
+ # Import tiktoken for token counting
33
+ try:
34
+ import tiktoken
35
+ HAS_TIKTOKEN = True
36
+ except ImportError:
37
+ HAS_TIKTOKEN = False
38
+
39
+
40
+ def _count_tokens(text: str, model: str) -> int:
41
+ """Count tokens for given text and model."""
42
+ if not text:
43
+ return 0
44
+
45
+ if HAS_TIKTOKEN and "gpt" in model.lower():
46
+ try:
47
+ encoding = tiktoken.encoding_for_model(model)
48
+ return len(encoding.encode(text))
49
+ except Exception:
50
+ pass
51
+
52
+ # Fallback: approximate (1 token ~= 4 chars)
53
+ return len(text) // 4
54
+
55
+
56
+ def _get_provider_from_model(model: str) -> str:
57
+ """Infer provider from model name."""
58
+ model_lower = model.lower()
59
+
60
+ if any(x in model_lower for x in ["gpt", "text-davinci", "text-embedding", "whisper", "dall-e"]):
61
+ return "openai"
62
+ elif any(x in model_lower for x in ["claude"]):
63
+ return "anthropic"
64
+ elif any(x in model_lower for x in ["gemini", "palm", "bison"]):
65
+ return "google"
66
+ elif any(x in model_lower for x in ["cohere", "command"]):
67
+ return "cohere"
68
+ else:
69
+ return "custom"
70
+
71
+
72
+ def _serialize_for_metadata(obj: Any) -> Any:
73
+ """Serialize objects for JSON metadata."""
74
+ if obj is None:
75
+ return None
76
+ elif isinstance(obj, (str, int, float, bool)):
77
+ return obj
78
+ elif isinstance(obj, (list, tuple)):
79
+ return [_serialize_for_metadata(item) for item in obj]
80
+ elif isinstance(obj, dict):
81
+ return {k: _serialize_for_metadata(v) for k, v in obj.items()}
82
+ elif hasattr(obj, "dict"):
83
+ return obj.dict()
84
+ elif hasattr(obj, "__dict__"):
85
+ return {k: _serialize_for_metadata(v) for k, v in obj.__dict__.items()
86
+ if not k.startswith("_")}
87
+ else:
88
+ return str(obj)
89
+
90
+
91
+ class SpanTracker:
92
+ """Tracks active spans and their metadata."""
93
+
94
+ def __init__(self):
95
+ self.spans: Dict[str, Dict[str, Any]] = {}
96
+ self._lock = threading.Lock()
97
+
98
+ def start_span(
99
+ self,
100
+ run_id: str,
101
+ trace_id: str,
102
+ parent_run_id: Optional[str],
103
+ operation: str,
104
+ span_type: str,
105
+ **kwargs
106
+ ) -> Dict[str, Any]:
107
+ """Start a new span."""
108
+ span_id = str(uuid.uuid4())
109
+
110
+ with self._lock:
111
+ parent_span_id = None
112
+ if parent_run_id and parent_run_id in self.spans:
113
+ parent_span_id = self.spans[parent_run_id].get("span_id")
114
+
115
+ span = {
116
+ "span_id": span_id,
117
+ "trace_id": trace_id,
118
+ "parent_span_id": parent_span_id,
119
+ "operation": operation,
120
+ "span_type": span_type,
121
+ "ts_start": datetime.now(timezone.utc),
122
+ "status": "success",
123
+ **kwargs
124
+ }
125
+ self.spans[run_id] = span
126
+ return span
127
+
128
+ def end_span(self, run_id: str) -> Optional[Dict[str, Any]]:
129
+ """End a span and return its data."""
130
+ with self._lock:
131
+ if run_id in self.spans:
132
+ span = self.spans.pop(run_id)
133
+ span["ts_end"] = datetime.now(timezone.utc)
134
+ span["duration_ms"] = int(
135
+ (span["ts_end"] - span["ts_start"]).total_seconds() * 1000
136
+ )
137
+ return span
138
+ return None
139
+
140
+ def update_span(self, run_id: str, **kwargs):
141
+ """Update span with additional data."""
142
+ with self._lock:
143
+ if run_id in self.spans:
144
+ self.spans[run_id].update(kwargs)
145
+
146
+ def get_span(self, run_id: str) -> Optional[Dict[str, Any]]:
147
+ """Get span data."""
148
+ with self._lock:
149
+ return self.spans.get(run_id)
150
+
151
+
152
+ class KalibrCallbackHandler(BaseCallbackHandler):
153
+ """LangChain callback handler for Kalibr observability.
154
+
155
+ This handler captures telemetry from LangChain components and sends
156
+ them to the Kalibr backend for analysis and visualization.
157
+
158
+ Supported callbacks:
159
+ - LLM start/end/error
160
+ - Chat model start/end
161
+ - Chain start/end/error
162
+ - Tool start/end/error
163
+ - Agent action/finish
164
+ - Retriever start/end
165
+ - Text generation (streaming)
166
+
167
+ Args:
168
+ api_key: Kalibr API key (or KALIBR_API_KEY env var)
169
+ endpoint: Backend endpoint URL (or KALIBR_ENDPOINT env var)
170
+ tenant_id: Tenant identifier (or KALIBR_TENANT_ID env var)
171
+ environment: Environment name (or KALIBR_ENVIRONMENT env var)
172
+ service: Service name (or KALIBR_SERVICE env var)
173
+ workflow_id: Workflow identifier for grouping traces
174
+ secret: HMAC secret for request signing
175
+ batch_size: Max events per batch (default: 100)
176
+ flush_interval: Flush interval in seconds (default: 2.0)
177
+ capture_input: Whether to capture input prompts (default: True)
178
+ capture_output: Whether to capture outputs (default: True)
179
+ max_content_length: Max length for captured content (default: 10000)
180
+ metadata: Additional metadata to include in all events
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ api_key: Optional[str] = None,
186
+ endpoint: Optional[str] = None,
187
+ tenant_id: Optional[str] = None,
188
+ environment: Optional[str] = None,
189
+ service: Optional[str] = None,
190
+ workflow_id: Optional[str] = None,
191
+ secret: Optional[str] = None,
192
+ batch_size: int = 100,
193
+ flush_interval: float = 2.0,
194
+ capture_input: bool = True,
195
+ capture_output: bool = True,
196
+ max_content_length: int = 10000,
197
+ metadata: Optional[Dict[str, Any]] = None,
198
+ ):
199
+ super().__init__()
200
+
201
+ # Configuration
202
+ self.api_key = api_key or os.getenv("KALIBR_API_KEY", "")
203
+ self.endpoint = endpoint or os.getenv(
204
+ "KALIBR_ENDPOINT",
205
+ os.getenv("KALIBR_API_ENDPOINT", "https://api.kalibr.systems/api/v1/traces")
206
+ )
207
+ self.tenant_id = tenant_id or os.getenv("KALIBR_TENANT_ID", "default")
208
+ self.environment = environment or os.getenv("KALIBR_ENVIRONMENT", "prod")
209
+ self.service = service or os.getenv("KALIBR_SERVICE", "langchain-app")
210
+ self.workflow_id = workflow_id or os.getenv("KALIBR_WORKFLOW_ID", "default-workflow")
211
+ self.secret = secret
212
+
213
+ # Content capture settings
214
+ self.capture_input = capture_input
215
+ self.capture_output = capture_output
216
+ self.max_content_length = max_content_length
217
+ self.default_metadata = metadata or {}
218
+
219
+ # Batching configuration
220
+ self.batch_size = batch_size
221
+ self.flush_interval = flush_interval
222
+
223
+ # Span tracking
224
+ self._span_tracker = SpanTracker()
225
+
226
+ # Root trace ID (created per top-level chain/llm call)
227
+ self._root_trace_id: Optional[str] = None
228
+ self._trace_lock = threading.Lock()
229
+
230
+ # Event queue for batching
231
+ self._event_queue: queue.Queue = queue.Queue(maxsize=5000)
232
+
233
+ # HTTP client
234
+ self._client = httpx.Client(timeout=10.0)
235
+
236
+ # Background flusher thread
237
+ self._shutdown = False
238
+ self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
239
+ self._flush_thread.start()
240
+
241
+ # Register cleanup
242
+ atexit.register(self.shutdown)
243
+
244
+ def _get_or_create_trace_id(self, parent_run_id: Optional[str]) -> str:
245
+ """Get existing trace ID or create a new one for root spans."""
246
+ with self._trace_lock:
247
+ if parent_run_id is None:
248
+ # This is a root span, create new trace
249
+ self._root_trace_id = str(uuid.uuid4())
250
+ return self._root_trace_id or str(uuid.uuid4())
251
+
252
+ def _truncate(self, text: str) -> str:
253
+ """Truncate text to max length."""
254
+ if len(text) > self.max_content_length:
255
+ return text[:self.max_content_length] + "...[truncated]"
256
+ return text
257
+
258
+ def _compute_cost(
259
+ self,
260
+ provider: str,
261
+ model: str,
262
+ input_tokens: int,
263
+ output_tokens: int
264
+ ) -> float:
265
+ """Compute cost using Kalibr cost adapters."""
266
+ if CostAdapterFactory is not None:
267
+ return CostAdapterFactory.compute_cost(
268
+ vendor=provider,
269
+ model_name=model,
270
+ tokens_in=input_tokens,
271
+ tokens_out=output_tokens
272
+ )
273
+ return 0.0
274
+
275
+ def _create_event(
276
+ self,
277
+ span: Dict[str, Any],
278
+ input_tokens: int = 0,
279
+ output_tokens: int = 0,
280
+ error_type: Optional[str] = None,
281
+ error_message: Optional[str] = None,
282
+ metadata: Optional[Dict[str, Any]] = None,
283
+ ) -> Dict[str, Any]:
284
+ """Create a standardized trace event from span data."""
285
+ provider = span.get("provider", "custom")
286
+ model = span.get("model", "unknown")
287
+
288
+ # Compute cost
289
+ cost_usd = self._compute_cost(provider, model, input_tokens, output_tokens)
290
+
291
+ # Build event
292
+ event = {
293
+ "schema_version": "1.0",
294
+ "trace_id": span["trace_id"],
295
+ "span_id": span["span_id"],
296
+ "parent_span_id": span.get("parent_span_id"),
297
+ "tenant_id": self.tenant_id,
298
+ "workflow_id": self.workflow_id,
299
+ "provider": provider,
300
+ "model_id": model,
301
+ "model_name": model,
302
+ "operation": span["operation"],
303
+ "endpoint": span.get("endpoint", span["operation"]),
304
+ "duration_ms": span.get("duration_ms", 0),
305
+ "latency_ms": span.get("duration_ms", 0),
306
+ "input_tokens": input_tokens,
307
+ "output_tokens": output_tokens,
308
+ "total_tokens": input_tokens + output_tokens,
309
+ "cost_usd": cost_usd,
310
+ "total_cost_usd": cost_usd,
311
+ "status": span.get("status", "success"),
312
+ "error_type": error_type,
313
+ "error_message": error_message,
314
+ "timestamp": span["ts_start"].isoformat(),
315
+ "ts_start": span["ts_start"].isoformat(),
316
+ "ts_end": span.get("ts_end", datetime.now(timezone.utc)).isoformat(),
317
+ "environment": self.environment,
318
+ "service": self.service,
319
+ "runtime_env": os.getenv("RUNTIME_ENV", "local"),
320
+ "sandbox_id": os.getenv("SANDBOX_ID", "local"),
321
+ "metadata": {
322
+ **self.default_metadata,
323
+ "span_type": span.get("span_type", "llm"),
324
+ "langchain": True,
325
+ **(metadata or {}),
326
+ },
327
+ }
328
+
329
+ return event
330
+
331
+ def _enqueue_event(self, event: Dict[str, Any]):
332
+ """Add event to queue for batching."""
333
+ try:
334
+ self._event_queue.put_nowait(event)
335
+ except queue.Full:
336
+ # Drop oldest event and retry
337
+ try:
338
+ self._event_queue.get_nowait()
339
+ self._event_queue.put_nowait(event)
340
+ except:
341
+ pass
342
+
343
+ def _flush_loop(self):
344
+ """Background thread to flush events periodically."""
345
+ batch = []
346
+ last_flush = time.time()
347
+
348
+ while not self._shutdown:
349
+ try:
350
+ try:
351
+ event = self._event_queue.get(timeout=0.1)
352
+ batch.append(event)
353
+ except queue.Empty:
354
+ pass
355
+
356
+ now = time.time()
357
+ should_flush = (
358
+ len(batch) >= self.batch_size or
359
+ (batch and now - last_flush >= self.flush_interval)
360
+ )
361
+
362
+ if should_flush:
363
+ self._send_batch(batch)
364
+ batch = []
365
+ last_flush = now
366
+
367
+ except Exception:
368
+ pass
369
+
370
+ # Final flush on shutdown
371
+ if batch:
372
+ self._send_batch(batch)
373
+
374
+ def _send_batch(self, batch: List[Dict[str, Any]]):
375
+ """Send batch to Kalibr backend."""
376
+ if not batch:
377
+ return
378
+
379
+ try:
380
+ payload = {"events": batch}
381
+
382
+ headers = {}
383
+ if self.api_key:
384
+ headers["X-API-Key"] = self.api_key
385
+
386
+ if self.secret:
387
+ import json
388
+ body = json.dumps(payload).encode("utf-8")
389
+ signature = hmac.new(
390
+ self.secret.encode(), body, hashlib.sha256
391
+ ).hexdigest()
392
+ headers["X-Signature"] = signature
393
+
394
+ response = self._client.post(
395
+ self.endpoint,
396
+ json=payload,
397
+ headers=headers,
398
+ )
399
+ response.raise_for_status()
400
+
401
+ except Exception as e:
402
+ # Log error but don't raise
403
+ pass
404
+
405
+ def shutdown(self):
406
+ """Shutdown handler and flush remaining events."""
407
+ if self._shutdown:
408
+ return
409
+
410
+ self._shutdown = True
411
+
412
+ if self._flush_thread.is_alive():
413
+ self._flush_thread.join(timeout=5.0)
414
+
415
+ self._client.close()
416
+
417
+ # =========================================================================
418
+ # LLM Callbacks
419
+ # =========================================================================
420
+
421
+ def on_llm_start(
422
+ self,
423
+ serialized: Dict[str, Any],
424
+ prompts: List[str],
425
+ *,
426
+ run_id: uuid.UUID,
427
+ parent_run_id: Optional[uuid.UUID] = None,
428
+ tags: Optional[List[str]] = None,
429
+ metadata: Optional[Dict[str, Any]] = None,
430
+ **kwargs: Any,
431
+ ) -> None:
432
+ """Called when LLM starts generating."""
433
+ run_id_str = str(run_id)
434
+ parent_id_str = str(parent_run_id) if parent_run_id else None
435
+ trace_id = self._get_or_create_trace_id(parent_id_str)
436
+
437
+ # Extract model info
438
+ model = kwargs.get("invocation_params", {}).get("model_name", "unknown")
439
+ if model == "unknown":
440
+ model = serialized.get("kwargs", {}).get("model_name", "unknown")
441
+
442
+ provider = _get_provider_from_model(model)
443
+
444
+ # Calculate input tokens
445
+ prompt_text = "\n".join(prompts)
446
+ input_tokens = _count_tokens(prompt_text, model)
447
+
448
+ span_metadata = {
449
+ "tags": tags or [],
450
+ "model": model,
451
+ "provider": provider,
452
+ "input_tokens": input_tokens,
453
+ }
454
+
455
+ if self.capture_input:
456
+ span_metadata["input"] = self._truncate(prompt_text)
457
+
458
+ self._span_tracker.start_span(
459
+ run_id=run_id_str,
460
+ trace_id=trace_id,
461
+ parent_run_id=parent_id_str,
462
+ operation="llm_call",
463
+ span_type="llm",
464
+ model=model,
465
+ provider=provider,
466
+ endpoint=f"{provider}.{model}",
467
+ **span_metadata,
468
+ )
469
+
470
+ def on_llm_end(
471
+ self,
472
+ response: LLMResult,
473
+ *,
474
+ run_id: uuid.UUID,
475
+ parent_run_id: Optional[uuid.UUID] = None,
476
+ **kwargs: Any,
477
+ ) -> None:
478
+ """Called when LLM finishes generating."""
479
+ run_id_str = str(run_id)
480
+ span = self._span_tracker.end_span(run_id_str)
481
+
482
+ if not span:
483
+ return
484
+
485
+ # Extract token usage from response
486
+ input_tokens = span.get("input_tokens", 0)
487
+ output_tokens = 0
488
+ output_text = ""
489
+
490
+ # Get token usage from LLM response
491
+ if response.llm_output:
492
+ token_usage = response.llm_output.get("token_usage", {})
493
+ if token_usage:
494
+ input_tokens = token_usage.get("prompt_tokens", input_tokens)
495
+ output_tokens = token_usage.get("completion_tokens", 0)
496
+
497
+ # Extract output text
498
+ if response.generations:
499
+ output_parts = []
500
+ for gen_list in response.generations:
501
+ for gen in gen_list:
502
+ if hasattr(gen, "text"):
503
+ output_parts.append(gen.text)
504
+ elif hasattr(gen, "message") and hasattr(gen.message, "content"):
505
+ output_parts.append(gen.message.content)
506
+ output_text = "\n".join(output_parts)
507
+
508
+ # Fallback token count from output
509
+ if output_tokens == 0:
510
+ output_tokens = _count_tokens(output_text, span.get("model", "unknown"))
511
+
512
+ # Build metadata
513
+ event_metadata = {
514
+ "tags": span.get("tags", []),
515
+ }
516
+ if self.capture_output and output_text:
517
+ event_metadata["output"] = self._truncate(output_text)
518
+ if self.capture_input and "input" in span:
519
+ event_metadata["input"] = span["input"]
520
+
521
+ event = self._create_event(
522
+ span=span,
523
+ input_tokens=input_tokens,
524
+ output_tokens=output_tokens,
525
+ metadata=event_metadata,
526
+ )
527
+
528
+ self._enqueue_event(event)
529
+
530
+ def on_llm_error(
531
+ self,
532
+ error: BaseException,
533
+ *,
534
+ run_id: uuid.UUID,
535
+ parent_run_id: Optional[uuid.UUID] = None,
536
+ **kwargs: Any,
537
+ ) -> None:
538
+ """Called when LLM errors."""
539
+ run_id_str = str(run_id)
540
+ span = self._span_tracker.end_span(run_id_str)
541
+
542
+ if not span:
543
+ return
544
+
545
+ span["status"] = "error"
546
+
547
+ event = self._create_event(
548
+ span=span,
549
+ input_tokens=span.get("input_tokens", 0),
550
+ output_tokens=0,
551
+ error_type=type(error).__name__,
552
+ error_message=str(error)[:512],
553
+ metadata={
554
+ "tags": span.get("tags", []),
555
+ "stack_trace": "".join(traceback.format_exception(
556
+ type(error), error, error.__traceback__
557
+ ))[:2000],
558
+ },
559
+ )
560
+
561
+ self._enqueue_event(event)
562
+
563
+ # =========================================================================
564
+ # Chat Model Callbacks
565
+ # =========================================================================
566
+
567
+ def on_chat_model_start(
568
+ self,
569
+ serialized: Dict[str, Any],
570
+ messages: List[List[BaseMessage]],
571
+ *,
572
+ run_id: uuid.UUID,
573
+ parent_run_id: Optional[uuid.UUID] = None,
574
+ tags: Optional[List[str]] = None,
575
+ metadata: Optional[Dict[str, Any]] = None,
576
+ **kwargs: Any,
577
+ ) -> None:
578
+ """Called when chat model starts."""
579
+ run_id_str = str(run_id)
580
+ parent_id_str = str(parent_run_id) if parent_run_id else None
581
+ trace_id = self._get_or_create_trace_id(parent_id_str)
582
+
583
+ # Extract model info
584
+ model = kwargs.get("invocation_params", {}).get("model", "unknown")
585
+ if model == "unknown":
586
+ model = kwargs.get("invocation_params", {}).get("model_name", "unknown")
587
+ if model == "unknown":
588
+ model = serialized.get("kwargs", {}).get("model", "unknown")
589
+ if model == "unknown":
590
+ model = serialized.get("kwargs", {}).get("model_name", "unknown")
591
+
592
+ provider = _get_provider_from_model(model)
593
+
594
+ # Calculate input tokens from messages
595
+ message_text = ""
596
+ for msg_list in messages:
597
+ for msg in msg_list:
598
+ if hasattr(msg, "content"):
599
+ content = msg.content
600
+ if isinstance(content, str):
601
+ message_text += content + "\n"
602
+ elif isinstance(content, list):
603
+ for item in content:
604
+ if isinstance(item, str):
605
+ message_text += item + "\n"
606
+ elif isinstance(item, dict) and "text" in item:
607
+ message_text += item["text"] + "\n"
608
+
609
+ input_tokens = _count_tokens(message_text, model)
610
+
611
+ span_metadata = {
612
+ "tags": tags or [],
613
+ "model": model,
614
+ "provider": provider,
615
+ "input_tokens": input_tokens,
616
+ "message_count": sum(len(msg_list) for msg_list in messages),
617
+ }
618
+
619
+ if self.capture_input:
620
+ span_metadata["input"] = self._truncate(message_text)
621
+
622
+ self._span_tracker.start_span(
623
+ run_id=run_id_str,
624
+ trace_id=trace_id,
625
+ parent_run_id=parent_id_str,
626
+ operation="chat_completion",
627
+ span_type="chat",
628
+ model=model,
629
+ provider=provider,
630
+ endpoint=f"{provider}.chat.completions",
631
+ **span_metadata,
632
+ )
633
+
634
+ # =========================================================================
635
+ # Chain Callbacks
636
+ # =========================================================================
637
+
638
+ def on_chain_start(
639
+ self,
640
+ serialized: Dict[str, Any],
641
+ inputs: Dict[str, Any],
642
+ *,
643
+ run_id: uuid.UUID,
644
+ parent_run_id: Optional[uuid.UUID] = None,
645
+ tags: Optional[List[str]] = None,
646
+ metadata: Optional[Dict[str, Any]] = None,
647
+ **kwargs: Any,
648
+ ) -> None:
649
+ """Called when chain starts."""
650
+ run_id_str = str(run_id)
651
+ parent_id_str = str(parent_run_id) if parent_run_id else None
652
+ trace_id = self._get_or_create_trace_id(parent_id_str)
653
+
654
+ # Get chain name
655
+ chain_name = serialized.get("name", serialized.get("id", ["unknown"])[-1])
656
+
657
+ span_metadata = {
658
+ "tags": tags or [],
659
+ "chain_name": chain_name,
660
+ }
661
+
662
+ if self.capture_input:
663
+ input_str = str(_serialize_for_metadata(inputs))
664
+ span_metadata["input"] = self._truncate(input_str)
665
+
666
+ self._span_tracker.start_span(
667
+ run_id=run_id_str,
668
+ trace_id=trace_id,
669
+ parent_run_id=parent_id_str,
670
+ operation=f"chain:{chain_name}",
671
+ span_type="chain",
672
+ model="chain",
673
+ provider="langchain",
674
+ endpoint=chain_name,
675
+ **span_metadata,
676
+ )
677
+
678
+ def on_chain_end(
679
+ self,
680
+ outputs: Dict[str, Any],
681
+ *,
682
+ run_id: uuid.UUID,
683
+ parent_run_id: Optional[uuid.UUID] = None,
684
+ **kwargs: Any,
685
+ ) -> None:
686
+ """Called when chain ends."""
687
+ run_id_str = str(run_id)
688
+ span = self._span_tracker.end_span(run_id_str)
689
+
690
+ if not span:
691
+ return
692
+
693
+ event_metadata = {
694
+ "tags": span.get("tags", []),
695
+ "chain_name": span.get("chain_name", "unknown"),
696
+ }
697
+
698
+ if self.capture_output:
699
+ output_str = str(_serialize_for_metadata(outputs))
700
+ event_metadata["output"] = self._truncate(output_str)
701
+ if self.capture_input and "input" in span:
702
+ event_metadata["input"] = span["input"]
703
+
704
+ event = self._create_event(
705
+ span=span,
706
+ metadata=event_metadata,
707
+ )
708
+
709
+ self._enqueue_event(event)
710
+
711
+ def on_chain_error(
712
+ self,
713
+ error: BaseException,
714
+ *,
715
+ run_id: uuid.UUID,
716
+ parent_run_id: Optional[uuid.UUID] = None,
717
+ **kwargs: Any,
718
+ ) -> None:
719
+ """Called when chain errors."""
720
+ run_id_str = str(run_id)
721
+ span = self._span_tracker.end_span(run_id_str)
722
+
723
+ if not span:
724
+ return
725
+
726
+ span["status"] = "error"
727
+
728
+ event = self._create_event(
729
+ span=span,
730
+ error_type=type(error).__name__,
731
+ error_message=str(error)[:512],
732
+ metadata={
733
+ "tags": span.get("tags", []),
734
+ "chain_name": span.get("chain_name", "unknown"),
735
+ "stack_trace": "".join(traceback.format_exception(
736
+ type(error), error, error.__traceback__
737
+ ))[:2000],
738
+ },
739
+ )
740
+
741
+ self._enqueue_event(event)
742
+
743
+ # =========================================================================
744
+ # Tool Callbacks
745
+ # =========================================================================
746
+
747
+ def on_tool_start(
748
+ self,
749
+ serialized: Dict[str, Any],
750
+ input_str: str,
751
+ *,
752
+ run_id: uuid.UUID,
753
+ parent_run_id: Optional[uuid.UUID] = None,
754
+ tags: Optional[List[str]] = None,
755
+ metadata: Optional[Dict[str, Any]] = None,
756
+ **kwargs: Any,
757
+ ) -> None:
758
+ """Called when tool starts."""
759
+ run_id_str = str(run_id)
760
+ parent_id_str = str(parent_run_id) if parent_run_id else None
761
+ trace_id = self._get_or_create_trace_id(parent_id_str)
762
+
763
+ tool_name = serialized.get("name", "unknown_tool")
764
+
765
+ span_metadata = {
766
+ "tags": tags or [],
767
+ "tool_name": tool_name,
768
+ }
769
+
770
+ if self.capture_input:
771
+ span_metadata["input"] = self._truncate(input_str)
772
+
773
+ self._span_tracker.start_span(
774
+ run_id=run_id_str,
775
+ trace_id=trace_id,
776
+ parent_run_id=parent_id_str,
777
+ operation=f"tool:{tool_name}",
778
+ span_type="tool",
779
+ model="tool",
780
+ provider="langchain",
781
+ endpoint=tool_name,
782
+ **span_metadata,
783
+ )
784
+
785
+ def on_tool_end(
786
+ self,
787
+ output: Any,
788
+ *,
789
+ run_id: uuid.UUID,
790
+ parent_run_id: Optional[uuid.UUID] = None,
791
+ **kwargs: Any,
792
+ ) -> None:
793
+ """Called when tool ends."""
794
+ run_id_str = str(run_id)
795
+ span = self._span_tracker.end_span(run_id_str)
796
+
797
+ if not span:
798
+ return
799
+
800
+ event_metadata = {
801
+ "tags": span.get("tags", []),
802
+ "tool_name": span.get("tool_name", "unknown"),
803
+ }
804
+
805
+ if self.capture_output:
806
+ output_str = str(_serialize_for_metadata(output))
807
+ event_metadata["output"] = self._truncate(output_str)
808
+ if self.capture_input and "input" in span:
809
+ event_metadata["input"] = span["input"]
810
+
811
+ event = self._create_event(
812
+ span=span,
813
+ metadata=event_metadata,
814
+ )
815
+
816
+ self._enqueue_event(event)
817
+
818
+ def on_tool_error(
819
+ self,
820
+ error: BaseException,
821
+ *,
822
+ run_id: uuid.UUID,
823
+ parent_run_id: Optional[uuid.UUID] = None,
824
+ **kwargs: Any,
825
+ ) -> None:
826
+ """Called when tool errors."""
827
+ run_id_str = str(run_id)
828
+ span = self._span_tracker.end_span(run_id_str)
829
+
830
+ if not span:
831
+ return
832
+
833
+ span["status"] = "error"
834
+
835
+ event = self._create_event(
836
+ span=span,
837
+ error_type=type(error).__name__,
838
+ error_message=str(error)[:512],
839
+ metadata={
840
+ "tags": span.get("tags", []),
841
+ "tool_name": span.get("tool_name", "unknown"),
842
+ "stack_trace": "".join(traceback.format_exception(
843
+ type(error), error, error.__traceback__
844
+ ))[:2000],
845
+ },
846
+ )
847
+
848
+ self._enqueue_event(event)
849
+
850
+ # =========================================================================
851
+ # Agent Callbacks
852
+ # =========================================================================
853
+
854
+ def on_agent_action(
855
+ self,
856
+ action: AgentAction,
857
+ *,
858
+ run_id: uuid.UUID,
859
+ parent_run_id: Optional[uuid.UUID] = None,
860
+ **kwargs: Any,
861
+ ) -> None:
862
+ """Called when agent takes an action."""
863
+ # Agent actions are tracked as part of the chain
864
+ # We just update metadata on the parent span
865
+ parent_id_str = str(parent_run_id) if parent_run_id else None
866
+ if parent_id_str:
867
+ self._span_tracker.update_span(
868
+ parent_id_str,
869
+ last_action=action.tool,
870
+ last_action_input=self._truncate(str(action.tool_input))
871
+ if self.capture_input else None,
872
+ )
873
+
874
+ def on_agent_finish(
875
+ self,
876
+ finish: AgentFinish,
877
+ *,
878
+ run_id: uuid.UUID,
879
+ parent_run_id: Optional[uuid.UUID] = None,
880
+ **kwargs: Any,
881
+ ) -> None:
882
+ """Called when agent finishes."""
883
+ # Agent finish is tracked as part of the chain
884
+ parent_id_str = str(parent_run_id) if parent_run_id else None
885
+ if parent_id_str:
886
+ self._span_tracker.update_span(
887
+ parent_id_str,
888
+ agent_finish=True,
889
+ return_values=self._truncate(str(finish.return_values))
890
+ if self.capture_output else None,
891
+ )
892
+
893
+ # =========================================================================
894
+ # Retriever Callbacks
895
+ # =========================================================================
896
+
897
+ def on_retriever_start(
898
+ self,
899
+ serialized: Dict[str, Any],
900
+ query: str,
901
+ *,
902
+ run_id: uuid.UUID,
903
+ parent_run_id: Optional[uuid.UUID] = None,
904
+ tags: Optional[List[str]] = None,
905
+ metadata: Optional[Dict[str, Any]] = None,
906
+ **kwargs: Any,
907
+ ) -> None:
908
+ """Called when retriever starts."""
909
+ run_id_str = str(run_id)
910
+ parent_id_str = str(parent_run_id) if parent_run_id else None
911
+ trace_id = self._get_or_create_trace_id(parent_id_str)
912
+
913
+ retriever_name = serialized.get("name", serialized.get("id", ["unknown"])[-1])
914
+
915
+ span_metadata = {
916
+ "tags": tags or [],
917
+ "retriever_name": retriever_name,
918
+ }
919
+
920
+ if self.capture_input:
921
+ span_metadata["query"] = self._truncate(query)
922
+
923
+ self._span_tracker.start_span(
924
+ run_id=run_id_str,
925
+ trace_id=trace_id,
926
+ parent_run_id=parent_id_str,
927
+ operation=f"retriever:{retriever_name}",
928
+ span_type="retriever",
929
+ model="retriever",
930
+ provider="langchain",
931
+ endpoint=retriever_name,
932
+ **span_metadata,
933
+ )
934
+
935
+ def on_retriever_end(
936
+ self,
937
+ documents: Sequence[Document],
938
+ *,
939
+ run_id: uuid.UUID,
940
+ parent_run_id: Optional[uuid.UUID] = None,
941
+ **kwargs: Any,
942
+ ) -> None:
943
+ """Called when retriever ends."""
944
+ run_id_str = str(run_id)
945
+ span = self._span_tracker.end_span(run_id_str)
946
+
947
+ if not span:
948
+ return
949
+
950
+ event_metadata = {
951
+ "tags": span.get("tags", []),
952
+ "retriever_name": span.get("retriever_name", "unknown"),
953
+ "document_count": len(documents),
954
+ }
955
+
956
+ if self.capture_input and "query" in span:
957
+ event_metadata["query"] = span["query"]
958
+
959
+ if self.capture_output and documents:
960
+ # Capture document summaries
961
+ doc_summaries = []
962
+ for doc in documents[:5]: # Limit to first 5 docs
963
+ summary = {
964
+ "content_preview": self._truncate(doc.page_content[:200]),
965
+ "metadata": _serialize_for_metadata(doc.metadata),
966
+ }
967
+ doc_summaries.append(summary)
968
+ event_metadata["documents"] = doc_summaries
969
+
970
+ event = self._create_event(
971
+ span=span,
972
+ metadata=event_metadata,
973
+ )
974
+
975
+ self._enqueue_event(event)
976
+
977
+ def on_retriever_error(
978
+ self,
979
+ error: BaseException,
980
+ *,
981
+ run_id: uuid.UUID,
982
+ parent_run_id: Optional[uuid.UUID] = None,
983
+ **kwargs: Any,
984
+ ) -> None:
985
+ """Called when retriever errors."""
986
+ run_id_str = str(run_id)
987
+ span = self._span_tracker.end_span(run_id_str)
988
+
989
+ if not span:
990
+ return
991
+
992
+ span["status"] = "error"
993
+
994
+ event = self._create_event(
995
+ span=span,
996
+ error_type=type(error).__name__,
997
+ error_message=str(error)[:512],
998
+ metadata={
999
+ "tags": span.get("tags", []),
1000
+ "retriever_name": span.get("retriever_name", "unknown"),
1001
+ "stack_trace": "".join(traceback.format_exception(
1002
+ type(error), error, error.__traceback__
1003
+ ))[:2000],
1004
+ },
1005
+ )
1006
+
1007
+ self._enqueue_event(event)
1008
+
1009
+ # =========================================================================
1010
+ # Text/Streaming Callbacks
1011
+ # =========================================================================
1012
+
1013
+ def on_llm_new_token(
1014
+ self,
1015
+ token: str,
1016
+ *,
1017
+ run_id: uuid.UUID,
1018
+ parent_run_id: Optional[uuid.UUID] = None,
1019
+ **kwargs: Any,
1020
+ ) -> None:
1021
+ """Called on each new token during streaming."""
1022
+ # We track tokens but don't send events for each one
1023
+ run_id_str = str(run_id)
1024
+ span = self._span_tracker.get_span(run_id_str)
1025
+
1026
+ if span:
1027
+ current_tokens = span.get("streaming_tokens", "")
1028
+ self._span_tracker.update_span(
1029
+ run_id_str,
1030
+ streaming_tokens=current_tokens + token,
1031
+ )
1032
+
1033
+ def on_text(
1034
+ self,
1035
+ text: str,
1036
+ *,
1037
+ run_id: uuid.UUID,
1038
+ parent_run_id: Optional[uuid.UUID] = None,
1039
+ **kwargs: Any,
1040
+ ) -> None:
1041
+ """Called when text is generated."""
1042
+ # Track text generation
1043
+ pass
1044
+
1045
+ # =========================================================================
1046
+ # Utility Methods
1047
+ # =========================================================================
1048
+
1049
+ def flush(self):
1050
+ """Force flush all pending events."""
1051
+ events = []
1052
+ while True:
1053
+ try:
1054
+ event = self._event_queue.get_nowait()
1055
+ events.append(event)
1056
+ except queue.Empty:
1057
+ break
1058
+
1059
+ if events:
1060
+ self._send_batch(events)
1061
+
1062
+ def get_trace_id(self) -> Optional[str]:
1063
+ """Get current trace ID."""
1064
+ return self._root_trace_id