braintrust 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. braintrust/__init__.py +3 -0
  2. braintrust/auto.py +179 -0
  3. braintrust/conftest.py +23 -4
  4. braintrust/framework.py +18 -5
  5. braintrust/logger.py +49 -13
  6. braintrust/oai.py +51 -0
  7. braintrust/test_bt_json.py +0 -5
  8. braintrust/test_framework.py +37 -0
  9. braintrust/test_http.py +444 -0
  10. braintrust/test_logger.py +179 -5
  11. braintrust/test_util.py +58 -1
  12. braintrust/util.py +20 -0
  13. braintrust/version.py +2 -2
  14. braintrust/wrappers/agno/__init__.py +2 -3
  15. braintrust/wrappers/anthropic.py +64 -0
  16. braintrust/wrappers/claude_agent_sdk/__init__.py +2 -3
  17. braintrust/wrappers/claude_agent_sdk/test_wrapper.py +9 -0
  18. braintrust/wrappers/dspy.py +52 -1
  19. braintrust/wrappers/google_genai/__init__.py +9 -6
  20. braintrust/wrappers/litellm.py +6 -43
  21. braintrust/wrappers/pydantic_ai.py +2 -3
  22. braintrust/wrappers/test_agno.py +9 -0
  23. braintrust/wrappers/test_anthropic.py +156 -0
  24. braintrust/wrappers/test_dspy.py +117 -0
  25. braintrust/wrappers/test_google_genai.py +9 -0
  26. braintrust/wrappers/test_litellm.py +57 -55
  27. braintrust/wrappers/test_openai.py +253 -1
  28. braintrust/wrappers/test_pydantic_ai_integration.py +9 -0
  29. braintrust/wrappers/test_utils.py +79 -0
  30. {braintrust-0.5.0.dist-info → braintrust-0.5.2.dist-info}/METADATA +1 -1
  31. {braintrust-0.5.0.dist-info → braintrust-0.5.2.dist-info}/RECORD +34 -32
  32. {braintrust-0.5.0.dist-info → braintrust-0.5.2.dist-info}/WHEEL +1 -1
  33. {braintrust-0.5.0.dist-info → braintrust-0.5.2.dist-info}/entry_points.txt +0 -0
  34. {braintrust-0.5.0.dist-info → braintrust-0.5.2.dist-info}/top_level.txt +0 -0
