braintrust 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. braintrust/_generated_types.py +737 -672
  2. braintrust/audit.py +2 -2
  3. braintrust/bt_json.py +178 -19
  4. braintrust/cli/eval.py +6 -7
  5. braintrust/cli/push.py +11 -11
  6. braintrust/context.py +12 -17
  7. braintrust/contrib/temporal/__init__.py +16 -27
  8. braintrust/contrib/temporal/test_temporal.py +8 -3
  9. braintrust/devserver/auth.py +8 -8
  10. braintrust/devserver/cache.py +3 -4
  11. braintrust/devserver/cors.py +8 -7
  12. braintrust/devserver/dataset.py +3 -5
  13. braintrust/devserver/eval_hooks.py +7 -6
  14. braintrust/devserver/schemas.py +22 -19
  15. braintrust/devserver/server.py +19 -12
  16. braintrust/devserver/test_cached_login.py +4 -4
  17. braintrust/framework.py +139 -142
  18. braintrust/framework2.py +88 -87
  19. braintrust/functions/invoke.py +66 -59
  20. braintrust/functions/stream.py +3 -2
  21. braintrust/generated_types.py +3 -1
  22. braintrust/git_fields.py +11 -11
  23. braintrust/gitutil.py +2 -3
  24. braintrust/graph_util.py +10 -10
  25. braintrust/id_gen.py +2 -2
  26. braintrust/logger.py +373 -471
  27. braintrust/merge_row_batch.py +10 -9
  28. braintrust/oai.py +21 -20
  29. braintrust/otel/__init__.py +49 -49
  30. braintrust/otel/context.py +16 -30
  31. braintrust/otel/test_distributed_tracing.py +14 -11
  32. braintrust/otel/test_otel_bt_integration.py +32 -31
  33. braintrust/parameters.py +8 -8
  34. braintrust/prompt.py +14 -14
  35. braintrust/prompt_cache/disk_cache.py +5 -4
  36. braintrust/prompt_cache/lru_cache.py +3 -2
  37. braintrust/prompt_cache/prompt_cache.py +13 -14
  38. braintrust/queue.py +4 -4
  39. braintrust/score.py +4 -4
  40. braintrust/serializable_data_class.py +4 -4
  41. braintrust/span_identifier_v1.py +1 -2
  42. braintrust/span_identifier_v2.py +3 -4
  43. braintrust/span_identifier_v3.py +23 -20
  44. braintrust/span_identifier_v4.py +34 -25
  45. braintrust/test_bt_json.py +644 -0
  46. braintrust/test_framework.py +72 -6
  47. braintrust/test_helpers.py +5 -5
  48. braintrust/test_id_gen.py +2 -3
  49. braintrust/test_logger.py +211 -107
  50. braintrust/test_otel.py +61 -53
  51. braintrust/test_queue.py +0 -1
  52. braintrust/test_score.py +1 -3
  53. braintrust/test_span_components.py +29 -44
  54. braintrust/util.py +9 -8
  55. braintrust/version.py +2 -2
  56. braintrust/wrappers/_anthropic_utils.py +4 -4
  57. braintrust/wrappers/agno/__init__.py +3 -4
  58. braintrust/wrappers/agno/agent.py +1 -2
  59. braintrust/wrappers/agno/function_call.py +1 -2
  60. braintrust/wrappers/agno/model.py +1 -2
  61. braintrust/wrappers/agno/team.py +1 -2
  62. braintrust/wrappers/agno/utils.py +12 -12
  63. braintrust/wrappers/anthropic.py +7 -8
  64. braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
  65. braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
  66. braintrust/wrappers/dspy.py +15 -17
  67. braintrust/wrappers/google_genai/__init__.py +17 -30
  68. braintrust/wrappers/langchain.py +22 -24
  69. braintrust/wrappers/litellm.py +4 -3
  70. braintrust/wrappers/openai.py +15 -15
  71. braintrust/wrappers/pydantic_ai.py +225 -110
  72. braintrust/wrappers/test_agno.py +0 -1
  73. braintrust/wrappers/test_dspy.py +0 -1
  74. braintrust/wrappers/test_google_genai.py +64 -4
  75. braintrust/wrappers/test_litellm.py +0 -1
  76. braintrust/wrappers/test_pydantic_ai_integration.py +819 -22
  77. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/METADATA +3 -2
  78. braintrust-0.4.1.dist-info/RECORD +121 -0
  79. braintrust-0.3.15.dist-info/RECORD +0 -120
  80. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/WHEEL +0 -0
  81. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/entry_points.txt +0 -0
  82. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
