braintrust 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,14 +1,15 @@
1
- from collections.abc import Mapping, Sequence
2
- from typing import Any, Optional
1
+ from collections.abc import Callable, Sequence
2
+ from typing import Any, Optional, TypeVar
3
3
 
4
- from .db_fields import IS_MERGE_FIELD, PARENT_ID_FIELD
5
- from .graph_util import UndirectedGraph, topological_sort, undirected_connected_components
4
+ from .db_fields import IS_MERGE_FIELD
5
+
6
+ T = TypeVar("T")
6
7
  from .util import merge_dicts
7
8
 
8
9
  _MergedRowKey = tuple[Optional[Any], ...]
9
10
 
10
11
 
11
- def _generate_merged_row_key(row: Mapping[str, Any], use_parent_id_for_id: bool = False) -> _MergedRowKey:
12
+ def _generate_merged_row_key(row: dict[str, Any]) -> _MergedRowKey:
12
13
  return tuple(
13
14
  row.get(k)
14
15
  for k in [
@@ -18,7 +19,7 @@ def _generate_merged_row_key(row: Mapping[str, Any], use_parent_id_for_id: bool
18
19
  "dataset_id",
19
20
  "prompt_session_id",
20
21
  "log_id",
21
- PARENT_ID_FIELD if use_parent_id_for_id else "id",
22
+ "id",
22
23
  ]
23
24
  )
24
25
 
@@ -49,12 +50,9 @@ def _restore_merge_row_skip_fields(row: dict[str, Any], skip_fields: dict[str, A
49
50
  row[field] = skip_fields[field]
50
51
 
51
52
 
52
- def merge_row_batch(rows: Sequence[dict[str, Any]]) -> list[list[dict[str, Any]]]:
53
+ def merge_row_batch(rows: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
53
54
  """Given a batch of rows, merges conflicting rows together to end up with a
54
- set of rows to insert. Returns a set of de-conflicted rows, as a list of
55
- lists, where separate lists contain "independent" rows which can be
56
- processed concurrently, while the rows in each list must be processed in
57
- order, as later rows may depend on earlier ones.
55
+ set of rows to insert. Returns a set of de-conflicted rows as a flat list.
58
56
 
59
57
  Note that the returned rows will be the same objects as the input `rows`,
60
58
  meaning they are mutated in place.
@@ -118,65 +116,31 @@ def merge_row_batch(rows: Sequence[dict[str, Any]]) -> list[list[dict[str, Any]]
118
116
  else:
119
117
  row_groups[key] = row
120
118
 
121
- merged = list(row_groups.values())
122
-
123
- # Now that we have just one row per id, we can bucket and order the rows by
124
- # their PARENT_ID_FIELD relationships.
125
- row_to_label = {_generate_merged_row_key(r): i for i, r in enumerate(merged)}
126
-
127
- # Form a graph where edges go from parents to their children.
128
- graph = {i: set() for i in range(len(merged))}
129
- for i, r in enumerate(merged):
130
- parent_id = r.get(PARENT_ID_FIELD)
131
- if not parent_id:
132
- continue
133
- parent_row_key = _generate_merged_row_key(r, use_parent_id_for_id=True)
134
- parent_label = row_to_label.get(parent_row_key)
135
- if parent_label is not None:
136
- graph[parent_label].add(i)
137
-
138
- # Group together all the connected components of the undirected graph to get
139
- # all groups of rows which each row in a group has a PARENT_ID_FIELD
140
- # relationship with at least one other row in the group.
141
- connected_components = undirected_connected_components(
142
- UndirectedGraph(vertices=set(graph.keys()), edges={(k, v) for k, vs in graph.items() for v in vs})
143
- )
144
-
145
- # For each connected row group, run topological sort over that subgraph to
146
- # get an ordering of rows where parents come before children.
147
- buckets = [topological_sort(graph, visitation_order=cc) for cc in connected_components]
148
- return [[merged[i] for i in bucket] for bucket in buckets]
119
+ return list(row_groups.values())
149
120
 
150
121
 
151
122
  def batch_items(
152
- items: list[list[str]], batch_max_num_items: int | None = None, batch_max_num_bytes: int | None = None
153
- ) -> list[list[list[str]]]:
154
- """Repartition the given list of items into sets of batches which can be
155
- published in parallel or in sequence.
156
-
157
- Output-wise, each outer List[List[str]] is a set of batches which must be
158
- published in sequence. Within each set of batches, each individual List[str]
159
- batch may be published in parallel with all other batches in its set,
160
- retaining the order within the batch. So from outside to inside, it goes
161
- ordered -> parallel -> ordered.
123
+ items: list[T],
124
+ batch_max_num_items: int | None = None,
125
+ batch_max_num_bytes: int | None = None,
126
+ get_byte_size: Callable[[T], int] | None = None,
127
+ ) -> list[list[T]]:
128
+ """Repartition the given list of items into batches.
162
129
 
163
130
  Arguments:
164
131
 
165
- - `items` is a list of ordered buckets, where the constraint is that items
166
- in different buckets can be published in parallel, while items within a
167
- bucket must be published in sequence. That means that if two items are in
168
- the same bucket, they will either appear in the same innermost List[str]
169
- in the output, or in separate List[List[str]] batch sets, with their
170
- relative order preserved. If two items are in different buckets, they can
171
- appear in different List[str] batches.
172
-
173
- - `batch_max_num_items` is the maximum number of items in each List[str]
174
- batch. If not provided, there is no limit on the number of items.
175
-
176
- - `batch_max_num_bytes` is the maximum number of bytes (computed as
177
- `sum(len(item) for item in batch)`) in each List[str] batch. If an
178
- individual item exceeds `batch_max_num_bytes` in size, we will place it in
179
- its own batch. If not provided, there is no limit on the number of bytes.
132
+ - `items` is a list of items to batch.
133
+
134
+ - `batch_max_num_items` is the maximum number of items in each batch.
135
+ If not provided, there is no limit on the number of items.
136
+
137
+ - `batch_max_num_bytes` is the maximum number of bytes in each batch.
138
+ If an individual item exceeds `batch_max_num_bytes` in size, we
139
+ will place it in its own batch. If not provided, there is no limit on
140
+ the number of bytes.
141
+
142
+ - `get_byte_size` is a function that returns the byte size of an item.
143
+ If not provided, defaults to `len(item)` (works for strings).
180
144
  """
181
145
 
182
146
  if batch_max_num_items is not None and batch_max_num_items <= 0:
@@ -184,60 +148,36 @@ def batch_items(
184
148
  if batch_max_num_bytes is not None and batch_max_num_bytes < 0:
185
149
  raise ValueError(f"batch_max_num_bytes must be nonnegative; got {batch_max_num_bytes}")
186
150
 
187
- output = []
188
- next_items = []
189
- batch_set = []
190
- batch = []
151
+ if get_byte_size is None:
152
+
153
+ def get_byte_size(item: T) -> int:
154
+ return len(item) # type: ignore[arg-type]
155
+
156
+ output: list[list[T]] = []
157
+ batch: list[T] = []
191
158
  batch_len = 0
192
159
 
193
- def add_to_batch(item):
160
+ def add_to_batch(item: T) -> None:
194
161
  nonlocal batch_len
195
162
  batch.append(item)
196
- batch_len += len(item)
163
+ batch_len += get_byte_size(item)
197
164
 
198
- def flush_batch():
165
+ def flush_batch() -> None:
199
166
  nonlocal batch, batch_len
200
- batch_set.append(batch)
167
+ output.append(batch)
201
168
  batch = []
202
169
  batch_len = 0
203
170
 
204
- while items:
205
- for bucket in items:
206
- i = 0
207
- for item in bucket:
208
- if len(batch) == 0 or (
209
- (batch_max_num_bytes is None or len(item) + batch_len < batch_max_num_bytes)
210
- and (batch_max_num_items is None or len(batch) < batch_max_num_items)
211
- ):
212
- add_to_batch(item)
213
- elif i == 0:
214
- # If the very first item in the bucket fills the batch, we
215
- # can flush this batch and start a new one which includes
216
- # this item.
217
- flush_batch()
218
- add_to_batch(item)
219
- else:
220
- break
221
- i += 1
222
- # If we didn't completely exhaust the bucket, save it for the next
223
- # batch set.
224
- if i < len(bucket):
225
- next_items.append(bucket[i:])
226
- # If we have filled the batch, flush it.
227
- if (batch_max_num_bytes is not None and batch_len >= batch_max_num_bytes) or (
228
- batch_max_num_items is not None and len(batch) >= batch_max_num_items
229
- ):
230
- flush_batch()
231
-
232
- # We've finished an iteration through all the buckets. Anything
233
- # remaining in `next_items` will need to be processed in a subsequent
234
- # batch set, so flush our remaining batch and the batch set, and use
235
- # next_items for the next iteration.
236
- if batch:
171
+ for item in items:
172
+ item_size = get_byte_size(item)
173
+ if len(batch) > 0 and not (
174
+ (batch_max_num_bytes is None or item_size + batch_len < batch_max_num_bytes)
175
+ and (batch_max_num_items is None or len(batch) < batch_max_num_items)
176
+ ):
237
177
  flush_batch()
238
- if batch_set:
239
- output.append(batch_set)
240
- batch_set = []
241
- items, next_items = next_items, []
178
+ add_to_batch(item)
179
+
180
+ if len(batch) > 0:
181
+ flush_batch()
242
182
 
243
183
  return output