braintrust 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
braintrust/conftest.py CHANGED
@@ -36,6 +36,7 @@ def override_app_url_for_tests():
36
36
  @pytest.fixture(autouse=True)
37
37
  def setup_braintrust():
38
38
  os.environ.setdefault("GOOGLE_API_KEY", os.getenv("GEMINI_API_KEY", "your_google_api_key_here"))
39
+ os.environ.setdefault("OPENAI_API_KEY", "sk-test-dummy-api-key-for-vcr-tests")
39
40
 
40
41
 
41
42
  @pytest.fixture(autouse=True)
@@ -1,14 +1,31 @@
1
- from typing import Any, Dict, List, Literal, Optional, TypeVar, Union, overload
1
+ from typing import Any, Dict, List, Literal, Optional, TypedDict, TypeVar, Union, overload
2
2
 
3
3
  from sseclient import SSEClient
4
4
 
5
+ from .._generated_types import InvokeContext
5
6
  from ..logger import Exportable, get_span_parent_object, login, proxy_conn
6
7
  from ..util import response_raise_for_status
7
8
  from .constants import INVOKE_API_VERSION
8
9
  from .stream import BraintrustInvokeError, BraintrustStream
9
10
 
10
11
  T = TypeVar("T")
11
- ModeType = Literal["auto", "parallel"]
12
+ ModeType = Literal["auto", "parallel", "json", "text"]
13
+ ObjectType = Literal["project_logs", "experiment", "dataset", "playground_logs"]
14
+
15
+
16
+ class SpanScope(TypedDict):
17
+ """Scope for operating on a single span."""
18
+
19
+ type: Literal["span"]
20
+ id: str
21
+ root_span_id: str
22
+
23
+
24
+ class TraceScope(TypedDict):
25
+ """Scope for operating on an entire trace."""
26
+
27
+ type: Literal["trace"]
28
+ root_span_id: str
12
29
 
13
30
 
14
31
  @overload