braintrust/__init__.py CHANGED
@@ -50,6 +50,9 @@ BRAINTRUST_API_KEY=<YOUR_BRAINTRUST_API_KEY> braintrust eval eval_hello.py
50
50
  """
51
51
 
52
52
  from .audit import *
53
+ from .auto import (
54
+ auto_instrument, # noqa: F401 # type: ignore[reportUnusedImport]
55
+ )
53
56
  from .framework import *
54
57
  from .framework2 import *
55
58
  from .functions.invoke import *
braintrust/auto.py ADDED
@@ -0,0 +1,179 @@
1
+ """
2
+ Auto-instrumentation for AI/ML libraries.
3
+
4
+ Provides one-line instrumentation for supported libraries.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from contextlib import contextmanager
11
+
12
+ __all__ = ["auto_instrument"]
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @contextmanager
18
+ def _try_patch():
19
+ """Context manager that suppresses ImportError and logs other exceptions."""
20
+ try:
21
+ yield
22
+ except ImportError:
23
+ pass
24
+ except Exception:
25
+ logger.exception("Failed to instrument")
26
+
27
+
28
+ def auto_instrument(
29
+ *,
30
+ openai: bool = True,
31
+ anthropic: bool = True,
32
+ litellm: bool = True,
33
+ pydantic_ai: bool = True,
34
+ google_genai: bool = True,
35
+ agno: bool = True,
36
+ claude_agent_sdk: bool = True,
37
+ dspy: bool = True,
38
+ ) -> dict[str, bool]:
39
+ """
40
+ Auto-instrument supported AI/ML libraries for Braintrust tracing.
41
+
42
+ Safe to call multiple times - already instrumented libraries are skipped.
43
+
44
+ Note on import order: If you use `from openai import OpenAI` style imports,
45
+ call auto_instrument() first. If you use `import openai` style imports,
46
+ order doesn't matter since attribute lookup happens dynamically.
47
+
48
+ Args:
49
+ openai: Enable OpenAI instrumentation (default: True)
50
+ anthropic: Enable Anthropic instrumentation (default: True)
51
+ litellm: Enable LiteLLM instrumentation (default: True)
52
+ pydantic_ai: Enable Pydantic AI instrumentation (default: True)
53
+ google_genai: Enable Google GenAI instrumentation (default: True)
54
+ agno: Enable Agno instrumentation (default: True)
55
+ claude_agent_sdk: Enable Claude Agent SDK instrumentation (default: True)
56
+ dspy: Enable DSPy instrumentation (default: True)
57
+
58
+ Returns:
59
+ Dict mapping integration name to whether it was successfully instrumented.
60
+
61
+ Example:
62
+ ```python
63
+ import braintrust
64
+ braintrust.auto_instrument()
65
+
66
+ # OpenAI
67
+ import openai
68
+ client = openai.OpenAI()
69
+ client.chat.completions.create(model="gpt-4o-mini", messages=[...])
70
+
71
+ # Anthropic
72
+ import anthropic
73
+ client = anthropic.Anthropic()
74
+ client.messages.create(model="claude-sonnet-4-20250514", messages=[...])
75
+
76
+ # LiteLLM
77
+ import litellm
78
+ litellm.completion(model="gpt-4o-mini", messages=[...])
79
+
80
+ # DSPy
81
+ import dspy
82
+ lm = dspy.LM("openai/gpt-4o-mini")
83
+ dspy.configure(lm=lm)
84
+
85
+ # Pydantic AI
86
+ from pydantic_ai import Agent
87
+ agent = Agent("openai:gpt-4o-mini")
88
+ result = agent.run_sync("Hello!")
89
+
90
+ # Google GenAI
91
+ from google.genai import Client
92
+ client = Client()
93
+ client.models.generate_content(model="gemini-2.0-flash", contents="Hello!")
94
+ ```
95
+ """
96
+ results = {}
97
+
98
+ if openai:
99
+ results["openai"] = _instrument_openai()
100
+ if anthropic:
101
+ results["anthropic"] = _instrument_anthropic()
102
+ if litellm:
103
+ results["litellm"] = _instrument_litellm()
104
+ if pydantic_ai:
105
+ results["pydantic_ai"] = _instrument_pydantic_ai()
106
+ if google_genai:
107
+ results["google_genai"] = _instrument_google_genai()
108
+ if agno:
109
+ results["agno"] = _instrument_agno()
110
+ if claude_agent_sdk:
111
+ results["claude_agent_sdk"] = _instrument_claude_agent_sdk()
112
+ if dspy:
113
+ results["dspy"] = _instrument_dspy()
114
+
115
+ return results
116
+
117
+
118
+ def _instrument_openai() -> bool:
119
+ with _try_patch():
120
+ from braintrust.oai import patch_openai
121
+
122
+ return patch_openai()
123
+ return False
124
+
125
+
126
+ def _instrument_anthropic() -> bool:
127
+ with _try_patch():
128
+ from braintrust.wrappers.anthropic import patch_anthropic
129
+
130
+ return patch_anthropic()
131
+ return False
132
+
133
+
134
+ def _instrument_litellm() -> bool:
135
+ with _try_patch():
136
+ from braintrust.wrappers.litellm import patch_litellm
137
+
138
+ return patch_litellm()
139
+ return False
140
+
141
+
142
+ def _instrument_pydantic_ai() -> bool:
143
+ with _try_patch():
144
+ from braintrust.wrappers.pydantic_ai import setup_pydantic_ai
145
+
146
+ return setup_pydantic_ai()
147
+ return False
148
+
149
+
150
+ def _instrument_google_genai() -> bool:
151
+ with _try_patch():
152
+ from braintrust.wrappers.google_genai import setup_genai
153
+
154
+ return setup_genai()
155
+ return False
156
+
157
+
158
+ def _instrument_agno() -> bool:
159
+ with _try_patch():
160
+ from braintrust.wrappers.agno import setup_agno
161
+
162
+ return setup_agno()
163
+ return False
164
+
165
+
166
+ def _instrument_claude_agent_sdk() -> bool:
167
+ with _try_patch():
168
+ from braintrust.wrappers.claude_agent_sdk import setup_claude_agent_sdk
169
+
170
+ return setup_claude_agent_sdk()
171
+ return False
172
+
173
+
174
+ def _instrument_dspy() -> bool:
175
+ with _try_patch():
176
+ from braintrust.wrappers.dspy import patch_dspy
177
+
178
+ return patch_dspy()
179
+ return False
braintrust/conftest.py CHANGED
@@ -48,16 +48,29 @@ def reset_braintrust_state():
48
48
  logger._state = logger.BraintrustState()