1
  import time
2
- from typing import Any, Dict, List, Optional
2
+ from typing import Any
3
3
 
4
4
 
5
- def omit(obj: Dict[str, Any], keys: List[str]):
5
+ def omit(obj: dict[str, Any], keys: list[str]):
6
6
  return {k: v for k, v in obj.items() if k not in keys}
7
7
 
8
8
 
@@ -14,11 +14,11 @@ def mark_patched(obj: Any):
14
14
  setattr(obj, "_braintrust_patched", True)
15
15
 
16
16
 
17
- def clean(obj: Dict[str, Any]) -> Dict[str, Any]:
17
+ def clean(obj: dict[str, Any]) -> dict[str, Any]:
18
18
  return {k: v for k, v in obj.items() if v is not None}
19
19
 
20
20
 
21
- def get_args_kwargs(args: List[str], kwargs: Dict[str, Any], keys: List[str]):
21
+ def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: list[str]):
22
22
  return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys)
23
23
 
24
24
 
@@ -71,7 +71,7 @@ AGNO_METRICS_MAP = {
71
71
  }
72
72
 
73
73
 
74
- def extract_metadata(instance: Any, component: str) -> Dict[str, Any]:
74
+ def extract_metadata(instance: Any, component: str) -> dict[str, Any]:
75
75
  """Extract metadata from any component (model, agent, team)."""
76
76
  metadata = {"component": component}
77
77
 
@@ -100,7 +100,7 @@ def extract_metadata(instance: Any, component: str) -> Dict[str, Any]:
100
100
  return metadata
101
101
 
102
102
 
103
- def parse_metrics_from_agno(usage: Any) -> Dict[str, Any]:
103
+ def parse_metrics_from_agno(usage: Any) -> dict[str, Any]:
104
104
  """Parse metrics from Agno usage object, following OpenAI wrapper pattern."""
105
105
  metrics = {}
106
106
 
@@ -121,7 +121,7 @@ def parse_metrics_from_agno(usage: Any) -> Dict[str, Any]:
121
121
  return metrics
122
122
 
123
123
 
124
- def extract_metrics(result: Any, messages: Optional[list] = None) -> Dict[str, Any]:
124
+ def extract_metrics(result: Any, messages: list | None = None) -> dict[str, Any]:
125
125
  """
126
126
  Unified metrics extraction for all components.
127
127
 
@@ -163,7 +163,7 @@ def extract_metrics(result: Any, messages: Optional[list] = None) -> Dict[str, A
163
163
  return {}
164
164
 
165
165
 
166
- def extract_streaming_metrics(aggregated: Dict[str, Any], start_time: float) -> Optional[Dict[str, Any]]:
166
+ def extract_streaming_metrics(aggregated: dict[str, Any], start_time: float) -> dict[str, Any] | None:
167
167
  """Extract metrics from aggregated streaming response."""
168
168
  metrics = {}
169
169
 
@@ -187,7 +187,7 @@ def extract_streaming_metrics(aggregated: Dict[str, Any], start_time: float) ->
187
187
  return metrics if metrics else None
188
188
 
189
189
 
190
- def _aggregate_metrics(target: Dict[str, Any], source: Dict[str, Any]) -> None:
190
+ def _aggregate_metrics(target: dict[str, Any], source: dict[str, Any]) -> None:
191
191
  """Aggregate metrics from source into target dict."""
192
192
  for key, value in source.items():
193
193
  if _is_numeric(value):
@@ -205,7 +205,7 @@ def _aggregate_metrics(target: Dict[str, Any], source: Dict[str, Any]) -> None:
205
205
  target[key] = value
206
206
 
207
207
 
208
- def _aggregate_model_chunks(chunks: List[Any]) -> Dict[str, Any]:
208
+ def _aggregate_model_chunks(chunks: list[Any]) -> dict[str, Any]:
209
209
  """Aggregate ModelResponse chunks from invoke_stream into a complete response."""
210
210
  aggregated = {
211
211
  "content": "",
@@ -263,7 +263,7 @@ def _aggregate_model_chunks(chunks: List[Any]) -> Dict[str, Any]:
263
263
  return aggregated
264
264
 
265
265
 
266
- def _aggregate_response_stream_chunks(chunks: List[Any]) -> Dict[str, Any]:
266
+ def _aggregate_response_stream_chunks(chunks: list[Any]) -> dict[str, Any]:
267
267
  """
