kalibr 1.0.28__py3-none-any.whl → 1.1.2a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. kalibr/__init__.py +170 -3
  2. kalibr/__main__.py +3 -203
  3. kalibr/capsule_middleware.py +108 -0
  4. kalibr/cli/__init__.py +5 -0
  5. kalibr/cli/capsule_cmd.py +174 -0
  6. kalibr/cli/deploy_cmd.py +114 -0
  7. kalibr/cli/main.py +67 -0
  8. kalibr/cli/run.py +203 -0
  9. kalibr/cli/serve.py +59 -0
  10. kalibr/client.py +293 -0
  11. kalibr/collector.py +173 -0
  12. kalibr/context.py +132 -0
  13. kalibr/cost_adapter.py +222 -0
  14. kalibr/decorators.py +140 -0
  15. kalibr/instrumentation/__init__.py +13 -0
  16. kalibr/instrumentation/anthropic_instr.py +282 -0
  17. kalibr/instrumentation/base.py +108 -0
  18. kalibr/instrumentation/google_instr.py +281 -0
  19. kalibr/instrumentation/openai_instr.py +265 -0
  20. kalibr/instrumentation/registry.py +153 -0
  21. kalibr/kalibr.py +144 -230
  22. kalibr/kalibr_app.py +53 -314
  23. kalibr/middleware/__init__.py +5 -0
  24. kalibr/middleware/auto_tracer.py +356 -0
  25. kalibr/models.py +41 -0
  26. kalibr/redaction.py +44 -0
  27. kalibr/schemas.py +116 -0
  28. kalibr/simple_tracer.py +258 -0
  29. kalibr/tokens.py +52 -0
  30. kalibr/trace_capsule.py +296 -0
  31. kalibr/trace_models.py +201 -0
  32. kalibr/tracer.py +354 -0
  33. kalibr/types.py +25 -93
  34. kalibr/utils.py +198 -0
  35. kalibr-1.1.2a0.dist-info/METADATA +236 -0
  36. kalibr-1.1.2a0.dist-info/RECORD +48 -0
  37. kalibr-1.1.2a0.dist-info/entry_points.txt +2 -0
  38. kalibr-1.1.2a0.dist-info/licenses/LICENSE +21 -0
  39. kalibr-1.1.2a0.dist-info/top_level.txt +4 -0
  40. kalibr_crewai/__init__.py +65 -0
  41. kalibr_crewai/callbacks.py +539 -0
  42. kalibr_crewai/instrumentor.py +513 -0
  43. kalibr_langchain/__init__.py +47 -0
  44. kalibr_langchain/async_callback.py +850 -0
  45. kalibr_langchain/callback.py +1064 -0
  46. kalibr_openai_agents/__init__.py +43 -0
  47. kalibr_openai_agents/processor.py +554 -0
  48. kalibr/deployment.py +0 -41
  49. kalibr/packager.py +0 -43
  50. kalibr/runtime_router.py +0 -138
  51. kalibr/schema_generators.py +0 -159
  52. kalibr/validator.py +0 -70
  53. kalibr-1.0.28.data/data/examples/README.md +0 -173
  54. kalibr-1.0.28.data/data/examples/basic_kalibr_example.py +0 -66
  55. kalibr-1.0.28.data/data/examples/enhanced_kalibr_example.py +0 -347
  56. kalibr-1.0.28.dist-info/METADATA +0 -175
  57. kalibr-1.0.28.dist-info/RECORD +0 -19
  58. kalibr-1.0.28.dist-info/entry_points.txt +0 -2
  59. kalibr-1.0.28.dist-info/licenses/LICENSE +0 -11
  60. kalibr-1.0.28.dist-info/top_level.txt +0 -1
  61. {kalibr-1.0.28.dist-info → kalibr-1.1.2a0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,850 @@
1
+ """Async Kalibr Callback Handler for LangChain.
2
+
3
+ This module provides an async-compatible callback handler for use with
4
+ async LangChain operations.
5
+ """
6
+
7
+ import asyncio
8
+ import os
9
+ import traceback
10
+ import uuid
11
+ from datetime import datetime, timezone
12
+ from typing import Any, Dict, List, Optional, Sequence, Union
13
+
14
+ import httpx
15
+ from langchain_core.callbacks import AsyncCallbackHandler
16
+ from langchain_core.agents import AgentAction, AgentFinish
17
+ from langchain_core.documents import Document
18
+ from langchain_core.messages import BaseMessage
19
+ from langchain_core.outputs import LLMResult
20
+
21
+ from .callback import (
22
+ SpanTracker,
23
+ _count_tokens,
24
+ _get_provider_from_model,
25
+ _serialize_for_metadata,
26
+ CostAdapterFactory,
27
+ )
28
+
29
+
30
+ class AsyncKalibrCallbackHandler(AsyncCallbackHandler):
31
+ """Async LangChain callback handler for Kalibr observability.
32
+
33
+ This handler is designed for async LangChain operations and uses
34
+ async HTTP calls to send telemetry.
35
+
36
+ Args:
37
+ api_key: Kalibr API key (or KALIBR_API_KEY env var)
38
+ endpoint: Backend endpoint URL (or KALIBR_ENDPOINT env var)
39
+ tenant_id: Tenant identifier (or KALIBR_TENANT_ID env var)
40
+ environment: Environment name (or KALIBR_ENVIRONMENT env var)
41
+ service: Service name (or KALIBR_SERVICE env var)
42
+ workflow_id: Workflow identifier for grouping traces
43
+ capture_input: Whether to capture input prompts (default: True)
44
+ capture_output: Whether to capture outputs (default: True)
45
+ max_content_length: Max length for captured content (default: 10000)
46
+ metadata: Additional metadata to include in all events
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ api_key: Optional[str] = None,
52
+ endpoint: Optional[str] = None,
53
+ tenant_id: Optional[str] = None,
54
+ environment: Optional[str] = None,
55
+ service: Optional[str] = None,
56
+ workflow_id: Optional[str] = None,
57
+ capture_input: bool = True,
58
+ capture_output: bool = True,
59
+ max_content_length: int = 10000,
60
+ metadata: Optional[Dict[str, Any]] = None,
61
+ ):
62
+ super().__init__()
63
+
64
+ # Configuration
65
+ self.api_key = api_key or os.getenv("KALIBR_API_KEY", "")
66
+ self.endpoint = endpoint or os.getenv(
67
+ "KALIBR_ENDPOINT",
68
+ os.getenv("KALIBR_API_ENDPOINT", "http://localhost:8001/api/v1/traces")
69
+ )
70
+ self.tenant_id = tenant_id or os.getenv("KALIBR_TENANT_ID", "default")
71
+ self.environment = environment or os.getenv("KALIBR_ENVIRONMENT", "prod")
72
+ self.service = service or os.getenv("KALIBR_SERVICE", "langchain-app")
73
+ self.workflow_id = workflow_id or os.getenv("KALIBR_WORKFLOW_ID", "default-workflow")
74
+
75
+ # Content capture settings
76
+ self.capture_input = capture_input
77
+ self.capture_output = capture_output
78
+ self.max_content_length = max_content_length
79
+ self.default_metadata = metadata or {}
80
+
81
+ # Span tracking
82
+ self._span_tracker = SpanTracker()
83
+
84
+ # Root trace ID
85
+ self._root_trace_id: Optional[str] = None
86
+ self._trace_lock = asyncio.Lock()
87
+
88
+ # Async HTTP client
89
+ self._client: Optional[httpx.AsyncClient] = None
90
+
91
+ # Event buffer for batching
92
+ self._event_buffer: List[Dict[str, Any]] = []
93
+ self._buffer_lock = asyncio.Lock()
94
+
95
+ async def _get_client(self) -> httpx.AsyncClient:
96
+ """Get or create async HTTP client."""
97
+ if self._client is None:
98
+ self._client = httpx.AsyncClient(timeout=10.0)
99
+ return self._client
100
+
101
+ async def _get_or_create_trace_id(self, parent_run_id: Optional[str]) -> str:
102
+ """Get existing trace ID or create a new one for root spans."""
103
+ async with self._trace_lock:
104
+ if parent_run_id is None:
105
+ self._root_trace_id = str(uuid.uuid4())
106
+ return self._root_trace_id or str(uuid.uuid4())
107
+
108
+ def _truncate(self, text: str) -> str:
109
+ """Truncate text to max length."""
110
+ if len(text) > self.max_content_length:
111
+ return text[:self.max_content_length] + "...[truncated]"
112
+ return text
113
+
114
+ def _compute_cost(
115
+ self,
116
+ provider: str,
117
+ model: str,
118
+ input_tokens: int,
119
+ output_tokens: int
120
+ ) -> float:
121
+ """Compute cost using Kalibr cost adapters."""
122
+ if CostAdapterFactory is not None:
123
+ return CostAdapterFactory.compute_cost(
124
+ vendor=provider,
125
+ model_name=model,
126
+ tokens_in=input_tokens,
127
+ tokens_out=output_tokens
128
+ )
129
+ return 0.0
130
+
131
+ def _create_event(
132
+ self,
133
+ span: Dict[str, Any],
134
+ input_tokens: int = 0,
135
+ output_tokens: int = 0,
136
+ error_type: Optional[str] = None,
137
+ error_message: Optional[str] = None,
138
+ metadata: Optional[Dict[str, Any]] = None,
139
+ ) -> Dict[str, Any]:
140
+ """Create a standardized trace event from span data."""
141
+ provider = span.get("provider", "custom")
142
+ model = span.get("model", "unknown")
143
+
144
+ cost_usd = self._compute_cost(provider, model, input_tokens, output_tokens)
145
+
146
+ event = {
147
+ "schema_version": "1.0",
148
+ "trace_id": span["trace_id"],
149
+ "span_id": span["span_id"],
150
+ "parent_span_id": span.get("parent_span_id"),
151
+ "tenant_id": self.tenant_id,
152
+ "workflow_id": self.workflow_id,
153
+ "provider": provider,
154
+ "model_id": model,
155
+ "model_name": model,
156
+ "operation": span["operation"],
157
+ "endpoint": span.get("endpoint", span["operation"]),
158
+ "duration_ms": span.get("duration_ms", 0),
159
+ "latency_ms": span.get("duration_ms", 0),
160
+ "input_tokens": input_tokens,
161
+ "output_tokens": output_tokens,
162
+ "total_tokens": input_tokens + output_tokens,
163
+ "cost_usd": cost_usd,
164
+ "total_cost_usd": cost_usd,
165
+ "status": span.get("status", "success"),
166
+ "error_type": error_type,
167
+ "error_message": error_message,
168
+ "timestamp": span["ts_start"].isoformat(),
169
+ "ts_start": span["ts_start"].isoformat(),
170
+ "ts_end": span.get("ts_end", datetime.now(timezone.utc)).isoformat(),
171
+ "environment": self.environment,
172
+ "service": self.service,
173
+ "runtime_env": os.getenv("RUNTIME_ENV", "local"),
174
+ "sandbox_id": os.getenv("SANDBOX_ID", "local"),
175
+ "metadata": {
176
+ **self.default_metadata,
177
+ "span_type": span.get("span_type", "llm"),
178
+ "langchain": True,
179
+ "async": True,
180
+ **(metadata or {}),
181
+ },
182
+ }
183
+
184
+ return event
185
+
186
+ async def _send_event(self, event: Dict[str, Any]):
187
+ """Send a single event to Kalibr backend."""
188
+ async with self._buffer_lock:
189
+ self._event_buffer.append(event)
190
+
191
+ # Flush if buffer is large enough
192
+ if len(self._event_buffer) >= 10:
193
+ await self._flush_buffer()
194
+
195
+ async def _flush_buffer(self):
196
+ """Flush event buffer to backend."""
197
+ if not self._event_buffer:
198
+ return
199
+
200
+ events = self._event_buffer.copy()
201
+ self._event_buffer.clear()
202
+
203
+ try:
204
+ client = await self._get_client()
205
+ payload = {"events": events}
206
+
207
+ headers = {}
208
+ if self.api_key:
209
+ headers["X-API-Key"] = self.api_key
210
+
211
+ await client.post(
212
+ self.endpoint,
213
+ json=payload,
214
+ headers=headers,
215
+ )
216
+ except Exception:
217
+ # Log error but don't raise
218
+ pass
219
+
220
+ async def close(self):
221
+ """Close the handler and flush remaining events."""
222
+ async with self._buffer_lock:
223
+ await self._flush_buffer()
224
+
225
+ if self._client:
226
+ await self._client.aclose()
227
+ self._client = None
228
+
229
+ # =========================================================================
230
+ # LLM Callbacks
231
+ # =========================================================================
232
+
233
+ async def on_llm_start(
234
+ self,
235
+ serialized: Dict[str, Any],
236
+ prompts: List[str],
237
+ *,
238
+ run_id: uuid.UUID,
239
+ parent_run_id: Optional[uuid.UUID] = None,
240
+ tags: Optional[List[str]] = None,
241
+ metadata: Optional[Dict[str, Any]] = None,
242
+ **kwargs: Any,
243
+ ) -> None:
244
+ """Called when LLM starts generating."""
245
+ run_id_str = str(run_id)
246
+ parent_id_str = str(parent_run_id) if parent_run_id else None
247
+ trace_id = await self._get_or_create_trace_id(parent_id_str)
248
+
249
+ model = kwargs.get("invocation_params", {}).get("model_name", "unknown")
250
+ if model == "unknown":
251
+ model = serialized.get("kwargs", {}).get("model_name", "unknown")
252
+
253
+ provider = _get_provider_from_model(model)
254
+
255
+ prompt_text = "\n".join(prompts)
256
+ input_tokens = _count_tokens(prompt_text, model)
257
+
258
+ span_metadata = {
259
+ "tags": tags or [],
260
+ "model": model,
261
+ "provider": provider,
262
+ "input_tokens": input_tokens,
263
+ }
264
+
265
+ if self.capture_input:
266
+ span_metadata["input"] = self._truncate(prompt_text)
267
+
268
+ self._span_tracker.start_span(
269
+ run_id=run_id_str,
270
+ trace_id=trace_id,
271
+ parent_run_id=parent_id_str,
272
+ operation="llm_call",
273
+ span_type="llm",
274
+ model=model,
275
+ provider=provider,
276
+ endpoint=f"{provider}.{model}",
277
+ **span_metadata,
278
+ )
279
+
280
+ async def on_llm_end(
281
+ self,
282
+ response: LLMResult,
283
+ *,
284
+ run_id: uuid.UUID,
285
+ parent_run_id: Optional[uuid.UUID] = None,
286
+ **kwargs: Any,
287
+ ) -> None:
288
+ """Called when LLM finishes generating."""
289
+ run_id_str = str(run_id)
290
+ span = self._span_tracker.end_span(run_id_str)
291
+
292
+ if not span:
293
+ return
294
+
295
+ input_tokens = span.get("input_tokens", 0)
296
+ output_tokens = 0
297
+ output_text = ""
298
+
299
+ if response.llm_output:
300
+ token_usage = response.llm_output.get("token_usage", {})
301
+ if token_usage:
302
+ input_tokens = token_usage.get("prompt_tokens", input_tokens)
303
+ output_tokens = token_usage.get("completion_tokens", 0)
304
+
305
+ if response.generations:
306
+ output_parts = []
307
+ for gen_list in response.generations:
308
+ for gen in gen_list:
309
+ if hasattr(gen, "text"):
310
+ output_parts.append(gen.text)
311
+ elif hasattr(gen, "message") and hasattr(gen.message, "content"):
312
+ output_parts.append(gen.message.content)
313
+ output_text = "\n".join(output_parts)
314
+
315
+ if output_tokens == 0:
316
+ output_tokens = _count_tokens(output_text, span.get("model", "unknown"))
317
+
318
+ event_metadata = {
319
+ "tags": span.get("tags", []),
320
+ }
321
+ if self.capture_output and output_text:
322
+ event_metadata["output"] = self._truncate(output_text)
323
+ if self.capture_input and "input" in span:
324
+ event_metadata["input"] = span["input"]
325
+
326
+ event = self._create_event(
327
+ span=span,
328
+ input_tokens=input_tokens,
329
+ output_tokens=output_tokens,
330
+ metadata=event_metadata,
331
+ )
332
+
333
+ await self._send_event(event)
334
+
335
+ async def on_llm_error(
336
+ self,
337
+ error: BaseException,
338
+ *,
339
+ run_id: uuid.UUID,
340
+ parent_run_id: Optional[uuid.UUID] = None,
341
+ **kwargs: Any,
342
+ ) -> None:
343
+ """Called when LLM errors."""
344
+ run_id_str = str(run_id)
345
+ span = self._span_tracker.end_span(run_id_str)
346
+
347
+ if not span:
348
+ return
349
+
350
+ span["status"] = "error"
351
+
352
+ event = self._create_event(
353
+ span=span,
354
+ input_tokens=span.get("input_tokens", 0),
355
+ output_tokens=0,
356
+ error_type=type(error).__name__,
357
+ error_message=str(error)[:512],
358
+ metadata={
359
+ "tags": span.get("tags", []),
360
+ "stack_trace": "".join(traceback.format_exception(
361
+ type(error), error, error.__traceback__
362
+ ))[:2000],
363
+ },
364
+ )
365
+
366
+ await self._send_event(event)
367
+
368
+ # =========================================================================
369
+ # Chat Model Callbacks
370
+ # =========================================================================
371
+
372
+ async def on_chat_model_start(
373
+ self,
374
+ serialized: Dict[str, Any],
375
+ messages: List[List[BaseMessage]],
376
+ *,
377
+ run_id: uuid.UUID,
378
+ parent_run_id: Optional[uuid.UUID] = None,
379
+ tags: Optional[List[str]] = None,
380
+ metadata: Optional[Dict[str, Any]] = None,
381
+ **kwargs: Any,
382
+ ) -> None:
383
+ """Called when chat model starts."""
384
+ run_id_str = str(run_id)
385
+ parent_id_str = str(parent_run_id) if parent_run_id else None
386
+ trace_id = await self._get_or_create_trace_id(parent_id_str)
387
+
388
+ model = kwargs.get("invocation_params", {}).get("model", "unknown")
389
+ if model == "unknown":
390
+ model = kwargs.get("invocation_params", {}).get("model_name", "unknown")
391
+ if model == "unknown":
392
+ model = serialized.get("kwargs", {}).get("model", "unknown")
393
+
394
+ provider = _get_provider_from_model(model)
395
+
396
+ message_text = ""
397
+ for msg_list in messages:
398
+ for msg in msg_list:
399
+ if hasattr(msg, "content"):
400
+ content = msg.content
401
+ if isinstance(content, str):
402
+ message_text += content + "\n"
403
+ elif isinstance(content, list):
404
+ for item in content:
405
+ if isinstance(item, str):
406
+ message_text += item + "\n"
407
+ elif isinstance(item, dict) and "text" in item:
408
+ message_text += item["text"] + "\n"
409
+
410
+ input_tokens = _count_tokens(message_text, model)
411
+
412
+ span_metadata = {
413
+ "tags": tags or [],
414
+ "model": model,
415
+ "provider": provider,
416
+ "input_tokens": input_tokens,
417
+ "message_count": sum(len(msg_list) for msg_list in messages),
418
+ }
419
+
420
+ if self.capture_input:
421
+ span_metadata["input"] = self._truncate(message_text)
422
+
423
+ self._span_tracker.start_span(
424
+ run_id=run_id_str,
425
+ trace_id=trace_id,
426
+ parent_run_id=parent_id_str,
427
+ operation="chat_completion",
428
+ span_type="chat",
429
+ model=model,
430
+ provider=provider,
431
+ endpoint=f"{provider}.chat.completions",
432
+ **span_metadata,
433
+ )
434
+
435
+ # =========================================================================
436
+ # Chain Callbacks
437
+ # =========================================================================
438
+
439
+ async def on_chain_start(
440
+ self,
441
+ serialized: Dict[str, Any],
442
+ inputs: Dict[str, Any],
443
+ *,
444
+ run_id: uuid.UUID,
445
+ parent_run_id: Optional[uuid.UUID] = None,
446
+ tags: Optional[List[str]] = None,
447
+ metadata: Optional[Dict[str, Any]] = None,
448
+ **kwargs: Any,
449
+ ) -> None:
450
+ """Called when chain starts."""
451
+ run_id_str = str(run_id)
452
+ parent_id_str = str(parent_run_id) if parent_run_id else None
453
+ trace_id = await self._get_or_create_trace_id(parent_id_str)
454
+
455
+ chain_name = serialized.get("name", serialized.get("id", ["unknown"])[-1])
456
+
457
+ span_metadata = {
458
+ "tags": tags or [],
459
+ "chain_name": chain_name,
460
+ }
461
+
462
+ if self.capture_input:
463
+ input_str = str(_serialize_for_metadata(inputs))
464
+ span_metadata["input"] = self._truncate(input_str)
465
+
466
+ self._span_tracker.start_span(
467
+ run_id=run_id_str,
468
+ trace_id=trace_id,
469
+ parent_run_id=parent_id_str,
470
+ operation=f"chain:{chain_name}",
471
+ span_type="chain",
472
+ model="chain",
473
+ provider="langchain",
474
+ endpoint=chain_name,
475
+ **span_metadata,
476
+ )
477
+
478
+ async def on_chain_end(
479
+ self,
480
+ outputs: Dict[str, Any],
481
+ *,
482
+ run_id: uuid.UUID,
483
+ parent_run_id: Optional[uuid.UUID] = None,
484
+ **kwargs: Any,
485
+ ) -> None:
486
+ """Called when chain ends."""
487
+ run_id_str = str(run_id)
488
+ span = self._span_tracker.end_span(run_id_str)
489
+
490
+ if not span:
491
+ return
492
+
493
+ event_metadata = {
494
+ "tags": span.get("tags", []),
495
+ "chain_name": span.get("chain_name", "unknown"),
496
+ }
497
+
498
+ if self.capture_output:
499
+ output_str = str(_serialize_for_metadata(outputs))
500
+ event_metadata["output"] = self._truncate(output_str)
501
+ if self.capture_input and "input" in span:
502
+ event_metadata["input"] = span["input"]
503
+
504
+ event = self._create_event(
505
+ span=span,
506
+ metadata=event_metadata,
507
+ )
508
+
509
+ await self._send_event(event)
510
+
511
+ async def on_chain_error(
512
+ self,
513
+ error: BaseException,
514
+ *,
515
+ run_id: uuid.UUID,
516
+ parent_run_id: Optional[uuid.UUID] = None,
517
+ **kwargs: Any,
518
+ ) -> None:
519
+ """Called when chain errors."""
520
+ run_id_str = str(run_id)
521
+ span = self._span_tracker.end_span(run_id_str)
522
+
523
+ if not span:
524
+ return
525
+
526
+ span["status"] = "error"
527
+
528
+ event = self._create_event(
529
+ span=span,
530
+ error_type=type(error).__name__,
531
+ error_message=str(error)[:512],
532
+ metadata={
533
+ "tags": span.get("tags", []),
534
+ "chain_name": span.get("chain_name", "unknown"),
535
+ "stack_trace": "".join(traceback.format_exception(
536
+ type(error), error, error.__traceback__
537
+ ))[:2000],
538
+ },
539
+ )
540
+
541
+ await self._send_event(event)
542
+
543
+ # =========================================================================
544
+ # Tool Callbacks
545
+ # =========================================================================
546
+
547
+ async def on_tool_start(
548
+ self,
549
+ serialized: Dict[str, Any],
550
+ input_str: str,
551
+ *,
552
+ run_id: uuid.UUID,
553
+ parent_run_id: Optional[uuid.UUID] = None,
554
+ tags: Optional[List[str]] = None,
555
+ metadata: Optional[Dict[str, Any]] = None,
556
+ **kwargs: Any,
557
+ ) -> None:
558
+ """Called when tool starts."""
559
+ run_id_str = str(run_id)
560
+ parent_id_str = str(parent_run_id) if parent_run_id else None
561
+ trace_id = await self._get_or_create_trace_id(parent_id_str)
562
+
563
+ tool_name = serialized.get("name", "unknown_tool")
564
+
565
+ span_metadata = {
566
+ "tags": tags or [],
567
+ "tool_name": tool_name,
568
+ }
569
+
570
+ if self.capture_input:
571
+ span_metadata["input"] = self._truncate(input_str)
572
+
573
+ self._span_tracker.start_span(
574
+ run_id=run_id_str,
575
+ trace_id=trace_id,
576
+ parent_run_id=parent_id_str,
577
+ operation=f"tool:{tool_name}",
578
+ span_type="tool",
579
+ model="tool",
580
+ provider="langchain",
581
+ endpoint=tool_name,
582
+ **span_metadata,
583
+ )
584
+
585
+ async def on_tool_end(
586
+ self,
587
+ output: Any,
588
+ *,
589
+ run_id: uuid.UUID,
590
+ parent_run_id: Optional[uuid.UUID] = None,
591
+ **kwargs: Any,
592
+ ) -> None:
593
+ """Called when tool ends."""
594
+ run_id_str = str(run_id)
595
+ span = self._span_tracker.end_span(run_id_str)
596
+
597
+ if not span:
598
+ return
599
+
600
+ event_metadata = {
601
+ "tags": span.get("tags", []),
602
+ "tool_name": span.get("tool_name", "unknown"),
603
+ }
604
+
605
+ if self.capture_output:
606
+ output_str = str(_serialize_for_metadata(output))
607
+ event_metadata["output"] = self._truncate(output_str)
608
+ if self.capture_input and "input" in span:
609
+ event_metadata["input"] = span["input"]
610
+
611
+ event = self._create_event(
612
+ span=span,
613
+ metadata=event_metadata,
614
+ )
615
+
616
+ await self._send_event(event)
617
+
618
+ async def on_tool_error(
619
+ self,
620
+ error: BaseException,
621
+ *,
622
+ run_id: uuid.UUID,
623
+ parent_run_id: Optional[uuid.UUID] = None,
624
+ **kwargs: Any,
625
+ ) -> None:
626
+ """Called when tool errors."""
627
+ run_id_str = str(run_id)
628
+ span = self._span_tracker.end_span(run_id_str)
629
+
630
+ if not span:
631
+ return
632
+
633
+ span["status"] = "error"
634
+
635
+ event = self._create_event(
636
+ span=span,
637
+ error_type=type(error).__name__,
638
+ error_message=str(error)[:512],
639
+ metadata={
640
+ "tags": span.get("tags", []),
641
+ "tool_name": span.get("tool_name", "unknown"),
642
+ "stack_trace": "".join(traceback.format_exception(
643
+ type(error), error, error.__traceback__
644
+ ))[:2000],
645
+ },
646
+ )
647
+
648
+ await self._send_event(event)
649
+
650
+ # =========================================================================
651
+ # Agent Callbacks
652
+ # =========================================================================
653
+
654
+ async def on_agent_action(
655
+ self,
656
+ action: AgentAction,
657
+ *,
658
+ run_id: uuid.UUID,
659
+ parent_run_id: Optional[uuid.UUID] = None,
660
+ **kwargs: Any,
661
+ ) -> None:
662
+ """Called when agent takes an action."""
663
+ parent_id_str = str(parent_run_id) if parent_run_id else None
664
+ if parent_id_str:
665
+ self._span_tracker.update_span(
666
+ parent_id_str,
667
+ last_action=action.tool,
668
+ last_action_input=self._truncate(str(action.tool_input))
669
+ if self.capture_input else None,
670
+ )
671
+
672
+ async def on_agent_finish(
673
+ self,
674
+ finish: AgentFinish,
675
+ *,
676
+ run_id: uuid.UUID,
677
+ parent_run_id: Optional[uuid.UUID] = None,
678
+ **kwargs: Any,
679
+ ) -> None:
680
+ """Called when agent finishes."""
681
+ parent_id_str = str(parent_run_id) if parent_run_id else None
682
+ if parent_id_str:
683
+ self._span_tracker.update_span(
684
+ parent_id_str,
685
+ agent_finish=True,
686
+ return_values=self._truncate(str(finish.return_values))
687
+ if self.capture_output else None,
688
+ )
689
+
690
+ # =========================================================================
691
+ # Retriever Callbacks
692
+ # =========================================================================
693
+
694
+ async def on_retriever_start(
695
+ self,
696
+ serialized: Dict[str, Any],
697
+ query: str,
698
+ *,
699
+ run_id: uuid.UUID,
700
+ parent_run_id: Optional[uuid.UUID] = None,
701
+ tags: Optional[List[str]] = None,
702
+ metadata: Optional[Dict[str, Any]] = None,
703
+ **kwargs: Any,
704
+ ) -> None:
705
+ """Called when retriever starts."""
706
+ run_id_str = str(run_id)
707
+ parent_id_str = str(parent_run_id) if parent_run_id else None
708
+ trace_id = await self._get_or_create_trace_id(parent_id_str)
709
+
710
+ retriever_name = serialized.get("name", serialized.get("id", ["unknown"])[-1])
711
+
712
+ span_metadata = {
713
+ "tags": tags or [],
714
+ "retriever_name": retriever_name,
715
+ }
716
+
717
+ if self.capture_input:
718
+ span_metadata["query"] = self._truncate(query)
719
+
720
+ self._span_tracker.start_span(
721
+ run_id=run_id_str,
722
+ trace_id=trace_id,
723
+ parent_run_id=parent_id_str,
724
+ operation=f"retriever:{retriever_name}",
725
+ span_type="retriever",
726
+ model="retriever",
727
+ provider="langchain",
728
+ endpoint=retriever_name,
729
+ **span_metadata,
730
+ )
731
+
732
+ async def on_retriever_end(
733
+ self,
734
+ documents: Sequence[Document],
735
+ *,
736
+ run_id: uuid.UUID,
737
+ parent_run_id: Optional[uuid.UUID] = None,
738
+ **kwargs: Any,
739
+ ) -> None:
740
+ """Called when retriever ends."""
741
+ run_id_str = str(run_id)
742
+ span = self._span_tracker.end_span(run_id_str)
743
+
744
+ if not span:
745
+ return
746
+
747
+ event_metadata = {
748
+ "tags": span.get("tags", []),
749
+ "retriever_name": span.get("retriever_name", "unknown"),
750
+ "document_count": len(documents),
751
+ }
752
+
753
+ if self.capture_input and "query" in span:
754
+ event_metadata["query"] = span["query"]
755
+
756
+ if self.capture_output and documents:
757
+ doc_summaries = []
758
+ for doc in documents[:5]:
759
+ summary = {
760
+ "content_preview": self._truncate(doc.page_content[:200]),
761
+ "metadata": _serialize_for_metadata(doc.metadata),
762
+ }
763
+ doc_summaries.append(summary)
764
+ event_metadata["documents"] = doc_summaries
765
+
766
+ event = self._create_event(
767
+ span=span,
768
+ metadata=event_metadata,
769
+ )
770
+
771
+ await self._send_event(event)
772
+
773
+ async def on_retriever_error(
774
+ self,
775
+ error: BaseException,
776
+ *,
777
+ run_id: uuid.UUID,
778
+ parent_run_id: Optional[uuid.UUID] = None,
779
+ **kwargs: Any,
780
+ ) -> None:
781
+ """Called when retriever errors."""
782
+ run_id_str = str(run_id)
783
+ span = self._span_tracker.end_span(run_id_str)
784
+
785
+ if not span:
786
+ return
787
+
788
+ span["status"] = "error"
789
+
790
+ event = self._create_event(
791
+ span=span,
792
+ error_type=type(error).__name__,
793
+ error_message=str(error)[:512],
794
+ metadata={
795
+ "tags": span.get("tags", []),
796
+ "retriever_name": span.get("retriever_name", "unknown"),
797
+ "stack_trace": "".join(traceback.format_exception(
798
+ type(error), error, error.__traceback__
799
+ ))[:2000],
800
+ },
801
+ )
802
+
803
+ await self._send_event(event)
804
+
805
+ # =========================================================================
806
+ # Streaming Callbacks
807
+ # =========================================================================
808
+
809
+ async def on_llm_new_token(
810
+ self,
811
+ token: str,
812
+ *,
813
+ run_id: uuid.UUID,
814
+ parent_run_id: Optional[uuid.UUID] = None,
815
+ **kwargs: Any,
816
+ ) -> None:
817
+ """Called on each new token during streaming."""
818
+ run_id_str = str(run_id)
819
+ span = self._span_tracker.get_span(run_id_str)
820
+
821
+ if span:
822
+ current_tokens = span.get("streaming_tokens", "")
823
+ self._span_tracker.update_span(
824
+ run_id_str,
825
+ streaming_tokens=current_tokens + token,
826
+ )
827
+
828
+ async def on_text(
829
+ self,
830
+ text: str,
831
+ *,
832
+ run_id: uuid.UUID,
833
+ parent_run_id: Optional[uuid.UUID] = None,
834
+ **kwargs: Any,
835
+ ) -> None:
836
+ """Called when text is generated."""
837
+ pass
838
+
839
+ # =========================================================================
840
+ # Utility Methods
841
+ # =========================================================================
842
+
843
+ async def flush(self):
844
+ """Force flush all pending events."""
845
+ async with self._buffer_lock:
846
+ await self._flush_buffer()
847
+
848
+ def get_trace_id(self) -> Optional[str]:
849
+ """Get current trace ID."""
850
+ return self._root_trace_id