braintrust 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. braintrust/__init__.py +4 -0
  2. braintrust/_generated_types.py +1200 -611
  3. braintrust/audit.py +2 -2
  4. braintrust/cli/eval.py +6 -7
  5. braintrust/cli/push.py +11 -11
  6. braintrust/conftest.py +1 -0
  7. braintrust/context.py +12 -17
  8. braintrust/contrib/temporal/__init__.py +16 -27
  9. braintrust/contrib/temporal/test_temporal.py +8 -3
  10. braintrust/devserver/auth.py +8 -8
  11. braintrust/devserver/cache.py +3 -4
  12. braintrust/devserver/cors.py +8 -7
  13. braintrust/devserver/dataset.py +3 -5
  14. braintrust/devserver/eval_hooks.py +7 -6
  15. braintrust/devserver/schemas.py +22 -19
  16. braintrust/devserver/server.py +19 -12
  17. braintrust/devserver/test_cached_login.py +4 -4
  18. braintrust/framework.py +128 -140
  19. braintrust/framework2.py +88 -87
  20. braintrust/functions/invoke.py +93 -53
  21. braintrust/functions/stream.py +3 -2
  22. braintrust/generated_types.py +17 -1
  23. braintrust/git_fields.py +11 -11
  24. braintrust/gitutil.py +2 -3
  25. braintrust/graph_util.py +10 -10
  26. braintrust/id_gen.py +2 -2
  27. braintrust/logger.py +346 -357
  28. braintrust/merge_row_batch.py +10 -9
  29. braintrust/oai.py +107 -24
  30. braintrust/otel/__init__.py +49 -49
  31. braintrust/otel/context.py +16 -30
  32. braintrust/otel/test_distributed_tracing.py +14 -11
  33. braintrust/otel/test_otel_bt_integration.py +32 -31
  34. braintrust/parameters.py +8 -8
  35. braintrust/prompt.py +14 -14
  36. braintrust/prompt_cache/disk_cache.py +5 -4
  37. braintrust/prompt_cache/lru_cache.py +3 -2
  38. braintrust/prompt_cache/prompt_cache.py +13 -14
  39. braintrust/queue.py +4 -4
  40. braintrust/score.py +4 -4
  41. braintrust/serializable_data_class.py +4 -4
  42. braintrust/span_identifier_v1.py +1 -2
  43. braintrust/span_identifier_v2.py +3 -4
  44. braintrust/span_identifier_v3.py +23 -20
  45. braintrust/span_identifier_v4.py +34 -25
  46. braintrust/test_framework.py +16 -6
  47. braintrust/test_helpers.py +5 -5
  48. braintrust/test_id_gen.py +2 -3
  49. braintrust/test_otel.py +61 -53
  50. braintrust/test_queue.py +0 -1
  51. braintrust/test_score.py +1 -3
  52. braintrust/test_span_components.py +29 -44
  53. braintrust/util.py +9 -8
  54. braintrust/version.py +2 -2
  55. braintrust/wrappers/_anthropic_utils.py +4 -4
  56. braintrust/wrappers/agno/__init__.py +3 -4
  57. braintrust/wrappers/agno/agent.py +1 -2
  58. braintrust/wrappers/agno/function_call.py +1 -2
  59. braintrust/wrappers/agno/model.py +1 -2
  60. braintrust/wrappers/agno/team.py +1 -2
  61. braintrust/wrappers/agno/utils.py +12 -12
  62. braintrust/wrappers/anthropic.py +7 -8
  63. braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
  64. braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
  65. braintrust/wrappers/dspy.py +15 -17
  66. braintrust/wrappers/google_genai/__init__.py +16 -16
  67. braintrust/wrappers/langchain.py +22 -24
  68. braintrust/wrappers/litellm.py +4 -3
  69. braintrust/wrappers/openai.py +15 -15
  70. braintrust/wrappers/pydantic_ai.py +1204 -0
  71. braintrust/wrappers/test_agno.py +0 -1
  72. braintrust/wrappers/test_dspy.py +0 -1
  73. braintrust/wrappers/test_google_genai.py +2 -3
  74. braintrust/wrappers/test_litellm.py +0 -1
  75. braintrust/wrappers/test_oai_attachments.py +322 -0
  76. braintrust/wrappers/test_pydantic_ai_integration.py +1788 -0
  77. braintrust/wrappers/{test_pydantic_ai.py → test_pydantic_ai_wrap_openai.py} +1 -2
  78. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/METADATA +3 -2
  79. braintrust-0.4.0.dist-info/RECORD +120 -0
  80. braintrust-0.3.14.dist-info/RECORD +0 -117
  81. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/WHEEL +0 -0
  82. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/entry_points.txt +0 -0
  83. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,6 @@ Tests serialization, deserialization, OTEL compatibility, and backward compatibi
