braintrust 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
braintrust/bt_json.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import dataclasses
2
2
  import json
3
- from typing import Any, cast
3
+ import math
4
+ from typing import Any, Callable, Mapping, NamedTuple, cast, overload
4
5
 
5
6
  # Try to import orjson for better performance
6
7
  # If not available, we'll use standard json
@@ -12,39 +13,184 @@ except ImportError:
12
13
  _HAS_ORJSON = False
13
14
 
14
15
 
15
- def _to_dict(obj: Any) -> Any:
16
- """
17
- Function-based default handler for non-JSON-serializable objects.
18
16
 
19
- Handles:
20
- - dataclasses
21
- - Pydantic v2 BaseModel
22
- - Pydantic v1 BaseModel
23
- - Falls back to str() for unknown types
17
+ def _to_bt_safe(v: Any) -> Any:
18
+ """
19
+ Converts the object to a Braintrust-safe representation (i.e. Attachment objects are safe (specially handled by background logger)).
24
20
  """
25
- if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
26
- return dataclasses.asdict(obj)
21
+ # avoid circular imports
22
+ from braintrust.logger import BaseAttachment, Dataset, Experiment, Logger, ReadonlyAttachment, Span
23
+
24
+ if isinstance(v, Span):
25
+ return "<span>"
26
+
27
+ if isinstance(v, Experiment):
28
+ return "<experiment>"
29
+
30
+ if isinstance(v, Dataset):
31
+ return "<dataset>"
32
+
33
+ if isinstance(v, Logger):
34
+ return "<logger>"
35
+
36
+ if isinstance(v, BaseAttachment):
37
+ return v
38
+
39
+ if isinstance(v, ReadonlyAttachment):
40
+ return v.reference
41
+
42
+ if dataclasses.is_dataclass(v) and not isinstance(v, type):
43
+ # Use manual field iteration instead of dataclasses.asdict() because
44
+ # asdict() deep-copies values, which breaks objects like Attachment
45
+ # that contain non-copyable items (thread locks, file handles, etc.)
46
+ return {f.name: _to_bt_safe(getattr(v, f.name)) for f in dataclasses.fields(v)}
47
+
48
+ # Pydantic model classes (not instances) with model_json_schema
49
+ if isinstance(v, type) and hasattr(v, "model_json_schema") and callable(cast(Any, v).model_json_schema):
50
+ try:
51
+ return cast(Any, v).model_json_schema()
52
+ except Exception:
53
+ pass
27
54
 
28
55
  # Attempt to dump a Pydantic v2 `BaseModel`.
29
56
  try:
30
- return cast(Any, obj).model_dump()
57
+ return cast(Any, v).model_dump(exclude_none=True)
31
58
  except (AttributeError, TypeError):
32
59
  pass
33
60
 
34
61
  # Attempt to dump a Pydantic v1 `BaseModel`.
35
62
  try:
36
- return cast(Any, obj).dict()
63
+ return cast(Any, v).dict(exclude_none=True)
37
64
  except (AttributeError, TypeError):
38
65
  pass
39
66
 
40
- # When everything fails, try to return the string representation of the object
67
+ if isinstance(v, float):
68
+ # Handle NaN and Infinity for JSON compatibility
69
+ if math.isnan(v):
70
+ return "NaN"
71
+
72
+ if math.isinf(v):
73
+ return "Infinity" if v > 0 else "-Infinity"
74
+
75
+ return v
76
+
77
+ if isinstance(v, (int, str, bool)) or v is None:
78
+ # Skip roundtrip for primitive types.
79
+ return v
80
+
81
+ # Note: we avoid using copy.deepcopy, because it's difficult to
82
+ # guarantee the independence of such copied types from their origin.
83
+ # E.g. the original type could have a `__del__` method that alters
84
+ # some shared internal state, and we need this deep copy to be
85
+ # fully-independent from the original.
86
+
87
+ # We pass `encoder=_str_encoder` since we've already tried converting rich objects to json safe objects.
88
+ return bt_loads(bt_dumps(v, encoder=_str_encoder))
89
+
90
+ @overload
91
+ def bt_safe_deep_copy(
92
+ obj: Mapping[str, Any],
93
+ max_depth: int = ...,
94
+ ) -> dict[str, Any]: ...
95
+
96
+ @overload
97
+ def bt_safe_deep_copy(
98
+ obj: list[Any],
99
+ max_depth: int = ...,
100
+ ) -> list[Any]: ...
101
+
102
+ @overload
103
+ def bt_safe_deep_copy(
104
+ obj: Any,
105
+ max_depth: int = ...,
106
+ ) -> Any: ...
107
+ def bt_safe_deep_copy(obj: Any, max_depth: int=200):
108
+ """
109
+ Creates a deep copy of the given object and converts rich objects to Braintrust-safe representations. See `_to_bt_safe` for more details.
110
+
111
+ Args:
112
+ obj: Object to deep copy and sanitize.
113
+ to_json_safe: Function to ensure the object is json safe.
114
+ max_depth: Maximum depth to copy.
115
+
116
+ Returns:
117
+ Deep copy of the object.
118
+ """
119
+ # Track visited objects to detect circular references
120
+ visited: set[int] = set()
121
+
122
+ def _deep_copy_object(v: Any, depth: int = 0) -> Any:
123
+ # Check depth limit - use >= to stop before exceeding
124
+ if depth >= max_depth:
125
+ return "<max depth exceeded>"
126
+
127
+ # Check for circular references in mutable containers
128
+ # Use id() to track object identity
129
+ if isinstance(v, (Mapping, list, tuple, set)):
130
+ obj_id = id(v)
131
+ if obj_id in visited:
132
+ return "<circular reference>"
133
+ visited.add(obj_id)
134
+ try:
135
+ if isinstance(v, Mapping):
136
+ # Prevent dict keys from holding references to user data. Note that
137
+ # `bt_json` already coerces keys to string, a behavior that comes from
138
+ # `json.dumps`. However, that runs at log upload time, while we want to
139
+ # cut out all the references to user objects synchronously in this
140
+ # function.
141
+ result = {}
142
+ for k in v:
143
+ try:
144
+ key_str = str(k)
145
+ except Exception:
146
+ # If str() fails on the key, use a fallback representation
147
+ key_str = f"<non-stringifiable-key: {type(k).__name__}>"
148
+ result[key_str] = _deep_copy_object(v[k], depth + 1)
149
+ return result
150
+ elif isinstance(v, (list, tuple, set)):
151
+ return [_deep_copy_object(x, depth + 1) for x in v]
152
+ finally:
153
+ # Remove from visited set after processing to allow the same object
154
+ # to appear in different branches of the tree
155
+ visited.discard(obj_id)
156
+
157
+ try:
158
+ return _to_bt_safe(v)
159
+ except Exception:
160
+ return f"<non-sanitizable: {type(v).__name__}>"
161
+
162
+ return _deep_copy_object(obj)
163
+
164
+ def _safe_str(obj: Any) -> str:
41
165
  try:
42
166
  return str(obj)
43
167
  except Exception:
44
- # If str() fails, return an error placeholder
45
168
  return f"<non-serializable: {type(obj).__name__}>"
46
169
 
47
170
 
171
+ def _to_json_safe(obj: Any) -> Any:
172
+ """
173
+ Handler for non-JSON-serializable objects. Returns a string representation of the object.
174
+ """
175
+ # avoid circular imports
176
+ from braintrust.logger import BaseAttachment
177
+
178
+ try:
179
+ v = _to_bt_safe(obj)
180
+
181
+ # JSON-safe representation of Attachment objects are their reference.
182
+ # If we get this object at this point, we have to assume someone has already uploaded the attachment!
183
+ if isinstance(v, BaseAttachment):
184
+ v = v.reference
185
+
186
+ return v
187
+ except Exception:
188
+ pass
189
+
190
+ # When everything fails, try to return the string representation of the object
191
+ return _safe_str(obj)
192
+
193
+
48
194
  class BraintrustJSONEncoder(json.JSONEncoder):
49
195
  """