@@ -19,11 +36,13 @@ def invoke(
19
36
  prompt_session_id: Optional[str] = None,
20
37
  prompt_session_function_id: Optional[str] = None,
21
38
  project_name: Optional[str] = None,
39
+ project_id: Optional[str] = None,
22
40
  slug: Optional[str] = None,
23
41
  global_function: Optional[str] = None,
24
42
  # arguments to the function
25
43
  input: Any = None,
26
44
  messages: Optional[List[Any]] = None,
45
+ context: Optional[InvokeContext] = None,
27
46
  metadata: Optional[Dict[str, Any]] = None,
28
47
  tags: Optional[List[str]] = None,
29
48
  parent: Optional[Union[Exportable, str]] = None,
@@ -45,11 +64,13 @@ def invoke(
45
64
  prompt_session_id: Optional[str] = None,
46
65
  prompt_session_function_id: Optional[str] = None,
47
66
  project_name: Optional[str] = None,
67
+ project_id: Optional[str] = None,
48
68
  slug: Optional[str] = None,
49
69
  global_function: Optional[str] = None,
50
70
  # arguments to the function
51
71
  input: Any = None,
52
72
  messages: Optional[List[Any]] = None,
73
+ context: Optional[InvokeContext] = None,
53
74
  metadata: Optional[Dict[str, Any]] = None,
54
75
  tags: Optional[List[str]] = None,
55
76
  parent: Optional[Union[Exportable, str]] = None,
@@ -70,11 +91,13 @@ def invoke(
70
91
  prompt_session_id: Optional[str] = None,
71
92
  prompt_session_function_id: Optional[str] = None,
72
93
  project_name: Optional[str] = None,
94
+ project_id: Optional[str] = None,
73
95
  slug: Optional[str] = None,
74
96
  global_function: Optional[str] = None,
75
97
  # arguments to the function
76
98
  input: Any = None,
77
99
  messages: Optional[List[Any]] = None,
100
+ context: Optional[InvokeContext] = None,
78
101
  metadata: Optional[Dict[str, Any]] = None,
79
102
  tags: Optional[List[str]] = None,
80
103
  parent: Optional[Union[Exportable, str]] = None,
@@ -93,6 +116,8 @@ def invoke(
93
116
  Args:
94
117
  input: The input to the function. This will be logged as the `input` field in the span.
95
118
  messages: Additional OpenAI-style messages to add to the prompt (only works for llm functions).
119
+ context: Context for functions that operate on spans/traces (e.g., facets). Should contain
120
+ `object_type`, `object_id`, and `scope` fields.
96
121
  metadata: Additional metadata to add to the span. This will be logged as the `metadata` field in the span.
97
122
  It will also be available as the {{metadata}} field in the prompt and as the `metadata` argument
98
123
  to the function.
@@ -118,6 +143,8 @@ def invoke(
118
143
  prompt_session_id: The ID of the prompt session to invoke the function from.
119
144
  prompt_session_function_id: The ID of the function in the prompt session to invoke.
120
145
  project_name: The name of the project containing the function to invoke.
146
+ project_id: The ID of the project to use for execution context (API keys, project defaults, etc.).
147
+ This is not the project the function belongs to, but the project context for the invocation.
121
148
  slug: The slug of the function to invoke.
122
149
  global_function: The name of the global function to invoke.
123
150
 
@@ -161,12 +188,18 @@ def invoke(
161
188
  )
162
189
  if messages is not None:
163
190
  request["messages"] = messages
191
+ if context is not None:
192
+ request["context"] = context
164
193
  if mode is not None:
165
194
  request["mode"] = mode
166
195
  if strict is not None:
167
196
  request["strict"] = strict
168
197
 
169
198
  headers = {"Accept": "text/event-stream" if stream else "application/json"}
199
+ if project_id is not None:
200
+ headers["x-bt-project-id"] = project_id
201
+ if org_name is not None:
202
+ headers["x-bt-org-name"] = org_name
170
203
 
171
204
  resp = proxy_conn().post("function/invoke", json=request, headers=headers, stream=stream)
172
205
  if resp.status_code == 500:
@@ -1,4 +1,4 @@
1
- """Auto-generated file (internal git SHA 8e9c0a96b3cf291360978c17580f72f6817bd6c8) -- do not modify"""
1
+ """Auto-generated file (internal git SHA 437eb5379a737f70dec98033fccf81de43e8e177) -- do not modify"""
2
2
 
3
3
  from ._generated_types import (
4
4
  Acl,
@@ -32,6 +32,7 @@ from ._generated_types import (
32
32
  ExperimentEvent,
33
33
  ExtendedSavedFunctionId,
34
34
  ExternalAttachmentReference,
35
+ FacetData,
35
36
  Function,
36
37
  FunctionData,
37
38
  FunctionFormat,
@@ -47,10 +48,14 @@ from ._generated_types import (
47
48
  GraphNode,
48
49
  Group,
49
50
  IfExists,
51
+ InvokeContext,
50
52
  InvokeFunction,
51
53
  InvokeParent,
54
+ InvokeScope,
55
+ MCPServer,
52
56
  MessageRole,
53
57
  ModelParams,
58
+ NullableSavedFunctionId,
54
59
  ObjectReference,
55
60
  ObjectReferenceNullish,
56
61
  OnlineScoreConfig,
@@ -86,11 +91,13 @@ from ._generated_types import (
86
91
  ServiceToken,
87
92
  SpanAttributes,
88
93
  SpanIFrame,
94
+ SpanScope,
89
95
  SpanType,
90
96
  SSEConsoleEventData,
91
97
  SSEProgressEventData,
92
98
  StreamingMode,
93
99
  ToolFunctionDefinition,
100
+ TraceScope,
94
101
  UploadStatus,
95
102
  User,
96
103
  View,
@@ -131,6 +138,7 @@ __all__ = [
131
138
  "ExperimentEvent",
132
139
  "ExtendedSavedFunctionId",
133
140
  "ExternalAttachmentReference",
141
+ "FacetData",
134
142
  "Function",
135
143
  "FunctionData",
136
144
  "FunctionFormat",
@@ -146,10 +154,14 @@ __all__ = [
146
154
  "GraphNode",
147
155
  "Group",
148
156
  "IfExists",
157
+ "InvokeContext",
149
158
  "InvokeFunction",
150
159
  "InvokeParent",
160
+ "InvokeScope",
161
+ "MCPServer",
151
162
  "MessageRole",
152
163
  "ModelParams",
164
+ "NullableSavedFunctionId",
153
165
  "ObjectReference",
154
166
  "ObjectReferenceNullish",
155
167
  "OnlineScoreConfig",
@@ -187,9 +199,11 @@ __all__ = [
187
199
  "ServiceToken",
188
200
  "SpanAttributes",
189
201
  "SpanIFrame",
202
+ "SpanScope",
190
203
  "SpanType",
191
204
  "StreamingMode",
192
205
  "ToolFunctionDefinition",
206
+ "TraceScope",
193
207
  "UploadStatus",
194
208
  "User",
195
209
  "View",
braintrust/gitutil.py CHANGED
@@ -88,8 +88,12 @@ def get_past_n_ancestors(n=1000, remote=None):
88
88
  if ancestor_output is None:
89
89
  return
90
90
  ancestor = repo.commit(ancestor_output)
91
+ count = 0
91
92
  for _ in range(n):
93
+ if count >= n:
94
+ break
92
95
  yield ancestor.hexsha
96
+ count += 1
93
97
  try:
94
98
  if ancestor.parents:
95
99
  ancestor = ancestor.parents[0]
braintrust/logger.py CHANGED
@@ -1104,7 +1104,7 @@ class _HTTPBackgroundLogger:
1104
1104
  _HTTPBackgroundLogger._write_payload_to_dir(payload_dir=self.all_publish_payloads_dir, payload=dataStr)
1105
1105
  for i in range(self.num_tries):
1106
1106
  start_time = time.time()
1107
- resp = conn.post("/logs3", data=dataStr)
1107
+ resp = conn.post("/logs3", data=dataStr.encode("utf-8"))
1108
1108
  if resp.ok:
1109
1109
  return
1110
1110
  resp_errmsg = f"{resp.status_code}: {resp.text}"
braintrust/oai.py CHANGED
@@ -1,8 +1,10 @@
1
1
  import abc
2
+ import base64
3
+ import re
2
4
  import time
3
- from typing import Any, Callable, Dict, List, Optional
5
+ from typing import Any, Callable, Dict, List, Optional, Union
4
6
 
5
- from .logger import Span, start_span
7
+ from .logger import Attachment, Span, start_span
6
8
  from .span_types import SpanTypeAttribute
7
9
  from .util import merge_dicts
8
10
 
@@ -68,6 +70,75 @@ def log_headers(response: Any, span: Span):
68
70
  )
69
71
 
70
72
 
73
+ def _convert_data_url_to_attachment(data_url: str, filename: Optional[str] = None) -> Union[Attachment, str]:
74
+ """Helper function to convert data URL to an Attachment."""
75
+ data_url_match = re.match(r"^data:([^;]+);base64,(.+)$", data_url)
76
+ if not data_url_match:
77
+ return data_url
78
+
79
+ mime_type, base64_data = data_url_match.groups()
80
+
81
+ try:
82
+ binary_data = base64.b64decode(base64_data)
83
+
84
+ if filename is None:
85
+ extension = mime_type.split("/")[1] if "/" in mime_type else "bin"
86
+ prefix = "image" if mime_type.startswith("image/") else "document"
87
+ filename = f"{prefix}.{extension}"
88
+
89
+ attachment = Attachment(data=binary_data, filename=filename, content_type=mime_type)
90
+
91
+ return attachment
92
+ except Exception:
93
+ return data_url
94
+
95
+
96
+ def _process_attachments_in_input(input_data: Any) -> Any:
97
+ """Process input to convert data URL images and base64 documents to Attachment objects."""
98
+ if isinstance(input_data, list):
99
+ return [_process_attachments_in_input(item) for item in input_data]
100
+
101
+ if isinstance(input_data, dict):
102
+ # Check for OpenAI's image_url format with data URLs
103
+ if (
104
+ input_data.get("type") == "image_url"
105
+ and isinstance(input_data.get("image_url"), dict)
106
+ and isinstance(input_data["image_url"].get("url"), str)
107
+ ):
108
+ processed_url = _convert_data_url_to_attachment(input_data["image_url"]["url"])
109
+ return {
110
+ **input_data,
111
+ "image_url": {
112
+ **input_data["image_url"],
113
+ "url": processed_url,
114
+ },
115
+ }
116
+
117
+ # Check for OpenAI's file format with data URL (e.g., PDFs)
118
+ if (
119
+ input_data.get("type") == "file"
120
+ and isinstance(input_data.get("file"), dict)
121
+ and isinstance(input_data["file"].get("file_data"), str)
122
+ ):
123
+ file_filename = input_data["file"].get("filename")
124
+ processed_file_data = _convert_data_url_to_attachment(
125
+ input_data["file"]["file_data"],
126
+ filename=file_filename if isinstance(file_filename, str) else None,
127
+ )
128
+ return {
129
+ **input_data,
130
+ "file": {
131
+ **input_data["file"],
132
+ "file_data": processed_file_data,
133
+ },
134
+ }
135
+
136
+ # Recursively process nested objects
137
+ return {key: _process_attachments_in_input(value) for key, value in input_data.items()}
138
+
139
+ return input_data
140
+
141
+
71
142
  class ChatCompletionWrapper:
72
143
  def __init__(self, create_fn: Optional[Callable[..., Any]], acreate_fn: Optional[Callable[..., Any]]):
73
144
  self.create_fn = create_fn
@@ -190,10 +261,14 @@ class ChatCompletionWrapper:
190
261
  # Then, copy the rest of the params
191
262
  params = prettify_params(params)
192
263
  messages = params.pop("messages", None)
264
+
265
+ # Process attachments in input (convert data URLs to Attachment objects)
266
+ processed_input = _process_attachments_in_input(messages)
267
+
193
268
  return merge_dicts(
194
269
  ret,
195
270
  {
196
- "input": messages,
271
+ "input": processed_input,
197
272
  "metadata": {**params, "provider": "openai"},
198
273
  },
199
274
  )
@@ -379,10 +454,14 @@ class ResponseWrapper:
379
454
  # Then, copy the rest of the params
380
455
  params = prettify_params(params)
381
456
  input_data = params.pop("input", None)
457
+
458
+ # Process attachments in input (convert data URLs to Attachment objects)
459
+ processed_input = _process_attachments_in_input(input_data)
460
+
382
461
  return merge_dicts(
383
462
  ret,
384
463
  {
385
- "input": input_data,
464
+ "input": processed_input,
386
465
  "metadata": {**params, "provider": "openai"},
387
466
  },
388
467
  )
@@ -540,12 +619,15 @@ class BaseWrapper(abc.ABC):
540
619
  ret = params.pop("span_info", {})
541
620
 
542
621
  params = prettify_params(params)
543
- input = params.pop("input", None)
622
+ input_data = params.pop("input", None)
623
+
624
+ # Process attachments in input (convert data URLs to Attachment objects)
625
+ processed_input = _process_attachments_in_input(input_data)
544
626
 
545
627
  return merge_dicts(
546
628
  ret,
547
629
  {
548
- "input": input,
630
+ "input": processed_input,
549
631
  "metadata": {**params, "provider": "openai"},
550
632
  },
551
633
  )
braintrust/score.py CHANGED
@@ -34,6 +34,7 @@ class Score(SerializableDataClass):
34
34
 
35
35
  def as_dict(self):
36
36
  return {
37
+ "name": self.name,
37
38
  "score": self.score,
38
39
  "metadata": self.metadata,
39
40
  }
@@ -0,0 +1,157 @@
1
+ import json
2
+ import unittest
3
+
4
+ from .score import Score
5
+
6
+
7
+ class TestScore(unittest.TestCase):
8
+ def test_as_dict_includes_all_required_fields(self):
9
+ """Test that as_dict() includes name, score, and metadata fields."""
10
+ score = Score(name="test_scorer", score=0.85, metadata={"key": "value"})
11
+ result = score.as_dict()
12
+
13
+ self.assertIn("name", result)
14
+ self.assertIn("score", result)
15
+ self.assertIn("metadata", result)
16
+
17
+ self.assertEqual(result["name"], "test_scorer")
18
+ self.assertEqual(result["score"], 0.85)
19
+ self.assertEqual(result["metadata"], {"key": "value"})
20
+
21
+ def test_as_dict_with_null_score(self):
22
+ """Test that as_dict() works correctly with null score."""
23
+ score = Score(name="null_scorer", score=None, metadata={})
24
+ result = score.as_dict()
25
+
26
+ self.assertEqual(result["name"], "null_scorer")
27
+ self.assertIsNone(result["score"])
28
+ self.assertEqual(result["metadata"], {})
29
+
30
+ def test_as_dict_with_empty_metadata(self):
31
+ """Test that as_dict() works correctly with empty metadata."""
32
+ score = Score(name="empty_metadata_scorer", score=1.0)
33
+ result = score.as_dict()
34
+
35
+ self.assertEqual(result["name"], "empty_metadata_scorer")
36
+ self.assertEqual(result["score"], 1.0)
37
+ self.assertEqual(result["metadata"], {})
38
+
39
+ def test_as_dict_with_complex_metadata(self):
40
+ """Test that as_dict() works correctly with complex nested metadata."""
41
+ complex_metadata = {
42
+ "reason": "Test reason",
43
+ "details": {"nested": {"deeply": "value"}},
44
+ "list": [1, 2, 3],
45
+ "bool": True,
46
+ }
47
+ score = Score(name="complex_scorer", score=0.5, metadata=complex_metadata)
48
+ result = score.as_dict()
49
+
50
+ self.assertEqual(result["name"], "complex_scorer")
51
+ self.assertEqual(result["score"], 0.5)
52
+ self.assertEqual(result["metadata"], complex_metadata)
53
+
54
+ def test_as_json_serialization(self):
55
+ """Test that as_json() produces valid JSON string."""
56
+ score = Score(name="json_scorer", score=0.75, metadata={"test": "data"})
57
+ json_str = score.as_json()
58
+
59
+ # Should be valid JSON
60
+ parsed = json.loads(json_str)
61
+
62
+ self.assertEqual(parsed["name"], "json_scorer")
63
+ self.assertEqual(parsed["score"], 0.75)
64
+ self.assertEqual(parsed["metadata"], {"test": "data"})
65
+
66
+ def test_from_dict_round_trip(self):
67
+ """Test that Score can be serialized to dict and deserialized back."""
68
+ original = Score(
69
+ name="round_trip_scorer", score=0.95, metadata={"info": "test"}
70
+ )
71
+
72
+ # Serialize to dict
73
+ as_dict = original.as_dict()
74
+
75
+ # Deserialize from dict
76
+ restored = Score.from_dict(as_dict)
77
+
78
+ self.assertEqual(restored.name, original.name)
79
+ self.assertEqual(restored.score, original.score)
80
+ self.assertEqual(restored.metadata, original.metadata)
81
+
82
+ def test_array_of_scores_serialization(self):
83
+ """Test that arrays of Score objects can be serialized correctly."""
84
+ scores = [
85
+ Score(name="score_1", score=0.8, metadata={"index": 1}),
86
+ Score(name="score_2", score=0.6, metadata={"index": 2}),
87
+ Score(name="score_3", score=None, metadata={}),
88
+ ]
89
+
90
+ # Serialize each score
91
+ serialized = [s.as_dict() for s in scores]
92
+
93
+ # Check that all scores have required fields
94
+ for i, s_dict in enumerate(serialized):
95
+ self.assertIn("name", s_dict)
96
+ self.assertIn("score", s_dict)
97
+ self.assertIn("metadata", s_dict)
98
+ self.assertEqual(s_dict["name"], f"score_{i + 1}")
99
+
100
+ # Check specific values
101
+ self.assertEqual(serialized[0]["score"], 0.8)
102
+ self.assertEqual(serialized[1]["score"], 0.6)
103
+ self.assertIsNone(serialized[2]["score"])
104
+
105
+ def test_array_of_scores_json_serialization(self):
106
+ """Test that arrays of Score objects can be JSON serialized."""
107
+ scores = [
108
+ Score(name="json_score_1", score=0.9),
109
+ Score(name="json_score_2", score=0.7),
110
+ ]
111
+
112
+ # Serialize to JSON
113
+ serialized = [s.as_dict() for s in scores]
114
+ json_str = json.dumps(serialized)
115
+
116
+ # Parse back
117
+ parsed = json.loads(json_str)
118
+
119
+ self.assertEqual(len(parsed), 2)
120
+ self.assertEqual(parsed[0]["name"], "json_score_1")
121
+ self.assertEqual(parsed[0]["score"], 0.9)
122
+ self.assertEqual(parsed[1]["name"], "json_score_2")
123
+ self.assertEqual(parsed[1]["score"], 0.7)
124
+
125
+ def test_score_validation_enforces_bounds(self):
126
+ """Test that Score validates score values are between 0 and 1."""
127
+ # Valid scores
128
+ Score(name="valid_0", score=0.0)
129
+ Score(name="valid_1", score=1.0)
130
+ Score(name="valid_mid", score=0.5)
131
+ Score(name="valid_null", score=None)
132
+
133
+ # Invalid scores
134
+ with self.assertRaises(ValueError):
135
+ Score(name="invalid_negative", score=-0.1)
136
+
137
+ with self.assertRaises(ValueError):
138
+ Score(name="invalid_over_one", score=1.1)
139
+
140
+ def test_score_does_not_include_deprecated_error_field(self):
141
+ """Test that as_dict() does not include the deprecated error field."""
142
+ score = Score(name="test_scorer", score=0.5)
143
+ result = score.as_dict()
144
+
145
+ # The error field should not be in the serialized output
146
+ self.assertNotIn("error", result)
147
+
148
+ # Even if error was set (though deprecated), it shouldn't be in as_dict
149
+ score_with_error = Score(name="error_scorer", score=0.5)
150
+ score_with_error.error = Exception("test") # Set after construction
151
+ result_with_error = score_with_error.as_dict()
152
+
153
+ self.assertNotIn("error", result_with_error)
154
+
155
+
156
+ if __name__ == "__main__":
157
+ unittest.main()
braintrust/version.py CHANGED
@@ -1,4 +1,4 @@
1
- VERSION = "0.3.13"
1
+ VERSION = "0.3.15"
2
2
 
3
3
  # this will be templated during the build
4
- GIT_COMMIT = "cef88a007fa60f4cd873f1d891a54ce5e173f3aa"
4
+ GIT_COMMIT = "dcd4f5a4be171b1cac28a5eb3534e4b55420cc06"