palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +343 -209
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +639 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +62 -6
- palimpzest/prompts/filter_prompts.py +51 -6
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
- palimpzest/prompts/prompt_factory.py +375 -47
- palimpzest/prompts/split_proposer_prompts.py +1 -1
- palimpzest/prompts/util_phrases.py +5 -0
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +160 -331
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +33 -19
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +26 -16
- palimpzest/query/operators/join.py +403 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +205 -77
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +42 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +32 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
- palimpzest-0.8.1.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,639 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
from chromadb.api.models.Collection import Collection
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
from palimpzest.constants import AggFunc, Cardinality
|
|
11
|
+
from palimpzest.core.elements.filters import Filter
|
|
12
|
+
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
13
|
+
from palimpzest.core.lib.schemas import create_schema_from_fields, project, union_schemas
|
|
14
|
+
from palimpzest.policy import construct_policy_from_kwargs
|
|
15
|
+
from palimpzest.query.operators.logical import (
|
|
16
|
+
Aggregate,
|
|
17
|
+
ConvertScan,
|
|
18
|
+
Distinct,
|
|
19
|
+
FilteredScan,
|
|
20
|
+
GroupByAggregate,
|
|
21
|
+
JoinOp,
|
|
22
|
+
LimitScan,
|
|
23
|
+
LogicalOperator,
|
|
24
|
+
Project,
|
|
25
|
+
RetrieveScan,
|
|
26
|
+
)
|
|
27
|
+
from palimpzest.query.processor.config import QueryProcessorConfig
|
|
28
|
+
from palimpzest.utils.hash_helpers import hash_for_serialized_dict
|
|
29
|
+
from palimpzest.validator.validator import Validator
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# TODO?: remove `schema` from `Dataset` and access it from `operator`?
|
|
33
|
+
# - Q: how do you handle datasets with multiple sources?
|
|
34
|
+
# - for joins the operator should have the union'ed schema
|
|
35
|
+
# - but for Contexts it may be trickier
|
|
36
|
+
class Dataset:
|
|
37
|
+
"""
|
|
38
|
+
A `Dataset` represents a collection of structured or unstructured data that can be processed and
|
|
39
|
+
transformed. Each `Dataset` is either a "root" `Dataset` (which yields data items) or it is the
|
|
40
|
+
result of performing data processing operation(s) on root `Dataset(s)`.
|
|
41
|
+
|
|
42
|
+
Users can perform computations on a `Dataset` in a lazy or eager fashion. Applying functions
|
|
43
|
+
such as `sem_filter`, `sem_map`, `sem_join`, `sem_agg`, etc. will lazily create a new `Dataset`.
|
|
44
|
+
Users can invoke the `run()` method to execute the computation and retrieve a materialized `Dataset`.
|
|
45
|
+
Materialized `Dataset`s can be processed further, or their results can be retrieved using `.get()`.
|
|
46
|
+
|
|
47
|
+
A root `Dataset` must subclass at least one of `pz.IterDataset`, `pz.IndexDataset`, or `pz.Context`.
|
|
48
|
+
Each of these classes supports a different access pattern:
|
|
49
|
+
|
|
50
|
+
- `pz.IterDataset`: supports accessing data via iteration
|
|
51
|
+
- Ex: iterating over a list of PDFs
|
|
52
|
+
- Ex: iterating over rows in a DataFrame
|
|
53
|
+
- `pz.IndexDataset`: supports accessing data via point lookups / queries
|
|
54
|
+
- Ex: querying a vector database
|
|
55
|
+
- Ex: querying a SQL database
|
|
56
|
+
- `pz.Context`: supports accessing data with an agent
|
|
57
|
+
- Ex: processing a set of CSV files with a data science agent
|
|
58
|
+
- Ex: processing time series data with a data cleaning agent
|
|
59
|
+
|
|
60
|
+
A root `Dataset` may subclass more than one of the aforementioned classes. For example, the root
|
|
61
|
+
`Dataset` for a list of files may inherit from `pz.IterDataset` and `pz.IndexDataset` to support
|
|
62
|
+
iterating over the files and performing point lookups for individual files.
|
|
63
|
+
|
|
64
|
+
For details on how to create your own root `Dataset`, please see: TODO
|
|
65
|
+
"""
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
sources: list[Dataset] | Dataset | None,
|
|
69
|
+
operator: LogicalOperator,
|
|
70
|
+
schema: type[BaseModel] | None = None,
|
|
71
|
+
id: str | None = None,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Initialize a `Dataset` with one or more `sources` and the operator that is being applied.
|
|
75
|
+
Root `Datasets` subclass `pz.IterDataset`, `pz.IndexDataset`, and/or `pz.Context` and use
|
|
76
|
+
their own constructors.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
sources (`list[Dataset] | Dataset`): The (list of) `Dataset(s)` which are input(s) to
|
|
80
|
+
the operator used to compute this `Dataset`.
|
|
81
|
+
operator (`LogicalOperator`): The `LogicalOperator` used to compute this `Dataset`.
|
|
82
|
+
schema (type[`BaseModel`] | None): The schema of this `Dataset`.
|
|
83
|
+
id (str | None): an identifier for this `Dataset` provided by the user
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
ValueError: if `sources` is not a `Dataset` or list of `Datasets`
|
|
87
|
+
"""
|
|
88
|
+
# set sources
|
|
89
|
+
self._sources = []
|
|
90
|
+
if isinstance(sources, list):
|
|
91
|
+
self._sources = sources
|
|
92
|
+
elif isinstance(sources, Dataset):
|
|
93
|
+
self._sources = [sources]
|
|
94
|
+
elif sources is not None:
|
|
95
|
+
raise ValueError("Dataset sources must be another Dataset or a list of Datasets. For root Datasets, you must subclass pz.IterDataset, pz.IndexDataset, or pz.Context.")
|
|
96
|
+
|
|
97
|
+
# set the logical operator and schema
|
|
98
|
+
self._operator: LogicalOperator = operator
|
|
99
|
+
self._schema = schema
|
|
100
|
+
|
|
101
|
+
# compute the dataset id
|
|
102
|
+
self._id = self._compute_dataset_id() if id is None else id
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def id(self) -> str:
|
|
106
|
+
"""The string identifier for this `Dataset`"""
|
|
107
|
+
return self._id
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def schema(self) -> type[BaseModel]:
|
|
111
|
+
"""The Pydantic model defining the schema of this `Dataset`"""
|
|
112
|
+
return self._schema
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def is_root(self) -> bool:
|
|
116
|
+
return len(self._sources) == 0
|
|
117
|
+
|
|
118
|
+
def __str__(self) -> str:
|
|
119
|
+
return f"Dataset(schema={self._schema}, id={self._id}, op_id={self._operator.get_logical_op_id()})"
|
|
120
|
+
|
|
121
|
+
def __iter__(self) -> Iterator[Dataset]:
|
|
122
|
+
for source in self._sources:
|
|
123
|
+
yield from source
|
|
124
|
+
yield self
|
|
125
|
+
|
|
126
|
+
def _compute_dataset_id(self) -> str:
|
|
127
|
+
"""
|
|
128
|
+
Compute the identifier for this `Dataset`. The ID is uniquely defined by the operation(s)
|
|
129
|
+
applied to the `Dataset's` sources.
|
|
130
|
+
"""
|
|
131
|
+
return hash_for_serialized_dict({
|
|
132
|
+
"source_ids": [source.id for source in self._sources],
|
|
133
|
+
"logical_op_id": self._operator.get_logical_op_id(),
|
|
134
|
+
})
|
|
135
|
+
|
|
136
|
+
def _set_root_datasets(self, new_root_datasets: dict[str, Dataset]) -> None:
|
|
137
|
+
"""
|
|
138
|
+
Update the root dataset(s) for this dataset with the `new_root_datasets`. This is used during
|
|
139
|
+
optimization to reuse the same physical plan while running it on a train dataset.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
new_root_datasets (dict[str, Dataset]): the new root datasets for this dataset.
|
|
143
|
+
"""
|
|
144
|
+
new_sources = []
|
|
145
|
+
for old_source in self._sources:
|
|
146
|
+
if old_source.id in new_root_datasets:
|
|
147
|
+
new_sources.append(new_root_datasets[old_source.id])
|
|
148
|
+
else:
|
|
149
|
+
old_source._set_root_datasets(new_root_datasets)
|
|
150
|
+
new_sources.append(old_source)
|
|
151
|
+
self._sources = new_sources
|
|
152
|
+
|
|
153
|
+
# TODO: the entire way (unique) logical op ids are computed and stored needs to be rethought
|
|
154
|
+
def _generate_unique_logical_op_ids(self, topo_idx: int | None = None) -> None:
|
|
155
|
+
"""
|
|
156
|
+
Generate unique operation IDs for all operators in this dataset and its sources.
|
|
157
|
+
This is used to ensure that each operator can be uniquely identified during execution.
|
|
158
|
+
"""
|
|
159
|
+
# generate the unique op ids for all sources' operators
|
|
160
|
+
for source in self._sources:
|
|
161
|
+
topo_idx = source._generate_unique_logical_op_ids(topo_idx)
|
|
162
|
+
topo_idx += 1
|
|
163
|
+
|
|
164
|
+
# if topo_idx is None, this is the first call, so we initialize it to 0
|
|
165
|
+
if topo_idx is None:
|
|
166
|
+
topo_idx = 0
|
|
167
|
+
|
|
168
|
+
# compute this operator's unique operator id
|
|
169
|
+
this_unique_logical_op_id = f"{topo_idx}-{self._operator.get_logical_op_id()}"
|
|
170
|
+
|
|
171
|
+
# update the unique logical op id for this operator
|
|
172
|
+
self._operator.set_unique_logical_op_id(this_unique_logical_op_id)
|
|
173
|
+
|
|
174
|
+
# return the current unique full_op_id for this operator
|
|
175
|
+
return topo_idx
|
|
176
|
+
|
|
177
|
+
# TODO
|
|
178
|
+
def _resolve_depends_on(self, depends_on: list[str]) -> list[str]:
|
|
179
|
+
"""
|
|
180
|
+
TODO: resolve the `depends_on` strings to their full field names ({Dataset.id}.{field_name}).
|
|
181
|
+
"""
|
|
182
|
+
return []
|
|
183
|
+
|
|
184
|
+
def _get_root_datasets(self) -> dict[str, Dataset]:
|
|
185
|
+
"""Return a mapping from the id --> Dataset for all root datasets."""
|
|
186
|
+
if self.is_root:
|
|
187
|
+
return {self.id: self}
|
|
188
|
+
|
|
189
|
+
root_datasets = {}
|
|
190
|
+
for source in self._sources:
|
|
191
|
+
child_root_datasets = source._get_root_datasets()
|
|
192
|
+
root_datasets = {**root_datasets, **child_root_datasets}
|
|
193
|
+
|
|
194
|
+
return root_datasets
|
|
195
|
+
|
|
196
|
+
def get_upstream_datasets(self) -> list[Dataset]:
|
|
197
|
+
"""
|
|
198
|
+
Get the list of all upstream datasets that are sources to this dataset.
|
|
199
|
+
"""
|
|
200
|
+
# recursively get the upstream datasets
|
|
201
|
+
upstream = []
|
|
202
|
+
for source in self._sources:
|
|
203
|
+
upstream.extend(source.get_upstream_datasets())
|
|
204
|
+
upstream.append(source)
|
|
205
|
+
return upstream
|
|
206
|
+
|
|
207
|
+
def get_limit(self) -> int | None:
|
|
208
|
+
"""Get the limit applied to this Dataset, if any."""
|
|
209
|
+
if isinstance(self._operator, LimitScan):
|
|
210
|
+
return self._operator.limit
|
|
211
|
+
|
|
212
|
+
source_limits = []
|
|
213
|
+
for source in self._sources:
|
|
214
|
+
source_limit = source.get_limit()
|
|
215
|
+
if source_limit is not None:
|
|
216
|
+
source_limits.append(source_limit)
|
|
217
|
+
|
|
218
|
+
if len(source_limits) == 0:
|
|
219
|
+
return None
|
|
220
|
+
|
|
221
|
+
return min([limit for limit in source_limits if limit is not None])
|
|
222
|
+
|
|
223
|
+
def copy(self):
|
|
224
|
+
return Dataset(
|
|
225
|
+
sources=[source.copy() for source in self._sources],
|
|
226
|
+
operator=self._operator.copy(),
|
|
227
|
+
schema=self._schema,
|
|
228
|
+
id=self.id,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def sem_join(self, other: Dataset, condition: str, desc: str | None = None, depends_on: str | list[str] | None = None) -> Dataset:
|
|
232
|
+
"""
|
|
233
|
+
Perform a semantic (inner) join on the specified join predicate
|
|
234
|
+
"""
|
|
235
|
+
# enforce type for depends_on
|
|
236
|
+
if isinstance(depends_on, str):
|
|
237
|
+
depends_on = [depends_on]
|
|
238
|
+
|
|
239
|
+
# construct new output schema
|
|
240
|
+
combined_schema = union_schemas([self.schema, other.schema], join=True)
|
|
241
|
+
|
|
242
|
+
# construct logical operator
|
|
243
|
+
operator = JoinOp(
|
|
244
|
+
input_schema=combined_schema,
|
|
245
|
+
output_schema=combined_schema,
|
|
246
|
+
condition=condition,
|
|
247
|
+
desc=desc,
|
|
248
|
+
depends_on=depends_on,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return Dataset(sources=[self, other], operator=operator, schema=combined_schema)
|
|
252
|
+
|
|
253
|
+
def filter(
|
|
254
|
+
self,
|
|
255
|
+
filter: Callable,
|
|
256
|
+
depends_on: str | list[str] | None = None,
|
|
257
|
+
) -> Dataset:
|
|
258
|
+
"""Add a user defined function as a filter to the Set. This filter will possibly restrict the items that are returned later."""
|
|
259
|
+
# construct Filter object
|
|
260
|
+
f = None
|
|
261
|
+
if callable(filter):
|
|
262
|
+
f = Filter(filter_fn=filter)
|
|
263
|
+
else:
|
|
264
|
+
error_str = f"Only support callable for filter, currently got {type(filter)}"
|
|
265
|
+
if isinstance(filter, str):
|
|
266
|
+
error_str += ". Consider using sem_filter() for semantic filters."
|
|
267
|
+
raise Exception(error_str)
|
|
268
|
+
|
|
269
|
+
# enforce type for depends_on
|
|
270
|
+
if isinstance(depends_on, str):
|
|
271
|
+
depends_on = [depends_on]
|
|
272
|
+
|
|
273
|
+
# construct logical operator
|
|
274
|
+
operator = FilteredScan(input_schema=self.schema, output_schema=self.schema, filter=f, depends_on=depends_on)
|
|
275
|
+
|
|
276
|
+
return Dataset(sources=[self], operator=operator, schema=self.schema)
|
|
277
|
+
|
|
278
|
+
def sem_filter(
|
|
279
|
+
self,
|
|
280
|
+
filter: str,
|
|
281
|
+
desc: str | None = None,
|
|
282
|
+
depends_on: str | list[str] | None = None,
|
|
283
|
+
) -> Dataset:
|
|
284
|
+
"""Add a natural language description of a filter to the Set. This filter will possibly restrict the items that are returned later."""
|
|
285
|
+
# construct Filter object
|
|
286
|
+
f = None
|
|
287
|
+
if isinstance(filter, str):
|
|
288
|
+
f = Filter(filter)
|
|
289
|
+
else:
|
|
290
|
+
raise Exception("sem_filter() only supports `str` input for _filter.", type(filter))
|
|
291
|
+
|
|
292
|
+
# enforce type for depends_on
|
|
293
|
+
if isinstance(depends_on, str):
|
|
294
|
+
depends_on = [depends_on]
|
|
295
|
+
|
|
296
|
+
# construct logical operator
|
|
297
|
+
operator = FilteredScan(input_schema=self.schema, output_schema=self.schema, filter=f, desc=desc, depends_on=depends_on)
|
|
298
|
+
|
|
299
|
+
return Dataset(sources=[self], operator=operator, schema=self.schema)
|
|
300
|
+
|
|
301
|
+
def _sem_map(self, cols: list[dict] | type[BaseModel] | None,
|
|
302
|
+
cardinality: Cardinality,
|
|
303
|
+
desc: str | None = None,
|
|
304
|
+
depends_on: str | list[str] | None = None) -> Dataset:
|
|
305
|
+
"""Execute the semantic map operation with the appropriate cardinality."""
|
|
306
|
+
# construct new output schema
|
|
307
|
+
new_output_schema = None
|
|
308
|
+
if cols is None:
|
|
309
|
+
new_output_schema = self.schema
|
|
310
|
+
elif isinstance(cols, list):
|
|
311
|
+
cols = create_schema_from_fields(cols)
|
|
312
|
+
new_output_schema = union_schemas([self.schema, cols])
|
|
313
|
+
elif issubclass(cols, BaseModel):
|
|
314
|
+
new_output_schema = union_schemas([self.schema, cols])
|
|
315
|
+
else:
|
|
316
|
+
raise ValueError("`cols` must be a list of dictionaries or a BaseModel.")
|
|
317
|
+
|
|
318
|
+
# enforce type for depends_on
|
|
319
|
+
if isinstance(depends_on, str):
|
|
320
|
+
depends_on = [depends_on]
|
|
321
|
+
|
|
322
|
+
# construct logical operator
|
|
323
|
+
operator = ConvertScan(
|
|
324
|
+
input_schema=self.schema,
|
|
325
|
+
output_schema=new_output_schema,
|
|
326
|
+
cardinality=cardinality,
|
|
327
|
+
udf=None,
|
|
328
|
+
desc=desc,
|
|
329
|
+
depends_on=depends_on,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
return Dataset(sources=[self], operator=operator, schema=new_output_schema)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def sem_add_columns(self, cols: list[dict] | type[BaseModel],
|
|
336
|
+
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
337
|
+
desc: str | None = None,
|
|
338
|
+
depends_on: str | list[str] | None = None) -> Dataset:
|
|
339
|
+
"""
|
|
340
|
+
NOTE: we are renaming this function to `sem_map` and deprecating `sem_add_columns` in the next
|
|
341
|
+
release of PZ. To update your code, simply change your calls from `.sem_add_columns(...)` to `.sem_map(...)`.
|
|
342
|
+
The function arguments will stay the same.
|
|
343
|
+
|
|
344
|
+
Add new columns by specifying the column names, descriptions, and types.
|
|
345
|
+
The column will be computed during the execution of the Dataset.
|
|
346
|
+
Example:
|
|
347
|
+
sem_add_columns(
|
|
348
|
+
[{'name': 'greeting', 'desc': 'The greeting message', 'type': str},
|
|
349
|
+
{'name': 'age', 'desc': 'The age of the person', 'type': int},
|
|
350
|
+
{'name': 'full_name', 'desc': 'The name of the person', 'type': str}]
|
|
351
|
+
)
|
|
352
|
+
"""
|
|
353
|
+
# issue deprecation warning
|
|
354
|
+
warnings.warn(
|
|
355
|
+
"we are renaming this function to `sem_map` and deprecating `sem_add_columns` in the next"
|
|
356
|
+
" release of PZ. To update your code, simply change your calls from `.sem_add_columns(...)`"
|
|
357
|
+
" to `.sem_map(...)`. The function arguments will stay the same.",
|
|
358
|
+
DeprecationWarning,
|
|
359
|
+
stacklevel=2
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
return self._sem_map(cols, cardinality, desc, depends_on)
|
|
363
|
+
|
|
364
|
+
def sem_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, depends_on: str | list[str] | None = None) -> Dataset:
|
|
365
|
+
"""
|
|
366
|
+
Compute new field(s) by specifying their names, descriptions, and types. For each input there will
|
|
367
|
+
be one output. The field(s) will be computed during the execution of the Dataset.
|
|
368
|
+
|
|
369
|
+
Example:
|
|
370
|
+
sem_map(
|
|
371
|
+
[{'name': 'greeting', 'desc': 'The greeting message', 'type': str},
|
|
372
|
+
{'name': 'age', 'desc': 'The age of the person', 'type': int},
|
|
373
|
+
{'name': 'full_name', 'desc': 'The name of the person', 'type': str}]
|
|
374
|
+
)
|
|
375
|
+
"""
|
|
376
|
+
return self._sem_map(cols, Cardinality.ONE_TO_ONE, desc, depends_on)
|
|
377
|
+
|
|
378
|
+
def sem_flat_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, depends_on: str | list[str] | None = None) -> Dataset:
|
|
379
|
+
"""
|
|
380
|
+
Compute new field(s) by specifying their names, descriptions, and types. For each input there will
|
|
381
|
+
be one or more output(s). The field(s) will be computed during the execution of the Dataset.
|
|
382
|
+
|
|
383
|
+
Example:
|
|
384
|
+
sem_flat_map(
|
|
385
|
+
cols=[
|
|
386
|
+
{'name': 'author_name', 'description': 'The name of the author', 'type': str},
|
|
387
|
+
{'name': 'institution', 'description': 'The institution of the author', 'type': str},
|
|
388
|
+
{'name': 'email', 'description': 'The author's email', 'type': str},
|
|
389
|
+
]
|
|
390
|
+
)
|
|
391
|
+
"""
|
|
392
|
+
return self._sem_map(cols, Cardinality.ONE_TO_MANY, desc, depends_on)
|
|
393
|
+
|
|
394
|
+
def _map(self, udf: Callable,
|
|
395
|
+
cols: list[dict] | type[BaseModel] | None,
|
|
396
|
+
cardinality: Cardinality,
|
|
397
|
+
depends_on: str | list[str] | None = None) -> Dataset:
|
|
398
|
+
"""Execute the map operation with the appropriate cardinality."""
|
|
399
|
+
# construct new output schema
|
|
400
|
+
new_output_schema = None
|
|
401
|
+
if cols is None:
|
|
402
|
+
new_output_schema = self.schema
|
|
403
|
+
elif isinstance(cols, list):
|
|
404
|
+
cols = create_schema_from_fields(cols)
|
|
405
|
+
new_output_schema = union_schemas([self.schema, cols])
|
|
406
|
+
elif issubclass(cols, BaseModel):
|
|
407
|
+
new_output_schema = union_schemas([self.schema, cols])
|
|
408
|
+
else:
|
|
409
|
+
raise ValueError("`cols` must be a list of dictionaries, a BaseModel, or None.")
|
|
410
|
+
|
|
411
|
+
# enforce type for depends_on
|
|
412
|
+
if isinstance(depends_on, str):
|
|
413
|
+
depends_on = [depends_on]
|
|
414
|
+
|
|
415
|
+
# construct logical operator
|
|
416
|
+
operator = ConvertScan(
|
|
417
|
+
input_schema=self.schema,
|
|
418
|
+
output_schema=new_output_schema,
|
|
419
|
+
cardinality=cardinality,
|
|
420
|
+
udf=udf,
|
|
421
|
+
depends_on=depends_on,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
return Dataset(sources=[self], operator=operator, schema=new_output_schema)
|
|
425
|
+
|
|
426
|
+
def add_columns(self, udf: Callable,
|
|
427
|
+
cols: list[dict] | type[BaseModel] | None,
|
|
428
|
+
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
429
|
+
depends_on: str | list[str] | None = None) -> Dataset:
|
|
430
|
+
"""
|
|
431
|
+
NOTE: we are renaming this function to `map` and deprecating `add_columns` in the next
|
|
432
|
+
release of PZ. To update your code, simply change your calls from `.add_columns(...)` to `.map(...)`.
|
|
433
|
+
The function arguments will stay the same.
|
|
434
|
+
|
|
435
|
+
Compute new fields (or update existing ones) with a UDF. For each input, this function will compute one output.
|
|
436
|
+
|
|
437
|
+
Set `cols=None` if your add_columns operation is not computing any new fields.
|
|
438
|
+
|
|
439
|
+
Examples:
|
|
440
|
+
add_columns(
|
|
441
|
+
udf=compute_personal_greeting,
|
|
442
|
+
cols=[
|
|
443
|
+
{'name': 'greeting', 'description': 'The greeting message', 'type': str},
|
|
444
|
+
{'name': 'age', 'description': 'The age of the person', 'type': int},
|
|
445
|
+
{'name': 'full_name', 'description': 'The name of the person', 'type': str},
|
|
446
|
+
]
|
|
447
|
+
)
|
|
448
|
+
"""
|
|
449
|
+
# issue deprecation warning
|
|
450
|
+
warnings.warn(
|
|
451
|
+
"we are renaming this function to `map` and deprecating `add_columns` in the next"
|
|
452
|
+
" release of PZ. To update your code, simply change your calls from `.add_columns(...)`"
|
|
453
|
+
" to `.map(...)`. The function arguments will stay the same.",
|
|
454
|
+
DeprecationWarning,
|
|
455
|
+
stacklevel=2
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# sanity check inputs
|
|
459
|
+
if udf is None:
|
|
460
|
+
raise ValueError("`udf` must be provided for add_columns.")
|
|
461
|
+
|
|
462
|
+
return self._map(udf, cols, cardinality, depends_on)
|
|
463
|
+
|
|
464
|
+
def map(self, udf: Callable,
|
|
465
|
+
cols: list[dict] | type[BaseModel] | None,
|
|
466
|
+
depends_on: str | list[str] | None = None) -> Dataset:
|
|
467
|
+
"""
|
|
468
|
+
Compute new fields (or update existing ones) with a UDF. For each input, this function will compute one output.
|
|
469
|
+
|
|
470
|
+
Set `cols=None` if your map is not computing any new fields.
|
|
471
|
+
|
|
472
|
+
Examples:
|
|
473
|
+
map(
|
|
474
|
+
udf=compute_personal_greeting,
|
|
475
|
+
cols=[
|
|
476
|
+
{'name': 'greeting', 'description': 'The greeting message', 'type': str},
|
|
477
|
+
{'name': 'age', 'description': 'The age of the person', 'type': int},
|
|
478
|
+
{'name': 'full_name', 'description': 'The name of the person', 'type': str},
|
|
479
|
+
]
|
|
480
|
+
)
|
|
481
|
+
"""
|
|
482
|
+
# sanity check inputs
|
|
483
|
+
if udf is None:
|
|
484
|
+
raise ValueError("`udf` must be provided for map.")
|
|
485
|
+
|
|
486
|
+
return self._map(udf, cols, Cardinality.ONE_TO_ONE, depends_on)
|
|
487
|
+
|
|
488
|
+
def flat_map(self, udf: Callable,
|
|
489
|
+
cols: list[dict] | type[BaseModel] | None,
|
|
490
|
+
depends_on: str | list[str] | None = None) -> Dataset:
|
|
491
|
+
"""
|
|
492
|
+
Compute new fields (or update existing ones) with a UDF. For each input, this function will compute one or more outputs.
|
|
493
|
+
|
|
494
|
+
Set `cols=None` if your flat_map is not computing any new fields.
|
|
495
|
+
|
|
496
|
+
Examples:
|
|
497
|
+
flat_map(
|
|
498
|
+
udf=extract_paper_authors,
|
|
499
|
+
cols=[
|
|
500
|
+
{'name': 'author_name', 'description': 'The name of the author', 'type': str},
|
|
501
|
+
{'name': 'institution', 'description': 'The institution of the author', 'type': str},
|
|
502
|
+
{'name': 'email', 'description': 'The author's email', 'type': str},
|
|
503
|
+
]
|
|
504
|
+
)
|
|
505
|
+
"""
|
|
506
|
+
# sanity check inputs
|
|
507
|
+
if udf is None:
|
|
508
|
+
raise ValueError("`udf` must be provided for map.")
|
|
509
|
+
|
|
510
|
+
return self._map(udf, cols, Cardinality.ONE_TO_MANY, depends_on)
|
|
511
|
+
|
|
512
|
+
def count(self) -> Dataset:
|
|
513
|
+
"""Apply a count aggregation to this set"""
|
|
514
|
+
operator = Aggregate(input_schema=self.schema, agg_func=AggFunc.COUNT)
|
|
515
|
+
return Dataset(sources=[self], operator=operator, schema=operator.output_schema)
|
|
516
|
+
|
|
517
|
+
def average(self) -> Dataset:
|
|
518
|
+
"""Apply an average aggregation to this set"""
|
|
519
|
+
operator = Aggregate(input_schema=self.schema, agg_func=AggFunc.AVERAGE)
|
|
520
|
+
return Dataset(sources=[self], operator=operator, schema=operator.output_schema)
|
|
521
|
+
|
|
522
|
+
def groupby(self, groupby: GroupBySig) -> Dataset:
|
|
523
|
+
output_schema = groupby.output_schema()
|
|
524
|
+
operator = GroupByAggregate(input_schema=self.schema, output_schema=output_schema, group_by_sig=groupby)
|
|
525
|
+
return Dataset(sources=[self], operator=operator, schema=output_schema)
|
|
526
|
+
|
|
527
|
+
def retrieve(
|
|
528
|
+
self,
|
|
529
|
+
index: Collection,
|
|
530
|
+
search_attr: str,
|
|
531
|
+
output_attrs: list[dict] | type[BaseModel],
|
|
532
|
+
search_func: Callable | None = None,
|
|
533
|
+
k: int = -1,
|
|
534
|
+
) -> Dataset:
|
|
535
|
+
"""
|
|
536
|
+
Retrieve the top-k nearest neighbors of the value of the `search_attr` from the `index` and
|
|
537
|
+
use these results to construct the `output_attrs` field(s).
|
|
538
|
+
"""
|
|
539
|
+
# construct new output schema
|
|
540
|
+
new_output_schema = None
|
|
541
|
+
if isinstance(output_attrs, list):
|
|
542
|
+
output_attrs = create_schema_from_fields(output_attrs)
|
|
543
|
+
new_output_schema = union_schemas([self.schema, output_attrs])
|
|
544
|
+
elif issubclass(output_attrs, BaseModel):
|
|
545
|
+
new_output_schema = union_schemas([self.schema, output_attrs])
|
|
546
|
+
else:
|
|
547
|
+
raise ValueError("`output_attrs` must be a list of dictionaries or a BaseModel.")
|
|
548
|
+
|
|
549
|
+
# TODO: revisit once we can think through abstraction(s)
|
|
550
|
+
# # construct the PZIndex from the user-provided index
|
|
551
|
+
# index = index_factory(index)
|
|
552
|
+
|
|
553
|
+
# construct logical operator
|
|
554
|
+
operator = RetrieveScan(
|
|
555
|
+
input_schema=self.schema,
|
|
556
|
+
output_schema=new_output_schema,
|
|
557
|
+
index=index,
|
|
558
|
+
search_func=search_func,
|
|
559
|
+
search_attr=search_attr,
|
|
560
|
+
output_attrs=output_attrs,
|
|
561
|
+
k=k,
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
return Dataset(sources=[self], operator=operator, schema=new_output_schema)
|
|
565
|
+
|
|
566
|
+
def limit(self, n: int) -> Dataset:
|
|
567
|
+
"""Limit the set size to no more than n rows"""
|
|
568
|
+
operator = LimitScan(input_schema=self.schema, output_schema=self.schema, limit=n)
|
|
569
|
+
return Dataset(sources=[self], operator=operator, schema=self.schema)
|
|
570
|
+
|
|
571
|
+
def distinct(self, distinct_cols: list[str] | None = None) -> Dataset:
|
|
572
|
+
"""Return a new Dataset with distinct rows based on the current schema."""
|
|
573
|
+
operator = Distinct(input_schema=self.schema, output_schema=self.schema, distinct_cols=distinct_cols)
|
|
574
|
+
return Dataset(sources=[self], operator=operator, schema=self.schema)
|
|
575
|
+
|
|
576
|
+
def project(self, project_cols: list[str] | str) -> Dataset:
|
|
577
|
+
"""Project the Set to only include the specified columns."""
|
|
578
|
+
project_cols = project_cols if isinstance(project_cols, list) else [project_cols]
|
|
579
|
+
new_output_schema = project(self.schema, project_cols)
|
|
580
|
+
operator = Project(input_schema=self.schema, output_schema=new_output_schema, project_cols=project_cols)
|
|
581
|
+
return Dataset(sources=[self], operator=operator, schema=new_output_schema)
|
|
582
|
+
|
|
583
|
+
def run(self, config: QueryProcessorConfig | None = None, **kwargs):
|
|
584
|
+
"""Invoke the QueryProcessor to execute the query. `kwargs` will be applied to the QueryProcessorConfig."""
|
|
585
|
+
# TODO: this import currently needs to be here to avoid a circular import; we should fix this in a subsequent PR
|
|
586
|
+
from palimpzest.query.processor.query_processor_factory import QueryProcessorFactory
|
|
587
|
+
|
|
588
|
+
# as syntactic sugar, we will allow some keyword arguments to parameterize our policies
|
|
589
|
+
policy = construct_policy_from_kwargs(**kwargs)
|
|
590
|
+
if policy is not None:
|
|
591
|
+
kwargs["policy"] = policy
|
|
592
|
+
|
|
593
|
+
# construct unique logical op ids for all operators in this dataset
|
|
594
|
+
self._generate_unique_logical_op_ids()
|
|
595
|
+
|
|
596
|
+
return QueryProcessorFactory.create_and_run_processor(self, config)
|
|
597
|
+
|
|
598
|
+
def optimize_and_run(self, train_dataset: dict[str, Dataset] | Dataset | None = None, validator: Validator | None = None, config: QueryProcessorConfig | None = None, **kwargs):
|
|
599
|
+
"""Optimize the PZ program using the train_dataset and validator before running the optimized plan."""
|
|
600
|
+
# TODO: this import currently needs to be here to avoid a circular import; we should fix this in a subsequent PR
|
|
601
|
+
from palimpzest.query.processor.query_processor_factory import QueryProcessorFactory
|
|
602
|
+
|
|
603
|
+
# confirm that either train_dataset or validator is provided
|
|
604
|
+
assert train_dataset is not None or validator is not None, "Must provide at least one of train_dataset or validator to use optimize_and_run()"
|
|
605
|
+
|
|
606
|
+
# validate the train_dataset has one input for each source dataset and normalize its type to be a dict
|
|
607
|
+
if train_dataset is not None:
|
|
608
|
+
root_datasets = self._get_root_datasets()
|
|
609
|
+
if isinstance(train_dataset, Dataset) and len(root_datasets) > 1:
|
|
610
|
+
raise ValueError(
|
|
611
|
+
"For plans with more than one root dataset, `train_dataset` must be a dictionary mapping"
|
|
612
|
+
" {'dataset_id' --> Dataset} for all root Datasets"
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
elif isinstance(train_dataset, Dataset):
|
|
616
|
+
root_dataset_id = list(root_datasets.values())[0].id
|
|
617
|
+
if train_dataset.id != root_dataset_id:
|
|
618
|
+
warnings.warn(
|
|
619
|
+
f"train_dataset.id={train_dataset.id} does not match root dataset id={root_dataset_id}\n"
|
|
620
|
+
f"Setting train_dataset to be the training data for root dataset with id={root_dataset_id} anyways.",
|
|
621
|
+
stacklevel=2,
|
|
622
|
+
)
|
|
623
|
+
train_dataset = {root_dataset_id: train_dataset}
|
|
624
|
+
|
|
625
|
+
elif not all(dataset_id in train_dataset for dataset_id in root_datasets):
|
|
626
|
+
missing_ids = [dataset_id for dataset_id in root_datasets if dataset_id not in train_dataset]
|
|
627
|
+
raise ValueError(
|
|
628
|
+
f"`train_dataset` is missing the following root dataset id(s): {missing_ids}"
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
# as syntactic sugar, we will allow some keyword arguments to parameterize our policies
|
|
632
|
+
policy = construct_policy_from_kwargs(**kwargs)
|
|
633
|
+
if policy is not None:
|
|
634
|
+
kwargs["policy"] = policy
|
|
635
|
+
|
|
636
|
+
# construct unique logical op ids for all operators in this dataset
|
|
637
|
+
self._generate_unique_logical_op_ids()
|
|
638
|
+
|
|
639
|
+
return QueryProcessorFactory.create_and_run_processor(self, config, train_dataset, validator)
|