braintrust 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. braintrust/_generated_types.py +737 -672
  2. braintrust/audit.py +2 -2
  3. braintrust/bt_json.py +178 -19
  4. braintrust/cli/eval.py +6 -7
  5. braintrust/cli/push.py +11 -11
  6. braintrust/context.py +12 -17
  7. braintrust/contrib/temporal/__init__.py +16 -27
  8. braintrust/contrib/temporal/test_temporal.py +8 -3
  9. braintrust/devserver/auth.py +8 -8
  10. braintrust/devserver/cache.py +3 -4
  11. braintrust/devserver/cors.py +8 -7
  12. braintrust/devserver/dataset.py +3 -5
  13. braintrust/devserver/eval_hooks.py +7 -6
  14. braintrust/devserver/schemas.py +22 -19
  15. braintrust/devserver/server.py +19 -12
  16. braintrust/devserver/test_cached_login.py +4 -4
  17. braintrust/framework.py +139 -142
  18. braintrust/framework2.py +88 -87
  19. braintrust/functions/invoke.py +66 -59
  20. braintrust/functions/stream.py +3 -2
  21. braintrust/generated_types.py +3 -1
  22. braintrust/git_fields.py +11 -11
  23. braintrust/gitutil.py +2 -3
  24. braintrust/graph_util.py +10 -10
  25. braintrust/id_gen.py +2 -2
  26. braintrust/logger.py +373 -471
  27. braintrust/merge_row_batch.py +10 -9
  28. braintrust/oai.py +21 -20
  29. braintrust/otel/__init__.py +49 -49
  30. braintrust/otel/context.py +16 -30
  31. braintrust/otel/test_distributed_tracing.py +14 -11
  32. braintrust/otel/test_otel_bt_integration.py +32 -31
  33. braintrust/parameters.py +8 -8
  34. braintrust/prompt.py +14 -14
  35. braintrust/prompt_cache/disk_cache.py +5 -4
  36. braintrust/prompt_cache/lru_cache.py +3 -2
  37. braintrust/prompt_cache/prompt_cache.py +13 -14
  38. braintrust/queue.py +4 -4
  39. braintrust/score.py +4 -4
  40. braintrust/serializable_data_class.py +4 -4
  41. braintrust/span_identifier_v1.py +1 -2
  42. braintrust/span_identifier_v2.py +3 -4
  43. braintrust/span_identifier_v3.py +23 -20
  44. braintrust/span_identifier_v4.py +34 -25
  45. braintrust/test_bt_json.py +644 -0
  46. braintrust/test_framework.py +72 -6
  47. braintrust/test_helpers.py +5 -5
  48. braintrust/test_id_gen.py +2 -3
  49. braintrust/test_logger.py +211 -107
  50. braintrust/test_otel.py +61 -53
  51. braintrust/test_queue.py +0 -1
  52. braintrust/test_score.py +1 -3
  53. braintrust/test_span_components.py +29 -44
  54. braintrust/util.py +9 -8
  55. braintrust/version.py +2 -2
  56. braintrust/wrappers/_anthropic_utils.py +4 -4
  57. braintrust/wrappers/agno/__init__.py +3 -4
  58. braintrust/wrappers/agno/agent.py +1 -2
  59. braintrust/wrappers/agno/function_call.py +1 -2
  60. braintrust/wrappers/agno/model.py +1 -2
  61. braintrust/wrappers/agno/team.py +1 -2
  62. braintrust/wrappers/agno/utils.py +12 -12
  63. braintrust/wrappers/anthropic.py +7 -8
  64. braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
  65. braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
  66. braintrust/wrappers/dspy.py +15 -17
  67. braintrust/wrappers/google_genai/__init__.py +17 -30
  68. braintrust/wrappers/langchain.py +22 -24
  69. braintrust/wrappers/litellm.py +4 -3
  70. braintrust/wrappers/openai.py +15 -15
  71. braintrust/wrappers/pydantic_ai.py +225 -110
  72. braintrust/wrappers/test_agno.py +0 -1
  73. braintrust/wrappers/test_dspy.py +0 -1
  74. braintrust/wrappers/test_google_genai.py +64 -4
  75. braintrust/wrappers/test_litellm.py +0 -1
  76. braintrust/wrappers/test_pydantic_ai_integration.py +819 -22
  77. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/METADATA +3 -2
  78. braintrust-0.4.1.dist-info/RECORD +121 -0
  79. braintrust-0.3.15.dist-info/RECORD +0 -120
  80. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/WHEEL +0 -0
  81. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/entry_points.txt +0 -0
  82. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/top_level.txt +0 -0
