braintrust 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/__init__.py +4 -0
- braintrust/_generated_types.py +1200 -611
- braintrust/audit.py +2 -2
- braintrust/cli/eval.py +6 -7
- braintrust/cli/push.py +11 -11
- braintrust/conftest.py +1 -0
- braintrust/context.py +12 -17
- braintrust/contrib/temporal/__init__.py +16 -27
- braintrust/contrib/temporal/test_temporal.py +8 -3
- braintrust/devserver/auth.py +8 -8
- braintrust/devserver/cache.py +3 -4
- braintrust/devserver/cors.py +8 -7
- braintrust/devserver/dataset.py +3 -5
- braintrust/devserver/eval_hooks.py +7 -6
- braintrust/devserver/schemas.py +22 -19
- braintrust/devserver/server.py +19 -12
- braintrust/devserver/test_cached_login.py +4 -4
- braintrust/framework.py +128 -140
- braintrust/framework2.py +88 -87
- braintrust/functions/invoke.py +93 -53
- braintrust/functions/stream.py +3 -2
- braintrust/generated_types.py +17 -1
- braintrust/git_fields.py +11 -11
- braintrust/gitutil.py +2 -3
- braintrust/graph_util.py +10 -10
- braintrust/id_gen.py +2 -2
- braintrust/logger.py +346 -357
- braintrust/merge_row_batch.py +10 -9
- braintrust/oai.py +107 -24
- braintrust/otel/__init__.py +49 -49
- braintrust/otel/context.py +16 -30
- braintrust/otel/test_distributed_tracing.py +14 -11
- braintrust/otel/test_otel_bt_integration.py +32 -31
- braintrust/parameters.py +8 -8
- braintrust/prompt.py +14 -14
- braintrust/prompt_cache/disk_cache.py +5 -4
- braintrust/prompt_cache/lru_cache.py +3 -2
- braintrust/prompt_cache/prompt_cache.py +13 -14
- braintrust/queue.py +4 -4
- braintrust/score.py +4 -4
- braintrust/serializable_data_class.py +4 -4
- braintrust/span_identifier_v1.py +1 -2
- braintrust/span_identifier_v2.py +3 -4
- braintrust/span_identifier_v3.py +23 -20
- braintrust/span_identifier_v4.py +34 -25
- braintrust/test_framework.py +16 -6
- braintrust/test_helpers.py +5 -5
- braintrust/test_id_gen.py +2 -3
- braintrust/test_otel.py +61 -53
- braintrust/test_queue.py +0 -1
- braintrust/test_score.py +1 -3
- braintrust/test_span_components.py +29 -44
- braintrust/util.py +9 -8
- braintrust/version.py +2 -2
- braintrust/wrappers/_anthropic_utils.py +4 -4
- braintrust/wrappers/agno/__init__.py +3 -4
- braintrust/wrappers/agno/agent.py +1 -2
- braintrust/wrappers/agno/function_call.py +1 -2
- braintrust/wrappers/agno/model.py +1 -2
- braintrust/wrappers/agno/team.py +1 -2
- braintrust/wrappers/agno/utils.py +12 -12
- braintrust/wrappers/anthropic.py +7 -8
- braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
- braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
- braintrust/wrappers/dspy.py +15 -17
- braintrust/wrappers/google_genai/__init__.py +16 -16
- braintrust/wrappers/langchain.py +22 -24
- braintrust/wrappers/litellm.py +4 -3
- braintrust/wrappers/openai.py +15 -15
- braintrust/wrappers/pydantic_ai.py +1204 -0
- braintrust/wrappers/test_agno.py +0 -1
- braintrust/wrappers/test_dspy.py +0 -1
- braintrust/wrappers/test_google_genai.py +2 -3
- braintrust/wrappers/test_litellm.py +0 -1
- braintrust/wrappers/test_oai_attachments.py +322 -0
- braintrust/wrappers/test_pydantic_ai_integration.py +1788 -0
- braintrust/wrappers/{test_pydantic_ai.py → test_pydantic_ai_wrap_openai.py} +1 -2
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/METADATA +3 -2
- braintrust-0.4.0.dist-info/RECORD +120 -0
- braintrust-0.3.14.dist-info/RECORD +0 -117
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/WHEEL +0 -0
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/entry_points.txt +0 -0
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,8 @@ import dataclasses
|
|
|
2
2
|
import logging
|
|
3
3
|
import threading
|
|
4
4
|
import time
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import AsyncGenerator, Callable
|
|
6
|
+
from typing import Any
|
|
6
7
|
|
|
7
8
|
from braintrust.logger import start_span
|
|
8
9
|
from braintrust.span_types import SpanTypeAttribute
|
|
@@ -108,12 +109,12 @@ def _wrap_tool_handler(handler: Any, tool_name: Any) -> Callable[..., Any]:
|
|
|
108
109
|
so we try the context variable first, then fall back to current_span export.
|
|
109
110
|
"""
|
|
110
111
|
# Check if already wrapped to prevent double-wrapping
|
|
111
|
-
if hasattr(handler,
|
|
112
|
+
if hasattr(handler, "_braintrust_wrapped"):
|
|
112
113
|
return handler
|
|
113
114
|
|
|
114
115
|
async def wrapped_handler(args: Any) -> Any:
|
|
115
116
|
# Get parent span export from thread-local storage
|
|
116
|
-
parent_export = getattr(_thread_local,
|
|
117
|
+
parent_export = getattr(_thread_local, "parent_span_export", None)
|
|
117
118
|
|
|
118
119
|
with start_span(
|
|
119
120
|
name=str(tool_name),
|
|
@@ -144,11 +145,14 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
144
145
|
We end the previous span when the next AssistantMessage arrives, using the marked
|
|
145
146
|
start time to ensure sequential timing (no overlapping LLM spans).
|
|
146
147
|
"""
|
|
147
|
-
def __init__(self, query_start_time: Optional[float] = None):
|
|
148
|
-
self.current_span: Optional[Any] = None
|
|
149
|
-
self.next_start_time: Optional[float] = query_start_time
|
|
150
148
|
|
|
151
|
-
def
|
|
149
|
+
def __init__(self, query_start_time: float | None = None):
|
|
150
|
+
self.current_span: Any | None = None
|
|
151
|
+
self.next_start_time: float | None = query_start_time
|
|
152
|
+
|
|
153
|
+
def start_llm_span(
|
|
154
|
+
self, message: Any, prompt: Any, conversation_history: list[dict[str, Any]]
|
|
155
|
+
) -> dict[str, Any] | None:
|
|
152
156
|
"""Start a new LLM span, ending the previous one if it exists."""
|
|
153
157
|
# Use the marked start time, or current time as fallback
|
|
154
158
|
start_time = self.next_start_time if self.next_start_time is not None else time.time()
|
|
@@ -158,8 +162,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
158
162
|
self.current_span.end(end_time=start_time)
|
|
159
163
|
|
|
160
164
|
final_content, span = _create_llm_span_for_messages(
|
|
161
|
-
[message], prompt, conversation_history,
|
|
162
|
-
start_time=start_time
|
|
165
|
+
[message], prompt, conversation_history, start_time=start_time
|
|
163
166
|
)
|
|
164
167
|
self.current_span = span
|
|
165
168
|
self.next_start_time = None # Reset for next span
|
|
@@ -169,7 +172,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
169
172
|
"""Mark when the next LLM call will start (after tool results)."""
|
|
170
173
|
self.next_start_time = time.time()
|
|
171
174
|
|
|
172
|
-
def log_usage(self, usage_metrics:
|
|
175
|
+
def log_usage(self, usage_metrics: dict[str, float]) -> None:
|
|
173
176
|
"""Log usage metrics to the current LLM span."""
|
|
174
177
|
if self.current_span and usage_metrics:
|
|
175
178
|
self.current_span.log(metrics=usage_metrics)
|
|
@@ -186,8 +189,8 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
186
189
|
client = original_client_class(*args, **kwargs)
|
|
187
190
|
super().__init__(client)
|
|
188
191
|
self.__client = client
|
|
189
|
-
self.__last_prompt:
|
|
190
|
-
self.__query_start_time:
|
|
192
|
+
self.__last_prompt: str | None = None
|
|
193
|
+
self.__query_start_time: float | None = None
|
|
191
194
|
|
|
192
195
|
async def query(self, *args: Any, **kwargs: Any) -> Any:
|
|
193
196
|
"""Wrap query to capture the prompt and start time for tracing."""
|
|
@@ -220,7 +223,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
220
223
|
# Store the parent span export in thread-local storage for tool handlers
|
|
221
224
|
_thread_local.parent_span_export = span.export()
|
|
222
225
|
|
|
223
|
-
final_results:
|
|
226
|
+
final_results: list[dict[str, Any]] = []
|
|
224
227
|
llm_tracker = LLMSpanTracker(query_start_time=self.__query_start_time)
|
|
225
228
|
|
|
226
229
|
try:
|
|
@@ -243,10 +246,12 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
243
246
|
llm_tracker.log_usage(usage_metrics)
|
|
244
247
|
|
|
245
248
|
result_metadata = {
|
|
246
|
-
k: v
|
|
249
|
+
k: v
|
|
250
|
+
for k, v in {
|
|
247
251
|
"num_turns": getattr(message, "num_turns", None),
|
|
248
252
|
"session_id": getattr(message, "session_id", None),
|
|
249
|
-
}.items()
|
|
253
|
+
}.items()
|
|
254
|
+
if v is not None
|
|
250
255
|
}
|
|
251
256
|
if result_metadata:
|
|
252
257
|
span.log(metadata=result_metadata)
|
|
@@ -257,8 +262,8 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
257
262
|
log.warning("Error in tracing code", exc_info=e)
|
|
258
263
|
finally:
|
|
259
264
|
llm_tracker.cleanup()
|
|
260
|
-
if hasattr(_thread_local,
|
|
261
|
-
delattr(_thread_local,
|
|
265
|
+
if hasattr(_thread_local, "parent_span_export"):
|
|
266
|
+
delattr(_thread_local, "parent_span_export")
|
|
262
267
|
|
|
263
268
|
async def __aenter__(self) -> "WrappedClaudeSDKClient":
|
|
264
269
|
await self.__client.__aenter__()
|
|
@@ -271,11 +276,11 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
271
276
|
|
|
272
277
|
|
|
273
278
|
def _create_llm_span_for_messages(
|
|
274
|
-
messages:
|
|
279
|
+
messages: list[Any], # List of AssistantMessage objects
|
|
275
280
|
prompt: Any,
|
|
276
|
-
conversation_history:
|
|
277
|
-
start_time:
|
|
278
|
-
) ->
|
|
281
|
+
conversation_history: list[dict[str, Any]],
|
|
282
|
+
start_time: float | None = None,
|
|
283
|
+
) -> tuple[dict[str, Any] | None, Any | None]:
|
|
279
284
|
"""Creates an LLM span for a group of AssistantMessage objects.
|
|
280
285
|
|
|
281
286
|
Returns a tuple of (final_content, span):
|
|
@@ -295,13 +300,12 @@ def _create_llm_span_for_messages(
|
|
|
295
300
|
model = getattr(last_message, "model", None)
|
|
296
301
|
input_messages = _build_llm_input(prompt, conversation_history)
|
|
297
302
|
|
|
298
|
-
outputs:
|
|
303
|
+
outputs: list[dict[str, Any]] = []
|
|
299
304
|
for msg in messages:
|
|
300
305
|
if hasattr(msg, "content"):
|
|
301
306
|
content = _serialize_content_blocks(msg.content)
|
|
302
307
|
outputs.append({"content": content, "role": "assistant"})
|
|
303
308
|
|
|
304
|
-
|
|
305
309
|
llm_span = start_span(
|
|
306
310
|
name="anthropic.messages.create",
|
|
307
311
|
span_attributes={"type": SpanTypeAttribute.LLM},
|
|
@@ -355,7 +359,7 @@ def _serialize_content_blocks(content: Any) -> Any:
|
|
|
355
359
|
return content
|
|
356
360
|
|
|
357
361
|
|
|
358
|
-
def _extract_usage_from_result_message(result_message: Any) ->
|
|
362
|
+
def _extract_usage_from_result_message(result_message: Any) -> dict[str, float]:
|
|
359
363
|
"""Extracts and normalizes usage metrics from a ResultMessage.
|
|
360
364
|
|
|
361
365
|
Uses shared Anthropic utilities for consistent metric extraction.
|
|
@@ -374,9 +378,7 @@ def _extract_usage_from_result_message(result_message: Any) -> Dict[str, float]:
|
|
|
374
378
|
return metrics
|
|
375
379
|
|
|
376
380
|
|
|
377
|
-
def _build_llm_input(
|
|
378
|
-
prompt: Any, conversation_history: List[Dict[str, Any]]
|
|
379
|
-
) -> Optional[List[Dict[str, Any]]]:
|
|
381
|
+
def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
|
380
382
|
"""Builds the input array for an LLM span from the initial prompt and conversation history.
|
|
381
383
|
|
|
382
384
|
Formats input to match Anthropic messages API format for proper UI rendering.
|
braintrust/wrappers/dspy.py
CHANGED
|
@@ -47,7 +47,7 @@ Advanced Usage with LiteLLM Patching:
|
|
|
47
47
|
```
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
|
-
from typing import Any
|
|
50
|
+
from typing import Any
|
|
51
51
|
|
|
52
52
|
from braintrust.logger import current_span, start_span
|
|
53
53
|
from braintrust.span_types import SpanTypeAttribute
|
|
@@ -58,9 +58,7 @@ from braintrust.span_types import SpanTypeAttribute
|
|
|
58
58
|
try:
|
|
59
59
|
from dspy.utils.callback import BaseCallback
|
|
60
60
|
except ImportError:
|
|
61
|
-
raise ImportError(
|
|
62
|
-
"DSPy is not installed. Please install it with: pip install dspy"
|
|
63
|
-
)
|
|
61
|
+
raise ImportError("DSPy is not installed. Please install it with: pip install dspy")
|
|
64
62
|
|
|
65
63
|
|
|
66
64
|
class BraintrustDSpyCallback(BaseCallback):
|
|
@@ -130,13 +128,13 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
130
128
|
"""Initialize the Braintrust DSPy callback handler."""
|
|
131
129
|
super().__init__()
|
|
132
130
|
# Map call_id to span objects for proper nesting
|
|
133
|
-
self._spans:
|
|
131
|
+
self._spans: dict[str, Any] = {}
|
|
134
132
|
|
|
135
133
|
def on_lm_start(
|
|
136
134
|
self,
|
|
137
135
|
call_id: str,
|
|
138
136
|
instance: Any,
|
|
139
|
-
inputs:
|
|
137
|
+
inputs: dict[str, Any],
|
|
140
138
|
):
|
|
141
139
|
"""Log the start of a language model call.
|
|
142
140
|
|
|
@@ -174,8 +172,8 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
174
172
|
def on_lm_end(
|
|
175
173
|
self,
|
|
176
174
|
call_id: str,
|
|
177
|
-
outputs:
|
|
178
|
-
exception:
|
|
175
|
+
outputs: dict[str, Any] | None,
|
|
176
|
+
exception: Exception | None = None,
|
|
179
177
|
):
|
|
180
178
|
"""Log the end of a language model call.
|
|
181
179
|
|
|
@@ -205,7 +203,7 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
205
203
|
self,
|
|
206
204
|
call_id: str,
|
|
207
205
|
instance: Any,
|
|
208
|
-
inputs:
|
|
206
|
+
inputs: dict[str, Any],
|
|
209
207
|
):
|
|
210
208
|
"""Log the start of a DSPy module execution.
|
|
211
209
|
|
|
@@ -236,8 +234,8 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
236
234
|
def on_module_end(
|
|
237
235
|
self,
|
|
238
236
|
call_id: str,
|
|
239
|
-
outputs:
|
|
240
|
-
exception:
|
|
237
|
+
outputs: Any | None,
|
|
238
|
+
exception: Exception | None = None,
|
|
241
239
|
):
|
|
242
240
|
"""Log the end of a DSPy module execution.
|
|
243
241
|
|
|
@@ -274,7 +272,7 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
274
272
|
self,
|
|
275
273
|
call_id: str,
|
|
276
274
|
instance: Any,
|
|
277
|
-
inputs:
|
|
275
|
+
inputs: dict[str, Any],
|
|
278
276
|
):
|
|
279
277
|
"""Log the start of a tool invocation.
|
|
280
278
|
|
|
@@ -309,8 +307,8 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
309
307
|
def on_tool_end(
|
|
310
308
|
self,
|
|
311
309
|
call_id: str,
|
|
312
|
-
outputs:
|
|
313
|
-
exception:
|
|
310
|
+
outputs: dict[str, Any] | None,
|
|
311
|
+
exception: Exception | None = None,
|
|
314
312
|
):
|
|
315
313
|
"""Log the end of a tool invocation.
|
|
316
314
|
|
|
@@ -340,7 +338,7 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
340
338
|
self,
|
|
341
339
|
call_id: str,
|
|
342
340
|
instance: Any,
|
|
343
|
-
inputs:
|
|
341
|
+
inputs: dict[str, Any],
|
|
344
342
|
):
|
|
345
343
|
"""Log the start of an evaluation run.
|
|
346
344
|
|
|
@@ -374,8 +372,8 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
374
372
|
def on_evaluate_end(
|
|
375
373
|
self,
|
|
376
374
|
call_id: str,
|
|
377
|
-
outputs:
|
|
378
|
-
exception:
|
|
375
|
+
outputs: Any | None,
|
|
376
|
+
exception: Exception | None = None,
|
|
379
377
|
):
|
|
380
378
|
"""Log the end of an evaluation run.
|
|
381
379
|
|
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import time
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
from wrapt import wrap_function_wrapper
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Any
|
|
6
5
|
|
|
7
6
|
from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
|
|
8
7
|
from braintrust.span_types import SpanTypeAttribute
|
|
8
|
+
from wrapt import wrap_function_wrapper
|
|
9
9
|
|
|
10
10
|
logger = logging.getLogger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def setup_genai(
|
|
14
|
-
api_key:
|
|
15
|
-
project_id:
|
|
16
|
-
project_name:
|
|
14
|
+
api_key: str | None = None,
|
|
15
|
+
project_id: str | None = None,
|
|
16
|
+
project_name: str | None = None,
|
|
17
17
|
):
|
|
18
18
|
span = current_span()
|
|
19
19
|
if span == NOOP_SPAN:
|
|
@@ -148,7 +148,7 @@ def wrap_async_models(AsyncModels: Any):
|
|
|
148
148
|
return AsyncModels
|
|
149
149
|
|
|
150
150
|
|
|
151
|
-
def _serialize_input(api_client: Any, input:
|
|
151
|
+
def _serialize_input(api_client: Any, input: dict[str, Any]):
|
|
152
152
|
config = _try_dict(input.get("config"))
|
|
153
153
|
|
|
154
154
|
if config is not None:
|
|
@@ -223,7 +223,7 @@ def _serialize_content_item(item: Any) -> Any:
|
|
|
223
223
|
return item
|
|
224
224
|
|
|
225
225
|
|
|
226
|
-
def _serialize_tools(api_client: Any, input:
|
|
226
|
+
def _serialize_tools(api_client: Any, input: Any | None):
|
|
227
227
|
try:
|
|
228
228
|
from google.genai.models import (
|
|
229
229
|
_GenerateContentParameters_to_mldev, # pyright: ignore [reportPrivateUsage]
|
|
@@ -242,7 +242,7 @@ def _serialize_tools(api_client: Any, input: Optional[Any]):
|
|
|
242
242
|
return None
|
|
243
243
|
|
|
244
244
|
|
|
245
|
-
def omit(obj:
|
|
245
|
+
def omit(obj: dict[str, Any], keys: Iterable[str]):
|
|
246
246
|
return {k: v for k, v in obj.items() if k not in keys}
|
|
247
247
|
|
|
248
248
|
|
|
@@ -254,11 +254,11 @@ def mark_patched(obj: Any):
|
|
|
254
254
|
return setattr(obj, "_braintrust_patched", True)
|
|
255
255
|
|
|
256
256
|
|
|
257
|
-
def get_args_kwargs(args:
|
|
257
|
+
def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: Iterable[str]):
|
|
258
258
|
return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys)
|
|
259
259
|
|
|
260
260
|
|
|
261
|
-
def _extract_generate_content_metrics(response: Any, start: float) ->
|
|
261
|
+
def _extract_generate_content_metrics(response: Any, start: float) -> dict[str, Any]:
|
|
262
262
|
"""Extract metrics from a non-streaming generate_content response."""
|
|
263
263
|
end_time = time.time()
|
|
264
264
|
metrics = dict(
|
|
@@ -297,8 +297,8 @@ def _extract_generate_content_metrics(response: Any, start: float) -> Dict[str,
|
|
|
297
297
|
|
|
298
298
|
|
|
299
299
|
def _aggregate_generate_content_chunks(
|
|
300
|
-
chunks:
|
|
301
|
-
) ->
|
|
300
|
+
chunks: list[Any], start: float, first_token_time: float | None = None
|
|
301
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
302
302
|
"""Aggregate streaming chunks into a single response with metrics."""
|
|
303
303
|
end_time = time.time()
|
|
304
304
|
metrics = dict(
|
|
@@ -410,11 +410,11 @@ def _aggregate_generate_content_chunks(
|
|
|
410
410
|
return aggregated, clean_metrics
|
|
411
411
|
|
|
412
412
|
|
|
413
|
-
def clean(obj:
|
|
413
|
+
def clean(obj: dict[str, Any]) -> dict[str, Any]:
|
|
414
414
|
return {k: v for k, v in obj.items() if v is not None}
|
|
415
415
|
|
|
416
416
|
|
|
417
|
-
def get_path(obj:
|
|
417
|
+
def get_path(obj: dict[str, Any], path: str, default: Any = None) -> Any | None:
|
|
418
418
|
keys = path.split(".")
|
|
419
419
|
current = obj
|
|
420
420
|
|
|
@@ -426,7 +426,7 @@ def get_path(obj: Dict[str, Any], path: str, default: Any = None) -> Optional[An
|
|
|
426
426
|
return current
|
|
427
427
|
|
|
428
428
|
|
|
429
|
-
def _try_dict(obj: Any) ->
|
|
429
|
+
def _try_dict(obj: Any) -> dict[str, Any] | None:
|
|
430
430
|
try:
|
|
431
431
|
return obj.model_dump()
|
|
432
432
|
except AttributeError:
|
braintrust/wrappers/langchain.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import contextvars
|
|
2
2
|
import logging
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
from uuid import UUID
|
|
5
5
|
|
|
6
6
|
import braintrust
|
|
@@ -30,7 +30,7 @@ class BraintrustTracer(BaseCallbackHandler):
|
|
|
30
30
|
self.logger = logger
|
|
31
31
|
self.spans = {}
|
|
32
32
|
|
|
33
|
-
def _start_span(self, parent_run_id, run_id, name:
|
|
33
|
+
def _start_span(self, parent_run_id, run_id, name: str | None, **kwargs: Any) -> Any:
|
|
34
34
|
assert run_id not in self.spans, f"Span already exists for run_id {run_id} (this is likely a bug)"
|
|
35
35
|
|
|
36
36
|
current_parent = langchain_parent.get()
|
|
@@ -60,29 +60,29 @@ class BraintrustTracer(BaseCallbackHandler):
|
|
|
60
60
|
|
|
61
61
|
def on_chain_start(
|
|
62
62
|
self,
|
|
63
|
-
serialized:
|
|
64
|
-
inputs:
|
|
63
|
+
serialized: dict[str, Any],
|
|
64
|
+
inputs: dict[str, Any],
|
|
65
65
|
*,
|
|
66
66
|
run_id: UUID,
|
|
67
|
-
parent_run_id:
|
|
68
|
-
tags:
|
|
67
|
+
parent_run_id: UUID | None = None,
|
|
68
|
+
tags: list[str] | None = None,
|
|
69
69
|
**kwargs: Any,
|
|
70
70
|
) -> Any:
|
|
71
71
|
self._start_span(parent_run_id, run_id, "Chain", input=inputs, metadata={"tags": tags})
|
|
72
72
|
|
|
73
73
|
def on_chain_end(
|
|
74
|
-
self, outputs:
|
|
74
|
+
self, outputs: dict[str, Any], *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
|
|
75
75
|
) -> Any:
|
|
76
76
|
self._end_span(run_id, output=outputs)
|
|
77
77
|
|
|
78
78
|
def on_llm_start(
|
|
79
79
|
self,
|
|
80
|
-
serialized:
|
|
81
|
-
prompts:
|
|
80
|
+
serialized: dict[str, Any],
|
|
81
|
+
prompts: list[str],
|
|
82
82
|
*,
|
|
83
83
|
run_id: UUID,
|
|
84
|
-
parent_run_id:
|
|
85
|
-
tags:
|
|
84
|
+
parent_run_id: UUID | None = None,
|
|
85
|
+
tags: list[str] | None = None,
|
|
86
86
|
**kwargs: Any,
|
|
87
87
|
) -> Any:
|
|
88
88
|
self._start_span(
|
|
@@ -95,12 +95,12 @@ class BraintrustTracer(BaseCallbackHandler):
|
|
|
95
95
|
|
|
96
96
|
def on_chat_model_start(
|
|
97
97
|
self,
|
|
98
|
-
serialized:
|
|
99
|
-
messages:
|
|
98
|
+
serialized: dict[str, Any],
|
|
99
|
+
messages: list[list[BaseMessage]],
|
|
100
100
|
*,
|
|
101
101
|
run_id: UUID,
|
|
102
|
-
parent_run_id:
|
|
103
|
-
tags:
|
|
102
|
+
parent_run_id: UUID | None = None,
|
|
103
|
+
tags: list[str] | None = None,
|
|
104
104
|
**kwargs: Any,
|
|
105
105
|
) -> Any:
|
|
106
106
|
self._start_span(
|
|
@@ -112,7 +112,7 @@ class BraintrustTracer(BaseCallbackHandler):
|
|
|
112
112
|
)
|
|
113
113
|
|
|
114
114
|
def on_llm_end(
|
|
115
|
-
self, response: LLMResult, *, run_id: UUID, parent_run_id:
|
|
115
|
+
self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
|
|
116
116
|
) -> Any:
|
|
117
117
|
metrics = {}
|
|
118
118
|
token_usage = response.llm_output.get("token_usage", {})
|
|
@@ -127,25 +127,23 @@ class BraintrustTracer(BaseCallbackHandler):
|
|
|
127
127
|
|
|
128
128
|
def on_tool_start(
|
|
129
129
|
self,
|
|
130
|
-
serialized:
|
|
130
|
+
serialized: dict[str, Any],
|
|
131
131
|
input_str: str,
|
|
132
132
|
*,
|
|
133
133
|
run_id: UUID,
|
|
134
|
-
parent_run_id:
|
|
135
|
-
tags:
|
|
134
|
+
parent_run_id: UUID | None = None,
|
|
135
|
+
tags: list[str] | None = None,
|
|
136
136
|
**kwargs: Any,
|
|
137
137
|
) -> Any:
|
|
138
138
|
_logger.warning("Starting tool, but it will not be traced in braintrust (unsupported)")
|
|
139
139
|
|
|
140
|
-
def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id:
|
|
140
|
+
def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
|
|
141
141
|
pass
|
|
142
142
|
|
|
143
|
-
def on_retriever_start(
|
|
144
|
-
self, query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
|
|
145
|
-
) -> Any:
|
|
143
|
+
def on_retriever_start(self, query: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
|
|
146
144
|
_logger.warning("Starting retriever, but it will not be traced in braintrust (unsupported)")
|
|
147
145
|
|
|
148
146
|
def on_retriever_end(
|
|
149
|
-
self, response:
|
|
147
|
+
self, response: list[Document], *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
|
|
150
148
|
) -> Any:
|
|
151
149
|
pass
|
braintrust/wrappers/litellm.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import time
|
|
4
|
-
from collections.abc import AsyncGenerator, Generator
|
|
4
|
+
from collections.abc import AsyncGenerator, Callable, Generator
|
|
5
5
|
from types import TracebackType
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from braintrust.logger import Span, start_span
|
|
9
9
|
from braintrust.span_types import SpanTypeAttribute
|
|
@@ -655,7 +655,8 @@ def patch_litellm():
|
|
|
655
655
|
"""
|
|
656
656
|
try:
|
|
657
657
|
import litellm
|
|
658
|
-
|
|
658
|
+
|
|
659
|
+
if not hasattr(litellm, "_braintrust_wrapped"):
|
|
659
660
|
wrapped = wrap_litellm(litellm)
|
|
660
661
|
litellm.completion = wrapped.completion
|
|
661
662
|
litellm.acompletion = wrapped.acompletion
|
braintrust/wrappers/openai.py
CHANGED
|
@@ -3,7 +3,7 @@ Exports `BraintrustTracingProcessor`, a `tracing.TracingProcessor` that logs tra
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import datetime
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import braintrust
|
|
9
9
|
from agents import tracing
|
|
@@ -40,13 +40,13 @@ def _span_name(span: tracing.Span[Any]) -> str:
|
|
|
40
40
|
return "Unknown"
|
|
41
41
|
|
|
42
42
|
|
|
43
|
-
def _timestamp_from_maybe_iso(timestamp:
|
|
43
|
+
def _timestamp_from_maybe_iso(timestamp: str | None) -> float | None:
|
|
44
44
|
if timestamp is None:
|
|
45
45
|
return None
|
|
46
46
|
return datetime.datetime.fromisoformat(timestamp).timestamp()
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
def _maybe_timestamp_elapsed(end:
|
|
49
|
+
def _maybe_timestamp_elapsed(end: str | None, start: str | None) -> float | None:
|
|
50
50
|
if start is None or end is None:
|
|
51
51
|
return None
|
|
52
52
|
return (datetime.datetime.fromisoformat(end) - datetime.datetime.fromisoformat(start)).total_seconds()
|
|
@@ -61,11 +61,11 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
|
|
|
61
61
|
If `None`, the current span, experiment, or logger will be selected exactly as in `braintrust.start_span`.
|
|
62
62
|
"""
|
|
63
63
|
|
|
64
|
-
def __init__(self, logger:
|
|
64
|
+
def __init__(self, logger: braintrust.Span | braintrust.Experiment | braintrust.Logger | None = None):
|
|
65
65
|
self._logger = logger
|
|
66
|
-
self._spans:
|
|
67
|
-
self._first_input:
|
|
68
|
-
self._last_output:
|
|
66
|
+
self._spans: dict[str, braintrust.Span] = {}
|
|
67
|
+
self._first_input: dict[str, Any] = {}
|
|
68
|
+
self._last_output: dict[str, Any] = {}
|
|
69
69
|
|
|
70
70
|
def on_trace_start(self, trace: tracing.Trace) -> None:
|
|
71
71
|
trace_meta = trace.export() or {}
|
|
@@ -113,7 +113,7 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
|
|
|
113
113
|
# TODO(sachin): Add end time when SDK provides it.
|
|
114
114
|
# span.end(_timestamp_from_maybe_iso(trace.ended_at))
|
|
115
115
|
|
|
116
|
-
def _agent_log_data(self, span: tracing.Span[tracing.AgentSpanData]) ->
|
|
116
|
+
def _agent_log_data(self, span: tracing.Span[tracing.AgentSpanData]) -> dict[str, Any]:
|
|
117
117
|
return {
|
|
118
118
|
"metadata": {
|
|
119
119
|
"tools": span.span_data.tools,
|
|
@@ -122,7 +122,7 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
|
|
|
122
122
|
}
|
|
123
123
|
}
|
|
124
124
|
|
|
125
|
-
def _response_log_data(self, span: tracing.Span[tracing.ResponseSpanData]) ->
|
|
125
|
+
def _response_log_data(self, span: tracing.Span[tracing.ResponseSpanData]) -> dict[str, Any]:
|
|
126
126
|
data = {}
|
|
127
127
|
if span.span_data.input is not None:
|
|
128
128
|
data["input"] = span.span_data.input
|
|
@@ -145,13 +145,13 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
|
|
|
145
145
|
|
|
146
146
|
return data
|
|
147
147
|
|
|
148
|
-
def _function_log_data(self, span: tracing.Span[tracing.FunctionSpanData]) ->
|
|
148
|
+
def _function_log_data(self, span: tracing.Span[tracing.FunctionSpanData]) -> dict[str, Any]:
|
|
149
149
|
return {
|
|
150
150
|
"input": span.span_data.input,
|
|
151
151
|
"output": span.span_data.output,
|
|
152
152
|
}
|
|
153
153
|
|
|
154
|
-
def _handoff_log_data(self, span: tracing.Span[tracing.HandoffSpanData]) ->
|
|
154
|
+
def _handoff_log_data(self, span: tracing.Span[tracing.HandoffSpanData]) -> dict[str, Any]:
|
|
155
155
|
return {
|
|
156
156
|
"metadata": {
|
|
157
157
|
"from_agent": span.span_data.from_agent,
|
|
@@ -159,14 +159,14 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
|
|
|
159
159
|
}
|
|
160
160
|
}
|
|
161
161
|
|
|
162
|
-
def _guardrail_log_data(self, span: tracing.Span[tracing.GuardrailSpanData]) ->
|
|
162
|
+
def _guardrail_log_data(self, span: tracing.Span[tracing.GuardrailSpanData]) -> dict[str, Any]:
|
|
163
163
|
return {
|
|
164
164
|
"metadata": {
|
|
165
165
|
"triggered": span.span_data.triggered,
|
|
166
166
|
}
|
|
167
167
|
}
|
|
168
168
|
|
|
169
|
-
def _generation_log_data(self, span: tracing.Span[tracing.GenerationSpanData]) ->
|
|
169
|
+
def _generation_log_data(self, span: tracing.Span[tracing.GenerationSpanData]) -> dict[str, Any]:
|
|
170
170
|
metrics = {}
|
|
171
171
|
ttft = _maybe_timestamp_elapsed(span.ended_at, span.started_at)
|
|
172
172
|
|
|
@@ -199,10 +199,10 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
|
|
|
199
199
|
"metrics": metrics,
|
|
200
200
|
}
|
|
201
201
|
|
|
202
|
-
def _custom_log_data(self, span: tracing.Span[tracing.CustomSpanData]) ->
|
|
202
|
+
def _custom_log_data(self, span: tracing.Span[tracing.CustomSpanData]) -> dict[str, Any]:
|
|
203
203
|
return span.span_data.data
|
|
204
204
|
|
|
205
|
-
def _log_data(self, span: tracing.Span[Any]) ->
|
|
205
|
+
def _log_data(self, span: tracing.Span[Any]) -> dict[str, Any]:
|
|
206
206
|
if isinstance(span.span_data, tracing.AgentSpanData):
|
|
207
207
|
return self._agent_log_data(span)
|
|
208
208
|
elif isinstance(span.span_data, tracing.ResponseSpanData):
|