268
268
  Aggregate chunks from response_stream which can be ModelResponse, RunOutputEvent, or TeamRunOutputEvent.
269
269
 
@@ -344,7 +344,7 @@ def _aggregate_response_stream_chunks(chunks: List[Any]) -> Dict[str, Any]:
344
344
  return aggregated
345
345
 
346
346
 
347
- def _aggregate_agent_chunks(chunks: List[Any]) -> Dict[str, Any]:
347
+ def _aggregate_agent_chunks(chunks: list[Any]) -> dict[str, Any]:
348
348
  """Aggregate BaseAgentRunEvent/BaseTeamRunEvent chunks into a complete response."""
349
349
  aggregated = {
350
350
  "content": "",
@@ -2,7 +2,6 @@ import logging
2
2
  import time
3
3
  import warnings
4
4
  from contextlib import contextmanager
5
- from typing import Optional
6
5
 
7
6
  from braintrust.logger import NOOP_SPAN, log_exc_info_to_span, start_span
8
7
  from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens
@@ -10,7 +9,6 @@ from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usag
10
9
  log = logging.getLogger(__name__)
11
10
 
12
11
 
13
-
14
12
  # This tracer depends on an internal anthropic method used to merge
15
13
  # streamed messages together. It's a bit tricky so I'm opting to use it
16
14
  # here. If it goes away, this polyfill will make it a no-op and the only
@@ -242,7 +240,7 @@ class TracedMessageStream(Wrapper):
242
240
  self.__metrics = {}
243
241
  self.__snapshot = None
244
242
  self.__request_start_time = request_start_time
245
- self.__time_to_first_token: Optional[float] = None
243
+ self.__time_to_first_token: float | None = None
246
244
 
247
245
  def _get_final_traced_message(self):
248
246
  return self.__snapshot
@@ -314,7 +312,7 @@ def _start_span(name, kwargs):
314
312
  return NOOP_SPAN
315
313
 
316
314
 
317
- def _log_message_to_span(message, span, time_to_first_token: Optional[float] = None):
315
+ def _log_message_to_span(message, span, time_to_first_token: float | None = None):
318
316
  """Log telemetry from the given anthropic.Message to the given span."""
319
317
  with _catch_exceptions():
320
318
  usage = getattr(message, "usage", {})
@@ -326,13 +324,14 @@ def _log_message_to_span(message, span, time_to_first_token: Optional[float] = N
326
324
 
327
325
  # Create output dict with only truthy values for role and content
328
326
  output = {
329
- k: v for k, v in {
330
- "role": getattr(message, "role", None),
331
- "content": getattr(message, "content", None)
332
- }.items() if v
327
+ k: v
328
+ for k, v in {"role": getattr(message, "role", None), "content": getattr(message, "content", None)}.items()
329
+ if v
333
330
  } or None
334
331
 
335
332
  span.log(output=output, metrics=metrics)
333
+
334
+
336
335
  @contextmanager
337
336
  def _catch_exceptions():
338
337
  try:
@@ -16,7 +16,6 @@ Usage (imports can be before or after setup):
16
16
  """
17
17
 
18
18
  import logging
19
- from typing import Optional
20
19
 
21
20
  from braintrust.logger import NOOP_SPAN, current_span, init_logger
22
21
 
@@ -28,9 +27,9 @@ __all__ = ["setup_claude_agent_sdk"]
28
27
 
29
28
 
30
29
  def setup_claude_agent_sdk(
31
- api_key: Optional[str] = None,
32
- project_id: Optional[str] = None,
33
- project: Optional[str] = None,
30
+ api_key: str | None = None,
31
+ project_id: str | None = None,
32
+ project: str | None = None,
34
33
  ) -> bool:
35
34
  """
36
35
  Setup Braintrust integration with Claude Agent SDK. Will automatically patch the SDK for automatic tracing.
@@ -2,7 +2,8 @@ import dataclasses
2
2
  import logging
3
3
  import threading
4
4
  import time
5
- from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple
5
+ from collections.abc import AsyncGenerator, Callable
6
+ from typing import Any
6
7
 
7
8
  from braintrust.logger import start_span
8
9
  from braintrust.span_types import SpanTypeAttribute
@@ -108,12 +109,12 @@ def _wrap_tool_handler(handler: Any, tool_name: Any) -> Callable[..., Any]:
108
109
  so we try the context variable first, then fall back to current_span export.
109
110
  """
110
111
  # Check if already wrapped to prevent double-wrapping
111
- if hasattr(handler, '_braintrust_wrapped'):
112
+ if hasattr(handler, "_braintrust_wrapped"):
112
113
  return handler
113
114
 
114
115
  async def wrapped_handler(args: Any) -> Any:
115
116
  # Get parent span export from thread-local storage
116
- parent_export = getattr(_thread_local, 'parent_span_export', None)
117
+ parent_export = getattr(_thread_local, "parent_span_export", None)
117
118
 
118
119
  with start_span(
119
120
  name=str(tool_name),
@@ -144,11 +145,14 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
144
145
  We end the previous span when the next AssistantMessage arrives, using the marked
145
146
  start time to ensure sequential timing (no overlapping LLM spans).
146
147
  """
147
- def __init__(self, query_start_time: Optional[float] = None):
148
- self.current_span: Optional[Any] = None
149
- self.next_start_time: Optional[float] = query_start_time
150
148
 
151
- def start_llm_span(self, message: Any, prompt: Any, conversation_history: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
149
+ def __init__(self, query_start_time: float | None = None):
150
+ self.current_span: Any | None = None
151
+ self.next_start_time: float | None = query_start_time
152
+
153
+ def start_llm_span(
154
+ self, message: Any, prompt: Any, conversation_history: list[dict[str, Any]]
155
+ ) -> dict[str, Any] | None:
152
156
  """Start a new LLM span, ending the previous one if it exists."""
153
157
  # Use the marked start time, or current time as fallback
154
158
  start_time = self.next_start_time if self.next_start_time is not None else time.time()
@@ -158,8 +162,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
158
162
  self.current_span.end(end_time=start_time)
159
163
 
160
164
  final_content, span = _create_llm_span_for_messages(
161
- [message], prompt, conversation_history,
162
- start_time=start_time
165
+ [message], prompt, conversation_history, start_time=start_time
163
166
  )
164
167
  self.current_span = span
165
168
  self.next_start_time = None # Reset for next span
@@ -169,7 +172,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
169
172
  """Mark when the next LLM call will start (after tool results)."""
170
173
  self.next_start_time = time.time()
171
174
 
172
- def log_usage(self, usage_metrics: Dict[str, float]) -> None:
175
+ def log_usage(self, usage_metrics: dict[str, float]) -> None:
173
176
  """Log usage metrics to the current LLM span."""
174
177
  if self.current_span and usage_metrics:
175
178
  self.current_span.log(metrics=usage_metrics)
@@ -186,8 +189,8 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
186
189
  client = original_client_class(*args, **kwargs)
187
190
  super().__init__(client)
188
191
  self.__client = client
189
- self.__last_prompt: Optional[str] = None
190
- self.__query_start_time: Optional[float] = None
192
+ self.__last_prompt: str | None = None
193
+ self.__query_start_time: float | None = None
191
194
 
192
195
  async def query(self, *args: Any, **kwargs: Any) -> Any:
193
196
  """Wrap query to capture the prompt and start time for tracing."""
@@ -220,7 +223,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
220
223
  # Store the parent span export in thread-local storage for tool handlers
221
224
  _thread_local.parent_span_export = span.export()
222
225
 
223
- final_results: List[Dict[str, Any]] = []
226
+ final_results: list[dict[str, Any]] = []
224
227
  llm_tracker = LLMSpanTracker(query_start_time=self.__query_start_time)
225
228
 
226
229
  try:
@@ -243,10 +246,12 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
243
246
  llm_tracker.log_usage(usage_metrics)
244
247
 
245
248
  result_metadata = {
246
- k: v for k, v in {
249
+ k: v
250
+ for k, v in {
247
251
  "num_turns": getattr(message, "num_turns", None),
248
252
  "session_id": getattr(message, "session_id", None),
249
- }.items() if v is not None
253
+ }.items()
254
+ if v is not None
250
255
  }
251
256
  if result_metadata:
252
257
  span.log(metadata=result_metadata)
@@ -257,8 +262,8 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
257
262
  log.warning("Error in tracing code", exc_info=e)
258
263
  finally:
259
264
  llm_tracker.cleanup()
260
- if hasattr(_thread_local, 'parent_span_export'):
261
- delattr(_thread_local, 'parent_span_export')
265
+ if hasattr(_thread_local, "parent_span_export"):
266
+ delattr(_thread_local, "parent_span_export")
262
267
 
263
268
  async def __aenter__(self) -> "WrappedClaudeSDKClient":
264
269
  await self.__client.__aenter__()
@@ -271,11 +276,11 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
271
276
 
272
277
 
273
278
  def _create_llm_span_for_messages(
274
- messages: List[Any], # List of AssistantMessage objects
279
+ messages: list[Any], # List of AssistantMessage objects
275
280
  prompt: Any,
276
- conversation_history: List[Dict[str, Any]],
277
- start_time: Optional[float] = None,
278
- ) -> Tuple[Optional[Dict[str, Any]], Optional[Any]]:
281
+ conversation_history: list[dict[str, Any]],
282
+ start_time: float | None = None,
283
+ ) -> tuple[dict[str, Any] | None, Any | None]:
279
284
  """Creates an LLM span for a group of AssistantMessage objects.
280
285
 
281
286
  Returns a tuple of (final_content, span):
@@ -295,13 +300,12 @@ def _create_llm_span_for_messages(
295
300
  model = getattr(last_message, "model", None)
296
301
  input_messages = _build_llm_input(prompt, conversation_history)
297
302
 
298
- outputs: List[Dict[str, Any]] = []
303
+ outputs: list[dict[str, Any]] = []
299
304
  for msg in messages:
300
305
  if hasattr(msg, "content"):
301
306
  content = _serialize_content_blocks(msg.content)
302
307
  outputs.append({"content": content, "role": "assistant"})
303
308
 
304
-
305
309
  llm_span = start_span(
306
310
  name="anthropic.messages.create",
307
311
  span_attributes={"type": SpanTypeAttribute.LLM},
@@ -355,7 +359,7 @@ def _serialize_content_blocks(content: Any) -> Any:
355
359
  return content
356
360
 
357
361
 
358
- def _extract_usage_from_result_message(result_message: Any) -> Dict[str, float]:
362
+ def _extract_usage_from_result_message(result_message: Any) -> dict[str, float]:
359
363
  """Extracts and normalizes usage metrics from a ResultMessage.
360
364
 
361
365
  Uses shared Anthropic utilities for consistent metric extraction.
@@ -374,9 +378,7 @@ def _extract_usage_from_result_message(result_message: Any) -> Dict[str, float]:
374
378
  return metrics
375
379
 
376
380
 
377
- def _build_llm_input(
378
- prompt: Any, conversation_history: List[Dict[str, Any]]
379
- ) -> Optional[List[Dict[str, Any]]]:
381
+ def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
380
382
  """Builds the input array for an LLM span from the initial prompt and conversation history.
381
383
 
382
384
  Formats input to match Anthropic messages API format for proper UI rendering.
@@ -47,7 +47,7 @@ Advanced Usage with LiteLLM Patching:
47
47
  ```
48
48
  """
49
49
 
50
- from typing import Any, Dict, Optional
50
+ from typing import Any
51
51
 
52
52
  from braintrust.logger import current_span, start_span
53
53
  from braintrust.span_types import SpanTypeAttribute
@@ -58,9 +58,7 @@ from braintrust.span_types import SpanTypeAttribute
58
58
  try:
59
59
  from dspy.utils.callback import BaseCallback
60
60
  except ImportError:
61
- raise ImportError(
62
- "DSPy is not installed. Please install it with: pip install dspy"
63
- )
61
+ raise ImportError("DSPy is not installed. Please install it with: pip install dspy")
64
62
 
65
63
 
66
64
  class BraintrustDSpyCallback(BaseCallback):
@@ -130,13 +128,13 @@ class BraintrustDSpyCallback(BaseCallback):
130
128
  """Initialize the Braintrust DSPy callback handler."""
131
129
  super().__init__()
132
130
  # Map call_id to span objects for proper nesting
133
- self._spans: Dict[str, Any] = {}
131
+ self._spans: dict[str, Any] = {}
134
132
 
135
133
  def on_lm_start(
136
134
  self,
137
135
  call_id: str,
138
136
  instance: Any,
139
- inputs: Dict[str, Any],
137
+ inputs: dict[str, Any],
140
138
  ):
141
139
  """Log the start of a language model call.
142
140
 
@@ -174,8 +172,8 @@ class BraintrustDSpyCallback(BaseCallback):
174
172
  def on_lm_end(
175
173
  self,
176
174
  call_id: str,
177
- outputs: Optional[Dict[str, Any]],
178
- exception: Optional[Exception] = None,
175
+ outputs: dict[str, Any] | None,
176
+ exception: Exception | None = None,
179
177
  ):
180
178
  """Log the end of a language model call.
181
179
 
@@ -205,7 +203,7 @@ class BraintrustDSpyCallback(BaseCallback):
205
203
  self,
206
204
  call_id: str,
207
205
  instance: Any,
208
- inputs: Dict[str, Any],
206
+ inputs: dict[str, Any],
209
207
  ):
210
208
  """Log the start of a DSPy module execution.
211
209
 
@@ -236,8 +234,8 @@ class BraintrustDSpyCallback(BaseCallback):
236
234
  def on_module_end(
237
235
  self,
238
236
  call_id: str,
239
- outputs: Optional[Any],
240
- exception: Optional[Exception] = None,
237
+ outputs: Any | None,
238
+ exception: Exception | None = None,
241
239
  ):
242
240
  """Log the end of a DSPy module execution.
243
241
 
@@ -274,7 +272,7 @@ class BraintrustDSpyCallback(BaseCallback):
274
272
  self,
275
273
  call_id: str,
276
274
  instance: Any,
277
- inputs: Dict[str, Any],
275
+ inputs: dict[str, Any],
278
276
  ):
