braintrust 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. braintrust/__init__.py +4 -0
  2. braintrust/_generated_types.py +1200 -611
  3. braintrust/audit.py +2 -2
  4. braintrust/cli/eval.py +6 -7
  5. braintrust/cli/push.py +11 -11
  6. braintrust/conftest.py +1 -0
  7. braintrust/context.py +12 -17
  8. braintrust/contrib/temporal/__init__.py +16 -27
  9. braintrust/contrib/temporal/test_temporal.py +8 -3
  10. braintrust/devserver/auth.py +8 -8
  11. braintrust/devserver/cache.py +3 -4
  12. braintrust/devserver/cors.py +8 -7
  13. braintrust/devserver/dataset.py +3 -5
  14. braintrust/devserver/eval_hooks.py +7 -6
  15. braintrust/devserver/schemas.py +22 -19
  16. braintrust/devserver/server.py +19 -12
  17. braintrust/devserver/test_cached_login.py +4 -4
  18. braintrust/framework.py +128 -140
  19. braintrust/framework2.py +88 -87
  20. braintrust/functions/invoke.py +93 -53
  21. braintrust/functions/stream.py +3 -2
  22. braintrust/generated_types.py +17 -1
  23. braintrust/git_fields.py +11 -11
  24. braintrust/gitutil.py +2 -3
  25. braintrust/graph_util.py +10 -10
  26. braintrust/id_gen.py +2 -2
  27. braintrust/logger.py +346 -357
  28. braintrust/merge_row_batch.py +10 -9
  29. braintrust/oai.py +107 -24
  30. braintrust/otel/__init__.py +49 -49
  31. braintrust/otel/context.py +16 -30
  32. braintrust/otel/test_distributed_tracing.py +14 -11
  33. braintrust/otel/test_otel_bt_integration.py +32 -31
  34. braintrust/parameters.py +8 -8
  35. braintrust/prompt.py +14 -14
  36. braintrust/prompt_cache/disk_cache.py +5 -4
  37. braintrust/prompt_cache/lru_cache.py +3 -2
  38. braintrust/prompt_cache/prompt_cache.py +13 -14
  39. braintrust/queue.py +4 -4
  40. braintrust/score.py +4 -4
  41. braintrust/serializable_data_class.py +4 -4
  42. braintrust/span_identifier_v1.py +1 -2
  43. braintrust/span_identifier_v2.py +3 -4
  44. braintrust/span_identifier_v3.py +23 -20
  45. braintrust/span_identifier_v4.py +34 -25
  46. braintrust/test_framework.py +16 -6
  47. braintrust/test_helpers.py +5 -5
  48. braintrust/test_id_gen.py +2 -3
  49. braintrust/test_otel.py +61 -53
  50. braintrust/test_queue.py +0 -1
  51. braintrust/test_score.py +1 -3
  52. braintrust/test_span_components.py +29 -44
  53. braintrust/util.py +9 -8
  54. braintrust/version.py +2 -2
  55. braintrust/wrappers/_anthropic_utils.py +4 -4
  56. braintrust/wrappers/agno/__init__.py +3 -4
  57. braintrust/wrappers/agno/agent.py +1 -2
  58. braintrust/wrappers/agno/function_call.py +1 -2
  59. braintrust/wrappers/agno/model.py +1 -2
  60. braintrust/wrappers/agno/team.py +1 -2
  61. braintrust/wrappers/agno/utils.py +12 -12
  62. braintrust/wrappers/anthropic.py +7 -8
  63. braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
  64. braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
  65. braintrust/wrappers/dspy.py +15 -17
  66. braintrust/wrappers/google_genai/__init__.py +16 -16
  67. braintrust/wrappers/langchain.py +22 -24
  68. braintrust/wrappers/litellm.py +4 -3
  69. braintrust/wrappers/openai.py +15 -15
  70. braintrust/wrappers/pydantic_ai.py +1204 -0
  71. braintrust/wrappers/test_agno.py +0 -1
  72. braintrust/wrappers/test_dspy.py +0 -1
  73. braintrust/wrappers/test_google_genai.py +2 -3
  74. braintrust/wrappers/test_litellm.py +0 -1
  75. braintrust/wrappers/test_oai_attachments.py +322 -0
  76. braintrust/wrappers/test_pydantic_ai_integration.py +1788 -0
  77. braintrust/wrappers/{test_pydantic_ai.py → test_pydantic_ai_wrap_openai.py} +1 -2
  78. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/METADATA +3 -2
  79. braintrust-0.4.0.dist-info/RECORD +120 -0
  80. braintrust-0.3.14.dist-info/RECORD +0 -117
  81. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/WHEEL +0 -0
  82. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/entry_points.txt +0 -0
  83. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,8 @@ import dataclasses
