braintrust 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/_generated_types.py +737 -672
- braintrust/audit.py +2 -2
- braintrust/bt_json.py +178 -19
- braintrust/cli/eval.py +6 -7
- braintrust/cli/push.py +11 -11
- braintrust/context.py +12 -17
- braintrust/contrib/temporal/__init__.py +16 -27
- braintrust/contrib/temporal/test_temporal.py +8 -3
- braintrust/devserver/auth.py +8 -8
- braintrust/devserver/cache.py +3 -4
- braintrust/devserver/cors.py +8 -7
- braintrust/devserver/dataset.py +3 -5
- braintrust/devserver/eval_hooks.py +7 -6
- braintrust/devserver/schemas.py +22 -19
- braintrust/devserver/server.py +19 -12
- braintrust/devserver/test_cached_login.py +4 -4
- braintrust/framework.py +139 -142
- braintrust/framework2.py +88 -87
- braintrust/functions/invoke.py +66 -59
- braintrust/functions/stream.py +3 -2
- braintrust/generated_types.py +3 -1
- braintrust/git_fields.py +11 -11
- braintrust/gitutil.py +2 -3
- braintrust/graph_util.py +10 -10
- braintrust/id_gen.py +2 -2
- braintrust/logger.py +373 -471
- braintrust/merge_row_batch.py +10 -9
- braintrust/oai.py +21 -20
- braintrust/otel/__init__.py +49 -49
- braintrust/otel/context.py +16 -30
- braintrust/otel/test_distributed_tracing.py +14 -11
- braintrust/otel/test_otel_bt_integration.py +32 -31
- braintrust/parameters.py +8 -8
- braintrust/prompt.py +14 -14
- braintrust/prompt_cache/disk_cache.py +5 -4
- braintrust/prompt_cache/lru_cache.py +3 -2
- braintrust/prompt_cache/prompt_cache.py +13 -14
- braintrust/queue.py +4 -4
- braintrust/score.py +4 -4
- braintrust/serializable_data_class.py +4 -4
- braintrust/span_identifier_v1.py +1 -2
- braintrust/span_identifier_v2.py +3 -4
- braintrust/span_identifier_v3.py +23 -20
- braintrust/span_identifier_v4.py +34 -25
- braintrust/test_bt_json.py +644 -0
- braintrust/test_framework.py +72 -6
- braintrust/test_helpers.py +5 -5
- braintrust/test_id_gen.py +2 -3
- braintrust/test_logger.py +211 -107
- braintrust/test_otel.py +61 -53
- braintrust/test_queue.py +0 -1
- braintrust/test_score.py +1 -3
- braintrust/test_span_components.py +29 -44
- braintrust/util.py +9 -8
- braintrust/version.py +2 -2
- braintrust/wrappers/_anthropic_utils.py +4 -4
- braintrust/wrappers/agno/__init__.py +3 -4
- braintrust/wrappers/agno/agent.py +1 -2
- braintrust/wrappers/agno/function_call.py +1 -2
- braintrust/wrappers/agno/model.py +1 -2
- braintrust/wrappers/agno/team.py +1 -2
- braintrust/wrappers/agno/utils.py +12 -12
- braintrust/wrappers/anthropic.py +7 -8
- braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
- braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
- braintrust/wrappers/dspy.py +15 -17
- braintrust/wrappers/google_genai/__init__.py +17 -30
- braintrust/wrappers/langchain.py +22 -24
- braintrust/wrappers/litellm.py +4 -3
- braintrust/wrappers/openai.py +15 -15
- braintrust/wrappers/pydantic_ai.py +225 -110
- braintrust/wrappers/test_agno.py +0 -1
- braintrust/wrappers/test_dspy.py +0 -1
- braintrust/wrappers/test_google_genai.py +64 -4
- braintrust/wrappers/test_litellm.py +0 -1
- braintrust/wrappers/test_pydantic_ai_integration.py +819 -22
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/METADATA +3 -2
- braintrust-0.4.1.dist-info/RECORD +121 -0
- braintrust-0.3.15.dist-info/RECORD +0 -120
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/WHEEL +0 -0
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/entry_points.txt +0 -0
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/top_level.txt +0 -0
braintrust/audit.py
CHANGED
|
@@ -5,7 +5,7 @@ Utilities for working with audit headers.
|
|
|
5
5
|
import base64
|
|
6
6
|
import gzip
|
|
7
7
|
import json
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import TypedDict
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class AuditResource(TypedDict):
|
|
@@ -14,7 +14,7 @@ class AuditResource(TypedDict):
|
|
|
14
14
|
name: str
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def parse_audit_resources(hdr: str) ->
|
|
17
|
+
def parse_audit_resources(hdr: str) -> list[AuditResource]:
|
|
18
18
|
j = json.loads(hdr)
|
|
19
19
|
if j["v"] == 1:
|
|
20
20
|
return json.loads(gzip.decompress(base64.b64decode(j["p"])))
|
braintrust/bt_json.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
import json
|
|
3
|
-
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any, Callable, Mapping, NamedTuple, cast, overload
|
|
4
5
|
|
|
5
6
|
# Try to import orjson for better performance
|
|
6
7
|
# If not available, we'll use standard json
|
|
@@ -12,39 +13,184 @@ except ImportError:
|
|
|
12
13
|
_HAS_ORJSON = False
|
|
13
14
|
|
|
14
15
|
|
|
15
|
-
def _to_dict(obj: Any) -> Any:
|
|
16
|
-
"""
|
|
17
|
-
Function-based default handler for non-JSON-serializable objects.
|
|
18
16
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
-
|
|
22
|
-
- Pydantic v1 BaseModel
|
|
23
|
-
- Falls back to str() for unknown types
|
|
17
|
+
def _to_bt_safe(v: Any) -> Any:
|
|
18
|
+
"""
|
|
19
|
+
Converts the object to a Braintrust-safe representation (i.e. Attachment objects are safe (specially handled by background logger)).
|
|
24
20
|
"""
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
# avoid circular imports
|
|
22
|
+
from braintrust.logger import BaseAttachment, Dataset, Experiment, Logger, ReadonlyAttachment, Span
|
|
23
|
+
|
|
24
|
+
if isinstance(v, Span):
|
|
25
|
+
return "<span>"
|
|
26
|
+
|
|
27
|
+
if isinstance(v, Experiment):
|
|
28
|
+
return "<experiment>"
|
|
29
|
+
|
|
30
|
+
if isinstance(v, Dataset):
|
|
31
|
+
return "<dataset>"
|
|
32
|
+
|
|
33
|
+
if isinstance(v, Logger):
|
|
34
|
+
return "<logger>"
|
|
35
|
+
|
|
36
|
+
if isinstance(v, BaseAttachment):
|
|
37
|
+
return v
|
|
38
|
+
|
|
39
|
+
if isinstance(v, ReadonlyAttachment):
|
|
40
|
+
return v.reference
|
|
41
|
+
|
|
42
|
+
if dataclasses.is_dataclass(v) and not isinstance(v, type):
|
|
43
|
+
# Use manual field iteration instead of dataclasses.asdict() because
|
|
44
|
+
# asdict() deep-copies values, which breaks objects like Attachment
|
|
45
|
+
# that contain non-copyable items (thread locks, file handles, etc.)
|
|
46
|
+
return {f.name: _to_bt_safe(getattr(v, f.name)) for f in dataclasses.fields(v)}
|
|
47
|
+
|
|
48
|
+
# Pydantic model classes (not instances) with model_json_schema
|
|
49
|
+
if isinstance(v, type) and hasattr(v, "model_json_schema") and callable(cast(Any, v).model_json_schema):
|
|
50
|
+
try:
|
|
51
|
+
return cast(Any, v).model_json_schema()
|
|
52
|
+
except Exception:
|
|
53
|
+
pass
|
|
27
54
|
|
|
28
55
|
# Attempt to dump a Pydantic v2 `BaseModel`.
|
|
29
56
|
try:
|
|
30
|
-
return cast(Any,
|
|
57
|
+
return cast(Any, v).model_dump(exclude_none=True)
|
|
31
58
|
except (AttributeError, TypeError):
|
|
32
59
|
pass
|
|
33
60
|
|
|
34
61
|
# Attempt to dump a Pydantic v1 `BaseModel`.
|
|
35
62
|
try:
|
|
36
|
-
return cast(Any,
|
|
63
|
+
return cast(Any, v).dict(exclude_none=True)
|
|
37
64
|
except (AttributeError, TypeError):
|
|
38
65
|
pass
|
|
39
66
|
|
|
40
|
-
|
|
67
|
+
if isinstance(v, float):
|
|
68
|
+
# Handle NaN and Infinity for JSON compatibility
|
|
69
|
+
if math.isnan(v):
|
|
70
|
+
return "NaN"
|
|
71
|
+
|
|
72
|
+
if math.isinf(v):
|
|
73
|
+
return "Infinity" if v > 0 else "-Infinity"
|
|
74
|
+
|
|
75
|
+
return v
|
|
76
|
+
|
|
77
|
+
if isinstance(v, (int, str, bool)) or v is None:
|
|
78
|
+
# Skip roundtrip for primitive types.
|
|
79
|
+
return v
|
|
80
|
+
|
|
81
|
+
# Note: we avoid using copy.deepcopy, because it's difficult to
|
|
82
|
+
# guarantee the independence of such copied types from their origin.
|
|
83
|
+
# E.g. the original type could have a `__del__` method that alters
|
|
84
|
+
# some shared internal state, and we need this deep copy to be
|
|
85
|
+
# fully-independent from the original.
|
|
86
|
+
|
|
87
|
+
# We pass `encoder=_str_encoder` since we've already tried converting rich objects to json safe objects.
|
|
88
|
+
return bt_loads(bt_dumps(v, encoder=_str_encoder))
|
|
89
|
+
|
|
90
|
+
@overload
|
|
91
|
+
def bt_safe_deep_copy(
|
|
92
|
+
obj: Mapping[str, Any],
|
|
93
|
+
max_depth: int = ...,
|
|
94
|
+
) -> dict[str, Any]: ...
|
|
95
|
+
|
|
96
|
+
@overload
|
|
97
|
+
def bt_safe_deep_copy(
|
|
98
|
+
obj: list[Any],
|
|
99
|
+
max_depth: int = ...,
|
|
100
|
+
) -> list[Any]: ...
|
|
101
|
+
|
|
102
|
+
@overload
|
|
103
|
+
def bt_safe_deep_copy(
|
|
104
|
+
obj: Any,
|
|
105
|
+
max_depth: int = ...,
|
|
106
|
+
) -> Any: ...
|
|
107
|
+
def bt_safe_deep_copy(obj: Any, max_depth: int=200):
|
|
108
|
+
"""
|
|
109
|
+
Creates a deep copy of the given object and converts rich objects to Braintrust-safe representations. See `_to_bt_safe` for more details.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
obj: Object to deep copy and sanitize.
|
|
113
|
+
to_json_safe: Function to ensure the object is json safe.
|
|
114
|
+
max_depth: Maximum depth to copy.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Deep copy of the object.
|
|
118
|
+
"""
|
|
119
|
+
# Track visited objects to detect circular references
|
|
120
|
+
visited: set[int] = set()
|
|
121
|
+
|
|
122
|
+
def _deep_copy_object(v: Any, depth: int = 0) -> Any:
|
|
123
|
+
# Check depth limit - use >= to stop before exceeding
|
|
124
|
+
if depth >= max_depth:
|
|
125
|
+
return "<max depth exceeded>"
|
|
126
|
+
|
|
127
|
+
# Check for circular references in mutable containers
|
|
128
|
+
# Use id() to track object identity
|
|
129
|
+
if isinstance(v, (Mapping, list, tuple, set)):
|
|
130
|
+
obj_id = id(v)
|
|
131
|
+
if obj_id in visited:
|
|
132
|
+
return "<circular reference>"
|
|
133
|
+
visited.add(obj_id)
|
|
134
|
+
try:
|
|
135
|
+
if isinstance(v, Mapping):
|
|
136
|
+
# Prevent dict keys from holding references to user data. Note that
|
|
137
|
+
# `bt_json` already coerces keys to string, a behavior that comes from
|
|
138
|
+
# `json.dumps`. However, that runs at log upload time, while we want to
|
|
139
|
+
# cut out all the references to user objects synchronously in this
|
|
140
|
+
# function.
|
|
141
|
+
result = {}
|
|
142
|
+
for k in v:
|
|
143
|
+
try:
|
|
144
|
+
key_str = str(k)
|
|
145
|
+
except Exception:
|
|
146
|
+
# If str() fails on the key, use a fallback representation
|
|
147
|
+
key_str = f"<non-stringifiable-key: {type(k).__name__}>"
|
|
148
|
+
result[key_str] = _deep_copy_object(v[k], depth + 1)
|
|
149
|
+
return result
|
|
150
|
+
elif isinstance(v, (list, tuple, set)):
|
|
151
|
+
return [_deep_copy_object(x, depth + 1) for x in v]
|
|
152
|
+
finally:
|
|
153
|
+
# Remove from visited set after processing to allow the same object
|
|
154
|
+
# to appear in different branches of the tree
|
|
155
|
+
visited.discard(obj_id)
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
return _to_bt_safe(v)
|
|
159
|
+
except Exception:
|
|
160
|
+
return f"<non-sanitizable: {type(v).__name__}>"
|
|
161
|
+
|
|
162
|
+
return _deep_copy_object(obj)
|
|
163
|
+
|
|
164
|
+
def _safe_str(obj: Any) -> str:
|
|
41
165
|
try:
|
|
42
166
|
return str(obj)
|
|
43
167
|
except Exception:
|
|
44
|
-
# If str() fails, return an error placeholder
|
|
45
168
|
return f"<non-serializable: {type(obj).__name__}>"
|
|
46
169
|
|
|
47
170
|
|
|
171
|
+
def _to_json_safe(obj: Any) -> Any:
|
|
172
|
+
"""
|
|
173
|
+
Handler for non-JSON-serializable objects. Returns a string representation of the object.
|
|
174
|
+
"""
|
|
175
|
+
# avoid circular imports
|
|
176
|
+
from braintrust.logger import BaseAttachment
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
v = _to_bt_safe(obj)
|
|
180
|
+
|
|
181
|
+
# JSON-safe representation of Attachment objects are their reference.
|
|
182
|
+
# If we get this object at this point, we have to assume someone has already uploaded the attachment!
|
|
183
|
+
if isinstance(v, BaseAttachment):
|
|
184
|
+
v = v.reference
|
|
185
|
+
|
|
186
|
+
return v
|
|
187
|
+
except Exception:
|
|
188
|
+
pass
|
|
189
|
+
|
|
190
|
+
# When everything fails, try to return the string representation of the object
|
|
191
|
+
return _safe_str(obj)
|
|
192
|
+
|
|
193
|
+
|
|
48
194
|
class BraintrustJSONEncoder(json.JSONEncoder):
|
|
49
195
|
"""
|
|
50
196
|
Custom JSON encoder for standard json library.
|
|
@@ -53,10 +199,22 @@ class BraintrustJSONEncoder(json.JSONEncoder):
|
|
|
53
199
|
"""
|
|
54
200
|
|
|
55
201
|
def default(self, o: Any):
|
|
56
|
-
return
|
|
202
|
+
return _to_json_safe(o)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class BraintrustStrEncoder(json.JSONEncoder):
|
|
206
|
+
def default(self, o: Any):
|
|
207
|
+
return _safe_str(o)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class Encoder(NamedTuple):
|
|
211
|
+
native: type[json.JSONEncoder]
|
|
212
|
+
orjson: Callable[[Any], Any]
|
|
57
213
|
|
|
214
|
+
_json_encoder = Encoder(native=BraintrustJSONEncoder, orjson=_to_json_safe)
|
|
215
|
+
_str_encoder = Encoder(native=BraintrustStrEncoder, orjson=_safe_str)
|
|
58
216
|
|
|
59
|
-
def bt_dumps(obj, **kwargs) -> str:
|
|
217
|
+
def bt_dumps(obj: Any, encoder: Encoder | None=_json_encoder, **kwargs: Any) -> str:
|
|
60
218
|
"""
|
|
61
219
|
Serialize obj to a JSON-formatted string.
|
|
62
220
|
|
|
@@ -65,6 +223,7 @@ def bt_dumps(obj, **kwargs) -> str:
|
|
|
65
223
|
|
|
66
224
|
Args:
|
|
67
225
|
obj: Object to serialize
|
|
226
|
+
encoder: Encoder to use, defaults to `_default_encoder`
|
|
68
227
|
**kwargs: Additional arguments (passed to json.dumps in fallback path)
|
|
69
228
|
|
|
70
229
|
Returns:
|
|
@@ -76,7 +235,7 @@ def bt_dumps(obj, **kwargs) -> str:
|
|
|
76
235
|
# pylint: disable=no-member # orjson is a C extension, pylint can't introspect it
|
|
77
236
|
return orjson.dumps( # type: ignore[possibly-unbound]
|
|
78
237
|
obj,
|
|
79
|
-
default=
|
|
238
|
+
default=encoder.orjson if encoder else None,
|
|
80
239
|
# options match json.dumps behavior for bc
|
|
81
240
|
option=orjson.OPT_SORT_KEYS | orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS, # type: ignore[possibly-unbound]
|
|
82
241
|
).decode("utf-8")
|
|
@@ -86,7 +245,7 @@ def bt_dumps(obj, **kwargs) -> str:
|
|
|
86
245
|
|
|
87
246
|
# Use standard json (either orjson not available or it failed)
|
|
88
247
|
# Use sort_keys=True for deterministic output (matches orjson OPT_SORT_KEYS)
|
|
89
|
-
return json.dumps(obj, cls=
|
|
248
|
+
return json.dumps(obj, cls=encoder.native if encoder else None, allow_nan=False, sort_keys=True, **kwargs)
|
|
90
249
|
|
|
91
250
|
|
|
92
251
|
def bt_loads(s: str, **kwargs) -> Any:
|
braintrust/cli/eval.py
CHANGED
|
@@ -6,7 +6,6 @@ import os
|
|
|
6
6
|
import sys
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from threading import Lock
|
|
9
|
-
from typing import Dict, List, Optional, Union
|
|
10
9
|
|
|
11
10
|
from .. import login
|
|
12
11
|
from ..framework import (
|
|
@@ -70,7 +69,7 @@ class EvaluatorOpts:
|
|
|
70
69
|
no_progress_bars: bool
|
|
71
70
|
terminate_on_failure: bool
|
|
72
71
|
watch: bool
|
|
73
|
-
filters:
|
|
72
|
+
filters: list[str]
|
|
74
73
|
list: bool
|
|
75
74
|
jsonl: bool
|
|
76
75
|
|
|
@@ -79,13 +78,13 @@ class EvaluatorOpts:
|
|
|
79
78
|
class LoadedEvaluator:
|
|
80
79
|
handle: FileHandle
|
|
81
80
|
evaluator: Evaluator
|
|
82
|
-
reporter:
|
|
81
|
+
reporter: ReporterDef | str | None = None
|
|
83
82
|
|
|
84
83
|
|
|
85
84
|
@dataclass
|
|
86
85
|
class EvaluatorState:
|
|
87
|
-
evaluators:
|
|
88
|
-
reporters:
|
|
86
|
+
evaluators: list[LoadedEvaluator] = field(default_factory=list)
|
|
87
|
+
reporters: dict[str, ReporterDef] = field(default_factory=dict)
|
|
89
88
|
|
|
90
89
|
|
|
91
90
|
def update_evaluators(eval_state: EvaluatorState, handles, terminate_on_failure):
|
|
@@ -157,7 +156,7 @@ async def run_evaluator_task(evaluator, position, opts: EvaluatorOpts):
|
|
|
157
156
|
experiment.flush()
|
|
158
157
|
|
|
159
158
|
|
|
160
|
-
def resolve_reporter(reporter:
|
|
159
|
+
def resolve_reporter(reporter: ReporterDef | str | None, reporters: dict[str, ReporterDef]) -> ReporterDef:
|
|
161
160
|
if isinstance(reporter, str):
|
|
162
161
|
if reporter not in reporters:
|
|
163
162
|
raise ValueError(f"Reporter {reporter} not found")
|
|
@@ -179,7 +178,7 @@ def add_report(eval_reports, reporter, report):
|
|
|
179
178
|
eval_reports[reporter.name]["results"].append(report)
|
|
180
179
|
|
|
181
180
|
|
|
182
|
-
async def run_once(handles:
|
|
181
|
+
async def run_once(handles: list[FileHandle], evaluator_opts: EvaluatorOpts) -> bool:
|
|
183
182
|
objects = EvaluatorState()
|
|
184
183
|
update_evaluators(objects, handles, terminate_on_failure=evaluator_opts.terminate_on_failure)
|
|
185
184
|
|
braintrust/cli/push.py
CHANGED
|
@@ -13,7 +13,7 @@ import sys
|
|
|
13
13
|
import tempfile
|
|
14
14
|
import textwrap
|
|
15
15
|
import zipfile
|
|
16
|
-
from typing import Any
|
|
16
|
+
from typing import Any
|
|
17
17
|
|
|
18
18
|
import requests
|
|
19
19
|
from braintrust.framework import _set_lazy_load
|
|
@@ -24,7 +24,7 @@ from ..generated_types import IfExists
|
|
|
24
24
|
from ..util import add_azure_blob_headers
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def _pkg_install_arg(pkg) ->
|
|
27
|
+
def _pkg_install_arg(pkg) -> str | None:
|
|
28
28
|
try:
|
|
29
29
|
dist = importlib.metadata.distribution(pkg)
|
|
30
30
|
direct_url = dist._path / "direct_url.json" # type: ignore
|
|
@@ -73,11 +73,11 @@ class _ProjectRootImporter(importlib.abc.MetaPathFinder):
|
|
|
73
73
|
self._project_root, self._path_rest = sys.path[0], sys.path[1:]
|
|
74
74
|
self._sources = []
|
|
75
75
|
|
|
76
|
-
def _under_project_root(self, path:
|
|
76
|
+
def _under_project_root(self, path: list[str]) -> bool:
|
|
77
77
|
"""Returns true if all paths in `path` are under the project root."""
|
|
78
78
|
return all(p.startswith(self._project_root) for p in path)
|
|
79
79
|
|
|
80
|
-
def _under_rest(self, path:
|
|
80
|
+
def _under_rest(self, path: list[str]) -> bool:
|
|
81
81
|
"""Returns true if any path in `path` is under one of the remaining paths in `sys.path`."""
|
|
82
82
|
return any(p.startswith(pr) for p in path for pr in self._path_rest)
|
|
83
83
|
|
|
@@ -94,11 +94,11 @@ class _ProjectRootImporter(importlib.abc.MetaPathFinder):
|
|
|
94
94
|
self._sources.append(spec.origin)
|
|
95
95
|
return spec
|
|
96
96
|
|
|
97
|
-
def sources(self) ->
|
|
97
|
+
def sources(self) -> list[str]:
|
|
98
98
|
return self._sources
|
|
99
99
|
|
|
100
100
|
|
|
101
|
-
def _import_module(name: str, path: str) ->
|
|
101
|
+
def _import_module(name: str, path: str) -> list[str]:
|
|
102
102
|
"""Imports the module and returns the list of source files
|
|
103
103
|
of all modules imported in the process.
|
|
104
104
|
|
|
@@ -120,7 +120,7 @@ def _py_version() -> str:
|
|
|
120
120
|
return f"{sys.version_info.major}.{sys.version_info.minor}"
|
|
121
121
|
|
|
122
122
|
|
|
123
|
-
def _run_install(install_args:
|
|
123
|
+
def _run_install(install_args: list[str], packages_dir: str):
|
|
124
124
|
subprocess.run(
|
|
125
125
|
[
|
|
126
126
|
"uv",
|
|
@@ -138,7 +138,7 @@ def _run_install(install_args: List[str], packages_dir: str):
|
|
|
138
138
|
)
|
|
139
139
|
|
|
140
140
|
|
|
141
|
-
def _upload_bundle(entry_module_name: str, sources:
|
|
141
|
+
def _upload_bundle(entry_module_name: str, sources: list[str], requirements: str | None) -> str:
|
|
142
142
|
_check_uv()
|
|
143
143
|
|
|
144
144
|
resp = proxy_conn().post_json(
|
|
@@ -212,7 +212,7 @@ def _upload_bundle(entry_module_name: str, sources: List[str], requirements: Opt
|
|
|
212
212
|
|
|
213
213
|
|
|
214
214
|
def _collect_function_function_defs(
|
|
215
|
-
project_ids: ProjectIdCache, functions:
|
|
215
|
+
project_ids: ProjectIdCache, functions: list[dict[str, Any]], bundle_id: str, if_exists: IfExists
|
|
216
216
|
) -> None:
|
|
217
217
|
for i, f in enumerate(global_.functions):
|
|
218
218
|
source = inspect.getsource(f.handler)
|
|
@@ -262,7 +262,7 @@ def _collect_function_function_defs(
|
|
|
262
262
|
|
|
263
263
|
|
|
264
264
|
def _collect_prompt_function_defs(
|
|
265
|
-
project_ids: ProjectIdCache, functions:
|
|
265
|
+
project_ids: ProjectIdCache, functions: list[dict[str, Any]], if_exists: IfExists
|
|
266
266
|
) -> None:
|
|
267
267
|
for p in global_.prompts:
|
|
268
268
|
functions.append(p.to_function_definition(if_exists, project_ids))
|
|
@@ -300,7 +300,7 @@ def run(args):
|
|
|
300
300
|
raise
|
|
301
301
|
|
|
302
302
|
project_ids = ProjectIdCache()
|
|
303
|
-
functions:
|
|
303
|
+
functions: list[dict[str, Any]] = []
|
|
304
304
|
if len(global_.functions) > 0:
|
|
305
305
|
bundle_id = _upload_bundle(module_name, sources, args.requirements)
|
|
306
306
|
_collect_function_function_defs(project_ids, functions, bundle_id, args.if_exists)
|
braintrust/context.py
CHANGED
|
@@ -5,12 +5,13 @@ import os
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from contextvars import ContextVar
|
|
7
7
|
from dataclasses import dataclass
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass
|
|
12
12
|
class SpanInfo:
|
|
13
13
|
"""Information about a span in the context."""
|
|
14
|
+
|
|
14
15
|
trace_id: str
|
|
15
16
|
span_id: str
|
|
16
17
|
span_object: Any = None
|
|
@@ -19,7 +20,7 @@ class SpanInfo:
|
|
|
19
20
|
@dataclass
|
|
20
21
|
class ParentSpanIds:
|
|
21
22
|
root_span_id: str
|
|
22
|
-
span_parents:
|
|
23
|
+
span_parents: list[str]
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class ContextManager(ABC):
|
|
@@ -30,7 +31,7 @@ class ContextManager(ABC):
|
|
|
30
31
|
"""
|
|
31
32
|
|
|
32
33
|
@abstractmethod
|
|
33
|
-
def get_current_span_info(self) ->
|
|
34
|
+
def get_current_span_info(self) -> Any | None:
|
|
34
35
|
"""Get information about the currently active span.
|
|
35
36
|
|
|
36
37
|
Returns:
|
|
@@ -40,7 +41,7 @@ class ContextManager(ABC):
|
|
|
40
41
|
pass
|
|
41
42
|
|
|
42
43
|
@abstractmethod
|
|
43
|
-
def get_parent_span_ids(self) ->
|
|
44
|
+
def get_parent_span_ids(self) -> ParentSpanIds | None:
|
|
44
45
|
"""Get parent span IDs for creating a new Braintrust span.
|
|
45
46
|
|
|
46
47
|
Returns:
|
|
@@ -75,32 +76,25 @@ class BraintrustContextManager(ContextManager):
|
|
|
75
76
|
"""Braintrust-only context manager using contextvars when OTEL is not available."""
|
|
76
77
|
|
|
77
78
|
def __init__(self):
|
|
78
|
-
self._current_span: ContextVar[
|
|
79
|
+
self._current_span: ContextVar[Any | None] = ContextVar("braintrust_current_span", default=None)
|
|
79
80
|
|
|
80
|
-
def get_current_span_info(self) ->
|
|
81
|
+
def get_current_span_info(self) -> SpanInfo | None:
|
|
81
82
|
"""Get information about the currently active span."""
|
|
82
83
|
current_span = self._current_span.get()
|
|
83
84
|
if not current_span:
|
|
84
85
|
return None
|
|
85
86
|
|
|
86
87
|
# Return SpanInfo for BT spans
|
|
87
|
-
return SpanInfo(
|
|
88
|
-
trace_id=current_span.root_span_id,
|
|
89
|
-
span_id=current_span.span_id,
|
|
90
|
-
span_object=current_span
|
|
91
|
-
)
|
|
88
|
+
return SpanInfo(trace_id=current_span.root_span_id, span_id=current_span.span_id, span_object=current_span)
|
|
92
89
|
|
|
93
|
-
def get_parent_span_ids(self) ->
|
|
90
|
+
def get_parent_span_ids(self) -> ParentSpanIds | None:
|
|
94
91
|
"""Get parent information for creating a new Braintrust span."""
|
|
95
92
|
current_span = self._current_span.get()
|
|
96
93
|
if not current_span:
|
|
97
94
|
return None
|
|
98
95
|
|
|
99
96
|
# If current span is a BT span, use it as parent
|
|
100
|
-
return ParentSpanIds(
|
|
101
|
-
root_span_id=current_span.root_span_id,
|
|
102
|
-
span_parents=[current_span.span_id]
|
|
103
|
-
)
|
|
97
|
+
return ParentSpanIds(root_span_id=current_span.root_span_id, span_parents=[current_span.span_id])
|
|
104
98
|
|
|
105
99
|
def set_current_span(self, span_object: Any) -> Any:
|
|
106
100
|
"""Set the current active span."""
|
|
@@ -123,9 +117,10 @@ def get_context_manager() -> ContextManager:
|
|
|
123
117
|
"""
|
|
124
118
|
|
|
125
119
|
# Check if OTEL should be explicitly enabled via environment variable
|
|
126
|
-
if os.environ.get(
|
|
120
|
+
if os.environ.get("BRAINTRUST_OTEL_COMPAT", "").lower() in ("1", "true", "yes"):
|
|
127
121
|
try:
|
|
128
122
|
from braintrust.otel.context import ContextManager as OtelContextManager
|
|
123
|
+
|
|
129
124
|
return OtelContextManager()
|
|
130
125
|
except ImportError:
|
|
131
126
|
logging.warning("OTEL not available, falling back to Braintrust-only version")
|
|
@@ -80,7 +80,8 @@ The integration will automatically:
|
|
|
80
80
|
"""
|
|
81
81
|
|
|
82
82
|
import dataclasses
|
|
83
|
-
from
|
|
83
|
+
from collections.abc import Mapping
|
|
84
|
+
from typing import Any
|
|
84
85
|
|
|
85
86
|
import braintrust
|
|
86
87
|
import temporalio.activity
|
|
@@ -129,7 +130,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
129
130
|
- Ensures replay safety (no duplicate spans during workflow replay)
|
|
130
131
|
"""
|
|
131
132
|
|
|
132
|
-
def __init__(self, logger:
|
|
133
|
+
def __init__(self, logger: Any | None = None) -> None:
|
|
133
134
|
"""Initialize interceptor.
|
|
134
135
|
|
|
135
136
|
Args:
|
|
@@ -142,7 +143,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
142
143
|
braintrust.logger._state._override_bg_logger.logger = logger
|
|
143
144
|
self._logger = braintrust.current_logger()
|
|
144
145
|
|
|
145
|
-
def _get_logger(self) ->
|
|
146
|
+
def _get_logger(self) -> Any | None:
|
|
146
147
|
"""Get logger for creating spans.
|
|
147
148
|
|
|
148
149
|
Sets thread-local override if background logger provided (for testing),
|
|
@@ -152,9 +153,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
152
153
|
braintrust.logger._state._override_bg_logger.logger = self._bg_logger
|
|
153
154
|
return self._logger
|
|
154
155
|
|
|
155
|
-
def intercept_client(
|
|
156
|
-
self, next: temporalio.client.OutboundInterceptor
|
|
157
|
-
) -> temporalio.client.OutboundInterceptor:
|
|
156
|
+
def intercept_client(self, next: temporalio.client.OutboundInterceptor) -> temporalio.client.OutboundInterceptor:
|
|
158
157
|
"""Intercept client calls to propagate span context to workflows."""
|
|
159
158
|
return _BraintrustClientOutboundInterceptor(next, self)
|
|
160
159
|
|
|
@@ -166,14 +165,14 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
166
165
|
|
|
167
166
|
def workflow_interceptor_class(
|
|
168
167
|
self, input: temporalio.worker.WorkflowInterceptorClassInput
|
|
169
|
-
) ->
|
|
168
|
+
) -> type["BraintrustWorkflowInboundInterceptor"] | None:
|
|
170
169
|
"""Return workflow interceptor class to propagate context to activities."""
|
|
171
170
|
input.unsafe_extern_functions["__braintrust_get_logger"] = self._get_logger
|
|
172
171
|
return BraintrustWorkflowInboundInterceptor
|
|
173
172
|
|
|
174
173
|
def _span_context_to_headers(
|
|
175
174
|
self,
|
|
176
|
-
span_context:
|
|
175
|
+
span_context: dict[str, Any],
|
|
177
176
|
headers: Mapping[str, temporalio.api.common.v1.Payload],
|
|
178
177
|
) -> Mapping[str, temporalio.api.common.v1.Payload]:
|
|
179
178
|
"""Add span context to headers."""
|
|
@@ -188,7 +187,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
188
187
|
|
|
189
188
|
def _span_context_from_headers(
|
|
190
189
|
self, headers: Mapping[str, temporalio.api.common.v1.Payload]
|
|
191
|
-
) ->
|
|
190
|
+
) -> dict[str, Any] | None:
|
|
192
191
|
"""Extract span context from headers."""
|
|
193
192
|
if _HEADER_KEY not in headers:
|
|
194
193
|
return None
|
|
@@ -204,9 +203,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
204
203
|
class _BraintrustClientOutboundInterceptor(temporalio.client.OutboundInterceptor):
|
|
205
204
|
"""Client interceptor that propagates span context to workflows."""
|
|
206
205
|
|
|
207
|
-
def __init__(
|
|
208
|
-
self, next: temporalio.client.OutboundInterceptor, root: BraintrustInterceptor
|
|
209
|
-
) -> None:
|
|
206
|
+
def __init__(self, next: temporalio.client.OutboundInterceptor, root: BraintrustInterceptor) -> None:
|
|
210
207
|
super().__init__(next)
|
|
211
208
|
self.root = root
|
|
212
209
|
|
|
@@ -233,9 +230,7 @@ class _BraintrustActivityInboundInterceptor(temporalio.worker.ActivityInboundInt
|
|
|
233
230
|
super().__init__(next)
|
|
234
231
|
self.root = root
|
|
235
232
|
|
|
236
|
-
async def execute_activity(
|
|
237
|
-
self, input: temporalio.worker.ExecuteActivityInput
|
|
238
|
-
) -> Any:
|
|
233
|
+
async def execute_activity(self, input: temporalio.worker.ExecuteActivityInput) -> Any:
|
|
239
234
|
info = temporalio.activity.info()
|
|
240
235
|
|
|
241
236
|
# Extract parent span context from headers
|
|
@@ -281,14 +276,12 @@ class BraintrustWorkflowInboundInterceptor(temporalio.worker.WorkflowInboundInte
|
|
|
281
276
|
def __init__(self, next: temporalio.worker.WorkflowInboundInterceptor) -> None:
|
|
282
277
|
super().__init__(next)
|
|
283
278
|
self.payload_converter = temporalio.converter.PayloadConverter.default
|
|
284
|
-
self._parent_span_context:
|
|
279
|
+
self._parent_span_context: dict[str, Any] | None = None
|
|
285
280
|
|
|
286
281
|
def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None:
|
|
287
282
|
super().init(_BraintrustWorkflowOutboundInterceptor(outbound, self))
|
|
288
283
|
|
|
289
|
-
async def execute_workflow(
|
|
290
|
-
self, input: temporalio.worker.ExecuteWorkflowInput
|
|
291
|
-
) -> Any:
|
|
284
|
+
async def execute_workflow(self, input: temporalio.worker.ExecuteWorkflowInput) -> Any:
|
|
292
285
|
# Extract parent span context from workflow headers (set by client)
|
|
293
286
|
parent_span_context = None
|
|
294
287
|
if _HEADER_KEY in input.headers:
|
|
@@ -342,9 +335,7 @@ class BraintrustWorkflowInboundInterceptor(temporalio.worker.WorkflowInboundInte
|
|
|
342
335
|
span.end()
|
|
343
336
|
|
|
344
337
|
|
|
345
|
-
class _BraintrustWorkflowOutboundInterceptor(
|
|
346
|
-
temporalio.worker.WorkflowOutboundInterceptor
|
|
347
|
-
):
|
|
338
|
+
class _BraintrustWorkflowOutboundInterceptor(temporalio.worker.WorkflowOutboundInterceptor):
|
|
348
339
|
"""Outbound workflow interceptor that propagates span context to activities."""
|
|
349
340
|
|
|
350
341
|
def __init__(
|
|
@@ -371,9 +362,7 @@ class _BraintrustWorkflowOutboundInterceptor(
|
|
|
371
362
|
return {**headers, _HEADER_KEY: payloads[0]}
|
|
372
363
|
return headers
|
|
373
364
|
|
|
374
|
-
def start_activity(
|
|
375
|
-
self, input: temporalio.worker.StartActivityInput
|
|
376
|
-
) -> temporalio.workflow.ActivityHandle:
|
|
365
|
+
def start_activity(self, input: temporalio.worker.StartActivityInput) -> temporalio.workflow.ActivityHandle:
|
|
377
366
|
input.headers = self._add_span_context_to_headers(input.headers)
|
|
378
367
|
return super().start_activity(input)
|
|
379
368
|
|
|
@@ -390,7 +379,7 @@ class _BraintrustWorkflowOutboundInterceptor(
|
|
|
390
379
|
return super().start_child_workflow(input)
|
|
391
380
|
|
|
392
381
|
|
|
393
|
-
def _modify_workflow_runner(existing:
|
|
382
|
+
def _modify_workflow_runner(existing: WorkflowRunner | None) -> WorkflowRunner | None:
|
|
394
383
|
"""Add braintrust to sandbox passthrough modules."""
|
|
395
384
|
if isinstance(existing, SandboxedWorkflowRunner):
|
|
396
385
|
new_restrictions = existing.restrictions.with_passthrough_modules("braintrust")
|
|
@@ -420,7 +409,7 @@ class BraintrustPlugin(SimplePlugin):
|
|
|
420
409
|
Requires temporalio >= 1.19.0.
|
|
421
410
|
"""
|
|
422
411
|
|
|
423
|
-
def __init__(self, logger:
|
|
412
|
+
def __init__(self, logger: Any | None = None) -> None:
|
|
424
413
|
"""Initialize the Braintrust plugin.
|
|
425
414
|
|
|
426
415
|
Args:
|
|
@@ -279,11 +279,15 @@ class TestBraintrustPluginIntegration:
|
|
|
279
279
|
|
|
280
280
|
# Verify workflow span was created
|
|
281
281
|
workflow_spans = [s for s in spans if "temporal.workflow" in s.get("span_attributes", {}).get("name", "")]
|
|
282
|
-
assert len(workflow_spans) > 0,
|
|
282
|
+
assert len(workflow_spans) > 0, (
|
|
283
|
+
f"Expected workflow span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
|
|
284
|
+
)
|
|
283
285
|
|
|
284
286
|
# Verify activity span was created
|
|
285
287
|
activity_spans = [s for s in spans if "temporal.activity" in s.get("span_attributes", {}).get("name", "")]
|
|
286
|
-
assert len(activity_spans) > 0,
|
|
288
|
+
assert len(activity_spans) > 0, (
|
|
289
|
+
f"Expected activity span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
|
|
290
|
+
)
|
|
287
291
|
|
|
288
292
|
@pytest.mark.asyncio
|
|
289
293
|
async def test_plugin_context_propagation(self, temporal_env, memory_logger):
|
|
@@ -498,5 +502,6 @@ class TestBraintrustPluginIntegration:
|
|
|
498
502
|
# Verify parent-child relationship: workflow should have client span as parent
|
|
499
503
|
workflow_span = workflow_spans[0]
|
|
500
504
|
client_span = client_spans[0]
|
|
501
|
-
assert workflow_span.get("root_span_id") == client_span.get("root_span_id"),
|
|
505
|
+
assert workflow_span.get("root_span_id") == client_span.get("root_span_id"), (
|
|
502
506
|
"Workflow span should be in same trace as client span"
|
|
507
|
+
)
|