braintrust 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/_generated_types.py +737 -672
- braintrust/audit.py +2 -2
- braintrust/bt_json.py +178 -19
- braintrust/cli/eval.py +6 -7
- braintrust/cli/push.py +11 -11
- braintrust/context.py +12 -17
- braintrust/contrib/temporal/__init__.py +16 -27
- braintrust/contrib/temporal/test_temporal.py +8 -3
- braintrust/devserver/auth.py +8 -8
- braintrust/devserver/cache.py +3 -4
- braintrust/devserver/cors.py +8 -7
- braintrust/devserver/dataset.py +3 -5
- braintrust/devserver/eval_hooks.py +7 -6
- braintrust/devserver/schemas.py +22 -19
- braintrust/devserver/server.py +19 -12
- braintrust/devserver/test_cached_login.py +4 -4
- braintrust/framework.py +139 -142
- braintrust/framework2.py +88 -87
- braintrust/functions/invoke.py +66 -59
- braintrust/functions/stream.py +3 -2
- braintrust/generated_types.py +3 -1
- braintrust/git_fields.py +11 -11
- braintrust/gitutil.py +2 -3
- braintrust/graph_util.py +10 -10
- braintrust/id_gen.py +2 -2
- braintrust/logger.py +373 -471
- braintrust/merge_row_batch.py +10 -9
- braintrust/oai.py +21 -20
- braintrust/otel/__init__.py +49 -49
- braintrust/otel/context.py +16 -30
- braintrust/otel/test_distributed_tracing.py +14 -11
- braintrust/otel/test_otel_bt_integration.py +32 -31
- braintrust/parameters.py +8 -8
- braintrust/prompt.py +14 -14
- braintrust/prompt_cache/disk_cache.py +5 -4
- braintrust/prompt_cache/lru_cache.py +3 -2
- braintrust/prompt_cache/prompt_cache.py +13 -14
- braintrust/queue.py +4 -4
- braintrust/score.py +4 -4
- braintrust/serializable_data_class.py +4 -4
- braintrust/span_identifier_v1.py +1 -2
- braintrust/span_identifier_v2.py +3 -4
- braintrust/span_identifier_v3.py +23 -20
- braintrust/span_identifier_v4.py +34 -25
- braintrust/test_bt_json.py +644 -0
- braintrust/test_framework.py +72 -6
- braintrust/test_helpers.py +5 -5
- braintrust/test_id_gen.py +2 -3
- braintrust/test_logger.py +211 -107
- braintrust/test_otel.py +61 -53
- braintrust/test_queue.py +0 -1
- braintrust/test_score.py +1 -3
- braintrust/test_span_components.py +29 -44
- braintrust/util.py +9 -8
- braintrust/version.py +2 -2
- braintrust/wrappers/_anthropic_utils.py +4 -4
- braintrust/wrappers/agno/__init__.py +3 -4
- braintrust/wrappers/agno/agent.py +1 -2
- braintrust/wrappers/agno/function_call.py +1 -2
- braintrust/wrappers/agno/model.py +1 -2
- braintrust/wrappers/agno/team.py +1 -2
- braintrust/wrappers/agno/utils.py +12 -12
- braintrust/wrappers/anthropic.py +7 -8
- braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
- braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
- braintrust/wrappers/dspy.py +15 -17
- braintrust/wrappers/google_genai/__init__.py +17 -30
- braintrust/wrappers/langchain.py +22 -24
- braintrust/wrappers/litellm.py +4 -3
- braintrust/wrappers/openai.py +15 -15
- braintrust/wrappers/pydantic_ai.py +225 -110
- braintrust/wrappers/test_agno.py +0 -1
- braintrust/wrappers/test_dspy.py +0 -1
- braintrust/wrappers/test_google_genai.py +64 -4
- braintrust/wrappers/test_litellm.py +0 -1
- braintrust/wrappers/test_pydantic_ai_integration.py +819 -22
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/METADATA +3 -2
- braintrust-0.4.1.dist-info/RECORD +121 -0
- braintrust-0.3.15.dist-info/RECORD +0 -120
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/WHEEL +0 -0
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/entry_points.txt +0 -0
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import time
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
|
|
5
|
-
def omit(obj:
|
|
5
|
+
def omit(obj: dict[str, Any], keys: list[str]):
|
|
6
6
|
return {k: v for k, v in obj.items() if k not in keys}
|
|
7
7
|
|
|
8
8
|
|
|
@@ -14,11 +14,11 @@ def mark_patched(obj: Any):
|
|
|
14
14
|
setattr(obj, "_braintrust_patched", True)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def clean(obj:
|
|
17
|
+
def clean(obj: dict[str, Any]) -> dict[str, Any]:
|
|
18
18
|
return {k: v for k, v in obj.items() if v is not None}
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def get_args_kwargs(args:
|
|
21
|
+
def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: list[str]):
|
|
22
22
|
return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys)
|
|
23
23
|
|
|
24
24
|
|
|
@@ -71,7 +71,7 @@ AGNO_METRICS_MAP = {
|
|
|
71
71
|
}
|
|
72
72
|
|
|
73
73
|
|
|
74
|
-
def extract_metadata(instance: Any, component: str) ->
|
|
74
|
+
def extract_metadata(instance: Any, component: str) -> dict[str, Any]:
|
|
75
75
|
"""Extract metadata from any component (model, agent, team)."""
|
|
76
76
|
metadata = {"component": component}
|
|
77
77
|
|
|
@@ -100,7 +100,7 @@ def extract_metadata(instance: Any, component: str) -> Dict[str, Any]:
|
|
|
100
100
|
return metadata
|
|
101
101
|
|
|
102
102
|
|
|
103
|
-
def parse_metrics_from_agno(usage: Any) ->
|
|
103
|
+
def parse_metrics_from_agno(usage: Any) -> dict[str, Any]:
|
|
104
104
|
"""Parse metrics from Agno usage object, following OpenAI wrapper pattern."""
|
|
105
105
|
metrics = {}
|
|
106
106
|
|
|
@@ -121,7 +121,7 @@ def parse_metrics_from_agno(usage: Any) -> Dict[str, Any]:
|
|
|
121
121
|
return metrics
|
|
122
122
|
|
|
123
123
|
|
|
124
|
-
def extract_metrics(result: Any, messages:
|
|
124
|
+
def extract_metrics(result: Any, messages: list | None = None) -> dict[str, Any]:
|
|
125
125
|
"""
|
|
126
126
|
Unified metrics extraction for all components.
|
|
127
127
|
|
|
@@ -163,7 +163,7 @@ def extract_metrics(result: Any, messages: Optional[list] = None) -> Dict[str, A
|
|
|
163
163
|
return {}
|
|
164
164
|
|
|
165
165
|
|
|
166
|
-
def extract_streaming_metrics(aggregated:
|
|
166
|
+
def extract_streaming_metrics(aggregated: dict[str, Any], start_time: float) -> dict[str, Any] | None:
|
|
167
167
|
"""Extract metrics from aggregated streaming response."""
|
|
168
168
|
metrics = {}
|
|
169
169
|
|
|
@@ -187,7 +187,7 @@ def extract_streaming_metrics(aggregated: Dict[str, Any], start_time: float) ->
|
|
|
187
187
|
return metrics if metrics else None
|
|
188
188
|
|
|
189
189
|
|
|
190
|
-
def _aggregate_metrics(target:
|
|
190
|
+
def _aggregate_metrics(target: dict[str, Any], source: dict[str, Any]) -> None:
|
|
191
191
|
"""Aggregate metrics from source into target dict."""
|
|
192
192
|
for key, value in source.items():
|
|
193
193
|
if _is_numeric(value):
|
|
@@ -205,7 +205,7 @@ def _aggregate_metrics(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
|
|
205
205
|
target[key] = value
|
|
206
206
|
|
|
207
207
|
|
|
208
|
-
def _aggregate_model_chunks(chunks:
|
|
208
|
+
def _aggregate_model_chunks(chunks: list[Any]) -> dict[str, Any]:
|
|
209
209
|
"""Aggregate ModelResponse chunks from invoke_stream into a complete response."""
|
|
210
210
|
aggregated = {
|
|
211
211
|
"content": "",
|
|
@@ -263,7 +263,7 @@ def _aggregate_model_chunks(chunks: List[Any]) -> Dict[str, Any]:
|
|
|
263
263
|
return aggregated
|
|
264
264
|
|
|
265
265
|
|
|
266
|
-
def _aggregate_response_stream_chunks(chunks:
|
|
266
|
+
def _aggregate_response_stream_chunks(chunks: list[Any]) -> dict[str, Any]:
|
|
267
267
|
"""
|
|
268
268
|
Aggregate chunks from response_stream which can be ModelResponse, RunOutputEvent, or TeamRunOutputEvent.
|
|
269
269
|
|
|
@@ -344,7 +344,7 @@ def _aggregate_response_stream_chunks(chunks: List[Any]) -> Dict[str, Any]:
|
|
|
344
344
|
return aggregated
|
|
345
345
|
|
|
346
346
|
|
|
347
|
-
def _aggregate_agent_chunks(chunks:
|
|
347
|
+
def _aggregate_agent_chunks(chunks: list[Any]) -> dict[str, Any]:
|
|
348
348
|
"""Aggregate BaseAgentRunEvent/BaseTeamRunEvent chunks into a complete response."""
|
|
349
349
|
aggregated = {
|
|
350
350
|
"content": "",
|
braintrust/wrappers/anthropic.py
CHANGED
|
@@ -2,7 +2,6 @@ import logging
|
|
|
2
2
|
import time
|
|
3
3
|
import warnings
|
|
4
4
|
from contextlib import contextmanager
|
|
5
|
-
from typing import Optional
|
|
6
5
|
|
|
7
6
|
from braintrust.logger import NOOP_SPAN, log_exc_info_to_span, start_span
|
|
8
7
|
from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens
|
|
@@ -10,7 +9,6 @@ from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usag
|
|
|
10
9
|
log = logging.getLogger(__name__)
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
|
|
14
12
|
# This tracer depends on an internal anthropic method used to merge
|
|
15
13
|
# streamed messages together. It's a bit tricky so I'm opting to use it
|
|
16
14
|
# here. If it goes away, this polyfill will make it a no-op and the only
|
|
@@ -242,7 +240,7 @@ class TracedMessageStream(Wrapper):
|
|
|
242
240
|
self.__metrics = {}
|
|
243
241
|
self.__snapshot = None
|
|
244
242
|
self.__request_start_time = request_start_time
|
|
245
|
-
self.__time_to_first_token:
|
|
243
|
+
self.__time_to_first_token: float | None = None
|
|
246
244
|
|
|
247
245
|
def _get_final_traced_message(self):
|
|
248
246
|
return self.__snapshot
|
|
@@ -314,7 +312,7 @@ def _start_span(name, kwargs):
|
|
|
314
312
|
return NOOP_SPAN
|
|
315
313
|
|
|
316
314
|
|
|
317
|
-
def _log_message_to_span(message, span, time_to_first_token:
|
|
315
|
+
def _log_message_to_span(message, span, time_to_first_token: float | None = None):
|
|
318
316
|
"""Log telemetry from the given anthropic.Message to the given span."""
|
|
319
317
|
with _catch_exceptions():
|
|
320
318
|
usage = getattr(message, "usage", {})
|
|
@@ -326,13 +324,14 @@ def _log_message_to_span(message, span, time_to_first_token: Optional[float] = N
|
|
|
326
324
|
|
|
327
325
|
# Create output dict with only truthy values for role and content
|
|
328
326
|
output = {
|
|
329
|
-
k: v
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
}.items() if v
|
|
327
|
+
k: v
|
|
328
|
+
for k, v in {"role": getattr(message, "role", None), "content": getattr(message, "content", None)}.items()
|
|
329
|
+
if v
|
|
333
330
|
} or None
|
|
334
331
|
|
|
335
332
|
span.log(output=output, metrics=metrics)
|
|
333
|
+
|
|
334
|
+
|
|
336
335
|
@contextmanager
|
|
337
336
|
def _catch_exceptions():
|
|
338
337
|
try:
|
|
@@ -16,7 +16,6 @@ Usage (imports can be before or after setup):
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
18
18
|
import logging
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
from braintrust.logger import NOOP_SPAN, current_span, init_logger
|
|
22
21
|
|
|
@@ -28,9 +27,9 @@ __all__ = ["setup_claude_agent_sdk"]
|
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
def setup_claude_agent_sdk(
|
|
31
|
-
api_key:
|
|
32
|
-
project_id:
|
|
33
|
-
project:
|
|
30
|
+
api_key: str | None = None,
|
|
31
|
+
project_id: str | None = None,
|
|
32
|
+
project: str | None = None,
|
|
34
33
|
) -> bool:
|
|
35
34
|
"""
|
|
36
35
|
Setup Braintrust integration with Claude Agent SDK. Will automatically patch the SDK for automatic tracing.
|
|
@@ -2,7 +2,8 @@ import dataclasses
|
|
|
2
2
|
import logging
|
|
3
3
|
import threading
|
|
4
4
|
import time
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import AsyncGenerator, Callable
|
|
6
|
+
from typing import Any
|
|
6
7
|
|
|
7
8
|
from braintrust.logger import start_span
|
|
8
9
|
from braintrust.span_types import SpanTypeAttribute
|
|
@@ -108,12 +109,12 @@ def _wrap_tool_handler(handler: Any, tool_name: Any) -> Callable[..., Any]:
|
|
|
108
109
|
so we try the context variable first, then fall back to current_span export.
|
|
109
110
|
"""
|
|
110
111
|
# Check if already wrapped to prevent double-wrapping
|
|
111
|
-
if hasattr(handler,
|
|
112
|
+
if hasattr(handler, "_braintrust_wrapped"):
|
|
112
113
|
return handler
|
|
113
114
|
|
|
114
115
|
async def wrapped_handler(args: Any) -> Any:
|
|
115
116
|
# Get parent span export from thread-local storage
|
|
116
|
-
parent_export = getattr(_thread_local,
|
|
117
|
+
parent_export = getattr(_thread_local, "parent_span_export", None)
|
|
117
118
|
|
|
118
119
|
with start_span(
|
|
119
120
|
name=str(tool_name),
|
|
@@ -144,11 +145,14 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
144
145
|
We end the previous span when the next AssistantMessage arrives, using the marked
|
|
145
146
|
start time to ensure sequential timing (no overlapping LLM spans).
|
|
146
147
|
"""
|
|
147
|
-
def __init__(self, query_start_time: Optional[float] = None):
|
|
148
|
-
self.current_span: Optional[Any] = None
|
|
149
|
-
self.next_start_time: Optional[float] = query_start_time
|
|
150
148
|
|
|
151
|
-
def
|
|
149
|
+
def __init__(self, query_start_time: float | None = None):
|
|
150
|
+
self.current_span: Any | None = None
|
|
151
|
+
self.next_start_time: float | None = query_start_time
|
|
152
|
+
|
|
153
|
+
def start_llm_span(
|
|
154
|
+
self, message: Any, prompt: Any, conversation_history: list[dict[str, Any]]
|
|
155
|
+
) -> dict[str, Any] | None:
|
|
152
156
|
"""Start a new LLM span, ending the previous one if it exists."""
|
|
153
157
|
# Use the marked start time, or current time as fallback
|
|
154
158
|
start_time = self.next_start_time if self.next_start_time is not None else time.time()
|
|
@@ -158,8 +162,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
158
162
|
self.current_span.end(end_time=start_time)
|
|
159
163
|
|
|
160
164
|
final_content, span = _create_llm_span_for_messages(
|
|
161
|
-
[message], prompt, conversation_history,
|
|
162
|
-
start_time=start_time
|
|
165
|
+
[message], prompt, conversation_history, start_time=start_time
|
|
163
166
|
)
|
|
164
167
|
self.current_span = span
|
|
165
168
|
self.next_start_time = None # Reset for next span
|
|
@@ -169,7 +172,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
169
172
|
"""Mark when the next LLM call will start (after tool results)."""
|
|
170
173
|
self.next_start_time = time.time()
|
|
171
174
|
|
|
172
|
-
def log_usage(self, usage_metrics:
|
|
175
|
+
def log_usage(self, usage_metrics: dict[str, float]) -> None:
|
|
173
176
|
"""Log usage metrics to the current LLM span."""
|
|
174
177
|
if self.current_span and usage_metrics:
|
|
175
178
|
self.current_span.log(metrics=usage_metrics)
|
|
@@ -186,8 +189,8 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
186
189
|
client = original_client_class(*args, **kwargs)
|
|
187
190
|
super().__init__(client)
|
|
188
191
|
self.__client = client
|
|
189
|
-
self.__last_prompt:
|
|
190
|
-
self.__query_start_time:
|
|
192
|
+
self.__last_prompt: str | None = None
|
|
193
|
+
self.__query_start_time: float | None = None
|
|
191
194
|
|
|
192
195
|
async def query(self, *args: Any, **kwargs: Any) -> Any:
|
|
193
196
|
"""Wrap query to capture the prompt and start time for tracing."""
|
|
@@ -220,7 +223,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
220
223
|
# Store the parent span export in thread-local storage for tool handlers
|
|
221
224
|
_thread_local.parent_span_export = span.export()
|
|
222
225
|
|
|
223
|
-
final_results:
|
|
226
|
+
final_results: list[dict[str, Any]] = []
|
|
224
227
|
llm_tracker = LLMSpanTracker(query_start_time=self.__query_start_time)
|
|
225
228
|
|
|
226
229
|
try:
|
|
@@ -243,10 +246,12 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
243
246
|
llm_tracker.log_usage(usage_metrics)
|
|
244
247
|
|
|
245
248
|
result_metadata = {
|
|
246
|
-
k: v
|
|
249
|
+
k: v
|
|
250
|
+
for k, v in {
|
|
247
251
|
"num_turns": getattr(message, "num_turns", None),
|
|
248
252
|
"session_id": getattr(message, "session_id", None),
|
|
249
|
-
}.items()
|
|
253
|
+
}.items()
|
|
254
|
+
if v is not None
|
|
250
255
|
}
|
|
251
256
|
if result_metadata:
|
|
252
257
|
span.log(metadata=result_metadata)
|
|
@@ -257,8 +262,8 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
257
262
|
log.warning("Error in tracing code", exc_info=e)
|
|
258
263
|
finally:
|
|
259
264
|
llm_tracker.cleanup()
|
|
260
|
-
if hasattr(_thread_local,
|
|
261
|
-
delattr(_thread_local,
|
|
265
|
+
if hasattr(_thread_local, "parent_span_export"):
|
|
266
|
+
delattr(_thread_local, "parent_span_export")
|
|
262
267
|
|
|
263
268
|
async def __aenter__(self) -> "WrappedClaudeSDKClient":
|
|
264
269
|
await self.__client.__aenter__()
|
|
@@ -271,11 +276,11 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
271
276
|
|
|
272
277
|
|
|
273
278
|
def _create_llm_span_for_messages(
|
|
274
|
-
messages:
|
|
279
|
+
messages: list[Any], # List of AssistantMessage objects
|
|
275
280
|
prompt: Any,
|
|
276
|
-
conversation_history:
|
|
277
|
-
start_time:
|
|
278
|
-
) ->
|
|
281
|
+
conversation_history: list[dict[str, Any]],
|
|
282
|
+
start_time: float | None = None,
|
|
283
|
+
) -> tuple[dict[str, Any] | None, Any | None]:
|
|
279
284
|
"""Creates an LLM span for a group of AssistantMessage objects.
|
|
280
285
|
|
|
281
286
|
Returns a tuple of (final_content, span):
|
|
@@ -295,13 +300,12 @@ def _create_llm_span_for_messages(
|
|
|
295
300
|
model = getattr(last_message, "model", None)
|
|
296
301
|
input_messages = _build_llm_input(prompt, conversation_history)
|
|
297
302
|
|
|
298
|
-
outputs:
|
|
303
|
+
outputs: list[dict[str, Any]] = []
|
|
299
304
|
for msg in messages:
|
|
300
305
|
if hasattr(msg, "content"):
|
|
301
306
|
content = _serialize_content_blocks(msg.content)
|
|
302
307
|
outputs.append({"content": content, "role": "assistant"})
|
|
303
308
|
|
|
304
|
-
|
|
305
309
|
llm_span = start_span(
|
|
306
310
|
name="anthropic.messages.create",
|
|
307
311
|
span_attributes={"type": SpanTypeAttribute.LLM},
|
|
@@ -355,7 +359,7 @@ def _serialize_content_blocks(content: Any) -> Any:
|
|
|
355
359
|
return content
|
|
356
360
|
|
|
357
361
|
|
|
358
|
-
def _extract_usage_from_result_message(result_message: Any) ->
|
|
362
|
+
def _extract_usage_from_result_message(result_message: Any) -> dict[str, float]:
|
|
359
363
|
"""Extracts and normalizes usage metrics from a ResultMessage.
|
|
360
364
|
|
|
361
365
|
Uses shared Anthropic utilities for consistent metric extraction.
|
|
@@ -374,9 +378,7 @@ def _extract_usage_from_result_message(result_message: Any) -> Dict[str, float]:
|
|
|
374
378
|
return metrics
|
|
375
379
|
|
|
376
380
|
|
|
377
|
-
def _build_llm_input(
|
|
378
|
-
prompt: Any, conversation_history: List[Dict[str, Any]]
|
|
379
|
-
) -> Optional[List[Dict[str, Any]]]:
|
|
381
|
+
def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
|
380
382
|
"""Builds the input array for an LLM span from the initial prompt and conversation history.
|
|
381
383
|
|
|
382
384
|
Formats input to match Anthropic messages API format for proper UI rendering.
|
braintrust/wrappers/dspy.py
CHANGED
|
@@ -47,7 +47,7 @@ Advanced Usage with LiteLLM Patching:
|
|
|
47
47
|
```
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
|
-
from typing import Any
|
|
50
|
+
from typing import Any
|
|
51
51
|
|
|
52
52
|
from braintrust.logger import current_span, start_span
|
|
53
53
|
from braintrust.span_types import SpanTypeAttribute
|
|
@@ -58,9 +58,7 @@ from braintrust.span_types import SpanTypeAttribute
|
|
|
58
58
|
try:
|
|
59
59
|
from dspy.utils.callback import BaseCallback
|
|
60
60
|
except ImportError:
|
|
61
|
-
raise ImportError(
|
|
62
|
-
"DSPy is not installed. Please install it with: pip install dspy"
|
|
63
|
-
)
|
|
61
|
+
raise ImportError("DSPy is not installed. Please install it with: pip install dspy")
|
|
64
62
|
|
|
65
63
|
|
|
66
64
|
class BraintrustDSpyCallback(BaseCallback):
|
|
@@ -130,13 +128,13 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
130
128
|
"""Initialize the Braintrust DSPy callback handler."""
|
|
131
129
|
super().__init__()
|
|
132
130
|
# Map call_id to span objects for proper nesting
|
|
133
|
-
self._spans:
|
|
131
|
+
self._spans: dict[str, Any] = {}
|
|
134
132
|
|
|
135
133
|
def on_lm_start(
|
|
136
134
|
self,
|
|
137
135
|
call_id: str,
|
|
138
136
|
instance: Any,
|
|
139
|
-
inputs:
|
|
137
|
+
inputs: dict[str, Any],
|
|
140
138
|
):
|
|
141
139
|
"""Log the start of a language model call.
|
|
142
140
|
|
|
@@ -174,8 +172,8 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
174
172
|
def on_lm_end(
|
|
175
173
|
self,
|
|
176
174
|
call_id: str,
|
|
177
|
-
outputs:
|
|
178
|
-
exception:
|
|
175
|
+
outputs: dict[str, Any] | None,
|
|
176
|
+
exception: Exception | None = None,
|
|
179
177
|
):
|
|
180
178
|
"""Log the end of a language model call.
|
|
181
179
|
|
|
@@ -205,7 +203,7 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
205
203
|
self,
|
|
206
204
|
call_id: str,
|
|
207
205
|
instance: Any,
|
|
208
|
-
inputs:
|
|
206
|
+
inputs: dict[str, Any],
|
|
209
207
|
):
|
|
210
208
|
"""Log the start of a DSPy module execution.
|
|
211
209
|
|
|
@@ -236,8 +234,8 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
236
234
|
def on_module_end(
|
|
237
235
|
self,
|
|
238
236
|
call_id: str,
|
|
239
|
-
outputs:
|
|
240
|
-
exception:
|
|
237
|
+
outputs: Any | None,
|
|
238
|
+
exception: Exception | None = None,
|
|
241
239
|
):
|
|
242
240
|
"""Log the end of a DSPy module execution.
|
|
243
241
|
|
|
@@ -274,7 +272,7 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
274
272
|
self,
|
|
275
273
|
call_id: str,
|
|
276
274
|
instance: Any,
|
|
277
|
-
inputs:
|
|
275
|
+
inputs: dict[str, Any],
|
|
278
276
|
):
|
|
279
277
|
"""Log the start of a tool invocation.
|
|
280
278
|
|
|
@@ -309,8 +307,8 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
309
307
|
def on_tool_end(
|
|
310
308
|
self,
|
|
311
309
|
call_id: str,
|
|
312
|
-
outputs:
|
|
313
|
-
exception:
|
|
310
|
+
outputs: dict[str, Any] | None,
|
|
311
|
+
exception: Exception | None = None,
|
|
314
312
|
):
|
|
315
313
|
"""Log the end of a tool invocation.
|
|
316
314
|
|
|
@@ -340,7 +338,7 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
340
338
|
self,
|
|
341
339
|
call_id: str,
|
|
342
340
|
instance: Any,
|
|
343
|
-
inputs:
|
|
341
|
+
inputs: dict[str, Any],
|
|
344
342
|
):
|
|
345
343
|
"""Log the start of an evaluation run.
|
|
346
344
|
|
|
@@ -374,8 +372,8 @@ class BraintrustDSpyCallback(BaseCallback):
|
|
|
374
372
|
def on_evaluate_end(
|
|
375
373
|
self,
|
|
376
374
|
call_id: str,
|
|
377
|
-
outputs:
|
|
378
|
-
exception:
|
|
375
|
+
outputs: Any | None,
|
|
376
|
+
exception: Exception | None = None,
|
|
379
377
|
):
|
|
380
378
|
"""Log the end of an evaluation run.
|
|
381
379
|
|
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import time
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
from wrapt import wrap_function_wrapper
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Any
|
|
6
5
|
|
|
6
|
+
from braintrust.bt_json import bt_safe_deep_copy
|
|
7
7
|
from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
|
|
8
8
|
from braintrust.span_types import SpanTypeAttribute
|
|
9
|
+
from wrapt import wrap_function_wrapper
|
|
9
10
|
|
|
10
11
|
logger = logging.getLogger(__name__)
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
def setup_genai(
|
|
14
|
-
api_key:
|
|
15
|
-
project_id:
|
|
16
|
-
project_name:
|
|
15
|
+
api_key: str | None = None,
|
|
16
|
+
project_id: str | None = None,
|
|
17
|
+
project_name: str | None = None,
|
|
17
18
|
):
|
|
18
19
|
span = current_span()
|
|
19
20
|
if span == NOOP_SPAN:
|
|
@@ -148,8 +149,8 @@ def wrap_async_models(AsyncModels: Any):
|
|
|
148
149
|
return AsyncModels
|
|
149
150
|
|
|
150
151
|
|
|
151
|
-
def _serialize_input(api_client: Any, input:
|
|
152
|
-
config =
|
|
152
|
+
def _serialize_input(api_client: Any, input: dict[str, Any]):
|
|
153
|
+
config = bt_safe_deep_copy(input.get("config"))
|
|
153
154
|
|
|
154
155
|
if config is not None:
|
|
155
156
|
tools = _serialize_tools(api_client, input)
|
|
@@ -223,7 +224,7 @@ def _serialize_content_item(item: Any) -> Any:
|
|
|
223
224
|
return item
|
|
224
225
|
|
|
225
226
|
|
|
226
|
-
def _serialize_tools(api_client: Any, input:
|
|
227
|
+
def _serialize_tools(api_client: Any, input: Any | None):
|
|
227
228
|
try:
|
|
228
229
|
from google.genai.models import (
|
|
229
230
|
_GenerateContentParameters_to_mldev, # pyright: ignore [reportPrivateUsage]
|
|
@@ -242,7 +243,7 @@ def _serialize_tools(api_client: Any, input: Optional[Any]):
|
|
|
242
243
|
return None
|
|
243
244
|
|
|
244
245
|
|
|
245
|
-
def omit(obj:
|
|
246
|
+
def omit(obj: dict[str, Any], keys: Iterable[str]):
|
|
246
247
|
return {k: v for k, v in obj.items() if k not in keys}
|
|
247
248
|
|
|
248
249
|
|
|
@@ -254,11 +255,11 @@ def mark_patched(obj: Any):
|
|
|
254
255
|
return setattr(obj, "_braintrust_patched", True)
|
|
255
256
|
|
|
256
257
|
|
|
257
|
-
def get_args_kwargs(args:
|
|
258
|
+
def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: Iterable[str]):
|
|
258
259
|
return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys)
|
|
259
260
|
|
|
260
261
|
|
|
261
|
-
def _extract_generate_content_metrics(response: Any, start: float) ->
|
|
262
|
+
def _extract_generate_content_metrics(response: Any, start: float) -> dict[str, Any]:
|
|
262
263
|
"""Extract metrics from a non-streaming generate_content response."""
|
|
263
264
|
end_time = time.time()
|
|
264
265
|
metrics = dict(
|
|
@@ -297,8 +298,8 @@ def _extract_generate_content_metrics(response: Any, start: float) -> Dict[str,
|
|
|
297
298
|
|
|
298
299
|
|
|
299
300
|
def _aggregate_generate_content_chunks(
|
|
300
|
-
chunks:
|
|
301
|
-
) ->
|
|
301
|
+
chunks: list[Any], start: float, first_token_time: float | None = None
|
|
302
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
302
303
|
"""Aggregate streaming chunks into a single response with metrics."""
|
|
303
304
|
end_time = time.time()
|
|
304
305
|
metrics = dict(
|
|
@@ -410,11 +411,11 @@ def _aggregate_generate_content_chunks(
|
|
|
410
411
|
return aggregated, clean_metrics
|
|
411
412
|
|
|
412
413
|
|
|
413
|
-
def clean(obj:
|
|
414
|
+
def clean(obj: dict[str, Any]) -> dict[str, Any]:
|
|
414
415
|
return {k: v for k, v in obj.items() if v is not None}
|
|
415
416
|
|
|
416
417
|
|
|
417
|
-
def get_path(obj:
|
|
418
|
+
def get_path(obj: dict[str, Any], path: str, default: Any = None) -> Any | None:
|
|
418
419
|
keys = path.split(".")
|
|
419
420
|
current = obj
|
|
420
421
|
|
|
@@ -424,17 +425,3 @@ def get_path(obj: Dict[str, Any], path: str, default: Any = None) -> Optional[An
|
|
|
424
425
|
current = current[key]
|
|
425
426
|
|
|
426
427
|
return current
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
def _try_dict(obj: Any) -> Optional[Dict[str, Any]]:
|
|
430
|
-
try:
|
|
431
|
-
return obj.model_dump()
|
|
432
|
-
except AttributeError:
|
|
433
|
-
pass
|
|
434
|
-
|
|
435
|
-
try:
|
|
436
|
-
return obj.dump()
|
|
437
|
-
except AttributeError:
|
|
438
|
-
pass
|
|
439
|
-
|
|
440
|
-
return obj
|
braintrust/wrappers/langchain.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import contextvars
|
|
2
2
|
import logging
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
from uuid import UUID
|
|
5
5
|
|
|
6
6
|
import braintrust
|
|
@@ -30,7 +30,7 @@ class BraintrustTracer(BaseCallbackHandler):
|
|
|
30
30
|
self.logger = logger
|
|
31
31
|
self.spans = {}
|
|
32
32
|
|
|
33
|
-
def _start_span(self, parent_run_id, run_id, name:
|
|
33
|
+
def _start_span(self, parent_run_id, run_id, name: str | None, **kwargs: Any) -> Any:
|
|
34
34
|
assert run_id not in self.spans, f"Span already exists for run_id {run_id} (this is likely a bug)"
|
|
35
35
|
|
|
36
36
|
current_parent = langchain_parent.get()
|
|
@@ -60,29 +60,29 @@ class BraintrustTracer(BaseCallbackHandler):
|
|
|
60
60
|
|
|
61
61
|
def on_chain_start(
|
|
62
62
|
self,
|
|
63
|
-
serialized:
|
|
64
|
-
inputs:
|
|
63
|
+
serialized: dict[str, Any],
|
|
64
|
+
inputs: dict[str, Any],
|
|
65
65
|
*,
|
|
66
66
|
run_id: UUID,
|
|
67
|
-
parent_run_id:
|
|
68
|
-
tags:
|
|
67
|
+
parent_run_id: UUID | None = None,
|
|
68
|
+
tags: list[str] | None = None,
|
|
69
69
|
**kwargs: Any,
|
|
70
70
|
) -> Any:
|
|
71
71
|
self._start_span(parent_run_id, run_id, "Chain", input=inputs, metadata={"tags": tags})
|
|
72
72
|
|
|
73
73
|
def on_chain_end(
|
|
74
|
-
self, outputs:
|
|
74
|
+
self, outputs: dict[str, Any], *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
|
|
75
75
|
) -> Any:
|
|
76
76
|
self._end_span(run_id, output=outputs)
|
|
77
77
|
|
|
78
78
|
def on_llm_start(
|
|
79
79
|
self,
|
|
80
|
-
serialized:
|
|
81
|
-
prompts:
|
|
80
|
+
serialized: dict[str, Any],
|
|
81
|
+
prompts: list[str],
|
|
82
82
|
*,
|
|
83
83
|
run_id: UUID,
|
|
84
|
-
parent_run_id:
|
|
85
|
-
tags:
|
|
84
|
+
parent_run_id: UUID | None = None,
|
|
85
|
+
tags: list[str] | None = None,
|
|
86
86
|
**kwargs: Any,
|
|
87
87
|
) -> Any:
|
|
88
88
|
self._start_span(
|
|
@@ -95,12 +95,12 @@ class BraintrustTracer(BaseCallbackHandler):
|
|
|
95
95
|
|
|
96
96
|
def on_chat_model_start(
|
|
97
97
|
self,
|
|
98
|
-
serialized:
|
|
99
|
-
messages:
|
|
98
|
+
serialized: dict[str, Any],
|
|
99
|
+
messages: list[list[BaseMessage]],
|
|
100
100
|
*,
|
|
101
101
|
run_id: UUID,
|
|
102
|
-
parent_run_id:
|
|
103
|
-
tags:
|
|
102
|
+
parent_run_id: UUID | None = None,
|
|
103
|
+
tags: list[str] | None = None,
|
|
104
104
|
**kwargs: Any,
|
|
105
105
|
) -> Any:
|
|
106
106
|
self._start_span(
|
|
@@ -112,7 +112,7 @@ class BraintrustTracer(BaseCallbackHandler):
|
|
|
112
112
|
)
|
|
113
113
|
|
|
114
114
|
def on_llm_end(
|
|
115
|
-
self, response: LLMResult, *, run_id: UUID, parent_run_id:
|
|
115
|
+
self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
|
|
116
116
|
) -> Any:
|
|
117
117
|
metrics = {}
|
|
118
118
|
token_usage = response.llm_output.get("token_usage", {})
|
|
@@ -127,25 +127,23 @@ class BraintrustTracer(BaseCallbackHandler):
|
|
|
127
127
|
|
|
128
128
|
def on_tool_start(
|
|
129
129
|
self,
|
|
130
|
-
serialized:
|
|
130
|
+
serialized: dict[str, Any],
|
|
131
131
|
input_str: str,
|
|
132
132
|
*,
|
|
133
133
|
run_id: UUID,
|
|
134
|
-
parent_run_id:
|
|
135
|
-
tags:
|
|
134
|
+
parent_run_id: UUID | None = None,
|
|
135
|
+
tags: list[str] | None = None,
|
|
136
136
|
**kwargs: Any,
|
|
137
137
|
) -> Any:
|
|
138
138
|
_logger.warning("Starting tool, but it will not be traced in braintrust (unsupported)")
|
|
139
139
|
|
|
140
|
-
def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id:
|
|
140
|
+
def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
|
|
141
141
|
pass
|
|
142
142
|
|
|
143
|
-
def on_retriever_start(
|
|
144
|
-
self, query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
|
|
145
|
-
) -> Any:
|
|
143
|
+
def on_retriever_start(self, query: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
|
|
146
144
|
_logger.warning("Starting retriever, but it will not be traced in braintrust (unsupported)")
|
|
147
145
|
|
|
148
146
|
def on_retriever_end(
|
|
149
|
-
self, response:
|
|
147
|
+
self, response: list[Document], *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
|
|
150
148
|
) -> Any:
|
|
151
149
|
pass
|
braintrust/wrappers/litellm.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import time
|
|
4
|
-
from collections.abc import AsyncGenerator, Generator
|
|
4
|
+
from collections.abc import AsyncGenerator, Callable, Generator
|
|
5
5
|
from types import TracebackType
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from braintrust.logger import Span, start_span
|
|
9
9
|
from braintrust.span_types import SpanTypeAttribute
|
|
@@ -655,7 +655,8 @@ def patch_litellm():
|
|
|
655
655
|
"""
|
|
656
656
|
try:
|
|
657
657
|
import litellm
|
|
658
|
-
|
|
658
|
+
|
|
659
|
+
if not hasattr(litellm, "_braintrust_wrapped"):
|
|
659
660
|
wrapped = wrap_litellm(litellm)
|
|
660
661
|
litellm.completion = wrapped.completion
|
|
661
662
|
litellm.acompletion = wrapped.acompletion
|