braintrust/audit.py CHANGED
@@ -5,7 +5,7 @@ Utilities for working with audit headers.
5
5
  import base64
6
6
  import gzip
7
7
  import json
8
- from typing import List, TypedDict
8
+ from typing import TypedDict
9
9
 
10
10
 
11
11
  class AuditResource(TypedDict):
@@ -14,7 +14,7 @@ class AuditResource(TypedDict):
14
14
  name: str
15
15
 
16
16
 
17
- def parse_audit_resources(hdr: str) -> List[AuditResource]:
17
+ def parse_audit_resources(hdr: str) -> list[AuditResource]:
18
18
  j = json.loads(hdr)
19
19
  if j["v"] == 1:
20
20
  return json.loads(gzip.decompress(base64.b64decode(j["p"])))
braintrust/bt_json.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import dataclasses
2
2
  import json
3
- from typing import Any, cast
3
+ import math
4
+ from typing import Any, Callable, Mapping, NamedTuple, cast, overload
4
5
 
5
6
  # Try to import orjson for better performance
6
7
  # If not available, we'll use standard json
@@ -12,39 +13,184 @@ except ImportError:
12
13
  _HAS_ORJSON = False
13
14
 
14
15
 
15
- def _to_dict(obj: Any) -> Any:
16
- """
17
- Function-based default handler for non-JSON-serializable objects.
18
16
 
19
- Handles:
20
- - dataclasses
21
- - Pydantic v2 BaseModel
22
- - Pydantic v1 BaseModel
23
- - Falls back to str() for unknown types
17
+ def _to_bt_safe(v: Any) -> Any:
18
+ """
19
+ Converts the object to a Braintrust-safe representation (i.e. Attachment objects are safe (specially handled by background logger)).
24
20
  """
25
- if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
26
- return dataclasses.asdict(obj)
21
+ # avoid circular imports
22
+ from braintrust.logger import BaseAttachment, Dataset, Experiment, Logger, ReadonlyAttachment, Span
23
+
24
+ if isinstance(v, Span):
25
+ return "<span>"
26
+
27
+ if isinstance(v, Experiment):
28
+ return "<experiment>"
29
+
30
+ if isinstance(v, Dataset):
31
+ return "<dataset>"
32
+
33
+ if isinstance(v, Logger):
34
+ return "<logger>"
35
+
36
+ if isinstance(v, BaseAttachment):
37
+ return v
38
+
39
+ if isinstance(v, ReadonlyAttachment):
40
+ return v.reference
41
+
42
+ if dataclasses.is_dataclass(v) and not isinstance(v, type):
43
+ # Use manual field iteration instead of dataclasses.asdict() because
44
+ # asdict() deep-copies values, which breaks objects like Attachment
45
+ # that contain non-copyable items (thread locks, file handles, etc.)
46
+ return {f.name: _to_bt_safe(getattr(v, f.name)) for f in dataclasses.fields(v)}
47
+
48
+ # Pydantic model classes (not instances) with model_json_schema
49
+ if isinstance(v, type) and hasattr(v, "model_json_schema") and callable(cast(Any, v).model_json_schema):
50
+ try:
51
+ return cast(Any, v).model_json_schema()
52
+ except Exception:
53
+ pass
27
54
 
28
55
  # Attempt to dump a Pydantic v2 `BaseModel`.
29
56
  try:
30
- return cast(Any, obj).model_dump()
57
+ return cast(Any, v).model_dump(exclude_none=True)
31
58
  except (AttributeError, TypeError):
32
59
  pass
33
60
 
34
61
  # Attempt to dump a Pydantic v1 `BaseModel`.
35
62
  try:
36
- return cast(Any, obj).dict()
63
+ return cast(Any, v).dict(exclude_none=True)
37
64
  except (AttributeError, TypeError):
38
65
  pass
39
66
 
40
- # When everything fails, try to return the string representation of the object
67
+ if isinstance(v, float):
68
+ # Handle NaN and Infinity for JSON compatibility
69
+ if math.isnan(v):
70
+ return "NaN"
71
+
72
+ if math.isinf(v):
73
+ return "Infinity" if v > 0 else "-Infinity"
74
+
75
+ return v
76
+
77
+ if isinstance(v, (int, str, bool)) or v is None:
78
+ # Skip roundtrip for primitive types.
79
+ return v
80
+
81
+ # Note: we avoid using copy.deepcopy, because it's difficult to
82
+ # guarantee the independence of such copied types from their origin.
83
+ # E.g. the original type could have a `__del__` method that alters
84
+ # some shared internal state, and we need this deep copy to be
85
+ # fully-independent from the original.
86
+
87
+ # We pass `encoder=_str_encoder` since we've already tried converting rich objects to json safe objects.
88
+ return bt_loads(bt_dumps(v, encoder=_str_encoder))
89
+
90
+ @overload
91
+ def bt_safe_deep_copy(
92
+ obj: Mapping[str, Any],
93
+ max_depth: int = ...,
94
+ ) -> dict[str, Any]: ...
95
+
96
+ @overload
97
+ def bt_safe_deep_copy(
98
+ obj: list[Any],
99
+ max_depth: int = ...,
100
+ ) -> list[Any]: ...
101
+
102
+ @overload
103
+ def bt_safe_deep_copy(
104
+ obj: Any,
105
+ max_depth: int = ...,
106
+ ) -> Any: ...
107
+ def bt_safe_deep_copy(obj: Any, max_depth: int=200):
108
+ """
109
+ Creates a deep copy of the given object and converts rich objects to Braintrust-safe representations. See `_to_bt_safe` for more details.
110
+
111
+ Args:
112
+ obj: Object to deep copy and sanitize.
113
+ to_json_safe: Function to ensure the object is json safe.
114
+ max_depth: Maximum depth to copy.
115
+
116
+ Returns:
117
+ Deep copy of the object.
118
+ """
119
+ # Track visited objects to detect circular references
120
+ visited: set[int] = set()
121
+
122
+ def _deep_copy_object(v: Any, depth: int = 0) -> Any:
123
+ # Check depth limit - use >= to stop before exceeding
124
+ if depth >= max_depth:
125
+ return "<max depth exceeded>"
126
+
127
+ # Check for circular references in mutable containers
128
+ # Use id() to track object identity
129
+ if isinstance(v, (Mapping, list, tuple, set)):
130
+ obj_id = id(v)
131
+ if obj_id in visited:
132
+ return "<circular reference>"
133
+ visited.add(obj_id)
134
+ try:
135
+ if isinstance(v, Mapping):
136
+ # Prevent dict keys from holding references to user data. Note that
137
+ # `bt_json` already coerces keys to string, a behavior that comes from
138
+ # `json.dumps`. However, that runs at log upload time, while we want to
139
+ # cut out all the references to user objects synchronously in this
140
+ # function.
141
+ result = {}
142
+ for k in v:
143
+ try:
144
+ key_str = str(k)
145
+ except Exception:
146
+ # If str() fails on the key, use a fallback representation
147
+ key_str = f"<non-stringifiable-key: {type(k).__name__}>"
148
+ result[key_str] = _deep_copy_object(v[k], depth + 1)
149
+ return result
150
+ elif isinstance(v, (list, tuple, set)):
151
+ return [_deep_copy_object(x, depth + 1) for x in v]
152
+ finally:
153
+ # Remove from visited set after processing to allow the same object
154
+ # to appear in different branches of the tree
155
+ visited.discard(obj_id)
156
+
157
+ try:
158
+ return _to_bt_safe(v)
159
+ except Exception:
160
+ return f"<non-sanitizable: {type(v).__name__}>"
161
+
162
+ return _deep_copy_object(obj)
163
+
164
+ def _safe_str(obj: Any) -> str:
41
165
  try:
42
166
  return str(obj)
43
167
  except Exception:
44
- # If str() fails, return an error placeholder
45
168
  return f"<non-serializable: {type(obj).__name__}>"
46
169
 
47
170
 
171
+ def _to_json_safe(obj: Any) -> Any:
172
+ """
173
+ Handler for non-JSON-serializable objects. Returns a string representation of the object.
174
+ """
175
+ # avoid circular imports
176
+ from braintrust.logger import BaseAttachment
177
+
178
+ try:
179
+ v = _to_bt_safe(obj)
180
+
181
+ # JSON-safe representation of Attachment objects are their reference.
182
+ # If we get this object at this point, we have to assume someone has already uploaded the attachment!
183
+ if isinstance(v, BaseAttachment):
184
+ v = v.reference
185
+
186
+ return v
187
+ except Exception:
188
+ pass
189
+
190
+ # When everything fails, try to return the string representation of the object
191
+ return _safe_str(obj)
192
+
193
+
48
194
  class BraintrustJSONEncoder(json.JSONEncoder):
49
195
  """