279
277
  """Log the start of a tool invocation.
280
278
 
@@ -309,8 +307,8 @@ class BraintrustDSpyCallback(BaseCallback):
309
307
  def on_tool_end(
310
308
  self,
311
309
  call_id: str,
312
- outputs: Optional[Dict[str, Any]],
313
- exception: Optional[Exception] = None,
310
+ outputs: dict[str, Any] | None,
311
+ exception: Exception | None = None,
314
312
  ):
315
313
  """Log the end of a tool invocation.
316
314
 
@@ -340,7 +338,7 @@ class BraintrustDSpyCallback(BaseCallback):
340
338
  self,
341
339
  call_id: str,
342
340
  instance: Any,
343
- inputs: Dict[str, Any],
341
+ inputs: dict[str, Any],
344
342
  ):
345
343
  """Log the start of an evaluation run.
346
344
 
@@ -374,8 +372,8 @@ class BraintrustDSpyCallback(BaseCallback):
374
372
  def on_evaluate_end(
375
373
  self,
376
374
  call_id: str,
377
- outputs: Optional[Any],
378
- exception: Optional[Exception] = None,
375
+ outputs: Any | None,
376
+ exception: Exception | None = None,
379
377
  ):
380
378
  """Log the end of an evaluation run.
381
379
 