50
196
  Custom JSON encoder for standard json library.
@@ -53,10 +199,22 @@ class BraintrustJSONEncoder(json.JSONEncoder):
53
199
  """
54
200
 
55
201
  def default(self, o: Any):
56
- return _to_dict(o)
202
+ return _to_json_safe(o)
203
+
204
+
205
+ class BraintrustStrEncoder(json.JSONEncoder):
206
+ def default(self, o: Any):
207
+ return _safe_str(o)
208
+
209
+
210
+ class Encoder(NamedTuple):
211
+ native: type[json.JSONEncoder]
212
+ orjson: Callable[[Any], Any]
57
213
 
214
+ _json_encoder = Encoder(native=BraintrustJSONEncoder, orjson=_to_json_safe)
215
+ _str_encoder = Encoder(native=BraintrustStrEncoder, orjson=_safe_str)
58
216
 
59
- def bt_dumps(obj, **kwargs) -> str:
217
+ def bt_dumps(obj: Any, encoder: Encoder | None=_json_encoder, **kwargs: Any) -> str:
60
218
  """
61
219
  Serialize obj to a JSON-formatted string.
62
220
 
@@ -65,6 +223,7 @@ def bt_dumps(obj, **kwargs) -> str:
65
223
 
66
224
  Args:
67
225
  obj: Object to serialize
