prela 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prela/__init__.py +394 -0
- prela/_version.py +3 -0
- prela/contrib/CLI.md +431 -0
- prela/contrib/README.md +118 -0
- prela/contrib/__init__.py +5 -0
- prela/contrib/cli.py +1063 -0
- prela/contrib/explorer.py +571 -0
- prela/core/__init__.py +64 -0
- prela/core/clock.py +98 -0
- prela/core/context.py +228 -0
- prela/core/replay.py +403 -0
- prela/core/sampler.py +178 -0
- prela/core/span.py +295 -0
- prela/core/tracer.py +498 -0
- prela/evals/__init__.py +94 -0
- prela/evals/assertions/README.md +484 -0
- prela/evals/assertions/__init__.py +78 -0
- prela/evals/assertions/base.py +90 -0
- prela/evals/assertions/multi_agent.py +625 -0
- prela/evals/assertions/semantic.py +223 -0
- prela/evals/assertions/structural.py +443 -0
- prela/evals/assertions/tool.py +380 -0
- prela/evals/case.py +370 -0
- prela/evals/n8n/__init__.py +69 -0
- prela/evals/n8n/assertions.py +450 -0
- prela/evals/n8n/runner.py +497 -0
- prela/evals/reporters/README.md +184 -0
- prela/evals/reporters/__init__.py +32 -0
- prela/evals/reporters/console.py +251 -0
- prela/evals/reporters/json.py +176 -0
- prela/evals/reporters/junit.py +278 -0
- prela/evals/runner.py +525 -0
- prela/evals/suite.py +316 -0
- prela/exporters/__init__.py +27 -0
- prela/exporters/base.py +189 -0
- prela/exporters/console.py +443 -0
- prela/exporters/file.py +322 -0
- prela/exporters/http.py +394 -0
- prela/exporters/multi.py +154 -0
- prela/exporters/otlp.py +388 -0
- prela/instrumentation/ANTHROPIC.md +297 -0
- prela/instrumentation/LANGCHAIN.md +480 -0
- prela/instrumentation/OPENAI.md +59 -0
- prela/instrumentation/__init__.py +49 -0
- prela/instrumentation/anthropic.py +1436 -0
- prela/instrumentation/auto.py +129 -0
- prela/instrumentation/base.py +436 -0
- prela/instrumentation/langchain.py +959 -0
- prela/instrumentation/llamaindex.py +719 -0
- prela/instrumentation/multi_agent/__init__.py +48 -0
- prela/instrumentation/multi_agent/autogen.py +357 -0
- prela/instrumentation/multi_agent/crewai.py +404 -0
- prela/instrumentation/multi_agent/langgraph.py +299 -0
- prela/instrumentation/multi_agent/models.py +203 -0
- prela/instrumentation/multi_agent/swarm.py +231 -0
- prela/instrumentation/n8n/__init__.py +68 -0
- prela/instrumentation/n8n/code_node.py +534 -0
- prela/instrumentation/n8n/models.py +336 -0
- prela/instrumentation/n8n/webhook.py +489 -0
- prela/instrumentation/openai.py +1198 -0
- prela/license.py +245 -0
- prela/replay/__init__.py +31 -0
- prela/replay/comparison.py +390 -0
- prela/replay/engine.py +1227 -0
- prela/replay/loader.py +231 -0
- prela/replay/result.py +196 -0
- prela-0.1.0.dist-info/METADATA +399 -0
- prela-0.1.0.dist-info/RECORD +71 -0
- prela-0.1.0.dist-info/WHEEL +4 -0
- prela-0.1.0.dist-info/entry_points.txt +2 -0
- prela-0.1.0.dist-info/licenses/LICENSE +190 -0
prela/replay/engine.py
ADDED
|
@@ -0,0 +1,1227 @@
|
|
|
1
|
+
"""Replay execution engine for deterministic trace re-execution."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from functools import wraps
|
|
9
|
+
from typing import Any, Callable, TypeVar
|
|
10
|
+
|
|
11
|
+
from prela.core.span import Span, SpanType
|
|
12
|
+
from prela.replay.loader import Trace
|
|
13
|
+
from prela.replay.result import ReplayResult, ReplayedSpan
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
T = TypeVar("T")
|
|
18
|
+
|
|
19
|
+
# Thread-local storage for retry counts
|
|
20
|
+
_retry_context = threading.local()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _is_retryable_error(error: Exception) -> bool:
|
|
24
|
+
"""Check if API error is retryable.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
error: Exception from API call
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
True if error is likely transient and should be retried
|
|
31
|
+
"""
|
|
32
|
+
error_str = str(error).lower()
|
|
33
|
+
error_type = type(error).__name__.lower()
|
|
34
|
+
|
|
35
|
+
# Retryable error patterns
|
|
36
|
+
retryable_patterns = [
|
|
37
|
+
"rate limit",
|
|
38
|
+
"429", # HTTP 429 Too Many Requests
|
|
39
|
+
"503", # HTTP 503 Service Unavailable
|
|
40
|
+
"502", # HTTP 502 Bad Gateway
|
|
41
|
+
"timeout",
|
|
42
|
+
"timed out",
|
|
43
|
+
"connection",
|
|
44
|
+
"temporarily unavailable",
|
|
45
|
+
"service unavailable",
|
|
46
|
+
"try again",
|
|
47
|
+
"overloaded",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
# Check error message
|
|
51
|
+
for pattern in retryable_patterns:
|
|
52
|
+
if pattern in error_str:
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
# Check error type
|
|
56
|
+
retryable_types = [
|
|
57
|
+
"timeout",
|
|
58
|
+
"connectionerror",
|
|
59
|
+
"httpstatuserror", # httpx
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
for error_type_pattern in retryable_types:
|
|
63
|
+
if error_type_pattern in error_type:
|
|
64
|
+
return True
|
|
65
|
+
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def with_retry(
|
|
70
|
+
max_retries: int = 3,
|
|
71
|
+
initial_delay: float = 1.0,
|
|
72
|
+
max_delay: float = 60.0,
|
|
73
|
+
exponential_base: float = 2.0,
|
|
74
|
+
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
|
75
|
+
"""Decorator for API calls with exponential backoff retry logic.
|
|
76
|
+
|
|
77
|
+
Stores retry count in thread-local storage accessible via _get_last_retry_count().
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
max_retries: Maximum number of retry attempts (0 = no retries)
|
|
81
|
+
initial_delay: Initial delay in seconds before first retry
|
|
82
|
+
max_delay: Maximum delay between retries (cap for exponential backoff)
|
|
83
|
+
exponential_base: Base for exponential backoff (2.0 = double each retry)
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Decorated function with retry logic
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
90
|
+
@wraps(func)
|
|
91
|
+
def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
92
|
+
last_exception: Exception | None = None
|
|
93
|
+
retry_count = 0
|
|
94
|
+
|
|
95
|
+
# Reset retry count in thread-local storage
|
|
96
|
+
_retry_context.retry_count = 0
|
|
97
|
+
|
|
98
|
+
for attempt in range(max_retries + 1):
|
|
99
|
+
try:
|
|
100
|
+
result = func(*args, **kwargs)
|
|
101
|
+
# Store final retry count
|
|
102
|
+
_retry_context.retry_count = retry_count
|
|
103
|
+
return result
|
|
104
|
+
except Exception as e:
|
|
105
|
+
last_exception = e
|
|
106
|
+
|
|
107
|
+
# Check if retryable
|
|
108
|
+
if not _is_retryable_error(e):
|
|
109
|
+
logger.debug(f"Non-retryable error, not retrying: {e}")
|
|
110
|
+
_retry_context.retry_count = 0
|
|
111
|
+
raise
|
|
112
|
+
|
|
113
|
+
# Last attempt, don't retry
|
|
114
|
+
if attempt == max_retries:
|
|
115
|
+
logger.error(
|
|
116
|
+
f"API call failed after {max_retries + 1} attempts: {e}"
|
|
117
|
+
)
|
|
118
|
+
_retry_context.retry_count = retry_count
|
|
119
|
+
raise
|
|
120
|
+
|
|
121
|
+
# Increment retry count
|
|
122
|
+
retry_count += 1
|
|
123
|
+
|
|
124
|
+
# Calculate delay with exponential backoff
|
|
125
|
+
delay = min(
|
|
126
|
+
initial_delay * (exponential_base**attempt),
|
|
127
|
+
max_delay,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
logger.warning(
|
|
131
|
+
f"API call failed (attempt {attempt + 1}/{max_retries + 1}), "
|
|
132
|
+
f"retrying in {delay:.1f}s: {e}"
|
|
133
|
+
)
|
|
134
|
+
time.sleep(delay)
|
|
135
|
+
|
|
136
|
+
# Should never reach here, but for type safety
|
|
137
|
+
_retry_context.retry_count = retry_count
|
|
138
|
+
if last_exception:
|
|
139
|
+
raise last_exception
|
|
140
|
+
raise RuntimeError("Retry logic failed unexpectedly")
|
|
141
|
+
|
|
142
|
+
return wrapper
|
|
143
|
+
|
|
144
|
+
return decorator
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _get_last_retry_count() -> int:
|
|
148
|
+
"""Get retry count from last API call in current thread.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Number of retries (0 if no retries or not set)
|
|
152
|
+
"""
|
|
153
|
+
return getattr(_retry_context, "retry_count", 0)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class ReplayEngine:
|
|
157
|
+
"""Engine for replaying traces with exact or modified parameters.
|
|
158
|
+
|
|
159
|
+
Supports:
|
|
160
|
+
- Exact replay: Use captured data, no API calls
|
|
161
|
+
- Modified replay: Change parameters, make real API calls for modified spans
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
trace: Trace,
|
|
167
|
+
max_retries: int = 3,
|
|
168
|
+
retry_initial_delay: float = 1.0,
|
|
169
|
+
retry_max_delay: float = 60.0,
|
|
170
|
+
retry_exponential_base: float = 2.0,
|
|
171
|
+
) -> None:
|
|
172
|
+
"""Initialize replay engine with a trace.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
trace: Trace to replay
|
|
176
|
+
max_retries: Maximum retry attempts for API calls (default: 3)
|
|
177
|
+
retry_initial_delay: Initial delay before first retry in seconds (default: 1.0)
|
|
178
|
+
retry_max_delay: Maximum delay between retries in seconds (default: 60.0)
|
|
179
|
+
retry_exponential_base: Base for exponential backoff (default: 2.0)
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
ValueError: If trace lacks replay data
|
|
183
|
+
"""
|
|
184
|
+
self.trace = trace
|
|
185
|
+
self.max_retries = max_retries
|
|
186
|
+
self.retry_initial_delay = retry_initial_delay
|
|
187
|
+
self.retry_max_delay = retry_max_delay
|
|
188
|
+
self.retry_exponential_base = retry_exponential_base
|
|
189
|
+
|
|
190
|
+
# Validate trace has replay data
|
|
191
|
+
if not trace.has_replay_data():
|
|
192
|
+
raise ValueError(
|
|
193
|
+
"Trace does not contain replay data. "
|
|
194
|
+
"Enable capture_for_replay=True when creating traces."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
is_complete, missing = trace.validate_replay_completeness()
|
|
198
|
+
if not is_complete:
|
|
199
|
+
logger.warning(
|
|
200
|
+
f"Trace has incomplete replay data. Missing snapshots for: {', '.join(missing[:5])}"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def replay_exact(self) -> ReplayResult:
|
|
204
|
+
"""Replay trace using all captured data without making API calls.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
ReplayResult with all spans executed using captured data
|
|
208
|
+
"""
|
|
209
|
+
result = ReplayResult(trace_id=self.trace.trace_id)
|
|
210
|
+
|
|
211
|
+
# Walk trace in execution order
|
|
212
|
+
for span in self.trace.walk_depth_first():
|
|
213
|
+
replayed_span = self._replay_span_exact(span)
|
|
214
|
+
result.spans.append(replayed_span)
|
|
215
|
+
|
|
216
|
+
# Aggregate metrics
|
|
217
|
+
result.total_duration_ms += replayed_span.duration_ms
|
|
218
|
+
result.total_tokens += replayed_span.tokens_used
|
|
219
|
+
result.total_cost_usd += replayed_span.cost_usd
|
|
220
|
+
|
|
221
|
+
if replayed_span.error:
|
|
222
|
+
result.errors.append(f"{span.name}: {replayed_span.error}")
|
|
223
|
+
|
|
224
|
+
# Extract final output from last root span
|
|
225
|
+
if self.trace.root_spans:
|
|
226
|
+
last_root = self.trace.root_spans[-1]
|
|
227
|
+
result.final_output = self._extract_output(last_root)
|
|
228
|
+
|
|
229
|
+
return result
|
|
230
|
+
|
|
231
|
+
def replay_with_modifications(
|
|
232
|
+
self,
|
|
233
|
+
model: str | None = None,
|
|
234
|
+
temperature: float | None = None,
|
|
235
|
+
system_prompt: str | None = None,
|
|
236
|
+
max_tokens: int | None = None,
|
|
237
|
+
mock_tool_responses: dict[str, Any] | None = None,
|
|
238
|
+
mock_retrieval_results: list[dict[str, Any]] | None = None,
|
|
239
|
+
enable_tool_execution: bool = False,
|
|
240
|
+
tool_execution_allowlist: list[str] | None = None,
|
|
241
|
+
tool_execution_blocklist: list[str] | None = None,
|
|
242
|
+
tool_registry: dict[str, Any] | None = None,
|
|
243
|
+
enable_retrieval_execution: bool = False,
|
|
244
|
+
retrieval_client: Any | None = None,
|
|
245
|
+
retrieval_query_override: str | None = None,
|
|
246
|
+
stream: bool = False,
|
|
247
|
+
stream_callback: Any = None,
|
|
248
|
+
) -> ReplayResult:
|
|
249
|
+
"""Replay trace with specified modifications.
|
|
250
|
+
|
|
251
|
+
For modified LLM spans, makes real API calls.
|
|
252
|
+
For unmodified spans, uses captured data.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
model: Override model for LLM spans
|
|
256
|
+
temperature: Override temperature for LLM spans
|
|
257
|
+
system_prompt: Override system prompt for LLM spans
|
|
258
|
+
max_tokens: Override max_tokens for LLM spans
|
|
259
|
+
mock_tool_responses: Override tool outputs by tool name
|
|
260
|
+
mock_retrieval_results: Override retrieval results
|
|
261
|
+
enable_tool_execution: If True, re-execute tools instead of using cached data
|
|
262
|
+
tool_execution_allowlist: Only execute tools in this list (if provided)
|
|
263
|
+
tool_execution_blocklist: Never execute tools in this list
|
|
264
|
+
tool_registry: Dictionary mapping tool names to callable functions
|
|
265
|
+
enable_retrieval_execution: If True, re-query vector database
|
|
266
|
+
retrieval_client: Vector database client (ChromaDB, Pinecone, etc.)
|
|
267
|
+
retrieval_query_override: Override query for retrieval spans
|
|
268
|
+
stream: If True, use streaming API for modified LLM spans
|
|
269
|
+
stream_callback: Optional callback for streaming chunks (chunk_text: str) -> None
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
ReplayResult with modifications applied
|
|
273
|
+
"""
|
|
274
|
+
result = ReplayResult(
|
|
275
|
+
trace_id=self.trace.trace_id,
|
|
276
|
+
modifications_applied={
|
|
277
|
+
"model": model,
|
|
278
|
+
"temperature": temperature,
|
|
279
|
+
"system_prompt": system_prompt,
|
|
280
|
+
"max_tokens": max_tokens,
|
|
281
|
+
"mock_tool_responses": list(mock_tool_responses.keys())
|
|
282
|
+
if mock_tool_responses
|
|
283
|
+
else [],
|
|
284
|
+
"mock_retrieval_results": len(mock_retrieval_results)
|
|
285
|
+
if mock_retrieval_results
|
|
286
|
+
else 0,
|
|
287
|
+
"enable_tool_execution": enable_tool_execution,
|
|
288
|
+
"tool_execution_allowlist": tool_execution_allowlist,
|
|
289
|
+
"tool_execution_blocklist": tool_execution_blocklist,
|
|
290
|
+
"enable_retrieval_execution": enable_retrieval_execution,
|
|
291
|
+
"retrieval_query_override": retrieval_query_override,
|
|
292
|
+
},
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# Walk trace in execution order
|
|
296
|
+
for span in self.trace.walk_depth_first():
|
|
297
|
+
# Determine if this span needs modification
|
|
298
|
+
needs_modification = self._span_needs_modification(
|
|
299
|
+
span,
|
|
300
|
+
model,
|
|
301
|
+
temperature,
|
|
302
|
+
system_prompt,
|
|
303
|
+
max_tokens,
|
|
304
|
+
mock_tool_responses,
|
|
305
|
+
mock_retrieval_results,
|
|
306
|
+
enable_tool_execution,
|
|
307
|
+
enable_retrieval_execution,
|
|
308
|
+
retrieval_query_override,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if needs_modification:
|
|
312
|
+
replayed_span = self._replay_span_modified(
|
|
313
|
+
span,
|
|
314
|
+
model=model,
|
|
315
|
+
temperature=temperature,
|
|
316
|
+
system_prompt=system_prompt,
|
|
317
|
+
max_tokens=max_tokens,
|
|
318
|
+
mock_tool_responses=mock_tool_responses,
|
|
319
|
+
mock_retrieval_results=mock_retrieval_results,
|
|
320
|
+
enable_tool_execution=enable_tool_execution,
|
|
321
|
+
tool_execution_allowlist=tool_execution_allowlist,
|
|
322
|
+
tool_execution_blocklist=tool_execution_blocklist,
|
|
323
|
+
tool_registry=tool_registry,
|
|
324
|
+
enable_retrieval_execution=enable_retrieval_execution,
|
|
325
|
+
retrieval_client=retrieval_client,
|
|
326
|
+
retrieval_query_override=retrieval_query_override,
|
|
327
|
+
stream=stream,
|
|
328
|
+
stream_callback=stream_callback,
|
|
329
|
+
)
|
|
330
|
+
else:
|
|
331
|
+
replayed_span = self._replay_span_exact(span)
|
|
332
|
+
|
|
333
|
+
result.spans.append(replayed_span)
|
|
334
|
+
|
|
335
|
+
# Aggregate metrics
|
|
336
|
+
result.total_duration_ms += replayed_span.duration_ms
|
|
337
|
+
result.total_tokens += replayed_span.tokens_used
|
|
338
|
+
result.total_cost_usd += replayed_span.cost_usd
|
|
339
|
+
|
|
340
|
+
if replayed_span.error:
|
|
341
|
+
result.errors.append(f"{span.name}: {replayed_span.error}")
|
|
342
|
+
|
|
343
|
+
# Extract final output
|
|
344
|
+
if self.trace.root_spans:
|
|
345
|
+
last_root = self.trace.root_spans[-1]
|
|
346
|
+
result.final_output = self._extract_output(last_root)
|
|
347
|
+
|
|
348
|
+
return result
|
|
349
|
+
|
|
350
|
+
def _replay_span_exact(self, span: Span) -> ReplayedSpan:
|
|
351
|
+
"""Replay span using captured data.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
span: Span to replay
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
ReplayedSpan with captured output
|
|
358
|
+
"""
|
|
359
|
+
if span.replay_snapshot is None:
|
|
360
|
+
return ReplayedSpan(
|
|
361
|
+
original_span_id=span.span_id,
|
|
362
|
+
span_type=span.span_type.value,
|
|
363
|
+
name=span.name,
|
|
364
|
+
input=None,
|
|
365
|
+
output=None,
|
|
366
|
+
error="No replay data available",
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
snapshot = span.replay_snapshot
|
|
370
|
+
|
|
371
|
+
# Extract input and output based on span type
|
|
372
|
+
if span.span_type == SpanType.LLM:
|
|
373
|
+
input_data = snapshot.llm_request
|
|
374
|
+
output_data = snapshot.llm_response.get("text") if snapshot.llm_response else None
|
|
375
|
+
tokens = (
|
|
376
|
+
snapshot.llm_response.get("prompt_tokens", 0)
|
|
377
|
+
+ snapshot.llm_response.get("completion_tokens", 0)
|
|
378
|
+
if snapshot.llm_response
|
|
379
|
+
else 0
|
|
380
|
+
)
|
|
381
|
+
cost = self._estimate_cost(
|
|
382
|
+
snapshot.llm_request.get("model") if snapshot.llm_request else None, tokens
|
|
383
|
+
)
|
|
384
|
+
elif span.span_type == SpanType.TOOL:
|
|
385
|
+
input_data = snapshot.tool_input
|
|
386
|
+
output_data = snapshot.tool_output
|
|
387
|
+
tokens = 0
|
|
388
|
+
cost = 0.0
|
|
389
|
+
elif span.span_type == SpanType.RETRIEVAL:
|
|
390
|
+
input_data = snapshot.retrieval_query
|
|
391
|
+
output_data = snapshot.retrieved_documents
|
|
392
|
+
tokens = 0
|
|
393
|
+
cost = 0.0
|
|
394
|
+
elif span.span_type == SpanType.AGENT:
|
|
395
|
+
input_data = snapshot.agent_memory
|
|
396
|
+
output_data = snapshot.agent_config
|
|
397
|
+
tokens = 0
|
|
398
|
+
cost = 0.0
|
|
399
|
+
else:
|
|
400
|
+
input_data = None
|
|
401
|
+
output_data = None
|
|
402
|
+
tokens = 0
|
|
403
|
+
cost = 0.0
|
|
404
|
+
|
|
405
|
+
duration_ms = span.duration_ms if span.ended_at else 0.0
|
|
406
|
+
|
|
407
|
+
return ReplayedSpan(
|
|
408
|
+
original_span_id=span.span_id,
|
|
409
|
+
span_type=span.span_type.value,
|
|
410
|
+
name=span.name,
|
|
411
|
+
input=input_data,
|
|
412
|
+
output=output_data,
|
|
413
|
+
was_modified=False,
|
|
414
|
+
duration_ms=duration_ms,
|
|
415
|
+
tokens_used=tokens,
|
|
416
|
+
cost_usd=cost,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
def _replay_span_modified(
|
|
420
|
+
self,
|
|
421
|
+
span: Span,
|
|
422
|
+
model: str | None = None,
|
|
423
|
+
temperature: float | None = None,
|
|
424
|
+
system_prompt: str | None = None,
|
|
425
|
+
max_tokens: int | None = None,
|
|
426
|
+
mock_tool_responses: dict[str, Any] | None = None,
|
|
427
|
+
mock_retrieval_results: list[dict[str, Any]] | None = None,
|
|
428
|
+
enable_tool_execution: bool = False,
|
|
429
|
+
tool_execution_allowlist: list[str] | None = None,
|
|
430
|
+
tool_execution_blocklist: list[str] | None = None,
|
|
431
|
+
tool_registry: dict[str, Any] | None = None,
|
|
432
|
+
enable_retrieval_execution: bool = False,
|
|
433
|
+
retrieval_client: Any | None = None,
|
|
434
|
+
retrieval_query_override: str | None = None,
|
|
435
|
+
stream: bool = False,
|
|
436
|
+
stream_callback: Any = None,
|
|
437
|
+
) -> ReplayedSpan:
|
|
438
|
+
"""Replay span with modifications.
|
|
439
|
+
|
|
440
|
+
For LLM spans with param changes, makes real API call.
|
|
441
|
+
For tool/retrieval with mocks, uses mock data.
|
|
442
|
+
For tool/retrieval with execution enabled, re-executes.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
span: Span to replay
|
|
446
|
+
model: Override model
|
|
447
|
+
temperature: Override temperature
|
|
448
|
+
system_prompt: Override system prompt
|
|
449
|
+
max_tokens: Override max_tokens
|
|
450
|
+
mock_tool_responses: Mock tool responses by name
|
|
451
|
+
mock_retrieval_results: Mock retrieval documents
|
|
452
|
+
enable_tool_execution: If True, re-execute tools
|
|
453
|
+
tool_execution_allowlist: Only execute tools in this list
|
|
454
|
+
tool_execution_blocklist: Never execute tools in this list
|
|
455
|
+
tool_registry: Dictionary mapping tool names to callables
|
|
456
|
+
enable_retrieval_execution: If True, re-query vector database
|
|
457
|
+
retrieval_client: Vector database client
|
|
458
|
+
retrieval_query_override: Override query for retrieval
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
ReplayedSpan with modified output
|
|
462
|
+
"""
|
|
463
|
+
if span.replay_snapshot is None:
|
|
464
|
+
return ReplayedSpan(
|
|
465
|
+
original_span_id=span.span_id,
|
|
466
|
+
span_type=span.span_type.value,
|
|
467
|
+
name=span.name,
|
|
468
|
+
input=None,
|
|
469
|
+
output=None,
|
|
470
|
+
error="No replay data available",
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
snapshot = span.replay_snapshot
|
|
474
|
+
modifications = []
|
|
475
|
+
|
|
476
|
+
# Handle LLM spans
|
|
477
|
+
if span.span_type == SpanType.LLM:
|
|
478
|
+
# Build modified request
|
|
479
|
+
modified_request = dict(snapshot.llm_request) if snapshot.llm_request else {}
|
|
480
|
+
|
|
481
|
+
if model is not None:
|
|
482
|
+
modified_request["model"] = model
|
|
483
|
+
modifications.append(f"model={model}")
|
|
484
|
+
if temperature is not None:
|
|
485
|
+
modified_request["temperature"] = temperature
|
|
486
|
+
modifications.append(f"temperature={temperature}")
|
|
487
|
+
if max_tokens is not None:
|
|
488
|
+
modified_request["max_tokens"] = max_tokens
|
|
489
|
+
modifications.append(f"max_tokens={max_tokens}")
|
|
490
|
+
if system_prompt is not None:
|
|
491
|
+
# Update system prompt in messages
|
|
492
|
+
if "messages" in modified_request:
|
|
493
|
+
messages = modified_request["messages"]
|
|
494
|
+
# Add or update system message
|
|
495
|
+
if messages and messages[0].get("role") == "system":
|
|
496
|
+
messages[0]["content"] = system_prompt
|
|
497
|
+
else:
|
|
498
|
+
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
499
|
+
modifications.append("system_prompt=<modified>")
|
|
500
|
+
|
|
501
|
+
# Make real API call
|
|
502
|
+
try:
|
|
503
|
+
output_data, tokens, cost = self._call_llm_api(
|
|
504
|
+
modified_request, stream=stream, stream_callback=stream_callback
|
|
505
|
+
)
|
|
506
|
+
duration_ms = 0.0 # TODO: Measure actual duration
|
|
507
|
+
error = None
|
|
508
|
+
retry_count = _get_last_retry_count()
|
|
509
|
+
except Exception as e:
|
|
510
|
+
logger.error(f"API call failed for {span.name}: {e}")
|
|
511
|
+
output_data = None
|
|
512
|
+
tokens = 0
|
|
513
|
+
cost = 0.0
|
|
514
|
+
duration_ms = 0.0
|
|
515
|
+
error = str(e)
|
|
516
|
+
retry_count = _get_last_retry_count()
|
|
517
|
+
|
|
518
|
+
return ReplayedSpan(
|
|
519
|
+
original_span_id=span.span_id,
|
|
520
|
+
span_type=span.span_type.value,
|
|
521
|
+
name=span.name,
|
|
522
|
+
input=modified_request,
|
|
523
|
+
output=output_data,
|
|
524
|
+
was_modified=True,
|
|
525
|
+
modification_details=", ".join(modifications),
|
|
526
|
+
duration_ms=duration_ms,
|
|
527
|
+
tokens_used=tokens,
|
|
528
|
+
cost_usd=cost,
|
|
529
|
+
error=error,
|
|
530
|
+
retry_count=retry_count,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# Handle tool spans with mocks or execution
|
|
534
|
+
elif span.span_type == SpanType.TOOL:
|
|
535
|
+
# Priority 1: Mock responses
|
|
536
|
+
if mock_tool_responses and snapshot.tool_name in mock_tool_responses:
|
|
537
|
+
output_data = mock_tool_responses[snapshot.tool_name]
|
|
538
|
+
modifications.append(f"mocked_output")
|
|
539
|
+
error = None
|
|
540
|
+
# Priority 2: Tool execution
|
|
541
|
+
elif enable_tool_execution:
|
|
542
|
+
try:
|
|
543
|
+
output_data = self._execute_tool_safely(
|
|
544
|
+
snapshot.tool_name,
|
|
545
|
+
snapshot.tool_input,
|
|
546
|
+
tool_execution_allowlist,
|
|
547
|
+
tool_execution_blocklist,
|
|
548
|
+
tool_registry,
|
|
549
|
+
)
|
|
550
|
+
modifications.append("tool_executed")
|
|
551
|
+
error = None
|
|
552
|
+
except Exception as e:
|
|
553
|
+
logger.error(f"Tool execution failed for {snapshot.tool_name}: {e}")
|
|
554
|
+
output_data = None
|
|
555
|
+
error = str(e)
|
|
556
|
+
# Priority 3: Cached data
|
|
557
|
+
else:
|
|
558
|
+
output_data = snapshot.tool_output
|
|
559
|
+
error = None
|
|
560
|
+
|
|
561
|
+
return ReplayedSpan(
|
|
562
|
+
original_span_id=span.span_id,
|
|
563
|
+
span_type=span.span_type.value,
|
|
564
|
+
name=span.name,
|
|
565
|
+
input=snapshot.tool_input,
|
|
566
|
+
output=output_data,
|
|
567
|
+
was_modified=len(modifications) > 0,
|
|
568
|
+
modification_details=", ".join(modifications) if modifications else None,
|
|
569
|
+
duration_ms=span.duration_ms if span.ended_at else 0.0,
|
|
570
|
+
error=error,
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
# Handle retrieval spans with mocks or execution
|
|
574
|
+
elif span.span_type == SpanType.RETRIEVAL:
|
|
575
|
+
# Priority 1: Mock results
|
|
576
|
+
if mock_retrieval_results is not None:
|
|
577
|
+
output_data = mock_retrieval_results
|
|
578
|
+
modifications.append("mocked_documents")
|
|
579
|
+
error = None
|
|
580
|
+
# Priority 2: Retrieval execution
|
|
581
|
+
elif enable_retrieval_execution and retrieval_client is not None:
|
|
582
|
+
try:
|
|
583
|
+
query = retrieval_query_override if retrieval_query_override else snapshot.retrieval_query
|
|
584
|
+
output_data = self._execute_retrieval(
|
|
585
|
+
query,
|
|
586
|
+
retrieval_client,
|
|
587
|
+
snapshot,
|
|
588
|
+
)
|
|
589
|
+
modifications.append("retrieval_executed")
|
|
590
|
+
if retrieval_query_override:
|
|
591
|
+
modifications.append(f"query_overridden")
|
|
592
|
+
error = None
|
|
593
|
+
except Exception as e:
|
|
594
|
+
logger.error(f"Retrieval execution failed: {e}")
|
|
595
|
+
output_data = None
|
|
596
|
+
error = str(e)
|
|
597
|
+
# Priority 3: Cached data
|
|
598
|
+
else:
|
|
599
|
+
output_data = snapshot.retrieved_documents
|
|
600
|
+
error = None
|
|
601
|
+
|
|
602
|
+
return ReplayedSpan(
|
|
603
|
+
original_span_id=span.span_id,
|
|
604
|
+
span_type=span.span_type.value,
|
|
605
|
+
name=span.name,
|
|
606
|
+
input=retrieval_query_override if retrieval_query_override else snapshot.retrieval_query,
|
|
607
|
+
output=output_data,
|
|
608
|
+
was_modified=len(modifications) > 0,
|
|
609
|
+
modification_details=", ".join(modifications) if modifications else None,
|
|
610
|
+
duration_ms=span.duration_ms if span.ended_at else 0.0,
|
|
611
|
+
error=error,
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
# Default: use exact replay
|
|
615
|
+
return self._replay_span_exact(span)
|
|
616
|
+
|
|
617
|
+
def _span_needs_modification(
|
|
618
|
+
self,
|
|
619
|
+
span: Span,
|
|
620
|
+
model: str | None,
|
|
621
|
+
temperature: float | None,
|
|
622
|
+
system_prompt: str | None,
|
|
623
|
+
max_tokens: int | None,
|
|
624
|
+
mock_tool_responses: dict[str, Any] | None,
|
|
625
|
+
mock_retrieval_results: list[dict[str, Any]] | None,
|
|
626
|
+
enable_tool_execution: bool,
|
|
627
|
+
enable_retrieval_execution: bool,
|
|
628
|
+
retrieval_query_override: str | None,
|
|
629
|
+
) -> bool:
|
|
630
|
+
"""Check if span needs modification.
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
span: Span to check
|
|
634
|
+
model: Model override
|
|
635
|
+
temperature: Temperature override
|
|
636
|
+
system_prompt: System prompt override
|
|
637
|
+
max_tokens: Max tokens override
|
|
638
|
+
mock_tool_responses: Tool response mocks
|
|
639
|
+
mock_retrieval_results: Retrieval result mocks
|
|
640
|
+
enable_tool_execution: If True, tools should be re-executed
|
|
641
|
+
enable_retrieval_execution: If True, retrieval should be re-executed
|
|
642
|
+
retrieval_query_override: If set, retrieval query should be overridden
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
True if span should be modified
|
|
646
|
+
"""
|
|
647
|
+
if span.span_type == SpanType.LLM:
|
|
648
|
+
return any([model, temperature, system_prompt, max_tokens])
|
|
649
|
+
elif span.span_type == SpanType.TOOL:
|
|
650
|
+
if mock_tool_responses and span.replay_snapshot:
|
|
651
|
+
return span.replay_snapshot.tool_name in mock_tool_responses
|
|
652
|
+
return enable_tool_execution
|
|
653
|
+
elif span.span_type == SpanType.RETRIEVAL:
|
|
654
|
+
return any([mock_retrieval_results is not None, enable_retrieval_execution, retrieval_query_override])
|
|
655
|
+
|
|
656
|
+
return False
|
|
657
|
+
|
|
658
|
+
def _execute_tool_safely(
|
|
659
|
+
self,
|
|
660
|
+
tool_name: str,
|
|
661
|
+
tool_input: Any,
|
|
662
|
+
allowlist: list[str] | None,
|
|
663
|
+
blocklist: list[str] | None,
|
|
664
|
+
tool_registry: dict[str, Any] | None,
|
|
665
|
+
) -> Any:
|
|
666
|
+
"""Execute tool with safety checks.
|
|
667
|
+
|
|
668
|
+
Args:
|
|
669
|
+
tool_name: Name of tool to execute
|
|
670
|
+
tool_input: Input data for tool
|
|
671
|
+
allowlist: Only execute tools in this list (if provided)
|
|
672
|
+
blocklist: Never execute tools in this list
|
|
673
|
+
tool_registry: Dictionary mapping tool names to callable functions
|
|
674
|
+
|
|
675
|
+
Returns:
|
|
676
|
+
Tool output
|
|
677
|
+
|
|
678
|
+
Raises:
|
|
679
|
+
ValueError: If tool is blocked, not in allowlist, or not found in registry
|
|
680
|
+
Exception: If tool execution fails
|
|
681
|
+
"""
|
|
682
|
+
# Check blocklist
|
|
683
|
+
if blocklist and tool_name in blocklist:
|
|
684
|
+
raise ValueError(f"Tool '{tool_name}' is blocked from execution")
|
|
685
|
+
|
|
686
|
+
# Check allowlist
|
|
687
|
+
if allowlist and tool_name not in allowlist:
|
|
688
|
+
raise ValueError(f"Tool '{tool_name}' not in allowlist")
|
|
689
|
+
|
|
690
|
+
# Find tool function from registry
|
|
691
|
+
if not tool_registry:
|
|
692
|
+
raise ValueError("tool_registry is required for tool execution")
|
|
693
|
+
|
|
694
|
+
tool_func = tool_registry.get(tool_name)
|
|
695
|
+
if not tool_func:
|
|
696
|
+
raise ValueError(f"Tool '{tool_name}' not found in registry")
|
|
697
|
+
|
|
698
|
+
# Execute tool
|
|
699
|
+
logger.debug(f"Executing tool '{tool_name}' with input: {tool_input}")
|
|
700
|
+
return tool_func(tool_input)
|
|
701
|
+
|
|
702
|
+
def _execute_retrieval(
|
|
703
|
+
self,
|
|
704
|
+
query: str,
|
|
705
|
+
client: Any,
|
|
706
|
+
snapshot: Any,
|
|
707
|
+
) -> list[dict[str, Any]]:
|
|
708
|
+
"""Execute retrieval query against vector database.
|
|
709
|
+
|
|
710
|
+
Args:
|
|
711
|
+
query: Search query
|
|
712
|
+
client: Vector database client (ChromaDB, Pinecone, Qdrant, Weaviate)
|
|
713
|
+
snapshot: Original span snapshot (for metadata like top_k)
|
|
714
|
+
|
|
715
|
+
Returns:
|
|
716
|
+
List of retrieved documents with text and scores
|
|
717
|
+
|
|
718
|
+
Raises:
|
|
719
|
+
ValueError: If client type is not supported
|
|
720
|
+
Exception: If retrieval fails
|
|
721
|
+
"""
|
|
722
|
+
# Detect client type
|
|
723
|
+
client_type = self._detect_retrieval_client(client)
|
|
724
|
+
|
|
725
|
+
# Extract metadata from snapshot
|
|
726
|
+
top_k = getattr(snapshot, "similarity_top_k", 5)
|
|
727
|
+
|
|
728
|
+
logger.debug(f"Executing retrieval with query: {query}, client: {client_type}, top_k: {top_k}")
|
|
729
|
+
|
|
730
|
+
if client_type == "chromadb":
|
|
731
|
+
return self._query_chromadb(client, query, top_k)
|
|
732
|
+
elif client_type == "pinecone":
|
|
733
|
+
return self._query_pinecone(client, query, top_k)
|
|
734
|
+
elif client_type == "qdrant":
|
|
735
|
+
return self._query_qdrant(client, query, top_k)
|
|
736
|
+
elif client_type == "weaviate":
|
|
737
|
+
return self._query_weaviate(client, query, top_k)
|
|
738
|
+
else:
|
|
739
|
+
raise ValueError(
|
|
740
|
+
f"Unsupported vector DB client type: {client_type}. "
|
|
741
|
+
f"Supported types: chromadb, pinecone, qdrant, weaviate"
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
def _detect_retrieval_client(self, client: Any) -> str:
|
|
745
|
+
"""Detect vector database client type.
|
|
746
|
+
|
|
747
|
+
Args:
|
|
748
|
+
client: Vector database client object
|
|
749
|
+
|
|
750
|
+
Returns:
|
|
751
|
+
Client type string ("chromadb", "pinecone", "qdrant", "weaviate")
|
|
752
|
+
"""
|
|
753
|
+
client_class = client.__class__.__name__
|
|
754
|
+
|
|
755
|
+
if "chroma" in client_class.lower():
|
|
756
|
+
return "chromadb"
|
|
757
|
+
elif "pinecone" in client_class.lower():
|
|
758
|
+
return "pinecone"
|
|
759
|
+
elif "qdrant" in client_class.lower():
|
|
760
|
+
return "qdrant"
|
|
761
|
+
elif "weaviate" in client_class.lower():
|
|
762
|
+
return "weaviate"
|
|
763
|
+
else:
|
|
764
|
+
# Try module name as fallback
|
|
765
|
+
module_name = client.__class__.__module__
|
|
766
|
+
if "chroma" in module_name.lower():
|
|
767
|
+
return "chromadb"
|
|
768
|
+
elif "pinecone" in module_name.lower():
|
|
769
|
+
return "pinecone"
|
|
770
|
+
elif "qdrant" in module_name.lower():
|
|
771
|
+
return "qdrant"
|
|
772
|
+
elif "weaviate" in module_name.lower():
|
|
773
|
+
return "weaviate"
|
|
774
|
+
|
|
775
|
+
return "unknown"
|
|
776
|
+
|
|
777
|
+
def _query_chromadb(
|
|
778
|
+
self,
|
|
779
|
+
client: Any,
|
|
780
|
+
query: str,
|
|
781
|
+
top_k: int,
|
|
782
|
+
) -> list[dict[str, Any]]:
|
|
783
|
+
"""Query ChromaDB client.
|
|
784
|
+
|
|
785
|
+
Args:
|
|
786
|
+
client: ChromaDB client or collection
|
|
787
|
+
query: Search query
|
|
788
|
+
top_k: Number of results to return
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
List of documents with text and scores
|
|
792
|
+
"""
|
|
793
|
+
# ChromaDB query API
|
|
794
|
+
results = client.query(query_texts=[query], n_results=top_k)
|
|
795
|
+
|
|
796
|
+
documents = []
|
|
797
|
+
if results and "documents" in results and results["documents"]:
|
|
798
|
+
docs = results["documents"][0] # First query
|
|
799
|
+
distances = results.get("distances", [[]])[0]
|
|
800
|
+
|
|
801
|
+
for i, doc in enumerate(docs):
|
|
802
|
+
documents.append({
|
|
803
|
+
"text": doc,
|
|
804
|
+
"score": 1.0 - distances[i] if i < len(distances) else 0.0,
|
|
805
|
+
})
|
|
806
|
+
|
|
807
|
+
return documents
|
|
808
|
+
|
|
809
|
+
def _query_pinecone(
|
|
810
|
+
self,
|
|
811
|
+
client: Any,
|
|
812
|
+
query: str,
|
|
813
|
+
top_k: int,
|
|
814
|
+
) -> list[dict[str, Any]]:
|
|
815
|
+
"""Query Pinecone index.
|
|
816
|
+
|
|
817
|
+
Args:
|
|
818
|
+
client: Pinecone index
|
|
819
|
+
query: Search query
|
|
820
|
+
top_k: Number of results to return
|
|
821
|
+
|
|
822
|
+
Returns:
|
|
823
|
+
List of documents with text and scores
|
|
824
|
+
"""
|
|
825
|
+
# Pinecone requires embedding the query first
|
|
826
|
+
# For now, return empty list (user should provide embedding model)
|
|
827
|
+
logger.warning("Pinecone retrieval requires embedding model - returning empty results")
|
|
828
|
+
return []
|
|
829
|
+
|
|
830
|
+
def _query_qdrant(
|
|
831
|
+
self,
|
|
832
|
+
client: Any,
|
|
833
|
+
query: str,
|
|
834
|
+
top_k: int,
|
|
835
|
+
) -> list[dict[str, Any]]:
|
|
836
|
+
"""Query Qdrant client.
|
|
837
|
+
|
|
838
|
+
Args:
|
|
839
|
+
client: Qdrant client
|
|
840
|
+
query: Search query
|
|
841
|
+
top_k: Number of results to return
|
|
842
|
+
|
|
843
|
+
Returns:
|
|
844
|
+
List of documents with text and scores
|
|
845
|
+
"""
|
|
846
|
+
# Qdrant requires embedding the query first
|
|
847
|
+
# For now, return empty list (user should provide embedding model)
|
|
848
|
+
logger.warning("Qdrant retrieval requires embedding model - returning empty results")
|
|
849
|
+
return []
|
|
850
|
+
|
|
851
|
+
def _query_weaviate(
|
|
852
|
+
self,
|
|
853
|
+
client: Any,
|
|
854
|
+
query: str,
|
|
855
|
+
top_k: int,
|
|
856
|
+
) -> list[dict[str, Any]]:
|
|
857
|
+
"""Query Weaviate client.
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
client: Weaviate client
|
|
861
|
+
query: Search query
|
|
862
|
+
top_k: Number of results to return
|
|
863
|
+
|
|
864
|
+
Returns:
|
|
865
|
+
List of documents with text and scores
|
|
866
|
+
"""
|
|
867
|
+
# Weaviate requires class name and schema
|
|
868
|
+
# For now, return empty list (user should provide class name)
|
|
869
|
+
logger.warning("Weaviate retrieval requires class name - returning empty results")
|
|
870
|
+
return []
|
|
871
|
+
|
|
872
|
+
def _call_llm_api(
|
|
873
|
+
self,
|
|
874
|
+
request: dict[str, Any],
|
|
875
|
+
stream: bool = False,
|
|
876
|
+
stream_callback: Any = None,
|
|
877
|
+
) -> tuple[str, int, float]:
|
|
878
|
+
"""Make real LLM API call with modified parameters.
|
|
879
|
+
|
|
880
|
+
Args:
|
|
881
|
+
request: LLM request with model, messages, etc.
|
|
882
|
+
stream: If True, use streaming API
|
|
883
|
+
stream_callback: Optional callback for streaming chunks (chunk_text: str) -> None
|
|
884
|
+
|
|
885
|
+
Returns:
|
|
886
|
+
Tuple of (response_text, tokens_used, cost_usd)
|
|
887
|
+
|
|
888
|
+
Raises:
|
|
889
|
+
ValueError: If model is missing or vendor cannot be detected
|
|
890
|
+
ImportError: If required SDK is not installed
|
|
891
|
+
Exception: If API call fails
|
|
892
|
+
"""
|
|
893
|
+
model = request.get("model")
|
|
894
|
+
if not model:
|
|
895
|
+
raise ValueError("Model is required for LLM API calls")
|
|
896
|
+
|
|
897
|
+
# Detect vendor from model name
|
|
898
|
+
vendor = self._detect_vendor(model)
|
|
899
|
+
|
|
900
|
+
if vendor == "openai":
|
|
901
|
+
return self._call_openai_api(request, stream=stream, stream_callback=stream_callback)
|
|
902
|
+
elif vendor == "anthropic":
|
|
903
|
+
return self._call_anthropic_api(request, stream=stream, stream_callback=stream_callback)
|
|
904
|
+
else:
|
|
905
|
+
raise ValueError(
|
|
906
|
+
f"Unsupported model vendor for '{model}'. "
|
|
907
|
+
f"Supported vendors: openai, anthropic"
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
def _detect_vendor(self, model: str) -> str:
|
|
911
|
+
"""Detect LLM vendor from model name.
|
|
912
|
+
|
|
913
|
+
Args:
|
|
914
|
+
model: Model name
|
|
915
|
+
|
|
916
|
+
Returns:
|
|
917
|
+
Vendor name (openai, anthropic)
|
|
918
|
+
|
|
919
|
+
Raises:
|
|
920
|
+
ValueError: If vendor cannot be detected
|
|
921
|
+
"""
|
|
922
|
+
model_lower = model.lower()
|
|
923
|
+
|
|
924
|
+
# OpenAI models
|
|
925
|
+
if any(prefix in model_lower for prefix in ["gpt-", "o1-", "text-embedding"]):
|
|
926
|
+
return "openai"
|
|
927
|
+
|
|
928
|
+
# Anthropic models
|
|
929
|
+
if any(prefix in model_lower for prefix in ["claude-", "claude"]):
|
|
930
|
+
return "anthropic"
|
|
931
|
+
|
|
932
|
+
raise ValueError(f"Cannot detect vendor from model name: {model}")
|
|
933
|
+
|
|
934
|
+
def _call_openai_api(
|
|
935
|
+
self,
|
|
936
|
+
request: dict[str, Any],
|
|
937
|
+
stream: bool = False,
|
|
938
|
+
stream_callback: Any = None,
|
|
939
|
+
) -> tuple[str, int, float]:
|
|
940
|
+
"""Call OpenAI API with optional streaming support and retry logic.
|
|
941
|
+
|
|
942
|
+
Automatically retries on transient errors (rate limits, timeouts, connection issues).
|
|
943
|
+
|
|
944
|
+
Args:
|
|
945
|
+
request: Request with model, messages, temperature, etc.
|
|
946
|
+
stream: If True, use streaming API
|
|
947
|
+
stream_callback: Optional callback for streaming chunks (chunk_text: str) -> None
|
|
948
|
+
|
|
949
|
+
Returns:
|
|
950
|
+
Tuple of (response_text, tokens_used, cost_usd)
|
|
951
|
+
|
|
952
|
+
Raises:
|
|
953
|
+
ImportError: If openai package is not installed
|
|
954
|
+
Exception: If API call fails after all retries
|
|
955
|
+
"""
|
|
956
|
+
# Apply retry decorator dynamically
|
|
957
|
+
@with_retry(
|
|
958
|
+
max_retries=self.max_retries,
|
|
959
|
+
initial_delay=self.retry_initial_delay,
|
|
960
|
+
max_delay=self.retry_max_delay,
|
|
961
|
+
exponential_base=self.retry_exponential_base,
|
|
962
|
+
)
|
|
963
|
+
def _make_call() -> tuple[str, int, float]:
|
|
964
|
+
return self._call_openai_api_impl(request, stream, stream_callback)
|
|
965
|
+
|
|
966
|
+
return _make_call()
|
|
967
|
+
|
|
968
|
+
def _call_openai_api_impl(
|
|
969
|
+
self,
|
|
970
|
+
request: dict[str, Any],
|
|
971
|
+
stream: bool = False,
|
|
972
|
+
stream_callback: Any = None,
|
|
973
|
+
) -> tuple[str, int, float]:
|
|
974
|
+
"""OpenAI API implementation (without retry logic).
|
|
975
|
+
|
|
976
|
+
Args:
|
|
977
|
+
request: Request with model, messages, temperature, etc.
|
|
978
|
+
stream: If True, use streaming API
|
|
979
|
+
stream_callback: Optional callback for streaming chunks
|
|
980
|
+
|
|
981
|
+
Returns:
|
|
982
|
+
Tuple of (response_text, tokens_used, cost_usd)
|
|
983
|
+
"""
|
|
984
|
+
try:
|
|
985
|
+
import openai
|
|
986
|
+
except ImportError:
|
|
987
|
+
raise ImportError(
|
|
988
|
+
"openai package is required for OpenAI API calls. "
|
|
989
|
+
"Install it with: pip install openai"
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
# Extract parameters
|
|
993
|
+
model = request.get("model")
|
|
994
|
+
messages = request.get("messages", [])
|
|
995
|
+
temperature = request.get("temperature")
|
|
996
|
+
max_tokens = request.get("max_tokens")
|
|
997
|
+
|
|
998
|
+
# Build kwargs
|
|
999
|
+
kwargs: dict[str, Any] = {"model": model, "messages": messages, "stream": stream}
|
|
1000
|
+
if temperature is not None:
|
|
1001
|
+
kwargs["temperature"] = temperature
|
|
1002
|
+
if max_tokens is not None:
|
|
1003
|
+
kwargs["max_tokens"] = max_tokens
|
|
1004
|
+
|
|
1005
|
+
# Make API call
|
|
1006
|
+
client = openai.OpenAI()
|
|
1007
|
+
|
|
1008
|
+
if stream:
|
|
1009
|
+
# Streaming mode
|
|
1010
|
+
response_text = ""
|
|
1011
|
+
prompt_tokens = 0
|
|
1012
|
+
completion_tokens = 0
|
|
1013
|
+
|
|
1014
|
+
stream_response = client.chat.completions.create(**kwargs)
|
|
1015
|
+
for chunk in stream_response:
|
|
1016
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
1017
|
+
chunk_text = chunk.choices[0].delta.content
|
|
1018
|
+
response_text += chunk_text
|
|
1019
|
+
|
|
1020
|
+
# Call callback if provided
|
|
1021
|
+
if stream_callback:
|
|
1022
|
+
stream_callback(chunk_text)
|
|
1023
|
+
|
|
1024
|
+
# Extract usage from final chunk
|
|
1025
|
+
if hasattr(chunk, "usage") and chunk.usage:
|
|
1026
|
+
prompt_tokens = chunk.usage.prompt_tokens
|
|
1027
|
+
completion_tokens = chunk.usage.completion_tokens
|
|
1028
|
+
|
|
1029
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
1030
|
+
else:
|
|
1031
|
+
# Non-streaming mode
|
|
1032
|
+
response = client.chat.completions.create(**kwargs)
|
|
1033
|
+
|
|
1034
|
+
# Extract response
|
|
1035
|
+
response_text = response.choices[0].message.content or ""
|
|
1036
|
+
prompt_tokens = response.usage.prompt_tokens if response.usage else 0
|
|
1037
|
+
completion_tokens = response.usage.completion_tokens if response.usage else 0
|
|
1038
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
1039
|
+
|
|
1040
|
+
# Calculate cost
|
|
1041
|
+
cost = self._estimate_cost(model, total_tokens)
|
|
1042
|
+
|
|
1043
|
+
return response_text, total_tokens, cost
|
|
1044
|
+
|
|
1045
|
+
def _call_anthropic_api(
|
|
1046
|
+
self,
|
|
1047
|
+
request: dict[str, Any],
|
|
1048
|
+
stream: bool = False,
|
|
1049
|
+
stream_callback: Any = None,
|
|
1050
|
+
) -> tuple[str, int, float]:
|
|
1051
|
+
"""Call Anthropic API with optional streaming support and retry logic.
|
|
1052
|
+
|
|
1053
|
+
Automatically retries on transient errors (rate limits, timeouts, connection issues).
|
|
1054
|
+
|
|
1055
|
+
Args:
|
|
1056
|
+
request: Request with model, messages, temperature, etc.
|
|
1057
|
+
stream: If True, use streaming API
|
|
1058
|
+
stream_callback: Optional callback for streaming chunks (chunk_text: str) -> None
|
|
1059
|
+
|
|
1060
|
+
Returns:
|
|
1061
|
+
Tuple of (response_text, tokens_used, cost_usd)
|
|
1062
|
+
|
|
1063
|
+
Raises:
|
|
1064
|
+
ImportError: If anthropic package is not installed
|
|
1065
|
+
Exception: If API call fails after all retries
|
|
1066
|
+
"""
|
|
1067
|
+
# Apply retry decorator dynamically
|
|
1068
|
+
@with_retry(
|
|
1069
|
+
max_retries=self.max_retries,
|
|
1070
|
+
initial_delay=self.retry_initial_delay,
|
|
1071
|
+
max_delay=self.retry_max_delay,
|
|
1072
|
+
exponential_base=self.retry_exponential_base,
|
|
1073
|
+
)
|
|
1074
|
+
def _make_call() -> tuple[str, int, float]:
|
|
1075
|
+
return self._call_anthropic_api_impl(request, stream, stream_callback)
|
|
1076
|
+
|
|
1077
|
+
return _make_call()
|
|
1078
|
+
|
|
1079
|
+
def _call_anthropic_api_impl(
|
|
1080
|
+
self,
|
|
1081
|
+
request: dict[str, Any],
|
|
1082
|
+
stream: bool = False,
|
|
1083
|
+
stream_callback: Any = None,
|
|
1084
|
+
) -> tuple[str, int, float]:
|
|
1085
|
+
"""Anthropic API implementation (without retry logic).
|
|
1086
|
+
|
|
1087
|
+
Args:
|
|
1088
|
+
request: Request with model, messages, temperature, etc.
|
|
1089
|
+
stream: If True, use streaming API
|
|
1090
|
+
stream_callback: Optional callback for streaming chunks
|
|
1091
|
+
|
|
1092
|
+
Returns:
|
|
1093
|
+
Tuple of (response_text, tokens_used, cost_usd)
|
|
1094
|
+
"""
|
|
1095
|
+
try:
|
|
1096
|
+
import anthropic
|
|
1097
|
+
except ImportError:
|
|
1098
|
+
raise ImportError(
|
|
1099
|
+
"anthropic package is required for Anthropic API calls. "
|
|
1100
|
+
"Install it with: pip install anthropic"
|
|
1101
|
+
)
|
|
1102
|
+
|
|
1103
|
+
# Extract parameters
|
|
1104
|
+
model = request.get("model")
|
|
1105
|
+
messages = request.get("messages", [])
|
|
1106
|
+
temperature = request.get("temperature")
|
|
1107
|
+
max_tokens = request.get("max_tokens", 1024) # Anthropic requires max_tokens
|
|
1108
|
+
|
|
1109
|
+
# Separate system message from messages
|
|
1110
|
+
system_message = None
|
|
1111
|
+
user_messages = []
|
|
1112
|
+
for msg in messages:
|
|
1113
|
+
if msg.get("role") == "system":
|
|
1114
|
+
system_message = msg.get("content")
|
|
1115
|
+
else:
|
|
1116
|
+
user_messages.append(msg)
|
|
1117
|
+
|
|
1118
|
+
# Build kwargs
|
|
1119
|
+
kwargs: dict[str, Any] = {
|
|
1120
|
+
"model": model,
|
|
1121
|
+
"messages": user_messages,
|
|
1122
|
+
"max_tokens": max_tokens,
|
|
1123
|
+
}
|
|
1124
|
+
if system_message:
|
|
1125
|
+
kwargs["system"] = system_message
|
|
1126
|
+
if temperature is not None:
|
|
1127
|
+
kwargs["temperature"] = temperature
|
|
1128
|
+
|
|
1129
|
+
# Make API call
|
|
1130
|
+
client = anthropic.Anthropic()
|
|
1131
|
+
|
|
1132
|
+
if stream:
|
|
1133
|
+
# Streaming mode
|
|
1134
|
+
response_text = ""
|
|
1135
|
+
prompt_tokens = 0
|
|
1136
|
+
completion_tokens = 0
|
|
1137
|
+
|
|
1138
|
+
with client.messages.stream(**kwargs) as stream:
|
|
1139
|
+
for text in stream.text_stream:
|
|
1140
|
+
response_text += text
|
|
1141
|
+
|
|
1142
|
+
# Call callback if provided
|
|
1143
|
+
if stream_callback:
|
|
1144
|
+
stream_callback(text)
|
|
1145
|
+
|
|
1146
|
+
# Get final message with usage stats
|
|
1147
|
+
final_message = stream.get_final_message()
|
|
1148
|
+
if final_message.usage:
|
|
1149
|
+
prompt_tokens = final_message.usage.input_tokens
|
|
1150
|
+
completion_tokens = final_message.usage.output_tokens
|
|
1151
|
+
|
|
1152
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
1153
|
+
else:
|
|
1154
|
+
# Non-streaming mode
|
|
1155
|
+
response = client.messages.create(**kwargs)
|
|
1156
|
+
|
|
1157
|
+
# Extract response text from content blocks
|
|
1158
|
+
response_text = ""
|
|
1159
|
+
for block in response.content:
|
|
1160
|
+
if hasattr(block, "text"):
|
|
1161
|
+
response_text += block.text
|
|
1162
|
+
|
|
1163
|
+
# Extract token usage
|
|
1164
|
+
prompt_tokens = response.usage.input_tokens if response.usage else 0
|
|
1165
|
+
completion_tokens = response.usage.output_tokens if response.usage else 0
|
|
1166
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
1167
|
+
|
|
1168
|
+
# Calculate cost
|
|
1169
|
+
cost = self._estimate_cost(model, total_tokens)
|
|
1170
|
+
|
|
1171
|
+
return response_text, total_tokens, cost
|
|
1172
|
+
|
|
1173
|
+
def _estimate_cost(self, model: str | None, tokens: int) -> float:
|
|
1174
|
+
"""Estimate cost for API call.
|
|
1175
|
+
|
|
1176
|
+
Args:
|
|
1177
|
+
model: Model name
|
|
1178
|
+
tokens: Total tokens used
|
|
1179
|
+
|
|
1180
|
+
Returns:
|
|
1181
|
+
Estimated cost in USD
|
|
1182
|
+
"""
|
|
1183
|
+
# Rough cost estimates (per 1M tokens)
|
|
1184
|
+
cost_per_1m = {
|
|
1185
|
+
"gpt-4": 30.0,
|
|
1186
|
+
"gpt-4-turbo": 10.0,
|
|
1187
|
+
"gpt-3.5-turbo": 1.5,
|
|
1188
|
+
"claude-3-opus": 15.0,
|
|
1189
|
+
"claude-3-sonnet": 3.0,
|
|
1190
|
+
"claude-3-haiku": 0.8,
|
|
1191
|
+
}
|
|
1192
|
+
|
|
1193
|
+
if not model:
|
|
1194
|
+
return 0.0
|
|
1195
|
+
|
|
1196
|
+
# Find matching model
|
|
1197
|
+
for model_prefix, cost in cost_per_1m.items():
|
|
1198
|
+
if model.startswith(model_prefix):
|
|
1199
|
+
return (tokens / 1_000_000) * cost
|
|
1200
|
+
|
|
1201
|
+
# Default rough estimate
|
|
1202
|
+
return (tokens / 1_000_000) * 5.0
|
|
1203
|
+
|
|
1204
|
+
def _extract_output(self, span: Span) -> Any:
|
|
1205
|
+
"""Extract final output from span.
|
|
1206
|
+
|
|
1207
|
+
Args:
|
|
1208
|
+
span: Span to extract output from
|
|
1209
|
+
|
|
1210
|
+
Returns:
|
|
1211
|
+
Output value
|
|
1212
|
+
"""
|
|
1213
|
+
if span.replay_snapshot is None:
|
|
1214
|
+
return None
|
|
1215
|
+
|
|
1216
|
+
snapshot = span.replay_snapshot
|
|
1217
|
+
|
|
1218
|
+
if span.span_type == SpanType.LLM:
|
|
1219
|
+
return snapshot.llm_response.get("text") if snapshot.llm_response else None
|
|
1220
|
+
elif span.span_type == SpanType.TOOL:
|
|
1221
|
+
return snapshot.tool_output
|
|
1222
|
+
elif span.span_type == SpanType.RETRIEVAL:
|
|
1223
|
+
return snapshot.retrieved_documents
|
|
1224
|
+
elif span.span_type == SpanType.AGENT:
|
|
1225
|
+
return snapshot.agent_config
|
|
1226
|
+
|
|
1227
|
+
return None
|