braintrust 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
braintrust/test_otel.py CHANGED
@@ -294,13 +294,12 @@ class TestSpanFiltering:
294
294
  self.provider.shutdown()
295
295
  self.memory_exporter.clear()
296
296
 
297
- def test_keeps_root_spans(self):
297
+ def test_filters_out_root_spans(self):
298
298
  with self.tracer.start_as_current_span("root_operation"):
299
299
  pass
300
300
 
301
301
  spans = self.memory_exporter.get_finished_spans()
302
- assert len(spans) == 1
303
- assert spans[0].name == "root_operation"
302
+ assert len(spans) == 0
304
303
 
305
304
  def test_keeps_gen_ai_spans(self):
306
305
  with self.tracer.start_as_current_span("root"):
@@ -312,7 +311,7 @@ class TestSpanFiltering:
312
311
  spans = self.memory_exporter.get_finished_spans()
313
312
  span_names = [span.name for span in spans]
314
313
 
315
- assert "root" in span_names
314
+ assert "root" not in span_names
316
315
  assert "gen_ai.completion" in span_names
317
316
  assert "regular_operation" not in span_names
318
317
 
@@ -329,35 +328,37 @@ class TestSpanFiltering:
329
328
  assert "braintrust.eval" in span_names
330
329
  assert "database_query" not in span_names
331
330
 
332
- def test_keeps_llm_spans(self):
331
+ def test_keeps_traceloop_spans(self):
333
332
  with self.tracer.start_as_current_span("root"):
334
- with self.tracer.start_as_current_span("llm.generate"):
333
+ with self.tracer.start_as_current_span("traceloop.agent"):
334
+ pass
335
+ with self.tracer.start_as_current_span("traceloop.workflow.step"):
335
336
  pass
336
337
 
337
338
  spans = self.memory_exporter.get_finished_spans()
338
339
  span_names = [span.name for span in spans]
339
- assert "llm.generate" in span_names
340
+ assert "root" not in span_names
341
+ assert "traceloop.agent" in span_names
342
+ assert "traceloop.workflow.step" in span_names
340
343
 
341
- def test_keeps_ai_spans(self):
344
+ def test_keeps_llm_spans(self):
342
345
  with self.tracer.start_as_current_span("root"):
343
- with self.tracer.start_as_current_span("ai.model_call"):
346
+ with self.tracer.start_as_current_span("llm.generate"):
344
347
  pass
345
348
 
346
349
  spans = self.memory_exporter.get_finished_spans()
347
350
  span_names = [span.name for span in spans]
348
- assert "ai.model_call" in span_names
351
+ assert "llm.generate" in span_names
349
352
 
350
- def test_keeps_traceloop_spans(self):
353
+ def test_keeps_ai_spans(self):
351
354
  with self.tracer.start_as_current_span("root"):
352
- with self.tracer.start_as_current_span("traceloop.agent"):
353
- pass
354
- with self.tracer.start_as_current_span("traceloop.workflow.step"):
355
+ with self.tracer.start_as_current_span("ai.model_call"):
355
356
  pass
356
357
 
357
358
  spans = self.memory_exporter.get_finished_spans()
358
359
  span_names = [span.name for span in spans]
359
- assert "traceloop.agent" in span_names
360
- assert "traceloop.workflow.step" in span_names
360
+ assert "root" not in span_names
361
+ assert "ai.model_call" in span_names
361
362
 
362
363
  def test_keeps_spans_with_llm_attributes(self):
363
364
  with self.tracer.start_as_current_span("root"):
@@ -374,7 +375,7 @@ class TestSpanFiltering:
374
375
  spans = self.memory_exporter.get_finished_spans()
375
376
  span_names = [span.name for span in spans]
376
377
 
377
- assert "root" in span_names
378
+ assert "root" not in span_names
378
379
  assert "some_operation" in span_names # has gen_ai.model attribute
379
380
  assert "another_operation" in span_names # has llm.tokens attribute
380
381
  assert "traceloop_operation" in span_names # has traceloop.agent_id attribute
@@ -390,10 +391,7 @@ class TestSpanFiltering:
390
391
  pass
391
392
 
392
393
  spans = self.memory_exporter.get_finished_spans()