6
6
  from uuid import uuid4
7
7
 
8
8
  import pytest
9
-
10
9
  from braintrust.id_gen import OTELIDGenerator
11
10
  from braintrust.span_identifier_v3 import SpanComponentsV3, SpanObjectTypeV3
12
11
  from braintrust.span_identifier_v4 import SpanComponentsV4
@@ -22,7 +21,7 @@ class TestSpanComponentsV3:
22
21
  object_id=str(uuid4()),
23
22
  row_id=str(uuid4()),
24
23
  span_id=str(uuid4()),
25
- root_span_id=str(uuid4())
24
+ root_span_id=str(uuid4()),
26
25
  )
27
26
 
28
27
  exported = components.to_str()
@@ -39,7 +38,7 @@ class TestSpanComponentsV3:
39
38
  components = SpanComponentsV3(
40
39
  object_type=SpanObjectTypeV3.EXPERIMENT,
41
40
  object_id=str(uuid4()),
42
- propagated_event={"key": "value", "nested": {"a": 1}}
41
+ propagated_event={"key": "value", "nested": {"a": 1}},
43
42
  )
44
43
 
45
44
  exported = components.to_str()
@@ -53,15 +52,15 @@ class TestSpanComponentsV3:
53
52
  """Test that V3 fails to preserve OTEL hex strings for 16-byte IDs (converts to UUID format)."""
54
53
  otel_gen = OTELIDGenerator()
55
54
  trace_id = otel_gen.get_trace_id() # 32-char hex (16 bytes)
56
- span_id = otel_gen.get_span_id() # 16-char hex (8 bytes)
55
+ span_id = otel_gen.get_span_id() # 16-char hex (8 bytes)
57
56
 
58
57
  # Use 16-byte hex strings for object_id and root_span_id to see UUID conversion
59
58
  components = SpanComponentsV3(
60
59
  object_type=SpanObjectTypeV3.PROJECT_LOGS,
61
60
  object_id=trace_id, # 16-byte hex should get converted to UUID format
62
- row_id='test-row-id',
63
- span_id=span_id, # 8-byte hex might be preserved
64
- root_span_id=trace_id # 16-byte hex should get converted to UUID format
61
+ row_id="test-row-id",
62
+ span_id=span_id, # 8-byte hex might be preserved
63
+ root_span_id=trace_id, # 16-byte hex should get converted to UUID format
65
64
  )
66
65
 
67
66
  exported = components.to_str()
@@ -79,14 +78,14 @@ class TestSpanComponentsV4:
79
78
  """Test that V4 preserves OTEL hex strings exactly."""
80
79
  otel_gen = OTELIDGenerator()
81
80
  trace_id = otel_gen.get_trace_id() # 32-char hex
82
- span_id = otel_gen.get_span_id() # 16-char hex
81
+ span_id = otel_gen.get_span_id() # 16-char hex
83
82
 
84
83
  components = SpanComponentsV4(
85
84
  object_type=SpanObjectTypeV3.PROJECT_LOGS,
86
- object_id='test-project-id',
87
- row_id='test-row-id',
85
+ object_id="test-project-id",
86
+ row_id="test-row-id",
88
87
  span_id=span_id,
89
- root_span_id=trace_id
88
+ root_span_id=trace_id,
90
89
  )
