kalibr 1.0.28__py3-none-any.whl → 1.1.2a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kalibr/__init__.py +170 -3
- kalibr/__main__.py +3 -203
- kalibr/capsule_middleware.py +108 -0
- kalibr/cli/__init__.py +5 -0
- kalibr/cli/capsule_cmd.py +174 -0
- kalibr/cli/deploy_cmd.py +114 -0
- kalibr/cli/main.py +67 -0
- kalibr/cli/run.py +203 -0
- kalibr/cli/serve.py +59 -0
- kalibr/client.py +293 -0
- kalibr/collector.py +173 -0
- kalibr/context.py +132 -0
- kalibr/cost_adapter.py +222 -0
- kalibr/decorators.py +140 -0
- kalibr/instrumentation/__init__.py +13 -0
- kalibr/instrumentation/anthropic_instr.py +282 -0
- kalibr/instrumentation/base.py +108 -0
- kalibr/instrumentation/google_instr.py +281 -0
- kalibr/instrumentation/openai_instr.py +265 -0
- kalibr/instrumentation/registry.py +153 -0
- kalibr/kalibr.py +144 -230
- kalibr/kalibr_app.py +53 -314
- kalibr/middleware/__init__.py +5 -0
- kalibr/middleware/auto_tracer.py +356 -0
- kalibr/models.py +41 -0
- kalibr/redaction.py +44 -0
- kalibr/schemas.py +116 -0
- kalibr/simple_tracer.py +258 -0
- kalibr/tokens.py +52 -0
- kalibr/trace_capsule.py +296 -0
- kalibr/trace_models.py +201 -0
- kalibr/tracer.py +354 -0
- kalibr/types.py +25 -93
- kalibr/utils.py +198 -0
- kalibr-1.1.2a0.dist-info/METADATA +236 -0
- kalibr-1.1.2a0.dist-info/RECORD +48 -0
- kalibr-1.1.2a0.dist-info/entry_points.txt +2 -0
- kalibr-1.1.2a0.dist-info/licenses/LICENSE +21 -0
- kalibr-1.1.2a0.dist-info/top_level.txt +4 -0
- kalibr_crewai/__init__.py +65 -0
- kalibr_crewai/callbacks.py +539 -0
- kalibr_crewai/instrumentor.py +513 -0
- kalibr_langchain/__init__.py +47 -0
- kalibr_langchain/async_callback.py +850 -0
- kalibr_langchain/callback.py +1064 -0
- kalibr_openai_agents/__init__.py +43 -0
- kalibr_openai_agents/processor.py +554 -0
- kalibr/deployment.py +0 -41
- kalibr/packager.py +0 -43
- kalibr/runtime_router.py +0 -138
- kalibr/schema_generators.py +0 -159
- kalibr/validator.py +0 -70
- kalibr-1.0.28.data/data/examples/README.md +0 -173
- kalibr-1.0.28.data/data/examples/basic_kalibr_example.py +0 -66
- kalibr-1.0.28.data/data/examples/enhanced_kalibr_example.py +0 -347
- kalibr-1.0.28.dist-info/METADATA +0 -175
- kalibr-1.0.28.dist-info/RECORD +0 -19
- kalibr-1.0.28.dist-info/entry_points.txt +0 -2
- kalibr-1.0.28.dist-info/licenses/LICENSE +0 -11
- kalibr-1.0.28.dist-info/top_level.txt +0 -1
- {kalibr-1.0.28.dist-info → kalibr-1.1.2a0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,850 @@
|
|
|
1
|
+
"""Async Kalibr Callback Handler for LangChain.
|
|
2
|
+
|
|
3
|
+
This module provides an async-compatible callback handler for use with
|
|
4
|
+
async LangChain operations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import os
|
|
9
|
+
import traceback
|
|
10
|
+
import uuid
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from typing import Any, Dict, List, Optional, Sequence, Union
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
from langchain_core.callbacks import AsyncCallbackHandler
|
|
16
|
+
from langchain_core.agents import AgentAction, AgentFinish
|
|
17
|
+
from langchain_core.documents import Document
|
|
18
|
+
from langchain_core.messages import BaseMessage
|
|
19
|
+
from langchain_core.outputs import LLMResult
|
|
20
|
+
|
|
21
|
+
from .callback import (
|
|
22
|
+
SpanTracker,
|
|
23
|
+
_count_tokens,
|
|
24
|
+
_get_provider_from_model,
|
|
25
|
+
_serialize_for_metadata,
|
|
26
|
+
CostAdapterFactory,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class AsyncKalibrCallbackHandler(AsyncCallbackHandler):
|
|
31
|
+
"""Async LangChain callback handler for Kalibr observability.
|
|
32
|
+
|
|
33
|
+
This handler is designed for async LangChain operations and uses
|
|
34
|
+
async HTTP calls to send telemetry.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
api_key: Kalibr API key (or KALIBR_API_KEY env var)
|
|
38
|
+
endpoint: Backend endpoint URL (or KALIBR_ENDPOINT env var)
|
|
39
|
+
tenant_id: Tenant identifier (or KALIBR_TENANT_ID env var)
|
|
40
|
+
environment: Environment name (or KALIBR_ENVIRONMENT env var)
|
|
41
|
+
service: Service name (or KALIBR_SERVICE env var)
|
|
42
|
+
workflow_id: Workflow identifier for grouping traces
|
|
43
|
+
capture_input: Whether to capture input prompts (default: True)
|
|
44
|
+
capture_output: Whether to capture outputs (default: True)
|
|
45
|
+
max_content_length: Max length for captured content (default: 10000)
|
|
46
|
+
metadata: Additional metadata to include in all events
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
api_key: Optional[str] = None,
|
|
52
|
+
endpoint: Optional[str] = None,
|
|
53
|
+
tenant_id: Optional[str] = None,
|
|
54
|
+
environment: Optional[str] = None,
|
|
55
|
+
service: Optional[str] = None,
|
|
56
|
+
workflow_id: Optional[str] = None,
|
|
57
|
+
capture_input: bool = True,
|
|
58
|
+
capture_output: bool = True,
|
|
59
|
+
max_content_length: int = 10000,
|
|
60
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
61
|
+
):
|
|
62
|
+
super().__init__()
|
|
63
|
+
|
|
64
|
+
# Configuration
|
|
65
|
+
self.api_key = api_key or os.getenv("KALIBR_API_KEY", "")
|
|
66
|
+
self.endpoint = endpoint or os.getenv(
|
|
67
|
+
"KALIBR_ENDPOINT",
|
|
68
|
+
os.getenv("KALIBR_API_ENDPOINT", "http://localhost:8001/api/v1/traces")
|
|
69
|
+
)
|
|
70
|
+
self.tenant_id = tenant_id or os.getenv("KALIBR_TENANT_ID", "default")
|
|
71
|
+
self.environment = environment or os.getenv("KALIBR_ENVIRONMENT", "prod")
|
|
72
|
+
self.service = service or os.getenv("KALIBR_SERVICE", "langchain-app")
|
|
73
|
+
self.workflow_id = workflow_id or os.getenv("KALIBR_WORKFLOW_ID", "default-workflow")
|
|
74
|
+
|
|
75
|
+
# Content capture settings
|
|
76
|
+
self.capture_input = capture_input
|
|
77
|
+
self.capture_output = capture_output
|
|
78
|
+
self.max_content_length = max_content_length
|
|
79
|
+
self.default_metadata = metadata or {}
|
|
80
|
+
|
|
81
|
+
# Span tracking
|
|
82
|
+
self._span_tracker = SpanTracker()
|
|
83
|
+
|
|
84
|
+
# Root trace ID
|
|
85
|
+
self._root_trace_id: Optional[str] = None
|
|
86
|
+
self._trace_lock = asyncio.Lock()
|
|
87
|
+
|
|
88
|
+
# Async HTTP client
|
|
89
|
+
self._client: Optional[httpx.AsyncClient] = None
|
|
90
|
+
|
|
91
|
+
# Event buffer for batching
|
|
92
|
+
self._event_buffer: List[Dict[str, Any]] = []
|
|
93
|
+
self._buffer_lock = asyncio.Lock()
|
|
94
|
+
|
|
95
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
96
|
+
"""Get or create async HTTP client."""
|
|
97
|
+
if self._client is None:
|
|
98
|
+
self._client = httpx.AsyncClient(timeout=10.0)
|
|
99
|
+
return self._client
|
|
100
|
+
|
|
101
|
+
async def _get_or_create_trace_id(self, parent_run_id: Optional[str]) -> str:
|
|
102
|
+
"""Get existing trace ID or create a new one for root spans."""
|
|
103
|
+
async with self._trace_lock:
|
|
104
|
+
if parent_run_id is None:
|
|
105
|
+
self._root_trace_id = str(uuid.uuid4())
|
|
106
|
+
return self._root_trace_id or str(uuid.uuid4())
|
|
107
|
+
|
|
108
|
+
def _truncate(self, text: str) -> str:
|
|
109
|
+
"""Truncate text to max length."""
|
|
110
|
+
if len(text) > self.max_content_length:
|
|
111
|
+
return text[:self.max_content_length] + "...[truncated]"
|
|
112
|
+
return text
|
|
113
|
+
|
|
114
|
+
def _compute_cost(
|
|
115
|
+
self,
|
|
116
|
+
provider: str,
|
|
117
|
+
model: str,
|
|
118
|
+
input_tokens: int,
|
|
119
|
+
output_tokens: int
|
|
120
|
+
) -> float:
|
|
121
|
+
"""Compute cost using Kalibr cost adapters."""
|
|
122
|
+
if CostAdapterFactory is not None:
|
|
123
|
+
return CostAdapterFactory.compute_cost(
|
|
124
|
+
vendor=provider,
|
|
125
|
+
model_name=model,
|
|
126
|
+
tokens_in=input_tokens,
|
|
127
|
+
tokens_out=output_tokens
|
|
128
|
+
)
|
|
129
|
+
return 0.0
|
|
130
|
+
|
|
131
|
+
def _create_event(
|
|
132
|
+
self,
|
|
133
|
+
span: Dict[str, Any],
|
|
134
|
+
input_tokens: int = 0,
|
|
135
|
+
output_tokens: int = 0,
|
|
136
|
+
error_type: Optional[str] = None,
|
|
137
|
+
error_message: Optional[str] = None,
|
|
138
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
139
|
+
) -> Dict[str, Any]:
|
|
140
|
+
"""Create a standardized trace event from span data."""
|
|
141
|
+
provider = span.get("provider", "custom")
|
|
142
|
+
model = span.get("model", "unknown")
|
|
143
|
+
|
|
144
|
+
cost_usd = self._compute_cost(provider, model, input_tokens, output_tokens)
|
|
145
|
+
|
|
146
|
+
event = {
|
|
147
|
+
"schema_version": "1.0",
|
|
148
|
+
"trace_id": span["trace_id"],
|
|
149
|
+
"span_id": span["span_id"],
|
|
150
|
+
"parent_span_id": span.get("parent_span_id"),
|
|
151
|
+
"tenant_id": self.tenant_id,
|
|
152
|
+
"workflow_id": self.workflow_id,
|
|
153
|
+
"provider": provider,
|
|
154
|
+
"model_id": model,
|
|
155
|
+
"model_name": model,
|
|
156
|
+
"operation": span["operation"],
|
|
157
|
+
"endpoint": span.get("endpoint", span["operation"]),
|
|
158
|
+
"duration_ms": span.get("duration_ms", 0),
|
|
159
|
+
"latency_ms": span.get("duration_ms", 0),
|
|
160
|
+
"input_tokens": input_tokens,
|
|
161
|
+
"output_tokens": output_tokens,
|
|
162
|
+
"total_tokens": input_tokens + output_tokens,
|
|
163
|
+
"cost_usd": cost_usd,
|
|
164
|
+
"total_cost_usd": cost_usd,
|
|
165
|
+
"status": span.get("status", "success"),
|
|
166
|
+
"error_type": error_type,
|
|
167
|
+
"error_message": error_message,
|
|
168
|
+
"timestamp": span["ts_start"].isoformat(),
|
|
169
|
+
"ts_start": span["ts_start"].isoformat(),
|
|
170
|
+
"ts_end": span.get("ts_end", datetime.now(timezone.utc)).isoformat(),
|
|
171
|
+
"environment": self.environment,
|
|
172
|
+
"service": self.service,
|
|
173
|
+
"runtime_env": os.getenv("RUNTIME_ENV", "local"),
|
|
174
|
+
"sandbox_id": os.getenv("SANDBOX_ID", "local"),
|
|
175
|
+
"metadata": {
|
|
176
|
+
**self.default_metadata,
|
|
177
|
+
"span_type": span.get("span_type", "llm"),
|
|
178
|
+
"langchain": True,
|
|
179
|
+
"async": True,
|
|
180
|
+
**(metadata or {}),
|
|
181
|
+
},
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
return event
|
|
185
|
+
|
|
186
|
+
async def _send_event(self, event: Dict[str, Any]):
|
|
187
|
+
"""Send a single event to Kalibr backend."""
|
|
188
|
+
async with self._buffer_lock:
|
|
189
|
+
self._event_buffer.append(event)
|
|
190
|
+
|
|
191
|
+
# Flush if buffer is large enough
|
|
192
|
+
if len(self._event_buffer) >= 10:
|
|
193
|
+
await self._flush_buffer()
|
|
194
|
+
|
|
195
|
+
async def _flush_buffer(self):
|
|
196
|
+
"""Flush event buffer to backend."""
|
|
197
|
+
if not self._event_buffer:
|
|
198
|
+
return
|
|
199
|
+
|
|
200
|
+
events = self._event_buffer.copy()
|
|
201
|
+
self._event_buffer.clear()
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
client = await self._get_client()
|
|
205
|
+
payload = {"events": events}
|
|
206
|
+
|
|
207
|
+
headers = {}
|
|
208
|
+
if self.api_key:
|
|
209
|
+
headers["X-API-Key"] = self.api_key
|
|
210
|
+
|
|
211
|
+
await client.post(
|
|
212
|
+
self.endpoint,
|
|
213
|
+
json=payload,
|
|
214
|
+
headers=headers,
|
|
215
|
+
)
|
|
216
|
+
except Exception:
|
|
217
|
+
# Log error but don't raise
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
async def close(self):
|
|
221
|
+
"""Close the handler and flush remaining events."""
|
|
222
|
+
async with self._buffer_lock:
|
|
223
|
+
await self._flush_buffer()
|
|
224
|
+
|
|
225
|
+
if self._client:
|
|
226
|
+
await self._client.aclose()
|
|
227
|
+
self._client = None
|
|
228
|
+
|
|
229
|
+
# =========================================================================
|
|
230
|
+
# LLM Callbacks
|
|
231
|
+
# =========================================================================
|
|
232
|
+
|
|
233
|
+
async def on_llm_start(
|
|
234
|
+
self,
|
|
235
|
+
serialized: Dict[str, Any],
|
|
236
|
+
prompts: List[str],
|
|
237
|
+
*,
|
|
238
|
+
run_id: uuid.UUID,
|
|
239
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
240
|
+
tags: Optional[List[str]] = None,
|
|
241
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
242
|
+
**kwargs: Any,
|
|
243
|
+
) -> None:
|
|
244
|
+
"""Called when LLM starts generating."""
|
|
245
|
+
run_id_str = str(run_id)
|
|
246
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
247
|
+
trace_id = await self._get_or_create_trace_id(parent_id_str)
|
|
248
|
+
|
|
249
|
+
model = kwargs.get("invocation_params", {}).get("model_name", "unknown")
|
|
250
|
+
if model == "unknown":
|
|
251
|
+
model = serialized.get("kwargs", {}).get("model_name", "unknown")
|
|
252
|
+
|
|
253
|
+
provider = _get_provider_from_model(model)
|
|
254
|
+
|
|
255
|
+
prompt_text = "\n".join(prompts)
|
|
256
|
+
input_tokens = _count_tokens(prompt_text, model)
|
|
257
|
+
|
|
258
|
+
span_metadata = {
|
|
259
|
+
"tags": tags or [],
|
|
260
|
+
"model": model,
|
|
261
|
+
"provider": provider,
|
|
262
|
+
"input_tokens": input_tokens,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
if self.capture_input:
|
|
266
|
+
span_metadata["input"] = self._truncate(prompt_text)
|
|
267
|
+
|
|
268
|
+
self._span_tracker.start_span(
|
|
269
|
+
run_id=run_id_str,
|
|
270
|
+
trace_id=trace_id,
|
|
271
|
+
parent_run_id=parent_id_str,
|
|
272
|
+
operation="llm_call",
|
|
273
|
+
span_type="llm",
|
|
274
|
+
model=model,
|
|
275
|
+
provider=provider,
|
|
276
|
+
endpoint=f"{provider}.{model}",
|
|
277
|
+
**span_metadata,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
async def on_llm_end(
|
|
281
|
+
self,
|
|
282
|
+
response: LLMResult,
|
|
283
|
+
*,
|
|
284
|
+
run_id: uuid.UUID,
|
|
285
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
286
|
+
**kwargs: Any,
|
|
287
|
+
) -> None:
|
|
288
|
+
"""Called when LLM finishes generating."""
|
|
289
|
+
run_id_str = str(run_id)
|
|
290
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
291
|
+
|
|
292
|
+
if not span:
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
input_tokens = span.get("input_tokens", 0)
|
|
296
|
+
output_tokens = 0
|
|
297
|
+
output_text = ""
|
|
298
|
+
|
|
299
|
+
if response.llm_output:
|
|
300
|
+
token_usage = response.llm_output.get("token_usage", {})
|
|
301
|
+
if token_usage:
|
|
302
|
+
input_tokens = token_usage.get("prompt_tokens", input_tokens)
|
|
303
|
+
output_tokens = token_usage.get("completion_tokens", 0)
|
|
304
|
+
|
|
305
|
+
if response.generations:
|
|
306
|
+
output_parts = []
|
|
307
|
+
for gen_list in response.generations:
|
|
308
|
+
for gen in gen_list:
|
|
309
|
+
if hasattr(gen, "text"):
|
|
310
|
+
output_parts.append(gen.text)
|
|
311
|
+
elif hasattr(gen, "message") and hasattr(gen.message, "content"):
|
|
312
|
+
output_parts.append(gen.message.content)
|
|
313
|
+
output_text = "\n".join(output_parts)
|
|
314
|
+
|
|
315
|
+
if output_tokens == 0:
|
|
316
|
+
output_tokens = _count_tokens(output_text, span.get("model", "unknown"))
|
|
317
|
+
|
|
318
|
+
event_metadata = {
|
|
319
|
+
"tags": span.get("tags", []),
|
|
320
|
+
}
|
|
321
|
+
if self.capture_output and output_text:
|
|
322
|
+
event_metadata["output"] = self._truncate(output_text)
|
|
323
|
+
if self.capture_input and "input" in span:
|
|
324
|
+
event_metadata["input"] = span["input"]
|
|
325
|
+
|
|
326
|
+
event = self._create_event(
|
|
327
|
+
span=span,
|
|
328
|
+
input_tokens=input_tokens,
|
|
329
|
+
output_tokens=output_tokens,
|
|
330
|
+
metadata=event_metadata,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
await self._send_event(event)
|
|
334
|
+
|
|
335
|
+
async def on_llm_error(
|
|
336
|
+
self,
|
|
337
|
+
error: BaseException,
|
|
338
|
+
*,
|
|
339
|
+
run_id: uuid.UUID,
|
|
340
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
341
|
+
**kwargs: Any,
|
|
342
|
+
) -> None:
|
|
343
|
+
"""Called when LLM errors."""
|
|
344
|
+
run_id_str = str(run_id)
|
|
345
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
346
|
+
|
|
347
|
+
if not span:
|
|
348
|
+
return
|
|
349
|
+
|
|
350
|
+
span["status"] = "error"
|
|
351
|
+
|
|
352
|
+
event = self._create_event(
|
|
353
|
+
span=span,
|
|
354
|
+
input_tokens=span.get("input_tokens", 0),
|
|
355
|
+
output_tokens=0,
|
|
356
|
+
error_type=type(error).__name__,
|
|
357
|
+
error_message=str(error)[:512],
|
|
358
|
+
metadata={
|
|
359
|
+
"tags": span.get("tags", []),
|
|
360
|
+
"stack_trace": "".join(traceback.format_exception(
|
|
361
|
+
type(error), error, error.__traceback__
|
|
362
|
+
))[:2000],
|
|
363
|
+
},
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
await self._send_event(event)
|
|
367
|
+
|
|
368
|
+
# =========================================================================
|
|
369
|
+
# Chat Model Callbacks
|
|
370
|
+
# =========================================================================
|
|
371
|
+
|
|
372
|
+
async def on_chat_model_start(
|
|
373
|
+
self,
|
|
374
|
+
serialized: Dict[str, Any],
|
|
375
|
+
messages: List[List[BaseMessage]],
|
|
376
|
+
*,
|
|
377
|
+
run_id: uuid.UUID,
|
|
378
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
379
|
+
tags: Optional[List[str]] = None,
|
|
380
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
381
|
+
**kwargs: Any,
|
|
382
|
+
) -> None:
|
|
383
|
+
"""Called when chat model starts."""
|
|
384
|
+
run_id_str = str(run_id)
|
|
385
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
386
|
+
trace_id = await self._get_or_create_trace_id(parent_id_str)
|
|
387
|
+
|
|
388
|
+
model = kwargs.get("invocation_params", {}).get("model", "unknown")
|
|
389
|
+
if model == "unknown":
|
|
390
|
+
model = kwargs.get("invocation_params", {}).get("model_name", "unknown")
|
|
391
|
+
if model == "unknown":
|
|
392
|
+
model = serialized.get("kwargs", {}).get("model", "unknown")
|
|
393
|
+
|
|
394
|
+
provider = _get_provider_from_model(model)
|
|
395
|
+
|
|
396
|
+
message_text = ""
|
|
397
|
+
for msg_list in messages:
|
|
398
|
+
for msg in msg_list:
|
|
399
|
+
if hasattr(msg, "content"):
|
|
400
|
+
content = msg.content
|
|
401
|
+
if isinstance(content, str):
|
|
402
|
+
message_text += content + "\n"
|
|
403
|
+
elif isinstance(content, list):
|
|
404
|
+
for item in content:
|
|
405
|
+
if isinstance(item, str):
|
|
406
|
+
message_text += item + "\n"
|
|
407
|
+
elif isinstance(item, dict) and "text" in item:
|
|
408
|
+
message_text += item["text"] + "\n"
|
|
409
|
+
|
|
410
|
+
input_tokens = _count_tokens(message_text, model)
|
|
411
|
+
|
|
412
|
+
span_metadata = {
|
|
413
|
+
"tags": tags or [],
|
|
414
|
+
"model": model,
|
|
415
|
+
"provider": provider,
|
|
416
|
+
"input_tokens": input_tokens,
|
|
417
|
+
"message_count": sum(len(msg_list) for msg_list in messages),
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
if self.capture_input:
|
|
421
|
+
span_metadata["input"] = self._truncate(message_text)
|
|
422
|
+
|
|
423
|
+
self._span_tracker.start_span(
|
|
424
|
+
run_id=run_id_str,
|
|
425
|
+
trace_id=trace_id,
|
|
426
|
+
parent_run_id=parent_id_str,
|
|
427
|
+
operation="chat_completion",
|
|
428
|
+
span_type="chat",
|
|
429
|
+
model=model,
|
|
430
|
+
provider=provider,
|
|
431
|
+
endpoint=f"{provider}.chat.completions",
|
|
432
|
+
**span_metadata,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# =========================================================================
|
|
436
|
+
# Chain Callbacks
|
|
437
|
+
# =========================================================================
|
|
438
|
+
|
|
439
|
+
async def on_chain_start(
|
|
440
|
+
self,
|
|
441
|
+
serialized: Dict[str, Any],
|
|
442
|
+
inputs: Dict[str, Any],
|
|
443
|
+
*,
|
|
444
|
+
run_id: uuid.UUID,
|
|
445
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
446
|
+
tags: Optional[List[str]] = None,
|
|
447
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
448
|
+
**kwargs: Any,
|
|
449
|
+
) -> None:
|
|
450
|
+
"""Called when chain starts."""
|
|
451
|
+
run_id_str = str(run_id)
|
|
452
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
453
|
+
trace_id = await self._get_or_create_trace_id(parent_id_str)
|
|
454
|
+
|
|
455
|
+
chain_name = serialized.get("name", serialized.get("id", ["unknown"])[-1])
|
|
456
|
+
|
|
457
|
+
span_metadata = {
|
|
458
|
+
"tags": tags or [],
|
|
459
|
+
"chain_name": chain_name,
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
if self.capture_input:
|
|
463
|
+
input_str = str(_serialize_for_metadata(inputs))
|
|
464
|
+
span_metadata["input"] = self._truncate(input_str)
|
|
465
|
+
|
|
466
|
+
self._span_tracker.start_span(
|
|
467
|
+
run_id=run_id_str,
|
|
468
|
+
trace_id=trace_id,
|
|
469
|
+
parent_run_id=parent_id_str,
|
|
470
|
+
operation=f"chain:{chain_name}",
|
|
471
|
+
span_type="chain",
|
|
472
|
+
model="chain",
|
|
473
|
+
provider="langchain",
|
|
474
|
+
endpoint=chain_name,
|
|
475
|
+
**span_metadata,
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
async def on_chain_end(
|
|
479
|
+
self,
|
|
480
|
+
outputs: Dict[str, Any],
|
|
481
|
+
*,
|
|
482
|
+
run_id: uuid.UUID,
|
|
483
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
484
|
+
**kwargs: Any,
|
|
485
|
+
) -> None:
|
|
486
|
+
"""Called when chain ends."""
|
|
487
|
+
run_id_str = str(run_id)
|
|
488
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
489
|
+
|
|
490
|
+
if not span:
|
|
491
|
+
return
|
|
492
|
+
|
|
493
|
+
event_metadata = {
|
|
494
|
+
"tags": span.get("tags", []),
|
|
495
|
+
"chain_name": span.get("chain_name", "unknown"),
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
if self.capture_output:
|
|
499
|
+
output_str = str(_serialize_for_metadata(outputs))
|
|
500
|
+
event_metadata["output"] = self._truncate(output_str)
|
|
501
|
+
if self.capture_input and "input" in span:
|
|
502
|
+
event_metadata["input"] = span["input"]
|
|
503
|
+
|
|
504
|
+
event = self._create_event(
|
|
505
|
+
span=span,
|
|
506
|
+
metadata=event_metadata,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
await self._send_event(event)
|
|
510
|
+
|
|
511
|
+
async def on_chain_error(
|
|
512
|
+
self,
|
|
513
|
+
error: BaseException,
|
|
514
|
+
*,
|
|
515
|
+
run_id: uuid.UUID,
|
|
516
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
517
|
+
**kwargs: Any,
|
|
518
|
+
) -> None:
|
|
519
|
+
"""Called when chain errors."""
|
|
520
|
+
run_id_str = str(run_id)
|
|
521
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
522
|
+
|
|
523
|
+
if not span:
|
|
524
|
+
return
|
|
525
|
+
|
|
526
|
+
span["status"] = "error"
|
|
527
|
+
|
|
528
|
+
event = self._create_event(
|
|
529
|
+
span=span,
|
|
530
|
+
error_type=type(error).__name__,
|
|
531
|
+
error_message=str(error)[:512],
|
|
532
|
+
metadata={
|
|
533
|
+
"tags": span.get("tags", []),
|
|
534
|
+
"chain_name": span.get("chain_name", "unknown"),
|
|
535
|
+
"stack_trace": "".join(traceback.format_exception(
|
|
536
|
+
type(error), error, error.__traceback__
|
|
537
|
+
))[:2000],
|
|
538
|
+
},
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
await self._send_event(event)
|
|
542
|
+
|
|
543
|
+
# =========================================================================
|
|
544
|
+
# Tool Callbacks
|
|
545
|
+
# =========================================================================
|
|
546
|
+
|
|
547
|
+
async def on_tool_start(
|
|
548
|
+
self,
|
|
549
|
+
serialized: Dict[str, Any],
|
|
550
|
+
input_str: str,
|
|
551
|
+
*,
|
|
552
|
+
run_id: uuid.UUID,
|
|
553
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
554
|
+
tags: Optional[List[str]] = None,
|
|
555
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
556
|
+
**kwargs: Any,
|
|
557
|
+
) -> None:
|
|
558
|
+
"""Called when tool starts."""
|
|
559
|
+
run_id_str = str(run_id)
|
|
560
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
561
|
+
trace_id = await self._get_or_create_trace_id(parent_id_str)
|
|
562
|
+
|
|
563
|
+
tool_name = serialized.get("name", "unknown_tool")
|
|
564
|
+
|
|
565
|
+
span_metadata = {
|
|
566
|
+
"tags": tags or [],
|
|
567
|
+
"tool_name": tool_name,
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
if self.capture_input:
|
|
571
|
+
span_metadata["input"] = self._truncate(input_str)
|
|
572
|
+
|
|
573
|
+
self._span_tracker.start_span(
|
|
574
|
+
run_id=run_id_str,
|
|
575
|
+
trace_id=trace_id,
|
|
576
|
+
parent_run_id=parent_id_str,
|
|
577
|
+
operation=f"tool:{tool_name}",
|
|
578
|
+
span_type="tool",
|
|
579
|
+
model="tool",
|
|
580
|
+
provider="langchain",
|
|
581
|
+
endpoint=tool_name,
|
|
582
|
+
**span_metadata,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
async def on_tool_end(
|
|
586
|
+
self,
|
|
587
|
+
output: Any,
|
|
588
|
+
*,
|
|
589
|
+
run_id: uuid.UUID,
|
|
590
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
591
|
+
**kwargs: Any,
|
|
592
|
+
) -> None:
|
|
593
|
+
"""Called when tool ends."""
|
|
594
|
+
run_id_str = str(run_id)
|
|
595
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
596
|
+
|
|
597
|
+
if not span:
|
|
598
|
+
return
|
|
599
|
+
|
|
600
|
+
event_metadata = {
|
|
601
|
+
"tags": span.get("tags", []),
|
|
602
|
+
"tool_name": span.get("tool_name", "unknown"),
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
if self.capture_output:
|
|
606
|
+
output_str = str(_serialize_for_metadata(output))
|
|
607
|
+
event_metadata["output"] = self._truncate(output_str)
|
|
608
|
+
if self.capture_input and "input" in span:
|
|
609
|
+
event_metadata["input"] = span["input"]
|
|
610
|
+
|
|
611
|
+
event = self._create_event(
|
|
612
|
+
span=span,
|
|
613
|
+
metadata=event_metadata,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
await self._send_event(event)
|
|
617
|
+
|
|
618
|
+
async def on_tool_error(
|
|
619
|
+
self,
|
|
620
|
+
error: BaseException,
|
|
621
|
+
*,
|
|
622
|
+
run_id: uuid.UUID,
|
|
623
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
624
|
+
**kwargs: Any,
|
|
625
|
+
) -> None:
|
|
626
|
+
"""Called when tool errors."""
|
|
627
|
+
run_id_str = str(run_id)
|
|
628
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
629
|
+
|
|
630
|
+
if not span:
|
|
631
|
+
return
|
|
632
|
+
|
|
633
|
+
span["status"] = "error"
|
|
634
|
+
|
|
635
|
+
event = self._create_event(
|
|
636
|
+
span=span,
|
|
637
|
+
error_type=type(error).__name__,
|
|
638
|
+
error_message=str(error)[:512],
|
|
639
|
+
metadata={
|
|
640
|
+
"tags": span.get("tags", []),
|
|
641
|
+
"tool_name": span.get("tool_name", "unknown"),
|
|
642
|
+
"stack_trace": "".join(traceback.format_exception(
|
|
643
|
+
type(error), error, error.__traceback__
|
|
644
|
+
))[:2000],
|
|
645
|
+
},
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
await self._send_event(event)
|
|
649
|
+
|
|
650
|
+
# =========================================================================
|
|
651
|
+
# Agent Callbacks
|
|
652
|
+
# =========================================================================
|
|
653
|
+
|
|
654
|
+
async def on_agent_action(
|
|
655
|
+
self,
|
|
656
|
+
action: AgentAction,
|
|
657
|
+
*,
|
|
658
|
+
run_id: uuid.UUID,
|
|
659
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
660
|
+
**kwargs: Any,
|
|
661
|
+
) -> None:
|
|
662
|
+
"""Called when agent takes an action."""
|
|
663
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
664
|
+
if parent_id_str:
|
|
665
|
+
self._span_tracker.update_span(
|
|
666
|
+
parent_id_str,
|
|
667
|
+
last_action=action.tool,
|
|
668
|
+
last_action_input=self._truncate(str(action.tool_input))
|
|
669
|
+
if self.capture_input else None,
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
async def on_agent_finish(
|
|
673
|
+
self,
|
|
674
|
+
finish: AgentFinish,
|
|
675
|
+
*,
|
|
676
|
+
run_id: uuid.UUID,
|
|
677
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
678
|
+
**kwargs: Any,
|
|
679
|
+
) -> None:
|
|
680
|
+
"""Called when agent finishes."""
|
|
681
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
682
|
+
if parent_id_str:
|
|
683
|
+
self._span_tracker.update_span(
|
|
684
|
+
parent_id_str,
|
|
685
|
+
agent_finish=True,
|
|
686
|
+
return_values=self._truncate(str(finish.return_values))
|
|
687
|
+
if self.capture_output else None,
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
# =========================================================================
|
|
691
|
+
# Retriever Callbacks
|
|
692
|
+
# =========================================================================
|
|
693
|
+
|
|
694
|
+
async def on_retriever_start(
|
|
695
|
+
self,
|
|
696
|
+
serialized: Dict[str, Any],
|
|
697
|
+
query: str,
|
|
698
|
+
*,
|
|
699
|
+
run_id: uuid.UUID,
|
|
700
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
701
|
+
tags: Optional[List[str]] = None,
|
|
702
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
703
|
+
**kwargs: Any,
|
|
704
|
+
) -> None:
|
|
705
|
+
"""Called when retriever starts."""
|
|
706
|
+
run_id_str = str(run_id)
|
|
707
|
+
parent_id_str = str(parent_run_id) if parent_run_id else None
|
|
708
|
+
trace_id = await self._get_or_create_trace_id(parent_id_str)
|
|
709
|
+
|
|
710
|
+
retriever_name = serialized.get("name", serialized.get("id", ["unknown"])[-1])
|
|
711
|
+
|
|
712
|
+
span_metadata = {
|
|
713
|
+
"tags": tags or [],
|
|
714
|
+
"retriever_name": retriever_name,
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
if self.capture_input:
|
|
718
|
+
span_metadata["query"] = self._truncate(query)
|
|
719
|
+
|
|
720
|
+
self._span_tracker.start_span(
|
|
721
|
+
run_id=run_id_str,
|
|
722
|
+
trace_id=trace_id,
|
|
723
|
+
parent_run_id=parent_id_str,
|
|
724
|
+
operation=f"retriever:{retriever_name}",
|
|
725
|
+
span_type="retriever",
|
|
726
|
+
model="retriever",
|
|
727
|
+
provider="langchain",
|
|
728
|
+
endpoint=retriever_name,
|
|
729
|
+
**span_metadata,
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
async def on_retriever_end(
|
|
733
|
+
self,
|
|
734
|
+
documents: Sequence[Document],
|
|
735
|
+
*,
|
|
736
|
+
run_id: uuid.UUID,
|
|
737
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
738
|
+
**kwargs: Any,
|
|
739
|
+
) -> None:
|
|
740
|
+
"""Called when retriever ends."""
|
|
741
|
+
run_id_str = str(run_id)
|
|
742
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
743
|
+
|
|
744
|
+
if not span:
|
|
745
|
+
return
|
|
746
|
+
|
|
747
|
+
event_metadata = {
|
|
748
|
+
"tags": span.get("tags", []),
|
|
749
|
+
"retriever_name": span.get("retriever_name", "unknown"),
|
|
750
|
+
"document_count": len(documents),
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
if self.capture_input and "query" in span:
|
|
754
|
+
event_metadata["query"] = span["query"]
|
|
755
|
+
|
|
756
|
+
if self.capture_output and documents:
|
|
757
|
+
doc_summaries = []
|
|
758
|
+
for doc in documents[:5]:
|
|
759
|
+
summary = {
|
|
760
|
+
"content_preview": self._truncate(doc.page_content[:200]),
|
|
761
|
+
"metadata": _serialize_for_metadata(doc.metadata),
|
|
762
|
+
}
|
|
763
|
+
doc_summaries.append(summary)
|
|
764
|
+
event_metadata["documents"] = doc_summaries
|
|
765
|
+
|
|
766
|
+
event = self._create_event(
|
|
767
|
+
span=span,
|
|
768
|
+
metadata=event_metadata,
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
await self._send_event(event)
|
|
772
|
+
|
|
773
|
+
async def on_retriever_error(
|
|
774
|
+
self,
|
|
775
|
+
error: BaseException,
|
|
776
|
+
*,
|
|
777
|
+
run_id: uuid.UUID,
|
|
778
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
779
|
+
**kwargs: Any,
|
|
780
|
+
) -> None:
|
|
781
|
+
"""Called when retriever errors."""
|
|
782
|
+
run_id_str = str(run_id)
|
|
783
|
+
span = self._span_tracker.end_span(run_id_str)
|
|
784
|
+
|
|
785
|
+
if not span:
|
|
786
|
+
return
|
|
787
|
+
|
|
788
|
+
span["status"] = "error"
|
|
789
|
+
|
|
790
|
+
event = self._create_event(
|
|
791
|
+
span=span,
|
|
792
|
+
error_type=type(error).__name__,
|
|
793
|
+
error_message=str(error)[:512],
|
|
794
|
+
metadata={
|
|
795
|
+
"tags": span.get("tags", []),
|
|
796
|
+
"retriever_name": span.get("retriever_name", "unknown"),
|
|
797
|
+
"stack_trace": "".join(traceback.format_exception(
|
|
798
|
+
type(error), error, error.__traceback__
|
|
799
|
+
))[:2000],
|
|
800
|
+
},
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
await self._send_event(event)
|
|
804
|
+
|
|
805
|
+
# =========================================================================
|
|
806
|
+
# Streaming Callbacks
|
|
807
|
+
# =========================================================================
|
|
808
|
+
|
|
809
|
+
async def on_llm_new_token(
|
|
810
|
+
self,
|
|
811
|
+
token: str,
|
|
812
|
+
*,
|
|
813
|
+
run_id: uuid.UUID,
|
|
814
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
815
|
+
**kwargs: Any,
|
|
816
|
+
) -> None:
|
|
817
|
+
"""Called on each new token during streaming."""
|
|
818
|
+
run_id_str = str(run_id)
|
|
819
|
+
span = self._span_tracker.get_span(run_id_str)
|
|
820
|
+
|
|
821
|
+
if span:
|
|
822
|
+
current_tokens = span.get("streaming_tokens", "")
|
|
823
|
+
self._span_tracker.update_span(
|
|
824
|
+
run_id_str,
|
|
825
|
+
streaming_tokens=current_tokens + token,
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
async def on_text(
|
|
829
|
+
self,
|
|
830
|
+
text: str,
|
|
831
|
+
*,
|
|
832
|
+
run_id: uuid.UUID,
|
|
833
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
834
|
+
**kwargs: Any,
|
|
835
|
+
) -> None:
|
|
836
|
+
"""Called when text is generated."""
|
|
837
|
+
pass
|
|
838
|
+
|
|
839
|
+
# =========================================================================
|
|
840
|
+
# Utility Methods
|
|
841
|
+
# =========================================================================
|
|
842
|
+
|
|
843
|
+
async def flush(self):
|
|
844
|
+
"""Force flush all pending events."""
|
|
845
|
+
async with self._buffer_lock:
|
|
846
|
+
await self._flush_buffer()
|
|
847
|
+
|
|
848
|
+
def get_trace_id(self) -> Optional[str]:
|
|
849
|
+
"""Get current trace ID."""
|
|
850
|
+
return self._root_trace_id
|