393
-
394
- # Only root should be kept
395
- assert len(spans) == 1
396
- assert spans[0].name == "root"
394
+ assert len(spans) == 0
397
395
 
398
396
  def test_custom_filter_keeps_spans(self):
399
397
  def custom_filter(span):
@@ -422,9 +420,9 @@ class TestSpanFiltering:
422
420
  spans = memory_exporter.get_finished_spans()
423
421
  span_names = [span.name for span in spans]
424
422
 
425
- assert "root" in span_names
426
423
  assert "custom_keep" in span_names # kept by custom filter
427
424
  assert "regular_operation" not in span_names # dropped by default logic
425
+ assert "root" not in span_names
428
426
 
429
427
  def test_custom_filter_drops_spans(self):
430
428
  def custom_filter(span):
@@ -453,9 +451,9 @@ class TestSpanFiltering:
453
451
  spans = memory_exporter.get_finished_spans()
454
452
  span_names = [span.name for span in spans]
455
453
 
456
- assert "root" in span_names
457
454
  assert "gen_ai.drop_this" not in span_names # dropped by custom filter
458
455
  assert "gen_ai.keep_this" in span_names # kept by default LLM logic
456
+ assert "root" not in span_names
459
457
 
460
458
  def test_custom_filter_none_uses_default_logic(self):
461
459
  def custom_filter(span):
@@ -482,7 +480,7 @@ class TestSpanFiltering:
482
480
  spans = memory_exporter.get_finished_spans()
483
481
  span_names = [span.name for span in spans]
484
482
 
485
- assert "root" in span_names
483
+ assert "root" not in span_names
486
484
  assert "gen_ai.completion" in span_names # kept by default LLM logic
487
485
  assert "regular_operation" not in span_names # dropped by default logic
488
486
 
@@ -546,11 +544,32 @@ class TestSpanFiltering:
546
544
  filtered_spans = filtered_spans_exporter.get_finished_spans()
547
545
  filtered_span_names = [span.name for span in filtered_spans]
548
546
 
549
- assert len(filtered_spans) == 3
550
- assert "user_request" in filtered_span_names # root span
547
+ assert len(filtered_spans) == 2
548
+ assert "user_request" not in filtered_span_names # root span
551
549
  assert "gen_ai.completion" in filtered_span_names # LLM name
552
550
  assert "response_formatting" in filtered_span_names # LLM attribute
553
551
 
552
+ def test_custom_filter_is_root_span(self):
553
+ from braintrust.otel import AISpanProcessor, is_root_span
554
+ from opentelemetry.sdk.trace import TracerProvider
555
+ from opentelemetry.sdk.trace.export import SimpleSpanProcessor
556
+ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
557
+
558
+ memory_exporter = InMemorySpanExporter()
559
+ processor = AISpanProcessor(SimpleSpanProcessor(memory_exporter), custom_filter=is_root_span)
560
+ provider = TracerProvider()
561
+ provider.add_span_processor(processor)
562
+ tracer = provider.get_tracer("test-braintrust-root-filter")
563
+
564
+ with tracer.start_as_current_span("root_span"):
565
+ with tracer.start_as_current_span("child_span"):
566
+ pass
567
+
568
+ provider.shutdown()
569
+ spans = memory_exporter.get_finished_spans()
570
+ names = [span.name for span in spans]
571
+ assert "root_span" in names
572
+ assert "child_span" not in names
554
573
 
555
574
  def test_parent_from_headers_invalid_inputs():
556
575
  """Test parent_from_headers with various invalid inputs."""
@@ -716,3 +735,76 @@ def test_add_span_parent_to_baggage():
716
735
  # Test with None span (should return None and warn)
717
736
  token = add_span_parent_to_baggage(None)
718
737
  assert token is None