50
196
  Custom JSON encoder for standard json library.
@@ -53,10 +199,22 @@ class BraintrustJSONEncoder(json.JSONEncoder):
53
199
  """
54
200
 
55
201
  def default(self, o: Any):
56
- return _to_dict(o)
202
+ return _to_json_safe(o)
203
+
204
+
205
+ class BraintrustStrEncoder(json.JSONEncoder):
206
+ def default(self, o: Any):
207
+ return _safe_str(o)
208
+
209
+
210
+ class Encoder(NamedTuple):
211
+ native: type[json.JSONEncoder]
212
+ orjson: Callable[[Any], Any]
57
213
 
214
+ _json_encoder = Encoder(native=BraintrustJSONEncoder, orjson=_to_json_safe)
215
+ _str_encoder = Encoder(native=BraintrustStrEncoder, orjson=_safe_str)
58
216
 
59
- def bt_dumps(obj, **kwargs) -> str:
217
+ def bt_dumps(obj: Any, encoder: Encoder | None=_json_encoder, **kwargs: Any) -> str:
60
218
  """
61
219
  Serialize obj to a JSON-formatted string.
62
220
 
@@ -65,6 +223,7 @@ def bt_dumps(obj, **kwargs) -> str:
65
223
 
66
224
  Args:
67
225
  obj: Object to serialize