2
2
  import logging
3
3
  import threading
4
4
  import time
5
- from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple
5
+ from collections.abc import AsyncGenerator, Callable
6
+ from typing import Any
6
7
 
7
8
  from braintrust.logger import start_span
8
9
  from braintrust.span_types import SpanTypeAttribute
@@ -108,12 +109,12 @@ def _wrap_tool_handler(handler: Any, tool_name: Any) -> Callable[..., Any]:
108
109
  so we try the context variable first, then fall back to current_span export.
109
110
  """
110
111
  # Check if already wrapped to prevent double-wrapping
111
- if hasattr(handler, '_braintrust_wrapped'):
112
+ if hasattr(handler, "_braintrust_wrapped"):
112
113
  return handler
113
114
 
114
115
  async def wrapped_handler(args: Any) -> Any:
115
116
  # Get parent span export from thread-local storage
116
- parent_export = getattr(_thread_local, 'parent_span_export', None)
117
+ parent_export = getattr(_thread_local, "parent_span_export", None)
117
118
 
118
119
  with start_span(
119
120
  name=str(tool_name),
@@ -144,11 +145,14 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
144
145
  We end the previous span when the next AssistantMessage arrives, using the marked
145
146
  start time to ensure sequential timing (no overlapping LLM spans).
146
147
  """
147
- def __init__(self, query_start_time: Optional[float] = None):
148
- self.current_span: Optional[Any] = None
149
- self.next_start_time: Optional[float] = query_start_time
150
148
 