738
+
739
+
740
+ def test_parent_from_headers_with_custom_propagator():
741
+ """Test parent_from_headers with a custom propagator."""
742
+ if not _check_otel_installed():
743
+ pytest.skip("OpenTelemetry SDK not fully installed, skipping test")
744
+
745
+ from braintrust.otel import parent_from_headers
746
+ from opentelemetry import baggage as otel_baggage
747
+ from opentelemetry import context as otel_context
748
+ from opentelemetry import trace
749
+ from opentelemetry.propagators.textmap import CarrierT, Getter, TextMapPropagator, default_getter
750
+ from opentelemetry.trace import NonRecordingSpan, SpanContext, TraceFlags
751
+
752
+ class CustomHeaderPropagator(TextMapPropagator):
753
+ """Custom propagator that reads trace context from X-Custom-* headers."""
754
+
755
+ def extract(
756
+ self,
757
+ carrier: CarrierT,
758
+ context: otel_context.Context | None = None,
759
+ getter: Getter = default_getter,
760
+ ) -> otel_context.Context:
761
+ if context is None:
762
+ context = otel_context.get_current()
763
+
764
+ trace_id = getter.get(carrier, "X-Custom-Trace-Id")
765
+ span_id = getter.get(carrier, "X-Custom-Span-Id")
766
+
767
+ if trace_id and span_id:
768
+ trace_id_list = trace_id if isinstance(trace_id, list) else [trace_id]
769
+ span_id_list = span_id if isinstance(span_id, list) else [span_id]
770
+
771
+ span_context = SpanContext(
772
+ trace_id=int(trace_id_list[0], 16),
773
+ span_id=int(span_id_list[0], 16),
774
+ is_remote=True,
775
+ trace_flags=TraceFlags.SAMPLED,
776
+ )
777
+ span = NonRecordingSpan(span_context)
778
+ context = trace.set_span_in_context(span, context)
779
+
780
+ # Also extract baggage from standard baggage header
781
+ baggage_header = getter.get(carrier, "baggage")
782
+ if baggage_header:
783
+ baggage_list = baggage_header if isinstance(baggage_header, list) else [baggage_header]
784
+ for item in baggage_list[0].split(","):
785
+ if "=" in item:
786
+ key, value = item.split("=", 1)
787
+ context = otel_baggage.set_baggage(key.strip(), value.strip(), context)
788
+
789
+ return context
790
+
791
+ def inject(self, carrier, context=None, setter=None):
792
+ pass # Not needed for this test
793
+
794
+ @property
795
+ def fields(self):
796
+ return {"X-Custom-Trace-Id", "X-Custom-Span-Id", "baggage"}
797
+
798
+ propagator = CustomHeaderPropagator()
799
+
800
+ # Custom header format
801
+ headers = {
802
+ "X-Custom-Trace-Id": "4bf92f3577b34da6a3ce929d0e0e4736",
803
+ "X-Custom-Span-Id": "00f067aa0ba902b7",
804
+ "baggage": "braintrust.parent=project_name:test-project",
805
+ }
806
+
807
+ result = parent_from_headers(headers, propagator=propagator)
808
+ assert result is not None
809
+ assert isinstance(result, str)
810
+ assert len(result) > 0
braintrust/test_util.py CHANGED
@@ -3,7 +3,7 @@ from typing import List
3
3
 
4
4
  import pytest
5
5
 
6
- from .util import LazyValue, mask_api_key
6
+ from .util import LazyValue, mask_api_key, merge_dicts_with_paths
7
7
 
8
8
 
9
9
  class TestLazyValue(unittest.TestCase):
@@ -160,3 +160,53 @@ def test_mask_api_key():
160
160
  assert mask_api_key("12345") == "12*45"
161
161
  for i in ["", "1", "12", "123", "1234"]:
162
162
  assert mask_api_key(i) == "*" * len(i)
