palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +343 -209
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +639 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +62 -6
- palimpzest/prompts/filter_prompts.py +51 -6
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
- palimpzest/prompts/prompt_factory.py +375 -47
- palimpzest/prompts/split_proposer_prompts.py +1 -1
- palimpzest/prompts/util_phrases.py +5 -0
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +160 -331
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +33 -19
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +26 -16
- palimpzest/query/operators/join.py +403 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +205 -77
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +42 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +32 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
- palimpzest-0.8.1.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -5,40 +5,60 @@ from collections.abc import Generator
|
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
7
|
import pandas as pd
|
|
8
|
-
|
|
9
|
-
from
|
|
10
|
-
|
|
11
|
-
from palimpzest.core.
|
|
12
|
-
from palimpzest.
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
from pydantic.fields import FieldInfo
|
|
10
|
+
|
|
11
|
+
from palimpzest.core.data import context
|
|
12
|
+
from palimpzest.core.lib.schemas import (
|
|
13
|
+
AudioBase64,
|
|
14
|
+
AudioFilepath,
|
|
15
|
+
ImageBase64,
|
|
16
|
+
ImageFilepath,
|
|
17
|
+
ImageURL,
|
|
18
|
+
create_schema_from_df,
|
|
19
|
+
project,
|
|
20
|
+
union_schemas,
|
|
21
|
+
)
|
|
22
|
+
from palimpzest.core.models import ExecutionStats, PlanStats, RecordOpStats
|
|
23
|
+
from palimpzest.utils.hash_helpers import hash_for_id, hash_for_serialized_dict
|
|
13
24
|
|
|
14
25
|
|
|
15
26
|
class DataRecord:
|
|
16
|
-
"""A DataRecord is a single record of data matching some
|
|
27
|
+
"""A DataRecord is a single record of data matching some schema defined by a BaseModel."""
|
|
17
28
|
|
|
18
29
|
def __init__(
|
|
19
30
|
self,
|
|
20
|
-
schema:
|
|
21
|
-
|
|
22
|
-
|
|
31
|
+
schema: BaseModel,
|
|
32
|
+
source_indices: str | list[str],
|
|
33
|
+
parent_ids: str | list[str] | None = None,
|
|
23
34
|
cardinality_idx: int | None = None,
|
|
24
35
|
):
|
|
25
|
-
# check that
|
|
26
|
-
assert
|
|
36
|
+
# check that source_indices are provided
|
|
37
|
+
assert source_indices is not None, "Every DataRecord must be constructed with source index (or indices)"
|
|
38
|
+
|
|
39
|
+
# normalize to list[str]
|
|
40
|
+
if not isinstance(source_indices, list):
|
|
41
|
+
source_indices = [source_indices]
|
|
42
|
+
|
|
43
|
+
# normalize to list[str]
|
|
44
|
+
if isinstance(parent_ids, str):
|
|
45
|
+
parent_ids = [parent_ids]
|
|
27
46
|
|
|
28
47
|
# schema for the data record
|
|
29
48
|
self.schema = schema
|
|
30
49
|
|
|
31
50
|
# mapping from field names to Field objects; effectively a mapping from a field name to its type
|
|
32
|
-
self.field_types: dict[str,
|
|
51
|
+
self.field_types: dict[str, FieldInfo] = schema.model_fields
|
|
33
52
|
|
|
34
53
|
# mapping from field names to their values
|
|
35
54
|
self.field_values: dict[str, Any] = {}
|
|
36
55
|
|
|
37
|
-
# the index in the
|
|
38
|
-
|
|
56
|
+
# the index in the root Dataset from which this DataRecord is derived;
|
|
57
|
+
# each source index takes the form: f"{root_dataset.id}-{idx}"
|
|
58
|
+
self.source_indices = sorted(source_indices)
|
|
39
59
|
|
|
40
|
-
# the id of the parent record(s) from which this DataRecord is derived
|
|
41
|
-
self.
|
|
60
|
+
# the id(s) of the parent record(s) from which this DataRecord is derived
|
|
61
|
+
self.parent_ids = parent_ids
|
|
42
62
|
|
|
43
63
|
# store the cardinality index
|
|
44
64
|
self.cardinality_idx = cardinality_idx
|
|
@@ -48,7 +68,7 @@ class DataRecord:
|
|
|
48
68
|
|
|
49
69
|
# NOTE: Record ids are hashed based on:
|
|
50
70
|
# 0. their schema (keys)
|
|
51
|
-
# 1. their parent record id(s) (or
|
|
71
|
+
# 1. their parent record id(s) (or source_indices if there is no parent record)
|
|
52
72
|
# 2. their index in the fan out (if this is in a one-to-many operation)
|
|
53
73
|
#
|
|
54
74
|
# We currently do NOT hash just based on record content (i.e. schema (key, value) pairs)
|
|
@@ -59,9 +79,9 @@ class DataRecord:
|
|
|
59
79
|
|
|
60
80
|
# unique identifier for the record
|
|
61
81
|
id_str = (
|
|
62
|
-
str(schema) + (
|
|
82
|
+
str(schema) + str(parent_ids) if parent_ids is not None else str(self.source_indices)
|
|
63
83
|
if cardinality_idx is None
|
|
64
|
-
else str(schema) + str(cardinality_idx) + str(
|
|
84
|
+
else str(schema) + str(cardinality_idx) + str(parent_ids) if parent_ids is not None else str(self.source_indices)
|
|
65
85
|
)
|
|
66
86
|
# TODO(Jun): build-in id should has a special name, the current self.id is too general which would conflict with user defined schema too easily.
|
|
67
87
|
# the options: built_in_id, generated_id
|
|
@@ -69,7 +89,7 @@ class DataRecord:
|
|
|
69
89
|
|
|
70
90
|
|
|
71
91
|
def __setattr__(self, name: str, value: Any, /) -> None:
|
|
72
|
-
if name in ["schema", "field_types", "field_values", "
|
|
92
|
+
if name in ["schema", "field_types", "field_values", "source_indices", "parent_ids", "cardinality_idx", "passed_operator", "id"]:
|
|
73
93
|
super().__setattr__(name, value)
|
|
74
94
|
else:
|
|
75
95
|
self.field_values[name] = value
|
|
@@ -103,11 +123,10 @@ class DataRecord:
|
|
|
103
123
|
return self.__str__(truncate=None)
|
|
104
124
|
|
|
105
125
|
def __eq__(self, other):
|
|
106
|
-
return isinstance(other, DataRecord) and self.field_values == other.field_values and self.schema
|
|
107
|
-
|
|
126
|
+
return isinstance(other, DataRecord) and self.field_values == other.field_values and self.schema == other.schema
|
|
108
127
|
|
|
109
128
|
def __hash__(self):
|
|
110
|
-
return hash(self.to_json_str(bytes_to_str=True))
|
|
129
|
+
return hash(self.to_json_str(bytes_to_str=True, sorted=True))
|
|
111
130
|
|
|
112
131
|
|
|
113
132
|
def __iter__(self):
|
|
@@ -118,16 +137,16 @@ class DataRecord:
|
|
|
118
137
|
return list(self.field_values.keys())
|
|
119
138
|
|
|
120
139
|
|
|
121
|
-
def get_field_type(self, field_name: str) ->
|
|
140
|
+
def get_field_type(self, field_name: str) -> FieldInfo:
|
|
122
141
|
return self.field_types[field_name]
|
|
123
142
|
|
|
124
143
|
|
|
125
144
|
def copy(self, include_bytes: bool = True, project_cols: list[str] | None = None):
|
|
126
|
-
# make
|
|
145
|
+
# make copy of the current record
|
|
127
146
|
new_dr = DataRecord(
|
|
128
147
|
self.schema,
|
|
129
|
-
|
|
130
|
-
|
|
148
|
+
source_indices=self.source_indices,
|
|
149
|
+
parent_ids=self.parent_ids,
|
|
131
150
|
cardinality_idx=self.cardinality_idx,
|
|
132
151
|
)
|
|
133
152
|
|
|
@@ -158,25 +177,28 @@ class DataRecord:
|
|
|
158
177
|
|
|
159
178
|
@staticmethod
|
|
160
179
|
def from_parent(
|
|
161
|
-
schema:
|
|
180
|
+
schema: BaseModel,
|
|
162
181
|
parent_record: DataRecord,
|
|
163
182
|
project_cols: list[str] | None = None,
|
|
164
183
|
cardinality_idx: int | None = None,
|
|
165
184
|
) -> DataRecord:
|
|
166
|
-
# project_cols must be None or contain at least one column
|
|
167
|
-
assert project_cols is None or len(project_cols) >= 1, "must have at least one column if using projection"
|
|
168
|
-
|
|
169
185
|
# if project_cols is None, then the new schema is a union of the provided schema and parent_record.schema;
|
|
186
|
+
# if project_cols is an empty list, then the new schema is simply the provided schema
|
|
170
187
|
# otherwise, it's a ProjectSchema
|
|
171
|
-
new_schema =
|
|
172
|
-
if project_cols is
|
|
173
|
-
new_schema =
|
|
188
|
+
new_schema = None
|
|
189
|
+
if project_cols is None:
|
|
190
|
+
new_schema = union_schemas([schema, parent_record.schema])
|
|
191
|
+
elif project_cols == []:
|
|
192
|
+
new_schema = schema
|
|
193
|
+
else:
|
|
194
|
+
new_schema = union_schemas([schema, parent_record.schema])
|
|
195
|
+
new_schema = project(new_schema, project_cols)
|
|
174
196
|
|
|
175
|
-
# make new record which has parent_record as its parent (and the same
|
|
197
|
+
# make new record which has parent_record as its parent (and the same source_indices)
|
|
176
198
|
new_dr = DataRecord(
|
|
177
199
|
new_schema,
|
|
178
|
-
|
|
179
|
-
|
|
200
|
+
source_indices=parent_record.source_indices,
|
|
201
|
+
parent_ids=[parent_record.id],
|
|
180
202
|
cardinality_idx=cardinality_idx,
|
|
181
203
|
)
|
|
182
204
|
|
|
@@ -194,35 +216,78 @@ class DataRecord:
|
|
|
194
216
|
|
|
195
217
|
@staticmethod
|
|
196
218
|
def from_agg_parents(
|
|
197
|
-
schema:
|
|
219
|
+
schema: BaseModel,
|
|
198
220
|
parent_records: DataRecordSet,
|
|
199
|
-
project_cols: list[str] | None = None,
|
|
200
221
|
cardinality_idx: int | None = None,
|
|
201
222
|
) -> DataRecord:
|
|
202
|
-
#
|
|
203
|
-
|
|
204
|
-
|
|
223
|
+
# flatten source indices from all parents
|
|
224
|
+
source_indices = [
|
|
225
|
+
source_idx
|
|
226
|
+
for parent_record in parent_records
|
|
227
|
+
for source_idx in parent_record.source_indices
|
|
228
|
+
]
|
|
229
|
+
|
|
230
|
+
# make new record which has all parent records as its parents
|
|
231
|
+
return DataRecord(
|
|
232
|
+
schema,
|
|
233
|
+
source_indices=source_indices,
|
|
234
|
+
parent_ids=[parent_record.id for parent_record in parent_records],
|
|
235
|
+
cardinality_idx=cardinality_idx,
|
|
236
|
+
)
|
|
205
237
|
|
|
206
238
|
@staticmethod
|
|
207
239
|
def from_join_parents(
|
|
208
|
-
|
|
209
|
-
right_schema: Schema,
|
|
240
|
+
schema: BaseModel,
|
|
210
241
|
left_parent_record: DataRecord,
|
|
211
242
|
right_parent_record: DataRecord,
|
|
212
243
|
project_cols: list[str] | None = None,
|
|
213
244
|
cardinality_idx: int = None,
|
|
214
245
|
) -> DataRecord:
|
|
215
|
-
#
|
|
216
|
-
|
|
246
|
+
# make new record which has left and right parent record as its parents
|
|
247
|
+
new_dr = DataRecord(
|
|
248
|
+
schema,
|
|
249
|
+
source_indices=list(left_parent_record.source_indices) + list(right_parent_record.source_indices),
|
|
250
|
+
parent_ids=[left_parent_record.id, right_parent_record.id],
|
|
251
|
+
cardinality_idx=cardinality_idx,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# get the set of fields and field descriptions to copy from the parent record(s)
|
|
255
|
+
left_copy_field_names = (
|
|
256
|
+
left_parent_record.get_field_names()
|
|
257
|
+
if project_cols is None
|
|
258
|
+
else [col for col in project_cols if col in left_parent_record.get_field_names()]
|
|
259
|
+
)
|
|
260
|
+
right_copy_field_names = (
|
|
261
|
+
right_parent_record.get_field_names()
|
|
262
|
+
if project_cols is None
|
|
263
|
+
else [col for col in project_cols if col in right_parent_record.get_field_names()]
|
|
264
|
+
)
|
|
265
|
+
left_copy_field_names = [field.split(".")[-1] for field in left_copy_field_names]
|
|
266
|
+
right_copy_field_names = [field.split(".")[-1] for field in right_copy_field_names]
|
|
217
267
|
|
|
268
|
+
# copy fields from the parents
|
|
269
|
+
for field_name in left_copy_field_names:
|
|
270
|
+
new_dr.field_types[field_name] = left_parent_record.get_field_type(field_name)
|
|
271
|
+
new_dr[field_name] = left_parent_record[field_name]
|
|
218
272
|
|
|
273
|
+
for field_name in right_copy_field_names:
|
|
274
|
+
new_field_name = field_name
|
|
275
|
+
if field_name in left_copy_field_names:
|
|
276
|
+
new_field_name = f"{field_name}_right"
|
|
277
|
+
new_dr.field_types[new_field_name] = right_parent_record.get_field_type(field_name)
|
|
278
|
+
new_dr[new_field_name] = right_parent_record[field_name]
|
|
279
|
+
|
|
280
|
+
return new_dr
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
# TODO: unused outside of unit tests
|
|
219
284
|
@staticmethod
|
|
220
|
-
def from_df(df: pd.DataFrame, schema:
|
|
285
|
+
def from_df(df: pd.DataFrame, schema: BaseModel | None = None) -> list[DataRecord]:
|
|
221
286
|
"""Create a list of DataRecords from a pandas DataFrame
|
|
222
287
|
|
|
223
288
|
Args:
|
|
224
289
|
df (pd.DataFrame): Input DataFrame
|
|
225
|
-
schema (
|
|
290
|
+
schema (BaseModel, optional): Schema for the DataRecords. If None, will be derived from DataFrame
|
|
226
291
|
|
|
227
292
|
Returns:
|
|
228
293
|
list[DataRecord]: List of DataRecord instances
|
|
@@ -230,16 +295,23 @@ class DataRecord:
|
|
|
230
295
|
if df is None:
|
|
231
296
|
raise ValueError("DataFrame is None!")
|
|
232
297
|
|
|
233
|
-
|
|
298
|
+
# create schema if one isn't provided
|
|
234
299
|
if schema is None:
|
|
235
|
-
schema =
|
|
300
|
+
schema = create_schema_from_df(df)
|
|
301
|
+
|
|
302
|
+
# create an id for the dataset from the schema
|
|
303
|
+
dataset_id = hash_for_serialized_dict({
|
|
304
|
+
k: {"annotation": str(v.annotation), "default": str(v.default), "description": v.description}
|
|
305
|
+
for k, v in schema.model_fields.items()
|
|
306
|
+
})
|
|
236
307
|
|
|
237
|
-
|
|
238
|
-
|
|
308
|
+
# create records
|
|
309
|
+
records = []
|
|
310
|
+
for idx, row in df.iterrows():
|
|
239
311
|
row_dict = row.to_dict()
|
|
240
|
-
record = DataRecord(schema=schema,
|
|
312
|
+
record = DataRecord(schema=schema, source_indices=[f"{dataset_id}-{idx}"])
|
|
241
313
|
record.field_values = row_dict
|
|
242
|
-
record.field_types = {field_name:
|
|
314
|
+
record.field_types = {field_name: schema.model_fields[field_name] for field_name in row_dict}
|
|
243
315
|
records.append(record)
|
|
244
316
|
|
|
245
317
|
return records
|
|
@@ -253,33 +325,41 @@ class DataRecord:
|
|
|
253
325
|
if project_cols is not None and len(project_cols) > 0:
|
|
254
326
|
fields = [field for field in fields if field in project_cols]
|
|
255
327
|
|
|
328
|
+
# convert Context --> str
|
|
329
|
+
for record in records:
|
|
330
|
+
for k in fields:
|
|
331
|
+
if isinstance(record[k], context.Context):
|
|
332
|
+
record[k] = record[k].description
|
|
333
|
+
|
|
256
334
|
return pd.DataFrame([
|
|
257
335
|
{k: record[k] for k in fields}
|
|
258
336
|
for record in records
|
|
259
337
|
])
|
|
260
338
|
|
|
261
|
-
def to_json_str(self, include_bytes: bool = True, bytes_to_str: bool = False, project_cols: list[str] | None = None):
|
|
339
|
+
def to_json_str(self, include_bytes: bool = True, bytes_to_str: bool = False, project_cols: list[str] | None = None, sorted: bool = False):
|
|
262
340
|
"""Return a JSON representation of this DataRecord"""
|
|
263
|
-
record_dict = self.to_dict(include_bytes, bytes_to_str, project_cols)
|
|
264
|
-
record_dict = {
|
|
265
|
-
field_name: self.schema.field_to_json(field_name, field_value)
|
|
266
|
-
for field_name, field_value in record_dict.items()
|
|
267
|
-
}
|
|
341
|
+
record_dict = self.to_dict(include_bytes, bytes_to_str, project_cols, sorted)
|
|
268
342
|
return json.dumps(record_dict, indent=2)
|
|
269
343
|
|
|
270
|
-
def to_dict(self, include_bytes: bool = True, bytes_to_str: bool = False, project_cols: list[str] | None = None):
|
|
344
|
+
def to_dict(self, include_bytes: bool = True, bytes_to_str: bool = False, project_cols: list[str] | None = None, _sorted: bool = False, mask_filepaths: bool = False):
|
|
271
345
|
"""Return a dictionary representation of this DataRecord"""
|
|
272
346
|
# TODO(chjun): In case of numpy types, the json.dumps will fail. Convert to native types.
|
|
273
347
|
# Better ways to handle this.
|
|
274
|
-
|
|
348
|
+
field_values = {
|
|
349
|
+
k: v.description
|
|
350
|
+
if isinstance(v, context.Context) else v
|
|
351
|
+
for k, v in self.field_values.items()
|
|
352
|
+
}
|
|
353
|
+
dct = pd.Series(field_values).to_dict()
|
|
275
354
|
|
|
276
355
|
if project_cols is not None and len(project_cols) > 0:
|
|
277
356
|
project_field_names = set(field.split(".")[-1] for field in project_cols)
|
|
278
357
|
dct = {k: v for k, v in dct.items() if k in project_field_names}
|
|
279
358
|
|
|
280
359
|
if not include_bytes:
|
|
281
|
-
for k
|
|
282
|
-
|
|
360
|
+
for k in dct:
|
|
361
|
+
field_type = self.field_types[k]
|
|
362
|
+
if field_type.annotation in [bytes, AudioBase64, ImageBase64, list[bytes], list[ImageBase64]]:
|
|
283
363
|
dct[k] = "<bytes>"
|
|
284
364
|
|
|
285
365
|
if bytes_to_str:
|
|
@@ -289,12 +369,21 @@ class DataRecord:
|
|
|
289
369
|
elif isinstance(v, list) and len(v) > 0 and any([isinstance(elt, bytes) for elt in v]):
|
|
290
370
|
dct[k] = [elt.decode("utf-8") if isinstance(elt, bytes) else elt for elt in v]
|
|
291
371
|
|
|
372
|
+
if _sorted:
|
|
373
|
+
dct = dict(sorted(dct.items()))
|
|
374
|
+
|
|
375
|
+
if mask_filepaths:
|
|
376
|
+
for k in dct:
|
|
377
|
+
field_type = self.field_types[k]
|
|
378
|
+
if field_type.annotation in [AudioBase64, AudioFilepath, ImageBase64, ImageFilepath, ImageURL]:
|
|
379
|
+
dct[k] = "<bytes>"
|
|
380
|
+
|
|
292
381
|
return dct
|
|
293
382
|
|
|
294
383
|
|
|
295
384
|
class DataRecordSet:
|
|
296
385
|
"""
|
|
297
|
-
A DataRecordSet contains a list of DataRecords that share the same schema, same
|
|
386
|
+
A DataRecordSet contains a list of DataRecords that share the same schema, same parent(s), and same source(s).
|
|
298
387
|
|
|
299
388
|
We explicitly check that this is True.
|
|
300
389
|
|
|
@@ -305,20 +394,22 @@ class DataRecordSet:
|
|
|
305
394
|
data_records: list[DataRecord],
|
|
306
395
|
record_op_stats: list[RecordOpStats],
|
|
307
396
|
field_to_score_fn: dict[str, str | callable] | None = None,
|
|
397
|
+
input: int | DataRecord | list[DataRecord] | tuple[list[DataRecord]] | None = None,
|
|
308
398
|
):
|
|
309
|
-
#
|
|
310
|
-
if len(data_records) > 0:
|
|
311
|
-
parent_id = data_records[0].parent_id
|
|
312
|
-
error_msg = "DataRecordSet must be constructed from the output of executing a single operator on a single input."
|
|
313
|
-
assert all([dr.parent_id == parent_id for dr in data_records]), error_msg
|
|
314
|
-
|
|
315
|
-
# set data_records, parent_id, and source_idx; note that it is possible for
|
|
399
|
+
# set data_records, parent_ids, and source_indices; note that it is possible for
|
|
316
400
|
# data_records to be an empty list in the event of a failed convert
|
|
317
401
|
self.data_records = data_records
|
|
318
|
-
self.
|
|
319
|
-
self.
|
|
402
|
+
self.parent_ids = data_records[0].parent_ids if len(data_records) > 0 else None
|
|
403
|
+
self.source_indices = data_records[0].source_indices if len(data_records) > 0 else None
|
|
320
404
|
self.schema = data_records[0].schema if len(data_records) > 0 else None
|
|
321
405
|
|
|
406
|
+
# the input to the operator which produced the data_records; type is tuple[DataRecord] | tuple[int]
|
|
407
|
+
# - for scan operators, input is a singleton tuple[int] which wraps the source_idx, e.g.: (source_idx,)
|
|
408
|
+
# - for join operators, input is a tuple with one entry for the left input DataRecord and one entry for the right input DataRecord
|
|
409
|
+
# - for aggregate operators, input is a tuple with all the input DataRecords to the aggregation
|
|
410
|
+
# - for all other operaotrs, input is a singleton tuple[DataRecord] which wraps the single input
|
|
411
|
+
self.input = input
|
|
412
|
+
|
|
322
413
|
# set statistics for generating these records
|
|
323
414
|
self.record_op_stats = record_op_stats
|
|
324
415
|
|
|
@@ -350,7 +441,7 @@ class DataRecordCollection:
|
|
|
350
441
|
The difference between DataRecordSet and DataRecordCollection
|
|
351
442
|
|
|
352
443
|
Goal:
|
|
353
|
-
DataRecordSet is a set of DataRecords that share the same schema, same
|
|
444
|
+
DataRecordSet is a set of DataRecords that share the same schema, same parents, and same sources.
|
|
354
445
|
DataRecordCollection is a general wrapper for list[DataRecord].
|
|
355
446
|
|
|
356
447
|
Usage:
|