braintrust 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. braintrust/__init__.py +4 -0
  2. braintrust/_generated_types.py +1200 -611
  3. braintrust/audit.py +2 -2
  4. braintrust/cli/eval.py +6 -7
  5. braintrust/cli/push.py +11 -11
  6. braintrust/conftest.py +1 -0
  7. braintrust/context.py +12 -17
  8. braintrust/contrib/temporal/__init__.py +16 -27
  9. braintrust/contrib/temporal/test_temporal.py +8 -3
  10. braintrust/devserver/auth.py +8 -8
  11. braintrust/devserver/cache.py +3 -4
  12. braintrust/devserver/cors.py +8 -7
  13. braintrust/devserver/dataset.py +3 -5
  14. braintrust/devserver/eval_hooks.py +7 -6
  15. braintrust/devserver/schemas.py +22 -19
  16. braintrust/devserver/server.py +19 -12
  17. braintrust/devserver/test_cached_login.py +4 -4
  18. braintrust/framework.py +128 -140
  19. braintrust/framework2.py +88 -87
  20. braintrust/functions/invoke.py +93 -53
  21. braintrust/functions/stream.py +3 -2
  22. braintrust/generated_types.py +17 -1
  23. braintrust/git_fields.py +11 -11
  24. braintrust/gitutil.py +2 -3
  25. braintrust/graph_util.py +10 -10
  26. braintrust/id_gen.py +2 -2
  27. braintrust/logger.py +346 -357
  28. braintrust/merge_row_batch.py +10 -9
  29. braintrust/oai.py +107 -24
  30. braintrust/otel/__init__.py +49 -49
  31. braintrust/otel/context.py +16 -30
  32. braintrust/otel/test_distributed_tracing.py +14 -11
  33. braintrust/otel/test_otel_bt_integration.py +32 -31
  34. braintrust/parameters.py +8 -8
  35. braintrust/prompt.py +14 -14
  36. braintrust/prompt_cache/disk_cache.py +5 -4
  37. braintrust/prompt_cache/lru_cache.py +3 -2
  38. braintrust/prompt_cache/prompt_cache.py +13 -14
  39. braintrust/queue.py +4 -4
  40. braintrust/score.py +4 -4
  41. braintrust/serializable_data_class.py +4 -4
  42. braintrust/span_identifier_v1.py +1 -2
  43. braintrust/span_identifier_v2.py +3 -4
  44. braintrust/span_identifier_v3.py +23 -20
  45. braintrust/span_identifier_v4.py +34 -25
  46. braintrust/test_framework.py +16 -6
  47. braintrust/test_helpers.py +5 -5
  48. braintrust/test_id_gen.py +2 -3
  49. braintrust/test_otel.py +61 -53
  50. braintrust/test_queue.py +0 -1
  51. braintrust/test_score.py +1 -3
  52. braintrust/test_span_components.py +29 -44
  53. braintrust/util.py +9 -8
  54. braintrust/version.py +2 -2
  55. braintrust/wrappers/_anthropic_utils.py +4 -4
  56. braintrust/wrappers/agno/__init__.py +3 -4
  57. braintrust/wrappers/agno/agent.py +1 -2
  58. braintrust/wrappers/agno/function_call.py +1 -2
  59. braintrust/wrappers/agno/model.py +1 -2
  60. braintrust/wrappers/agno/team.py +1 -2
  61. braintrust/wrappers/agno/utils.py +12 -12
  62. braintrust/wrappers/anthropic.py +7 -8
  63. braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
  64. braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
  65. braintrust/wrappers/dspy.py +15 -17
  66. braintrust/wrappers/google_genai/__init__.py +16 -16
  67. braintrust/wrappers/langchain.py +22 -24
  68. braintrust/wrappers/litellm.py +4 -3
  69. braintrust/wrappers/openai.py +15 -15
  70. braintrust/wrappers/pydantic_ai.py +1204 -0
  71. braintrust/wrappers/test_agno.py +0 -1
  72. braintrust/wrappers/test_dspy.py +0 -1
  73. braintrust/wrappers/test_google_genai.py +2 -3
  74. braintrust/wrappers/test_litellm.py +0 -1
  75. braintrust/wrappers/test_oai_attachments.py +322 -0
  76. braintrust/wrappers/test_pydantic_ai_integration.py +1788 -0
  77. braintrust/wrappers/{test_pydantic_ai.py → test_pydantic_ai_wrap_openai.py} +1 -2
  78. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/METADATA +3 -2
  79. braintrust-0.4.0.dist-info/RECORD +120 -0
  80. braintrust-0.3.14.dist-info/RECORD +0 -117
  81. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/WHEEL +0 -0
  82. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/entry_points.txt +0 -0
  83. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/top_level.txt +0 -0