151
- def start_llm_span(self, message: Any, prompt: Any, conversation_history: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
149
+ def __init__(self, query_start_time: float | None = None):
150
+ self.current_span: Any | None = None
151
+ self.next_start_time: float | None = query_start_time
152
+
153
+ def start_llm_span(
154
+ self, message: Any, prompt: Any, conversation_history: list[dict[str, Any]]
155
+ ) -> dict[str, Any] | None:
152
156
  """Start a new LLM span, ending the previous one if it exists."""
153
157
  # Use the marked start time, or current time as fallback
154
158
  start_time = self.next_start_time if self.next_start_time is not None else time.time()
@@ -158,8 +162,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
158
162
  self.current_span.end(end_time=start_time)
159
163
 
160
164
  final_content, span = _create_llm_span_for_messages(
161
- [message], prompt, conversation_history,
162
- start_time=start_time
165
+ [message], prompt, conversation_history, start_time=start_time
163
166
  )
164
167
  self.current_span = span
165
168
  self.next_start_time = None # Reset for next span
@@ -169,7 +172,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
169
172
  """Mark when the next LLM call will start (after tool results)."""
170
173
  self.next_start_time = time.time()
171
174
 
172
- def log_usage(self, usage_metrics: Dict[str, float]) -> None:
175
+ def log_usage(self, usage_metrics: dict[str, float]) -> None:
173
176
  """Log usage metrics to the current LLM span."""
174
177
  if self.current_span and usage_metrics:
175
178
  self.current_span.log(metrics=usage_metrics)
@@ -186,8 +189,8 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
186
189
  client = original_client_class(*args, **kwargs)
187
190
  super().__init__(client)
188
191
  self.__client = client
189
- self.__last_prompt: Optional[str] = None
190
- self.__query_start_time: Optional[float] = None
192
+ self.__last_prompt: str | None = None
193
+ self.__query_start_time: float | None = None
191
194
 
192
195
  async def query(self, *args: Any, **kwargs: Any) -> Any:
193
196
  """Wrap query to capture the prompt and start time for tracing."""
@@ -220,7 +223,7 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
220
223
  # Store the parent span export in thread-local storage for tool handlers
221
224
  _thread_local.parent_span_export = span.export()
222
225
 
223
- final_results: List[Dict[str, Any]] = []
226
+ final_results: list[dict[str, Any]] = []
224
227
  llm_tracker = LLMSpanTracker(query_start_time=self.__query_start_time)
225
228
 
226
229
  try:
@@ -243,10 +246,12 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
243
246
  llm_tracker.log_usage(usage_metrics)
244
247
 
245
248
  result_metadata = {
246
- k: v for k, v in {
249
+ k: v
250
+ for k, v in {
247
251
  "num_turns": getattr(message, "num_turns", None),
248
252
  "session_id": getattr(message, "session_id", None),
249
- }.items() if v is not None
253
+ }.items()
254
+ if v is not None
250
255
  }
251
256
  if result_metadata:
252
257
  span.log(metadata=result_metadata)
@@ -257,8 +262,8 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
257
262
  log.warning("Error in tracing code", exc_info=e)
258
263
  finally:
259
264
  llm_tracker.cleanup()
260
- if hasattr(_thread_local, 'parent_span_export'):
261
- delattr(_thread_local, 'parent_span_export')
265
+ if hasattr(_thread_local, "parent_span_export"):
266
+ delattr(_thread_local, "parent_span_export")
262
267
 
263
268
  async def __aenter__(self) -> "WrappedClaudeSDKClient":
264
269
  await self.__client.__aenter__()
@@ -271,11 +276,11 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
271
276
 
272
277
 
273
278
  def _create_llm_span_for_messages(
274
- messages: List[Any], # List of AssistantMessage objects
279
+ messages: list[Any], # List of AssistantMessage objects
275
280
  prompt: Any,
276
- conversation_history: List[Dict[str, Any]],
277
- start_time: Optional[float] = None,
278
- ) -> Tuple[Optional[Dict[str, Any]], Optional[Any]]:
281
+ conversation_history: list[dict[str, Any]],
282
+ start_time: float | None = None,
283
+ ) -> tuple[dict[str, Any] | None, Any | None]:
279
284
  """Creates an LLM span for a group of AssistantMessage objects.
280
285
 
281
286
  Returns a tuple of (final_content, span):
@@ -295,13 +300,12 @@ def _create_llm_span_for_messages(
295
300
  model = getattr(last_message, "model", None)
296
301
  input_messages = _build_llm_input(prompt, conversation_history)
297
302
 
298
- outputs: List[Dict[str, Any]] = []
303
+ outputs: list[dict[str, Any]] = []
299
304
  for msg in messages:
300
305
  if hasattr(msg, "content"):
301
306
  content = _serialize_content_blocks(msg.content)
302
307
  outputs.append({"content": content, "role": "assistant"})
303
308
 
304
-
305
309
  llm_span = start_span(
306
310
  name="anthropic.messages.create",
307
311
  span_attributes={"type": SpanTypeAttribute.LLM},
@@ -355,7 +359,7 @@ def _serialize_content_blocks(content: Any) -> Any:
355
359
  return content
356
360
 
357
361
 
358
- def _extract_usage_from_result_message(result_message: Any) -> Dict[str, float]:
362
+ def _extract_usage_from_result_message(result_message: Any) -> dict[str, float]:
359
363
  """Extracts and normalizes usage metrics from a ResultMessage.
360
364
 
361
365
  Uses shared Anthropic utilities for consistent metric extraction.
@@ -374,9 +378,7 @@ def _extract_usage_from_result_message(result_message: Any) -> Dict[str, float]:
374
378
  return metrics
375
379
 
376
380
 
