braintrust 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/__init__.py +4 -0
- braintrust/_generated_types.py +596 -72
- braintrust/conftest.py +1 -0
- braintrust/functions/invoke.py +35 -2
- braintrust/generated_types.py +15 -1
- braintrust/gitutil.py +4 -0
- braintrust/logger.py +1 -1
- braintrust/oai.py +88 -6
- braintrust/score.py +1 -0
- braintrust/test_score.py +157 -0
- braintrust/version.py +2 -2
- braintrust/wrappers/pydantic_ai.py +1203 -0
- braintrust/wrappers/test_oai_attachments.py +322 -0
- braintrust/wrappers/test_pydantic_ai_integration.py +1788 -0
- braintrust/wrappers/{test_pydantic_ai.py → test_pydantic_ai_wrap_openai.py} +1 -2
- {braintrust-0.3.13.dist-info → braintrust-0.3.15.dist-info}/METADATA +1 -1
- {braintrust-0.3.13.dist-info → braintrust-0.3.15.dist-info}/RECORD +20 -16
- {braintrust-0.3.13.dist-info → braintrust-0.3.15.dist-info}/WHEEL +0 -0
- {braintrust-0.3.13.dist-info → braintrust-0.3.15.dist-info}/entry_points.txt +0 -0
- {braintrust-0.3.13.dist-info → braintrust-0.3.15.dist-info}/top_level.txt +0 -0
braintrust/conftest.py
CHANGED
|
@@ -36,6 +36,7 @@ def override_app_url_for_tests():
|
|
|
36
36
|
@pytest.fixture(autouse=True)
|
|
37
37
|
def setup_braintrust():
|
|
38
38
|
os.environ.setdefault("GOOGLE_API_KEY", os.getenv("GEMINI_API_KEY", "your_google_api_key_here"))
|
|
39
|
+
os.environ.setdefault("OPENAI_API_KEY", "sk-test-dummy-api-key-for-vcr-tests")
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
@pytest.fixture(autouse=True)
|
braintrust/functions/invoke.py
CHANGED
|
@@ -1,14 +1,31 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Literal, Optional, TypeVar, Union, overload
|
|
1
|
+
from typing import Any, Dict, List, Literal, Optional, TypedDict, TypeVar, Union, overload
|
|
2
2
|
|
|
3
3
|
from sseclient import SSEClient
|
|
4
4
|
|
|
5
|
+
from .._generated_types import InvokeContext
|
|
5
6
|
from ..logger import Exportable, get_span_parent_object, login, proxy_conn
|
|
6
7
|
from ..util import response_raise_for_status
|
|
7
8
|
from .constants import INVOKE_API_VERSION
|
|
8
9
|
from .stream import BraintrustInvokeError, BraintrustStream
|
|
9
10
|
|
|
10
11
|
T = TypeVar("T")
|
|
11
|
-
ModeType = Literal["auto", "parallel"]
|
|
12
|
+
ModeType = Literal["auto", "parallel", "json", "text"]
|
|
13
|
+
ObjectType = Literal["project_logs", "experiment", "dataset", "playground_logs"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SpanScope(TypedDict):
|
|
17
|
+
"""Scope for operating on a single span."""
|
|
18
|
+
|
|
19
|
+
type: Literal["span"]
|
|
20
|
+
id: str
|
|
21
|
+
root_span_id: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TraceScope(TypedDict):
|
|
25
|
+
"""Scope for operating on an entire trace."""
|
|
26
|
+
|
|
27
|
+
type: Literal["trace"]
|
|
28
|
+
root_span_id: str
|
|
12
29
|
|
|
13
30
|
|
|
14
31
|
@overload
|
|
@@ -19,11 +36,13 @@ def invoke(
|
|
|
19
36
|
prompt_session_id: Optional[str] = None,
|
|
20
37
|
prompt_session_function_id: Optional[str] = None,
|
|
21
38
|
project_name: Optional[str] = None,
|
|
39
|
+
project_id: Optional[str] = None,
|
|
22
40
|
slug: Optional[str] = None,
|
|
23
41
|
global_function: Optional[str] = None,
|
|
24
42
|
# arguments to the function
|
|
25
43
|
input: Any = None,
|
|
26
44
|
messages: Optional[List[Any]] = None,
|
|
45
|
+
context: Optional[InvokeContext] = None,
|
|
27
46
|
metadata: Optional[Dict[str, Any]] = None,
|
|
28
47
|
tags: Optional[List[str]] = None,
|
|
29
48
|
parent: Optional[Union[Exportable, str]] = None,
|
|
@@ -45,11 +64,13 @@ def invoke(
|
|
|
45
64
|
prompt_session_id: Optional[str] = None,
|
|
46
65
|
prompt_session_function_id: Optional[str] = None,
|
|
47
66
|
project_name: Optional[str] = None,
|
|
67
|
+
project_id: Optional[str] = None,
|
|
48
68
|
slug: Optional[str] = None,
|
|
49
69
|
global_function: Optional[str] = None,
|
|
50
70
|
# arguments to the function
|
|
51
71
|
input: Any = None,
|
|
52
72
|
messages: Optional[List[Any]] = None,
|
|
73
|
+
context: Optional[InvokeContext] = None,
|
|
53
74
|
metadata: Optional[Dict[str, Any]] = None,
|
|
54
75
|
tags: Optional[List[str]] = None,
|
|
55
76
|
parent: Optional[Union[Exportable, str]] = None,
|
|
@@ -70,11 +91,13 @@ def invoke(
|
|
|
70
91
|
prompt_session_id: Optional[str] = None,
|
|
71
92
|
prompt_session_function_id: Optional[str] = None,
|
|
72
93
|
project_name: Optional[str] = None,
|
|
94
|
+
project_id: Optional[str] = None,
|
|
73
95
|
slug: Optional[str] = None,
|
|
74
96
|
global_function: Optional[str] = None,
|
|
75
97
|
# arguments to the function
|
|
76
98
|
input: Any = None,
|
|
77
99
|
messages: Optional[List[Any]] = None,
|
|
100
|
+
context: Optional[InvokeContext] = None,
|
|
78
101
|
metadata: Optional[Dict[str, Any]] = None,
|
|
79
102
|
tags: Optional[List[str]] = None,
|
|
80
103
|
parent: Optional[Union[Exportable, str]] = None,
|
|
@@ -93,6 +116,8 @@ def invoke(
|
|
|
93
116
|
Args:
|
|
94
117
|
input: The input to the function. This will be logged as the `input` field in the span.
|
|
95
118
|
messages: Additional OpenAI-style messages to add to the prompt (only works for llm functions).
|
|
119
|
+
context: Context for functions that operate on spans/traces (e.g., facets). Should contain
|
|
120
|
+
`object_type`, `object_id`, and `scope` fields.
|
|
96
121
|
metadata: Additional metadata to add to the span. This will be logged as the `metadata` field in the span.
|
|
97
122
|
It will also be available as the {{metadata}} field in the prompt and as the `metadata` argument
|
|
98
123
|
to the function.
|
|
@@ -118,6 +143,8 @@ def invoke(
|
|
|
118
143
|
prompt_session_id: The ID of the prompt session to invoke the function from.
|
|
119
144
|
prompt_session_function_id: The ID of the function in the prompt session to invoke.
|
|
120
145
|
project_name: The name of the project containing the function to invoke.
|
|
146
|
+
project_id: The ID of the project to use for execution context (API keys, project defaults, etc.).
|
|
147
|
+
This is not the project the function belongs to, but the project context for the invocation.
|
|
121
148
|
slug: The slug of the function to invoke.
|
|
122
149
|
global_function: The name of the global function to invoke.
|
|
123
150
|
|
|
@@ -161,12 +188,18 @@ def invoke(
|
|
|
161
188
|
)
|
|
162
189
|
if messages is not None:
|
|
163
190
|
request["messages"] = messages
|
|
191
|
+
if context is not None:
|
|
192
|
+
request["context"] = context
|
|
164
193
|
if mode is not None:
|
|
165
194
|
request["mode"] = mode
|
|
166
195
|
if strict is not None:
|
|
167
196
|
request["strict"] = strict
|
|
168
197
|
|
|
169
198
|
headers = {"Accept": "text/event-stream" if stream else "application/json"}
|
|
199
|
+
if project_id is not None:
|
|
200
|
+
headers["x-bt-project-id"] = project_id
|
|
201
|
+
if org_name is not None:
|
|
202
|
+
headers["x-bt-org-name"] = org_name
|
|
170
203
|
|
|
171
204
|
resp = proxy_conn().post("function/invoke", json=request, headers=headers, stream=stream)
|
|
172
205
|
if resp.status_code == 500:
|
braintrust/generated_types.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Auto-generated file (internal git SHA
|
|
1
|
+
"""Auto-generated file (internal git SHA 437eb5379a737f70dec98033fccf81de43e8e177) -- do not modify"""
|
|
2
2
|
|
|
3
3
|
from ._generated_types import (
|
|
4
4
|
Acl,
|
|
@@ -32,6 +32,7 @@ from ._generated_types import (
|
|
|
32
32
|
ExperimentEvent,
|
|
33
33
|
ExtendedSavedFunctionId,
|
|
34
34
|
ExternalAttachmentReference,
|
|
35
|
+
FacetData,
|
|
35
36
|
Function,
|
|
36
37
|
FunctionData,
|
|
37
38
|
FunctionFormat,
|
|
@@ -47,10 +48,14 @@ from ._generated_types import (
|
|
|
47
48
|
GraphNode,
|
|
48
49
|
Group,
|
|
49
50
|
IfExists,
|
|
51
|
+
InvokeContext,
|
|
50
52
|
InvokeFunction,
|
|
51
53
|
InvokeParent,
|
|
54
|
+
InvokeScope,
|
|
55
|
+
MCPServer,
|
|
52
56
|
MessageRole,
|
|
53
57
|
ModelParams,
|
|
58
|
+
NullableSavedFunctionId,
|
|
54
59
|
ObjectReference,
|
|
55
60
|
ObjectReferenceNullish,
|
|
56
61
|
OnlineScoreConfig,
|
|
@@ -86,11 +91,13 @@ from ._generated_types import (
|
|
|
86
91
|
ServiceToken,
|
|
87
92
|
SpanAttributes,
|
|
88
93
|
SpanIFrame,
|
|
94
|
+
SpanScope,
|
|
89
95
|
SpanType,
|
|
90
96
|
SSEConsoleEventData,
|
|
91
97
|
SSEProgressEventData,
|
|
92
98
|
StreamingMode,
|
|
93
99
|
ToolFunctionDefinition,
|
|
100
|
+
TraceScope,
|
|
94
101
|
UploadStatus,
|
|
95
102
|
User,
|
|
96
103
|
View,
|
|
@@ -131,6 +138,7 @@ __all__ = [
|
|
|
131
138
|
"ExperimentEvent",
|
|
132
139
|
"ExtendedSavedFunctionId",
|
|
133
140
|
"ExternalAttachmentReference",
|
|
141
|
+
"FacetData",
|
|
134
142
|
"Function",
|
|
135
143
|
"FunctionData",
|
|
136
144
|
"FunctionFormat",
|
|
@@ -146,10 +154,14 @@ __all__ = [
|
|
|
146
154
|
"GraphNode",
|
|
147
155
|
"Group",
|
|
148
156
|
"IfExists",
|
|
157
|
+
"InvokeContext",
|
|
149
158
|
"InvokeFunction",
|
|
150
159
|
"InvokeParent",
|
|
160
|
+
"InvokeScope",
|
|
161
|
+
"MCPServer",
|
|
151
162
|
"MessageRole",
|
|
152
163
|
"ModelParams",
|
|
164
|
+
"NullableSavedFunctionId",
|
|
153
165
|
"ObjectReference",
|
|
154
166
|
"ObjectReferenceNullish",
|
|
155
167
|
"OnlineScoreConfig",
|
|
@@ -187,9 +199,11 @@ __all__ = [
|
|
|
187
199
|
"ServiceToken",
|
|
188
200
|
"SpanAttributes",
|
|
189
201
|
"SpanIFrame",
|
|
202
|
+
"SpanScope",
|
|
190
203
|
"SpanType",
|
|
191
204
|
"StreamingMode",
|
|
192
205
|
"ToolFunctionDefinition",
|
|
206
|
+
"TraceScope",
|
|
193
207
|
"UploadStatus",
|
|
194
208
|
"User",
|
|
195
209
|
"View",
|
braintrust/gitutil.py
CHANGED
|
@@ -88,8 +88,12 @@ def get_past_n_ancestors(n=1000, remote=None):
|
|
|
88
88
|
if ancestor_output is None:
|
|
89
89
|
return
|
|
90
90
|
ancestor = repo.commit(ancestor_output)
|
|
91
|
+
count = 0
|
|
91
92
|
for _ in range(n):
|
|
93
|
+
if count >= n:
|
|
94
|
+
break
|
|
92
95
|
yield ancestor.hexsha
|
|
96
|
+
count += 1
|
|
93
97
|
try:
|
|
94
98
|
if ancestor.parents:
|
|
95
99
|
ancestor = ancestor.parents[0]
|
braintrust/logger.py
CHANGED
|
@@ -1104,7 +1104,7 @@ class _HTTPBackgroundLogger:
|
|
|
1104
1104
|
_HTTPBackgroundLogger._write_payload_to_dir(payload_dir=self.all_publish_payloads_dir, payload=dataStr)
|
|
1105
1105
|
for i in range(self.num_tries):
|
|
1106
1106
|
start_time = time.time()
|
|
1107
|
-
resp = conn.post("/logs3", data=dataStr)
|
|
1107
|
+
resp = conn.post("/logs3", data=dataStr.encode("utf-8"))
|
|
1108
1108
|
if resp.ok:
|
|
1109
1109
|
return
|
|
1110
1110
|
resp_errmsg = f"{resp.status_code}: {resp.text}"
|
braintrust/oai.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import abc
|
|
2
|
+
import base64
|
|
3
|
+
import re
|
|
2
4
|
import time
|
|
3
|
-
from typing import Any, Callable, Dict, List, Optional
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
4
6
|
|
|
5
|
-
from .logger import Span, start_span
|
|
7
|
+
from .logger import Attachment, Span, start_span
|
|
6
8
|
from .span_types import SpanTypeAttribute
|
|
7
9
|
from .util import merge_dicts
|
|
8
10
|
|
|
@@ -68,6 +70,75 @@ def log_headers(response: Any, span: Span):
|
|
|
68
70
|
)
|
|
69
71
|
|
|
70
72
|
|
|
73
|
+
def _convert_data_url_to_attachment(data_url: str, filename: Optional[str] = None) -> Union[Attachment, str]:
|
|
74
|
+
"""Helper function to convert data URL to an Attachment."""
|
|
75
|
+
data_url_match = re.match(r"^data:([^;]+);base64,(.+)$", data_url)
|
|
76
|
+
if not data_url_match:
|
|
77
|
+
return data_url
|
|
78
|
+
|
|
79
|
+
mime_type, base64_data = data_url_match.groups()
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
binary_data = base64.b64decode(base64_data)
|
|
83
|
+
|
|
84
|
+
if filename is None:
|
|
85
|
+
extension = mime_type.split("/")[1] if "/" in mime_type else "bin"
|
|
86
|
+
prefix = "image" if mime_type.startswith("image/") else "document"
|
|
87
|
+
filename = f"{prefix}.{extension}"
|
|
88
|
+
|
|
89
|
+
attachment = Attachment(data=binary_data, filename=filename, content_type=mime_type)
|
|
90
|
+
|
|
91
|
+
return attachment
|
|
92
|
+
except Exception:
|
|
93
|
+
return data_url
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _process_attachments_in_input(input_data: Any) -> Any:
|
|
97
|
+
"""Process input to convert data URL images and base64 documents to Attachment objects."""
|
|
98
|
+
if isinstance(input_data, list):
|
|
99
|
+
return [_process_attachments_in_input(item) for item in input_data]
|
|
100
|
+
|
|
101
|
+
if isinstance(input_data, dict):
|
|
102
|
+
# Check for OpenAI's image_url format with data URLs
|
|
103
|
+
if (
|
|
104
|
+
input_data.get("type") == "image_url"
|
|
105
|
+
and isinstance(input_data.get("image_url"), dict)
|
|
106
|
+
and isinstance(input_data["image_url"].get("url"), str)
|
|
107
|
+
):
|
|
108
|
+
processed_url = _convert_data_url_to_attachment(input_data["image_url"]["url"])
|
|
109
|
+
return {
|
|
110
|
+
**input_data,
|
|
111
|
+
"image_url": {
|
|
112
|
+
**input_data["image_url"],
|
|
113
|
+
"url": processed_url,
|
|
114
|
+
},
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
# Check for OpenAI's file format with data URL (e.g., PDFs)
|
|
118
|
+
if (
|
|
119
|
+
input_data.get("type") == "file"
|
|
120
|
+
and isinstance(input_data.get("file"), dict)
|
|
121
|
+
and isinstance(input_data["file"].get("file_data"), str)
|
|
122
|
+
):
|
|
123
|
+
file_filename = input_data["file"].get("filename")
|
|
124
|
+
processed_file_data = _convert_data_url_to_attachment(
|
|
125
|
+
input_data["file"]["file_data"],
|
|
126
|
+
filename=file_filename if isinstance(file_filename, str) else None,
|
|
127
|
+
)
|
|
128
|
+
return {
|
|
129
|
+
**input_data,
|
|
130
|
+
"file": {
|
|
131
|
+
**input_data["file"],
|
|
132
|
+
"file_data": processed_file_data,
|
|
133
|
+
},
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
# Recursively process nested objects
|
|
137
|
+
return {key: _process_attachments_in_input(value) for key, value in input_data.items()}
|
|
138
|
+
|
|
139
|
+
return input_data
|
|
140
|
+
|
|
141
|
+
|
|
71
142
|
class ChatCompletionWrapper:
|
|
72
143
|
def __init__(self, create_fn: Optional[Callable[..., Any]], acreate_fn: Optional[Callable[..., Any]]):
|
|
73
144
|
self.create_fn = create_fn
|
|
@@ -190,10 +261,14 @@ class ChatCompletionWrapper:
|
|
|
190
261
|
# Then, copy the rest of the params
|
|
191
262
|
params = prettify_params(params)
|
|
192
263
|
messages = params.pop("messages", None)
|
|
264
|
+
|
|
265
|
+
# Process attachments in input (convert data URLs to Attachment objects)
|
|
266
|
+
processed_input = _process_attachments_in_input(messages)
|
|
267
|
+
|
|
193
268
|
return merge_dicts(
|
|
194
269
|
ret,
|
|
195
270
|
{
|
|
196
|
-
"input":
|
|
271
|
+
"input": processed_input,
|
|
197
272
|
"metadata": {**params, "provider": "openai"},
|
|
198
273
|
},
|
|
199
274
|
)
|
|
@@ -379,10 +454,14 @@ class ResponseWrapper:
|
|
|
379
454
|
# Then, copy the rest of the params
|
|
380
455
|
params = prettify_params(params)
|
|
381
456
|
input_data = params.pop("input", None)
|
|
457
|
+
|
|
458
|
+
# Process attachments in input (convert data URLs to Attachment objects)
|
|
459
|
+
processed_input = _process_attachments_in_input(input_data)
|
|
460
|
+
|
|
382
461
|
return merge_dicts(
|
|
383
462
|
ret,
|
|
384
463
|
{
|
|
385
|
-
"input":
|
|
464
|
+
"input": processed_input,
|
|
386
465
|
"metadata": {**params, "provider": "openai"},
|
|
387
466
|
},
|
|
388
467
|
)
|
|
@@ -540,12 +619,15 @@ class BaseWrapper(abc.ABC):
|
|
|
540
619
|
ret = params.pop("span_info", {})
|
|
541
620
|
|
|
542
621
|
params = prettify_params(params)
|
|
543
|
-
|
|
622
|
+
input_data = params.pop("input", None)
|
|
623
|
+
|
|
624
|
+
# Process attachments in input (convert data URLs to Attachment objects)
|
|
625
|
+
processed_input = _process_attachments_in_input(input_data)
|
|
544
626
|
|
|
545
627
|
return merge_dicts(
|
|
546
628
|
ret,
|
|
547
629
|
{
|
|
548
|
-
"input":
|
|
630
|
+
"input": processed_input,
|
|
549
631
|
"metadata": {**params, "provider": "openai"},
|
|
550
632
|
},
|
|
551
633
|
)
|
braintrust/score.py
CHANGED
braintrust/test_score.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import unittest
|
|
3
|
+
|
|
4
|
+
from .score import Score
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TestScore(unittest.TestCase):
|
|
8
|
+
def test_as_dict_includes_all_required_fields(self):
|
|
9
|
+
"""Test that as_dict() includes name, score, and metadata fields."""
|
|
10
|
+
score = Score(name="test_scorer", score=0.85, metadata={"key": "value"})
|
|
11
|
+
result = score.as_dict()
|
|
12
|
+
|
|
13
|
+
self.assertIn("name", result)
|
|
14
|
+
self.assertIn("score", result)
|
|
15
|
+
self.assertIn("metadata", result)
|
|
16
|
+
|
|
17
|
+
self.assertEqual(result["name"], "test_scorer")
|
|
18
|
+
self.assertEqual(result["score"], 0.85)
|
|
19
|
+
self.assertEqual(result["metadata"], {"key": "value"})
|
|
20
|
+
|
|
21
|
+
def test_as_dict_with_null_score(self):
|
|
22
|
+
"""Test that as_dict() works correctly with null score."""
|
|
23
|
+
score = Score(name="null_scorer", score=None, metadata={})
|
|
24
|
+
result = score.as_dict()
|
|
25
|
+
|
|
26
|
+
self.assertEqual(result["name"], "null_scorer")
|
|
27
|
+
self.assertIsNone(result["score"])
|
|
28
|
+
self.assertEqual(result["metadata"], {})
|
|
29
|
+
|
|
30
|
+
def test_as_dict_with_empty_metadata(self):
|
|
31
|
+
"""Test that as_dict() works correctly with empty metadata."""
|
|
32
|
+
score = Score(name="empty_metadata_scorer", score=1.0)
|
|
33
|
+
result = score.as_dict()
|
|
34
|
+
|
|
35
|
+
self.assertEqual(result["name"], "empty_metadata_scorer")
|
|
36
|
+
self.assertEqual(result["score"], 1.0)
|
|
37
|
+
self.assertEqual(result["metadata"], {})
|
|
38
|
+
|
|
39
|
+
def test_as_dict_with_complex_metadata(self):
|
|
40
|
+
"""Test that as_dict() works correctly with complex nested metadata."""
|
|
41
|
+
complex_metadata = {
|
|
42
|
+
"reason": "Test reason",
|
|
43
|
+
"details": {"nested": {"deeply": "value"}},
|
|
44
|
+
"list": [1, 2, 3],
|
|
45
|
+
"bool": True,
|
|
46
|
+
}
|
|
47
|
+
score = Score(name="complex_scorer", score=0.5, metadata=complex_metadata)
|
|
48
|
+
result = score.as_dict()
|
|
49
|
+
|
|
50
|
+
self.assertEqual(result["name"], "complex_scorer")
|
|
51
|
+
self.assertEqual(result["score"], 0.5)
|
|
52
|
+
self.assertEqual(result["metadata"], complex_metadata)
|
|
53
|
+
|
|
54
|
+
def test_as_json_serialization(self):
|
|
55
|
+
"""Test that as_json() produces valid JSON string."""
|
|
56
|
+
score = Score(name="json_scorer", score=0.75, metadata={"test": "data"})
|
|
57
|
+
json_str = score.as_json()
|
|
58
|
+
|
|
59
|
+
# Should be valid JSON
|
|
60
|
+
parsed = json.loads(json_str)
|
|
61
|
+
|
|
62
|
+
self.assertEqual(parsed["name"], "json_scorer")
|
|
63
|
+
self.assertEqual(parsed["score"], 0.75)
|
|
64
|
+
self.assertEqual(parsed["metadata"], {"test": "data"})
|
|
65
|
+
|
|
66
|
+
def test_from_dict_round_trip(self):
|
|
67
|
+
"""Test that Score can be serialized to dict and deserialized back."""
|
|
68
|
+
original = Score(
|
|
69
|
+
name="round_trip_scorer", score=0.95, metadata={"info": "test"}
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Serialize to dict
|
|
73
|
+
as_dict = original.as_dict()
|
|
74
|
+
|
|
75
|
+
# Deserialize from dict
|
|
76
|
+
restored = Score.from_dict(as_dict)
|
|
77
|
+
|
|
78
|
+
self.assertEqual(restored.name, original.name)
|
|
79
|
+
self.assertEqual(restored.score, original.score)
|
|
80
|
+
self.assertEqual(restored.metadata, original.metadata)
|
|
81
|
+
|
|
82
|
+
def test_array_of_scores_serialization(self):
|
|
83
|
+
"""Test that arrays of Score objects can be serialized correctly."""
|
|
84
|
+
scores = [
|
|
85
|
+
Score(name="score_1", score=0.8, metadata={"index": 1}),
|
|
86
|
+
Score(name="score_2", score=0.6, metadata={"index": 2}),
|
|
87
|
+
Score(name="score_3", score=None, metadata={}),
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
# Serialize each score
|
|
91
|
+
serialized = [s.as_dict() for s in scores]
|
|
92
|
+
|
|
93
|
+
# Check that all scores have required fields
|
|
94
|
+
for i, s_dict in enumerate(serialized):
|
|
95
|
+
self.assertIn("name", s_dict)
|
|
96
|
+
self.assertIn("score", s_dict)
|
|
97
|
+
self.assertIn("metadata", s_dict)
|
|
98
|
+
self.assertEqual(s_dict["name"], f"score_{i + 1}")
|
|
99
|
+
|
|
100
|
+
# Check specific values
|
|
101
|
+
self.assertEqual(serialized[0]["score"], 0.8)
|
|
102
|
+
self.assertEqual(serialized[1]["score"], 0.6)
|
|
103
|
+
self.assertIsNone(serialized[2]["score"])
|
|
104
|
+
|
|
105
|
+
def test_array_of_scores_json_serialization(self):
|
|
106
|
+
"""Test that arrays of Score objects can be JSON serialized."""
|
|
107
|
+
scores = [
|
|
108
|
+
Score(name="json_score_1", score=0.9),
|
|
109
|
+
Score(name="json_score_2", score=0.7),
|
|
110
|
+
]
|
|
111
|
+
|
|
112
|
+
# Serialize to JSON
|
|
113
|
+
serialized = [s.as_dict() for s in scores]
|
|
114
|
+
json_str = json.dumps(serialized)
|
|
115
|
+
|
|
116
|
+
# Parse back
|
|
117
|
+
parsed = json.loads(json_str)
|
|
118
|
+
|
|
119
|
+
self.assertEqual(len(parsed), 2)
|
|
120
|
+
self.assertEqual(parsed[0]["name"], "json_score_1")
|
|
121
|
+
self.assertEqual(parsed[0]["score"], 0.9)
|
|
122
|
+
self.assertEqual(parsed[1]["name"], "json_score_2")
|
|
123
|
+
self.assertEqual(parsed[1]["score"], 0.7)
|
|
124
|
+
|
|
125
|
+
def test_score_validation_enforces_bounds(self):
|
|
126
|
+
"""Test that Score validates score values are between 0 and 1."""
|
|
127
|
+
# Valid scores
|
|
128
|
+
Score(name="valid_0", score=0.0)
|
|
129
|
+
Score(name="valid_1", score=1.0)
|
|
130
|
+
Score(name="valid_mid", score=0.5)
|
|
131
|
+
Score(name="valid_null", score=None)
|
|
132
|
+
|
|
133
|
+
# Invalid scores
|
|
134
|
+
with self.assertRaises(ValueError):
|
|
135
|
+
Score(name="invalid_negative", score=-0.1)
|
|
136
|
+
|
|
137
|
+
with self.assertRaises(ValueError):
|
|
138
|
+
Score(name="invalid_over_one", score=1.1)
|
|
139
|
+
|
|
140
|
+
def test_score_does_not_include_deprecated_error_field(self):
|
|
141
|
+
"""Test that as_dict() does not include the deprecated error field."""
|
|
142
|
+
score = Score(name="test_scorer", score=0.5)
|
|
143
|
+
result = score.as_dict()
|
|
144
|
+
|
|
145
|
+
# The error field should not be in the serialized output
|
|
146
|
+
self.assertNotIn("error", result)
|
|
147
|
+
|
|
148
|
+
# Even if error was set (though deprecated), it shouldn't be in as_dict
|
|
149
|
+
score_with_error = Score(name="error_scorer", score=0.5)
|
|
150
|
+
score_with_error.error = Exception("test") # Set after construction
|
|
151
|
+
result_with_error = score_with_error.as_dict()
|
|
152
|
+
|
|
153
|
+
self.assertNotIn("error", result_with_error)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
if __name__ == "__main__":
|
|
157
|
+
unittest.main()
|
braintrust/version.py
CHANGED