braintrust 0.4.3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. braintrust/__init__.py +3 -0
  2. braintrust/_generated_types.py +106 -6
  3. braintrust/auto.py +179 -0
  4. braintrust/conftest.py +23 -4
  5. braintrust/framework.py +113 -3
  6. braintrust/functions/invoke.py +3 -1
  7. braintrust/functions/test_invoke.py +61 -0
  8. braintrust/generated_types.py +7 -1
  9. braintrust/logger.py +127 -45
  10. braintrust/oai.py +51 -0
  11. braintrust/span_cache.py +337 -0
  12. braintrust/span_identifier_v3.py +21 -0
  13. braintrust/test_bt_json.py +0 -5
  14. braintrust/test_framework.py +37 -0
  15. braintrust/test_http.py +444 -0
  16. braintrust/test_logger.py +295 -5
  17. braintrust/test_span_cache.py +344 -0
  18. braintrust/test_trace.py +267 -0
  19. braintrust/test_util.py +58 -1
  20. braintrust/trace.py +385 -0
  21. braintrust/util.py +20 -0
  22. braintrust/version.py +2 -2
  23. braintrust/wrappers/agno/__init__.py +2 -3
  24. braintrust/wrappers/anthropic.py +64 -0
  25. braintrust/wrappers/claude_agent_sdk/__init__.py +2 -3
  26. braintrust/wrappers/claude_agent_sdk/_wrapper.py +48 -6
  27. braintrust/wrappers/claude_agent_sdk/test_wrapper.py +115 -0
  28. braintrust/wrappers/dspy.py +52 -1
  29. braintrust/wrappers/google_genai/__init__.py +9 -6
  30. braintrust/wrappers/litellm.py +6 -43
  31. braintrust/wrappers/pydantic_ai.py +2 -3
  32. braintrust/wrappers/test_agno.py +9 -0
  33. braintrust/wrappers/test_anthropic.py +156 -0
  34. braintrust/wrappers/test_dspy.py +117 -0
  35. braintrust/wrappers/test_google_genai.py +9 -0
  36. braintrust/wrappers/test_litellm.py +57 -55
  37. braintrust/wrappers/test_openai.py +253 -1
  38. braintrust/wrappers/test_pydantic_ai_integration.py +9 -0
  39. braintrust/wrappers/test_utils.py +79 -0
  40. {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/METADATA +1 -1
  41. {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/RECORD +44 -37
  42. {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/WHEEL +1 -1
  43. {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/entry_points.txt +0 -0
  44. {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
1
+ """Tests for the invoke module, particularly init_function."""
2
+
3
+
4
+ from braintrust.functions.invoke import init_function
5
+ from braintrust.logger import _internal_get_global_state, _internal_reset_global_state
6
+
7
+
8
+ class TestInitFunction:
9
+ """Tests for init_function."""
10
+
11
+ def setup_method(self):
12
+ """Reset state before each test."""
13
+ _internal_reset_global_state()
14
+
15
+ def teardown_method(self):
16
+ """Clean up after each test."""
17
+ _internal_reset_global_state()
18
+
19
+ def test_init_function_disables_span_cache(self):
20
+ """Test that init_function disables the span cache."""
21
+ state = _internal_get_global_state()
22
+
23
+ # Cache should be disabled by default (it's only enabled during evals)
24
+ assert state.span_cache.disabled is True
25
+
26
+ # Enable the cache (simulating what happens during eval)
27
+ state.span_cache.start()
28
+ assert state.span_cache.disabled is False
29
+
30
+ # Call init_function
31
+ f = init_function("test-project", "test-function")
32
+
33
+ # Cache should now be disabled (init_function explicitly disables it)
34
+ assert state.span_cache.disabled is True
35
+ assert f.__name__ == "init_function-test-project-test-function-latest"
36
+
37
+ def test_init_function_with_version(self):
38
+ """Test that init_function creates a function with the correct name including version."""
39
+ f = init_function("my-project", "my-scorer", version="v1")
40
+ assert f.__name__ == "init_function-my-project-my-scorer-v1"
41
+
42
+ def test_init_function_without_version_uses_latest(self):
43
+ """Test that init_function uses 'latest' in name when version not specified."""
44
+ f = init_function("my-project", "my-scorer")
45
+ assert f.__name__ == "init_function-my-project-my-scorer-latest"
46
+
47
+ def test_init_function_permanently_disables_cache(self):
48
+ """Test that init_function permanently disables the cache (can't be re-enabled)."""
49
+ state = _internal_get_global_state()
50
+
51
+ # Enable the cache
52
+ state.span_cache.start()
53
+ assert state.span_cache.disabled is False
54
+
55
+ # Call init_function
56
+ init_function("test-project", "test-function")
57
+ assert state.span_cache.disabled is True
58
+
59
+ # Try to start again - should still be disabled because of explicit disable
60
+ state.span_cache.start()
61
+ assert state.span_cache.disabled is True
@@ -1,4 +1,4 @@
1
- """Auto-generated file (internal git SHA 87ac73f4945a47eff2d4e42775ba4dbc58854c73) -- do not modify"""
1
+ """Auto-generated file (internal git SHA 21146f64bf5ad1eadd3a99d186274728e25e5399) -- do not modify"""
2
2
 
3
3
  from ._generated_types import (
4
4
  Acl,
@@ -29,6 +29,9 @@ from ._generated_types import (
29
29
  Dataset,
30
30
  DatasetEvent,
31
31
  EnvVar,
32
+ EvalStatusPage,
33
+ EvalStatusPageConfig,
34
+ EvalStatusPageTheme,
32
35
  Experiment,
33
36
  ExperimentEvent,
34
37
  ExtendedSavedFunctionId,
@@ -136,6 +139,9 @@ __all__ = [
136
139
  "Dataset",
137
140
  "DatasetEvent",
138
141
  "EnvVar",
142
+ "EvalStatusPage",
143
+ "EvalStatusPageConfig",
144
+ "EvalStatusPageTheme",
139
145
  "Experiment",
140
146
  "ExperimentEvent",
141
147
  "ExtendedSavedFunctionId",
braintrust/logger.py CHANGED
@@ -47,12 +47,9 @@ from urllib3.util.retry import Retry
47
47
  from . import context, id_gen
48
48
  from .bt_json import bt_dumps, bt_safe_deep_copy
49
49
  from .db_fields import (
50
- ASYNC_SCORING_CONTROL_FIELD,
51
50
  AUDIT_METADATA_FIELD,
52
51
  AUDIT_SOURCE_FIELD,
53
52
  IS_MERGE_FIELD,
54
- MERGE_PATHS_FIELD,
55
- SKIP_ASYNC_SCORING_FIELD,
56
53
  TRANSACTION_ID_FIELD,
57
54
  VALID_SOURCES,
58
55
  )
@@ -90,6 +87,7 @@ from .util import (
90
87
  get_caller_location,
91
88
  mask_api_key,
92
89
  merge_dicts,
90
+ parse_env_var_float,
93
91
  response_raise_for_status,
94
92
  )
95
93
 
@@ -101,6 +99,14 @@ from .xact_ids import prettify_xact
101
99
  Metadata = dict[str, Any]
102
100
  DATA_API_VERSION = 2
103
101
 
102
+
103
+ class DatasetRef(TypedDict, total=False):
104
+ """Reference to a dataset by ID and optional version."""
105
+
106
+ id: str
107
+ version: str
108
+
109
+
104
110
  T = TypeVar("T")
105
111
  TMapping = TypeVar("TMapping", bound=Mapping[str, Any])
106
112
  TMutableMapping = TypeVar("TMutableMapping", bound=MutableMapping[str, Any])
@@ -344,9 +350,16 @@ class BraintrustState:
344
350
  def __init__(self):
345
351
  self.id = str(uuid.uuid4())
346
352
  self.current_experiment: Experiment | None = None
347
- self.current_logger: contextvars.ContextVar[Logger | None] = contextvars.ContextVar(
353
+ # We use both a ContextVar and a plain attribute for the current logger:
354
+ # - _cv_logger (ContextVar): Provides async context isolation so different
355
+ # async tasks can have different loggers without affecting each other.
356
+ # - _local_logger (plain attribute): Fallback for threads, since ContextVars
357
+ # don't propagate to new threads. This way if users don't want to do
358
+ # anything specific they'll always have a "global logger"
359
+ self._cv_logger: contextvars.ContextVar[Logger | None] = contextvars.ContextVar(
348
360
  "braintrust_current_logger", default=None
349
361
  )
362
+ self._local_logger: Logger | None = None
350
363
  self.current_parent: contextvars.ContextVar[str | None] = contextvars.ContextVar(
351
364
  "braintrust_current_parent", default=None
352
365
  )
@@ -396,6 +409,11 @@ class BraintrustState:
396
409
  ),
397
410
  )
398
411
 
412
+ from braintrust.span_cache import SpanCache
413
+
414
+ self.span_cache = SpanCache()
415
+ self._otel_flush_callback: Any | None = None
416
+
399
417
  def reset_login_info(self):
400
418
  self.app_url: str | None = None
401
419
  self.app_public_url: str | None = None
@@ -415,7 +433,8 @@ class BraintrustState:
415
433
  def reset_parent_state(self):
416
434
  # reset possible parent state for tests
417
435
  self.current_experiment = None
418
- self.current_logger.set(None)
436
+ self._cv_logger.set(None)
437
+ self._local_logger = None
419
438
  self.current_parent.set(None)
420
439
  self.current_span.set(NOOP_SPAN)
421
440
 
@@ -452,6 +471,21 @@ class BraintrustState:
452
471
 
453
472
  return self._context_manager
454
473
 
474
+ def register_otel_flush(self, callback: Any) -> None:
475
+ """
476
+ Register an OTEL flush callback. This is called by the OTEL integration
477
+ when it initializes a span processor/exporter.
478
+ """
479
+ self._otel_flush_callback = callback
480
+
481
+ async def flush_otel(self) -> None:
482
+ """
483
+ Flush OTEL spans if a callback is registered.
484
+ Called during ensure_spans_flushed to ensure OTEL spans are visible in BTQL.
485
+ """
486
+ if self._otel_flush_callback:
487
+ await self._otel_flush_callback()
488
+
455
489
  def copy_state(self, other: "BraintrustState"):
456
490
  """Copy login information from another BraintrustState instance."""
457
491
  self.__dict__.update({
@@ -460,7 +494,8 @@ class BraintrustState:
460
494
  if k
461
495
  not in (
462
496
  "current_experiment",
463
- "current_logger",
497
+ "_cv_logger",
498
+ "_local_logger",
464
499
  "current_parent",
465
500
  "current_span",
466
501
  "_global_bg_logger",
@@ -530,10 +565,6 @@ class BraintrustState:
530
565
  self._user_info = self.api_conn().get_json("ping")
531
566
  return self._user_info
532
567
 
533
- def set_user_info_if_null(self, info: Mapping[str, Any]):
534
- if not self._user_info:
535
- self._user_info = info
536
-
537
568
  def global_bg_logger(self) -> "_BackgroundLogger":
538
569
  return getattr(self._override_bg_logger, "logger", None) or self._global_bg_logger.get()
539
570
 
@@ -595,14 +626,28 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
595
626
  base_num_retries: Maximum number of retries before giving up and re-raising the exception.
596
627
  backoff_factor: A multiplier used to determine the time to wait between retries.
597
628
  The actual wait time is calculated as: backoff_factor * (2 ** retry_count).
629
+ default_timeout_secs: Default timeout in seconds for requests that don't specify one.
630
+ Prevents indefinite hangs on stale connections.
598
631
  """
599
632
 
600
- def __init__(self, *args: Any, base_num_retries: int = 0, backoff_factor: float = 0.5, **kwargs: Any):
633
+ def __init__(
634
+ self,
635
+ *args: Any,
636
+ base_num_retries: int = 0,
637
+ backoff_factor: float = 0.5,
638
+ default_timeout_secs: float = 60,
639
+ **kwargs: Any,
640
+ ):
601
641
  self.base_num_retries = base_num_retries
602
642
  self.backoff_factor = backoff_factor
643
+ self.default_timeout_secs = default_timeout_secs
603
644
  super().__init__(*args, **kwargs)
604
645
 
605
646
  def send(self, *args, **kwargs):
647
+ # Apply default timeout if none provided to prevent indefinite hangs
648
+ if kwargs.get("timeout") is None:
649
+ kwargs["timeout"] = self.default_timeout_secs
650
+
606
651
  num_prev_retries = 0
607
652
  while True:
608
653
  try:
@@ -614,6 +659,14 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
614
659
  return response
615
660
  except (urllib3.exceptions.HTTPError, requests.exceptions.RequestException) as e:
616
661
  if num_prev_retries < self.base_num_retries:
662
+ if isinstance(e, requests.exceptions.ReadTimeout):
663
+ # Clear all connection pools to discard stale connections. This
664
+ # fixes hangs caused by NAT gateways silently dropping idle TCP
665
+ # connections (e.g., Azure's ~4 min timeout). close() calls
666
+ # PoolManager.clear() which is thread-safe: in-flight requests
667
+ # keep their checked-out connections, and new requests create
668
+ # fresh pools on demand.
669
+ self.close()
617
670
  # Emulates the sleeping logic in the backoff_factor of urllib3 Retry
618
671
  sleep_s = self.backoff_factor * (2**num_prev_retries)
619
672
  print("Retrying request after error:", e, file=sys.stderr)
@@ -635,14 +688,16 @@ class HTTPConnection:
635
688
  def ping(self) -> bool:
636
689
  try:
637
690
  resp = self.get("ping")
638
- _state.set_user_info_if_null(resp.json())
639
691
  return resp.ok
640
692
  except requests.exceptions.ConnectionError:
641
693
  return False
642
694
 
643
695
  def make_long_lived(self) -> None:
644
696
  if not self.adapter:
645
- self.adapter = RetryRequestExceptionsAdapter(base_num_retries=10, backoff_factor=0.5)
697
+ timeout_secs = parse_env_var_float("BRAINTRUST_HTTP_TIMEOUT", 60.0)
698
+ self.adapter = RetryRequestExceptionsAdapter(
699
+ base_num_retries=10, backoff_factor=0.5, default_timeout_secs=timeout_secs
700
+ )
646
701
  self._reset()
647
702
 
648
703
  @staticmethod
@@ -687,6 +742,8 @@ class HTTPConnection:
687
742
  return self.session.delete(_urljoin(self.base_url, path), *args, **kwargs)
688
743
 
689
744
  def get_json(self, object_type: str, args: Mapping[str, Any] | None = None, retries: int = 0) -> Mapping[str, Any]:
745
+ # FIXME[matt]: the retry logic seems to be unused and could be n*2 because of the the retry logic
746
+ # in the RetryRequestExceptionsAdapter. We should probably remove this.
690
747
  tries = retries + 1
691
748
  for i in range(tries):
692
749
  resp = self.get(f"/{object_type}", params=args)
@@ -1297,7 +1354,7 @@ def init(
1297
1354
  project: str | None = None,
1298
1355
  experiment: str | None = None,
1299
1356
  description: str | None = None,
1300
- dataset: Optional["Dataset"] = None,
1357
+ dataset: Optional["Dataset"] | DatasetRef = None,
1301
1358
  open: bool = False,
1302
1359
  base_experiment: str | None = None,
1303
1360
  is_public: bool = False,
@@ -1410,12 +1467,19 @@ def init(
1410
1467
  args["base_exp_id"] = base_experiment_id
1411
1468
  elif base_experiment is not None:
1412
1469
  args["base_experiment"] = base_experiment
1413
- else:
1470
+ elif merged_git_metadata_settings and merged_git_metadata_settings.collect != "none":
1414
1471
  args["ancestor_commits"] = list(get_past_n_ancestors())
1415
1472
 
1416
1473
  if dataset is not None:
1417
- args["dataset_id"] = dataset.id
1418
- args["dataset_version"] = dataset.version
1474
+ if isinstance(dataset, dict):
1475
+ # Simple {"id": ..., "version": ...} dict
1476
+ args["dataset_id"] = dataset["id"]
1477
+ if "version" in dataset:
1478
+ args["dataset_version"] = dataset["version"]
1479
+ else:
1480
+ # Full Dataset object
1481
+ args["dataset_id"] = dataset.id
1482
+ args["dataset_version"] = dataset.version
1419
1483
 
1420
1484
  if is_public is not None:
1421
1485
  args["public"] = is_public
@@ -1446,7 +1510,11 @@ def init(
1446
1510
  # For experiments, disable queue size limit enforcement (unlimited queue)
1447
1511
  state.enforce_queue_size_limit(False)
1448
1512
 
1449
- ret = Experiment(lazy_metadata=LazyValue(compute_metadata, use_mutex=True), dataset=dataset, state=state)
1513
+ ret = Experiment(
1514
+ lazy_metadata=LazyValue(compute_metadata, use_mutex=True),
1515
+ dataset=dataset if isinstance(dataset, Dataset) else None,
1516
+ state=state,
1517
+ )
1450
1518
  if set_current:
1451
1519
  state.current_experiment = ret
1452
1520
  return ret
@@ -1598,7 +1666,8 @@ def init_logger(
1598
1666
  if set_current:
1599
1667
  if _state is None:
1600
1668
  raise RuntimeError("_state is None in init_logger. This should never happen.")
1601
- _state.current_logger.set(ret)
1669
+ _state._cv_logger.set(ret)
1670
+ _state._local_logger = ret
1602
1671
  return ret
1603
1672
 
1604
1673
 
@@ -1761,6 +1830,25 @@ def login(
1761
1830
  _state.login(app_url=app_url, api_key=api_key, org_name=org_name, force_login=force_login)
1762
1831
 
1763
1832
 
1833
+ def register_otel_flush(callback: Any) -> None:
1834
+ """
1835
+ Register a callback to flush OTEL spans. This is called by the OTEL integration
1836
+ when it initializes a span processor/exporter.
1837
+
1838
+ When ensure_spans_flushed is called (e.g., before a BTQL query in scorers),
1839
+ this callback will be invoked to ensure OTEL spans are flushed to the server.
1840
+
1841
+ Also disables the span cache, since OTEL spans aren't in the local cache
1842
+ and we need BTQL to see the complete span tree (both native + OTEL spans).
1843
+
1844
+ :param callback: The async callback function to flush OTEL spans.
1845
+ """
1846
+ global _state
1847
+ _state.register_otel_flush(callback)
1848
+ # Disable span cache since OTEL spans aren't in the local cache
1849
+ _state.span_cache.disable()
1850
+
1851
+
1764
1852
  def login_to_state(
1765
1853
  app_url: str | None = None,
1766
1854
  api_key: str | None = None,
@@ -1900,7 +1988,7 @@ def current_experiment() -> Optional["Experiment"]:
1900
1988
  def current_logger() -> Optional["Logger"]:
1901
1989
  """Returns the currently-active logger (set by `braintrust.init_logger(...)`). Returns None if no current logger has been set."""
1902
1990
 
1903
- return _state.current_logger.get()
1991
+ return _state._cv_logger.get() or _state._local_logger
1904
1992
 
1905
1993
 
1906
1994
  def current_span() -> Span:
@@ -2323,30 +2411,6 @@ def _enrich_attachments(event: TMutableMapping) -> TMutableMapping:
2323
2411
 
2324
2412
 
2325
2413
  def _validate_and_sanitize_experiment_log_partial_args(event: Mapping[str, Any]) -> dict[str, Any]:
2326
- # Make sure only certain keys are specified.
2327
- forbidden_keys = set(event.keys()) - {
2328
- "input",
2329
- "output",
2330
- "expected",
2331
- "tags",
2332
- "scores",
2333
- "metadata",
2334
- "metrics",
2335
- "error",
2336
- "dataset_record_id",
2337
- "origin",
2338
- "inputs",
2339
- "span_attributes",
2340
- ASYNC_SCORING_CONTROL_FIELD,
2341
- MERGE_PATHS_FIELD,
2342
- SKIP_ASYNC_SCORING_FIELD,
2343
- "span_id",
2344
- "root_span_id",
2345
- "_bt_internal_override_pagination_key",
2346
- }
2347
- if forbidden_keys:
2348
- raise ValueError(f"The following keys are not permitted: {forbidden_keys}")
2349
-
2350
2414
  scores = event.get("scores")
2351
2415
  if scores:
2352
2416
  for name, score in scores.items():
@@ -3855,6 +3919,21 @@ class SpanImpl(Span):
3855
3919
  if serializable_partial_record.get("metrics", {}).get("end") is not None:
3856
3920
  self._logged_end_time = serializable_partial_record["metrics"]["end"]
3857
3921
 
3922
+ # Write to local span cache for scorer access
3923
+ # Only cache experiment spans - regular logs don't need caching
3924
+ if self.parent_object_type == SpanObjectTypeV3.EXPERIMENT:
3925
+ from braintrust.span_cache import CachedSpan
3926
+
3927
+ cached_span = CachedSpan(
3928
+ span_id=self.span_id,
3929
+ input=serializable_partial_record.get("input"),
3930
+ output=serializable_partial_record.get("output"),
3931
+ metadata=serializable_partial_record.get("metadata"),
3932
+ span_parents=self.span_parents,
3933
+ span_attributes=serializable_partial_record.get("span_attributes"),
3934
+ )
3935
+ self.state.span_cache.queue_write(self.root_span_id, self.span_id, cached_span)
3936
+
3858
3937
  def compute_record() -> dict[str, Any]:
3859
3938
  exporter = _get_exporter()
3860
3939
  return dict(
@@ -3938,6 +4017,9 @@ class SpanImpl(Span):
3938
4017
  use_v4 = os.getenv("BRAINTRUST_OTEL_COMPAT", "false").lower() == "true"
3939
4018
  span_components_class = SpanComponentsV4 if use_v4 else SpanComponentsV3
3940
4019
 
4020
+ # Disable span cache since remote function spans won't be in the local cache
4021
+ self.state.span_cache.disable()
4022
+
3941
4023
  return span_components_class(
3942
4024
  object_type=self.parent_object_type,
3943
4025
  object_id=object_id,
@@ -3951,7 +4033,7 @@ class SpanImpl(Span):
3951
4033
  def link(self) -> str:
3952
4034
  parent_type, info = self._get_parent_info()
3953
4035
  if parent_type == SpanObjectTypeV3.PROJECT_LOGS:
3954
- cur_logger = self.state.current_logger.get()
4036
+ cur_logger = self.state._cv_logger.get() or self.state._local_logger
3955
4037
  if not cur_logger:
3956
4038
  return NOOP_SPAN_PERMALINK
3957
4039
  base_url = cur_logger._get_link_base_url()
braintrust/oai.py CHANGED
@@ -5,6 +5,8 @@ import time
5
5
  from collections.abc import Callable
6
6
  from typing import Any
7
7
 
8
+ from wrapt import wrap_function_wrapper
9
+
8
10
  from .logger import Attachment, Span, start_span
9
11
  from .span_types import SpanTypeAttribute
10
12
  from .util import merge_dicts
@@ -986,3 +988,52 @@ def _is_not_given(value: Any) -> bool:
986
988
  return type_name == "NotGiven"
987
989
  except Exception:
988
990
  return False
991
+
992
+
993
+ def _openai_init_wrapper(wrapped, instance, args, kwargs):
994
+ """Wrapper for OpenAI.__init__ that applies tracing after initialization."""
995
+ wrapped(*args, **kwargs)
996
+ _apply_openai_wrapper(instance)
997
+
998
+
999
+ def patch_openai() -> bool:
1000
+ """
1001
+ Patch OpenAI to add Braintrust tracing globally.
1002
+
1003
+ After calling this, all new OpenAI() and AsyncOpenAI() clients
1004
+ will automatically have tracing enabled.
1005
+
1006
+ Returns:
1007
+ True if OpenAI was patched (or already patched), False if OpenAI is not installed.
1008
+
1009
+ Example:
1010
+ ```python
1011
+ import braintrust
1012
+ braintrust.patch_openai()
1013
+
1014
+ import openai
1015
+ client = openai.OpenAI()
1016
+ # All calls are now traced!
1017
+ ```
1018
+ """
1019
+ try:
1020
+ import openai
1021
+
1022
+ if getattr(openai, "__braintrust_wrapped__", False):
1023
+ return True # Already patched
1024
+
1025
+ wrap_function_wrapper("openai", "OpenAI.__init__", _openai_init_wrapper)
1026
+ wrap_function_wrapper("openai", "AsyncOpenAI.__init__", _openai_init_wrapper)
1027
+ openai.__braintrust_wrapped__ = True
1028
+ return True
1029
+
1030
+ except ImportError:
1031
+ return False
1032
+
1033
+
1034
+ def _apply_openai_wrapper(client):
1035
+ """Apply tracing wrapper to an OpenAI client instance in-place."""
1036
+ wrapped = wrap_openai(client)
1037
+ for attr in ("chat", "responses", "embeddings", "moderations", "beta"):
1038
+ if hasattr(wrapped, attr):
1039
+ setattr(client, attr, getattr(wrapped, attr))