braintrust 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. braintrust/_generated_types.py +737 -672
  2. braintrust/audit.py +2 -2
  3. braintrust/bt_json.py +178 -19
  4. braintrust/cli/eval.py +6 -7
  5. braintrust/cli/push.py +11 -11
  6. braintrust/context.py +12 -17
  7. braintrust/contrib/temporal/__init__.py +16 -27
  8. braintrust/contrib/temporal/test_temporal.py +8 -3
  9. braintrust/devserver/auth.py +8 -8
  10. braintrust/devserver/cache.py +3 -4
  11. braintrust/devserver/cors.py +8 -7
  12. braintrust/devserver/dataset.py +3 -5
  13. braintrust/devserver/eval_hooks.py +7 -6
  14. braintrust/devserver/schemas.py +22 -19
  15. braintrust/devserver/server.py +19 -12
  16. braintrust/devserver/test_cached_login.py +4 -4
  17. braintrust/framework.py +139 -142
  18. braintrust/framework2.py +88 -87
  19. braintrust/functions/invoke.py +66 -59
  20. braintrust/functions/stream.py +3 -2
  21. braintrust/generated_types.py +3 -1
  22. braintrust/git_fields.py +11 -11
  23. braintrust/gitutil.py +2 -3
  24. braintrust/graph_util.py +10 -10
  25. braintrust/id_gen.py +2 -2
  26. braintrust/logger.py +373 -471
  27. braintrust/merge_row_batch.py +10 -9
  28. braintrust/oai.py +21 -20
  29. braintrust/otel/__init__.py +49 -49
  30. braintrust/otel/context.py +16 -30
  31. braintrust/otel/test_distributed_tracing.py +14 -11
  32. braintrust/otel/test_otel_bt_integration.py +32 -31
  33. braintrust/parameters.py +8 -8
  34. braintrust/prompt.py +14 -14
  35. braintrust/prompt_cache/disk_cache.py +5 -4
  36. braintrust/prompt_cache/lru_cache.py +3 -2
  37. braintrust/prompt_cache/prompt_cache.py +13 -14
  38. braintrust/queue.py +4 -4
  39. braintrust/score.py +4 -4
  40. braintrust/serializable_data_class.py +4 -4
  41. braintrust/span_identifier_v1.py +1 -2
  42. braintrust/span_identifier_v2.py +3 -4
  43. braintrust/span_identifier_v3.py +23 -20
  44. braintrust/span_identifier_v4.py +34 -25
  45. braintrust/test_bt_json.py +644 -0
  46. braintrust/test_framework.py +72 -6
  47. braintrust/test_helpers.py +5 -5
  48. braintrust/test_id_gen.py +2 -3
  49. braintrust/test_logger.py +211 -107
  50. braintrust/test_otel.py +61 -53
  51. braintrust/test_queue.py +0 -1
  52. braintrust/test_score.py +1 -3
  53. braintrust/test_span_components.py +29 -44
  54. braintrust/util.py +9 -8
  55. braintrust/version.py +2 -2
  56. braintrust/wrappers/_anthropic_utils.py +4 -4
  57. braintrust/wrappers/agno/__init__.py +3 -4
  58. braintrust/wrappers/agno/agent.py +1 -2
  59. braintrust/wrappers/agno/function_call.py +1 -2
  60. braintrust/wrappers/agno/model.py +1 -2
  61. braintrust/wrappers/agno/team.py +1 -2
  62. braintrust/wrappers/agno/utils.py +12 -12
  63. braintrust/wrappers/anthropic.py +7 -8
  64. braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
  65. braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
  66. braintrust/wrappers/dspy.py +15 -17
  67. braintrust/wrappers/google_genai/__init__.py +17 -30
  68. braintrust/wrappers/langchain.py +22 -24
  69. braintrust/wrappers/litellm.py +4 -3
  70. braintrust/wrappers/openai.py +15 -15
  71. braintrust/wrappers/pydantic_ai.py +225 -110
  72. braintrust/wrappers/test_agno.py +0 -1
  73. braintrust/wrappers/test_dspy.py +0 -1
  74. braintrust/wrappers/test_google_genai.py +64 -4
  75. braintrust/wrappers/test_litellm.py +0 -1
  76. braintrust/wrappers/test_pydantic_ai_integration.py +819 -22
  77. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/METADATA +3 -2
  78. braintrust-0.4.1.dist-info/RECORD +121 -0
  79. braintrust-0.3.15.dist-info/RECORD +0 -120
  80. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/WHEEL +0 -0
  81. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/entry_points.txt +0 -0
  82. {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
+ from collections.abc import Awaitable, Callable
1
2
  from dataclasses import dataclass
2
- from typing import Awaitable, Callable, Dict, Optional
3
3
 
4
4
  from starlette.middleware.base import BaseHTTPMiddleware
5
5
  from starlette.requests import Request
@@ -15,14 +15,14 @@ BRAINTRUST_PROJECT_ID_HEADER = "x-bt-project-id"
15
15
 
16
16
  @dataclass
17
17
  class RequestContext:
18
- app_origin: Optional[str]
19
- token: Optional[str]
20
- org_name: Optional[str]
21
- project_id: Optional[str]
22
- state: Optional[BraintrustState]
18
+ app_origin: str | None
19
+ token: str | None
20
+ org_name: str | None
21
+ project_id: str | None
22
+ state: BraintrustState | None
23
23
 
24
24
 
25
- def extract_allowed_origin(origin: Optional[str]) -> Optional[str]:
25
+ def extract_allowed_origin(origin: str | None) -> str | None:
26
26
  """Extract and validate the origin header."""
27
27
  # This should use the same check_origin logic from cors.py
28
28
  from .cors import check_origin
@@ -32,7 +32,7 @@ def extract_allowed_origin(origin: Optional[str]) -> Optional[str]:
32
32
  return None
33
33
 
34
34
 
35
- def parse_braintrust_auth_header(headers: Dict[str, str]) -> Optional[str]:
35
+ def parse_braintrust_auth_header(headers: dict[str, str]) -> str | None:
36
36
  """Parse the authorization token from headers."""
37
37
  # Check x-bt-auth-token first
38
38
  token = headers.get(BRAINTRUST_AUTH_TOKEN_HEADER)
@@ -1,7 +1,6 @@
1
1
  """LRU cache implementation for the dev server."""
2
2
 
3
3
  import json
4
- from typing import Dict, Optional
5
4
 
6
5
  from ..logger import BraintrustState, login_to_state
7
6
 
@@ -11,10 +10,10 @@ class LRUCache:
11
10
 
12
11
  def __init__(self, max_size: int = 32):
13
12
  self.max_size = max_size
14
- self.cache: Dict[str, BraintrustState] = {}
13
+ self.cache: dict[str, BraintrustState] = {}
15
14
  self.access_order: list[str] = []
16
15
 
17
- def get(self, key: str) -> Optional[BraintrustState]:
16
+ def get(self, key: str) -> BraintrustState | None:
18
17
  """Get a value from the cache, updating access order."""
19
18
  if key in self.cache:
20
19
  # Move to end to mark as recently used
@@ -41,7 +40,7 @@ class LRUCache:
41
40
  _login_cache = LRUCache(max_size=32) # TODO: Make this configurable
42
41
 
43
42
 
44
- async def cached_login(api_key: str, app_url: str, org_name: Optional[str] = None) -> BraintrustState:
43
+ async def cached_login(api_key: str, app_url: str, org_name: str | None = None) -> BraintrustState:
45
44
  """Login with caching to avoid repeated API calls."""
46
45
  cache_key = json.dumps({"api_key": api_key, "app_url": app_url, "org_name": org_name})
47
46
 
@@ -1,9 +1,10 @@
1
1
  import os
2
2
  import re
3
- from typing import Any, Awaitable, Callable, Dict, List, Union
3
+ from collections.abc import Awaitable, Callable
4
+ from typing import Any
4
5
 
5
6
  # CORS configuration
6
- ALLOWED_ORIGINS: List[Union[str, re.Pattern]] = [
7
+ ALLOWED_ORIGINS: list[str | re.Pattern] = [
7
8
  "https://www.braintrust.dev",
8
9
  "https://www.braintrustdata.com",
9
10
  re.compile(r"https://.*\.preview\.braintrust\.dev"),
@@ -70,9 +71,9 @@ def create_cors_middleware() -> type:
70
71
 
71
72
  async def __call__(
72
73
  self,
73
- scope: Dict[str, Any],
74
- receive: Callable[[], Awaitable[Dict[str, Any]]],
75
- send: Callable[[Dict[str, Any]], Awaitable[None]],
74
+ scope: dict[str, Any],
75
+ receive: Callable[[], Awaitable[dict[str, Any]]],
76
+ send: Callable[[dict[str, Any]], Awaitable[None]],
76
77
  ) -> None:
77
78
  if scope["type"] == "http":
78
79
  headers = dict(scope["headers"])
@@ -81,7 +82,7 @@ def create_cors_middleware() -> type:
81
82
  # Handle OPTIONS requests
82
83
  if scope["method"] == "OPTIONS":
83
84
 
84
- async def send_options_wrapper(message: Dict[str, Any]) -> None:
85
+ async def send_options_wrapper(message: dict[str, Any]) -> None:
85
86
  if message["type"] == "http.response.start":
86
87
  headers_dict = dict(message.get("headers", []))
87
88
 
@@ -120,7 +121,7 @@ def create_cors_middleware() -> type:
120
121
  return
121
122
 
122
123
  # For other requests, add CORS headers if origin is valid
123
- async def send_wrapper(message: Dict[str, Any]) -> None:
124
+ async def send_wrapper(message: dict[str, Any]) -> None:
124
125
  if message["type"] == "http.response.start" and origin and check_origin(origin):
125
126
  headers_dict = dict(message.get("headers", []))
126
127
 
@@ -1,11 +1,11 @@
1
- from typing import Any, Dict, Union
1
+ from typing import Any
2
2
 
3
3
  from braintrust import init_dataset
4
4
  from braintrust._generated_types import RunEvalData, RunEvalData1, RunEvalData2
5
5
  from braintrust.logger import BraintrustState
6
6
 
7
7
 
8
- async def get_dataset_by_id(state: BraintrustState, dataset_id: str) -> Dict[str, str]:
8
+ async def get_dataset_by_id(state: BraintrustState, dataset_id: str) -> dict[str, str]:
9
9
  """Fetch dataset information by ID."""
10
10
  # Make API call to get dataset info
11
11
  conn = state.api_conn()
@@ -23,9 +23,7 @@ async def get_dataset_by_id(state: BraintrustState, dataset_id: str) -> Dict[str
23
23
 
24
24
 
25
25
  # NOTE: To make this performant, we'll have to make these functions work with async i/o
26
- async def get_dataset(
27
- state: BraintrustState, data: Union[RunEvalData, RunEvalData1, RunEvalData2, Dict[str, Any]]
28
- ) -> Any:
26
+ async def get_dataset(state: BraintrustState, data: RunEvalData | RunEvalData1 | RunEvalData2 | dict[str, Any]) -> Any:
29
27
  """
30
28
  Get dataset from various data sources.
31
29
 
@@ -7,7 +7,8 @@ for reporting progress during evaluation execution.
7
7
 
8
8
  import asyncio
9
9
  import json
10
- from typing import Any, Callable, Dict, Optional
10
+ from collections.abc import Callable
11
+ from typing import Any
11
12
 
12
13
 
13
14
  class EvalHooks:
@@ -15,13 +16,13 @@ class EvalHooks:
15
16
 
16
17
  def __init__(
17
18
  self,
18
- report_progress: Optional[Callable[[Dict[str, Any]], None]] = None,
19
- parameters: Optional[Dict[str, Any]] = None,
19
+ report_progress: Callable[[dict[str, Any]], None] | None = None,
20
+ parameters: dict[str, Any] | None = None,
20
21
  ):
21
22
  self._report_progress = report_progress
22
23
  self.parameters = parameters or {}
23
24
 
24
- def report_progress(self, event: Dict[str, Any]) -> None:
25
+ def report_progress(self, event: dict[str, Any]) -> None:
25
26
  """Report progress during task execution."""
26
27
  if self._report_progress:
27
28
  self._report_progress(event)
@@ -45,7 +46,7 @@ class SSEQueue:
45
46
  """Simple wrapper around asyncio.Queue for SSE events."""
46
47
 
47
48
  def __init__(self):
48
- self.queue: asyncio.Queue[Optional[str]] = asyncio.Queue()
49
+ self.queue: asyncio.Queue[str | None] = asyncio.Queue()
49
50
 
50
51
  async def put_event(self, event: str, data: Any) -> None:
51
52
  """Add an SSE event to the queue."""
@@ -56,6 +57,6 @@ class SSEQueue:
56
57
  """Signal end of stream."""
57
58
  await self.queue.put(None)
58
59
 
59
- async def get(self) -> Optional[str]:
60
+ async def get(self) -> str | None:
60
61
  """Get the next event from the queue."""
61
62
  return await self.queue.get()
@@ -1,7 +1,8 @@
1
1
  import json
2
- from typing import Any, Dict, List, Optional, Sequence, Union, get_args, get_origin
2
+ from collections.abc import Sequence
3
+ from typing import Any, Union, get_args, get_origin, get_type_hints
3
4
 
4
- from typing_extensions import TypedDict, get_type_hints
5
+ from typing_extensions import TypedDict
5
6
 
6
7
  # This is not beautiful code, but it saves us from introducing Pydantic as a dependency, and it is fairly
7
8
  # straightforward for an LLM to keep it up to date with runEvalBodySchema in JS.
@@ -16,12 +17,12 @@ class ValidationError(Exception):
16
17
  class ParsedFunctionId(TypedDict, total=False):
17
18
  """Parsed function identifier."""
18
19
 
19
- function_id: Optional[str]
20
- version: Optional[str]
21
- name: Optional[str]
22
- prompt_session_id: Optional[str]
23
- inline_code: Optional[str]
24
- global_function: Optional[str]
20
+ function_id: str | None
21
+ version: str | None
22
+ name: str | None
23
+ prompt_session_id: str | None
24
+ inline_code: str | None
25
+ global_function: str | None
25
26
 
26
27
 
27
28
  class ParsedParent(TypedDict):
@@ -35,16 +36,16 @@ class ParsedEvalBody(TypedDict, total=False):
35
36
  """Type for parsed eval request body."""
36
37
 
37
38
  name: str # Required
38
- parameters: Dict[str, Any]
39
+ parameters: dict[str, Any]
39
40
  data: Any
40
- scores: List[ParsedFunctionId]
41
+ scores: list[ParsedFunctionId]
41
42
  experiment_name: str
42
43
  project_id: str
43
- parent: Union[str, ParsedParent]
44
+ parent: str | ParsedParent
44
45
  stream: bool
45
46
 
46
47
 
47
- def validate_typed_dict(data: Any, typed_dict_class: type, path: str = "") -> Dict[str, Any]:
48
+ def validate_typed_dict(data: Any, typed_dict_class: type, path: str = "") -> dict[str, Any]:
48
49
  """Validate data against a TypedDict definition."""
49
50
  if not isinstance(data, dict):
50
51
  raise ValidationError(f"{path or 'Root'} must be a dictionary, got {type(data).__name__}")
@@ -107,7 +108,7 @@ def validate_value(value: Any, expected_type: type, path: str) -> Any:
107
108
  return validate_value(value, inner_type, path)
108
109
 
109
110
  # Handle List/Sequence
110
- if origin in (list, List, Sequence):
111
+ if origin in (list, list, Sequence):
111
112
  if not isinstance(value, list):
112
113
  raise ValidationError(f"{path} must be a list, got {type(value).__name__}")
113
114
 
@@ -115,7 +116,7 @@ def validate_value(value: Any, expected_type: type, path: str) -> Any:
115
116
  return [validate_value(item, item_type, f"{path}[{i}]") for i, item in enumerate(value)]
116
117
 
117
118
  # Handle Dict/Mapping
118
- if origin in (dict, Dict):
119
+ if origin in (dict, dict):
119
120
  if not isinstance(value, dict):
120
121
  raise ValidationError(f"{path} must be a dict, got {type(value).__name__}")
121
122
 
@@ -172,7 +173,7 @@ def parse_function_id(data: Any, path: str = "function") -> ParsedFunctionId:
172
173
  raise ValidationError(f"{path} must specify function_id, name, prompt_session_id, or inline_code")
173
174
 
174
175
 
175
- def parse_eval_body(request_data: Union[str, bytes, dict]) -> ParsedEvalBody:
176
+ def parse_eval_body(request_data: str | bytes | dict) -> ParsedEvalBody:
176
177
  """
177
178
  Parse request body for eval execution.
178
179
 
@@ -221,10 +222,12 @@ def parse_eval_body(request_data: Union[str, bytes, dict]) -> ParsedEvalBody:
221
222
  parsed_scores = []
222
223
  for i, score in enumerate(scores_data):
223
224
  try:
224
- parsed_scores.append({
225
- "name": score["name"],
226
- "function_id": parse_function_id(score["function_id"], f"scores[{i}]"),
227
- })
225
+ parsed_scores.append(
226
+ {
227
+ "name": score["name"],
228
+ "function_id": parse_function_id(score["function_id"], f"scores[{i}]"),
229
+ }
230
+ )
228
231
  except ValidationError as e:
229
232
  raise ValidationError(f"Invalid score at index {i}: {e}")
230
233
 
@@ -2,7 +2,7 @@ import asyncio
2
2
  import json
3
3
  import sys
4
4
  import textwrap
5
- from typing import Any, Optional, Union
5
+ from typing import Any
6
6
 
7
7
  try:
8
8
  import uvicorn
@@ -40,7 +40,7 @@ _all_evaluators: dict[str, Evaluator[Any, Any]] = {}
40
40
 
41
41
 
42
42
  class CheckAuthorizedMiddleware(BaseHTTPMiddleware):
43
- def __init__(self, app, allowed_org_name: Optional[str] = None):
43
+ def __init__(self, app, allowed_org_name: str | None = None):
44
44
  super().__init__(app)
45
45
  self.allowed_org_name = allowed_org_name
46
46
  self.protected_paths = ["/list", "/eval"]
@@ -100,7 +100,7 @@ async def list_evaluators(request: Request) -> JSONResponse:
100
100
  return JSONResponse(evaluator_list)
101
101
 
102
102
 
103
- async def run_eval(request: Request) -> Union[JSONResponse, StreamingResponse]:
103
+ async def run_eval(request: Request) -> JSONResponse | StreamingResponse:
104
104
  """Handle eval execution requests."""
105
105
  try:
106
106
  # Get request body
@@ -157,12 +157,14 @@ async def run_eval(request: Request) -> Union[JSONResponse, StreamingResponse]:
157
157
  result = await evaluator.task(input, hooks)
158
158
  else:
159
159
  result = evaluator.task(input, hooks)
160
- hooks.report_progress({
161
- "format": "code",
162
- "output_type": "completion",
163
- "event": "json_delta",
164
- "data": json.dumps(result),
165
- })
160
+ hooks.report_progress(
161
+ {
162
+ "format": "code",
163
+ "output_type": "completion",
164
+ "event": "json_delta",
165
+ "data": json.dumps(result),
166
+ }
167
+ )
166
168
  return result
167
169
 
168
170
  def on_start_fn(summary: ExperimentSummary):
@@ -214,6 +216,7 @@ async def run_eval(request: Request) -> Union[JSONResponse, StreamingResponse]:
214
216
 
215
217
  async def event_generator():
216
218
  """Generate SSE events from the queue."""
219
+
217
220
  # Create a task to run the eval and signal completion
218
221
  async def run_and_complete():
219
222
  try:
@@ -255,7 +258,7 @@ async def run_eval(request: Request) -> Union[JSONResponse, StreamingResponse]:
255
258
  return JSONResponse({"error": f"Failed to run evaluation: {str(e)}"}, status_code=500)
256
259
 
257
260
 
258
- def create_app(evaluators: list[Evaluator[Any, Any]], org_name: Optional[str] = None):
261
+ def create_app(evaluators: list[Evaluator[Any, Any]], org_name: str | None = None):
259
262
  """Create and configure the Starlette app for the dev server.
260
263
 
261
264
  Args:
@@ -283,7 +286,9 @@ def create_app(evaluators: list[Evaluator[Any, Any]], org_name: Optional[str] =
283
286
  return app
284
287
 
285
288
 
286
- def run_dev_server(evaluators: list[Evaluator[Any, Any]], host: str = "localhost", port: int = 8300, org_name: Optional[str] = None):
289
+ def run_dev_server(
290
+ evaluators: list[Evaluator[Any, Any]], host: str = "localhost", port: int = 8300, org_name: str | None = None
291
+ ):
287
292
  """Start the dev server.
288
293
 
289
294
  Args:
@@ -305,7 +310,9 @@ def snake_to_camel(snake_str: str) -> str:
305
310
  return components[0] + "".join(x.title() for x in components[1:]) if components else snake_str
306
311
 
307
312
 
308
- def make_scorer(state: BraintrustState, name: str, score: FunctionId, project_id: Optional[str] = None) -> EvalScorer[Any, Any]:
313
+ def make_scorer(
314
+ state: BraintrustState, name: str, score: FunctionId, project_id: str | None = None
315
+ ) -> EvalScorer[Any, Any]:
309
316
  def scorer_fn(input, output, expected, metadata):
310
317
  request = {
311
318
  **score,
@@ -10,7 +10,7 @@ class TestCachedLogin(unittest.TestCase):
10
10
  """Clear the cache before each test."""
11
11
  cache._login_cache = cache.LRUCache(max_size=32)
12
12
 
13
- @patch('braintrust.devserver.cache.login_to_state')
13
+ @patch("braintrust.devserver.cache.login_to_state")
14
14
  def test_cached_login_caches_results(self, mock_login):
15
15
  """Test that cached_login caches and reuses results."""
16
16
  mock_state = MagicMock()
@@ -26,7 +26,7 @@ class TestCachedLogin(unittest.TestCase):
26
26
  self.assertEqual(result2, mock_state)
27
27
  self.assertEqual(mock_login.call_count, 1) # Still 1, not called again
28
28
 
29
- @patch('braintrust.devserver.cache.login_to_state')
29
+ @patch("braintrust.devserver.cache.login_to_state")
30
30
  def test_cached_login_different_keys(self, mock_login):
31
31
  """Test that different cache keys create separate entries."""
32
32
  mock_state1 = MagicMock()
@@ -48,7 +48,7 @@ class TestCachedLogin(unittest.TestCase):
48
48
  self.assertEqual(result3, mock_state3)
49
49
  self.assertEqual(mock_login.call_count, 3)
50
50
 
51
- @patch('braintrust.devserver.cache.login_to_state')
51
+ @patch("braintrust.devserver.cache.login_to_state")
52
52
  def test_cached_login_with_org_name(self, mock_login):
53
53
  """Test caching with org_name parameter."""
54
54
  mock_state = MagicMock()
@@ -68,7 +68,7 @@ class TestCachedLogin(unittest.TestCase):
68
68
  result3 = asyncio.run(cache.cached_login("api_key_1", "https://app.braintrust.com", org_name="other_org"))
69
69
  self.assertEqual(mock_login.call_count, 2)
70
70
 
71
- @patch('braintrust.devserver.cache.login_to_state')
71
+ @patch("braintrust.devserver.cache.login_to_state")
72
72
  def test_cached_login_propagates_exceptions(self, mock_login):
73
73
  """Test that exceptions from login_to_state are propagated."""
74
74
  mock_login.side_effect = ValueError("Invalid API key")