braintrust 0.4.3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/__init__.py +3 -0
- braintrust/_generated_types.py +106 -6
- braintrust/auto.py +179 -0
- braintrust/conftest.py +23 -4
- braintrust/framework.py +113 -3
- braintrust/functions/invoke.py +3 -1
- braintrust/functions/test_invoke.py +61 -0
- braintrust/generated_types.py +7 -1
- braintrust/logger.py +127 -45
- braintrust/oai.py +51 -0
- braintrust/span_cache.py +337 -0
- braintrust/span_identifier_v3.py +21 -0
- braintrust/test_bt_json.py +0 -5
- braintrust/test_framework.py +37 -0
- braintrust/test_http.py +444 -0
- braintrust/test_logger.py +295 -5
- braintrust/test_span_cache.py +344 -0
- braintrust/test_trace.py +267 -0
- braintrust/test_util.py +58 -1
- braintrust/trace.py +385 -0
- braintrust/util.py +20 -0
- braintrust/version.py +2 -2
- braintrust/wrappers/agno/__init__.py +2 -3
- braintrust/wrappers/anthropic.py +64 -0
- braintrust/wrappers/claude_agent_sdk/__init__.py +2 -3
- braintrust/wrappers/claude_agent_sdk/_wrapper.py +48 -6
- braintrust/wrappers/claude_agent_sdk/test_wrapper.py +115 -0
- braintrust/wrappers/dspy.py +52 -1
- braintrust/wrappers/google_genai/__init__.py +9 -6
- braintrust/wrappers/litellm.py +6 -43
- braintrust/wrappers/pydantic_ai.py +2 -3
- braintrust/wrappers/test_agno.py +9 -0
- braintrust/wrappers/test_anthropic.py +156 -0
- braintrust/wrappers/test_dspy.py +117 -0
- braintrust/wrappers/test_google_genai.py +9 -0
- braintrust/wrappers/test_litellm.py +57 -55
- braintrust/wrappers/test_openai.py +253 -1
- braintrust/wrappers/test_pydantic_ai_integration.py +9 -0
- braintrust/wrappers/test_utils.py +79 -0
- {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/METADATA +1 -1
- {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/RECORD +44 -37
- {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/WHEEL +1 -1
- {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/entry_points.txt +0 -0
- {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Tests for the invoke module, particularly init_function."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from braintrust.functions.invoke import init_function
|
|
5
|
+
from braintrust.logger import _internal_get_global_state, _internal_reset_global_state
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestInitFunction:
|
|
9
|
+
"""Tests for init_function."""
|
|
10
|
+
|
|
11
|
+
def setup_method(self):
|
|
12
|
+
"""Reset state before each test."""
|
|
13
|
+
_internal_reset_global_state()
|
|
14
|
+
|
|
15
|
+
def teardown_method(self):
|
|
16
|
+
"""Clean up after each test."""
|
|
17
|
+
_internal_reset_global_state()
|
|
18
|
+
|
|
19
|
+
def test_init_function_disables_span_cache(self):
|
|
20
|
+
"""Test that init_function disables the span cache."""
|
|
21
|
+
state = _internal_get_global_state()
|
|
22
|
+
|
|
23
|
+
# Cache should be disabled by default (it's only enabled during evals)
|
|
24
|
+
assert state.span_cache.disabled is True
|
|
25
|
+
|
|
26
|
+
# Enable the cache (simulating what happens during eval)
|
|
27
|
+
state.span_cache.start()
|
|
28
|
+
assert state.span_cache.disabled is False
|
|
29
|
+
|
|
30
|
+
# Call init_function
|
|
31
|
+
f = init_function("test-project", "test-function")
|
|
32
|
+
|
|
33
|
+
# Cache should now be disabled (init_function explicitly disables it)
|
|
34
|
+
assert state.span_cache.disabled is True
|
|
35
|
+
assert f.__name__ == "init_function-test-project-test-function-latest"
|
|
36
|
+
|
|
37
|
+
def test_init_function_with_version(self):
|
|
38
|
+
"""Test that init_function creates a function with the correct name including version."""
|
|
39
|
+
f = init_function("my-project", "my-scorer", version="v1")
|
|
40
|
+
assert f.__name__ == "init_function-my-project-my-scorer-v1"
|
|
41
|
+
|
|
42
|
+
def test_init_function_without_version_uses_latest(self):
|
|
43
|
+
"""Test that init_function uses 'latest' in name when version not specified."""
|
|
44
|
+
f = init_function("my-project", "my-scorer")
|
|
45
|
+
assert f.__name__ == "init_function-my-project-my-scorer-latest"
|
|
46
|
+
|
|
47
|
+
def test_init_function_permanently_disables_cache(self):
|
|
48
|
+
"""Test that init_function permanently disables the cache (can't be re-enabled)."""
|
|
49
|
+
state = _internal_get_global_state()
|
|
50
|
+
|
|
51
|
+
# Enable the cache
|
|
52
|
+
state.span_cache.start()
|
|
53
|
+
assert state.span_cache.disabled is False
|
|
54
|
+
|
|
55
|
+
# Call init_function
|
|
56
|
+
init_function("test-project", "test-function")
|
|
57
|
+
assert state.span_cache.disabled is True
|
|
58
|
+
|
|
59
|
+
# Try to start again - should still be disabled because of explicit disable
|
|
60
|
+
state.span_cache.start()
|
|
61
|
+
assert state.span_cache.disabled is True
|
braintrust/generated_types.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Auto-generated file (internal git SHA
|
|
1
|
+
"""Auto-generated file (internal git SHA 21146f64bf5ad1eadd3a99d186274728e25e5399) -- do not modify"""
|
|
2
2
|
|
|
3
3
|
from ._generated_types import (
|
|
4
4
|
Acl,
|
|
@@ -29,6 +29,9 @@ from ._generated_types import (
|
|
|
29
29
|
Dataset,
|
|
30
30
|
DatasetEvent,
|
|
31
31
|
EnvVar,
|
|
32
|
+
EvalStatusPage,
|
|
33
|
+
EvalStatusPageConfig,
|
|
34
|
+
EvalStatusPageTheme,
|
|
32
35
|
Experiment,
|
|
33
36
|
ExperimentEvent,
|
|
34
37
|
ExtendedSavedFunctionId,
|
|
@@ -136,6 +139,9 @@ __all__ = [
|
|
|
136
139
|
"Dataset",
|
|
137
140
|
"DatasetEvent",
|
|
138
141
|
"EnvVar",
|
|
142
|
+
"EvalStatusPage",
|
|
143
|
+
"EvalStatusPageConfig",
|
|
144
|
+
"EvalStatusPageTheme",
|
|
139
145
|
"Experiment",
|
|
140
146
|
"ExperimentEvent",
|
|
141
147
|
"ExtendedSavedFunctionId",
|
braintrust/logger.py
CHANGED
|
@@ -47,12 +47,9 @@ from urllib3.util.retry import Retry
|
|
|
47
47
|
from . import context, id_gen
|
|
48
48
|
from .bt_json import bt_dumps, bt_safe_deep_copy
|
|
49
49
|
from .db_fields import (
|
|
50
|
-
ASYNC_SCORING_CONTROL_FIELD,
|
|
51
50
|
AUDIT_METADATA_FIELD,
|
|
52
51
|
AUDIT_SOURCE_FIELD,
|
|
53
52
|
IS_MERGE_FIELD,
|
|
54
|
-
MERGE_PATHS_FIELD,
|
|
55
|
-
SKIP_ASYNC_SCORING_FIELD,
|
|
56
53
|
TRANSACTION_ID_FIELD,
|
|
57
54
|
VALID_SOURCES,
|
|
58
55
|
)
|
|
@@ -90,6 +87,7 @@ from .util import (
|
|
|
90
87
|
get_caller_location,
|
|
91
88
|
mask_api_key,
|
|
92
89
|
merge_dicts,
|
|
90
|
+
parse_env_var_float,
|
|
93
91
|
response_raise_for_status,
|
|
94
92
|
)
|
|
95
93
|
|
|
@@ -101,6 +99,14 @@ from .xact_ids import prettify_xact
|
|
|
101
99
|
Metadata = dict[str, Any]
|
|
102
100
|
DATA_API_VERSION = 2
|
|
103
101
|
|
|
102
|
+
|
|
103
|
+
class DatasetRef(TypedDict, total=False):
|
|
104
|
+
"""Reference to a dataset by ID and optional version."""
|
|
105
|
+
|
|
106
|
+
id: str
|
|
107
|
+
version: str
|
|
108
|
+
|
|
109
|
+
|
|
104
110
|
T = TypeVar("T")
|
|
105
111
|
TMapping = TypeVar("TMapping", bound=Mapping[str, Any])
|
|
106
112
|
TMutableMapping = TypeVar("TMutableMapping", bound=MutableMapping[str, Any])
|
|
@@ -344,9 +350,16 @@ class BraintrustState:
|
|
|
344
350
|
def __init__(self):
|
|
345
351
|
self.id = str(uuid.uuid4())
|
|
346
352
|
self.current_experiment: Experiment | None = None
|
|
347
|
-
|
|
353
|
+
# We use both a ContextVar and a plain attribute for the current logger:
|
|
354
|
+
# - _cv_logger (ContextVar): Provides async context isolation so different
|
|
355
|
+
# async tasks can have different loggers without affecting each other.
|
|
356
|
+
# - _local_logger (plain attribute): Fallback for threads, since ContextVars
|
|
357
|
+
# don't propagate to new threads. This way if users don't want to do
|
|
358
|
+
# anything specific they'll always have a "global logger"
|
|
359
|
+
self._cv_logger: contextvars.ContextVar[Logger | None] = contextvars.ContextVar(
|
|
348
360
|
"braintrust_current_logger", default=None
|
|
349
361
|
)
|
|
362
|
+
self._local_logger: Logger | None = None
|
|
350
363
|
self.current_parent: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
|
351
364
|
"braintrust_current_parent", default=None
|
|
352
365
|
)
|
|
@@ -396,6 +409,11 @@ class BraintrustState:
|
|
|
396
409
|
),
|
|
397
410
|
)
|
|
398
411
|
|
|
412
|
+
from braintrust.span_cache import SpanCache
|
|
413
|
+
|
|
414
|
+
self.span_cache = SpanCache()
|
|
415
|
+
self._otel_flush_callback: Any | None = None
|
|
416
|
+
|
|
399
417
|
def reset_login_info(self):
|
|
400
418
|
self.app_url: str | None = None
|
|
401
419
|
self.app_public_url: str | None = None
|
|
@@ -415,7 +433,8 @@ class BraintrustState:
|
|
|
415
433
|
def reset_parent_state(self):
|
|
416
434
|
# reset possible parent state for tests
|
|
417
435
|
self.current_experiment = None
|
|
418
|
-
self.
|
|
436
|
+
self._cv_logger.set(None)
|
|
437
|
+
self._local_logger = None
|
|
419
438
|
self.current_parent.set(None)
|
|
420
439
|
self.current_span.set(NOOP_SPAN)
|
|
421
440
|
|
|
@@ -452,6 +471,21 @@ class BraintrustState:
|
|
|
452
471
|
|
|
453
472
|
return self._context_manager
|
|
454
473
|
|
|
474
|
+
def register_otel_flush(self, callback: Any) -> None:
|
|
475
|
+
"""
|
|
476
|
+
Register an OTEL flush callback. This is called by the OTEL integration
|
|
477
|
+
when it initializes a span processor/exporter.
|
|
478
|
+
"""
|
|
479
|
+
self._otel_flush_callback = callback
|
|
480
|
+
|
|
481
|
+
async def flush_otel(self) -> None:
|
|
482
|
+
"""
|
|
483
|
+
Flush OTEL spans if a callback is registered.
|
|
484
|
+
Called during ensure_spans_flushed to ensure OTEL spans are visible in BTQL.
|
|
485
|
+
"""
|
|
486
|
+
if self._otel_flush_callback:
|
|
487
|
+
await self._otel_flush_callback()
|
|
488
|
+
|
|
455
489
|
def copy_state(self, other: "BraintrustState"):
|
|
456
490
|
"""Copy login information from another BraintrustState instance."""
|
|
457
491
|
self.__dict__.update({
|
|
@@ -460,7 +494,8 @@ class BraintrustState:
|
|
|
460
494
|
if k
|
|
461
495
|
not in (
|
|
462
496
|
"current_experiment",
|
|
463
|
-
"
|
|
497
|
+
"_cv_logger",
|
|
498
|
+
"_local_logger",
|
|
464
499
|
"current_parent",
|
|
465
500
|
"current_span",
|
|
466
501
|
"_global_bg_logger",
|
|
@@ -530,10 +565,6 @@ class BraintrustState:
|
|
|
530
565
|
self._user_info = self.api_conn().get_json("ping")
|
|
531
566
|
return self._user_info
|
|
532
567
|
|
|
533
|
-
def set_user_info_if_null(self, info: Mapping[str, Any]):
|
|
534
|
-
if not self._user_info:
|
|
535
|
-
self._user_info = info
|
|
536
|
-
|
|
537
568
|
def global_bg_logger(self) -> "_BackgroundLogger":
|
|
538
569
|
return getattr(self._override_bg_logger, "logger", None) or self._global_bg_logger.get()
|
|
539
570
|
|
|
@@ -595,14 +626,28 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
|
|
|
595
626
|
base_num_retries: Maximum number of retries before giving up and re-raising the exception.
|
|
596
627
|
backoff_factor: A multiplier used to determine the time to wait between retries.
|
|
597
628
|
The actual wait time is calculated as: backoff_factor * (2 ** retry_count).
|
|
629
|
+
default_timeout_secs: Default timeout in seconds for requests that don't specify one.
|
|
630
|
+
Prevents indefinite hangs on stale connections.
|
|
598
631
|
"""
|
|
599
632
|
|
|
600
|
-
def __init__(
|
|
633
|
+
def __init__(
|
|
634
|
+
self,
|
|
635
|
+
*args: Any,
|
|
636
|
+
base_num_retries: int = 0,
|
|
637
|
+
backoff_factor: float = 0.5,
|
|
638
|
+
default_timeout_secs: float = 60,
|
|
639
|
+
**kwargs: Any,
|
|
640
|
+
):
|
|
601
641
|
self.base_num_retries = base_num_retries
|
|
602
642
|
self.backoff_factor = backoff_factor
|
|
643
|
+
self.default_timeout_secs = default_timeout_secs
|
|
603
644
|
super().__init__(*args, **kwargs)
|
|
604
645
|
|
|
605
646
|
def send(self, *args, **kwargs):
|
|
647
|
+
# Apply default timeout if none provided to prevent indefinite hangs
|
|
648
|
+
if kwargs.get("timeout") is None:
|
|
649
|
+
kwargs["timeout"] = self.default_timeout_secs
|
|
650
|
+
|
|
606
651
|
num_prev_retries = 0
|
|
607
652
|
while True:
|
|
608
653
|
try:
|
|
@@ -614,6 +659,14 @@ class RetryRequestExceptionsAdapter(HTTPAdapter):
|
|
|
614
659
|
return response
|
|
615
660
|
except (urllib3.exceptions.HTTPError, requests.exceptions.RequestException) as e:
|
|
616
661
|
if num_prev_retries < self.base_num_retries:
|
|
662
|
+
if isinstance(e, requests.exceptions.ReadTimeout):
|
|
663
|
+
# Clear all connection pools to discard stale connections. This
|
|
664
|
+
# fixes hangs caused by NAT gateways silently dropping idle TCP
|
|
665
|
+
# connections (e.g., Azure's ~4 min timeout). close() calls
|
|
666
|
+
# PoolManager.clear() which is thread-safe: in-flight requests
|
|
667
|
+
# keep their checked-out connections, and new requests create
|
|
668
|
+
# fresh pools on demand.
|
|
669
|
+
self.close()
|
|
617
670
|
# Emulates the sleeping logic in the backoff_factor of urllib3 Retry
|
|
618
671
|
sleep_s = self.backoff_factor * (2**num_prev_retries)
|
|
619
672
|
print("Retrying request after error:", e, file=sys.stderr)
|
|
@@ -635,14 +688,16 @@ class HTTPConnection:
|
|
|
635
688
|
def ping(self) -> bool:
|
|
636
689
|
try:
|
|
637
690
|
resp = self.get("ping")
|
|
638
|
-
_state.set_user_info_if_null(resp.json())
|
|
639
691
|
return resp.ok
|
|
640
692
|
except requests.exceptions.ConnectionError:
|
|
641
693
|
return False
|
|
642
694
|
|
|
643
695
|
def make_long_lived(self) -> None:
|
|
644
696
|
if not self.adapter:
|
|
645
|
-
|
|
697
|
+
timeout_secs = parse_env_var_float("BRAINTRUST_HTTP_TIMEOUT", 60.0)
|
|
698
|
+
self.adapter = RetryRequestExceptionsAdapter(
|
|
699
|
+
base_num_retries=10, backoff_factor=0.5, default_timeout_secs=timeout_secs
|
|
700
|
+
)
|
|
646
701
|
self._reset()
|
|
647
702
|
|
|
648
703
|
@staticmethod
|
|
@@ -687,6 +742,8 @@ class HTTPConnection:
|
|
|
687
742
|
return self.session.delete(_urljoin(self.base_url, path), *args, **kwargs)
|
|
688
743
|
|
|
689
744
|
def get_json(self, object_type: str, args: Mapping[str, Any] | None = None, retries: int = 0) -> Mapping[str, Any]:
|
|
745
|
+
# FIXME[matt]: the retry logic seems to be unused and could be n*2 because of the the retry logic
|
|
746
|
+
# in the RetryRequestExceptionsAdapter. We should probably remove this.
|
|
690
747
|
tries = retries + 1
|
|
691
748
|
for i in range(tries):
|
|
692
749
|
resp = self.get(f"/{object_type}", params=args)
|
|
@@ -1297,7 +1354,7 @@ def init(
|
|
|
1297
1354
|
project: str | None = None,
|
|
1298
1355
|
experiment: str | None = None,
|
|
1299
1356
|
description: str | None = None,
|
|
1300
|
-
dataset: Optional["Dataset"] = None,
|
|
1357
|
+
dataset: Optional["Dataset"] | DatasetRef = None,
|
|
1301
1358
|
open: bool = False,
|
|
1302
1359
|
base_experiment: str | None = None,
|
|
1303
1360
|
is_public: bool = False,
|
|
@@ -1410,12 +1467,19 @@ def init(
|
|
|
1410
1467
|
args["base_exp_id"] = base_experiment_id
|
|
1411
1468
|
elif base_experiment is not None:
|
|
1412
1469
|
args["base_experiment"] = base_experiment
|
|
1413
|
-
|
|
1470
|
+
elif merged_git_metadata_settings and merged_git_metadata_settings.collect != "none":
|
|
1414
1471
|
args["ancestor_commits"] = list(get_past_n_ancestors())
|
|
1415
1472
|
|
|
1416
1473
|
if dataset is not None:
|
|
1417
|
-
|
|
1418
|
-
|
|
1474
|
+
if isinstance(dataset, dict):
|
|
1475
|
+
# Simple {"id": ..., "version": ...} dict
|
|
1476
|
+
args["dataset_id"] = dataset["id"]
|
|
1477
|
+
if "version" in dataset:
|
|
1478
|
+
args["dataset_version"] = dataset["version"]
|
|
1479
|
+
else:
|
|
1480
|
+
# Full Dataset object
|
|
1481
|
+
args["dataset_id"] = dataset.id
|
|
1482
|
+
args["dataset_version"] = dataset.version
|
|
1419
1483
|
|
|
1420
1484
|
if is_public is not None:
|
|
1421
1485
|
args["public"] = is_public
|
|
@@ -1446,7 +1510,11 @@ def init(
|
|
|
1446
1510
|
# For experiments, disable queue size limit enforcement (unlimited queue)
|
|
1447
1511
|
state.enforce_queue_size_limit(False)
|
|
1448
1512
|
|
|
1449
|
-
ret = Experiment(
|
|
1513
|
+
ret = Experiment(
|
|
1514
|
+
lazy_metadata=LazyValue(compute_metadata, use_mutex=True),
|
|
1515
|
+
dataset=dataset if isinstance(dataset, Dataset) else None,
|
|
1516
|
+
state=state,
|
|
1517
|
+
)
|
|
1450
1518
|
if set_current:
|
|
1451
1519
|
state.current_experiment = ret
|
|
1452
1520
|
return ret
|
|
@@ -1598,7 +1666,8 @@ def init_logger(
|
|
|
1598
1666
|
if set_current:
|
|
1599
1667
|
if _state is None:
|
|
1600
1668
|
raise RuntimeError("_state is None in init_logger. This should never happen.")
|
|
1601
|
-
_state.
|
|
1669
|
+
_state._cv_logger.set(ret)
|
|
1670
|
+
_state._local_logger = ret
|
|
1602
1671
|
return ret
|
|
1603
1672
|
|
|
1604
1673
|
|
|
@@ -1761,6 +1830,25 @@ def login(
|
|
|
1761
1830
|
_state.login(app_url=app_url, api_key=api_key, org_name=org_name, force_login=force_login)
|
|
1762
1831
|
|
|
1763
1832
|
|
|
1833
|
+
def register_otel_flush(callback: Any) -> None:
|
|
1834
|
+
"""
|
|
1835
|
+
Register a callback to flush OTEL spans. This is called by the OTEL integration
|
|
1836
|
+
when it initializes a span processor/exporter.
|
|
1837
|
+
|
|
1838
|
+
When ensure_spans_flushed is called (e.g., before a BTQL query in scorers),
|
|
1839
|
+
this callback will be invoked to ensure OTEL spans are flushed to the server.
|
|
1840
|
+
|
|
1841
|
+
Also disables the span cache, since OTEL spans aren't in the local cache
|
|
1842
|
+
and we need BTQL to see the complete span tree (both native + OTEL spans).
|
|
1843
|
+
|
|
1844
|
+
:param callback: The async callback function to flush OTEL spans.
|
|
1845
|
+
"""
|
|
1846
|
+
global _state
|
|
1847
|
+
_state.register_otel_flush(callback)
|
|
1848
|
+
# Disable span cache since OTEL spans aren't in the local cache
|
|
1849
|
+
_state.span_cache.disable()
|
|
1850
|
+
|
|
1851
|
+
|
|
1764
1852
|
def login_to_state(
|
|
1765
1853
|
app_url: str | None = None,
|
|
1766
1854
|
api_key: str | None = None,
|
|
@@ -1900,7 +1988,7 @@ def current_experiment() -> Optional["Experiment"]:
|
|
|
1900
1988
|
def current_logger() -> Optional["Logger"]:
|
|
1901
1989
|
"""Returns the currently-active logger (set by `braintrust.init_logger(...)`). Returns None if no current logger has been set."""
|
|
1902
1990
|
|
|
1903
|
-
return _state.
|
|
1991
|
+
return _state._cv_logger.get() or _state._local_logger
|
|
1904
1992
|
|
|
1905
1993
|
|
|
1906
1994
|
def current_span() -> Span:
|
|
@@ -2323,30 +2411,6 @@ def _enrich_attachments(event: TMutableMapping) -> TMutableMapping:
|
|
|
2323
2411
|
|
|
2324
2412
|
|
|
2325
2413
|
def _validate_and_sanitize_experiment_log_partial_args(event: Mapping[str, Any]) -> dict[str, Any]:
|
|
2326
|
-
# Make sure only certain keys are specified.
|
|
2327
|
-
forbidden_keys = set(event.keys()) - {
|
|
2328
|
-
"input",
|
|
2329
|
-
"output",
|
|
2330
|
-
"expected",
|
|
2331
|
-
"tags",
|
|
2332
|
-
"scores",
|
|
2333
|
-
"metadata",
|
|
2334
|
-
"metrics",
|
|
2335
|
-
"error",
|
|
2336
|
-
"dataset_record_id",
|
|
2337
|
-
"origin",
|
|
2338
|
-
"inputs",
|
|
2339
|
-
"span_attributes",
|
|
2340
|
-
ASYNC_SCORING_CONTROL_FIELD,
|
|
2341
|
-
MERGE_PATHS_FIELD,
|
|
2342
|
-
SKIP_ASYNC_SCORING_FIELD,
|
|
2343
|
-
"span_id",
|
|
2344
|
-
"root_span_id",
|
|
2345
|
-
"_bt_internal_override_pagination_key",
|
|
2346
|
-
}
|
|
2347
|
-
if forbidden_keys:
|
|
2348
|
-
raise ValueError(f"The following keys are not permitted: {forbidden_keys}")
|
|
2349
|
-
|
|
2350
2414
|
scores = event.get("scores")
|
|
2351
2415
|
if scores:
|
|
2352
2416
|
for name, score in scores.items():
|
|
@@ -3855,6 +3919,21 @@ class SpanImpl(Span):
|
|
|
3855
3919
|
if serializable_partial_record.get("metrics", {}).get("end") is not None:
|
|
3856
3920
|
self._logged_end_time = serializable_partial_record["metrics"]["end"]
|
|
3857
3921
|
|
|
3922
|
+
# Write to local span cache for scorer access
|
|
3923
|
+
# Only cache experiment spans - regular logs don't need caching
|
|
3924
|
+
if self.parent_object_type == SpanObjectTypeV3.EXPERIMENT:
|
|
3925
|
+
from braintrust.span_cache import CachedSpan
|
|
3926
|
+
|
|
3927
|
+
cached_span = CachedSpan(
|
|
3928
|
+
span_id=self.span_id,
|
|
3929
|
+
input=serializable_partial_record.get("input"),
|
|
3930
|
+
output=serializable_partial_record.get("output"),
|
|
3931
|
+
metadata=serializable_partial_record.get("metadata"),
|
|
3932
|
+
span_parents=self.span_parents,
|
|
3933
|
+
span_attributes=serializable_partial_record.get("span_attributes"),
|
|
3934
|
+
)
|
|
3935
|
+
self.state.span_cache.queue_write(self.root_span_id, self.span_id, cached_span)
|
|
3936
|
+
|
|
3858
3937
|
def compute_record() -> dict[str, Any]:
|
|
3859
3938
|
exporter = _get_exporter()
|
|
3860
3939
|
return dict(
|
|
@@ -3938,6 +4017,9 @@ class SpanImpl(Span):
|
|
|
3938
4017
|
use_v4 = os.getenv("BRAINTRUST_OTEL_COMPAT", "false").lower() == "true"
|
|
3939
4018
|
span_components_class = SpanComponentsV4 if use_v4 else SpanComponentsV3
|
|
3940
4019
|
|
|
4020
|
+
# Disable span cache since remote function spans won't be in the local cache
|
|
4021
|
+
self.state.span_cache.disable()
|
|
4022
|
+
|
|
3941
4023
|
return span_components_class(
|
|
3942
4024
|
object_type=self.parent_object_type,
|
|
3943
4025
|
object_id=object_id,
|
|
@@ -3951,7 +4033,7 @@ class SpanImpl(Span):
|
|
|
3951
4033
|
def link(self) -> str:
|
|
3952
4034
|
parent_type, info = self._get_parent_info()
|
|
3953
4035
|
if parent_type == SpanObjectTypeV3.PROJECT_LOGS:
|
|
3954
|
-
cur_logger = self.state.
|
|
4036
|
+
cur_logger = self.state._cv_logger.get() or self.state._local_logger
|
|
3955
4037
|
if not cur_logger:
|
|
3956
4038
|
return NOOP_SPAN_PERMALINK
|
|
3957
4039
|
base_url = cur_logger._get_link_base_url()
|
braintrust/oai.py
CHANGED
|
@@ -5,6 +5,8 @@ import time
|
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
+
from wrapt import wrap_function_wrapper
|
|
9
|
+
|
|
8
10
|
from .logger import Attachment, Span, start_span
|
|
9
11
|
from .span_types import SpanTypeAttribute
|
|
10
12
|
from .util import merge_dicts
|
|
@@ -986,3 +988,52 @@ def _is_not_given(value: Any) -> bool:
|
|
|
986
988
|
return type_name == "NotGiven"
|
|
987
989
|
except Exception:
|
|
988
990
|
return False
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
def _openai_init_wrapper(wrapped, instance, args, kwargs):
|
|
994
|
+
"""Wrapper for OpenAI.__init__ that applies tracing after initialization."""
|
|
995
|
+
wrapped(*args, **kwargs)
|
|
996
|
+
_apply_openai_wrapper(instance)
|
|
997
|
+
|
|
998
|
+
|
|
999
|
+
def patch_openai() -> bool:
|
|
1000
|
+
"""
|
|
1001
|
+
Patch OpenAI to add Braintrust tracing globally.
|
|
1002
|
+
|
|
1003
|
+
After calling this, all new OpenAI() and AsyncOpenAI() clients
|
|
1004
|
+
will automatically have tracing enabled.
|
|
1005
|
+
|
|
1006
|
+
Returns:
|
|
1007
|
+
True if OpenAI was patched (or already patched), False if OpenAI is not installed.
|
|
1008
|
+
|
|
1009
|
+
Example:
|
|
1010
|
+
```python
|
|
1011
|
+
import braintrust
|
|
1012
|
+
braintrust.patch_openai()
|
|
1013
|
+
|
|
1014
|
+
import openai
|
|
1015
|
+
client = openai.OpenAI()
|
|
1016
|
+
# All calls are now traced!
|
|
1017
|
+
```
|
|
1018
|
+
"""
|
|
1019
|
+
try:
|
|
1020
|
+
import openai
|
|
1021
|
+
|
|
1022
|
+
if getattr(openai, "__braintrust_wrapped__", False):
|
|
1023
|
+
return True # Already patched
|
|
1024
|
+
|
|
1025
|
+
wrap_function_wrapper("openai", "OpenAI.__init__", _openai_init_wrapper)
|
|
1026
|
+
wrap_function_wrapper("openai", "AsyncOpenAI.__init__", _openai_init_wrapper)
|
|
1027
|
+
openai.__braintrust_wrapped__ = True
|
|
1028
|
+
return True
|
|
1029
|
+
|
|
1030
|
+
except ImportError:
|
|
1031
|
+
return False
|
|
1032
|
+
|
|
1033
|
+
|
|
1034
|
+
def _apply_openai_wrapper(client):
|
|
1035
|
+
"""Apply tracing wrapper to an OpenAI client instance in-place."""
|
|
1036
|
+
wrapped = wrap_openai(client)
|
|
1037
|
+
for attr in ("chat", "responses", "embeddings", "moderations", "beta"):
|
|
1038
|
+
if hasattr(wrapped, attr):
|
|
1039
|
+
setattr(client, attr, getattr(wrapped, attr))
|