braintrust/audit.py CHANGED
@@ -5,7 +5,7 @@ Utilities for working with audit headers.
5
5
  import base64
6
6
  import gzip
7
7
  import json
8
- from typing import List, TypedDict
8
+ from typing import TypedDict
9
9
 
10
10
 
11
11
  class AuditResource(TypedDict):
@@ -14,7 +14,7 @@ class AuditResource(TypedDict):
14
14
  name: str
15
15
 
16
16
 
17
- def parse_audit_resources(hdr: str) -> List[AuditResource]:
17
+ def parse_audit_resources(hdr: str) -> list[AuditResource]:
18
18
  j = json.loads(hdr)
19
19
  if j["v"] == 1:
20
20
  return json.loads(gzip.decompress(base64.b64decode(j["p"])))
braintrust/cli/eval.py CHANGED
@@ -6,7 +6,6 @@ import os
6
6
  import sys
7
7
  from dataclasses import dataclass, field
8
8
  from threading import Lock
9
- from typing import Dict, List, Optional, Union
10
9
 
11
10
  from .. import login
12
11
  from ..framework import (
@@ -70,7 +69,7 @@ class EvaluatorOpts:
70
69
  no_progress_bars: bool
71
70
  terminate_on_failure: bool
72
71
  watch: bool
73
- filters: List[str]
72
+ filters: list[str]
74
73
  list: bool
75
74
  jsonl: bool
76
75
 
@@ -79,13 +78,13 @@ class EvaluatorOpts:
79
78
  class LoadedEvaluator:
80
79
  handle: FileHandle
81
80
  evaluator: Evaluator
82
- reporter: Optional[Union[ReporterDef, str]] = None
81
+ reporter: ReporterDef | str | None = None
83
82
 
84
83
 
85
84
  @dataclass
86
85
  class EvaluatorState:
87
- evaluators: List[LoadedEvaluator] = field(default_factory=list)
88
- reporters: Dict[str, ReporterDef] = field(default_factory=dict)
86
+ evaluators: list[LoadedEvaluator] = field(default_factory=list)
87
+ reporters: dict[str, ReporterDef] = field(default_factory=dict)
89
88
 
90
89
 
91
90
  def update_evaluators(eval_state: EvaluatorState, handles, terminate_on_failure):
@@ -157,7 +156,7 @@ async def run_evaluator_task(evaluator, position, opts: EvaluatorOpts):
157
156
  experiment.flush()
158
157
 
159
158
 
160
- def resolve_reporter(reporter: Optional[Union[ReporterDef, str]], reporters: Dict[str, ReporterDef]) -> ReporterDef:
159
+ def resolve_reporter(reporter: ReporterDef | str | None, reporters: dict[str, ReporterDef]) -> ReporterDef:
161
160
  if isinstance(reporter, str):
162
161
  if reporter not in reporters:
163
162
  raise ValueError(f"Reporter {reporter} not found")
@@ -179,7 +178,7 @@ def add_report(eval_reports, reporter, report):
179
178
  eval_reports[reporter.name]["results"].append(report)
180
179
 
181
180
 
182
- async def run_once(handles: List[FileHandle], evaluator_opts: EvaluatorOpts) -> bool:
181
+ async def run_once(handles: list[FileHandle], evaluator_opts: EvaluatorOpts) -> bool:
183
182
  objects = EvaluatorState()
184
183
  update_evaluators(objects, handles, terminate_on_failure=evaluator_opts.terminate_on_failure)
185
184
 
braintrust/cli/push.py CHANGED
@@ -13,7 +13,7 @@ import sys
13
13
  import tempfile
14
14
  import textwrap
15
15
  import zipfile
16
- from typing import Any, Dict, List, Optional
16
+ from typing import Any
17
17
 
18
18
  import requests
19
19
  from braintrust.framework import _set_lazy_load
@@ -24,7 +24,7 @@ from ..generated_types import IfExists
24
24
  from ..util import add_azure_blob_headers
25
25
 
26
26
 
27
- def _pkg_install_arg(pkg) -> Optional[str]:
27
+ def _pkg_install_arg(pkg) -> str | None:
28
28
  try:
29
29
  dist = importlib.metadata.distribution(pkg)
30
30
  direct_url = dist._path / "direct_url.json" # type: ignore
@@ -73,11 +73,11 @@ class _ProjectRootImporter(importlib.abc.MetaPathFinder):
73
73
  self._project_root, self._path_rest = sys.path[0], sys.path[1:]
74
74
  self._sources = []
75
75
 
76
- def _under_project_root(self, path: List[str]) -> bool:
76
+ def _under_project_root(self, path: list[str]) -> bool:
77
77
  """Returns true if all paths in `path` are under the project root."""
78
78
  return all(p.startswith(self._project_root) for p in path)
79
79
 
80
- def _under_rest(self, path: List[str]) -> bool:
80
+ def _under_rest(self, path: list[str]) -> bool:
81
81
  """Returns true if any path in `path` is under one of the remaining paths in `sys.path`."""
82
82
  return any(p.startswith(pr) for p in path for pr in self._path_rest)
83
83
 
@@ -94,11 +94,11 @@ class _ProjectRootImporter(importlib.abc.MetaPathFinder):
94
94
  self._sources.append(spec.origin)
95
95
  return spec
96
96
 
97
- def sources(self) -> List[str]:
97
+ def sources(self) -> list[str]:
98
98
  return self._sources
99
99
 
100
100
 
101
- def _import_module(name: str, path: str) -> List[str]:
101
+ def _import_module(name: str, path: str) -> list[str]:
102
102
  """Imports the module and returns the list of source files
103
103
  of all modules imported in the process.
104
104
 
@@ -120,7 +120,7 @@ def _py_version() -> str:
120
120
  return f"{sys.version_info.major}.{sys.version_info.minor}"
121
121
 
122
122
 
123
- def _run_install(install_args: List[str], packages_dir: str):
123
+ def _run_install(install_args: list[str], packages_dir: str):
124
124
  subprocess.run(
125
125
  [
126
126
  "uv",
@@ -138,7 +138,7 @@ def _run_install(install_args: List[str], packages_dir: str):
138
138
  )
139
139
 
140
140
 
141
- def _upload_bundle(entry_module_name: str, sources: List[str], requirements: Optional[str]) -> str:
141
+ def _upload_bundle(entry_module_name: str, sources: list[str], requirements: str | None) -> str:
142
142
  _check_uv()
143
143
 
144
144
  resp = proxy_conn().post_json(
@@ -212,7 +212,7 @@ def _upload_bundle(entry_module_name: str, sources: List[str], requirements: Opt
212
212
 
213
213
 
214
214
  def _collect_function_function_defs(
215
- project_ids: ProjectIdCache, functions: List[Dict[str, Any]], bundle_id: str, if_exists: IfExists
215
+ project_ids: ProjectIdCache, functions: list[dict[str, Any]], bundle_id: str, if_exists: IfExists
216
216
  ) -> None:
217
217
  for i, f in enumerate(global_.functions):
218
218
  source = inspect.getsource(f.handler)
@@ -262,7 +262,7 @@ def _collect_function_function_defs(
262
262
 
263
263
 
264
264
  def _collect_prompt_function_defs(
265
- project_ids: ProjectIdCache, functions: List[Dict[str, Any]], if_exists: IfExists
265
+ project_ids: ProjectIdCache, functions: list[dict[str, Any]], if_exists: IfExists
266
266
  ) -> None:
267
267
  for p in global_.prompts:
268
268
  functions.append(p.to_function_definition(if_exists, project_ids))
@@ -300,7 +300,7 @@ def run(args):
300
300
  raise
301
301
 
302
302
  project_ids = ProjectIdCache()
303
- functions: List[Dict[str, Any]] = []
303
+ functions: list[dict[str, Any]] = []
304
304
  if len(global_.functions) > 0:
305
305
  bundle_id = _upload_bundle(module_name, sources, args.requirements)
306
306
  _collect_function_function_defs(project_ids, functions, bundle_id, args.if_exists)
braintrust/conftest.py CHANGED
@@ -36,6 +36,7 @@ def override_app_url_for_tests():
36
36
  @pytest.fixture(autouse=True)
37
37
  def setup_braintrust():
38
38
  os.environ.setdefault("GOOGLE_API_KEY", os.getenv("GEMINI_API_KEY", "your_google_api_key_here"))
39
+ os.environ.setdefault("OPENAI_API_KEY", "sk-test-dummy-api-key-for-vcr-tests")
39
40
 
40
41
 
41
42
  @pytest.fixture(autouse=True)
braintrust/context.py CHANGED
@@ -5,12 +5,13 @@ import os
5
5
  from abc import ABC, abstractmethod
6
6
  from contextvars import ContextVar
7
7
  from dataclasses import dataclass
8
- from typing import Any, List, Optional
8
+ from typing import Any
9
9
 
10
10
 
11
11
  @dataclass
12
12
  class SpanInfo:
13
13
  """Information about a span in the context."""
14
+
14
15
  trace_id: str
15
16
  span_id: str
16
17
  span_object: Any = None
@@ -19,7 +20,7 @@ class SpanInfo:
19
20
  @dataclass
20
21
  class ParentSpanIds:
21
22
  root_span_id: str
22
- span_parents: List[str]
23
+ span_parents: list[str]
23
24
 
24
25
 
25
26
  class ContextManager(ABC):
@@ -30,7 +31,7 @@ class ContextManager(ABC):
30
31
  """
31
32
 
32
33
  @abstractmethod
33
- def get_current_span_info(self) -> Optional[Any]:
34
+ def get_current_span_info(self) -> Any | None:
34
35
  """Get information about the currently active span.
35
36
 
36
37
  Returns:
@@ -40,7 +41,7 @@ class ContextManager(ABC):
40
41
  pass
41
42
 
42
43
  @abstractmethod
43
- def get_parent_span_ids(self) -> Optional[ParentSpanIds]:
44
+ def get_parent_span_ids(self) -> ParentSpanIds | None:
44
45
  """Get parent span IDs for creating a new Braintrust span.
45
46
 
46
47
  Returns:
@@ -75,32 +76,25 @@ class BraintrustContextManager(ContextManager):
75
76
  """Braintrust-only context manager using contextvars when OTEL is not available."""
76
77
 
77
78
  def __init__(self):
78
- self._current_span: ContextVar[Optional[Any]] = ContextVar('braintrust_current_span', default=None)
79
+ self._current_span: ContextVar[Any | None] = ContextVar("braintrust_current_span", default=None)
79
80
 
80
- def get_current_span_info(self) -> Optional[SpanInfo]:
81
+ def get_current_span_info(self) -> SpanInfo | None:
81
82
  """Get information about the currently active span."""
82
83
  current_span = self._current_span.get()
83
84
  if not current_span:
84
85
  return None
85
86
 
86
87
  # Return SpanInfo for BT spans
87
- return SpanInfo(
88
- trace_id=current_span.root_span_id,
89
- span_id=current_span.span_id,
90
- span_object=current_span
91
- )
88
+ return SpanInfo(trace_id=current_span.root_span_id, span_id=current_span.span_id, span_object=current_span)
92
89
 
93
- def get_parent_span_ids(self) -> Optional[ParentSpanIds]:
90
+ def get_parent_span_ids(self) -> ParentSpanIds | None:
94
91
  """Get parent information for creating a new Braintrust span."""
95
92
  current_span = self._current_span.get()
96
93
  if not current_span:
97
94
  return None
98
95
 
99
96
  # If current span is a BT span, use it as parent
100
- return ParentSpanIds(
101
- root_span_id=current_span.root_span_id,
102
- span_parents=[current_span.span_id]
103
- )
97
+ return ParentSpanIds(root_span_id=current_span.root_span_id, span_parents=[current_span.span_id])
104
98
 
105
99
  def set_current_span(self, span_object: Any) -> Any:
106
100
  """Set the current active span."""
@@ -123,9 +117,10 @@ def get_context_manager() -> ContextManager:
123
117
  """
124
118
 
125
119
  # Check if OTEL should be explicitly enabled via environment variable
126
- if os.environ.get('BRAINTRUST_OTEL_COMPAT', '').lower() in ('1', 'true', 'yes'):
120
+ if os.environ.get("BRAINTRUST_OTEL_COMPAT", "").lower() in ("1", "true", "yes"):
127
121
  try:
128
122
  from braintrust.otel.context import ContextManager as OtelContextManager
123
+
129
124
  return OtelContextManager()
130
125
  except ImportError:
131
126
  logging.warning("OTEL not available, falling back to Braintrust-only version")
@@ -80,7 +80,8 @@ The integration will automatically:
80
80
  """
81
81
 
82
82
  import dataclasses
83
- from typing import Any, Dict, Mapping, Optional, Type
83
+ from collections.abc import Mapping
84
+ from typing import Any
84
85
 
85
86
  import braintrust
86
87
  import temporalio.activity
@@ -129,7 +130,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
129
130
  - Ensures replay safety (no duplicate spans during workflow replay)
130
131
  """
131
132
 
132
- def __init__(self, logger: Optional[Any] = None) -> None:
133
+ def __init__(self, logger: Any | None = None) -> None:
133
134
  """Initialize interceptor.
134
135
 
135
136
  Args:
@@ -142,7 +143,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
142
143
  braintrust.logger._state._override_bg_logger.logger = logger
143
144
  self._logger = braintrust.current_logger()
144
145
 
145
- def _get_logger(self) -> Optional[Any]:
146
+ def _get_logger(self) -> Any | None:
146
147
  """Get logger for creating spans.
147
148
 
148
149
  Sets thread-local override if background logger provided (for testing),
@@ -152,9 +153,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
152
153
  braintrust.logger._state._override_bg_logger.logger = self._bg_logger
153
154
  return self._logger
154
155
 
155
- def intercept_client(
156
- self, next: temporalio.client.OutboundInterceptor
157
- ) -> temporalio.client.OutboundInterceptor:
156
+ def intercept_client(self, next: temporalio.client.OutboundInterceptor) -> temporalio.client.OutboundInterceptor:
158
157
  """Intercept client calls to propagate span context to workflows."""
159
158
  return _BraintrustClientOutboundInterceptor(next, self)
160
159
 
@@ -166,14 +165,14 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
166
165
 
167
166
  def workflow_interceptor_class(
168
167
  self, input: temporalio.worker.WorkflowInterceptorClassInput
169
- ) -> Optional[Type["BraintrustWorkflowInboundInterceptor"]]:
168
+ ) -> type["BraintrustWorkflowInboundInterceptor"] | None:
170
169
  """Return workflow interceptor class to propagate context to activities."""
171
170
  input.unsafe_extern_functions["__braintrust_get_logger"] = self._get_logger
172
171
  return BraintrustWorkflowInboundInterceptor
173
172
 
174
173
  def _span_context_to_headers(
175
174
  self,
176
- span_context: Dict[str, Any],
175
+ span_context: dict[str, Any],
177
176
  headers: Mapping[str, temporalio.api.common.v1.Payload],
178
177
  ) -> Mapping[str, temporalio.api.common.v1.Payload]:
179
178
  """Add span context to headers."""
@@ -188,7 +187,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
188
187
 
189
188
  def _span_context_from_headers(
190
189
  self, headers: Mapping[str, temporalio.api.common.v1.Payload]
191
- ) -> Optional[Dict[str, Any]]:
190
+ ) -> dict[str, Any] | None:
192
191
  """Extract span context from headers."""
193
192
  if _HEADER_KEY not in headers:
194
193
  return None
@@ -204,9 +203,7 @@ class BraintrustInterceptor(temporalio.client.Interceptor, temporalio.worker.Int
204
203
  class _BraintrustClientOutboundInterceptor(temporalio.client.OutboundInterceptor):
205
204
  """Client interceptor that propagates span context to workflows."""
206
205
 
207
- def __init__(
208
- self, next: temporalio.client.OutboundInterceptor, root: BraintrustInterceptor
209
- ) -> None:
206
+ def __init__(self, next: temporalio.client.OutboundInterceptor, root: BraintrustInterceptor) -> None:
210
207
  super().__init__(next)
211
208
  self.root = root
212
209
 
@@ -233,9 +230,7 @@ class _BraintrustActivityInboundInterceptor(temporalio.worker.ActivityInboundInt
233
230
  super().__init__(next)
234
231
  self.root = root
235
232
 
236
- async def execute_activity(
237
- self, input: temporalio.worker.ExecuteActivityInput
238
- ) -> Any:
233
+ async def execute_activity(self, input: temporalio.worker.ExecuteActivityInput) -> Any:
239
234
  info = temporalio.activity.info()
240
235
 
241
236
  # Extract parent span context from headers
@@ -281,14 +276,12 @@ class BraintrustWorkflowInboundInterceptor(temporalio.worker.WorkflowInboundInte
281
276
  def __init__(self, next: temporalio.worker.WorkflowInboundInterceptor) -> None:
282
277
  super().__init__(next)
283
278
  self.payload_converter = temporalio.converter.PayloadConverter.default
284
- self._parent_span_context: Optional[Dict[str, Any]] = None
279
+ self._parent_span_context: dict[str, Any] | None = None
285
280
 
286
281
  def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None:
287
282
  super().init(_BraintrustWorkflowOutboundInterceptor(outbound, self))
288
283
 
289
- async def execute_workflow(
290
- self, input: temporalio.worker.ExecuteWorkflowInput
291
- ) -> Any:
284
+ async def execute_workflow(self, input: temporalio.worker.ExecuteWorkflowInput) -> Any:
292
285
  # Extract parent span context from workflow headers (set by client)
293
286
  parent_span_context = None
294
287
  if _HEADER_KEY in input.headers:
@@ -342,9 +335,7 @@ class BraintrustWorkflowInboundInterceptor(temporalio.worker.WorkflowInboundInte
342
335
  span.end()
343
336
 
344
337
 
345
- class _BraintrustWorkflowOutboundInterceptor(
346
- temporalio.worker.WorkflowOutboundInterceptor
347
- ):
338
+ class _BraintrustWorkflowOutboundInterceptor(temporalio.worker.WorkflowOutboundInterceptor):
348
339
  """Outbound workflow interceptor that propagates span context to activities."""
349
340
 
350
341
  def __init__(
@@ -371,9 +362,7 @@ class _BraintrustWorkflowOutboundInterceptor(
371
362
  return {**headers, _HEADER_KEY: payloads[0]}
372
363
  return headers
373
364
 
374
- def start_activity(
375
- self, input: temporalio.worker.StartActivityInput
376
- ) -> temporalio.workflow.ActivityHandle:
365
+ def start_activity(self, input: temporalio.worker.StartActivityInput) -> temporalio.workflow.ActivityHandle:
377
366
  input.headers = self._add_span_context_to_headers(input.headers)
378
367
  return super().start_activity(input)
379
368
 
@@ -390,7 +379,7 @@ class _BraintrustWorkflowOutboundInterceptor(
390
379
  return super().start_child_workflow(input)
391
380
 
392
381
 
393
- def _modify_workflow_runner(existing: Optional[WorkflowRunner]) -> Optional[WorkflowRunner]:
382
+ def _modify_workflow_runner(existing: WorkflowRunner | None) -> WorkflowRunner | None:
394
383
  """Add braintrust to sandbox passthrough modules."""
395
384
  if isinstance(existing, SandboxedWorkflowRunner):
396
385
  new_restrictions = existing.restrictions.with_passthrough_modules("braintrust")
@@ -420,7 +409,7 @@ class BraintrustPlugin(SimplePlugin):
420
409
  Requires temporalio >= 1.19.0.
421
410
  """
422
411
 
423
- def __init__(self, logger: Optional[Any] = None) -> None:
412
+ def __init__(self, logger: Any | None = None) -> None:
424
413
  """Initialize the Braintrust plugin.
425
414
 
426
415
  Args:
@@ -279,11 +279,15 @@ class TestBraintrustPluginIntegration:
279
279
 
280
280
  # Verify workflow span was created
281
281
  workflow_spans = [s for s in spans if "temporal.workflow" in s.get("span_attributes", {}).get("name", "")]
282
- assert len(workflow_spans) > 0, f"Expected workflow span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
282
+ assert len(workflow_spans) > 0, (
283
+ f"Expected workflow span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
284
+ )
283
285
 
284
286
  # Verify activity span was created
285
287
  activity_spans = [s for s in spans if "temporal.activity" in s.get("span_attributes", {}).get("name", "")]
286
- assert len(activity_spans) > 0, f"Expected activity span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
288
+ assert len(activity_spans) > 0, (
289
+ f"Expected activity span to be created. Span names: {[s.get('span_attributes', {}).get('name', 'unknown') for s in spans]}"
290
+ )
287
291
 
288
292
  @pytest.mark.asyncio
289
293
  async def test_plugin_context_propagation(self, temporal_env, memory_logger):
@@ -498,5 +502,6 @@ class TestBraintrustPluginIntegration:
498
502
  # Verify parent-child relationship: workflow should have client span as parent
499
503
  workflow_span = workflow_spans[0]
500
504
  client_span = client_spans[0]
501
- assert workflow_span.get("root_span_id") == client_span.get("root_span_id"), \
505
+ assert workflow_span.get("root_span_id") == client_span.get("root_span_id"), (
502
506
  "Workflow span should be in same trace as client span"
507
+ )
@@ -1,5 +1,5 @@
1
+ from collections.abc import Awaitable, Callable
1
2
  from dataclasses import dataclass
2
- from typing import Awaitable, Callable, Dict, Optional
3
3
 
4
4
  from starlette.middleware.base import BaseHTTPMiddleware
5
5
  from starlette.requests import Request
@@ -15,14 +15,14 @@ BRAINTRUST_PROJECT_ID_HEADER = "x-bt-project-id"
15
15
 
16
16
  @dataclass
17
17
  class RequestContext:
18
- app_origin: Optional[str]
19
- token: Optional[str]
20
- org_name: Optional[str]
21
- project_id: Optional[str]
22
- state: Optional[BraintrustState]
18
+ app_origin: str | None
19
+ token: str | None
20
+ org_name: str | None
21
+ project_id: str | None
22
+ state: BraintrustState | None
23
23
 
24
24
 
25
- def extract_allowed_origin(origin: Optional[str]) -> Optional[str]:
25
+ def extract_allowed_origin(origin: str | None) -> str | None:
26
26
  """Extract and validate the origin header."""
27
27
  # This should use the same check_origin logic from cors.py
28
28
  from .cors import check_origin
@@ -32,7 +32,7 @@ def extract_allowed_origin(origin: Optional[str]) -> Optional[str]:
32
32
  return None
33
33
 
34
34
 
35
- def parse_braintrust_auth_header(headers: Dict[str, str]) -> Optional[str]:
35
+ def parse_braintrust_auth_header(headers: dict[str, str]) -> str | None:
36
36
  """Parse the authorization token from headers."""
37
37
  # Check x-bt-auth-token first
38
38
  token = headers.get(BRAINTRUST_AUTH_TOKEN_HEADER)
@@ -1,7 +1,6 @@
1
1
  """LRU cache implementation for the dev server."""
2
2
 
3
3
  import json
4
- from typing import Dict, Optional
5
4
 
6
5
  from ..logger import BraintrustState, login_to_state
7
6
 
@@ -11,10 +10,10 @@ class LRUCache:
11
10
 
12
11
  def __init__(self, max_size: int = 32):
13
12
  self.max_size = max_size
14
- self.cache: Dict[str, BraintrustState] = {}
13
+ self.cache: dict[str, BraintrustState] = {}
15
14
  self.access_order: list[str] = []
16
15
 
17
- def get(self, key: str) -> Optional[BraintrustState]:
16
+ def get(self, key: str) -> BraintrustState | None:
18
17
  """Get a value from the cache, updating access order."""
19
18
  if key in self.cache:
20
19
  # Move to end to mark as recently used
@@ -41,7 +40,7 @@ class LRUCache:
41
40
  _login_cache = LRUCache(max_size=32) # TODO: Make this configurable
42
41
 
43
42
 
44
- async def cached_login(api_key: str, app_url: str, org_name: Optional[str] = None) -> BraintrustState:
43
+ async def cached_login(api_key: str, app_url: str, org_name: str | None = None) -> BraintrustState:
45
44
  """Login with caching to avoid repeated API calls."""
46
45
  cache_key = json.dumps({"api_key": api_key, "app_url": app_url, "org_name": org_name})
47
46
 
@@ -1,9 +1,10 @@
1
1
  import os
2
2
  import re
3
- from typing import Any, Awaitable, Callable, Dict, List, Union
3
+ from collections.abc import Awaitable, Callable
4
+ from typing import Any
4
5
 
5
6
  # CORS configuration
6
- ALLOWED_ORIGINS: List[Union[str, re.Pattern]] = [
7
+ ALLOWED_ORIGINS: list[str | re.Pattern] = [
7
8
  "https://www.braintrust.dev",
8
9
  "https://www.braintrustdata.com",
9
10
  re.compile(r"https://.*\.preview\.braintrust\.dev"),
@@ -70,9 +71,9 @@ def create_cors_middleware() -> type:
70
71
 
71
72
  async def __call__(
72
73
  self,
73
- scope: Dict[str, Any],
74
- receive: Callable[[], Awaitable[Dict[str, Any]]],
75
- send: Callable[[Dict[str, Any]], Awaitable[None]],
74
+ scope: dict[str, Any],
75
+ receive: Callable[[], Awaitable[dict[str, Any]]],
76
+ send: Callable[[dict[str, Any]], Awaitable[None]],
76
77
  ) -> None:
77
78
  if scope["type"] == "http":
78
79
  headers = dict(scope["headers"])
@@ -81,7 +82,7 @@ def create_cors_middleware() -> type:
81
82
  # Handle OPTIONS requests
82
83
  if scope["method"] == "OPTIONS":
83
84
 
84
- async def send_options_wrapper(message: Dict[str, Any]) -> None:
85
+ async def send_options_wrapper(message: dict[str, Any]) -> None:
85
86
  if message["type"] == "http.response.start":
86
87
  headers_dict = dict(message.get("headers", []))
87
88
 
@@ -120,7 +121,7 @@ def create_cors_middleware() -> type:
120
121
  return
121
122
 
122
123
  # For other requests, add CORS headers if origin is valid
123
- async def send_wrapper(message: Dict[str, Any]) -> None:
124
+ async def send_wrapper(message: dict[str, Any]) -> None:
124
125
  if message["type"] == "http.response.start" and origin and check_origin(origin):
125
126
  headers_dict = dict(message.get("headers", []))
126
127
 
@@ -1,11 +1,11 @@
1
- from typing import Any, Dict, Union
1
+ from typing import Any
2
2
 
3
3
  from braintrust import init_dataset
4
4
  from braintrust._generated_types import RunEvalData, RunEvalData1, RunEvalData2
5
5
  from braintrust.logger import BraintrustState
6
6
 
7
7
 
8
- async def get_dataset_by_id(state: BraintrustState, dataset_id: str) -> Dict[str, str]:
8
+ async def get_dataset_by_id(state: BraintrustState, dataset_id: str) -> dict[str, str]:
9
9
  """Fetch dataset information by ID."""
10
10
  # Make API call to get dataset info
11
11
  conn = state.api_conn()
@@ -23,9 +23,7 @@ async def get_dataset_by_id(state: BraintrustState, dataset_id: str) -> Dict[str
23
23
 
24
24
 
25
25
  # NOTE: To make this performant, we'll have to make these functions work with async i/o
26
- async def get_dataset(
27
- state: BraintrustState, data: Union[RunEvalData, RunEvalData1, RunEvalData2, Dict[str, Any]]
28
- ) -> Any:
26
+ async def get_dataset(state: BraintrustState, data: RunEvalData | RunEvalData1 | RunEvalData2 | dict[str, Any]) -> Any:
29
27
  """
30
28
  Get dataset from various data sources.
31
29
 
@@ -7,7 +7,8 @@ for reporting progress during evaluation execution.
7
7
 
8
8
  import asyncio
9
9
  import json
10
- from typing import Any, Callable, Dict, Optional
10
+ from collections.abc import Callable
11
+ from typing import Any
11
12
 
12
13
 
13
14
  class EvalHooks:
@@ -15,13 +16,13 @@ class EvalHooks:
15
16
 
16
17
  def __init__(
17
18
  self,
18
- report_progress: Optional[Callable[[Dict[str, Any]], None]] = None,
19
- parameters: Optional[Dict[str, Any]] = None,
19
+ report_progress: Callable[[dict[str, Any]], None] | None = None,
20
+ parameters: dict[str, Any] | None = None,
20
21
  ):
21
22
  self._report_progress = report_progress
22
23
  self.parameters = parameters or {}
23
24
 
24
- def report_progress(self, event: Dict[str, Any]) -> None:
25
+ def report_progress(self, event: dict[str, Any]) -> None:
25
26
  """Report progress during task execution."""
26
27
  if self._report_progress:
27
28
  self._report_progress(event)
@@ -45,7 +46,7 @@ class SSEQueue:
45
46
  """Simple wrapper around asyncio.Queue for SSE events."""
46
47
 
47
48
  def __init__(self):
48
- self.queue: asyncio.Queue[Optional[str]] = asyncio.Queue()
49
+ self.queue: asyncio.Queue[str | None] = asyncio.Queue()
49
50
 
50
51
  async def put_event(self, event: str, data: Any) -> None:
51
52
  """Add an SSE event to the queue."""
@@ -56,6 +57,6 @@ class SSEQueue:
56
57
  """Signal end of stream."""
57
58
  await self.queue.put(None)
58
59
 
59
- async def get(self) -> Optional[str]:
60
+ async def get(self) -> str | None:
60
61
  """Get the next event from the queue."""
61
62
  return await self.queue.get()