agentreplay 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,545 @@
1
+ # Copyright 2025 Sushanth (https://github.com/sushanthpy)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Decorator-based tracing for Agentreplay.
17
+
18
+ Provides @traceable and @observe decorators for easy function instrumentation.
19
+
20
+ Example:
21
+ >>> from agentreplay import init, traceable
22
+ >>>
23
+ >>> init()
24
+ >>>
25
+ >>> @traceable
26
+ >>> def my_function(query: str) -> str:
27
+ ... return f"Result for {query}"
28
+ >>>
29
+ >>> result = my_function("hello") # Automatically traced!
30
+ """
31
+
32
+ import functools
33
+ import inspect
34
+ import time
35
+ import logging
36
+ from typing import (
37
+ Optional, Callable, TypeVar, Any, Dict, Union,
38
+ overload, Awaitable
39
+ )
40
+ try:
41
+ from typing import ParamSpec
42
+ except ImportError:
43
+ from typing_extensions import ParamSpec
44
+ from contextvars import ContextVar
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+ # Type variables for generic decorators
49
+ P = ParamSpec("P")
50
+ R = TypeVar("R")
51
+
52
+ # Context variable for current span
53
+ _current_span: ContextVar[Optional[Any]] = ContextVar("current_span", default=None)
54
+
55
+
56
+ # =============================================================================
57
+ # Span Kind
58
+ # =============================================================================
59
+
60
+ class SpanKind:
61
+ """Span kind constants for categorizing operations."""
62
+ CHAIN = "chain"
63
+ LLM = "llm"
64
+ TOOL = "tool"
65
+ RETRIEVER = "retriever"
66
+ EMBEDDING = "embedding"
67
+ GUARDRAIL = "guardrail"
68
+ CACHE = "cache"
69
+ HTTP = "http"
70
+ DB = "db"
71
+
72
+
73
+ # =============================================================================
74
+ # Active Span
75
+ # =============================================================================
76
+
77
+ class ActiveSpan:
78
+ """Active span with methods to add data.
79
+
80
+ This is yielded by the trace() context manager and passed to
81
+ decorated functions.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ name: str,
87
+ kind: str = SpanKind.CHAIN,
88
+ span_id: Optional[str] = None,
89
+ parent_id: Optional[str] = None,
90
+ trace_id: Optional[str] = None,
91
+ ):
92
+ self.name = name
93
+ self.kind = kind
94
+ self.span_id = span_id or self._generate_id()
95
+ self.parent_id = parent_id
96
+ self.trace_id = trace_id or self._generate_id()
97
+ self.start_time = time.time()
98
+ self.end_time: Optional[float] = None
99
+ self.attributes: Dict[str, Any] = {}
100
+ self.events: list = []
101
+ self.input_data: Optional[Any] = None
102
+ self.output_data: Optional[Any] = None
103
+ self.error: Optional[Exception] = None
104
+ self.token_usage: Dict[str, int] = {}
105
+ self._ended = False
106
+
107
+ @staticmethod
108
+ def _generate_id() -> str:
109
+ """Generate unique span ID."""
110
+ import uuid
111
+ return uuid.uuid4().hex[:16]
112
+
113
+ def set_input(self, data: Any) -> "ActiveSpan":
114
+ """Set input data."""
115
+ self.input_data = data
116
+ return self
117
+
118
+ def set_output(self, data: Any) -> "ActiveSpan":
119
+ """Set output data."""
120
+ self.output_data = data
121
+ return self
122
+
123
+ def set_attribute(self, key: str, value: Any) -> "ActiveSpan":
124
+ """Set a span attribute."""
125
+ self.attributes[key] = value
126
+ return self
127
+
128
+ def set_attributes(self, attributes: Dict[str, Any]) -> "ActiveSpan":
129
+ """Set multiple attributes."""
130
+ self.attributes.update(attributes)
131
+ return self
132
+
133
+ def add_event(self, name: str, attributes: Optional[Dict[str, Any]] = None) -> "ActiveSpan":
134
+ """Add an event to the span."""
135
+ self.events.append({
136
+ "name": name,
137
+ "timestamp": time.time(),
138
+ "attributes": attributes or {},
139
+ })
140
+ return self
141
+
142
+ def set_error(self, error: Exception) -> "ActiveSpan":
143
+ """Set error on span."""
144
+ self.error = error
145
+ self.attributes["error.type"] = type(error).__name__
146
+ self.attributes["error.message"] = str(error)
147
+ import traceback
148
+ self.attributes["error.stack"] = traceback.format_exc()
149
+ return self
150
+
151
+ def set_token_usage(
152
+ self,
153
+ prompt_tokens: Optional[int] = None,
154
+ completion_tokens: Optional[int] = None,
155
+ total_tokens: Optional[int] = None,
156
+ ) -> "ActiveSpan":
157
+ """Set token usage for LLM calls."""
158
+ if prompt_tokens is not None:
159
+ self.token_usage["prompt"] = prompt_tokens
160
+ self.attributes["gen_ai.usage.prompt_tokens"] = prompt_tokens
161
+ if completion_tokens is not None:
162
+ self.token_usage["completion"] = completion_tokens
163
+ self.attributes["gen_ai.usage.completion_tokens"] = completion_tokens
164
+ if total_tokens is not None:
165
+ self.token_usage["total"] = total_tokens
166
+ self.attributes["gen_ai.usage.total_tokens"] = total_tokens
167
+ return self
168
+
169
+ def set_model(self, model: str, provider: Optional[str] = None) -> "ActiveSpan":
170
+ """Set model information."""
171
+ self.attributes["gen_ai.request.model"] = model
172
+ if provider:
173
+ self.attributes["gen_ai.system"] = provider
174
+ return self
175
+
176
+ def end(self) -> None:
177
+ """End the span and send to backend."""
178
+ if self._ended:
179
+ return
180
+
181
+ self._ended = True
182
+ self.end_time = time.time()
183
+
184
+ # Send to backend
185
+ self._send()
186
+
187
+ def _send(self) -> None:
188
+ """Send span to Agentreplay backend."""
189
+ try:
190
+ from agentreplay.sdk import get_batching_client, is_initialized, get_config
191
+
192
+ if not is_initialized():
193
+ return
194
+
195
+ config = get_config()
196
+ if not config.enabled:
197
+ return
198
+
199
+ # Build edge
200
+ from agentreplay.models import AgentFlowEdge, SpanType
201
+
202
+ # Map kind to SpanType
203
+ span_type_map = {
204
+ SpanKind.CHAIN: SpanType.ROOT,
205
+ SpanKind.LLM: SpanType.TOOL_CALL,
206
+ SpanKind.TOOL: SpanType.TOOL_CALL,
207
+ SpanKind.RETRIEVER: SpanType.TOOL_CALL,
208
+ SpanKind.EMBEDDING: SpanType.TOOL_CALL,
209
+ }
210
+ span_type = span_type_map.get(self.kind, SpanType.ROOT)
211
+
212
+ # Calculate duration
213
+ duration_us = int((self.end_time - self.start_time) * 1_000_000) if self.end_time else 0
214
+
215
+ # Build payload
216
+ payload = {}
217
+ if self.input_data is not None and config.capture_input:
218
+ payload["input"] = self._safe_serialize(self.input_data)
219
+ if self.output_data is not None and config.capture_output:
220
+ payload["output"] = self._safe_serialize(self.output_data)
221
+ if self.attributes:
222
+ payload["attributes"] = self.attributes
223
+ if self.events:
224
+ payload["events"] = self.events
225
+ if self.error:
226
+ payload["error"] = {
227
+ "type": type(self.error).__name__,
228
+ "message": str(self.error),
229
+ }
230
+
231
+ edge = AgentFlowEdge(
232
+ tenant_id=config.tenant_id,
233
+ project_id=config.project_id,
234
+ agent_id=config.agent_id,
235
+ session_id=int(self.trace_id[:8], 16) if self.trace_id else 0,
236
+ span_type=span_type,
237
+ timestamp_us=int(self.start_time * 1_000_000),
238
+ duration_us=duration_us,
239
+ token_count=self.token_usage.get("total", 0),
240
+ payload=payload,
241
+ )
242
+
243
+ # Send via batching client
244
+ client = get_batching_client()
245
+ client.insert(edge)
246
+
247
+ except Exception as e:
248
+ logger.debug(f"Failed to send span: {e}")
249
+
250
+ def _safe_serialize(self, data: Any, max_size: int = 10000) -> Any:
251
+ """Safely serialize data with size limits."""
252
+ import json
253
+
254
+ try:
255
+ serialized = json.dumps(data, default=str)
256
+ if len(serialized) > max_size:
257
+ return {"__truncated": True, "__preview": serialized[:1000]}
258
+ return data
259
+ except Exception:
260
+ return str(data)[:max_size]
261
+
262
+ def __enter__(self) -> "ActiveSpan":
263
+ """Context manager entry."""
264
+ # Set as current span
265
+ self._token = _current_span.set(self)
266
+ return self
267
+
268
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
269
+ """Context manager exit."""
270
+ # Capture error if any
271
+ if exc_val is not None:
272
+ self.set_error(exc_val)
273
+
274
+ # End span
275
+ self.end()
276
+
277
+ # Reset current span
278
+ _current_span.reset(self._token)
279
+
280
+
281
+ # =============================================================================
282
+ # Get Current Span
283
+ # =============================================================================
284
+
285
+ def get_current_span() -> Optional[ActiveSpan]:
286
+ """Get the currently active span.
287
+
288
+ Returns:
289
+ ActiveSpan if inside a traced context, None otherwise
290
+ """
291
+ return _current_span.get()
292
+
293
+
294
+ # =============================================================================
295
+ # Traceable Decorator
296
+ # =============================================================================
297
+
298
+ @overload
299
+ def traceable(func: Callable[P, R]) -> Callable[P, R]: ...
300
+
301
+ @overload
302
+ def traceable(
303
+ *,
304
+ name: Optional[str] = None,
305
+ kind: str = SpanKind.CHAIN,
306
+ capture_input: bool = True,
307
+ capture_output: bool = True,
308
+ metadata: Optional[Dict[str, Any]] = None,
309
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
310
+
311
+
312
+ def traceable(
313
+ func: Optional[Callable[P, R]] = None,
314
+ *,
315
+ name: Optional[str] = None,
316
+ kind: str = SpanKind.CHAIN,
317
+ capture_input: bool = True,
318
+ capture_output: bool = True,
319
+ metadata: Optional[Dict[str, Any]] = None,
320
+ ) -> Union[Callable[P, R], Callable[[Callable[P, R]], Callable[P, R]]]:
321
+ """Decorator to trace a function.
322
+
323
+ Works with both sync and async functions. Automatically captures
324
+ inputs, outputs, errors, and timing.
325
+
326
+ Args:
327
+ func: Function to decorate (when used without parentheses)
328
+ name: Span name (default: function name)
329
+ kind: Span kind (chain, llm, tool, retriever, etc.)
330
+ capture_input: Whether to capture function inputs
331
+ capture_output: Whether to capture function output
332
+ metadata: Additional metadata to attach
333
+
334
+ Returns:
335
+ Decorated function
336
+
337
+ Example:
338
+ >>> @traceable
339
+ >>> def simple_function():
340
+ ... return "hello"
341
+
342
+ >>> @traceable(name="my_operation", kind="tool")
343
+ >>> def tool_function(query: str):
344
+ ... return search(query)
345
+
346
+ >>> @traceable(capture_input=False) # Don't capture sensitive inputs
347
+ >>> def sensitive_function(password: str):
348
+ ... return authenticate(password)
349
+ """
350
+ def decorator(fn: Callable[P, R]) -> Callable[P, R]:
351
+ span_name = name or fn.__name__
352
+
353
+ # Check if async
354
+ if inspect.iscoroutinefunction(fn):
355
+ @functools.wraps(fn)
356
+ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
357
+ # Get parent span
358
+ parent = get_current_span()
359
+
360
+ # Create span
361
+ span = ActiveSpan(
362
+ name=span_name,
363
+ kind=kind,
364
+ parent_id=parent.span_id if parent else None,
365
+ trace_id=parent.trace_id if parent else None,
366
+ )
367
+
368
+ # Add metadata
369
+ if metadata:
370
+ span.set_attributes(metadata)
371
+
372
+ # Capture input
373
+ if capture_input:
374
+ try:
375
+ input_data = _capture_args(fn, args, kwargs)
376
+ span.set_input(input_data)
377
+ except Exception:
378
+ pass
379
+
380
+ # Execute with span context
381
+ with span:
382
+ try:
383
+ result = await fn(*args, **kwargs)
384
+
385
+ # Capture output
386
+ if capture_output:
387
+ span.set_output(result)
388
+
389
+ return result
390
+ except Exception as e:
391
+ span.set_error(e)
392
+ raise
393
+
394
+ return async_wrapper
395
+ else:
396
+ @functools.wraps(fn)
397
+ def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
398
+ # Get parent span
399
+ parent = get_current_span()
400
+
401
+ # Create span
402
+ span = ActiveSpan(
403
+ name=span_name,
404
+ kind=kind,
405
+ parent_id=parent.span_id if parent else None,
406
+ trace_id=parent.trace_id if parent else None,
407
+ )
408
+
409
+ # Add metadata
410
+ if metadata:
411
+ span.set_attributes(metadata)
412
+
413
+ # Capture input
414
+ if capture_input:
415
+ try:
416
+ input_data = _capture_args(fn, args, kwargs)
417
+ span.set_input(input_data)
418
+ except Exception:
419
+ pass
420
+
421
+ # Execute with span context
422
+ with span:
423
+ try:
424
+ result = fn(*args, **kwargs)
425
+
426
+ # Capture output
427
+ if capture_output:
428
+ span.set_output(result)
429
+
430
+ return result
431
+ except Exception as e:
432
+ span.set_error(e)
433
+ raise
434
+
435
+ return sync_wrapper
436
+
437
+ # Handle @traceable vs @traceable()
438
+ if func is not None:
439
+ return decorator(func)
440
+ return decorator
441
+
442
+
443
+ # Alias for Langfuse-style API
444
+ observe = traceable
445
+
446
+
447
+ def _capture_args(fn: Callable, args: tuple, kwargs: dict) -> Dict[str, Any]:
448
+ """Capture function arguments as a dict."""
449
+ sig = inspect.signature(fn)
450
+ params = list(sig.parameters.keys())
451
+
452
+ result = {}
453
+ for i, arg in enumerate(args):
454
+ if i < len(params):
455
+ result[params[i]] = arg
456
+ else:
457
+ result[f"arg_{i}"] = arg
458
+
459
+ result.update(kwargs)
460
+ return result
461
+
462
+
463
+ # =============================================================================
464
+ # Trace Context Manager
465
+ # =============================================================================
466
+
467
+ def trace(
468
+ name: str,
469
+ *,
470
+ kind: str = SpanKind.CHAIN,
471
+ input: Optional[Any] = None,
472
+ metadata: Optional[Dict[str, Any]] = None,
473
+ ) -> ActiveSpan:
474
+ """Create a trace span as a context manager.
475
+
476
+ Args:
477
+ name: Span name
478
+ kind: Span kind (chain, llm, tool, retriever, etc.)
479
+ input: Input data to record
480
+ metadata: Additional metadata
481
+
482
+ Returns:
483
+ ActiveSpan context manager
484
+
485
+ Example:
486
+ >>> with trace("retrieve_documents", kind="retriever") as span:
487
+ ... docs = vector_db.search(query)
488
+ ... span.set_output({"count": len(docs)})
489
+ ... return docs
490
+ """
491
+ # Get parent span
492
+ parent = get_current_span()
493
+
494
+ # Create span
495
+ span = ActiveSpan(
496
+ name=name,
497
+ kind=kind,
498
+ parent_id=parent.span_id if parent else None,
499
+ trace_id=parent.trace_id if parent else None,
500
+ )
501
+
502
+ # Set input
503
+ if input is not None:
504
+ span.set_input(input)
505
+
506
+ # Set metadata
507
+ if metadata:
508
+ span.set_attributes(metadata)
509
+
510
+ return span
511
+
512
+
513
+ def start_span(
514
+ name: str,
515
+ *,
516
+ kind: str = SpanKind.CHAIN,
517
+ input: Optional[Any] = None,
518
+ metadata: Optional[Dict[str, Any]] = None,
519
+ ) -> ActiveSpan:
520
+ """Start a manual span (must call span.end()).
521
+
522
+ Use trace() context manager when possible. This is for cases
523
+ where you need manual control over span lifetime.
524
+
525
+ Args:
526
+ name: Span name
527
+ kind: Span kind
528
+ input: Input data
529
+ metadata: Additional metadata
530
+
531
+ Returns:
532
+ ActiveSpan (call .end() when done)
533
+
534
+ Example:
535
+ >>> span = start_span("long_operation", kind="tool")
536
+ >>> try:
537
+ ... result = do_something()
538
+ ... span.set_output(result)
539
+ >>> except Exception as e:
540
+ ... span.set_error(e)
541
+ ... raise
542
+ >>> finally:
543
+ ... span.end()
544
+ """
545
+ return trace(name, kind=kind, input=input, metadata=metadata)
@@ -50,8 +50,12 @@ def get_site_packages():
50
50
  import agentreplay
51
51
  agentreplay_dir = os.path.dirname(agentreplay.__file__)
52
52
  site_packages = os.path.dirname(agentreplay_dir)
53
- if site_packages not in paths:
54
- paths.insert(0, site_packages)
53
+
54
+ # Prioritize the directory where agentreplay is actually installed
55
+ if site_packages in paths:
56
+ paths.remove(site_packages)
57
+ paths.insert(0, site_packages)
58
+
55
59
  except ImportError:
56
60
  pass
57
61