braintrust 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/__init__.py +4 -0
- braintrust/_generated_types.py +1200 -611
- braintrust/audit.py +2 -2
- braintrust/cli/eval.py +6 -7
- braintrust/cli/push.py +11 -11
- braintrust/conftest.py +1 -0
- braintrust/context.py +12 -17
- braintrust/contrib/temporal/__init__.py +16 -27
- braintrust/contrib/temporal/test_temporal.py +8 -3
- braintrust/devserver/auth.py +8 -8
- braintrust/devserver/cache.py +3 -4
- braintrust/devserver/cors.py +8 -7
- braintrust/devserver/dataset.py +3 -5
- braintrust/devserver/eval_hooks.py +7 -6
- braintrust/devserver/schemas.py +22 -19
- braintrust/devserver/server.py +19 -12
- braintrust/devserver/test_cached_login.py +4 -4
- braintrust/framework.py +128 -140
- braintrust/framework2.py +88 -87
- braintrust/functions/invoke.py +93 -53
- braintrust/functions/stream.py +3 -2
- braintrust/generated_types.py +17 -1
- braintrust/git_fields.py +11 -11
- braintrust/gitutil.py +2 -3
- braintrust/graph_util.py +10 -10
- braintrust/id_gen.py +2 -2
- braintrust/logger.py +346 -357
- braintrust/merge_row_batch.py +10 -9
- braintrust/oai.py +107 -24
- braintrust/otel/__init__.py +49 -49
- braintrust/otel/context.py +16 -30
- braintrust/otel/test_distributed_tracing.py +14 -11
- braintrust/otel/test_otel_bt_integration.py +32 -31
- braintrust/parameters.py +8 -8
- braintrust/prompt.py +14 -14
- braintrust/prompt_cache/disk_cache.py +5 -4
- braintrust/prompt_cache/lru_cache.py +3 -2
- braintrust/prompt_cache/prompt_cache.py +13 -14
- braintrust/queue.py +4 -4
- braintrust/score.py +4 -4
- braintrust/serializable_data_class.py +4 -4
- braintrust/span_identifier_v1.py +1 -2
- braintrust/span_identifier_v2.py +3 -4
- braintrust/span_identifier_v3.py +23 -20
- braintrust/span_identifier_v4.py +34 -25
- braintrust/test_framework.py +16 -6
- braintrust/test_helpers.py +5 -5
- braintrust/test_id_gen.py +2 -3
- braintrust/test_otel.py +61 -53
- braintrust/test_queue.py +0 -1
- braintrust/test_score.py +1 -3
- braintrust/test_span_components.py +29 -44
- braintrust/util.py +9 -8
- braintrust/version.py +2 -2
- braintrust/wrappers/_anthropic_utils.py +4 -4
- braintrust/wrappers/agno/__init__.py +3 -4
- braintrust/wrappers/agno/agent.py +1 -2
- braintrust/wrappers/agno/function_call.py +1 -2
- braintrust/wrappers/agno/model.py +1 -2
- braintrust/wrappers/agno/team.py +1 -2
- braintrust/wrappers/agno/utils.py +12 -12
- braintrust/wrappers/anthropic.py +7 -8
- braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
- braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
- braintrust/wrappers/dspy.py +15 -17
- braintrust/wrappers/google_genai/__init__.py +16 -16
- braintrust/wrappers/langchain.py +22 -24
- braintrust/wrappers/litellm.py +4 -3
- braintrust/wrappers/openai.py +15 -15
- braintrust/wrappers/pydantic_ai.py +1204 -0
- braintrust/wrappers/test_agno.py +0 -1
- braintrust/wrappers/test_dspy.py +0 -1
- braintrust/wrappers/test_google_genai.py +2 -3
- braintrust/wrappers/test_litellm.py +0 -1
- braintrust/wrappers/test_oai_attachments.py +322 -0
- braintrust/wrappers/test_pydantic_ai_integration.py +1788 -0
- braintrust/wrappers/{test_pydantic_ai.py → test_pydantic_ai_wrap_openai.py} +1 -2
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/METADATA +3 -2
- braintrust-0.4.0.dist-info/RECORD +120 -0
- braintrust-0.3.14.dist-info/RECORD +0 -117
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/WHEEL +0 -0
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/entry_points.txt +0 -0
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/top_level.txt +0 -0
braintrust/audit.py
CHANGED
|
@@ -5,7 +5,7 @@ Utilities for working with audit headers.
|
|
|
5
5
|
import base64
|
|
6
6
|
import gzip
|
|
7
7
|
import json
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import TypedDict
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class AuditResource(TypedDict):
|
|
@@ -14,7 +14,7 @@ class AuditResource(TypedDict):
|
|
|
14
14
|
name: str
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def parse_audit_resources(hdr: str) ->
|
|
17
|
+
def parse_audit_resources(hdr: str) -> list[AuditResource]:
|
|
18
18
|
j = json.loads(hdr)
|
|
19
19
|
if j["v"] == 1:
|
|
20
20
|
return json.loads(gzip.decompress(base64.b64decode(j["p"])))
|
braintrust/cli/eval.py
CHANGED
|
@@ -6,7 +6,6 @@ import os
|
|
|
6
6
|
import sys
|
|
7
7
|
from dataclasses import dataclass, field
|
|
8
8
|
from threading import Lock
|
|
9
|
-
from typing import Dict, List, Optional, Union
|
|
10
9
|
|
|
11
10
|
from .. import login
|
|
12
11
|
from ..framework import (
|
|
@@ -70,7 +69,7 @@ class EvaluatorOpts:
|
|
|
70
69
|
no_progress_bars: bool
|
|
71
70
|
terminate_on_failure: bool
|
|
72
71
|
watch: bool
|
|
73
|
-
filters:
|
|
72
|
+
filters: list[str]
|
|
74
73
|
list: bool
|
|
75
74
|
jsonl: bool
|
|
76
75
|
|
|
@@ -79,13 +78,13 @@ class EvaluatorOpts:
|
|
|
79
78
|
class LoadedEvaluator:
|
|
80
79
|
handle: FileHandle
|
|
81
80
|
evaluator: Evaluator
|
|
82
|
-
reporter:
|
|
81
|
+
reporter: ReporterDef | str | None = None
|
|
83
82
|
|
|
84
83
|
|
|
85
84
|
@dataclass
|
|
86
85
|
class EvaluatorState:
|
|
87
|
-
evaluators:
|
|
88
|
-
reporters:
|
|
86
|
+
evaluators: list[LoadedEvaluator] = field(default_factory=list)
|
|
87
|
+
reporters: dict[str, ReporterDef] = field(default_factory=dict)
|
|
89
88
|
|
|
90
89
|
|
|
91
90
|
def update_evaluators(eval_state: EvaluatorState, handles, terminate_on_failure):
|
|
@@ -157,7 +156,7 @@ async def run_evaluator_task(evaluator, position, opts: EvaluatorOpts):
|
|
|
157
156
|
experiment.flush()
|
|
158
157
|
|
|
159
158
|
|
|
160
|
-
def resolve_reporter(reporter:
|
|
159
|
+
def resolve_reporter(reporter: ReporterDef | str | None, reporters: dict[str, ReporterDef]) -> ReporterDef:
|
|
161
160
|
if isinstance(reporter, str):
|
|
162
161
|
if reporter not in reporters:
|
|
163
162
|
raise ValueError(f"Reporter {reporter} not found")
|
|
@@ -179,7 +178,7 @@ def add_report(eval_reports, reporter, report):
|
|
|
179
178
|
eval_reports[reporter.name]["results"].append(report)
|
|
180
179
|
|
|
181
180
|
|
|
182
|
-
async def run_once(handles:
|
|
181
|
+
async def run_once(handles: list[FileHandle], evaluator_opts: EvaluatorOpts) -> bool:
|
|
183
182
|
objects = EvaluatorState()
|
|
184
183
|
update_evaluators(objects, handles, terminate_on_failure=evaluator_opts.terminate_on_failure)
|
|
185
184
|
|
braintrust/cli/push.py
CHANGED
|
@@ -13,7 +13,7 @@ import sys
|
|
|
13
13
|
import tempfile
|
|
14
14
|
import textwrap
|
|
15
15
|
import zipfile
|
|
16
|
-
from typing import Any
|
|
16
|
+
from typing import Any
|
|
17
17
|
|
|
18
18
|
import requests
|
|
19
19
|
from braintrust.framework import _set_lazy_load
|
|
@@ -24,7 +24,7 @@ from ..generated_types import IfExists
|
|
|
24
24
|
from ..util import add_azure_blob_headers
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
def _pkg_install_arg(pkg) ->
|
|
27
|
+
def _pkg_install_arg(pkg) -> str | None:
|
|
28
28
|
try:
|
|
29
29
|
dist = importlib.metadata.distribution(pkg)
|
|
30
30
|
direct_url = dist._path / "direct_url.json" # type: ignore
|
|
@@ -73,11 +73,11 @@ class _ProjectRootImporter(importlib.abc.MetaPathFinder):
|
|
|
73
73
|
self._project_root, self._path_rest = sys.path[0], sys.path[1:]
|
|
74
74
|
self._sources = []
|
|
75
75
|
|
|
76
|
-
def _under_project_root(self, path:
|
|
76
|
+
def _under_project_root(self, path: list[str]) -> bool:
|
|
77
77
|
"""Returns true if all paths in `path` are under the project root."""
|
|
78
78
|
return all(p.startswith(self._project_root) for p in path)
|
|
79
79
|
|
|
80
|
-
def _under_rest(self, path:
|
|
80
|
+
def _under_rest(self, path: list[str]) -> bool:
|
|
81
81
|
"""Returns true if any path in `path` is under one of the remaining paths in `sys.path`."""
|
|
82
82
|
return any(p.startswith(pr) for p in path for pr in self._path_rest)
|
|
83
83
|
|
|
@@ -94,11 +94,11 @@ class _ProjectRootImporter(importlib.abc.MetaPathFinder):
|
|
|
94
94
|
self._sources.append(spec.origin)
|
|
95
95
|
return spec
|
|
96
96
|
|
|
97
|
-
def sources(self) ->
|
|
97
|
+
def sources(self) -> list[str]:
|
|
98
98
|
return self._sources
|
|
99
99
|
|
|
100
100
|
|
|
101
|
-
def _import_module(name: str, path: str) ->
|
|
101
|
+
def _import_module(name: str, path: str) -> list[str]:
|
|
102
102
|
"""Imports the module and returns the list of source files
|
|
103
103
|
of all modules imported in the process.
|
|
104
104
|
|
|
@@ -120,7 +120,7 @@ def _py_version() -> str:
|
|
|
120
120
|
return f"{sys.version_info.major}.{sys.version_info.minor}"
|
|
121
121
|
|
|
122
122
|
|
|
123
|
-
def _run_install(install_args:
|
|
123
|
+
def _run_install(install_args: list[str], packages_dir: str):
|
|
124
124
|
subprocess.run(
|
|
125
125
|
[
|
|
126
126
|
"uv",
|
|
@@ -138,7 +138,7 @@ def _run_install(install_args: List[str], packages_dir: str):
|
|
|
138
138
|
)
|
|
139
139
|
|
|
140
140
|
|
|
141
|
-
def _upload_bundle(entry_module_name: str, sources:
|
|
141
|
+
def _upload_bundle(entry_module_name: str, sources: list[str], requirements: str | None) -> str:
|
|
142
142
|
_check_uv()
|
|
143
143
|
|
|
144
144
|
resp = proxy_conn().post_json(
|
|
@@ -212,7 +212,7 @@ def _upload_bundle(entry_module_name: str, sources: List[str], requirements: Opt
|
|
|
212
212
|
|
|
213
213
|
|
|
214
214
|
def _collect_function_function_defs(
|
|
215
|
-
project_ids: ProjectIdCache, functions:
|
|
215
|
+
project_ids: ProjectIdCache, functions: list[dict[str, Any]], bundle_id: str, if_exists: IfExists
|
|
216
216
|
) -> None:
|
|
217
217
|
for i, f in enumerate(global_.functions):
|
|
218
218
|
source = inspect.getsource(f.handler)
|
|
@@ -262,7 +262,7 @@ def _collect_function_function_defs(
|
|
|
262
262
|
|
|
263
263
|
|
|
264
264
|
def _collect_prompt_function_defs(
|
|
265
|
-
project_ids: ProjectIdCache, functions:
|
|
265
|
+
project_ids: ProjectIdCache, functions: list[dict[str, Any]], if_exists: IfExists
|
|
266
266
|
) -> None:
|
|
267
267
|
for p in global_.prompts:
|
|
268
268
|
functions.append(p.to_function_definition(if_exists, project_ids))
|
|
@@ -300,7 +300,7 @@ def run(args):
|
|
|
300
300
|
raise
|
|
301
301
|
|
|
302
302
|
project_ids = ProjectIdCache()
|
|
303
|
-
functions:
|
|
303
|
+
functions: list[dict[str, Any]] = []
|
|
304
304
|
if len(global_.functions) > 0:
|
|
305
305
|
bundle_id = _upload_bundle(module_name, sources, args.requirements)
|
|
306
306
|
_collect_function_function_defs(project_ids, functions, bundle_id, args.if_exists)
|
braintrust/conftest.py
CHANGED
|
@@ -36,6 +36,7 @@ def override_app_url_for_tests():
|
|
|
36
36
|
@pytest.fixture(autouse=True)
|
|
37
37
|
def setup_braintrust():
|
|
38
38
|
os.environ.setdefault("GOOGLE_API_KEY", os.getenv("GEMINI_API_KEY", "your_google_api_key_here"))
|
|
39
|
+
os.environ.setdefault("OPENAI_API_KEY", "sk-test-dummy-api-key-for-vcr-tests")
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
@pytest.fixture(autouse=True)
|
braintrust/context.py
CHANGED
|
@@ -5,12 +5,13 @@ import os
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
from contextvars import ContextVar
|
|
7
7
|
from dataclasses import dataclass
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass
|
|
12
12
|
class SpanInfo:
|
|
13
13
|
"""Information about a span in the context."""
|
|
14
|
+
|
|
14
15
|
trace_id: str
|
|
15
16
|
span_id: str
|
|
16
17
|
span_object: Any = None
|
|
@@ -19,7 +20,7 @@ class SpanInfo:
|
|
|
19
20
|
@dataclass
|
|
20
21
|
class ParentSpanIds:
|
|
21
22
|
root_span_id: str
|
|
22
|
-
span_parents:
|
|
23
|
+
span_parents: list[str]
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class ContextManager(ABC):
|
|
@@ -30,7 +31,7 @@ class ContextManager(ABC):
|
|
|
30
31
|
"""
|
|
31
32
|
|
|
32
33
|
@abstractmethod
|
|
33
|
-
def get_current_span_info(self) ->
|
|
34
|
+
def get_current_span_info(self) -> Any | None:
|
|
34
35
|
"""Get information about the currently active span.
|
|
35
36
|
|
|
36
37
|
Returns:
|
|
@@ -40,7 +41,7 @@ class ContextManager(ABC):
|
|
|
40
41
|
pass
|
|
41
42
|
|
|
42
43
|
@abstractmethod
|
|
43
|
-
def get_parent_span_ids(self) ->
|
|
44
|
+
def get_parent_span_ids(self) -> ParentSpanIds | None:
|
|
44
45
|
"""Get parent span IDs for creating a new Braintrust span.
|
|
45
46
|
|
|
46
47
|
Returns:
|
|
@@ -75,32 +76,25 @@ class BraintrustContextManager(ContextManager):
|
|
|
75
76
|
"""Braintrust-only context manager using contextvars when OTEL is not available."""
|
|
76
77
|
|
|
77
78
|
def __init__(self):
|
|
78
|
-
self._current_span: ContextVar[
|
|
79
|
+
self._current_span: ContextVar[Any | None] = ContextVar("braintrust_current_span", default=None)
|
|
79
80
|
|
|
80
|
-
def get_current_span_info(self) ->
|
|
81
|
+
def get_current_span_info(self) -> SpanInfo | None:
|
|
81
82
|
"""Get information about the currently active span."""
|
|
82
83
|
current_span = self._current_span.get()
|
|
83
84
|
if not current_span:
|
|
84
85
|
return None
|
|
85
86
|
|
|
86
87
|
# Return SpanInfo for BT spans
|
|
87
|
-
return SpanInfo(
|
|
88
|
-
trace_id=current_span.root_span_id,
|
|
89
|
-
span_id=current_span.span_id,
|
|
90
|
-
span_object=current_span
|
|
91
|
-
)
|
|
88
|
+
return SpanInfo(trace_id=current_span.root_span_id, span_id=current_span.span_id, span_object=current_span)
|
|
92
89
|
|
|
93
|
-
def get_parent_span_ids(self) ->
|
|
90
|
+
def get_parent_span_ids(self) -> ParentSpanIds | None:
|
|
94
91
|
"""Get parent information for creating a new Braintrust span."""
|
|
95
92
|
current_span = self._current_span.get()
|
|
96
93
|
if not current_span:
|
|
97
94
|
return None
|
|
98
95
|
|
|
99
96
|
# If current span is a BT span, use it as parent
|
|
100
|
-
return ParentSpanIds(
|
|
101
|
-
root_span_id=current_span.root_span_id,
|
|
102
|
-
span_parents=[current_span.span_id]
|
|
103
|
-
)
|
|
97
|
+
return ParentSpanIds(root_span_id=current_span.root_span_id, span_parents=[current_span.span_id])
|
|
104
98
|
|
|
105
99
|
def set_current_span(self, span_object: Any) -> Any:
|
|
106
100
|
"""Set the current active span."""
|
|
@@ -123,9 +117,10 @@ def get_context_manager() -> ContextManager:
|
|
|
123
117
|
"""
|
|
124
118
|
|
|
125
119
|
# Check if OTEL should be explicitly enabled via environment variable
|
|
126
|
-
if os.environ.get(
|
|
120
|
+
if os.environ.get("BRAINTRUST_OTEL_COMPAT", "").lower() in ("1", "true", "yes"):
|
|
127
121
|
try:
|
|
128
122
|
from braintrust.otel.context import ContextManager as OtelContextManager
|
|
123
|
+
|
|
129
124
|
return OtelContextManager()
|
|
130
125
|
except ImportError:
|
|
131
126
|
logging.warning("OTEL not available, falling back to Braintrust-only version")
|
|
@@ -80,7 +80,8 @@ The integration will automatically:
|
|
|
80
80
|
"""
|
|
81
81
|
|
|
82
82
|
import dataclasses
|
|
83
|
-
from
|
|
83
|
+
from collections.abc import Mapping
|
|
84
|
+
from typing import Any
|
|
84
85
|
|
|
85
86
|
import braintrust
|
|
86
87
|
import temporalio.activity
|
|
@@ -129,7 +130,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
129
130
|
- Ensures replay safety (no duplicate spans during workflow replay)
|
|
130
131
|
"""
|
|
131
132
|
|
|
132
|
-
def __init__(self, logger:
|
|
133
|
+
def __init__(self, logger: Any | None = None) -> None:
|
|
133
134
|
"""Initialize interceptor.
|
|
134
135
|
|
|
135
136
|
Args:
|
|
@@ -142,7 +143,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
142
143
|
braintrust.logger._state._override_bg_logger.logger = logger
|
|
143
144
|
self._logger = braintrust.current_logger()
|
|
144
145
|
|
|
145
|
-
def _get_logger(self) ->
|
|
146
|
+
def _get_logger(self) -> Any | None:
|
|
146
147
|
"""Get logger for creating spans.
|
|
147
148
|
|
|
148
149
|
Sets thread-local override if background logger provided (for testing),
|
|
@@ -152,9 +153,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
152
153
|
braintrust.logger._state._override_bg_logger.logger = self._bg_logger
|
|
153
154
|
return self._logger
|
|
154
155
|
|
|
155
|
-
def intercept_client(
|
|
156
|
-
self, next: temporalio.client.OutboundInterceptor
|
|
157
|
-
) -> temporalio.client.OutboundInterceptor:
|
|
156
|
+
def intercept_client(self, next: temporalio.client.OutboundInterceptor) -> temporalio.client.OutboundInterceptor:
|
|
158
157
|
"""Intercept client calls to propagate span context to workflows."""
|
|
159
158
|
return _BraintrustClientOutboundInterceptor(next, self)
|
|
160
159
|
|
|
@@ -166,14 +165,14 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
166
165
|
|
|
167
166
|
def workflow_interceptor_class(
|
|
168
167
|
self, input: temporalio.worker.WorkflowInterceptorClassInput
|
|
169
|
-
) ->
|
|
168
|
+
) -> type["BraintrustWorkflowInboundInterceptor"] | None:
|
|
170
169
|
"""Return workflow interceptor class to propagate context to activities."""
|
|
171
170
|
input.unsafe_extern_functions["__braintrust_get_logger"] = self._get_logger
|
|
172
171
|
return BraintrustWorkflowInboundInterceptor
|
|
173
172
|
|
|
174
173
|
def _span_context_to_headers(
|
|
175
174
|
self,
|
|
176
|
-
span_context:
|
|
175
|
+
span_context: dict[str, Any],
|
|
177
176
|
headers: Mapping[str, temporalio.api.common.v1.Payload],
|
|
178
177
|
) -> Mapping[str, temporalio.api.common.v1.Payload]:
|
|
179
178
|
"""Add span context to headers."""
|
|
@@ -188,7 +187,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
188
187
|
|
|
189
188
|
def _span_context_from_headers(
|
|
190
189
|
self, headers: Mapping[str, temporalio.api.common.v1.Payload]
|
|
191
|
-
) ->
|
|
190
|
+
) -> dict[str, Any] | None:
|
|
192
191
|
"""Extract span context from headers."""
|
|
193
192
|
if _HEADER_KEY not in headers:
|
|
194
193
|
return None
|
|
@@ -204,9 +203,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
|
|
|
204
203
|
class _BraintrustClientOutboundInterceptor(temporalio.client.OutboundInterceptor):
|
|
205
204
|
"""Client interceptor that propagates span context to workflows."""
|
|
206
205
|
|
|
207
|
-
def __init__(
|
|
208
|
-
self, next: temporalio.client.OutboundInterceptor, root: BraintrustInterceptor
|
|
209
|
-
) -> None:
|
|
206
|
+
def __init__(self, next: temporalio.client.OutboundInterceptor, root: BraintrustInterceptor) -> None:
|
|
210
207
|
super().__init__(next)
|
|
211
208
|
self.root = root
|
|
212
209
|
|
|
@@ -233,9 +230,7 @@ class _BraintrustActivityInboundInterceptor(temporalio.worker.ActivityInboundInt
|
|
|
233
230
|
super().__init__(next)
|
|
234
231
|
self.root = root
|
|
235
232
|
|
|
236
|
-
async def execute_activity(
|
|
237
|
-
self, input: temporalio.worker.ExecuteActivityInput
|
|
238
|
-
) -> Any:
|
|
233
|
+
async def execute_activity(self, input: temporalio.worker.ExecuteActivityInput) -> Any:
|
|
239
234
|
info = temporalio.activity.info()
|
|
240
235
|
|
|
241
236
|
# Extract parent span context from headers
|
|
@@ -281,14 +276,12 @@ class BraintrustWorkflowInboundInterceptor(temporalio.worker.WorkflowInboundInte
|
|
|
281
276
|
def __init__(self, next: temporalio.worker.WorkflowInboundInterceptor) -> None:
|
|
282
277
|
super().__init__(next)
|
|
283
278
|
self.payload_converter = temporalio.converter.PayloadConverter.default
|
|
284
|
-
self._parent_span_context:
|
|
279
|
+
self._parent_span_context: dict[str, Any] | None = None
|
|
285
280
|
|
|
286
281
|
def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None:
|
|
287
282
|
super().init(_BraintrustWorkflowOutboundInterceptor(outbound, self))
|
|
288
283
|
|
|
289
|
-
async def execute_workflow(
|
|
290
|
-
self, input: temporalio.worker.ExecuteWorkflowInput
|
|
291
|
-
) -> Any:
|
|
284
|
+
async def execute_workflow(self, input: temporalio.worker.ExecuteWorkflowInput) -> Any:
|
|
292
285
|
# Extract parent span context from workflow headers (set by client)
|
|
293
286
|
parent_span_context = None
|
|
294
287
|
if _HEADER_KEY in input.headers:
|
|
@@ -342,9 +335,7 @@ class BraintrustWorkflowInboundInterceptor(temporalio.worker.WorkflowInboundInte
|
|
|
342
335
|
span.end()
|
|
343
336
|
|
|
344
337
|
|
|
345
|
-
class _BraintrustWorkflowOutboundInterceptor(
|
|
346
|
-
temporalio.worker.WorkflowOutboundInterceptor
|
|
347
|
-
):
|
|
338
|
+
class _BraintrustWorkflowOutboundInterceptor(temporalio.worker.WorkflowOutboundInterceptor):
|
|
348
339
|
"""Outbound workflow interceptor that propagates span context to activities."""
|
|
349
340
|
|
|
350
341
|
def __init__(
|
|
@@ -371,9 +362,7 @@ class _BraintrustWorkflowOutboundInterceptor(
|
|
|
371
362
|
return {**headers, _HEADER_KEY: payloads[0]}
|
|
372
363
|
return headers
|
|
373
364
|
|
|
374
|
-
def start_activity(
|
|
375
|
-
self, input: temporalio.worker.StartActivityInput
|
|
376
|
-
) -> temporalio.workflow.ActivityHandle:
|
|
365
|
+
def start_activity(self, input: temporalio.worker.StartActivityInput) -> temporalio.workflow.ActivityHandle:
|
|
377
366
|
input.headers = self._add_span_context_to_headers(input.headers)
|
|
378
367
|
return super().start_activity(input)
|
|
379
368
|
|
|
@@ -390,7 +379,7 @@ class _BraintrustWorkflowOutboundInterceptor(
|
|
|
390
379
|
return super().start_child_workflow(input)
|
|
391
380
|
|
|
392
381
|
|
|
393
|
-
def _modify_workflow_runner(existing:
|
|
382
|
+
def _modify_workflow_runner(existing: WorkflowRunner | None) -> WorkflowRunner | None:
|
|
394
383
|
"""Add braintrust to sandbox passthrough modules."""
|
|
395
384
|
if isinstance(existing, SandboxedWorkflowRunner):
|
|
396
385
|
new_restrictions = existing.restrictions.with_passthrough_modules("braintrust")
|
|
@@ -420,7 +409,7 @@ class BraintrustPlugin(SimplePlugin):
|
|
|
420
409
|
Requires temporalio >= 1.19.0.
|
|
421
410
|
"""
|
|
422
411
|
|
|
423
|
-
def __init__(self, logger:
|
|
412
|
+
def __init__(self, logger: Any | None = None) -> None:
|
|
424
413
|
"""Initialize the Braintrust plugin.
|
|
425
414
|
|
|
426
415
|
Args:
|
|
@@ -279,11 +279,15 @@ class TestBraintrustPluginIntegration:
|
|
|
279
279
|
|
|
280
280
|
# Verify workflow span was created
|
|
281
281
|
workflow_spans = [s for s in spans if "temporal.workflow" in s.get("span_attributes", {}).get("name", "")]
|
|
282
|
-
assert len(workflow_spans) > 0,
|
|
282
|
+
assert len(workflow_spans) > 0, (
|
|
283
|
+
f"Expected workflow span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
|
|
284
|
+
)
|
|
283
285
|
|
|
284
286
|
# Verify activity span was created
|
|
285
287
|
activity_spans = [s for s in spans if "temporal.activity" in s.get("span_attributes", {}).get("name", "")]
|
|
286
|
-
assert len(activity_spans) > 0,
|
|
288
|
+
assert len(activity_spans) > 0, (
|
|
289
|
+
f"Expected activity span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
|
|
290
|
+
)
|
|
287
291
|
|
|
288
292
|
@pytest.mark.asyncio
|
|
289
293
|
async def test_plugin_context_propagation(self, temporal_env, memory_logger):
|
|
@@ -498,5 +502,6 @@ class TestBraintrustPluginIntegration:
|
|
|
498
502
|
# Verify parent-child relationship: workflow should have client span as parent
|
|
499
503
|
workflow_span = workflow_spans[0]
|
|
500
504
|
client_span = client_spans[0]
|
|
501
|
-
assert workflow_span.get("root_span_id") == client_span.get("root_span_id"),
|
|
505
|
+
assert workflow_span.get("root_span_id") == client_span.get("root_span_id"), (
|
|
502
506
|
"Workflow span should be in same trace as client span"
|
|
507
|
+
)
|
braintrust/devserver/auth.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
+
from collections.abc import Awaitable, Callable
|
|
1
2
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Awaitable, Callable, Dict, Optional
|
|
3
3
|
|
|
4
4
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
5
5
|
from starlette.requests import Request
|
|
@@ -15,14 +15,14 @@ BRAINTRUST_PROJECT_ID_HEADER = "x-bt-project-id"
|
|
|
15
15
|
|
|
16
16
|
@dataclass
|
|
17
17
|
class RequestContext:
|
|
18
|
-
app_origin:
|
|
19
|
-
token:
|
|
20
|
-
org_name:
|
|
21
|
-
project_id:
|
|
22
|
-
state:
|
|
18
|
+
app_origin: str | None
|
|
19
|
+
token: str | None
|
|
20
|
+
org_name: str | None
|
|
21
|
+
project_id: str | None
|
|
22
|
+
state: BraintrustState | None
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def extract_allowed_origin(origin:
|
|
25
|
+
def extract_allowed_origin(origin: str | None) -> str | None:
|
|
26
26
|
"""Extract and validate the origin header."""
|
|
27
27
|
# This should use the same check_origin logic from cors.py
|
|
28
28
|
from .cors import check_origin
|
|
@@ -32,7 +32,7 @@ def extract_allowed_origin(origin: Optional[str]) -> Optional[str]:
|
|
|
32
32
|
return None
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
def parse_braintrust_auth_header(headers:
|
|
35
|
+
def parse_braintrust_auth_header(headers: dict[str, str]) -> str | None:
|
|
36
36
|
"""Parse the authorization token from headers."""
|
|
37
37
|
# Check x-bt-auth-token first
|
|
38
38
|
token = headers.get(BRAINTRUST_AUTH_TOKEN_HEADER)
|
braintrust/devserver/cache.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""LRU cache implementation for the dev server."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from typing import Dict, Optional
|
|
5
4
|
|
|
6
5
|
from ..logger import BraintrustState, login_to_state
|
|
7
6
|
|
|
@@ -11,10 +10,10 @@ class LRUCache:
|
|
|
11
10
|
|
|
12
11
|
def __init__(self, max_size: int = 32):
|
|
13
12
|
self.max_size = max_size
|
|
14
|
-
self.cache:
|
|
13
|
+
self.cache: dict[str, BraintrustState] = {}
|
|
15
14
|
self.access_order: list[str] = []
|
|
16
15
|
|
|
17
|
-
def get(self, key: str) ->
|
|
16
|
+
def get(self, key: str) -> BraintrustState | None:
|
|
18
17
|
"""Get a value from the cache, updating access order."""
|
|
19
18
|
if key in self.cache:
|
|
20
19
|
# Move to end to mark as recently used
|
|
@@ -41,7 +40,7 @@ class LRUCache:
|
|
|
41
40
|
_login_cache = LRUCache(max_size=32) # TODO: Make this configurable
|
|
42
41
|
|
|
43
42
|
|
|
44
|
-
async def cached_login(api_key: str, app_url: str, org_name:
|
|
43
|
+
async def cached_login(api_key: str, app_url: str, org_name: str | None = None) -> BraintrustState:
|
|
45
44
|
"""Login with caching to avoid repeated API calls."""
|
|
46
45
|
cache_key = json.dumps({"api_key": api_key, "app_url": app_url, "org_name": org_name})
|
|
47
46
|
|
braintrust/devserver/cors.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import re
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Awaitable, Callable
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
6
|
# CORS configuration
|
|
6
|
-
ALLOWED_ORIGINS:
|
|
7
|
+
ALLOWED_ORIGINS: list[str | re.Pattern] = [
|
|
7
8
|
"https://www.braintrust.dev",
|
|
8
9
|
"https://www.braintrustdata.com",
|
|
9
10
|
re.compile(r"https://.*\.preview\.braintrust\.dev"),
|
|
@@ -70,9 +71,9 @@ def create_cors_middleware() -> type:
|
|
|
70
71
|
|
|
71
72
|
async def __call__(
|
|
72
73
|
self,
|
|
73
|
-
scope:
|
|
74
|
-
receive: Callable[[], Awaitable[
|
|
75
|
-
send: Callable[[
|
|
74
|
+
scope: dict[str, Any],
|
|
75
|
+
receive: Callable[[], Awaitable[dict[str, Any]]],
|
|
76
|
+
send: Callable[[dict[str, Any]], Awaitable[None]],
|
|
76
77
|
) -> None:
|
|
77
78
|
if scope["type"] == "http":
|
|
78
79
|
headers = dict(scope["headers"])
|
|
@@ -81,7 +82,7 @@ def create_cors_middleware() -> type:
|
|
|
81
82
|
# Handle OPTIONS requests
|
|
82
83
|
if scope["method"] == "OPTIONS":
|
|
83
84
|
|
|
84
|
-
async def send_options_wrapper(message:
|
|
85
|
+
async def send_options_wrapper(message: dict[str, Any]) -> None:
|
|
85
86
|
if message["type"] == "http.response.start":
|
|
86
87
|
headers_dict = dict(message.get("headers", []))
|
|
87
88
|
|
|
@@ -120,7 +121,7 @@ def create_cors_middleware() -> type:
|
|
|
120
121
|
return
|
|
121
122
|
|
|
122
123
|
# For other requests, add CORS headers if origin is valid
|
|
123
|
-
async def send_wrapper(message:
|
|
124
|
+
async def send_wrapper(message: dict[str, Any]) -> None:
|
|
124
125
|
if message["type"] == "http.response.start" and origin and check_origin(origin):
|
|
125
126
|
headers_dict = dict(message.get("headers", []))
|
|
126
127
|
|
braintrust/devserver/dataset.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any
|
|
2
2
|
|
|
3
3
|
from braintrust import init_dataset
|
|
4
4
|
from braintrust._generated_types import RunEvalData, RunEvalData1, RunEvalData2
|
|
5
5
|
from braintrust.logger import BraintrustState
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
async def get_dataset_by_id(state: BraintrustState, dataset_id: str) ->
|
|
8
|
+
async def get_dataset_by_id(state: BraintrustState, dataset_id: str) -> dict[str, str]:
|
|
9
9
|
"""Fetch dataset information by ID."""
|
|
10
10
|
# Make API call to get dataset info
|
|
11
11
|
conn = state.api_conn()
|
|
@@ -23,9 +23,7 @@ async def get_dataset_by_id(state: BraintrustState, dataset_id: str) -> Dict[str
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
# NOTE: To make this performant, we'll have to make these functions work with async i/o
|
|
26
|
-
async def get_dataset(
|
|
27
|
-
state: BraintrustState, data: Union[RunEvalData, RunEvalData1, RunEvalData2, Dict[str, Any]]
|
|
28
|
-
) -> Any:
|
|
26
|
+
async def get_dataset(state: BraintrustState, data: RunEvalData | RunEvalData1 | RunEvalData2 | dict[str, Any]) -> Any:
|
|
29
27
|
"""
|
|
30
28
|
Get dataset from various data sources.
|
|
31
29
|
|
|
@@ -7,7 +7,8 @@ for reporting progress during evaluation execution.
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import json
|
|
10
|
-
from
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import Any
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class EvalHooks:
|
|
@@ -15,13 +16,13 @@ class EvalHooks:
|
|
|
15
16
|
|
|
16
17
|
def __init__(
|
|
17
18
|
self,
|
|
18
|
-
report_progress:
|
|
19
|
-
parameters:
|
|
19
|
+
report_progress: Callable[[dict[str, Any]], None] | None = None,
|
|
20
|
+
parameters: dict[str, Any] | None = None,
|
|
20
21
|
):
|
|
21
22
|
self._report_progress = report_progress
|
|
22
23
|
self.parameters = parameters or {}
|
|
23
24
|
|
|
24
|
-
def report_progress(self, event:
|
|
25
|
+
def report_progress(self, event: dict[str, Any]) -> None:
|
|
25
26
|
"""Report progress during task execution."""
|
|
26
27
|
if self._report_progress:
|
|
27
28
|
self._report_progress(event)
|
|
@@ -45,7 +46,7 @@ class SSEQueue:
|
|
|
45
46
|
"""Simple wrapper around asyncio.Queue for SSE events."""
|
|
46
47
|
|
|
47
48
|
def __init__(self):
|
|
48
|
-
self.queue: asyncio.Queue[
|
|
49
|
+
self.queue: asyncio.Queue[str | None] = asyncio.Queue()
|
|
49
50
|
|
|
50
51
|
async def put_event(self, event: str, data: Any) -> None:
|
|
51
52
|
"""Add an SSE event to the queue."""
|
|
@@ -56,6 +57,6 @@ class SSEQueue:
|
|
|
56
57
|
"""Signal end of stream."""
|
|
57
58
|
await self.queue.put(None)
|
|
58
59
|
|
|
59
|
-
async def get(self) ->
|
|
60
|
+
async def get(self) -> str | None:
|
|
60
61
|
"""Get the next event from the queue."""
|
|
61
62
|
return await self.queue.get()
|