226
+ encoder: Encoder to use, defaults to `_default_encoder`
68
227
  **kwargs: Additional arguments (passed to json.dumps in fallback path)
69
228
 
70
229
  Returns:
@@ -76,7 +235,7 @@ def bt_dumps(obj, **kwargs) -> str:
76
235
  # pylint: disable=no-member # orjson is a C extension, pylint can't introspect it
77
236
  return orjson.dumps( # type: ignore[possibly-unbound]
78
237
  obj,
79
- default=_to_dict,
238
+ default=encoder.orjson if encoder else None,
80
239
  # options match json.dumps behavior for bc
81
240
  option=orjson.OPT_SORT_KEYS | orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS, # type: ignore[possibly-unbound]
82
241
  ).decode("utf-8")
@@ -86,7 +245,7 @@ def bt_dumps(obj, **kwargs) -> str:
86
245
 
87
246
  # Use standard json (either orjson not available or it failed)
88
247
  # Use sort_keys=True for deterministic output (matches orjson OPT_SORT_KEYS)
89
- return json.dumps(obj, cls=BraintrustJSONEncoder, allow_nan=False, sort_keys=True, **kwargs)
248
+ return json.dumps(obj, cls=encoder.native if encoder else None, allow_nan=False, sort_keys=True, **kwargs)
90
249
 
91
250
 
92
251
  def bt_loads(s: str, **kwargs) -> Any:
braintrust/cli/eval.py CHANGED
@@ -6,7 +6,6 @@ import os
6
6
  import sys
7
7
  from dataclasses import dataclass, field
8
8
  from threading import Lock
9
- from typing import Dict, List, Optional, Union
10
9
 
11
10
  from .. import login