377
- def _build_llm_input(
378
- prompt: Any, conversation_history: List[Dict[str, Any]]
379
- ) -> Optional[List[Dict[str, Any]]]:
381
+ def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
380
382
  """Builds the input array for an LLM span from the initial prompt and conversation history.
381
383
 
382
384
  Formats input to match Anthropic messages API format for proper UI rendering.
@@ -47,7 +47,7 @@ Advanced Usage with LiteLLM Patching:
47
47
  ```
48
48
  """
49
49
 
50
- from typing import Any, Dict, Optional
50
+ from typing import Any
51
51
 
52
52
  from braintrust.logger import current_span, start_span
53
53
  from braintrust.span_types import SpanTypeAttribute
@@ -58,9 +58,7 @@ from braintrust.span_types import SpanTypeAttribute
58
58
  try:
59
59
  from dspy.utils.callback import BaseCallback
60
60
  except ImportError:
61
- raise ImportError(
62
- "DSPy is not installed. Please install it with: pip install dspy"
63
- )
61
+ raise ImportError("DSPy is not installed. Please install it with: pip install dspy")
64
62
 
65
63
 
66
64
  class BraintrustDSpyCallback(BaseCallback):
@@ -130,13 +128,13 @@ class BraintrustDSpyCallback(BaseCallback):
130
128
  """Initialize the Braintrust DSPy callback handler."""
131
129
  super().__init__()
132
130
  # Map call_id to span objects for proper nesting
133
- self._spans: Dict[str, Any] = {}
131
+ self._spans: dict[str, Any] = {}
134
132
 
135
133
  def on_lm_start(
136
134
  self,
137
135
  call_id: str,
138
136
  instance: Any,
139
- inputs: Dict[str, Any],
137
+ inputs: dict[str, Any],
140
138
  ):
141
139
  """Log the start of a language model call.
142
140
 
@@ -174,8 +172,8 @@ class BraintrustDSpyCallback(BaseCallback):
174
172
  def on_lm_end(
175
173
  self,
176
174
  call_id: str,
177
- outputs: Optional[Dict[str, Any]],
178
- exception: Optional[Exception] = None,
175
+ outputs: dict[str, Any] | None,
176
+ exception: Exception | None = None,
179
177
  ):
180
178
  """Log the end of a language model call.
181
179
 
@@ -205,7 +203,7 @@ class BraintrustDSpyCallback(BaseCallback):
205
203
  self,
206
204
  call_id: str,
207
205
  instance: Any,
208
- inputs: Dict[str, Any],
206
+ inputs: dict[str, Any],
209
207
  ):
210
208
  """Log the start of a DSPy module execution.
211
209
 
@@ -236,8 +234,8 @@ class BraintrustDSpyCallback(BaseCallback):
236
234
  def on_module_end(
237
235
  self,
238
236
  call_id: str,
239
- outputs: Optional[Any],
240
- exception: Optional[Exception] = None,
237
+ outputs: Any | None,
238
+ exception: Exception | None = None,
241
239
  ):
242
240
  """Log the end of a DSPy module execution.
243
241
 
@@ -274,7 +272,7 @@ class BraintrustDSpyCallback(BaseCallback):
274
272
  self,
275
273
  call_id: str,
276
274
  instance: Any,
277
- inputs: Dict[str, Any],
275
+ inputs: dict[str, Any],
278
276
  ):
279
277
  """Log the start of a tool invocation.
280
278
 
@@ -309,8 +307,8 @@ class BraintrustDSpyCallback(BaseCallback):
309
307
  def on_tool_end(
310
308
  self,
311
309
  call_id: str,
312
- outputs: Optional[Dict[str, Any]],
313
- exception: Optional[Exception] = None,
310
+ outputs: dict[str, Any] | None,
311
+ exception: Exception | None = None,
314
312
  ):
315
313
  """Log the end of a tool invocation.
316
314
 
@@ -340,7 +338,7 @@ class BraintrustDSpyCallback(BaseCallback):
340
338
  self,
341
339
  call_id: str,
342
340
  instance: Any,
343
- inputs: Dict[str, Any],
341
+ inputs: dict[str, Any],
344
342
  ):