91
90
 
92
91
  exported = components.to_str()
@@ -108,9 +107,9 @@ class TestSpanComponentsV4:
108
107
  components = SpanComponentsV4(
109
108
  object_type=SpanObjectTypeV3.PROJECT_LOGS,
110
109
  object_id=uuid_object_id,
111
- row_id='test-row-id',
110
+ row_id="test-row-id",
112
111
  span_id=uuid_span_id,
113
- root_span_id=uuid_root_span_id
112
+ root_span_id=uuid_root_span_id,
114
113
  )
115
114
 
116
115
  exported = components.to_str()
@@ -133,9 +132,9 @@ class TestSpanComponentsV4:
133
132
  components = SpanComponentsV4(
134
133
  object_type=SpanObjectTypeV3.EXPERIMENT,
135
134
  object_id=uuid_object_id,
136
- row_id='test-row-id',
135
+ row_id="test-row-id",
137
136
  span_id=hex_span_id,
138
- root_span_id=hex_trace_id
137
+ root_span_id=hex_trace_id,
139
138
  )
140
139
 
141
140
  exported = components.to_str()
@@ -162,10 +161,10 @@ class TestSpanComponentsV4:
162
161
  # Create equivalent Python object
163
162
  py_components = SpanComponentsV4(
164
163
  object_type=SpanObjectTypeV3.EXPERIMENT,
165
- object_id='js-test-experiment-id',
166
- row_id='js-test-row-id',
167
- span_id='abcdef1234567890',
168
- root_span_id='fedcba0987654321fedcba0987654321'
164
+ object_id="js-test-experiment-id",
165
+ row_id="js-test-row-id",
166
+ span_id="abcdef1234567890",
167
+ root_span_id="fedcba0987654321fedcba0987654321",
169
168
  )
170
169
 
171
170
  # Python should generate the same slug
@@ -184,8 +183,8 @@ class TestSpanComponentsV4:
184
183
  """Test V4 with additional metadata."""
185
184
  components = SpanComponentsV4(
186
185
  object_type=SpanObjectTypeV3.PLAYGROUND_LOGS,
187
- object_id='test-session-id',
188
- propagated_event={"user": "test", "data": [1, 2, 3]}
186
+ object_id="test-session-id",
187
+ propagated_event={"user": "test", "data": [1, 2, 3]},
189
188
  )
190
189
 
191
190
  exported = components.to_str()
@@ -199,14 +198,14 @@ class TestSpanComponentsV4:
199
198
  """Test that non-UUID/hex strings are stored in JSON portion."""
200
199
  components = SpanComponentsV4(
201
200
  object_type=SpanObjectTypeV3.PROJECT_LOGS,
202
- object_id='not-a-uuid-or-hex', # Will be stored in JSON
201
+ object_id="not-a-uuid-or-hex", # Will be stored in JSON
203
202
  # Don't test row_id alone - if present, span_id and root_span_id must also be present
204
203
  )
205
204
 
206
205
  exported = components.to_str()
207
206
  imported = SpanComponentsV4.from_str(exported)
208
207
 
209
- assert imported.object_id == 'not-a-uuid-or-hex'
208
+ assert imported.object_id == "not-a-uuid-or-hex"
210
209
 
211
210
 
212
211
  class TestBackwardCompatibility:
@@ -221,7 +220,7 @@ class TestBackwardCompatibility:
221
220
  row_id=str(uuid4()),
222
221
  span_id=str(uuid4()),
223
222
  root_span_id=str(uuid4()),
224
- propagated_event={"version": "v3"}
223
+ propagated_event={"version": "v3"},
225
224
  )
226
225
 
227
226
  # Serialize with V3
@@ -238,7 +237,6 @@ class TestBackwardCompatibility:
238
237
  assert v4_imported.propagated_event == v3_components.propagated_event
239
238
 
240
239
 
241
-
242
240
  class TestErrorHandling:
243
241
  """Test error handling and edge cases."""
244
242
 