163
+
164
+
165
+ class TestTagsSetUnionMerge:
166
+ def test_tags_arrays_are_merged_as_sets_by_default(self):
167
+ a = {"tags": ["a", "b"]}
168
+ b = {"tags": ["b", "c"]}
169
+ merge_dicts_with_paths(a, b, (), set())
170
+ assert set(a["tags"]) == {"a", "b", "c"}
171
+
172
+ def test_tags_merge_deduplicates_values(self):
173
+ a = {"tags": ["a", "b", "c"]}
174
+ b = {"tags": ["a", "b", "c", "d"]}
175
+ merge_dicts_with_paths(a, b, (), set())
176
+ assert set(a["tags"]) == {"a", "b", "c", "d"}
177
+
178
+ def test_tags_merge_works_when_merge_into_has_no_tags(self):
179
+ a = {"other": "data"}
180
+ b = {"tags": ["a", "b"]}
181
+ merge_dicts_with_paths(a, b, (), set())
182
+ assert set(a["tags"]) == {"a", "b"}
183
+
184
+ def test_tags_merge_works_when_merge_from_has_no_tags(self):
185
+ a = {"tags": ["a", "b"]}
186
+ b = {"other": "data"}
187
+ merge_dicts_with_paths(a, b, (), set())
188
+ assert set(a["tags"]) == {"a", "b"}
189
+
190
+ def test_tags_are_replaced_when_included_in_merge_paths(self):
191
+ a = {"tags": ["a", "b"]}
192
+ b = {"tags": ["c", "d"]}
193
+ merge_dicts_with_paths(a, b, (), {("tags",)})
194
+ assert a["tags"] == ["c", "d"]
195
+
196
+ def test_empty_tags_array_clears_tags_when_in_merge_paths(self):
197
+ a = {"tags": ["a", "b"]}
198
+ b = {"tags": []}
199
+ merge_dicts_with_paths(a, b, (), {("tags",)})
200
+ assert a["tags"] == []
201
+
202
+ def test_none_tags_replaces_tags(self):
203
+ a = {"tags": ["a", "b"]}
204
+ b = {"tags": None}
205
+ merge_dicts_with_paths(a, b, (), set())
206
+ assert a["tags"] is None
207
+
208
+ def test_set_union_only_applies_to_top_level_tags_field(self):
209
+ a = {"metadata": {"tags": ["a", "b"]}}
210
+ b = {"metadata": {"tags": ["c", "d"]}}
211
+ merge_dicts_with_paths(a, b, (), set())
212
+ assert a["metadata"]["tags"] == ["c", "d"]
braintrust/util.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import inspect
2
+ import json
2
3
  import sys
3
4
  import threading
4
5
  import urllib.parse
@@ -29,11 +30,16 @@ def coalesce(*args):
29
30
  return None
30
31
 
31
32
 
33
+ # Fields that automatically use set-union merge semantics (unless in merge_paths).
34
+ _SET_UNION_FIELDS = frozenset(["tags"])
35
+
36
+
32
37
  def merge_dicts_with_paths(
33
- merge_into: dict[str, Any], merge_from: Mapping[str, Any], path: tuple[str, ...], merge_paths: set[tuple[str]]
38
+ merge_into: dict[str, Any], merge_from: Mapping[str, Any], path: tuple[str, ...], merge_paths: set[tuple[str, ...]]
34
39
  ) -> dict[str, Any]:
35
40
  """Merges merge_from into merge_into, destructively updating merge_into. Does not merge any further than
36
- merge_paths."""
41
+ merge_paths. For fields in _SET_UNION_FIELDS (like "tags"), arrays are merged as sets (union)
42
+ unless the field is explicitly listed in merge_paths (opt-out to replacement)."""
37
43
 
38
44
  if not isinstance(merge_into, dict):
39
45
  raise ValueError("merge_into must be a dictionary")
@@ -43,7 +49,22 @@ def merge_dicts_with_paths(
43
49
  for k, merge_from_v in merge_from.items():
44
50
  full_path = path + (k,)
45
51
  merge_into_v = merge_into.get(k)
46
- if isinstance(merge_into_v, dict) and isinstance(merge_from_v, dict) and full_path not in merge_paths:
52
+
53
+ # Check if this field should use set-union merge (e.g., "tags" at top level)
54
+ is_set_union_field = len(path) == 0 and k in _SET_UNION_FIELDS and full_path not in merge_paths
55
+
56
+ if is_set_union_field and isinstance(merge_into_v, list) and isinstance(merge_from_v, list):
57
+ # Set-union merge: combine arrays, deduplicate using JSON for objects
58
+ seen: set[str] = set()
59
+ combined = []
60
+ for item in merge_into_v + list(merge_from_v):
61
+ # Use JSON serialization for consistent object comparison
62
+ item_key = json.dumps(item, sort_keys=True) if isinstance(item, (dict, list)) else str(item)
63
+ if item_key not in seen:
64
+ seen.add(item_key)
65
+ combined.append(item)
66
+ merge_into[k] = combined
67
+ elif isinstance(merge_into_v, dict) and isinstance(merge_from_v, dict) and full_path not in merge_paths:
47
68
  merge_dicts_with_paths(merge_into_v, merge_from_v, full_path, merge_paths)
48
69
  else:
49
70
  merge_into[k] = merge_from_v
braintrust/version.py CHANGED
@@ -1,4 +1,4 @@
1
- VERSION = "0.4.0"
1
+ VERSION = "0.4.2"
2
2
 
3
3
  # this will be templated during the build
4
- GIT_COMMIT = "8ab13f3f48af6a4d3c0b053e4bbabfd4f24f23ec"
4
+ GIT_COMMIT = "3ca420e53e77d4665b91ccc7631c95dc97ce566d"
@@ -3,6 +3,7 @@ import time
3
3
  from collections.abc import Iterable
4
4
  from typing import Any
5
5
 
6
+ from braintrust.bt_json import bt_safe_deep_copy
6
7
  from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
7
8
  from braintrust.span_types import SpanTypeAttribute
8
9
  from wrapt import wrap_function_wrapper
@@ -149,7 +150,7 @@ def wrap_async_models(AsyncModels: Any):
149
150
 
150
151
 
151
152
  def _serialize_input(api_client: Any, input: dict[str, Any]):
152
- config = _try_dict(input.get("config"))
153
+ config = bt_safe_deep_copy(input.get("config"))
153
154
 
154
155
  if config is not None:
155
156
  tools = _serialize_tools(api_client, input)
@@ -424,17 +425,3 @@ def get_path(obj: dict[str, Any], path: str, default: Any = None) -> Any | None:
424
425
  current = current[key]
425
426
 
426
427
  return current
427
-
428
-
429
- def _try_dict(obj: Any) -> dict[str, Any] | None:
430
- try:
431
- return obj.model_dump()
432
- except AttributeError:
433
- pass
434
-
435
- try:
436
- return obj.dump()
437
- except AttributeError:
438
- pass
439
-
440
- return obj
@@ -657,9 +657,52 @@ def patch_litellm():
657
657
  import litellm
658
658
 
659
659
  if not hasattr(litellm, "_braintrust_wrapped"):
660
+ # Store originals for unpatch_litellm()
661
+ litellm._braintrust_original_completion = litellm.completion
662
+ litellm._braintrust_original_acompletion = litellm.acompletion
663
+ litellm._braintrust_original_responses = litellm.responses
664
+ litellm._braintrust_original_aresponses = litellm.aresponses
665
+
660
666
  wrapped = wrap_litellm(litellm)
661
667
  litellm.completion = wrapped.completion
662
668
  litellm.acompletion = wrapped.acompletion
669
+ litellm.responses = wrapped.responses
670
+ litellm.aresponses = wrapped.aresponses
663
671
  litellm._braintrust_wrapped = True
664
672
  except ImportError:
665
673
  pass # litellm not available
674
+
675
+
676
+ def unpatch_litellm():
677
+ """
678
+ Restore LiteLLM to its original state, removing Braintrust tracing.
679
+
680
+ This undoes the patching done by patch_litellm(), restoring the original
681
+ completion, acompletion, responses, and aresponses functions.
682
+
683
+ Example:
684
+ ```python
685
+ import braintrust
686
+ braintrust.patch_litellm()
687
+
688
+ # ... use litellm with tracing ...
689
+
690
+ braintrust.unpatch_litellm() # restore original behavior
691
+ ```
692
+ """
693
+ try:
694
+ import litellm
695
+
696
+ if hasattr(litellm, "_braintrust_wrapped"):
697
+ litellm.completion = litellm._braintrust_original_completion
698
+ litellm.acompletion = litellm._braintrust_original_acompletion
699
+ litellm.responses = litellm._braintrust_original_responses
700
+ litellm.aresponses = litellm._braintrust_original_aresponses
701
+
702
+ delattr(litellm, "_braintrust_wrapped")
703
+ delattr(litellm, "_braintrust_original_completion")
704
+ delattr(litellm, "_braintrust_original_acompletion")
705
+ delattr(litellm, "_braintrust_original_responses")
706
+ delattr(litellm, "_braintrust_original_aresponses")
707
+ except ImportError:
708
+ pass # litellm not available