@@ -1,19 +1,20 @@
1
1
  import logging
2
2
  import time
3
- from typing import Any, Dict, Iterable, List, Optional, Tuple
4
-
5
- from wrapt import wrap_function_wrapper
3
+ from collections.abc import Iterable
4
+ from typing import Any
6
5
 
6
+ from braintrust.bt_json import bt_safe_deep_copy
7
7
  from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
8
8
  from braintrust.span_types import SpanTypeAttribute
9
+ from wrapt import wrap_function_wrapper
9
10
 
10
11
  logger = logging.getLogger(__name__)
11
12
 
12
13
 
13
14
  def setup_genai(
14
- api_key: Optional[str] = None,
15
- project_id: Optional[str] = None,
16
- project_name: Optional[str] = None,
15
+ api_key: str | None = None,
16
+ project_id: str | None = None,
17
+ project_name: str | None = None,
17
18
  ):
18
19
  span = current_span()
19
20
  if span == NOOP_SPAN:
@@ -148,8 +149,8 @@ def wrap_async_models(AsyncModels: Any):
148
149
  return AsyncModels
149
150
 
150
151
 
151
- def _serialize_input(api_client: Any, input: Dict[str, Any]):
152
- config = _try_dict(input.get("config"))
152
+ def _serialize_input(api_client: Any, input: dict[str, Any]):
153
+ config = bt_safe_deep_copy(input.get("config"))
153
154
 