@@ -247,7 +245,7 @@ class TestErrorHandling:
247
245
  with pytest.raises(AssertionError):
248
246
  SpanComponentsV4(
249
247
  object_type="invalid_type", # Should be SpanObjectTypeV3 enum
250
- object_id="test-id"
248
+ object_id="test-id",
251
249
  )
252
250
 
253
251
  def test_missing_required_fields(self):
@@ -280,10 +278,7 @@ class TestErrorHandling:
280
278
  import base64
281
279
 
282
280
  # Create valid data then corrupt it
283
- components = SpanComponentsV4(
284
- object_type=SpanObjectTypeV3.PROJECT_LOGS,
285
- object_id="test-id"
286
- )
281
+ components = SpanComponentsV4(object_type=SpanObjectTypeV3.PROJECT_LOGS, object_id="test-id")
287
282
  valid_exported = components.to_str()
288
283
 
289
284
  # Decode, corrupt, re-encode
@@ -302,30 +297,21 @@ class TestObjectIdFields:
302
297
 
303
298
  def test_experiment_object_id_fields(self):
304
299
  """Test object_id_fields for experiment type."""
305
- components = SpanComponentsV4(
306
- object_type=SpanObjectTypeV3.EXPERIMENT,
307
- object_id="test-experiment-id"
308
- )
300
+ components = SpanComponentsV4(object_type=SpanObjectTypeV3.EXPERIMENT, object_id="test-experiment-id")
309
301
 
310
302
  fields = components.object_id_fields()
311
303
  assert fields == {"experiment_id": "test-experiment-id"}
312
304
 
313
305
  def test_project_logs_object_id_fields(self):
314
306
  """Test object_id_fields for project_logs type."""
315
- components = SpanComponentsV4(
316
- object_type=SpanObjectTypeV3.PROJECT_LOGS,
317
- object_id="test-project-id"
318
- )
307
+ components = SpanComponentsV4(object_type=SpanObjectTypeV3.PROJECT_LOGS, object_id="test-project-id")
319
308
 
320
309
  fields = components.object_id_fields()
321
310
  assert fields == {"project_id": "test-project-id", "log_id": "g"}
322
311
 
323
312
  def test_playground_logs_object_id_fields(self):
324
313
  """Test object_id_fields for playground_logs type."""
325
- components = SpanComponentsV4(
326
- object_type=SpanObjectTypeV3.PLAYGROUND_LOGS,
327
- object_id="test-session-id"
328
- )
314
+ components = SpanComponentsV4(object_type=SpanObjectTypeV3.PLAYGROUND_LOGS, object_id="test-session-id")
329
315
 
330
316
  fields = components.object_id_fields()
331
317
  assert fields == {"prompt_session_id": "test-session-id", "log_id": "x"}
@@ -333,8 +319,7 @@ class TestObjectIdFields:
333
319
  def test_object_id_fields_without_object_id(self):
334
320
  """Test that object_id_fields raises error without object_id."""
335
321
  components = SpanComponentsV4(
336
- object_type=SpanObjectTypeV3.PROJECT_LOGS,
337
- compute_object_metadata_args={"key": "value"}
322
+ object_type=SpanObjectTypeV3.PROJECT_LOGS, compute_object_metadata_args={"key": "value"}
338
323
  )
339
324
 
340
325
  with pytest.raises(Exception) as exc_info:
braintrust/util.py CHANGED
@@ -2,8 +2,9 @@ import inspect
2
2
  import sys
3
3
  import threading
4
4
  import urllib.parse
5
+ from collections.abc import Callable, Mapping
5
6
  from dataclasses import dataclass
6
- from typing import Any, Callable, Dict, Generic, Literal, Mapping, Optional, Set, Tuple, TypedDict, TypeVar, Union
7
+ from typing import Any, Generic, Literal, TypedDict, TypeVar, Union
7
8
 
8
9
  from requests import HTTPError, Response
9
10
 
@@ -29,8 +30,8 @@ def coalesce(*args):
29
30
 
30
31
 