226
+ encoder: Encoder to use, defaults to `_default_encoder`
68
227
  **kwargs: Additional arguments (passed to json.dumps in fallback path)
69
228
 
70
229
  Returns:
@@ -76,7 +235,7 @@ def bt_dumps(obj, **kwargs) -> str:
76
235
  # pylint: disable=no-member # orjson is a C extension, pylint can't introspect it
77
236
  return orjson.dumps( # type: ignore[possibly-unbound]
78
237
  obj,
79
- default=_to_dict,
238
+ default=encoder.orjson if encoder else None,
80
239
  # options match json.dumps behavior for bc
81
240
  option=orjson.OPT_SORT_KEYS | orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS, # type: ignore[possibly-unbound]
82
241
  ).decode("utf-8")
@@ -86,7 +245,7 @@ def bt_dumps(obj, **kwargs) -> str:
86
245
 
87
246
  # Use standard json (either orjson not available or it failed)
88
247
  # Use sort_keys=True for deterministic output (matches orjson OPT_SORT_KEYS)
89
- return json.dumps(obj, cls=BraintrustJSONEncoder, allow_nan=False, sort_keys=True, **kwargs)
248
+ return json.dumps(obj, cls=encoder.native if encoder else None, allow_nan=False, sort_keys=True, **kwargs)
90
249
 
91
250
 
92
251
  def bt_loads(s: str, **kwargs) -> Any:
braintrust/db_fields.py CHANGED
@@ -5,6 +5,7 @@ ID_FIELD = "id"
5
5
 
6
6
  IS_MERGE_FIELD = "_is_merge"
7
7
  MERGE_PATHS_FIELD = "_merge_paths"
8
+ ARRAY_DELETE_FIELD = "_array_delete"
8
9
 
9
10
  AUDIT_SOURCE_FIELD = "_audit_source"
10
11
  AUDIT_METADATA_FIELD = "_audit_metadata"
braintrust/framework.py CHANGED
@@ -47,7 +47,7 @@ from .resource_manager import ResourceManager
47
47
  from .score import Score, is_score, is_scorer
48
48
  from .serializable_data_class import SerializableDataClass
49
49
  from .span_types import SpanTypeAttribute
50
- from .util import bt_iscoroutinefunction, eprint
50
+ from .util import bt_iscoroutinefunction, eprint, merge_dicts
51
51
 
52
52
  Input = TypeVar("Input")
53
53
  Output = TypeVar("Output")
