braintrust 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/bt_json.py +178 -19
- braintrust/db_fields.py +1 -0
- braintrust/framework.py +13 -4
- braintrust/logger.py +30 -120
- braintrust/otel/__init__.py +24 -15
- braintrust/test_bt_json.py +644 -0
- braintrust/test_framework.py +81 -0
- braintrust/test_logger.py +245 -107
- braintrust/test_otel.py +118 -26
- braintrust/test_util.py +51 -1
- braintrust/util.py +24 -3
- braintrust/version.py +2 -2
- braintrust/wrappers/google_genai/__init__.py +2 -15
- braintrust/wrappers/litellm.py +43 -0
- braintrust/wrappers/pydantic_ai.py +209 -95
- braintrust/wrappers/test_google_genai.py +62 -1
- braintrust/wrappers/test_litellm.py +73 -0
- braintrust/wrappers/test_pydantic_ai_integration.py +819 -22
- {braintrust-0.4.0.dist-info → braintrust-0.4.2.dist-info}/METADATA +1 -1
- {braintrust-0.4.0.dist-info → braintrust-0.4.2.dist-info}/RECORD +23 -22
- {braintrust-0.4.0.dist-info → braintrust-0.4.2.dist-info}/WHEEL +0 -0
- {braintrust-0.4.0.dist-info → braintrust-0.4.2.dist-info}/entry_points.txt +0 -0
- {braintrust-0.4.0.dist-info → braintrust-0.4.2.dist-info}/top_level.txt +0 -0
braintrust/test_otel.py
CHANGED
|
@@ -294,13 +294,12 @@ class TestSpanFiltering:
|
|
|
294
294
|
self.provider.shutdown()
|
|
295
295
|
self.memory_exporter.clear()
|
|
296
296
|
|
|
297
|
-
def
|
|
297
|
+
def test_filters_out_root_spans(self):
|
|
298
298
|
with self.tracer.start_as_current_span("root_operation"):
|
|
299
299
|
pass
|
|
300
300
|
|
|
301
301
|
spans = self.memory_exporter.get_finished_spans()
|
|
302
|
-
assert len(spans) ==
|
|
303
|
-
assert spans[0].name == "root_operation"
|
|
302
|
+
assert len(spans) == 0
|
|
304
303
|
|
|
305
304
|
def test_keeps_gen_ai_spans(self):
|
|
306
305
|
with self.tracer.start_as_current_span("root"):
|
|
@@ -312,7 +311,7 @@ class TestSpanFiltering:
|
|
|
312
311
|
spans = self.memory_exporter.get_finished_spans()
|
|
313
312
|
span_names = [span.name for span in spans]
|
|
314
313
|
|
|
315
|
-
assert "root" in span_names
|
|
314
|
+
assert "root" not in span_names
|
|
316
315
|
assert "gen_ai.completion" in span_names
|
|
317
316
|
assert "regular_operation" not in span_names
|
|
318
317
|
|
|
@@ -329,35 +328,37 @@ class TestSpanFiltering:
|
|
|
329
328
|
assert "braintrust.eval" in span_names
|
|
330
329
|
assert "database_query" not in span_names
|
|
331
330
|
|
|
332
|
-
def
|
|
331
|
+
def test_keeps_traceloop_spans(self):
|
|
333
332
|
with self.tracer.start_as_current_span("root"):
|
|
334
|
-
with self.tracer.start_as_current_span("
|
|
333
|
+
with self.tracer.start_as_current_span("traceloop.agent"):
|
|
334
|
+
pass
|
|
335
|
+
with self.tracer.start_as_current_span("traceloop.workflow.step"):
|
|
335
336
|
pass
|
|
336
337
|
|
|
337
338
|
spans = self.memory_exporter.get_finished_spans()
|
|
338
339
|
span_names = [span.name for span in spans]
|
|
339
|
-
assert "
|
|
340
|
+
assert "root" not in span_names
|
|
341
|
+
assert "traceloop.agent" in span_names
|
|
342
|
+
assert "traceloop.workflow.step" in span_names
|
|
340
343
|
|
|
341
|
-
def
|
|
344
|
+
def test_keeps_llm_spans(self):
|
|
342
345
|
with self.tracer.start_as_current_span("root"):
|
|
343
|
-
with self.tracer.start_as_current_span("
|
|
346
|
+
with self.tracer.start_as_current_span("llm.generate"):
|
|
344
347
|
pass
|
|
345
348
|
|
|
346
349
|
spans = self.memory_exporter.get_finished_spans()
|
|
347
350
|
span_names = [span.name for span in spans]
|
|
348
|
-
assert "
|
|
351
|
+
assert "llm.generate" in span_names
|
|
349
352
|
|
|
350
|
-
def
|
|
353
|
+
def test_keeps_ai_spans(self):
|
|
351
354
|
with self.tracer.start_as_current_span("root"):
|
|
352
|
-
with self.tracer.start_as_current_span("
|
|
353
|
-
pass
|
|
354
|
-
with self.tracer.start_as_current_span("traceloop.workflow.step"):
|
|
355
|
+
with self.tracer.start_as_current_span("ai.model_call"):
|
|
355
356
|
pass
|
|
356
357
|
|
|
357
358
|
spans = self.memory_exporter.get_finished_spans()
|
|
358
359
|
span_names = [span.name for span in spans]
|
|
359
|
-
assert "
|
|
360
|
-
assert "
|
|
360
|
+
assert "root" not in span_names
|
|
361
|
+
assert "ai.model_call" in span_names
|
|
361
362
|
|
|
362
363
|
def test_keeps_spans_with_llm_attributes(self):
|
|
363
364
|
with self.tracer.start_as_current_span("root"):
|
|
@@ -374,7 +375,7 @@ class TestSpanFiltering:
|
|
|
374
375
|
spans = self.memory_exporter.get_finished_spans()
|
|
375
376
|
span_names = [span.name for span in spans]
|
|
376
377
|
|
|
377
|
-
assert "root" in span_names
|
|
378
|
+
assert "root" not in span_names
|
|
378
379
|
assert "some_operation" in span_names # has gen_ai.model attribute
|
|
379
380
|
assert "another_operation" in span_names # has llm.tokens attribute
|
|
380
381
|
assert "traceloop_operation" in span_names # has traceloop.agent_id attribute
|
|
@@ -390,10 +391,7 @@ class TestSpanFiltering:
|
|
|
390
391
|
pass
|
|
391
392
|
|
|
392
393
|
spans = self.memory_exporter.get_finished_spans()
|
|
393
|
-
|
|
394
|
-
# Only root should be kept
|
|
395
|
-
assert len(spans) == 1
|
|
396
|
-
assert spans[0].name == "root"
|
|
394
|
+
assert len(spans) == 0
|
|
397
395
|
|
|
398
396
|
def test_custom_filter_keeps_spans(self):
|
|
399
397
|
def custom_filter(span):
|
|
@@ -422,9 +420,9 @@ class TestSpanFiltering:
|
|
|
422
420
|
spans = memory_exporter.get_finished_spans()
|
|
423
421
|
span_names = [span.name for span in spans]
|
|
424
422
|
|
|
425
|
-
assert "root" in span_names
|
|
426
423
|
assert "custom_keep" in span_names # kept by custom filter
|
|
427
424
|
assert "regular_operation" not in span_names # dropped by default logic
|
|
425
|
+
assert "root" not in span_names
|
|
428
426
|
|
|
429
427
|
def test_custom_filter_drops_spans(self):
|
|
430
428
|
def custom_filter(span):
|
|
@@ -453,9 +451,9 @@ class TestSpanFiltering:
|
|
|
453
451
|
spans = memory_exporter.get_finished_spans()
|
|
454
452
|
span_names = [span.name for span in spans]
|
|
455
453
|
|
|
456
|
-
assert "root" in span_names
|
|
457
454
|
assert "gen_ai.drop_this" not in span_names # dropped by custom filter
|
|
458
455
|
assert "gen_ai.keep_this" in span_names # kept by default LLM logic
|
|
456
|
+
assert "root" not in span_names
|
|
459
457
|
|
|
460
458
|
def test_custom_filter_none_uses_default_logic(self):
|
|
461
459
|
def custom_filter(span):
|
|
@@ -482,7 +480,7 @@ class TestSpanFiltering:
|
|
|
482
480
|
spans = memory_exporter.get_finished_spans()
|
|
483
481
|
span_names = [span.name for span in spans]
|
|
484
482
|
|
|
485
|
-
assert "root" in span_names
|
|
483
|
+
assert "root" not in span_names
|
|
486
484
|
assert "gen_ai.completion" in span_names # kept by default LLM logic
|
|
487
485
|
assert "regular_operation" not in span_names # dropped by default logic
|
|
488
486
|
|
|
@@ -546,11 +544,32 @@ class TestSpanFiltering:
|
|
|
546
544
|
filtered_spans = filtered_spans_exporter.get_finished_spans()
|
|
547
545
|
filtered_span_names = [span.name for span in filtered_spans]
|
|
548
546
|
|
|
549
|
-
assert len(filtered_spans) ==
|
|
550
|
-
assert "user_request" in filtered_span_names # root span
|
|
547
|
+
assert len(filtered_spans) == 2
|
|
548
|
+
assert "user_request" not in filtered_span_names # root span
|
|
551
549
|
assert "gen_ai.completion" in filtered_span_names # LLM name
|
|
552
550
|
assert "response_formatting" in filtered_span_names # LLM attribute
|
|
553
551
|
|
|
552
|
+
def test_custom_filter_is_root_span(self):
|
|
553
|
+
from braintrust.otel import AISpanProcessor, is_root_span
|
|
554
|
+
from opentelemetry.sdk.trace import TracerProvider
|
|
555
|
+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
|
556
|
+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
|
557
|
+
|
|
558
|
+
memory_exporter = InMemorySpanExporter()
|
|
559
|
+
processor = AISpanProcessor(SimpleSpanProcessor(memory_exporter), custom_filter=is_root_span)
|
|
560
|
+
provider = TracerProvider()
|
|
561
|
+
provider.add_span_processor(processor)
|
|
562
|
+
tracer = provider.get_tracer("test-braintrust-root-filter")
|
|
563
|
+
|
|
564
|
+
with tracer.start_as_current_span("root_span"):
|
|
565
|
+
with tracer.start_as_current_span("child_span"):
|
|
566
|
+
pass
|
|
567
|
+
|
|
568
|
+
provider.shutdown()
|
|
569
|
+
spans = memory_exporter.get_finished_spans()
|
|
570
|
+
names = [span.name for span in spans]
|
|
571
|
+
assert "root_span" in names
|
|
572
|
+
assert "child_span" not in names
|
|
554
573
|
|
|
555
574
|
def test_parent_from_headers_invalid_inputs():
|
|
556
575
|
"""Test parent_from_headers with various invalid inputs."""
|
|
@@ -716,3 +735,76 @@ def test_add_span_parent_to_baggage():
|
|
|
716
735
|
# Test with None span (should return None and warn)
|
|
717
736
|
token = add_span_parent_to_baggage(None)
|
|
718
737
|
assert token is None
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def test_parent_from_headers_with_custom_propagator():
|
|
741
|
+
"""Test parent_from_headers with a custom propagator."""
|
|
742
|
+
if not _check_otel_installed():
|
|
743
|
+
pytest.skip("OpenTelemetry SDK not fully installed, skipping test")
|
|
744
|
+
|
|
745
|
+
from braintrust.otel import parent_from_headers
|
|
746
|
+
from opentelemetry import baggage as otel_baggage
|
|
747
|
+
from opentelemetry import context as otel_context
|
|
748
|
+
from opentelemetry import trace
|
|
749
|
+
from opentelemetry.propagators.textmap import CarrierT, Getter, TextMapPropagator, default_getter
|
|
750
|
+
from opentelemetry.trace import NonRecordingSpan, SpanContext, TraceFlags
|
|
751
|
+
|
|
752
|
+
class CustomHeaderPropagator(TextMapPropagator):
|
|
753
|
+
"""Custom propagator that reads trace context from X-Custom-* headers."""
|
|
754
|
+
|
|
755
|
+
def extract(
|
|
756
|
+
self,
|
|
757
|
+
carrier: CarrierT,
|
|
758
|
+
context: otel_context.Context | None = None,
|
|
759
|
+
getter: Getter = default_getter,
|
|
760
|
+
) -> otel_context.Context:
|
|
761
|
+
if context is None:
|
|
762
|
+
context = otel_context.get_current()
|
|
763
|
+
|
|
764
|
+
trace_id = getter.get(carrier, "X-Custom-Trace-Id")
|
|
765
|
+
span_id = getter.get(carrier, "X-Custom-Span-Id")
|
|
766
|
+
|
|
767
|
+
if trace_id and span_id:
|
|
768
|
+
trace_id_list = trace_id if isinstance(trace_id, list) else [trace_id]
|
|
769
|
+
span_id_list = span_id if isinstance(span_id, list) else [span_id]
|
|
770
|
+
|
|
771
|
+
span_context = SpanContext(
|
|
772
|
+
trace_id=int(trace_id_list[0], 16),
|
|
773
|
+
span_id=int(span_id_list[0], 16),
|
|
774
|
+
is_remote=True,
|
|
775
|
+
trace_flags=TraceFlags.SAMPLED,
|
|
776
|
+
)
|
|
777
|
+
span = NonRecordingSpan(span_context)
|
|
778
|
+
context = trace.set_span_in_context(span, context)
|
|
779
|
+
|
|
780
|
+
# Also extract baggage from standard baggage header
|
|
781
|
+
baggage_header = getter.get(carrier, "baggage")
|
|
782
|
+
if baggage_header:
|
|
783
|
+
baggage_list = baggage_header if isinstance(baggage_header, list) else [baggage_header]
|
|
784
|
+
for item in baggage_list[0].split(","):
|
|
785
|
+
if "=" in item:
|
|
786
|
+
key, value = item.split("=", 1)
|
|
787
|
+
context = otel_baggage.set_baggage(key.strip(), value.strip(), context)
|
|
788
|
+
|
|
789
|
+
return context
|
|
790
|
+
|
|
791
|
+
def inject(self, carrier, context=None, setter=None):
|
|
792
|
+
pass # Not needed for this test
|
|
793
|
+
|
|
794
|
+
@property
|
|
795
|
+
def fields(self):
|
|
796
|
+
return {"X-Custom-Trace-Id", "X-Custom-Span-Id", "baggage"}
|
|
797
|
+
|
|
798
|
+
propagator = CustomHeaderPropagator()
|
|
799
|
+
|
|
800
|
+
# Custom header format
|
|
801
|
+
headers = {
|
|
802
|
+
"X-Custom-Trace-Id": "4bf92f3577b34da6a3ce929d0e0e4736",
|
|
803
|
+
"X-Custom-Span-Id": "00f067aa0ba902b7",
|
|
804
|
+
"baggage": "braintrust.parent=project_name:test-project",
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
result = parent_from_headers(headers, propagator=propagator)
|
|
808
|
+
assert result is not None
|
|
809
|
+
assert isinstance(result, str)
|
|
810
|
+
assert len(result) > 0
|
braintrust/test_util.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import List
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
|
|
6
|
-
from .util import LazyValue, mask_api_key
|
|
6
|
+
from .util import LazyValue, mask_api_key, merge_dicts_with_paths
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class TestLazyValue(unittest.TestCase):
|
|
@@ -160,3 +160,53 @@ def test_mask_api_key():
|
|
|
160
160
|
assert mask_api_key("12345") == "12*45"
|
|
161
161
|
for i in ["", "1", "12", "123", "1234"]:
|
|
162
162
|
assert mask_api_key(i) == "*" * len(i)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class TestTagsSetUnionMerge:
|
|
166
|
+
def test_tags_arrays_are_merged_as_sets_by_default(self):
|
|
167
|
+
a = {"tags": ["a", "b"]}
|
|
168
|
+
b = {"tags": ["b", "c"]}
|
|
169
|
+
merge_dicts_with_paths(a, b, (), set())
|
|
170
|
+
assert set(a["tags"]) == {"a", "b", "c"}
|
|
171
|
+
|
|
172
|
+
def test_tags_merge_deduplicates_values(self):
|
|
173
|
+
a = {"tags": ["a", "b", "c"]}
|
|
174
|
+
b = {"tags": ["a", "b", "c", "d"]}
|
|
175
|
+
merge_dicts_with_paths(a, b, (), set())
|
|
176
|
+
assert set(a["tags"]) == {"a", "b", "c", "d"}
|
|
177
|
+
|
|
178
|
+
def test_tags_merge_works_when_merge_into_has_no_tags(self):
|
|
179
|
+
a = {"other": "data"}
|
|
180
|
+
b = {"tags": ["a", "b"]}
|
|
181
|
+
merge_dicts_with_paths(a, b, (), set())
|
|
182
|
+
assert set(a["tags"]) == {"a", "b"}
|
|
183
|
+
|
|
184
|
+
def test_tags_merge_works_when_merge_from_has_no_tags(self):
|
|
185
|
+
a = {"tags": ["a", "b"]}
|
|
186
|
+
b = {"other": "data"}
|
|
187
|
+
merge_dicts_with_paths(a, b, (), set())
|
|
188
|
+
assert set(a["tags"]) == {"a", "b"}
|
|
189
|
+
|
|
190
|
+
def test_tags_are_replaced_when_included_in_merge_paths(self):
|
|
191
|
+
a = {"tags": ["a", "b"]}
|
|
192
|
+
b = {"tags": ["c", "d"]}
|
|
193
|
+
merge_dicts_with_paths(a, b, (), {("tags",)})
|
|
194
|
+
assert a["tags"] == ["c", "d"]
|
|
195
|
+
|
|
196
|
+
def test_empty_tags_array_clears_tags_when_in_merge_paths(self):
|
|
197
|
+
a = {"tags": ["a", "b"]}
|
|
198
|
+
b = {"tags": []}
|
|
199
|
+
merge_dicts_with_paths(a, b, (), {("tags",)})
|
|
200
|
+
assert a["tags"] == []
|
|
201
|
+
|
|
202
|
+
def test_none_tags_replaces_tags(self):
|
|
203
|
+
a = {"tags": ["a", "b"]}
|
|
204
|
+
b = {"tags": None}
|
|
205
|
+
merge_dicts_with_paths(a, b, (), set())
|
|
206
|
+
assert a["tags"] is None
|
|
207
|
+
|
|
208
|
+
def test_set_union_only_applies_to_top_level_tags_field(self):
|
|
209
|
+
a = {"metadata": {"tags": ["a", "b"]}}
|
|
210
|
+
b = {"metadata": {"tags": ["c", "d"]}}
|
|
211
|
+
merge_dicts_with_paths(a, b, (), set())
|
|
212
|
+
assert a["metadata"]["tags"] == ["c", "d"]
|
braintrust/util.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import json
|
|
2
3
|
import sys
|
|
3
4
|
import threading
|
|
4
5
|
import urllib.parse
|
|
@@ -29,11 +30,16 @@ def coalesce(*args):
|
|
|
29
30
|
return None
|
|
30
31
|
|
|
31
32
|
|
|
33
|
+
# Fields that automatically use set-union merge semantics (unless in merge_paths).
|
|
34
|
+
_SET_UNION_FIELDS = frozenset(["tags"])
|
|
35
|
+
|
|
36
|
+
|
|
32
37
|
def merge_dicts_with_paths(
|
|
33
|
-
merge_into: dict[str, Any], merge_from: Mapping[str, Any], path: tuple[str, ...], merge_paths: set[tuple[str]]
|
|
38
|
+
merge_into: dict[str, Any], merge_from: Mapping[str, Any], path: tuple[str, ...], merge_paths: set[tuple[str, ...]]
|
|
34
39
|
) -> dict[str, Any]:
|
|
35
40
|
"""Merges merge_from into merge_into, destructively updating merge_into. Does not merge any further than
|
|
36
|
-
merge_paths.""
|
|
41
|
+
merge_paths. For fields in _SET_UNION_FIELDS (like "tags"), arrays are merged as sets (union)
|
|
42
|
+
unless the field is explicitly listed in merge_paths (opt-out to replacement)."""
|
|
37
43
|
|
|
38
44
|
if not isinstance(merge_into, dict):
|
|
39
45
|
raise ValueError("merge_into must be a dictionary")
|
|
@@ -43,7 +49,22 @@ def merge_dicts_with_paths(
|
|
|
43
49
|
for k, merge_from_v in merge_from.items():
|
|
44
50
|
full_path = path + (k,)
|
|
45
51
|
merge_into_v = merge_into.get(k)
|
|
46
|
-
|
|
52
|
+
|
|
53
|
+
# Check if this field should use set-union merge (e.g., "tags" at top level)
|
|
54
|
+
is_set_union_field = len(path) == 0 and k in _SET_UNION_FIELDS and full_path not in merge_paths
|
|
55
|
+
|
|
56
|
+
if is_set_union_field and isinstance(merge_into_v, list) and isinstance(merge_from_v, list):
|
|
57
|
+
# Set-union merge: combine arrays, deduplicate using JSON for objects
|
|
58
|
+
seen: set[str] = set()
|
|
59
|
+
combined = []
|
|
60
|
+
for item in merge_into_v + list(merge_from_v):
|
|
61
|
+
# Use JSON serialization for consistent object comparison
|
|
62
|
+
item_key = json.dumps(item, sort_keys=True) if isinstance(item, (dict, list)) else str(item)
|
|
63
|
+
if item_key not in seen:
|
|
64
|
+
seen.add(item_key)
|
|
65
|
+
combined.append(item)
|
|
66
|
+
merge_into[k] = combined
|
|
67
|
+
elif isinstance(merge_into_v, dict) and isinstance(merge_from_v, dict) and full_path not in merge_paths:
|
|
47
68
|
merge_dicts_with_paths(merge_into_v, merge_from_v, full_path, merge_paths)
|
|
48
69
|
else:
|
|
49
70
|
merge_into[k] = merge_from_v
|
braintrust/version.py
CHANGED
|
@@ -3,6 +3,7 @@ import time
|
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
|
+
from braintrust.bt_json import bt_safe_deep_copy
|
|
6
7
|
from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
|
|
7
8
|
from braintrust.span_types import SpanTypeAttribute
|
|
8
9
|
from wrapt import wrap_function_wrapper
|
|
@@ -149,7 +150,7 @@ def wrap_async_models(AsyncModels: Any):
|
|
|
149
150
|
|
|
150
151
|
|
|
151
152
|
def _serialize_input(api_client: Any, input: dict[str, Any]):
|
|
152
|
-
config =
|
|
153
|
+
config = bt_safe_deep_copy(input.get("config"))
|
|
153
154
|
|
|
154
155
|
if config is not None:
|
|
155
156
|
tools = _serialize_tools(api_client, input)
|
|
@@ -424,17 +425,3 @@ def get_path(obj: dict[str, Any], path: str, default: Any = None) -> Any | None:
|
|
|
424
425
|
current = current[key]
|
|
425
426
|
|
|
426
427
|
return current
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
def _try_dict(obj: Any) -> dict[str, Any] | None:
|
|
430
|
-
try:
|
|
431
|
-
return obj.model_dump()
|
|
432
|
-
except AttributeError:
|
|
433
|
-
pass
|
|
434
|
-
|
|
435
|
-
try:
|
|
436
|
-
return obj.dump()
|
|
437
|
-
except AttributeError:
|
|
438
|
-
pass
|
|
439
|
-
|
|
440
|
-
return obj
|
braintrust/wrappers/litellm.py
CHANGED
|
@@ -657,9 +657,52 @@ def patch_litellm():
|
|
|
657
657
|
import litellm
|
|
658
658
|
|
|
659
659
|
if not hasattr(litellm, "_braintrust_wrapped"):
|
|
660
|
+
# Store originals for unpatch_litellm()
|
|
661
|
+
litellm._braintrust_original_completion = litellm.completion
|
|
662
|
+
litellm._braintrust_original_acompletion = litellm.acompletion
|
|
663
|
+
litellm._braintrust_original_responses = litellm.responses
|
|
664
|
+
litellm._braintrust_original_aresponses = litellm.aresponses
|
|
665
|
+
|
|
660
666
|
wrapped = wrap_litellm(litellm)
|
|
661
667
|
litellm.completion = wrapped.completion
|
|
662
668
|
litellm.acompletion = wrapped.acompletion
|
|
669
|
+
litellm.responses = wrapped.responses
|
|
670
|
+
litellm.aresponses = wrapped.aresponses
|
|
663
671
|
litellm._braintrust_wrapped = True
|
|
664
672
|
except ImportError:
|
|
665
673
|
pass # litellm not available
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
def unpatch_litellm():
|
|
677
|
+
"""
|
|
678
|
+
Restore LiteLLM to its original state, removing Braintrust tracing.
|
|
679
|
+
|
|
680
|
+
This undoes the patching done by patch_litellm(), restoring the original
|
|
681
|
+
completion, acompletion, responses, and aresponses functions.
|
|
682
|
+
|
|
683
|
+
Example:
|
|
684
|
+
```python
|
|
685
|
+
import braintrust
|
|
686
|
+
braintrust.patch_litellm()
|
|
687
|
+
|
|
688
|
+
# ... use litellm with tracing ...
|
|
689
|
+
|
|
690
|
+
braintrust.unpatch_litellm() # restore original behavior
|
|
691
|
+
```
|
|
692
|
+
"""
|
|
693
|
+
try:
|
|
694
|
+
import litellm
|
|
695
|
+
|
|
696
|
+
if hasattr(litellm, "_braintrust_wrapped"):
|
|
697
|
+
litellm.completion = litellm._braintrust_original_completion
|
|
698
|
+
litellm.acompletion = litellm._braintrust_original_acompletion
|
|
699
|
+
litellm.responses = litellm._braintrust_original_responses
|
|
700
|
+
litellm.aresponses = litellm._braintrust_original_aresponses
|
|
701
|
+
|
|
702
|
+
delattr(litellm, "_braintrust_wrapped")
|
|
703
|
+
delattr(litellm, "_braintrust_original_completion")
|
|
704
|
+
delattr(litellm, "_braintrust_original_acompletion")
|
|
705
|
+
delattr(litellm, "_braintrust_original_responses")
|
|
706
|
+
delattr(litellm, "_braintrust_original_aresponses")
|
|
707
|
+
except ImportError:
|
|
708
|
+
pass # litellm not available
|