agentreplay 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,385 @@
1
+ # Copyright 2025 Sushanth (https://github.com/sushanthpy)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ LangChain callback handler that emits OpenTelemetry spans with proper hierarchy.
17
+
18
+ This creates parent-child span relationships for:
19
+ - Chains → LLM calls
20
+ - Agents → Tool calls → LLM calls
21
+ - RAG pipelines → Retrieval → LLM synthesis
22
+
23
+ Integrates with Agentreplay's zero-code instrumentation.
24
+ """
25
+
26
+ from typing import Any, Dict, List, Optional, Union
27
+ from uuid import UUID
28
+ import time
29
+
30
+ from langchain_core.callbacks import BaseCallbackHandler
31
+ from langchain_core.outputs import LLMResult
32
+ from langchain_core.agents import AgentAction, AgentFinish
33
+ from langchain_core.documents import Document
34
+
35
+ try:
36
+ from opentelemetry import trace
37
+ from opentelemetry.trace import Status, StatusCode
38
+ OTEL_AVAILABLE = True
39
+ except ImportError:
40
+ OTEL_AVAILABLE = False
41
+
42
+
43
+ class AgentreplayCallbackHandler(BaseCallbackHandler):
44
+ """LangChain callback handler that creates hierarchical OTEL spans."""
45
+
46
+ def __init__(self):
47
+ super().__init__()
48
+ self.tracer = trace.get_tracer(__name__) if OTEL_AVAILABLE else None
49
+ self.spans: Dict[str, Any] = {} # run_id -> span
50
+ self.parent_map: Dict[str, str] = {} # run_id -> parent_run_id
51
+
52
+ def _get_parent_span(self, parent_run_id: Optional[UUID]) -> Optional[Any]:
53
+ """Get parent span from run_id."""
54
+ if not parent_run_id or not self.tracer:
55
+ return None
56
+ parent_id = str(parent_run_id)
57
+ return self.spans.get(parent_id)
58
+
59
+ def _start_span(self, name: str, run_id: UUID, parent_run_id: Optional[UUID] = None, **attributes) -> Any:
60
+ """Start a new OTEL span with optional parent."""
61
+ if not self.tracer:
62
+ return None
63
+
64
+ run_id_str = str(run_id)
65
+ parent_span = self._get_parent_span(parent_run_id)
66
+
67
+ # Create span with parent context
68
+ if parent_span:
69
+ ctx = trace.set_span_in_context(parent_span)
70
+ span = self.tracer.start_span(name, context=ctx)
71
+ else:
72
+ span = self.tracer.start_span(name)
73
+
74
+ # Set attributes
75
+ for key, value in attributes.items():
76
+ if value is not None:
77
+ span.set_attribute(key, str(value))
78
+
79
+ self.spans[run_id_str] = span
80
+ if parent_run_id:
81
+ self.parent_map[run_id_str] = str(parent_run_id)
82
+
83
+ return span
84
+
85
+ def _end_span(self, run_id: UUID, status: Optional[StatusCode] = None, error: Optional[str] = None):
86
+ """End a span and clean up."""
87
+ if not self.tracer:
88
+ return
89
+
90
+ run_id_str = str(run_id)
91
+ span = self.spans.pop(run_id_str, None)
92
+
93
+ if span:
94
+ if error:
95
+ span.set_status(Status(StatusCode.ERROR, error))
96
+ span.record_exception(Exception(error))
97
+ elif status:
98
+ span.set_status(Status(status))
99
+ else:
100
+ span.set_status(Status(StatusCode.OK))
101
+ span.end()
102
+
103
+ self.parent_map.pop(run_id_str, None)
104
+
105
+ # ===== Chain Callbacks =====
106
+
107
+ def on_chain_start(
108
+ self,
109
+ serialized: Dict[str, Any],
110
+ inputs: Dict[str, Any],
111
+ *,
112
+ run_id: UUID,
113
+ parent_run_id: Optional[UUID] = None,
114
+ tags: Optional[List[str]] = None,
115
+ metadata: Optional[Dict[str, Any]] = None,
116
+ **kwargs: Any,
117
+ ) -> Any:
118
+ """Called when a chain starts running."""
119
+ chain_name = serialized.get("name", serialized.get("id", ["unknown"])[-1])
120
+
121
+ self._start_span(
122
+ f"chain.{chain_name}",
123
+ run_id,
124
+ parent_run_id,
125
+ **{
126
+ "chain.name": chain_name,
127
+ "chain.type": serialized.get("id", ["unknown"])[0],
128
+ "chain.inputs": str(inputs)[:1000], # Truncate
129
+ "span.type": "chain",
130
+ }
131
+ )
132
+
133
+ def on_chain_end(
134
+ self,
135
+ outputs: Dict[str, Any],
136
+ *,
137
+ run_id: UUID,
138
+ parent_run_id: Optional[UUID] = None,
139
+ **kwargs: Any,
140
+ ) -> Any:
141
+ """Called when a chain finishes running."""
142
+ span = self.spans.get(str(run_id))
143
+ if span:
144
+ span.set_attribute("chain.outputs", str(outputs)[:1000])
145
+ self._end_span(run_id, StatusCode.OK)
146
+
147
+ def on_chain_error(
148
+ self,
149
+ error: BaseException,
150
+ *,
151
+ run_id: UUID,
152
+ parent_run_id: Optional[UUID] = None,
153
+ **kwargs: Any,
154
+ ) -> Any:
155
+ """Called when a chain errors."""
156
+ self._end_span(run_id, StatusCode.ERROR, str(error))
157
+
158
+ # ===== LLM Callbacks =====
159
+
160
+ def on_llm_start(
161
+ self,
162
+ serialized: Dict[str, Any],
163
+ prompts: List[str],
164
+ *,
165
+ run_id: UUID,
166
+ parent_run_id: Optional[UUID] = None,
167
+ tags: Optional[List[str]] = None,
168
+ metadata: Optional[Dict[str, Any]] = None,
169
+ **kwargs: Any,
170
+ ) -> Any:
171
+ """Called when LLM starts running."""
172
+ model_name = serialized.get("name", "unknown")
173
+
174
+ self._start_span(
175
+ f"llm.{model_name}",
176
+ run_id,
177
+ parent_run_id,
178
+ **{
179
+ "llm.model": model_name,
180
+ "llm.prompts": str(prompts)[:2000],
181
+ "llm.prompt_count": len(prompts),
182
+ "span.type": "llm",
183
+ }
184
+ )
185
+
186
+ def on_llm_end(
187
+ self,
188
+ response: LLMResult,
189
+ *,
190
+ run_id: UUID,
191
+ parent_run_id: Optional[UUID] = None,
192
+ **kwargs: Any,
193
+ ) -> Any:
194
+ """Called when LLM finishes running."""
195
+ span = self.spans.get(str(run_id))
196
+ if span:
197
+ # Extract token usage
198
+ if hasattr(response, "llm_output") and response.llm_output:
199
+ token_usage = response.llm_output.get("token_usage", {})
200
+ if token_usage:
201
+ span.set_attribute("llm.tokens.prompt", token_usage.get("prompt_tokens", 0))
202
+ span.set_attribute("llm.tokens.completion", token_usage.get("completion_tokens", 0))
203
+ span.set_attribute("llm.tokens.total", token_usage.get("total_tokens", 0))
204
+
205
+ # Extract generations
206
+ if response.generations:
207
+ first_gen = response.generations[0][0] if response.generations[0] else None
208
+ if first_gen:
209
+ span.set_attribute("llm.response", str(first_gen.text)[:2000])
210
+
211
+ self._end_span(run_id, StatusCode.OK)
212
+
213
+ def on_llm_error(
214
+ self,
215
+ error: BaseException,
216
+ *,
217
+ run_id: UUID,
218
+ parent_run_id: Optional[UUID] = None,
219
+ **kwargs: Any,
220
+ ) -> Any:
221
+ """Called when LLM errors."""
222
+ self._end_span(run_id, StatusCode.ERROR, str(error))
223
+
224
+ # ===== Tool Callbacks =====
225
+
226
+ def on_tool_start(
227
+ self,
228
+ serialized: Dict[str, Any],
229
+ input_str: str,
230
+ *,
231
+ run_id: UUID,
232
+ parent_run_id: Optional[UUID] = None,
233
+ tags: Optional[List[str]] = None,
234
+ metadata: Optional[Dict[str, Any]] = None,
235
+ inputs: Optional[Dict[str, Any]] = None,
236
+ **kwargs: Any,
237
+ ) -> Any:
238
+ """Called when a tool starts running."""
239
+ tool_name = serialized.get("name", "unknown_tool")
240
+
241
+ self._start_span(
242
+ f"tool.{tool_name}",
243
+ run_id,
244
+ parent_run_id,
245
+ **{
246
+ "tool.name": tool_name,
247
+ "tool.description": serialized.get("description", ""),
248
+ "tool.input": input_str[:1000],
249
+ "span.type": "tool",
250
+ }
251
+ )
252
+
253
+ def on_tool_end(
254
+ self,
255
+ output: str,
256
+ *,
257
+ run_id: UUID,
258
+ parent_run_id: Optional[UUID] = None,
259
+ **kwargs: Any,
260
+ ) -> Any:
261
+ """Called when a tool finishes running."""
262
+ span = self.spans.get(str(run_id))
263
+ if span:
264
+ span.set_attribute("tool.output", str(output)[:2000])
265
+ self._end_span(run_id, StatusCode.OK)
266
+
267
+ def on_tool_error(
268
+ self,
269
+ error: BaseException,
270
+ *,
271
+ run_id: UUID,
272
+ parent_run_id: Optional[UUID] = None,
273
+ **kwargs: Any,
274
+ ) -> Any:
275
+ """Called when a tool errors."""
276
+ self._end_span(run_id, StatusCode.ERROR, str(error))
277
+
278
+ # ===== Agent Callbacks =====
279
+
280
+ def on_agent_action(
281
+ self,
282
+ action: AgentAction,
283
+ *,
284
+ run_id: UUID,
285
+ parent_run_id: Optional[UUID] = None,
286
+ **kwargs: Any,
287
+ ) -> Any:
288
+ """Called when an agent takes an action."""
289
+ span = self.spans.get(str(parent_run_id) if parent_run_id else str(run_id))
290
+ if span:
291
+ span.add_event(
292
+ "agent.action",
293
+ attributes={
294
+ "tool": action.tool,
295
+ "tool_input": str(action.tool_input)[:1000],
296
+ "log": action.log[:500] if action.log else "",
297
+ }
298
+ )
299
+
300
+ def on_agent_finish(
301
+ self,
302
+ finish: AgentFinish,
303
+ *,
304
+ run_id: UUID,
305
+ parent_run_id: Optional[UUID] = None,
306
+ **kwargs: Any,
307
+ ) -> Any:
308
+ """Called when an agent finishes."""
309
+ span = self.spans.get(str(parent_run_id) if parent_run_id else str(run_id))
310
+ if span:
311
+ span.add_event(
312
+ "agent.finish",
313
+ attributes={
314
+ "return_values": str(finish.return_values)[:1000],
315
+ "log": finish.log[:500] if finish.log else "",
316
+ }
317
+ )
318
+
319
+ # ===== Retriever Callbacks =====
320
+
321
+ def on_retriever_start(
322
+ self,
323
+ serialized: Dict[str, Any],
324
+ query: str,
325
+ *,
326
+ run_id: UUID,
327
+ parent_run_id: Optional[UUID] = None,
328
+ tags: Optional[List[str]] = None,
329
+ metadata: Optional[Dict[str, Any]] = None,
330
+ **kwargs: Any,
331
+ ) -> Any:
332
+ """Called when retriever starts."""
333
+ self._start_span(
334
+ "retriever.search",
335
+ run_id,
336
+ parent_run_id,
337
+ **{
338
+ "retriever.query": query[:1000],
339
+ "retriever.type": serialized.get("name", "unknown"),
340
+ "span.type": "retrieval",
341
+ }
342
+ )
343
+
344
+ def on_retriever_end(
345
+ self,
346
+ documents: List[Document],
347
+ *,
348
+ run_id: UUID,
349
+ parent_run_id: Optional[UUID] = None,
350
+ **kwargs: Any,
351
+ ) -> Any:
352
+ """Called when retriever finishes."""
353
+ span = self.spans.get(str(run_id))
354
+ if span:
355
+ span.set_attribute("retriever.document_count", len(documents))
356
+ # Store top 3 document snippets
357
+ for i, doc in enumerate(documents[:3]):
358
+ span.set_attribute(
359
+ f"retriever.doc_{i+1}",
360
+ doc.page_content[:500]
361
+ )
362
+ if doc.metadata:
363
+ span.set_attribute(
364
+ f"retriever.doc_{i+1}_metadata",
365
+ str(doc.metadata)[:200]
366
+ )
367
+ self._end_span(run_id, StatusCode.OK)
368
+
369
+ def on_retriever_error(
370
+ self,
371
+ error: BaseException,
372
+ *,
373
+ run_id: UUID,
374
+ parent_run_id: Optional[UUID] = None,
375
+ **kwargs: Any,
376
+ ) -> Any:
377
+ """Called when retriever errors."""
378
+ self._end_span(run_id, StatusCode.ERROR, str(error))
379
+
380
+
381
+ def get_agentreplay_callback() -> Optional[AgentreplayCallbackHandler]:
382
+ """Get Agentreplay callback handler if OTEL is available."""
383
+ if not OTEL_AVAILABLE:
384
+ return None
385
+ return AgentreplayCallbackHandler()
agentreplay/models.py ADDED
@@ -0,0 +1,120 @@
1
+ # Copyright 2025 Sushanth (https://github.com/sushanthpy)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Data models for Agentreplay SDK."""
16
+
17
+ from enum import IntEnum
18
+ from typing import Optional
19
+ from pydantic import BaseModel, Field
20
+ import time
21
+
22
+
23
+ class SpanType(IntEnum):
24
+ """Agent execution span types."""
25
+
26
+ ROOT = 0
27
+ PLANNING = 1
28
+ REASONING = 2
29
+ TOOL_CALL = 3
30
+ TOOL_RESPONSE = 4
31
+ SYNTHESIS = 5
32
+ RESPONSE = 6
33
+ ERROR = 7
34
+ CUSTOM = 255
35
+
36
+ # Backward compatibility aliases
37
+ AGENT = 0 # Alias for ROOT
38
+ TOOL = 3 # Alias for TOOL_CALL
39
+
40
+
41
+ class SensitivityFlags(IntEnum):
42
+ """Sensitivity flags for PII and redaction control."""
43
+
44
+ NONE = 0
45
+ PII = 1 << 0 # Contains personally identifiable information
46
+ SECRET = 1 << 1 # Contains secrets/credentials
47
+ INTERNAL = 1 << 2 # Internal-only data
48
+ NO_EMBED = 1 << 3 # Never embed in vector index
49
+
50
+
51
+ class AgentFlowEdge(BaseModel):
52
+ """AgentFlow Edge - represents one step in agent execution.
53
+
54
+ This is the fundamental unit of data in Agentreplay.
55
+ Fixed 128-byte format when serialized.
56
+ """
57
+
58
+ # Identity & Causality
59
+ edge_id: int = Field(default=0, description="Unique edge identifier (u128)")
60
+ causal_parent: int = Field(default=0, description="Parent edge ID (0 for root)")
61
+
62
+ # Temporal
63
+ timestamp_us: int = Field(
64
+ default_factory=lambda: int(time.time() * 1_000_000),
65
+ description="Timestamp in microseconds since epoch",
66
+ )
67
+ logical_clock: int = Field(default=0, description="Lamport logical clock")
68
+
69
+ # Multi-tenancy
70
+ tenant_id: int = Field(description="Tenant identifier")
71
+ project_id: int = Field(default=0, description="Project identifier within tenant")
72
+ schema_version: int = Field(default=2, description="AFF schema version")
73
+ sensitivity_flags: int = Field(
74
+ default=SensitivityFlags.NONE, description="Sensitivity/privacy flags"
75
+ )
76
+
77
+ # Context
78
+ agent_id: int = Field(description="Agent identifier")
79
+ session_id: int = Field(description="Session/conversation identifier")
80
+ span_type: SpanType = Field(description="Type of agent execution span")
81
+ parent_count: int = Field(default=1, description="Number of parents (>1 for DAG fan-in)")
82
+
83
+ # Probabilistic / Cost
84
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="Confidence score")
85
+ token_count: int = Field(default=0, ge=0, description="Number of tokens used")
86
+ duration_us: int = Field(default=0, ge=0, description="Duration in microseconds")
87
+ sampling_rate: float = Field(default=1.0, ge=0.0, le=1.0, description="Sampling rate")
88
+
89
+ # Payload metadata
90
+ compression_type: int = Field(default=0, description="Compression type (0=None, 1=LZ4, 2=ZSTD)")
91
+ has_payload: bool = Field(default=False, description="Whether payload data exists")
92
+
93
+ # Metadata
94
+ flags: int = Field(default=0, description="General purpose flags")
95
+ checksum: int = Field(default=0, description="BLAKE3 checksum for integrity")
96
+
97
+ class Config:
98
+ """Pydantic config."""
99
+
100
+ use_enum_values = True
101
+
102
+
103
+ class QueryFilter(BaseModel):
104
+ """Filter for querying edges."""
105
+
106
+ tenant_id: Optional[int] = None
107
+ project_id: Optional[int] = None
108
+ agent_id: Optional[int] = None
109
+ session_id: Optional[int] = None
110
+ span_type: Optional[SpanType] = None
111
+ min_confidence: Optional[float] = None
112
+ exclude_pii: bool = False
113
+
114
+
115
+ class QueryResponse(BaseModel):
116
+ """Response from query operations."""
117
+
118
+ edges: list[AgentFlowEdge]
119
+ total_count: int
120
+ has_more: bool = False