154
155
  if config is not None:
155
156
  tools = _serialize_tools(api_client, input)
@@ -223,7 +224,7 @@ def _serialize_content_item(item: Any) -> Any:
223
224
  return item
224
225
 
225
226
 
226
- def _serialize_tools(api_client: Any, input: Optional[Any]):
227
+ def _serialize_tools(api_client: Any, input: Any | None):
227
228
  try:
228
229
  from google.genai.models import (
229
230
  _GenerateContentParameters_to_mldev, # pyright: ignore [reportPrivateUsage]
@@ -242,7 +243,7 @@ def _serialize_tools(api_client: Any, input: Optional[Any]):
242
243
  return None
243
244
 
244
245
 
245
- def omit(obj: Dict[str, Any], keys: Iterable[str]):
246
+ def omit(obj: dict[str, Any], keys: Iterable[str]):
246
247
  return {k: v for k, v in obj.items() if k not in keys}
247
248
 
248
249
 
@@ -254,11 +255,11 @@ def mark_patched(obj: Any):
254
255
  return setattr(obj, "_braintrust_patched", True)
255
256
 
256
257
 
257
- def get_args_kwargs(args: List[str], kwargs: Dict[str, Any], keys: Iterable[str]):
258
+ def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: Iterable[str]):
258
259
  return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys)