49
49
 
50
50
 
51
- @pytest.fixture(scope="session")
52
- def vcr_config():
51
+ @pytest.fixture(autouse=True)
52
+ def skip_vcr_tests_in_wheel_mode(request):
53
+ """Skip VCR tests when running from an installed wheel.
54
+
55
+ Wheel mode (BRAINTRUST_TESTING_WHEEL=1) is a pre-release sanity check
56
+ that verifies the built package installs and runs correctly. It's not
57
+ intended to be a full test suite - VCR cassettes are not included in
58
+ the wheel, so we skip those tests here. The full test suite with VCR
59
+ tests runs against source code during normal CI.
60
+ """
61
+ if os.environ.get("BRAINTRUST_TESTING_WHEEL") == "1":
62
+ if request.node.get_closest_marker("vcr"):
63
+ pytest.skip("VCR tests skipped in wheel mode (pre-release sanity check only)")
64
+
65
+
66
+ def get_vcr_config():
53
67
  """
54
- VCR configuration for recording/playing back HTTP interactions.
68
+ Get VCR configuration for recording/playing back HTTP interactions.
55
69
 
56
70
  In CI, use "none" to fail if cassette is missing.
57
71
  Locally, use "once" to record new cassettes if they don't exist.
58
72
  """
59
73
  record_mode = "none" if (os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS")) else "once"
60
-
61
74
  return {
62
75
  "record_mode": record_mode,
63
76
  "filter_headers": [
@@ -70,3 +83,9 @@ def vcr_config():
70
83
  "x-bt-auth-token",
71
84
  ],
72
85
  }
86
+
87
+
88
+ @pytest.fixture(scope="session")
89
+ def vcr_config():
90
+ """Pytest fixture wrapper for get_vcr_config()."""
91
+ return get_vcr_config()
braintrust/framework.py CHANGED
@@ -673,6 +673,7 @@ def _EvalCommon(
673
673
  stream: Callable[[SSEProgressEvent], None] | None = None,
674
674
  parent: str | None = None,
675
675
  state: BraintrustState | None = None,
676
+ enable_cache: bool = True,
676
677
  ) -> Callable[[], Coroutine[Any, Any, EvalResultWithSummary[Input, Output]]]:
677
678
  """
678
679
  This helper is needed because in case of `_lazy_load`, we need to update
@@ -759,7 +760,7 @@ def _EvalCommon(
759
760
  async def run_to_completion():
760
761
  with parent_context(parent, state):
761
762
  try:
762
- ret = await run_evaluator(experiment, evaluator, 0, [], stream, state)
763
+ ret = await run_evaluator(experiment, evaluator, 0, [], stream, state, enable_cache)
763
764
  reporter.report_eval(evaluator, ret, verbose=True, jsonl=False)
764
765
  return ret
765
766
  finally:
@@ -798,6 +799,7 @@ async def EvalAsync(
798
799
  stream: Callable[[SSEProgressEvent], None] | None = None,
799
800
  parent: str | None = None,
800
801
  state: BraintrustState | None = None,
802
+ enable_cache: bool = True,
801
803
  ) -> EvalResultWithSummary[Input, Output]:
802
804
  """
803
805
  A function you can use to define an evaluator. This is a convenience wrapper around the `Evaluator` class.
@@ -855,6 +857,8 @@ async def EvalAsync(
855
857
  :param parent: If specified, instead of creating a new experiment object, the Eval() will populate
856
858
  the object or span specified by this parent.
857
859
  :param state: Optional BraintrustState to use for the evaluation. If not specified, the global login state will be used.
860
+ :param enable_cache: Whether to enable the span cache for this evaluation. Defaults to True. The span cache stores
861
+ span data on disk to minimize memory usage and allow scorers to read spans without server round-trips.
858
862
  :return: An `EvalResultWithSummary` object, which contains all results and a summary.
859
863
  """
860
864
  f = _EvalCommon(
@@ -883,6 +887,7 @@ async def EvalAsync(
883
887
  stream=stream,
884
888
  parent=parent,
885
889
  state=state,
890
+ enable_cache=enable_cache,
886
891
  )
887
892
 
888
893
  return await f()
@@ -918,6 +923,7 @@ def Eval(
918
923
  stream: Callable[[SSEProgressEvent], None] | None = None,
919
924
  parent: str | None = None,
920
925
  state: BraintrustState | None = None,
926
+ enable_cache: bool = True,
921
927
  ) -> EvalResultWithSummary[Input, Output]:
922
928
  """
923
929
  A function you can use to define an evaluator. This is a convenience wrapper around the `Evaluator` class.
@@ -975,6 +981,8 @@ def Eval(
975
981
  :param parent: If specified, instead of creating a new experiment object, the Eval() will populate
976
982
  the object or span specified by this parent.
977
983
  :param state: Optional BraintrustState to use for the evaluation. If not specified, the global login state will be used.
984
+ :param enable_cache: Whether to enable the span cache for this evaluation. Defaults to True. The span cache stores
985
+ span data on disk to minimize memory usage and allow scorers to read spans without server round-trips.
978
986
  :return: An `EvalResultWithSummary` object, which contains all results and a summary.
979
987
  """
980
988
 
@@ -1005,6 +1013,7 @@ def Eval(
1005
1013
  stream=stream,
1006
1014
  parent=parent,
1007
1015
  state=state,
1016
+ enable_cache=enable_cache,
1008
1017
  )
1009
1018
 
1010
1019
  # https://stackoverflow.com/questions/55409641/asyncio-run-cannot-be-called-from-a-running-event-loop-when-using-jupyter-no
@@ -1249,10 +1258,11 @@ async def run_evaluator(
1249
1258
  filters: list[Filter],
1250
1259
  stream: Callable[[SSEProgressEvent], None] | None = None,
1251
1260
  state: BraintrustState | None = None,
1261
+ enable_cache: bool = True,
1252
1262
  ) -> EvalResultWithSummary[Input, Output]:
1253
1263
  """Wrapper on _run_evaluator_internal that times out execution after evaluator.timeout."""
1254
1264
  results = await asyncio.wait_for(
1255
- _run_evaluator_internal(experiment, evaluator, position, filters, stream, state), evaluator.timeout
1265
+ _run_evaluator_internal(experiment, evaluator, position, filters, stream, state, enable_cache), evaluator.timeout
1256
1266
  )
1257
1267
 
1258
1268
  if experiment:
@@ -1280,6 +1290,7 @@ async def _run_evaluator_internal(
1280
1290
  filters: list[Filter],
1281
1291
  stream: Callable[[SSEProgressEvent], None] | None = None,
1282
1292
  state: BraintrustState | None = None,
1293
+ enable_cache: bool = True,
1283
1294
  ):
1284
1295
  # Start span cache for this eval (it's disabled by default to avoid temp files outside of evals)
1285
1296
  if state is None:
@@ -1287,13 +1298,15 @@ async def _run_evaluator_internal(
1287
1298
 
1288
1299
  state = _internal_get_global_state()
1289
1300
 
1290
- state.span_cache.start()
1301
+ if enable_cache:
1302
+ state.span_cache.start()
1291
1303
  try:
1292
1304
  return await _run_evaluator_internal_impl(experiment, evaluator, position, filters, stream, state)
1293
1305
  finally:
1294
1306
  # Clean up disk-based span cache after eval completes and stop caching
1295
- state.span_cache.dispose()
1296
- state.span_cache.stop()
1307
+ if enable_cache:
1308
+ state.span_cache.dispose()
1309
+ state.span_cache.stop()
1297
1310
 
1298
1311
 
1299
1312
  async def _run_evaluator_internal_impl(
braintrust/logger.py CHANGED
@@ -87,6 +87,7 @@ from .util import (
87
87
  get_caller_location,
88
88
  mask_api_key,
89
89
  merge_dicts,
90
+ parse_env_var_float,
90
91
  response_raise_for_status,
91
92
  )
92
93
 
@@ -349,9 +350,16 @@ class BraintrustState:
349
350
  def __init__(self):
350
351
  self.id = str(uuid.uuid4())
351
352
  self.current_experiment: Experiment | None = None
352
- self.current_logger: contextvars.ContextVar[Logger | None] = contextvars.ContextVar(
353
+ # We use both a ContextVar and a plain attribute for the current logger:
354
+ # - _cv_logger (ContextVar): Provides async context isolation so different
355
+ # async tasks can have different loggers without affecting each other.
356
+ # - _local_logger (plain attribute): Fallback for threads, since ContextVars
357
+ # don't propagate to new threads. This way if users don't want to do
358
+ # anything specific they'll always have a "global logger"
359
+ self._cv_logger: contextvars.ContextVar[Logger | None] = contextvars.ContextVar(
353
360
  "braintrust_current_logger", default=None
354
361
  )
362
+ self._local_logger: Logger | None = None
355
363
  self.current_parent: contextvars.ContextVar[str | None] = contextvars.ContextVar(
356
364
  "braintrust_current_parent", default=None
357
365
  )
@@ -425,7 +433,8 @@ class BraintrustState:
425
433
  def reset_parent_state(self):
426
434
  # reset possible parent state for tests
427
435
  self.current_experiment = None
428
- self.current_logger.set(None)
436
+ self._cv_logger.set(None)
437
+ self._local_logger = None
429
438
  self.current_parent.set(None)
430
439
  self.current_span.set(NOOP_SPAN)
431
440
 
@@ -485,7 +494,8 @@ class BraintrustState:
485
494
  if k
486
495
  not in (
487
496
  "current_experiment",
488
- "current_logger",
497
+ "_cv_logger",
498
+ "_local_logger",
489
499
  "current_parent",
490
500
  "current_span",
491
501
  "_global_bg_logger",
@@ -555,10 +565,6 @@ class BraintrustState:
555
565
  self._user_info = self.api_conn().get_json("ping")
556
566
  return self._user_info
557
567
 
558
- def set_user_info_if_null(self, info: Mapping[str, Any]):
559
- if not self._user_info:
560
- self._user_info = info
561
-
562
568
  def global_bg_logger(self) -> "_BackgroundLogger":
563
569
  return getattr(self._override_bg_logger, "logger", None) or self._global_bg_logger.get()
564
570
 
@@ -620,14 +626,28 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
620
626
  base_num_retries: Maximum number of retries before giving up and re-raising the exception.
621
627
  backoff_factor: A multiplier used to determine the time to wait between retries.
622
628
  The actual wait time is calculated as: backoff_factor * (2 ** retry_count).
629
+ default_timeout_secs: Default timeout in seconds for requests that don't specify one.
630
+ Prevents indefinite hangs on stale connections.
623
631
  """
624
632
 
625
- def __init__(self, *args: Any, base_num_retries: int = 0, backoff_factor: float = 0.5, **kwargs: Any):
633
+ def __init__(
634
+ self,
635
+ *args: Any,
636
+ base_num_retries: int = 0,
637
+ backoff_factor: float = 0.5,
638
+ default_timeout_secs: float = 60,
639
+ **kwargs: Any,
640
+ ):
626
641
  self.base_num_retries = base_num_retries
627
642
  self.backoff_factor = backoff_factor
643
+ self.default_timeout_secs = default_timeout_secs
628
644
  super().__init__(*args, **kwargs)
629
645
 
630
646
  def send(self, *args, **kwargs):
647
+ # Apply default timeout if none provided to prevent indefinite hangs
648
+ if kwargs.get("timeout") is None:
649
+ kwargs["timeout"] = self.default_timeout_secs
650
+
631
651
  num_prev_retries = 0
632
652
  while True:
633
653
  try:
@@ -639,6 +659,14 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
639
659
  return response
640
660
  except (urllib3.exceptions.HTTPError, requests.exceptions.RequestException) as e:
641
661
  if num_prev_retries < self.base_num_retries:
662
+ if isinstance(e, requests.exceptions.ReadTimeout):
663
+ # Clear all connection pools to discard stale connections. This
664
+ # fixes hangs caused by NAT gateways silently dropping idle TCP
665
+ # connections (e.g., Azure's ~4 min timeout). close() calls
666
+ # PoolManager.clear() which is thread-safe: in-flight requests
667
+ # keep their checked-out connections, and new requests create
668
+ # fresh pools on demand.
669
+ self.close()
642
670
  # Emulates the sleeping logic in the backoff_factor of urllib3 Retry
643
671
  sleep_s = self.backoff_factor * (2**num_prev_retries)
644
672
  print("Retrying request after error:", e, file=sys.stderr)
@@ -660,14 +688,16 @@ class HTTPConnection:
660
688
  def ping(self) -> bool:
661
689
  try:
662
690
  resp = self.get("ping")
663
- _state.set_user_info_if_null(resp.json())
664
691
  return resp.ok
665
692
  except requests.exceptions.ConnectionError:
666
693
  return False
667
694
 
668
695
  def make_long_lived(self) -> None:
669
696
  if not self.adapter:
670
- self.adapter = RetryRequestExceptionsAdapter(base_num_retries=10, backoff_factor=0.5)
697
+ timeout_secs = parse_env_var_float("BRAINTRUST_HTTP_TIMEOUT", 60.0)
698
+ self.adapter = RetryRequestExceptionsAdapter(
699
+ base_num_retries=10, backoff_factor=0.5, default_timeout_secs=timeout_secs
700
+ )
671
701
  self._reset()
672
702
 
673
703
  @staticmethod
@@ -712,6 +742,8 @@ class HTTPConnection:
712
742
  return self.session.delete(_urljoin(self.base_url, path), *args, **kwargs)
713
743
 
714
744
  def get_json(self, object_type: str, args: Mapping[str, Any] | None = None, retries: int = 0) -> Mapping[str, Any]:
745
+ # FIXME[matt]: the retry logic seems to be unused and could be n*2 because of the the retry logic
746
+ # in the RetryRequestExceptionsAdapter. We should probably remove this.
715
747
  tries = retries + 1
716
748
  for i in range(tries):
717
749
  resp = self.get(f"/{object_type}", params=args)
@@ -1634,7 +1666,8 @@ def init_logger(
1634
1666
  if set_current:
1635
1667
  if _state is None:
1636
1668
  raise RuntimeError("_state is None in init_logger. This should never happen.")
1637
- _state.current_logger.set(ret)
1669
+ _state._cv_logger.set(ret)
1670
+ _state._local_logger = ret
1638
1671
  return ret
1639
1672
 
1640
1673
 
@@ -1955,7 +1988,7 @@ def current_experiment() -> Optional["Experiment"]:
1955
1988
  def current_logger() -> Optional["Logger"]:
1956
1989
  """Returns the currently-active logger (set by `braintrust.init_logger(...)`). Returns None if no current logger has been set."""
1957
1990
 
1958
- return _state.current_logger.get()
1991
+ return _state._cv_logger.get() or _state._local_logger
1959
1992
 
1960
1993
 
1961
1994
  def current_span() -> Span:
@@ -3984,6 +4017,9 @@ class SpanImpl(Span):
3984
4017
  use_v4 = os.getenv("BRAINTRUST_OTEL_COMPAT", "false").lower() == "true"
3985
4018
  span_components_class = SpanComponentsV4 if use_v4 else SpanComponentsV3
3986
4019
 
4020
+ # Disable span cache since remote function spans won't be in the local cache
4021
+ self.state.span_cache.disable()
4022
+
3987
4023
  return span_components_class(
3988
4024
  object_type=self.parent_object_type,
3989
4025
  object_id=object_id,
@@ -3997,7 +4033,7 @@ class SpanImpl(Span):
3997
4033
  def link(self) -> str:
3998
4034
  parent_type, info = self._get_parent_info()
3999
4035
  if parent_type == SpanObjectTypeV3.PROJECT_LOGS:
4000
- cur_logger = self.state.current_logger.get()
4036
+ cur_logger = self.state._cv_logger.get() or self.state._local_logger
4001
4037
  if not cur_logger:
4002
4038
  return NOOP_SPAN_PERMALINK
4003
4039
  base_url = cur_logger._get_link_base_url()
braintrust/oai.py CHANGED
@@ -5,6 +5,8 @@ import time
5
5
  from collections.abc import Callable
6
6
  from typing import Any
7
7
 
8
+ from wrapt import wrap_function_wrapper
9
+
8
10
  from .logger import Attachment, Span, start_span
9
11
  from .span_types import SpanTypeAttribute
10
12
  from .util import merge_dicts
@@ -986,3 +988,52 @@ def _is_not_given(value: Any) -> bool:
986
988
  return type_name == "NotGiven"
987
989
  except Exception:
988
990
  return False
991
+
992
+
993
+ def _openai_init_wrapper(wrapped, instance, args, kwargs):
994
+ """Wrapper for OpenAI.__init__ that applies tracing after initialization."""
995
+ wrapped(*args, **kwargs)
996
+ _apply_openai_wrapper(instance)
997
+
998
+
999
+ def patch_openai() -> bool:
1000
+ """
1001
+ Patch OpenAI to add Braintrust tracing globally.
1002
+
1003
+ After calling this, all new OpenAI() and AsyncOpenAI() clients
1004
+ will automatically have tracing enabled.
1005
+
1006
+ Returns:
1007
+ True if OpenAI was patched (or already patched), False if OpenAI is not installed.
1008
+
1009
+ Example:
1010
+ ```python
1011
+ import braintrust
1012
+ braintrust.patch_openai()
1013
+
1014
+ import openai
1015
+ client = openai.OpenAI()
1016
+ # All calls are now traced!
1017
+ ```
1018
+ """
1019
+ try:
1020
+ import openai
1021
+
1022
+ if getattr(openai, "__braintrust_wrapped__", False):
1023
+ return True # Already patched
1024
+
1025
+ wrap_function_wrapper("openai", "OpenAI.__init__", _openai_init_wrapper)
1026
+ wrap_function_wrapper("openai", "AsyncOpenAI.__init__", _openai_init_wrapper)
1027
+ openai.__braintrust_wrapped__ = True
1028
+ return True
1029
+
1030
+ except ImportError:
1031
+ return False
1032
+
1033
+
1034
+ def _apply_openai_wrapper(client):
1035
+ """Apply tracing wrapper to an OpenAI client instance in-place."""
1036
+ wrapped = wrap_openai(client)
1037
+ for attr in ("chat", "responses", "embeddings", "moderations", "beta"):
1038
+ if hasattr(wrapped, attr):
1039
+ setattr(client, attr, getattr(wrapped, attr))
@@ -302,11 +302,6 @@ def test_to_bt_safe_special_objects():
302
302
  assert _to_bt_safe(dataset) == "<dataset>"
303
303
  assert _to_bt_safe(logger) == "<logger>"
304
304
 
305
- # Clean up
306
- exp.flush()
307
- dataset.flush()
308
- logger.flush()
309
-
310
305
 
311
306
  class TestBTJsonAttachments(TestCase):
312
307
  def test_to_bt_safe_attachments(self):
@@ -1,6 +1,8 @@
1
1
  from typing import List
2
+ from unittest.mock import MagicMock
2
3
 
3
4
  import pytest
5
+ from braintrust.logger import BraintrustState
4
6
 
5
7
  from .framework import (
6
8
  Eval,
@@ -241,6 +243,7 @@ async def test_hooks_trial_index_multiple_inputs():
241
243
  assert sorted(input_2_trials) == [0, 1]
242
244
 
243
245
 
246
+ @pytest.mark.vcr
244
247
  @pytest.mark.asyncio
245
248
  async def test_scorer_spans_have_purpose_attribute(with_memory_logger, with_simulate_login):
246
249
  """Test that scorer spans have span_attributes.purpose='scorer' and propagate to subspans."""
@@ -527,3 +530,37 @@ async def test_hooks_without_setting_tags(with_memory_logger, with_simulate_logi
527
530
  root_span = [log for log in logs if not log["span_parents"]]
528
531
  assert len(root_span) == 1
529
532
  assert root_span[0].get("tags") == None
533
+
534
+ @pytest.mark.asyncio
535
+ async def test_eval_enable_cache():
536
+ state = BraintrustState()
537
+ state.span_cache = MagicMock()
538
+
539
+ # Test enable_cache=False
540
+ await Eval(
541
+ "test-enable-cache-false",
542
+ data=[EvalCase(input=1, expected=1)],
543
+ task=lambda x: x,
544
+ scores=[],
545
+ state=state,
546
+ no_send_logs=True,
547
+ enable_cache=False,
548
+ )
549
+ state.span_cache.start.assert_not_called()
550
+ state.span_cache.stop.assert_not_called()
551
+
552
+ # Test enable_cache=True (default)
553
+ state.span_cache.start.reset_mock()
554
+ state.span_cache.stop.reset_mock()
555
+
556
+ await Eval(
557
+ "test-enable-cache-true",
558
+ data=[EvalCase(input=1, expected=1)],
559
+ task=lambda x: x,
560
+ scores=[],
561
+ state=state,
562
+ no_send_logs=True,
563
+ # enable_cache defaults to True
564
+ )
565
+ state.span_cache.start.assert_called()
566
+ state.span_cache.stop.assert_called()