345
343
  """Log the start of an evaluation run.
346
344
 
@@ -374,8 +372,8 @@ class BraintrustDSpyCallback(BaseCallback):
374
372
  def on_evaluate_end(
375
373
  self,
376
374
  call_id: str,
377
- outputs: Optional[Any],
378
- exception: Optional[Exception] = None,
375
+ outputs: Any | None,
376
+ exception: Exception | None = None,
379
377
  ):
380
378
  """Log the end of an evaluation run.
381
379
 
@@ -1,19 +1,19 @@
1
1
  import logging
2
2
  import time
3
- from typing import Any, Dict, Iterable, List, Optional, Tuple
4
-
5
- from wrapt import wrap_function_wrapper
3
+ from collections.abc import Iterable
4
+ from typing import Any
6
5
 
7
6
  from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
8
7
  from braintrust.span_types import SpanTypeAttribute
8
+ from wrapt import wrap_function_wrapper
9
9
 
10
10
  logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
13
  def setup_genai(
14
- api_key: Optional[str] = None,
15
- project_id: Optional[str] = None,
16
- project_name: Optional[str] = None,
14
+ api_key: str | None = None,
15
+ project_id: str | None = None,
16
+ project_name: str | None = None,
17
17
  ):
18
18
  span = current_span()
19
19
  if span == NOOP_SPAN:
@@ -148,7 +148,7 @@ def wrap_async_models(AsyncModels: Any):
148
148
  return AsyncModels
149
149
 
150
150
 
151
- def _serialize_input(api_client: Any, input: Dict[str, Any]):
151
+ def _serialize_input(api_client: Any, input: dict[str, Any]):
152
152
  config = _try_dict(input.get("config"))
153
153
 
154
154
  if config is not None:
@@ -223,7 +223,7 @@ def _serialize_content_item(item: Any) -> Any:
223
223
  return item
224
224
 
225
225
 
226
- def _serialize_tools(api_client: Any, input: Optional[Any]):
226
+ def _serialize_tools(api_client: Any, input: Any | None):
227
227
  try:
228
228
  from google.genai.models import (
229
229
  _GenerateContentParameters_to_mldev, # pyright: ignore [reportPrivateUsage]
@@ -242,7 +242,7 @@ def _serialize_tools(api_client: Any, input: Optional[Any]):
242
242
  return None
243
243
 
244
244
 
245
- def omit(obj: Dict[str, Any], keys: Iterable[str]):
245
+ def omit(obj: dict[str, Any], keys: Iterable[str]):
246
246
  return {k: v for k, v in obj.items() if k not in keys}
247
247
 
248
248
 
@@ -254,11 +254,11 @@ def mark_patched(obj: Any):
254
254
  return setattr(obj, "_braintrust_patched", True)
255
255
 
256
256
 
257
- def get_args_kwargs(args: List[str], kwargs: Dict[str, Any], keys: Iterable[str]):
257
+ def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: Iterable[str]):
258
258
  return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys)
259
259
 
260
260
 
261
- def _extract_generate_content_metrics(response: Any, start: float) -> Dict[str, Any]:
261
+ def _extract_generate_content_metrics(response: Any, start: float) -> dict[str, Any]:
262
262
  """Extract metrics from a non-streaming generate_content response."""
263
263
  end_time = time.time()