31
32
  def merge_dicts_with_paths(
32
- merge_into: Dict[str, Any], merge_from: Mapping[str, Any], path: Tuple[str, ...], merge_paths: Set[Tuple[str]]
33
- ) -> Dict[str, Any]:
33
+ merge_into: dict[str, Any], merge_from: Mapping[str, Any], path: tuple[str, ...], merge_paths: set[tuple[str]]
34
+ ) -> dict[str, Any]:
34
35
  """Merges merge_from into merge_into, destructively updating merge_into. Does not merge any further than
35
36
  merge_paths."""
36
37
 
@@ -50,7 +51,7 @@ def merge_dicts_with_paths(
50
51
  return merge_into
51
52
 
52
53
 
53
- def merge_dicts(merge_into: Dict[str, Any], merge_from: Mapping[str, Any]) -> Dict[str, Any]:
54
+ def merge_dicts(merge_into: dict[str, Any], merge_from: Mapping[str, Any]) -> dict[str, Any]:
54
55
  """Merges merge_from into merge_into, destructively updating merge_into."""
55
56
 
56
57
  return merge_dicts_with_paths(merge_into, merge_from, (), set())
@@ -92,7 +93,7 @@ class CallerLocation(TypedDict):
92
93
  caller_lineno: int
93
94
 
94
95
 
95
- def get_caller_location() -> Optional[CallerLocation]:
96
+ def get_caller_location() -> CallerLocation | None:
96
97
  frame = inspect.currentframe()
97
98
  while frame:
98
99
  frame = frame.f_back
@@ -145,7 +146,7 @@ class LazyValue(Generic[T]):
145
146
  return self._state.has_succeeded
146
147
 
147
148
  @property
148
- def value(self) -> Optional[T]:
149
+ def value(self) -> T | None:
149
150
  return self._state.value if self._state.has_succeeded == True else None
150
151
 
151
152
  def get(self) -> T:
@@ -167,7 +168,7 @@ class LazyValue(Generic[T]):
167
168
  if self.mutex:
168
169
  self.mutex.release()
169
170
 
170
- def get_sync(self) -> Tuple[bool, Optional[T]]:
171
+ def get_sync(self) -> tuple[bool, T | None]:
171
172
  """Returns a tuple of (has_succeeded, value) without triggering evaluation."""
172
173
  if self._state.has_succeeded:
173
174
  # should be fine without the mutex check
@@ -206,7 +207,7 @@ def bt_iscoroutinefunction(f):
206
207
  return inspect.iscoroutinefunction(f) or inspect.isasyncgenfunction(f) or getattr(f, BT_IS_ASYNC_ATTRIBUTE, False)
207
208
 
208
209
 
209
- def add_azure_blob_headers(headers: Dict[str, str], url: str) -> None:
210
+ def add_azure_blob_headers(headers: dict[str, str], url: str) -> None:
210
211
  # According to https://stackoverflow.com/questions/37824136/put-on-sas-blob-url-without-specifying-x-ms-blob-type-header,
211
212
  # there is no way to avoid including this.
212
213
  if "blob.core.windows.net" in url:
braintrust/version.py CHANGED
@@ -1,4 +1,4 @@
1
- VERSION = "0.3.14"
1
+ VERSION = "0.4.0"
2
2
 
3
3
  # this will be templated during the build
4
- GIT_COMMIT = "dbbc1894ef31143816e5913676301261bc44aa4c"
4
+ GIT_COMMIT = "8ab13f3f48af6a4d3c0b053e4bbabfd4f24f23ec"
@@ -1,6 +1,6 @@
1
1
  """Shared utilities for Anthropic API wrappers."""
2
2
 
3
- from typing import Any, Dict
3
+ from typing import Any
4
4
 
5
5
 
6
6
  class Wrapper:
@@ -13,7 +13,7 @@ class Wrapper:
13
13
  return getattr(self.__wrapped, name)
14
14
 
15
15
 
16
- def extract_anthropic_usage(usage: Any) -> Dict[str, float]:
16
+ def extract_anthropic_usage(usage: Any) -> dict[str, float]:
17
17
  """Extract and normalize usage metrics from Anthropic usage object or dict.
18
18
 
19
19
  Converts Anthropic's usage format to Braintrust's standard token metric names.
@@ -29,7 +29,7 @@ def extract_anthropic_usage(usage: Any) -> Dict[str, float]:
29
29
  - prompt_cached_tokens (from cache_read_input_tokens)
30
30
  - prompt_cache_creation_tokens (from cache_creation_input_tokens)
31
31
  """
32
- metrics: Dict[str, float] = {}
32
+ metrics: dict[str, float] = {}
33
33
 
34
34
  if not usage:
35
35
  return metrics
@@ -73,7 +73,7 @@ def extract_anthropic_usage(usage: Any) -> Dict[str, float]:
73
73
  return metrics
74
74
 
75
75
 
76
- def finalize_anthropic_tokens(metrics: Dict[str, float]) -> Dict[str, float]:
76
+ def finalize_anthropic_tokens(metrics: dict[str, float]) -> dict[str, float]:
77
77
  """Finalize Anthropic token calculations.
78
78
 
79
79
  Anthropic doesn't include cache tokens in the total, so we need to sum them.
@@ -21,7 +21,6 @@ Usage:
21
21
  __all__ = ["setup_agno", "wrap_agent", "wrap_function_call", "wrap_model", "wrap_team"]
22
22
 
23
23
  import logging
24
- from typing import Optional
25
24
 
26
25
  from braintrust.logger import NOOP_SPAN, current_span, init_logger
27
26
 
@@ -34,9 +33,9 @@ logger = logging.getLogger(__name__)
34
33
 
35
34
 
36
35
  def setup_agno(
37
- api_key: Optional[str] = None,
38
- project_id: Optional[str] = None,
39
- project_name: Optional[str] = None,
36
+ api_key: str | None = None,
37
+ project_id: str | None = None,
38
+ project_name: str | None = None,
40
39
  ) -> bool:
41
40
  """
42
41
  Setup Braintrust integration with Agno. Will automatically patch Agno agents, models, and function calls for tracing.
@@ -1,10 +1,9 @@
1
1
  import time
2
2
  from typing import Any
3
3
 
4
- from wrapt import wrap_function_wrapper
5
-
6
4
  from braintrust.logger import start_span
7
5
  from braintrust.span_types import SpanTypeAttribute
6
+ from wrapt import wrap_function_wrapper
8
7
 
9
8
  from .utils import (
10
9
  _aggregate_agent_chunks,
@@ -1,9 +1,8 @@
1
1
  from typing import Any
2
2
 
3
- from wrapt import wrap_function_wrapper
4
-
5
3
  from braintrust.logger import start_span
6
4
  from braintrust.span_types import SpanTypeAttribute
5
+ from wrapt import wrap_function_wrapper
7
6
 
8
7
  from .utils import is_patched
9
8
 
@@ -5,10 +5,9 @@ ModelWrapper class for Braintrust-Agno model observability.
5
5
  import time
6
6
  from typing import Any
7
7
 
8
- from wrapt import wrap_function_wrapper
9
-
10
8
  from braintrust.logger import start_span
11
9
  from braintrust.span_types import SpanTypeAttribute
10
+ from wrapt import wrap_function_wrapper
12
11
 
13
12
  from .utils import (
14
13
  _aggregate_model_chunks,
@@ -1,10 +1,9 @@
1
1
  import time
2
2
  from typing import Any
3
3
 
4
- from wrapt import wrap_function_wrapper
5
-
6
4
  from braintrust.logger import start_span
7
5
  from braintrust.span_types import SpanTypeAttribute
6
+ from wrapt import wrap_function_wrapper
8
7
 
9
8
  from .utils import (
10
9
  _aggregate_agent_chunks,
@@ -1,8 +1,8 @@
1
1
  import time
2
- from typing import Any, Dict, List, Optional
2
+ from typing import Any
3
3
 
4
4
 
5
- def omit(obj: Dict[str, Any], keys: List[str]):
5
+ def omit(obj: dict[str, Any], keys: list[str]):
6
6
  return {k: v for k, v in obj.items() if k not in keys}
7
7
 
8
8
 
@@ -14,11 +14,11 @@ def mark_patched(obj: Any):
14
14
  setattr(obj, "_braintrust_patched", True)
15
15
 
16
16
 
17
- def clean(obj: Dict[str, Any]) -> Dict[str, Any]:
17
+ def clean(obj: dict[str, Any]) -> dict[str, Any]:
18
18
  return {k: v for k, v in obj.items() if v is not None}
19
19
 
20
20
 
21
- def get_args_kwargs(args: List[str], kwargs: Dict[str, Any], keys: List[str]):
21
+ def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: list[str]):
22
22
  return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys)
23
23
 
24
24
 
@@ -71,7 +71,7 @@ AGNO_METRICS_MAP = {
71
71
  }
72
72
 
73
73
 
74
- def extract_metadata(instance: Any, component: str) -> Dict[str, Any]:
74
+ def extract_metadata(instance: Any, component: str) -> dict[str, Any]:
75
75
  """Extract metadata from any component (model, agent, team)."""
76
76
  metadata = {"component": component}
77
77
 
@@ -100,7 +100,7 @@ def extract_metadata(instance: Any, component: str) -> Dict[str, Any]:
100
100
  return metadata
101
101
 
102
102
 
103
- def parse_metrics_from_agno(usage: Any) -> Dict[str, Any]:
103
+ def parse_metrics_from_agno(usage: Any) -> dict[str, Any]:
104
104
  """Parse metrics from Agno usage object, following OpenAI wrapper pattern."""
105
105
  metrics = {}
106
106
 
@@ -121,7 +121,7 @@ def parse_metrics_from_agno(usage: Any) -> Dict[str, Any]:
121
121
  return metrics
122
122
 
123
123
 
124
- def extract_metrics(result: Any, messages: Optional[list] = None) -> Dict[str, Any]:
124
+ def extract_metrics(result: Any, messages: list | None = None) -> dict[str, Any]:
125
125
  """
126
126
  Unified metrics extraction for all components.
127
127
 
@@ -163,7 +163,7 @@ def extract_metrics(result: Any, messages: Optional[list] = None) -> Dict[str, A
163
163
  return {}
164
164
 
165
165
 
166
- def extract_streaming_metrics(aggregated: Dict[str, Any], start_time: float) -> Optional[Dict[str, Any]]:
166
+ def extract_streaming_metrics(aggregated: dict[str, Any], start_time: float) -> dict[str, Any] | None:
167
167
  """Extract metrics from aggregated streaming response."""
168
168
  metrics = {}
169
169
 
@@ -187,7 +187,7 @@ def extract_streaming_metrics(aggregated: Dict[str, Any], start_time: float) ->
187
187
  return metrics if metrics else None
188
188
 
189
189
 
190
- def _aggregate_metrics(target: Dict[str, Any], source: Dict[str, Any]) -> None:
190
+ def _aggregate_metrics(target: dict[str, Any], source: dict[str, Any]) -> None:
191
191
  """Aggregate metrics from source into target dict."""
192
192
  for key, value in source.items():
193
193
  if _is_numeric(value):
@@ -205,7 +205,7 @@ def _aggregate_metrics(target: Dict[str, Any], source: Dict[str, Any]) -> None:
205
205
  target[key] = value
206
206
 
207
207
 
208
- def _aggregate_model_chunks(chunks: List[Any]) -> Dict[str, Any]:
208
+ def _aggregate_model_chunks(chunks: list[Any]) -> dict[str, Any]:
209
209
  """Aggregate ModelResponse chunks from invoke_stream into a complete response."""
210
210
  aggregated = {
211
211
  "content": "",
@@ -263,7 +263,7 @@ def _aggregate_model_chunks(chunks: List[Any]) -> Dict[str, Any]:
263
263
  return aggregated
264
264
 
265
265
 
266
- def _aggregate_response_stream_chunks(chunks: List[Any]) -> Dict[str, Any]:
266
+ def _aggregate_response_stream_chunks(chunks: list[Any]) -> dict[str, Any]:
267
267
  """
268
268
  Aggregate chunks from response_stream which can be ModelResponse, RunOutputEvent, or TeamRunOutputEvent.
269
269
 
@@ -344,7 +344,7 @@ def _aggregate_response_stream_chunks(chunks: List[Any]) -> Dict[str, Any]:
344
344
  return aggregated
345
345
 
346
346
 
347
- def _aggregate_agent_chunks(chunks: List[Any]) -> Dict[str, Any]:
347
+ def _aggregate_agent_chunks(chunks: list[Any]) -> dict[str, Any]:
348
348
  """Aggregate BaseAgentRunEvent/BaseTeamRunEvent chunks into a complete response."""
349
349
  aggregated = {
350
350
  "content": "",
@@ -2,7 +2,6 @@ import logging
2
2
  import time
3
3
  import warnings
4
4
  from contextlib import contextmanager
5
- from typing import Optional
6
5
 
7
6
  from braintrust.logger import NOOP_SPAN, log_exc_info_to_span, start_span
8
7
  from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens
@@ -10,7 +9,6 @@ from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usag
10
9
  log = logging.getLogger(__name__)
11
10
 
12
11
 
13
-
14
12
  # This tracer depends on an internal anthropic method used to merge
15
13
  # streamed messages together. It's a bit tricky so I'm opting to use it
16
14
  # here. If it goes away, this polyfill will make it a no-op and the only
@@ -242,7 +240,7 @@ class TracedMessageStream(Wrapper):
242
240
  self.__metrics = {}
243
241
  self.__snapshot = None
244
242
  self.__request_start_time = request_start_time
245
- self.__time_to_first_token: Optional[float] = None
243
+ self.__time_to_first_token: float | None = None
246
244
 
247
245
  def _get_final_traced_message(self):
248
246
  return self.__snapshot
@@ -314,7 +312,7 @@ def _start_span(name, kwargs):
314
312
  return NOOP_SPAN
315
313
 
316
314
 
317
- def _log_message_to_span(message, span, time_to_first_token: Optional[float] = None):
315
+ def _log_message_to_span(message, span, time_to_first_token: float | None = None):
318
316
  """Log telemetry from the given anthropic.Message to the given span."""
319
317
  with _catch_exceptions():
320
318
  usage = getattr(message, "usage", {})
@@ -326,13 +324,14 @@ def _log_message_to_span(message, span, time_to_first_token: Optional[float] = N
326
324
 
327
325
  # Create output dict with only truthy values for role and content
328
326
  output = {
329
- k: v for k, v in {
330
- "role": getattr(message, "role", None),
331
- "content": getattr(message, "content", None)
332
- }.items() if v
327
+ k: v
328
+ for k, v in {"role": getattr(message, "role", None), "content": getattr(message, "content", None)}.items()
329
+ if v
333
330
  } or None
334
331
 
335
332
  span.log(output=output, metrics=metrics)
333
+
334
+
336
335
  @contextmanager
337
336
  def _catch_exceptions():
338
337
  try:
@@ -16,7 +16,6 @@ Usage (imports can be before or after setup):
16
16
  """
17
17
 
18
18
  import logging
19
- from typing import Optional
20
19
 
21
20
  from braintrust.logger import NOOP_SPAN, current_span, init_logger
22
21
 
@@ -28,9 +27,9 @@ __all__ = ["setup_claude_agent_sdk"]
28
27
 
29
28
 
30
29
  def setup_claude_agent_sdk(
31
- api_key: Optional[str] = None,
32
- project_id: Optional[str] = None,
33
- project: Optional[str] = None,
30
+ api_key: str | None = None,
31
+ project_id: str | None = None,
32
+ project: str | None = None,
34
33
  ) -> bool:
35
34
  """
36
35
  Setup Braintrust integration with Claude Agent SDK. Will automatically patch the SDK for automatic tracing.