braintrust 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/_generated_types.py +737 -672
- braintrust/audit.py +2 -2
- braintrust/bt_json.py +178 -19
- braintrust/cli/eval.py +6 -7
- braintrust/cli/push.py +11 -11
- braintrust/context.py +12 -17
- braintrust/contrib/temporal/__init__.py +16 -27
- braintrust/contrib/temporal/test_temporal.py +8 -3
- braintrust/devserver/auth.py +8 -8
- braintrust/devserver/cache.py +3 -4
- braintrust/devserver/cors.py +8 -7
- braintrust/devserver/dataset.py +3 -5
- braintrust/devserver/eval_hooks.py +7 -6
- braintrust/devserver/schemas.py +22 -19
- braintrust/devserver/server.py +19 -12
- braintrust/devserver/test_cached_login.py +4 -4
- braintrust/framework.py +139 -142
- braintrust/framework2.py +88 -87
- braintrust/functions/invoke.py +66 -59
- braintrust/functions/stream.py +3 -2
- braintrust/generated_types.py +3 -1
- braintrust/git_fields.py +11 -11
- braintrust/gitutil.py +2 -3
- braintrust/graph_util.py +10 -10
- braintrust/id_gen.py +2 -2
- braintrust/logger.py +373 -471
- braintrust/merge_row_batch.py +10 -9
- braintrust/oai.py +21 -20
- braintrust/otel/__init__.py +49 -49
- braintrust/otel/context.py +16 -30
- braintrust/otel/test_distributed_tracing.py +14 -11
- braintrust/otel/test_otel_bt_integration.py +32 -31
- braintrust/parameters.py +8 -8
- braintrust/prompt.py +14 -14
- braintrust/prompt_cache/disk_cache.py +5 -4
- braintrust/prompt_cache/lru_cache.py +3 -2
- braintrust/prompt_cache/prompt_cache.py +13 -14
- braintrust/queue.py +4 -4
- braintrust/score.py +4 -4
- braintrust/serializable_data_class.py +4 -4
- braintrust/span_identifier_v1.py +1 -2
- braintrust/span_identifier_v2.py +3 -4
- braintrust/span_identifier_v3.py +23 -20
- braintrust/span_identifier_v4.py +34 -25
- braintrust/test_bt_json.py +644 -0
- braintrust/test_framework.py +72 -6
- braintrust/test_helpers.py +5 -5
- braintrust/test_id_gen.py +2 -3
- braintrust/test_logger.py +211 -107
- braintrust/test_otel.py +61 -53
- braintrust/test_queue.py +0 -1
- braintrust/test_score.py +1 -3
- braintrust/test_span_components.py +29 -44
- braintrust/util.py +9 -8
- braintrust/version.py +2 -2
- braintrust/wrappers/_anthropic_utils.py +4 -4
- braintrust/wrappers/agno/__init__.py +3 -4
- braintrust/wrappers/agno/agent.py +1 -2
- braintrust/wrappers/agno/function_call.py +1 -2
- braintrust/wrappers/agno/model.py +1 -2
- braintrust/wrappers/agno/team.py +1 -2
- braintrust/wrappers/agno/utils.py +12 -12
- braintrust/wrappers/anthropic.py +7 -8
- braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
- braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
- braintrust/wrappers/dspy.py +15 -17
- braintrust/wrappers/google_genai/__init__.py +17 -30
- braintrust/wrappers/langchain.py +22 -24
- braintrust/wrappers/litellm.py +4 -3
- braintrust/wrappers/openai.py +15 -15
- braintrust/wrappers/pydantic_ai.py +225 -110
- braintrust/wrappers/test_agno.py +0 -1
- braintrust/wrappers/test_dspy.py +0 -1
- braintrust/wrappers/test_google_genai.py +64 -4
- braintrust/wrappers/test_litellm.py +0 -1
- braintrust/wrappers/test_pydantic_ai_integration.py +819 -22
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/METADATA +3 -2
- braintrust-0.4.1.dist-info/RECORD +121 -0
- braintrust-0.3.15.dist-info/RECORD +0 -120
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/WHEEL +0 -0
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/entry_points.txt +0 -0
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/top_level.txt +0 -0
braintrust/devserver/auth.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
+
from collections.abc import Awaitable, Callable
|
|
1
2
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Awaitable, Callable, Dict, Optional
|
|
3
3
|
|
|
4
4
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
5
5
|
from starlette.requests import Request
|
|
@@ -15,14 +15,14 @@ BRAINTRUST_PROJECT_ID_HEADER = "x-bt-project-id"
|
|
|
15
15
|
|
|
16
16
|
@dataclass
|
|
17
17
|
class RequestContext:
|
|
18
|
-
app_origin:
|
|
19
|
-
token:
|
|
20
|
-
org_name:
|
|
21
|
-
project_id:
|
|
22
|
-
state:
|
|
18
|
+
app_origin: str | None
|
|
19
|
+
token: str | None
|
|
20
|
+
org_name: str | None
|
|
21
|
+
project_id: str | None
|
|
22
|
+
state: BraintrustState | None
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
def extract_allowed_origin(origin:
|
|
25
|
+
def extract_allowed_origin(origin: str | None) -> str | None:
|
|
26
26
|
"""Extract and validate the origin header."""
|
|
27
27
|
# This should use the same check_origin logic from cors.py
|
|
28
28
|
from .cors import check_origin
|
|
@@ -32,7 +32,7 @@ def extract_allowed_origin(origin: Optional[str]) -> Optional[str]:
|
|
|
32
32
|
return None
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
def parse_braintrust_auth_header(headers:
|
|
35
|
+
def parse_braintrust_auth_header(headers: dict[str, str]) -> str | None:
|
|
36
36
|
"""Parse the authorization token from headers."""
|
|
37
37
|
# Check x-bt-auth-token first
|
|
38
38
|
token = headers.get(BRAINTRUST_AUTH_TOKEN_HEADER)
|
braintrust/devserver/cache.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""LRU cache implementation for the dev server."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from typing import Dict, Optional
|
|
5
4
|
|
|
6
5
|
from ..logger import BraintrustState, login_to_state
|
|
7
6
|
|
|
@@ -11,10 +10,10 @@ class LRUCache:
|
|
|
11
10
|
|
|
12
11
|
def __init__(self, max_size: int = 32):
|
|
13
12
|
self.max_size = max_size
|
|
14
|
-
self.cache:
|
|
13
|
+
self.cache: dict[str, BraintrustState] = {}
|
|
15
14
|
self.access_order: list[str] = []
|
|
16
15
|
|
|
17
|
-
def get(self, key: str) ->
|
|
16
|
+
def get(self, key: str) -> BraintrustState | None:
|
|
18
17
|
"""Get a value from the cache, updating access order."""
|
|
19
18
|
if key in self.cache:
|
|
20
19
|
# Move to end to mark as recently used
|
|
@@ -41,7 +40,7 @@ class LRUCache:
|
|
|
41
40
|
_login_cache = LRUCache(max_size=32) # TODO: Make this configurable
|
|
42
41
|
|
|
43
42
|
|
|
44
|
-
async def cached_login(api_key: str, app_url: str, org_name:
|
|
43
|
+
async def cached_login(api_key: str, app_url: str, org_name: str | None = None) -> BraintrustState:
|
|
45
44
|
"""Login with caching to avoid repeated API calls."""
|
|
46
45
|
cache_key = json.dumps({"api_key": api_key, "app_url": app_url, "org_name": org_name})
|
|
47
46
|
|
braintrust/devserver/cors.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import re
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Awaitable, Callable
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
6
|
# CORS configuration
|
|
6
|
-
ALLOWED_ORIGINS:
|
|
7
|
+
ALLOWED_ORIGINS: list[str | re.Pattern] = [
|
|
7
8
|
"https://www.braintrust.dev",
|
|
8
9
|
"https://www.braintrustdata.com",
|
|
9
10
|
re.compile(r"https://.*\.preview\.braintrust\.dev"),
|
|
@@ -70,9 +71,9 @@ def create_cors_middleware() -> type:
|
|
|
70
71
|
|
|
71
72
|
async def __call__(
|
|
72
73
|
self,
|
|
73
|
-
scope:
|
|
74
|
-
receive: Callable[[], Awaitable[
|
|
75
|
-
send: Callable[[
|
|
74
|
+
scope: dict[str, Any],
|
|
75
|
+
receive: Callable[[], Awaitable[dict[str, Any]]],
|
|
76
|
+
send: Callable[[dict[str, Any]], Awaitable[None]],
|
|
76
77
|
) -> None:
|
|
77
78
|
if scope["type"] == "http":
|
|
78
79
|
headers = dict(scope["headers"])
|
|
@@ -81,7 +82,7 @@ def create_cors_middleware() -> type:
|
|
|
81
82
|
# Handle OPTIONS requests
|
|
82
83
|
if scope["method"] == "OPTIONS":
|
|
83
84
|
|
|
84
|
-
async def send_options_wrapper(message:
|
|
85
|
+
async def send_options_wrapper(message: dict[str, Any]) -> None:
|
|
85
86
|
if message["type"] == "http.response.start":
|
|
86
87
|
headers_dict = dict(message.get("headers", []))
|
|
87
88
|
|
|
@@ -120,7 +121,7 @@ def create_cors_middleware() -> type:
|
|
|
120
121
|
return
|
|
121
122
|
|
|
122
123
|
# For other requests, add CORS headers if origin is valid
|
|
123
|
-
async def send_wrapper(message:
|
|
124
|
+
async def send_wrapper(message: dict[str, Any]) -> None:
|
|
124
125
|
if message["type"] == "http.response.start" and origin and check_origin(origin):
|
|
125
126
|
headers_dict = dict(message.get("headers", []))
|
|
126
127
|
|
braintrust/devserver/dataset.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any
|
|
2
2
|
|
|
3
3
|
from braintrust import init_dataset
|
|
4
4
|
from braintrust._generated_types import RunEvalData, RunEvalData1, RunEvalData2
|
|
5
5
|
from braintrust.logger import BraintrustState
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
async def get_dataset_by_id(state: BraintrustState, dataset_id: str) ->
|
|
8
|
+
async def get_dataset_by_id(state: BraintrustState, dataset_id: str) -> dict[str, str]:
|
|
9
9
|
"""Fetch dataset information by ID."""
|
|
10
10
|
# Make API call to get dataset info
|
|
11
11
|
conn = state.api_conn()
|
|
@@ -23,9 +23,7 @@ async def get_dataset_by_id(state: BraintrustState, dataset_id: str) -> Dict[str
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
# NOTE: To make this performant, we'll have to make these functions work with async i/o
|
|
26
|
-
async def get_dataset(
|
|
27
|
-
state: BraintrustState, data: Union[RunEvalData, RunEvalData1, RunEvalData2, Dict[str, Any]]
|
|
28
|
-
) -> Any:
|
|
26
|
+
async def get_dataset(state: BraintrustState, data: RunEvalData | RunEvalData1 | RunEvalData2 | dict[str, Any]) -> Any:
|
|
29
27
|
"""
|
|
30
28
|
Get dataset from various data sources.
|
|
31
29
|
|
|
@@ -7,7 +7,8 @@ for reporting progress during evaluation execution.
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import json
|
|
10
|
-
from
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import Any
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class EvalHooks:
|
|
@@ -15,13 +16,13 @@ class EvalHooks:
|
|
|
15
16
|
|
|
16
17
|
def __init__(
|
|
17
18
|
self,
|
|
18
|
-
report_progress:
|
|
19
|
-
parameters:
|
|
19
|
+
report_progress: Callable[[dict[str, Any]], None] | None = None,
|
|
20
|
+
parameters: dict[str, Any] | None = None,
|
|
20
21
|
):
|
|
21
22
|
self._report_progress = report_progress
|
|
22
23
|
self.parameters = parameters or {}
|
|
23
24
|
|
|
24
|
-
def report_progress(self, event:
|
|
25
|
+
def report_progress(self, event: dict[str, Any]) -> None:
|
|
25
26
|
"""Report progress during task execution."""
|
|
26
27
|
if self._report_progress:
|
|
27
28
|
self._report_progress(event)
|
|
@@ -45,7 +46,7 @@ class SSEQueue:
|
|
|
45
46
|
"""Simple wrapper around asyncio.Queue for SSE events."""
|
|
46
47
|
|
|
47
48
|
def __init__(self):
|
|
48
|
-
self.queue: asyncio.Queue[
|
|
49
|
+
self.queue: asyncio.Queue[str | None] = asyncio.Queue()
|
|
49
50
|
|
|
50
51
|
async def put_event(self, event: str, data: Any) -> None:
|
|
51
52
|
"""Add an SSE event to the queue."""
|
|
@@ -56,6 +57,6 @@ class SSEQueue:
|
|
|
56
57
|
"""Signal end of stream."""
|
|
57
58
|
await self.queue.put(None)
|
|
58
59
|
|
|
59
|
-
async def get(self) ->
|
|
60
|
+
async def get(self) -> str | None:
|
|
60
61
|
"""Get the next event from the queue."""
|
|
61
62
|
return await self.queue.get()
|
braintrust/devserver/schemas.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import Any, Union, get_args, get_origin, get_type_hints
|
|
3
4
|
|
|
4
|
-
from typing_extensions import TypedDict
|
|
5
|
+
from typing_extensions import TypedDict
|
|
5
6
|
|
|
6
7
|
# This is not beautiful code, but it saves us from introducing Pydantic as a dependency, and it is fairly
|
|
7
8
|
# straightforward for an LLM to keep it up to date with runEvalBodySchema in JS.
|
|
@@ -16,12 +17,12 @@ class ValidationError(Exception):
|
|
|
16
17
|
class ParsedFunctionId(TypedDict, total=False):
|
|
17
18
|
"""Parsed function identifier."""
|
|
18
19
|
|
|
19
|
-
function_id:
|
|
20
|
-
version:
|
|
21
|
-
name:
|
|
22
|
-
prompt_session_id:
|
|
23
|
-
inline_code:
|
|
24
|
-
global_function:
|
|
20
|
+
function_id: str | None
|
|
21
|
+
version: str | None
|
|
22
|
+
name: str | None
|
|
23
|
+
prompt_session_id: str | None
|
|
24
|
+
inline_code: str | None
|
|
25
|
+
global_function: str | None
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
class ParsedParent(TypedDict):
|
|
@@ -35,16 +36,16 @@ class ParsedEvalBody(TypedDict, total=False):
|
|
|
35
36
|
"""Type for parsed eval request body."""
|
|
36
37
|
|
|
37
38
|
name: str # Required
|
|
38
|
-
parameters:
|
|
39
|
+
parameters: dict[str, Any]
|
|
39
40
|
data: Any
|
|
40
|
-
scores:
|
|
41
|
+
scores: list[ParsedFunctionId]
|
|
41
42
|
experiment_name: str
|
|
42
43
|
project_id: str
|
|
43
|
-
parent:
|
|
44
|
+
parent: str | ParsedParent
|
|
44
45
|
stream: bool
|
|
45
46
|
|
|
46
47
|
|
|
47
|
-
def validate_typed_dict(data: Any, typed_dict_class: type, path: str = "") ->
|
|
48
|
+
def validate_typed_dict(data: Any, typed_dict_class: type, path: str = "") -> dict[str, Any]:
|
|
48
49
|
"""Validate data against a TypedDict definition."""
|
|
49
50
|
if not isinstance(data, dict):
|
|
50
51
|
raise ValidationError(f"{path or 'Root'} must be a dictionary, got {type(data).__name__}")
|
|
@@ -107,7 +108,7 @@ def validate_value(value: Any, expected_type: type, path: str) -> Any:
|
|
|
107
108
|
return validate_value(value, inner_type, path)
|
|
108
109
|
|
|
109
110
|
# Handle List/Sequence
|
|
110
|
-
if origin in (list,
|
|
111
|
+
if origin in (list, list, Sequence):
|
|
111
112
|
if not isinstance(value, list):
|
|
112
113
|
raise ValidationError(f"{path} must be a list, got {type(value).__name__}")
|
|
113
114
|
|
|
@@ -115,7 +116,7 @@ def validate_value(value: Any, expected_type: type, path: str) -> Any:
|
|
|
115
116
|
return [validate_value(item, item_type, f"{path}[{i}]") for i, item in enumerate(value)]
|
|
116
117
|
|
|
117
118
|
# Handle Dict/Mapping
|
|
118
|
-
if origin in (dict,
|
|
119
|
+
if origin in (dict, dict):
|
|
119
120
|
if not isinstance(value, dict):
|
|
120
121
|
raise ValidationError(f"{path} must be a dict, got {type(value).__name__}")
|
|
121
122
|
|
|
@@ -172,7 +173,7 @@ def parse_function_id(data: Any, path: str = "function") -> ParsedFunctionId:
|
|
|
172
173
|
raise ValidationError(f"{path} must specify function_id, name, prompt_session_id, or inline_code")
|
|
173
174
|
|
|
174
175
|
|
|
175
|
-
def parse_eval_body(request_data:
|
|
176
|
+
def parse_eval_body(request_data: str | bytes | dict) -> ParsedEvalBody:
|
|
176
177
|
"""
|
|
177
178
|
Parse request body for eval execution.
|
|
178
179
|
|
|
@@ -221,10 +222,12 @@ def parse_eval_body(request_data: Union[str, bytes, dict]) -> ParsedEvalBody:
|
|
|
221
222
|
parsed_scores = []
|
|
222
223
|
for i, score in enumerate(scores_data):
|
|
223
224
|
try:
|
|
224
|
-
parsed_scores.append(
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
225
|
+
parsed_scores.append(
|
|
226
|
+
{
|
|
227
|
+
"name": score["name"],
|
|
228
|
+
"function_id": parse_function_id(score["function_id"], f"scores[{i}]"),
|
|
229
|
+
}
|
|
230
|
+
)
|
|
228
231
|
except ValidationError as e:
|
|
229
232
|
raise ValidationError(f"Invalid score at index {i}: {e}")
|
|
230
233
|
|
braintrust/devserver/server.py
CHANGED
|
@@ -2,7 +2,7 @@ import asyncio
|
|
|
2
2
|
import json
|
|
3
3
|
import sys
|
|
4
4
|
import textwrap
|
|
5
|
-
from typing import Any
|
|
5
|
+
from typing import Any
|
|
6
6
|
|
|
7
7
|
try:
|
|
8
8
|
import uvicorn
|
|
@@ -40,7 +40,7 @@ _all_evaluators: dict[str, Evaluator[Any, Any]] = {}
|
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
class CheckAuthorizedMiddleware(BaseHTTPMiddleware):
|
|
43
|
-
def __init__(self, app, allowed_org_name:
|
|
43
|
+
def __init__(self, app, allowed_org_name: str | None = None):
|
|
44
44
|
super().__init__(app)
|
|
45
45
|
self.allowed_org_name = allowed_org_name
|
|
46
46
|
self.protected_paths = ["/list", "/eval"]
|
|
@@ -100,7 +100,7 @@ async def list_evaluators(request: Request) -> JSONResponse:
|
|
|
100
100
|
return JSONResponse(evaluator_list)
|
|
101
101
|
|
|
102
102
|
|
|
103
|
-
async def run_eval(request: Request) ->
|
|
103
|
+
async def run_eval(request: Request) -> JSONResponse | StreamingResponse:
|
|
104
104
|
"""Handle eval execution requests."""
|
|
105
105
|
try:
|
|
106
106
|
# Get request body
|
|
@@ -157,12 +157,14 @@ async def run_eval(request: Request) -> Union[JSONResponse, StreamingResponse]:
|
|
|
157
157
|
result = await evaluator.task(input, hooks)
|
|
158
158
|
else:
|
|
159
159
|
result = evaluator.task(input, hooks)
|
|
160
|
-
hooks.report_progress(
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
160
|
+
hooks.report_progress(
|
|
161
|
+
{
|
|
162
|
+
"format": "code",
|
|
163
|
+
"output_type": "completion",
|
|
164
|
+
"event": "json_delta",
|
|
165
|
+
"data": json.dumps(result),
|
|
166
|
+
}
|
|
167
|
+
)
|
|
166
168
|
return result
|
|
167
169
|
|
|
168
170
|
def on_start_fn(summary: ExperimentSummary):
|
|
@@ -214,6 +216,7 @@ async def run_eval(request: Request) -> Union[JSONResponse, StreamingResponse]:
|
|
|
214
216
|
|
|
215
217
|
async def event_generator():
|
|
216
218
|
"""Generate SSE events from the queue."""
|
|
219
|
+
|
|
217
220
|
# Create a task to run the eval and signal completion
|
|
218
221
|
async def run_and_complete():
|
|
219
222
|
try:
|
|
@@ -255,7 +258,7 @@ async def run_eval(request: Request) -> Union[JSONResponse, StreamingResponse]:
|
|
|
255
258
|
return JSONResponse({"error": f"Failed to run evaluation: {str(e)}"}, status_code=500)
|
|
256
259
|
|
|
257
260
|
|
|
258
|
-
def create_app(evaluators: list[Evaluator[Any, Any]], org_name:
|
|
261
|
+
def create_app(evaluators: list[Evaluator[Any, Any]], org_name: str | None = None):
|
|
259
262
|
"""Create and configure the Starlette app for the dev server.
|
|
260
263
|
|
|
261
264
|
Args:
|
|
@@ -283,7 +286,9 @@ def create_app(evaluators: list[Evaluator[Any, Any]], org_name: Optional[str] =
|
|
|
283
286
|
return app
|
|
284
287
|
|
|
285
288
|
|
|
286
|
-
def run_dev_server(
|
|
289
|
+
def run_dev_server(
|
|
290
|
+
evaluators: list[Evaluator[Any, Any]], host: str = "localhost", port: int = 8300, org_name: str | None = None
|
|
291
|
+
):
|
|
287
292
|
"""Start the dev server.
|
|
288
293
|
|
|
289
294
|
Args:
|
|
@@ -305,7 +310,9 @@ def snake_to_camel(snake_str: str) -> str:
|
|
|
305
310
|
return components[0] + "".join(x.title() for x in components[1:]) if components else snake_str
|
|
306
311
|
|
|
307
312
|
|
|
308
|
-
def make_scorer(
|
|
313
|
+
def make_scorer(
|
|
314
|
+
state: BraintrustState, name: str, score: FunctionId, project_id: str | None = None
|
|
315
|
+
) -> EvalScorer[Any, Any]:
|
|
309
316
|
def scorer_fn(input, output, expected, metadata):
|
|
310
317
|
request = {
|
|
311
318
|
**score,
|
|
@@ -10,7 +10,7 @@ class TestCachedLogin(unittest.TestCase):
|
|
|
10
10
|
"""Clear the cache before each test."""
|
|
11
11
|
cache._login_cache = cache.LRUCache(max_size=32)
|
|
12
12
|
|
|
13
|
-
@patch(
|
|
13
|
+
@patch("braintrust.devserver.cache.login_to_state")
|
|
14
14
|
def test_cached_login_caches_results(self, mock_login):
|
|
15
15
|
"""Test that cached_login caches and reuses results."""
|
|
16
16
|
mock_state = MagicMock()
|
|
@@ -26,7 +26,7 @@ class TestCachedLogin(unittest.TestCase):
|
|
|
26
26
|
self.assertEqual(result2, mock_state)
|
|
27
27
|
self.assertEqual(mock_login.call_count, 1) # Still 1, not called again
|
|
28
28
|
|
|
29
|
-
@patch(
|
|
29
|
+
@patch("braintrust.devserver.cache.login_to_state")
|
|
30
30
|
def test_cached_login_different_keys(self, mock_login):
|
|
31
31
|
"""Test that different cache keys create separate entries."""
|
|
32
32
|
mock_state1 = MagicMock()
|
|
@@ -48,7 +48,7 @@ class TestCachedLogin(unittest.TestCase):
|
|
|
48
48
|
self.assertEqual(result3, mock_state3)
|
|
49
49
|
self.assertEqual(mock_login.call_count, 3)
|
|
50
50
|
|
|
51
|
-
@patch(
|
|
51
|
+
@patch("braintrust.devserver.cache.login_to_state")
|
|
52
52
|
def test_cached_login_with_org_name(self, mock_login):
|
|
53
53
|
"""Test caching with org_name parameter."""
|
|
54
54
|
mock_state = MagicMock()
|
|
@@ -68,7 +68,7 @@ class TestCachedLogin(unittest.TestCase):
|
|
|
68
68
|
result3 = asyncio.run(cache.cached_login("api_key_1", "https://app.braintrust.com", org_name="other_org"))
|
|
69
69
|
self.assertEqual(mock_login.call_count, 2)
|
|
70
70
|
|
|
71
|
-
@patch(
|
|
71
|
+
@patch("braintrust.devserver.cache.login_to_state")
|
|
72
72
|
def test_cached_login_propagates_exceptions(self, mock_login):
|
|
73
73
|
"""Test that exceptions from login_to_state are propagated."""
|
|
74
74
|
mock_login.side_effect = ValueError("Invalid API key")
|