264
264
  metrics = dict(
@@ -297,8 +297,8 @@ def _extract_generate_content_metrics(response: Any, start: float) -> Dict[str,
297
297
 
298
298
 
299
299
  def _aggregate_generate_content_chunks(
300
- chunks: List[Any], start: float, first_token_time: Optional[float] = None
301
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
300
+ chunks: list[Any], start: float, first_token_time: float | None = None
301
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
302
302
  """Aggregate streaming chunks into a single response with metrics."""
303
303
  end_time = time.time()
304
304
  metrics = dict(
@@ -410,11 +410,11 @@ def _aggregate_generate_content_chunks(
410
410
  return aggregated, clean_metrics
411
411
 
412
412
 
413
- def clean(obj: Dict[str, Any]) -> Dict[str, Any]:
413
+ def clean(obj: dict[str, Any]) -> dict[str, Any]:
414
414
  return {k: v for k, v in obj.items() if v is not None}
415
415
 
416
416
 
417
- def get_path(obj: Dict[str, Any], path: str, default: Any = None) -> Optional[Any]:
417
+ def get_path(obj: dict[str, Any], path: str, default: Any = None) -> Any | None:
418
418
  keys = path.split(".")
419
419
  current = obj
420
420
 
@@ -426,7 +426,7 @@ def get_path(obj: Dict[str, Any], path: str, default: Any = None) -> Optional[An
426
426
  return current
427
427
 
428
428
 
429
- def _try_dict(obj: Any) -> Optional[Dict[str, Any]]:
429
+ def _try_dict(obj: Any) -> dict[str, Any] | None:
430
430
  try:
431
431
  return obj.model_dump()
432
432
  except AttributeError:
@@ -1,6 +1,6 @@
1
1
  import contextvars
2
2
  import logging
3
- from typing import Any, Dict, List, Optional
3
+ from typing import Any
4
4
  from uuid import UUID
5
5
 
6
6
  import braintrust
@@ -30,7 +30,7 @@ class BraintrustTracer(BaseCallbackHandler):
30
30
  self.logger = logger
31
31
  self.spans = {}
32
32
 
33
- def _start_span(self, parent_run_id, run_id, name: Optional[str], **kwargs: Any) -> Any:
33
+ def _start_span(self, parent_run_id, run_id, name: str | None, **kwargs: Any) -> Any:
34
34
  assert run_id not in self.spans, f"Span already exists for run_id {run_id} (this is likely a bug)"
35
35
 
36
36
  current_parent = langchain_parent.get()
@@ -60,29 +60,29 @@ class BraintrustTracer(BaseCallbackHandler):
60
60
 
61
61
  def on_chain_start(
62
62
  self,
63
- serialized: Dict[str, Any],
64
- inputs: Dict[str, Any],
63
+ serialized: dict[str, Any],
64
+ inputs: dict[str, Any],
65
65
  *,
66
66
  run_id: UUID,
67
- parent_run_id: Optional[UUID] = None,
68
- tags: Optional[List[str]] = None,
67
+ parent_run_id: UUID | None = None,
68
+ tags: list[str] | None = None,
69
69
  **kwargs: Any,
70
70
  ) -> Any:
71
71
  self._start_span(parent_run_id, run_id, "Chain", input=inputs, metadata={"tags": tags})
72
72
 
73
73
  def on_chain_end(
74
- self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
74
+ self, outputs: dict[str, Any], *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
75
75
  ) -> Any:
76
76
  self._end_span(run_id, output=outputs)
77
77
 
78
78
  def on_llm_start(
79
79
  self,
80
- serialized: Dict[str, Any],
81
- prompts: List[str],
80
+ serialized: dict[str, Any],
81
+ prompts: list[str],
82
82
  *,
83
83
  run_id: UUID,
84
- parent_run_id: Optional[UUID] = None,
85
- tags: Optional[List[str]] = None,
84
+ parent_run_id: UUID | None = None,
85
+ tags: list[str] | None = None,
86
86
  **kwargs: Any,
87
87
  ) -> Any:
88
88
  self._start_span(
@@ -95,12 +95,12 @@ class BraintrustTracer(BaseCallbackHandler):
95
95
 
96
96
  def on_chat_model_start(
97
97
  self,
98
- serialized: Dict[str, Any],
99
- messages: List[List[BaseMessage]],
98
+ serialized: dict[str, Any],
99
+ messages: list[list[BaseMessage]],
100
100
  *,
101
101
  run_id: UUID,
102
- parent_run_id: Optional[UUID] = None,
103
- tags: Optional[List[str]] = None,
102
+ parent_run_id: UUID | None = None,
103
+ tags: list[str] | None = None,
104
104
  **kwargs: Any,
105
105
  ) -> Any:
106
106
  self._start_span(
@@ -112,7 +112,7 @@ class BraintrustTracer(BaseCallbackHandler):
112
112
  )
113
113
 
114
114
  def on_llm_end(
115
- self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
115
+ self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
116
116
  ) -> Any:
117
117
  metrics = {}
118
118
  token_usage = response.llm_output.get("token_usage", {})
@@ -127,25 +127,23 @@ class BraintrustTracer(BaseCallbackHandler):
127
127
 
128
128
  def on_tool_start(
129
129
  self,
130
- serialized: Dict[str, Any],
130
+ serialized: dict[str, Any],
131
131
  input_str: str,
132
132
  *,
133
133
  run_id: UUID,
134
- parent_run_id: Optional[UUID] = None,
135
- tags: Optional[List[str]] = None,
134
+ parent_run_id: UUID | None = None,
135
+ tags: list[str] | None = None,
136
136
  **kwargs: Any,
137
137
  ) -> Any:
138
138
  _logger.warning("Starting tool, but it will not be traced in braintrust (unsupported)")
139
139
 
140
- def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
140
+ def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
141
141
  pass
142
142
 
143
- def on_retriever_start(
144
- self, query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
145
- ) -> Any:
143
+ def on_retriever_start(self, query: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
146
144
  _logger.warning("Starting retriever, but it will not be traced in braintrust (unsupported)")
147
145
 
148
146
  def on_retriever_end(
149
- self, response: List[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any
147
+ self, response: list[Document], *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
150
148
  ) -> Any:
151
149
  pass
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import time
4
- from collections.abc import AsyncGenerator, Generator
4
+ from collections.abc import AsyncGenerator, Callable, Generator
5
5
  from types import TracebackType
6
- from typing import Any, Callable
6
+ from typing import Any
7
7
 
8
8
  from braintrust.logger import Span, start_span
9
9
  from braintrust.span_types import SpanTypeAttribute
@@ -655,7 +655,8 @@ def patch_litellm():
655
655
  """
656
656
  try:
657
657
  import litellm
658
- if not hasattr(litellm, '_braintrust_wrapped'):
658
+
659
+ if not hasattr(litellm, "_braintrust_wrapped"):
659
660
  wrapped = wrap_litellm(litellm)
660
661
  litellm.completion = wrapped.completion
661
662
  litellm.acompletion = wrapped.acompletion
@@ -3,7 +3,7 @@ Exports `BraintrustTracingProcessor`, a `tracing.TracingProcessor` that logs tra
3
3
  """
4
4
 
5
5
  import datetime
6
- from typing import Any, Dict, Optional, Union
6
+ from typing import Any
7
7
 
8
8
  import braintrust
9
9
  from agents import tracing
@@ -40,13 +40,13 @@ def _span_name(span: tracing.Span[Any]) -> str:
40
40
  return "Unknown"
41
41
 
42
42
 
43
- def _timestamp_from_maybe_iso(timestamp: Optional[str]) -> Optional[float]:
43
+ def _timestamp_from_maybe_iso(timestamp: str | None) -> float | None:
44
44
  if timestamp is None:
45
45
  return None
46
46
  return datetime.datetime.fromisoformat(timestamp).timestamp()
47
47
 
48
48
 
49
- def _maybe_timestamp_elapsed(end: Optional[str], start: Optional[str]) -> Optional[float]:
49
+ def _maybe_timestamp_elapsed(end: str | None, start: str | None) -> float | None:
50
50
  if start is None or end is None:
51
51
  return None
52
52
  return (datetime.datetime.fromisoformat(end) - datetime.datetime.fromisoformat(start)).total_seconds()
@@ -61,11 +61,11 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
61
61
  If `None`, the current span, experiment, or logger will be selected exactly as in `braintrust.start_span`.
62
62
  """
63
63
 
64
- def __init__(self, logger: Optional[Union[braintrust.Span, braintrust.Experiment, braintrust.Logger]] = None):
64
+ def __init__(self, logger: braintrust.Span | braintrust.Experiment | braintrust.Logger | None = None):
65
65
  self._logger = logger
66
- self._spans: Dict[str, braintrust.Span] = {}
67
- self._first_input: Dict[str, Any] = {}
68
- self._last_output: Dict[str, Any] = {}
66
+ self._spans: dict[str, braintrust.Span] = {}
67
+ self._first_input: dict[str, Any] = {}
68
+ self._last_output: dict[str, Any] = {}
69
69
 
70
70
  def on_trace_start(self, trace: tracing.Trace) -> None:
71
71
  trace_meta = trace.export() or {}
@@ -113,7 +113,7 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
113
113
  # TODO(sachin): Add end time when SDK provides it.
114
114
  # span.end(_timestamp_from_maybe_iso(trace.ended_at))
115
115
 
116
- def _agent_log_data(self, span: tracing.Span[tracing.AgentSpanData]) -> Dict[str, Any]:
116
+ def _agent_log_data(self, span: tracing.Span[tracing.AgentSpanData]) -> dict[str, Any]:
117
117
  return {
118
118
  "metadata": {
119
119
  "tools": span.span_data.tools,
@@ -122,7 +122,7 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
122
122
  }
123
123
  }
124
124
 
125
- def _response_log_data(self, span: tracing.Span[tracing.ResponseSpanData]) -> Dict[str, Any]:
125
+ def _response_log_data(self, span: tracing.Span[tracing.ResponseSpanData]) -> dict[str, Any]:
126
126
  data = {}
127
127
  if span.span_data.input is not None:
128
128
  data["input"] = span.span_data.input
@@ -145,13 +145,13 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
145
145
 
146
146
  return data
147
147
 
148
- def _function_log_data(self, span: tracing.Span[tracing.FunctionSpanData]) -> Dict[str, Any]:
148
+ def _function_log_data(self, span: tracing.Span[tracing.FunctionSpanData]) -> dict[str, Any]:
149
149
  return {
150
150
  "input": span.span_data.input,
151
151
  "output": span.span_data.output,
152
152
  }
153
153
 
154
- def _handoff_log_data(self, span: tracing.Span[tracing.HandoffSpanData]) -> Dict[str, Any]:
154
+ def _handoff_log_data(self, span: tracing.Span[tracing.HandoffSpanData]) -> dict[str, Any]:
155
155
  return {
156
156
  "metadata": {
157
157
  "from_agent": span.span_data.from_agent,
@@ -159,14 +159,14 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
159
159
  }
160
160
  }
161
161
 
162
- def _guardrail_log_data(self, span: tracing.Span[tracing.GuardrailSpanData]) -> Dict[str, Any]:
162
+ def _guardrail_log_data(self, span: tracing.Span[tracing.GuardrailSpanData]) -> dict[str, Any]:
163
163
  return {
164
164
  "metadata": {
165
165
  "triggered": span.span_data.triggered,
166
166
  }
167
167
  }
168
168
 
169
- def _generation_log_data(self, span: tracing.Span[tracing.GenerationSpanData]) -> Dict[str, Any]:
169
+ def _generation_log_data(self, span: tracing.Span[tracing.GenerationSpanData]) -> dict[str, Any]:
170
170
  metrics = {}
171
171
  ttft = _maybe_timestamp_elapsed(span.ended_at, span.started_at)
172
172
 
@@ -199,10 +199,10 @@ class BraintrustTracingProcessor(tracing.TracingProcessor):
199
199
  "metrics": metrics,
200
200
  }
201
201
 
202
- def _custom_log_data(self, span: tracing.Span[tracing.CustomSpanData]) -> Dict[str, Any]:
202
+ def _custom_log_data(self, span: tracing.Span[tracing.CustomSpanData]) -> dict[str, Any]:
203
203
  return span.span_data.data
204
204
 
205
- def _log_data(self, span: tracing.Span[Any]) -> Dict[str, Any]:
205
+ def _log_data(self, span: tracing.Span[Any]) -> dict[str, Any]:
206
206
  if isinstance(span.span_data, tracing.AgentSpanData):
207
207
  return self._agent_log_data(span)
208
208
  elif isinstance(span.span_data, tracing.ResponseSpanData):