12
11
  from ..framework import (
@@ -70,7 +69,7 @@ class EvaluatorOpts:
70
69
  no_progress_bars: bool
71
70
  terminate_on_failure: bool
72
71
  watch: bool
73
- filters: List[str]
72
+ filters: list[str]
74
73
  list: bool
75
74
  jsonl: bool
76
75
 
@@ -79,13 +78,13 @@ class EvaluatorOpts:
79
78
  class LoadedEvaluator:
80
79
  handle: FileHandle
81
80
  evaluator: Evaluator
82
- reporter: Optional[Union[ReporterDef, str]] = None
81
+ reporter: ReporterDef | str | None = None
83
82
 
84
83
 
85
84
  @dataclass
86
85
  class EvaluatorState:
87
- evaluators: List[LoadedEvaluator] = field(default_factory=list)
88
- reporters: Dict[str, ReporterDef] = field(default_factory=dict)
86
+ evaluators: list[LoadedEvaluator] = field(default_factory=list)
87
+ reporters: dict[str, ReporterDef] = field(default_factory=dict)
89
88
 
90
89
 
91
90
  def update_evaluators(eval_state: EvaluatorState, handles, terminate_on_failure):
@@ -157,7 +156,7 @@ async def run_evaluator_task(evaluator, position, opts: EvaluatorOpts):
157
156
  experiment.flush()
158
157
 
159
158
 
160
- def resolve_reporter(reporter: Optional[Union[ReporterDef, str]], reporters: Dict[str, ReporterDef]) -> ReporterDef:
159
+ def resolve_reporter(reporter: ReporterDef | str | None, reporters: dict[str, ReporterDef]) -> ReporterDef:
161
160
  if isinstance(reporter, str):
162
161
  if reporter not in reporters:
163
162
  raise ValueError(f"Reporter {reporter} not found")
@@ -179,7 +178,7 @@ def add_report(eval_reports, reporter, report):
179
178
  eval_reports[reporter.name]["results"].append(report)
180
179
 
181
180
 
182
- async def run_once(handles: List[FileHandle], evaluator_opts: EvaluatorOpts) -> bool:
181
+ async def run_once(handles: list[FileHandle], evaluator_opts: EvaluatorOpts) -> bool:
183
182
  objects = EvaluatorState()
184
183
  update_evaluators(objects, handles, terminate_on_failure=evaluator_opts.terminate_on_failure)
185
184
 
braintrust/cli/push.py CHANGED
@@ -13,7 +13,7 @@ import sys
13
13
  import tempfile
14
14
  import textwrap
15
15
  import zipfile
16
- from typing import Any, Dict, List, Optional
16
+ from typing import Any
17
17
 
18
18
  import requests
19
19
  from braintrust.framework import _set_lazy_load
@@ -24,7 +24,7 @@ from ..generated_types import IfExists
24
24
  from ..util import add_azure_blob_headers
25
25
 
26
26
 
27
- def _pkg_install_arg(pkg) -> Optional[str]:
27
+ def _pkg_install_arg(pkg) -> str | None:
28
28
  try:
29
29
  dist = importlib.metadata.distribution(pkg)
30
30
  direct_url = dist._path / "direct_url.json" # type: ignore
@@ -73,11 +73,11 @@ class _ProjectRootImporter(importlib.abc.MetaPathFinder):
73
73
  self._project_root, self._path_rest = sys.path[0], sys.path[1:]
74
74
  self._sources = []
75
75
 
76
- def _under_project_root(self, path: List[str]) -> bool:
76
+ def _under_project_root(self, path: list[str]) -> bool:
77
77
  """Returns true if all paths in `path` are under the project root."""
78
78
  return all(p.startswith(self._project_root) for p in path)
79
79
 
80
- def _under_rest(self, path: List[str]) -> bool:
80
+ def _under_rest(self, path: list[str]) -> bool:
81
81
  """Returns true if any path in `path` is under one of the remaining paths in `sys.path`."""
82
82
  return any(p.startswith(pr) for p in path for pr in self._path_rest)
83
83
 
@@ -94,11 +94,11 @@ class _ProjectRootImporter(importlib.abc.MetaPathFinder):
94
94
  self._sources.append(spec.origin)
95
95
  return spec
96
96
 
97
- def sources(self) -> List[str]:
97
+ def sources(self) -> list[str]:
98
98
  return self._sources
99
99
 
100
100
 
101
- def _import_module(name: str, path: str) -> List[str]:
101
+ def _import_module(name: str, path: str) -> list[str]:
102
102
  """Imports the module and returns the list of source files
103
103
  of all modules imported in the process.
104
104
 
@@ -120,7 +120,7 @@ def _py_version() -> str:
120
120
  return f"{sys.version_info.major}.{sys.version_info.minor}"
121
121
 
122
122
 
123
- def _run_install(install_args: List[str], packages_dir: str):
123
+ def _run_install(install_args: list[str], packages_dir: str):
124
124
  subprocess.run(
125
125
  [
126
126
  "uv",
@@ -138,7 +138,7 @@ def _run_install(install_args: List[str], packages_dir: str):
138
138
  )
139
139
 
140
140
 
141
- def _upload_bundle(entry_module_name: str, sources: List[str], requirements: Optional[str]) -> str:
141
+ def _upload_bundle(entry_module_name: str, sources: list[str], requirements: str | None) -> str:
142
142
  _check_uv()
143
143
 
144
144
  resp = proxy_conn().post_json(
@@ -212,7 +212,7 @@ def _upload_bundle(entry_module_name: str, sources: List[str], requirements: Opt
212
212
 
213
213
 
214
214
  def _collect_function_function_defs(
215
- project_ids: ProjectIdCache, functions: List[Dict[str, Any]], bundle_id: str, if_exists: IfExists
215
+ project_ids: ProjectIdCache, functions: list[dict[str, Any]], bundle_id: str, if_exists: IfExists
216
216
  ) -> None:
217
217
  for i, f in enumerate(global_.functions):
218
218
  source = inspect.getsource(f.handler)
@@ -262,7 +262,7 @@ def _collect_function_function_defs(
262
262
 
263
263
 
264
264
  def _collect_prompt_function_defs(
265
- project_ids: ProjectIdCache, functions: List[Dict[str, Any]], if_exists: IfExists
265
+ project_ids: ProjectIdCache, functions: list[dict[str, Any]], if_exists: IfExists
266
266
  ) -> None:
267
267
  for p in global_.prompts:
268
268
  functions.append(p.to_function_definition(if_exists, project_ids))
@@ -300,7 +300,7 @@ def run(args):
300
300
  raise
301
301
 
302
302
  project_ids = ProjectIdCache()
303
- functions: List[Dict[str, Any]] = []
303
+ functions: list[dict[str, Any]] = []
304
304
  if len(global_.functions) > 0:
305
305
  bundle_id = _upload_bundle(module_name, sources, args.requirements)
306
306
  _collect_function_function_defs(project_ids, functions, bundle_id, args.if_exists)
braintrust/context.py CHANGED
@@ -5,12 +5,13 @@ import os
5
5
  from abc import ABC, abstractmethod
6
6
  from contextvars import ContextVar
7
7
  from dataclasses import dataclass
8
- from typing import Any, List, Optional
8
+ from typing import Any
9
9
 
10
10
 
11
11
  @dataclass
12
12
  class SpanInfo:
13
13
  """Information about a span in the context."""
14
+
14
15
  trace_id: str
15
16
  span_id: str
16
17
  span_object: Any = None
@@ -19,7 +20,7 @@ class SpanInfo:
19
20
  @dataclass
20
21
  class ParentSpanIds:
21
22
  root_span_id: str
22
- span_parents: List[str]
23
+ span_parents: list[str]
23
24
 
24
25
 
25
26
  class ContextManager(ABC):
@@ -30,7 +31,7 @@ class ContextManager(ABC):
30
31
  """
31
32
 
32
33
  @abstractmethod
33
- def get_current_span_info(self) -> Optional[Any]:
34
+ def get_current_span_info(self) -> Any | None:
34
35
  """Get information about the currently active span.
35
36
 
36
37
  Returns:
@@ -40,7 +41,7 @@ class ContextManager(ABC):
40
41
  pass
41
42
 
42
43
  @abstractmethod
43
- def get_parent_span_ids(self) -> Optional[ParentSpanIds]:
44
+ def get_parent_span_ids(self) -> ParentSpanIds | None:
44
45
  """Get parent span IDs for creating a new Braintrust span.
45
46
 
46
47
  Returns:
@@ -75,32 +76,25 @@ class BraintrustContextManager(ContextManager):
75
76
  """Braintrust-only context manager using contextvars when OTEL is not available."""
76
77
 
77
78
  def __init__(self):
78
- self._current_span: ContextVar[Optional[Any]] = ContextVar('braintrust_current_span', default=None)
79
+ self._current_span: ContextVar[Any | None] = ContextVar("braintrust_current_span", default=None)
79
80
 
80
- def get_current_span_info(self) -> Optional[SpanInfo]:
81
+ def get_current_span_info(self) -> SpanInfo | None:
81
82
  """Get information about the currently active span."""
82
83
  current_span = self._current_span.get()
83
84
  if not current_span:
84
85
  return None
85
86
 
86
87
  # Return SpanInfo for BT spans
87
- return SpanInfo(
88
- trace_id=current_span.root_span_id,
89
- span_id=current_span.span_id,
90
- span_object=current_span
91
- )
88
+ return SpanInfo(trace_id=current_span.root_span_id, span_id=current_span.span_id, span_object=current_span)
92
89
 
93
- def get_parent_span_ids(self) -> Optional[ParentSpanIds]:
90
+ def get_parent_span_ids(self) -> ParentSpanIds | None:
94
91
  """Get parent information for creating a new Braintrust span."""
95
92
  current_span = self._current_span.get()
96
93
  if not current_span:
97
94
  return None
98
95
 
99
96
  # If current span is a BT span, use it as parent
100
- return ParentSpanIds(
101
- root_span_id=current_span.root_span_id,
102
- span_parents=[current_span.span_id]
103
- )
97
+ return ParentSpanIds(root_span_id=current_span.root_span_id, span_parents=[current_span.span_id])
104
98
 
105
99
  def set_current_span(self, span_object: Any) -> Any:
106
100
  """Set the current active span."""
@@ -123,9 +117,10 @@ def get_context_manager() -> ContextManager:
123
117
  """
124
118
 
125
119
  # Check if OTEL should be explicitly enabled via environment variable
126
- if os.environ.get('BRAINTRUST_OTEL_COMPAT', '').lower() in ('1', 'true', 'yes'):
120
+ if os.environ.get("BRAINTRUST_OTEL_COMPAT", "").lower() in ("1", "true", "yes"):
127
121
  try:
128
122
  from braintrust.otel.context import ContextManager as OtelContextManager
123
+
129
124
  return OtelContextManager()
130
125
  except ImportError:
131
126
  logging.warning("OTEL not available, falling back to Braintrust-only version")
@@ -80,7 +80,8 @@ The integration will automatically:
80
80
  """
81
81
 
82
82
  import dataclasses
83
- from typing import Any, Dict, Mapping, Optional, Type
83
+ from collections.abc import Mapping
84
+ from typing import Any
84
85
 
85
86
  import braintrust
86
87
  import temporalio.activity
@@ -129,7 +130,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
129
130
  - Ensures replay safety (no duplicate spans during workflow replay)
130
131
  """
131
132
 
132
- def __init__(self, logger: Optional[Any] = None) -> None:
133
+ def __init__(self, logger: Any | None = None) -> None:
133
134
  """Initialize interceptor.
134
135
 
135
136
  Args:
@@ -142,7 +143,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
142
143
  braintrust.logger._state._override_bg_logger.logger = logger
143
144
  self._logger = braintrust.current_logger()
144
145
 
145
- def _get_logger(self) -> Optional[Any]:
146
+ def _get_logger(self) -> Any | None:
146
147
  """Get logger for creating spans.
147
148
 
148
149
  Sets thread-local override if background logger provided (for testing),
@@ -152,9 +153,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
152
153
  braintrust.logger._state._override_bg_logger.logger = self._bg_logger
153
154
  return self._logger
154
155
 
155
- def intercept_client(
156
- self, next: temporalio.client.OutboundInterceptor
157
- ) -> temporalio.client.OutboundInterceptor:
156
+ def intercept_client(self, next: temporalio.client.OutboundInterceptor) -> temporalio.client.OutboundInterceptor:
158
157
  """Intercept client calls to propagate span context to workflows."""
159
158
  return _BraintrustClientOutboundInterceptor(next, self)
160
159
 
@@ -166,14 +165,14 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
166
165
 
167
166
  def workflow_interceptor_class(
168
167
  self, input: temporalio.worker.WorkflowInterceptorClassInput
169
- ) -> Optional[Type["BraintrustWorkflowInboundInterceptor"]]:
168
+ ) -> type["BraintrustWorkflowInboundInterceptor"] | None:
170
169
  """Return workflow interceptor class to propagate context to activities."""
171
170
  input.unsafe_extern_functions["__braintrust_get_logger"] = self._get_logger
172
171
  return BraintrustWorkflowInboundInterceptor
173
172
 
174
173
  def _span_context_to_headers(
175
174
  self,
176
- span_context: Dict[str, Any],
175
+ span_context: dict[str, Any],
177
176
  headers: Mapping[str, temporalio.api.common.v1.Payload],
178
177
  ) -> Mapping[str, temporalio.api.common.v1.Payload]:
179
178
  """Add span context to headers."""
@@ -188,7 +187,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
188
187
 
189
188
  def _span_context_from_headers(
190
189
  self, headers: Mapping[str, temporalio.api.common.v1.Payload]
191
- ) -> Optional[Dict[str, Any]]:
190
+ ) -> dict[str, Any] | None:
192
191
  """Extract span context from headers."""
193
192
  if _HEADER_KEY not in headers:
194
193
  return None
@@ -204,9 +203,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
204
203
  class _BraintrustClientOutboundInterceptor(temporalio.client.OutboundInterceptor):
205
204
  """Client interceptor that propagates span context to workflows."""
206
205
 
207
- def __init__(
208
- self, next: temporalio.client.OutboundInterceptor, root: BraintrustInterceptor
209
- ) -> None:
206
+ def __init__(self, next: temporalio.client.OutboundInterceptor, root: BraintrustInterceptor) -> None:
210
207
  super().__init__(next)
211
208
  self.root = root
212
209
 
@@ -233,9 +230,7 @@ class _BraintrustActivityInboundInterceptor(temporalio.worker.ActivityInboundInt
233
230
  super().__init__(next)
234
231
  self.root = root
235
232
 
236
- async def execute_activity(
237
- self, input: temporalio.worker.ExecuteActivityInput
238
- ) -> Any:
233
+ async def execute_activity(self, input: temporalio.worker.ExecuteActivityInput) -> Any:
239
234
  info = temporalio.activity.info()
240
235
 
241
236
  # Extract parent span context from headers
@@ -281,14 +276,12 @@ class BraintrustWorkflowInboundInterceptor(temporalio.worker.WorkflowInboundInte
281
276
  def __init__(self, next: temporalio.worker.WorkflowInboundInterceptor) -> None:
282
277
  super().__init__(next)
283
278
  self.payload_converter = temporalio.converter.PayloadConverter.default
284
- self._parent_span_context: Optional[Dict[str, Any]] = None
279
+ self._parent_span_context: dict[str, Any] | None = None
285
280
 
286
281
  def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None:
287
282
  super().init(_BraintrustWorkflowOutboundInterceptor(outbound, self))
288
283
 
289
- async def execute_workflow(
290
- self, input: temporalio.worker.ExecuteWorkflowInput
291
- ) -> Any:
284
+ async def execute_workflow(self, input: temporalio.worker.ExecuteWorkflowInput) -> Any:
292
285
  # Extract parent span context from workflow headers (set by client)
293
286
  parent_span_context = None
294
287
  if _HEADER_KEY in input.headers:
@@ -342,9 +335,7 @@ class BraintrustWorkflowInboundInterceptor(temporalio.worker.WorkflowInboundInte
342
335
  span.end()
343
336
 
344
337
 
345
- class _BraintrustWorkflowOutboundInterceptor(
346
- temporalio.worker.WorkflowOutboundInterceptor
347
- ):
338
+ class _BraintrustWorkflowOutboundInterceptor(temporalio.worker.WorkflowOutboundInterceptor):
348
339
  """Outbound workflow interceptor that propagates span context to activities."""
349
340
 
350
341
  def __init__(
@@ -371,9 +362,7 @@ class _BraintrustWorkflowOutboundInterceptor(
371
362
  return {**headers, _HEADER_KEY: payloads[0]}
372
363
  return headers
373
364
 
374
- def start_activity(
375
- self, input: temporalio.worker.StartActivityInput
376
- ) -> temporalio.workflow.ActivityHandle:
365
+ def start_activity(self, input: temporalio.worker.StartActivityInput) -> temporalio.workflow.ActivityHandle:
377
366
  input.headers = self._add_span_context_to_headers(input.headers)
378
367
  return super().start_activity(input)
379
368
 
@@ -390,7 +379,7 @@ class _BraintrustWorkflowOutboundInterceptor(
390
379
  return super().start_child_workflow(input)
391
380
 
392
381
 
393
- def _modify_workflow_runner(existing: Optional[WorkflowRunner]) -> Optional[WorkflowRunner]:
382
+ def _modify_workflow_runner(existing: WorkflowRunner | None) -> WorkflowRunner | None:
394
383
  """Add braintrust to sandbox passthrough modules."""
395
384
  if isinstance(existing, SandboxedWorkflowRunner):
396
385
  new_restrictions = existing.restrictions.with_passthrough_modules("braintrust")
@@ -420,7 +409,7 @@ class BraintrustPlugin(SimplePlugin):
420
409
  Requires temporalio >= 1.19.0.
421
410
  """
422
411
 
423
- def __init__(self, logger: Optional[Any] = None) -> None:
412
+ def __init__(self, logger: Any | None = None) -> None:
424
413
  """Initialize the Braintrust plugin.
425
414
 
426
415
  Args:
@@ -279,11 +279,15 @@ class TestBraintrustPluginIntegration:
279
279
 
280
280
  # Verify workflow span was created
281
281
  workflow_spans = [s for s in spans if "temporal.workflow" in s.get("span_attributes", {}).get("name", "")]
282
- assert len(workflow_spans) > 0, f"Expected workflow span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
282
+ assert len(workflow_spans) > 0, (
283
+ f"Expected workflow span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
284
+ )
283
285
 
284
286
  # Verify activity span was created
285
287
  activity_spans = [s for s in spans if "temporal.activity" in s.get("span_attributes", {}).get("name", "")]
286
- assert len(activity_spans) > 0, f"Expected activity span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
288
+ assert len(activity_spans) > 0, (
289
+ f"Expected activity span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
290
+ )
287
291
 
288
292
  @pytest.mark.asyncio
289
293
  async def test_plugin_context_propagation(self, temporal_env, memory_logger):
@@ -498,5 +502,6 @@ class TestBraintrustPluginIntegration:
498
502
  # Verify parent-child relationship: workflow should have client span as parent
499
503
  workflow_span = workflow_spans[0]
500
504
  client_span = client_spans[0]
501
- assert workflow_span.get("root_span_id") == client_span.get("root_span_id"), \
505
+ assert workflow_span.get("root_span_id") == client_span.get("root_span_id"), (
502
506
  "Workflow span should be in same trace as client span"
507
+ )