@@ -1284,8 +1284,17 @@ async def _run_evaluator_internal(
1284
1284
  event_loop = asyncio.get_event_loop()
1285
1285
 
1286
1286
  async def await_or_run_scorer(root_span, scorer, name, **kwargs):
1287
+ # Merge purpose into parent's propagated_event rather than replacing it
1288
+ parent_propagated = root_span.propagated_event or {}
1289
+ merged_propagated = merge_dicts(
1290
+ {**parent_propagated},
1291
+ {"span_attributes": {"purpose": "scorer"}},
1292
+ )
1287
1293
  with root_span.start_span(
1288
- name=name, span_attributes={"type": SpanTypeAttribute.SCORE}, input=dict(**kwargs)
1294
+ name=name,
1295
+ span_attributes={"type": SpanTypeAttribute.SCORE, "purpose": "scorer"},
1296
+ propagated_event=merged_propagated,
1297
+ input=dict(**kwargs),
1289
1298
  ) as span:
1290
1299
  score = scorer
1291
1300
  if hasattr(scorer, "eval_async"):
@@ -1550,9 +1559,9 @@ def build_local_summary(
1550
1559
  scores_by_name = defaultdict(lambda: (0, 0))
1551
1560
  for result in results:
1552
1561
  for name, score in result.scores.items():
1553
- curr = scores_by_name[name]
1554
- if curr is None:
1562
+ if score is None:
1555
1563
  continue
1564
+ curr = scores_by_name[name]
1556
1565
  scores_by_name[name] = (curr[0] + score, curr[1] + 1)
1557
1566
  longest_score_name = max(len(name) for name in scores_by_name) if scores_by_name else 0
1558
1567
  avg_scores = {
braintrust/logger.py CHANGED
@@ -9,7 +9,6 @@ import inspect
9
9
  import io
10
10
  import json
11
11
  import logging
12
- import math
13
12
  import os
14
13
  import sys
15
14
  import textwrap
@@ -46,7 +45,7 @@ from requests.adapters import HTTPAdapter
46
45
  from urllib3.util.retry import Retry
47
46
 
48
47
  from . import context, id_gen
49
- from .bt_json import bt_dumps, bt_loads
48
+ from .bt_json import bt_dumps, bt_safe_deep_copy
50
49
  from .db_fields import (
51
50
  ASYNC_SCORING_CONTROL_FIELD,
52
51
  AUDIT_METADATA_FIELD,
@@ -271,6 +270,10 @@ class _NoopSpan(Span):
271
270
  def id(self):
272
271
  return ""
273
272
 
273
+ @property
274
+ def propagated_event(self):
275
+ return None
276
+
274
277
  def log(self, **event: Any):
275
278
  pass
276
279
 
@@ -739,13 +742,6 @@ def construct_logs3_data(items: Sequence[str]):
739
742
  return '{"rows": ' + rowsS + ', "api_version": ' + str(DATA_API_VERSION) + "}"
740
743
 
741
744
 
742
- def _check_json_serializable(event):
743
- try:
744
- return bt_dumps(event)
745
- except (TypeError, ValueError) as e:
746
- raise Exception(f"All logged values must be JSON-serializable: {event}") from e
747
-
748
-
749
745
  class _MaskingError:
750
746
  """Internal class to signal masking errors that need special handling."""
751
747
 
@@ -795,6 +791,7 @@ class _MemoryBackgroundLogger(_BackgroundLogger):
795
791
  self.lock = threading.Lock()
796
792
  self.logs = []
797
793
  self.masking_function: Callable[[Any], Any] | None = None
794
+ self.upload_attempts: list[BaseAttachment] = [] # Track upload attempts
798
795
 
799
796
  def enforce_queue_size_limit(self, enforce: bool) -> None:
800
797
  pass
@@ -808,7 +805,21 @@ class _MemoryBackgroundLogger(_BackgroundLogger):
808
805
  self.masking_function = masking_function
809
806
 
810
807
  def flush(self, batch_size: int | None = None):
811
- pass
808
+ """Flush the memory logger, extracting attachments and tracking upload attempts."""
809
+ with self.lock:
810
+ if not self.logs:
811
+ return
812
+
813
+ # Unwrap lazy values and extract attachments
814
+ logs = [l.get() for l in self.logs]
815
+
816
+ # Extract attachments from all logs
817
+ attachments: list[BaseAttachment] = []
818
+ for log in logs:
819
+ _extract_attachments(log, attachments)
820
+
821
+ # Track upload attempts (don't actually call upload() in tests)
822
+ self.upload_attempts.extend(attachments)
812
823
 
813
824
  def pop(self):
814
825
  with self.lock:
@@ -1959,24 +1970,14 @@ def get_span_parent_object(
1959
1970
 
1960
1971
  def _try_log_input(span, f_sig, f_args, f_kwargs):
1961
1972
  if f_sig:
1962
- bound_args = f_sig.bind(*f_args, **f_kwargs).arguments
1963
- input_serializable = bound_args
1973
+ input_data = f_sig.bind(*f_args, **f_kwargs).arguments
1964
1974
  else:
1965
- input_serializable = dict(args=f_args, kwargs=f_kwargs)
1966
- try:
1967
- _check_json_serializable(input_serializable)
1968
- except Exception as e:
1969
- input_serializable = "<input not json-serializable>: " + str(e)
1970
- span.log(input=input_serializable)
1975
+ input_data = dict(args=f_args, kwargs=f_kwargs)
1976
+ span.log(input=input_data)
1971
1977
 
1972
1978
 
1973
1979
  def _try_log_output(span, output):
1974
- output_serializable = output
1975
- try:
1976
- _check_json_serializable(output)
1977
- except Exception as e:
1978
- output_serializable = "<output not json-serializable>: " + str(e)
1979
- span.log(output=output_serializable)
1980
+ span.log(output=output)
1980
1981
 
1981
1982
 
1982
1983
  F = TypeVar("F", bound=Callable[..., Any])
@@ -2426,91 +2427,6 @@ def _validate_and_sanitize_experiment_log_full_args(event: Mapping[str, Any], ha
2426
2427
  return event
2427
2428
 
2428
2429
 
2429
- def _deep_copy_event(event: Mapping[str, Any]) -> dict[str, Any]:
2430
- """
2431
- Creates a deep copy of the given event. Replaces references to user objects
2432
- with placeholder strings to ensure serializability, except for `Attachment`
2433
- and `ExternalAttachment` objects, which are preserved and not deep-copied.
2434
-
2435
- Handles circular references and excessive nesting depth to prevent
2436
- RecursionError during serialization.
2437
- """
2438
- # Maximum depth to prevent hitting Python's recursion limit
2439
- # Python's default limit is ~1000, we use a conservative limit
2440
- # to account for existing call stack usage from pytest, application code, etc.
2441
- MAX_DEPTH = 200
2442
-
2443
- # Track visited objects to detect circular references
2444
- visited: set[int] = set()
2445
-
2446
- def _deep_copy_object(v: Any, depth: int = 0) -> Any:
2447
- # Check depth limit - use >= to stop before exceeding
2448
- if depth >= MAX_DEPTH:
2449
- return "<max depth exceeded>"
2450
-
2451
- # Check for circular references in mutable containers
2452
- # Use id() to track object identity
2453
- if isinstance(v, (Mapping, list, tuple, set)):
2454
- obj_id = id(v)
2455
- if obj_id in visited:
2456
- return "<circular reference>"
2457
- visited.add(obj_id)
2458
- try:
2459
- if isinstance(v, Mapping):
2460
- # Prevent dict keys from holding references to user data. Note that
2461
- # `bt_json` already coerces keys to string, a behavior that comes from
2462
- # `json.dumps`. However, that runs at log upload time, while we want to
2463
- # cut out all the references to user objects synchronously in this
2464
- # function.
2465
- result = {}
2466
- for k in v:
2467
- try:
2468
- key_str = str(k)
2469
- except Exception:
2470
- # If str() fails on the key, use a fallback representation
2471
- key_str = f"<non-stringifiable-key: {type(k).__name__}>"
2472
- result[key_str] = _deep_copy_object(v[k], depth + 1)
2473
- return result
2474
- elif isinstance(v, (list, tuple, set)):
2475
- return [_deep_copy_object(x, depth + 1) for x in v]
2476
- finally:
2477
- # Remove from visited set after processing to allow the same object
2478
- # to appear in different branches of the tree
2479
- visited.discard(obj_id)
2480
-
2481
- if isinstance(v, Span):
2482
- return "<span>"
2483
- elif isinstance(v, Experiment):
2484
- return "<experiment>"
2485
- elif isinstance(v, Dataset):
2486
- return "<dataset>"
2487
- elif isinstance(v, Logger):
2488
- return "<logger>"
2489
- elif isinstance(v, BaseAttachment):
2490
- return v
2491
- elif isinstance(v, ReadonlyAttachment):
2492
- return v.reference
2493
- elif isinstance(v, float):
2494
- # Handle NaN and Infinity for JSON compatibility
2495
- if math.isnan(v):
2496
- return "NaN"
2497
- elif math.isinf(v):
2498
- return "Infinity" if v > 0 else "-Infinity"
2499
- return v
2500
- elif isinstance(v, (int, str, bool)) or v is None:
2501
- # Skip roundtrip for primitive types.
2502
- return v
2503
- else:
2504
- # Note: we avoid using copy.deepcopy, because it's difficult to
2505
- # guarantee the independence of such copied types from their origin.
2506
- # E.g. the original type could have a `__del__` method that alters
2507
- # some shared internal state, and we need this deep copy to be
2508
- # fully-independent from the original.
2509
- return bt_loads(bt_dumps(v))
2510
-
2511
- return _deep_copy_object(event)
2512
-
2513
-
2514
2430
  class ObjectIterator(Generic[T]):
2515
2431
  def __init__(self, refetch_fn: Callable[[], Sequence[T]]):
2516
2432
  self.refetch_fn = refetch_fn
@@ -3060,7 +2976,7 @@ def _log_feedback_impl(
3060
2976
  metadata = update_event.pop("metadata")
3061
2977
  update_event = {k: v for k, v in update_event.items() if v is not None}
3062
2978
 
3063
- update_event = _deep_copy_event(update_event)
2979
+ update_event = bt_safe_deep_copy(update_event)
3064
2980
 
3065
2981
  def parent_ids():
3066
2982
  exporter = _get_exporter()
@@ -3116,7 +3032,7 @@ def _update_span_impl(
3116
3032
  event=event,
3117
3033
  )
3118
3034
 
3119
- update_event = _deep_copy_event(update_event)
3035
+ update_event = bt_safe_deep_copy(update_event)
3120
3036
 
3121
3037
  def parent_ids():
3122
3038
  exporter = _get_exporter()
@@ -3936,14 +3852,10 @@ class SpanImpl(Span):
3936
3852
  **{IS_MERGE_FIELD: self._is_merge},
3937
3853
  )
3938
3854
 
3939
- serializable_partial_record = _deep_copy_event(partial_record)
3940
- _check_json_serializable(serializable_partial_record)
3855
+ serializable_partial_record = bt_safe_deep_copy(partial_record)
3941
3856
  if serializable_partial_record.get("metrics", {}).get("end") is not None:
3942
3857
  self._logged_end_time = serializable_partial_record["metrics"]["end"]
3943
3858
 
3944
- if len(serializable_partial_record.get("tags", [])) > 0 and self.span_parents:
3945
- raise Exception("Tags can only be logged to the root span")
3946
-
3947
3859
  def compute_record() -> dict[str, Any]:
3948
3860
  exporter = _get_exporter()
3949
3861
  return dict(
@@ -4304,8 +4216,7 @@ class Dataset(ObjectFetcher[DatasetEvent]):
4304
4216
  args[IS_MERGE_FIELD] = True
4305
4217
  args = _filter_none_args(args) # If merging, then remove None values to prevent null value writes
4306
4218
 
4307
- _check_json_serializable(args)
4308
- args = _deep_copy_event(args)
4219
+ args = bt_safe_deep_copy(args)
4309
4220
 
4310
4221
  def compute_args() -> dict[str, Any]:
4311
4222
  return dict(
@@ -4408,8 +4319,7 @@ class Dataset(ObjectFetcher[DatasetEvent]):
4408
4319
  "_object_delete": True, # XXX potentially place this in the logging endpoint
4409
4320
  },
4410
4321
  )
4411
- _check_json_serializable(partial_args)
4412
- partial_args = _deep_copy_event(partial_args)
4322
+ partial_args = bt_safe_deep_copy(partial_args)
4413
4323
 
4414
4324
  def compute_args():
4415
4325
  return dict(
@@ -90,18 +90,13 @@ class AISpanProcessor:
90
90
  def _should_keep_filtered_span(self, span):
91
91
  """
92
92
  Keep spans if:
93
- 1. It's a root span (no parent)
94
- 2. Custom filter returns True/False (if provided)
95
- 3. Span name starts with 'gen_ai.', 'braintrust.', 'llm.', 'ai.', or 'traceloop.'
96
- 4. Any attribute name starts with those prefixes
93
+ 1. Custom filter returns True/False (if provided)
94
+ 2. Span name starts with 'gen_ai.', 'braintrust.', 'llm.', 'ai.', or 'traceloop.'
95
+ 3. Any attribute name starts with those prefixes
97
96
  """
98
97
  if not span:
99
98
  return False
100
99
 
101
- # Braintrust requires root spans, so always keep them
102
- if span.parent is None:
103
- return True
104
-
105
100
  # Apply custom filter if provided
106
101
  if self._custom_filter:
107
102
  custom_result = self._custom_filter(span)
@@ -384,6 +379,9 @@ def _get_braintrust_parent(object_type, object_id: str | None = None, compute_ar
384
379
 
385
380
  return None
386
381
 
382
+ def is_root_span(span) -> bool:
383
+ """Returns True if the span is a root span (no parent span)."""
384
+ return getattr(span, "parent", None) is None
387
385
 
388
386
  def context_from_span_export(export_str: str):
389
387
  """
@@ -522,15 +520,17 @@ def add_span_parent_to_baggage(span, ctx=None):
522
520
  return add_parent_to_baggage(parent_value, ctx=ctx)
523
521
 
524
522
 
525
- def parent_from_headers(headers: dict[str, str]) -> str | None:
523
+ def parent_from_headers(headers: dict[str, str], propagator=None) -> str | None:
526
524
  """
527
- Extract a Braintrust-compatible parent string from W3C Trace Context headers.
525
+ Extract a Braintrust-compatible parent string from trace context headers.
528
526
 
529
- This converts OTEL trace context headers (traceparent/baggage) into a format
530
- that can be passed as the 'parent' parameter to Braintrust's start_span() method.
527
+ This converts OTEL trace context headers into a format that can be passed
528
+ as the 'parent' parameter to Braintrust's start_span() method.
531
529
 
532
530
  Args:
533
- headers: Dictionary with 'traceparent' and optionally 'baggage' keys
531
+ headers: Dictionary with trace context headers (e.g., 'traceparent'/'baggage' for W3C)
532
+ propagator: Optional custom TextMapPropagator. If not provided, uses the
533
+ globally registered propagator (W3C TraceContext by default).
534
534
 
535
535
  Returns:
536
536
  Braintrust V4 export string that can be used as parent parameter,
@@ -545,6 +545,12 @@ def parent_from_headers(headers: dict[str, str]) -> str | None:
545
545
  >>> parent = parent_from_headers(headers)
546
546
  >>> with project.start_span(name="service_c", parent=parent) as span:
547
547
  >>> span.log(input="BT span as child of OTEL parent")
548
+
549
+ >>> # Using a custom propagator (e.g., B3 format)
550
+ >>> from opentelemetry.propagators.b3 import B3MultiFormat
551
+ >>> propagator = B3MultiFormat()
552
+ >>> headers = {'X-B3-TraceId': '...', 'X-B3-SpanId': '...', 'baggage': '...'}
553
+ >>> parent = parent_from_headers(headers, propagator=propagator)
548
554
  """
549
555
  if not OTEL_AVAILABLE:
550
556
  raise ImportError(INSTALL_ERR_MSG)
@@ -553,8 +559,11 @@ def parent_from_headers(headers: dict[str, str]) -> str | None:
553
559
  from opentelemetry import baggage, trace
554
560
  from opentelemetry.propagate import extract
555
561
 
556
- # Extract context from headers using W3C Trace Context propagator
557
- ctx = extract(headers)
562
+ # Extract context from headers using provided propagator or global propagator
563
+ if propagator is not None:
564
+ ctx = propagator.extract(headers)
565
+ else:
566
+ ctx = extract(headers)
558
567
 
559
568
  # Get span from context
560
569
  span = trace.get_current_span(ctx)