braintrust 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/__init__.py +11 -0
- braintrust/_generated_types.py +193 -11
- braintrust/db_fields.py +10 -0
- braintrust/generated_types.py +7 -1
- braintrust/logger.py +324 -125
- braintrust/merge_row_batch.py +49 -109
- braintrust/oai.py +13 -4
- braintrust/test_context.py +1264 -0
- braintrust/test_merge_row_batch.py +160 -0
- braintrust/version.py +2 -2
- braintrust/wrappers/test_openai.py +100 -1
- braintrust/wrappers/threads.py +114 -0
- {braintrust-0.5.2.dist-info → braintrust-0.5.4.dist-info}/METADATA +1 -1
- {braintrust-0.5.2.dist-info → braintrust-0.5.4.dist-info}/RECORD +17 -15
- braintrust/graph_util.py +0 -147
- {braintrust-0.5.2.dist-info → braintrust-0.5.4.dist-info}/WHEEL +0 -0
- {braintrust-0.5.2.dist-info → braintrust-0.5.4.dist-info}/entry_points.txt +0 -0
- {braintrust-0.5.2.dist-info → braintrust-0.5.4.dist-info}/top_level.txt +0 -0
braintrust/merge_row_batch.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
|
-
from collections.abc import
|
|
2
|
-
from typing import Any, Optional
|
|
1
|
+
from collections.abc import Callable, Sequence
|
|
2
|
+
from typing import Any, Optional, TypeVar
|
|
3
3
|
|
|
4
|
-
from .db_fields import IS_MERGE_FIELD
|
|
5
|
-
|
|
4
|
+
from .db_fields import IS_MERGE_FIELD
|
|
5
|
+
|
|
6
|
+
T = TypeVar("T")
|
|
6
7
|
from .util import merge_dicts
|
|
7
8
|
|
|
8
9
|
_MergedRowKey = tuple[Optional[Any], ...]
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
def _generate_merged_row_key(row:
|
|
12
|
+
def _generate_merged_row_key(row: dict[str, Any]) -> _MergedRowKey:
|
|
12
13
|
return tuple(
|
|
13
14
|
row.get(k)
|
|
14
15
|
for k in [
|
|
@@ -18,7 +19,7 @@ def _generate_merged_row_key(row: Mapping[str, Any], use_parent_id_for_id: bool
|
|
|
18
19
|
"dataset_id",
|
|
19
20
|
"prompt_session_id",
|
|
20
21
|
"log_id",
|
|
21
|
-
|
|
22
|
+
"id",
|
|
22
23
|
]
|
|
23
24
|
)
|
|
24
25
|
|
|
@@ -49,12 +50,9 @@ def _restore_merge_row_skip_fields(row: dict[str, Any], skip_fields: dict[str, A
|
|
|
49
50
|
row[field] = skip_fields[field]
|
|
50
51
|
|
|
51
52
|
|
|
52
|
-
def merge_row_batch(rows: Sequence[dict[str, Any]]) -> list[
|
|
53
|
+
def merge_row_batch(rows: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
53
54
|
"""Given a batch of rows, merges conflicting rows together to end up with a
|
|
54
|
-
set of rows to insert. Returns a set of de-conflicted rows
|
|
55
|
-
lists, where separate lists contain "independent" rows which can be
|
|
56
|
-
processed concurrently, while the rows in each list must be processed in
|
|
57
|
-
order, as later rows may depend on earlier ones.
|
|
55
|
+
set of rows to insert. Returns a set of de-conflicted rows as a flat list.
|
|
58
56
|
|
|
59
57
|
Note that the returned rows will be the same objects as the input `rows`,
|
|
60
58
|
meaning they are mutated in place.
|
|
@@ -118,65 +116,31 @@ def merge_row_batch(rows: Sequence[dict[str, Any]]) -> list[list[dict[str, Any]]
|
|
|
118
116
|
else:
|
|
119
117
|
row_groups[key] = row
|
|
120
118
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
# Now that we have just one row per id, we can bucket and order the rows by
|
|
124
|
-
# their PARENT_ID_FIELD relationships.
|
|
125
|
-
row_to_label = {_generate_merged_row_key(r): i for i, r in enumerate(merged)}
|
|
126
|
-
|
|
127
|
-
# Form a graph where edges go from parents to their children.
|
|
128
|
-
graph = {i: set() for i in range(len(merged))}
|
|
129
|
-
for i, r in enumerate(merged):
|
|
130
|
-
parent_id = r.get(PARENT_ID_FIELD)
|
|
131
|
-
if not parent_id:
|
|
132
|
-
continue
|
|
133
|
-
parent_row_key = _generate_merged_row_key(r, use_parent_id_for_id=True)
|
|
134
|
-
parent_label = row_to_label.get(parent_row_key)
|
|
135
|
-
if parent_label is not None:
|
|
136
|
-
graph[parent_label].add(i)
|
|
137
|
-
|
|
138
|
-
# Group together all the connected components of the undirected graph to get
|
|
139
|
-
# all groups of rows which each row in a group has a PARENT_ID_FIELD
|
|
140
|
-
# relationship with at least one other row in the group.
|
|
141
|
-
connected_components = undirected_connected_components(
|
|
142
|
-
UndirectedGraph(vertices=set(graph.keys()), edges={(k, v) for k, vs in graph.items() for v in vs})
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
# For each connected row group, run topological sort over that subgraph to
|
|
146
|
-
# get an ordering of rows where parents come before children.
|
|
147
|
-
buckets = [topological_sort(graph, visitation_order=cc) for cc in connected_components]
|
|
148
|
-
return [[merged[i] for i in bucket] for bucket in buckets]
|
|
119
|
+
return list(row_groups.values())
|
|
149
120
|
|
|
150
121
|
|
|
151
122
|
def batch_items(
|
|
152
|
-
items: list[
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
published in sequence. Within each set of batches, each individual List[str]
|
|
159
|
-
batch may be published in parallel with all other batches in its set,
|
|
160
|
-
retaining the order within the batch. So from outside to inside, it goes
|
|
161
|
-
ordered -> parallel -> ordered.
|
|
123
|
+
items: list[T],
|
|
124
|
+
batch_max_num_items: int | None = None,
|
|
125
|
+
batch_max_num_bytes: int | None = None,
|
|
126
|
+
get_byte_size: Callable[[T], int] | None = None,
|
|
127
|
+
) -> list[list[T]]:
|
|
128
|
+
"""Repartition the given list of items into batches.
|
|
162
129
|
|
|
163
130
|
Arguments:
|
|
164
131
|
|
|
165
|
-
- `items` is a list of
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
`sum(len(item) for item in batch)`) in each List[str] batch. If an
|
|
178
|
-
individual item exceeds `batch_max_num_bytes` in size, we will place it in
|
|
179
|
-
its own batch. If not provided, there is no limit on the number of bytes.
|
|
132
|
+
- `items` is a list of items to batch.
|
|
133
|
+
|
|
134
|
+
- `batch_max_num_items` is the maximum number of items in each batch.
|
|
135
|
+
If not provided, there is no limit on the number of items.
|
|
136
|
+
|
|
137
|
+
- `batch_max_num_bytes` is the maximum number of bytes in each batch.
|
|
138
|
+
If an individual item exceeds `batch_max_num_bytes` in size, we
|
|
139
|
+
will place it in its own batch. If not provided, there is no limit on
|
|
140
|
+
the number of bytes.
|
|
141
|
+
|
|
142
|
+
- `get_byte_size` is a function that returns the byte size of an item.
|
|
143
|
+
If not provided, defaults to `len(item)` (works for strings).
|
|
180
144
|
"""
|
|
181
145
|
|
|
182
146
|
if batch_max_num_items is not None and batch_max_num_items <= 0:
|
|
@@ -184,60 +148,36 @@ def batch_items(
|
|
|
184
148
|
if batch_max_num_bytes is not None and batch_max_num_bytes < 0:
|
|
185
149
|
raise ValueError(f"batch_max_num_bytes must be nonnegative; got {batch_max_num_bytes}")
|
|
186
150
|
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
151
|
+
if get_byte_size is None:
|
|
152
|
+
|
|
153
|
+
def get_byte_size(item: T) -> int:
|
|
154
|
+
return len(item) # type: ignore[arg-type]
|
|
155
|
+
|
|
156
|
+
output: list[list[T]] = []
|
|
157
|
+
batch: list[T] = []
|
|
191
158
|
batch_len = 0
|
|
192
159
|
|
|
193
|
-
def add_to_batch(item):
|
|
160
|
+
def add_to_batch(item: T) -> None:
|
|
194
161
|
nonlocal batch_len
|
|
195
162
|
batch.append(item)
|
|
196
|
-
batch_len +=
|
|
163
|
+
batch_len += get_byte_size(item)
|
|
197
164
|
|
|
198
|
-
def flush_batch():
|
|
165
|
+
def flush_batch() -> None:
|
|
199
166
|
nonlocal batch, batch_len
|
|
200
|
-
|
|
167
|
+
output.append(batch)
|
|
201
168
|
batch = []
|
|
202
169
|
batch_len = 0
|
|
203
170
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
and (batch_max_num_items is None or len(batch) < batch_max_num_items)
|
|
211
|
-
):
|
|
212
|
-
add_to_batch(item)
|
|
213
|
-
elif i == 0:
|
|
214
|
-
# If the very first item in the bucket fills the batch, we
|
|
215
|
-
# can flush this batch and start a new one which includes
|
|
216
|
-
# this item.
|
|
217
|
-
flush_batch()
|
|
218
|
-
add_to_batch(item)
|
|
219
|
-
else:
|
|
220
|
-
break
|
|
221
|
-
i += 1
|
|
222
|
-
# If we didn't completely exhaust the bucket, save it for the next
|
|
223
|
-
# batch set.
|
|
224
|
-
if i < len(bucket):
|
|
225
|
-
next_items.append(bucket[i:])
|
|
226
|
-
# If we have filled the batch, flush it.
|
|
227
|
-
if (batch_max_num_bytes is not None and batch_len >= batch_max_num_bytes) or (
|
|
228
|
-
batch_max_num_items is not None and len(batch) >= batch_max_num_items
|
|
229
|
-
):
|
|
230
|
-
flush_batch()
|
|
231
|
-
|
|
232
|
-
# We've finished an iteration through all the buckets. Anything
|
|
233
|
-
# remaining in `next_items` will need to be processed in a subsequent
|
|
234
|
-
# batch set, so flush our remaining batch and the batch set, and use
|
|
235
|
-
# next_items for the next iteration.
|
|
236
|
-
if batch:
|
|
171
|
+
for item in items:
|
|
172
|
+
item_size = get_byte_size(item)
|
|
173
|
+
if len(batch) > 0 and not (
|
|
174
|
+
(batch_max_num_bytes is None or item_size + batch_len < batch_max_num_bytes)
|
|
175
|
+
and (batch_max_num_items is None or len(batch) < batch_max_num_items)
|
|
176
|
+
):
|
|
237
177
|
flush_batch()
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
178
|
+
add_to_batch(item)
|
|
179
|
+
|
|
180
|
+
if len(batch) > 0:
|
|
181
|
+
flush_batch()
|
|
242
182
|
|
|
243
183
|
return output
|
braintrust/oai.py
CHANGED
|
@@ -312,16 +312,25 @@ class ChatCompletionWrapper:
|
|
|
312
312
|
|
|
313
313
|
# pylint: disable=unsubscriptable-object
|
|
314
314
|
if not tool_calls or (tool_delta.get("id") and tool_calls[-1]["id"] != tool_delta.get("id")):
|
|
315
|
+
function_arg = tool_delta.get("function", {})
|
|
315
316
|
tool_calls = (tool_calls or []) + [
|
|
316
317
|
{
|
|
317
|
-
"id":
|
|
318
|
-
"type":
|
|
319
|
-
"function":
|
|
318
|
+
"id": tool_delta.get("id"),
|
|
319
|
+
"type": tool_delta.get("type"),
|
|
320
|
+
"function": {
|
|
321
|
+
"name": function_arg.get("name"),
|
|
322
|
+
"arguments": function_arg.get("arguments") or "",
|
|
323
|
+
},
|
|
320
324
|
}
|
|
321
325
|
]
|
|
322
326
|
else:
|
|
323
327
|
# pylint: disable=unsubscriptable-object
|
|
324
|
-
|
|
328
|
+
# append to existing tool call
|
|
329
|
+
function_arg = tool_delta.get("function", {})
|
|
330
|
+
args = function_arg.get("arguments") or ""
|
|
331
|
+
if isinstance(args, str):
|
|
332
|
+
# pylint: disable=unsubscriptable-object
|
|
333
|
+
tool_calls[-1]["function"]["arguments"] += args
|
|
325
334
|
|
|
326
335
|
return {
|
|
327
336
|
"metrics": metrics,
|