259
260
 
260
261
 
261
- def _extract_generate_content_metrics(response: Any, start: float) -> Dict[str, Any]:
262
+ def _extract_generate_content_metrics(response: Any, start: float) -> dict[str, Any]:
262
263
  """Extract metrics from a non-streaming generate_content response."""
263
264
  end_time = time.time()
264
265
  metrics = dict(
@@ -297,8 +298,8 @@ def _extract_generate_content_metrics(response: Any, start: float) -> Dict[str,
297
298
 
298
299
 
299
300
  def _aggregate_generate_content_chunks(
300
- chunks: List[Any], start: float, first_token_time: Optional[float] = None
301
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
301
+ chunks: list[Any], start: float, first_token_time: float | None = None
302
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
302
303
  """Aggregate streaming chunks into a single response with metrics."""
303
304
  end_time = time.time()
304
305
  metrics = dict(
@@ -410,11 +411,11 @@ def _aggregate_generate_content_chunks(
410
411
  return aggregated, clean_metrics
411
412
 
412
413
 
413
- def clean(obj: Dict[str, Any]) -> Dict[str, Any]:
414
+ def clean(obj: dict[str, Any]) -> dict[str, Any]:
414
415
  return {k: v for k, v in obj.items() if v is not None}
415
416
 
416
417
 
417
- def get_path(obj: Dict[str, Any], path: str, default: Any = None) -> Optional[Any]:
418
+ def get_path(obj: dict[str, Any], path: str, default: Any = None) -> Any | None:
418
419
  keys = path.split(".")
419
420
  current = obj
420
421
 
@@ -424,17 +425,3 @@ def get_path(obj: Dict[str, Any], path: str, default: Any = None) -> Optional[An
424
425
  current = current[key]
425
426
 
426
427
  return current
427
-
428
-
429
- def _try_dict(obj: Any) -> Optional[Dict[str, Any]]:
430
- try:
431
- return obj.model_dump()
432
- except AttributeError:
433
- pass
434
-
435
- try:
436
- return obj.dump()
437
- except AttributeError:
438
- pass
439
-
440
- return obj
@@ -1,6 +1,6 @@
1
1
  import contextvars
2
2
  import logging
3
- from typing import Any, Dict, List, Optional
3
+ from typing import Any
4
4
  from uuid import UUID
5
5
 
6
6
  import braintrust
@@ -30,7 +30,7 @@ class BraintrustTracer(BaseCallbackHandler):
30
30
  self.logger = logger
31
31
  self.spans = {}
32
32
 
33
- def _start_span(self, parent_run_id, run_id, name: Optional[str], **kwargs: Any) -> Any:
33
+ def _start_span(self, parent_run_id, run_id, name: str | None, **kwargs: Any) -> Any:
34
34
  assert run_id not in self.spans, f"Span already exists for run_id {run_id} (this is likely a bug)"
35
35
 
36
36
  current_parent = langchain_parent.get()
@@ -60,29 +60,29 @@ class BraintrustTracer(BaseCallbackHandler):
60
60
 
61
61
  def on_chain_start(
62
62
  self,
63
- serialized: Dict[str, Any],
64
- inputs: Dict[str, Any],
63
+ serialized: dict[str, Any],
64
+ inputs: dict[str, Any],
65
65
  *,
66
66
  run_id: UUID,
67
- parent_run_id: Optional[UUID] = None,
68
- tags: Optional[List[str]] = None,
67
+ parent_run_id: UUID | None = None,
68
+ tags: list[str] | None = None,
69
69
  **kwargs: Any,
70
70
  ) -> Any:
71
71
  self._start_span(parent_run_id, run_id, "Chain", input=inputs, metadata={"tags": tags})
72
72
 
73
73
  def on_chain_end(
74
- self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
74
+ self, outputs: dict[str, Any], *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
75
75
  ) -> Any:
76
76
  self._end_span(run_id, output=outputs)
77
77
 
78
78
  def on_llm_start(
79
79
  self,
80
- serialized: Dict[str, Any],
81
- prompts: List[str],
80
+ serialized: dict[str, Any],
81
+ prompts: list[str],
82
82
  *,
83
83
  run_id: UUID,
84
- parent_run_id: Optional[UUID] = None,
85
- tags: Optional[List[str]] = None,
84
+ parent_run_id: UUID | None = None,
85
+ tags: list[str] | None = None,
86
86
  **kwargs: Any,
87
87
  ) -> Any:
88
88
  self._start_span(
@@ -95,12 +95,12 @@ class BraintrustTracer(BaseCallbackHandler):
95
95
 
96
96
  def on_chat_model_start(
97
97
  self,
98
- serialized: Dict[str, Any],
99
- messages: List[List[BaseMessage]],
98
+ serialized: dict[str, Any],
99
+ messages: list[list[BaseMessage]],
100
100
  *,
101
101
  run_id: UUID,
102
- parent_run_id: Optional[UUID] = None,
103
- tags: Optional[List[str]] = None,
102
+ parent_run_id: UUID | None = None,
103
+ tags: list[str] | None = None,
104
104
  **kwargs: Any,
105
105
  ) -> Any:
106
106
  self._start_span(
@@ -112,7 +112,7 @@ class BraintrustTracer(BaseCallbackHandler):
112
112
  )
113
113
 
114
114
  def on_llm_end(
115
- self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
115
+ self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
116
116
  ) -> Any:
117
117
  metrics = {}
118
118
  token_usage = response.llm_output.get("token_usage", {})
@@ -127,25 +127,23 @@ class BraintrustTracer(BaseCallbackHandler):
127
127
 
128
128
  def on_tool_start(
129
129
  self,
130
- serialized: Dict[str, Any],
130
+ serialized: dict[str, Any],
131
131
  input_str: str,
132
132
  *,
133
133
  run_id: UUID,
134
- parent_run_id: Optional[UUID] = None,
135
- tags: Optional[List[str]] = None,
134
+ parent_run_id: UUID | None = None,
135
+ tags: list[str] | None = None,
136
136
  **kwargs: Any,
137
137
  ) -> Any:
138
138
  _logger.warning("Starting tool, but it will not be traced in braintrust (unsupported)")
139
139
 
140
- def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
140
+ def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
141
141
  pass
142
142
 
143
- def on_retriever_start(
144
- self, query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
145
- ) -> Any:
143
+ def on_retriever_start(self, query: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
146
144
  _logger.warning("Starting retriever, but it will not be traced in braintrust (unsupported)")
147
145
 
148
146
  def on_retriever_end(
149
- self, response: List[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
147
+ self, response: list[Document], *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
150
148
  ) -> Any:
151
149
  pass
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import time
4
- from collections.abc import AsyncGenerator, Generator
4
+ from collections.abc import AsyncGenerator, Callable, Generator
5
5
  from types import TracebackType
6
- from typing import Any, Callable
6
+ from typing import Any
7
7
 
8
8
  from braintrust.logger import Span, start_span
9
9
  from braintrust.span_types import SpanTypeAttribute
@@ -655,7 +655,8 @@ def patch_litellm():
655
655
  """
656
656
  try:
657
657
  import litellm
658
- if not hasattr(litellm, '_braintrust_wrapped'):
658
+
659
+ if not hasattr(litellm, "_braintrust_wrapped"):
659
660
  wrapped = wrap_litellm(litellm)
660
661
  litellm.completion = wrapped.completion
661
662
  litellm.acompletion = wrapped.acompletion