palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.20.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,634 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from collections.abc import Iterator
5
+ from typing import Callable
6
+
7
+ from chromadb.api.models.Collection import Collection
8
+ from pydantic import BaseModel
9
+
10
+ from palimpzest.constants import AggFunc, Cardinality
11
+ from palimpzest.core.elements.filters import Filter
12
+ from palimpzest.core.elements.groupbysig import GroupBySig
13
+ from palimpzest.core.lib.schemas import create_schema_from_fields, project, union_schemas
14
+ from palimpzest.policy import construct_policy_from_kwargs
15
+ from palimpzest.query.operators.logical import (
16
+ Aggregate,
17
+ ConvertScan,
18
+ Distinct,
19
+ FilteredScan,
20
+ GroupByAggregate,
21
+ JoinOp,
22
+ LimitScan,
23
+ LogicalOperator,
24
+ Project,
25
+ RetrieveScan,
26
+ )
27
+ from palimpzest.query.processor.config import QueryProcessorConfig
28
+ from palimpzest.utils.hash_helpers import hash_for_serialized_dict
29
+ from palimpzest.validator.validator import Validator
30
+
31
+
32
+ # TODO?: remove `schema` from `Dataset` and access it from `operator`?
33
+ # - Q: how do you handle datasets with multiple sources?
34
+ # - for joins the operator should have the union'ed schema
35
+ # - but for Contexts it may be trickier
36
+ class Dataset:
37
+ """
38
+ A `Dataset` represents a collection of structured or unstructured data that can be processed and
39
+ transformed. Each `Dataset` is either a "root" `Dataset` (which yields data items) or it is the
40
+ result of performing data processing operation(s) on root `Dataset(s)`.
41
+
42
+ Users can perform computations on a `Dataset` in a lazy or eager fashion. Applying functions
43
+ such as `sem_filter`, `sem_map`, `sem_join`, `sem_agg`, etc. will lazily create a new `Dataset`.
44
+ Users can invoke the `run()` method to execute the computation and retrieve a materialized `Dataset`.
45
+ Materialized `Dataset`s can be processed further, or their results can be retrieved using `.get()`.
46
+
47
+ A root `Dataset` must subclass at least one of `pz.IterDataset`, `pz.IndexDataset`, or `pz.Context`.
48
+ Each of these classes supports a different access pattern:
49
+
50
+ - `pz.IterDataset`: supports accessing data via iteration
51
+ - Ex: iterating over a list of PDFs
52
+ - Ex: iterating over rows in a DataFrame
53
+ - `pz.IndexDataset`: supports accessing data via point lookups / queries
54
+ - Ex: querying a vector database
55
+ - Ex: querying a SQL database
56
+ - `pz.Context`: supports accessing data with an agent
57
+ - Ex: processing a set of CSV files with a data science agent
58
+ - Ex: processing time series data with a data cleaning agent
59
+
60
+ A root `Dataset` may subclass more than one of the aforementioned classes. For example, the root
61
+ `Dataset` for a list of files may inherit from `pz.IterDataset` and `pz.IndexDataset` to support
62
+ iterating over the files and performing point lookups for individual files.
63
+
64
+ For details on how to create your own root `Dataset`, please see: TODO
65
+ """
66
+ def __init__(
67
+ self,
68
+ sources: list[Dataset] | Dataset | None,
69
+ operator: LogicalOperator,
70
+ schema: type[BaseModel] | None = None,
71
+ id: str | None = None,
72
+ ) -> None:
73
+ """
74
+ Initialize a `Dataset` with one or more `sources` and the operator that is being applied.
75
+ Root `Datasets` subclass `pz.IterDataset`, `pz.IndexDataset`, and/or `pz.Context` and use
76
+ their own constructors.
77
+
78
+ Args:
79
+ sources (`list[Dataset] | Dataset`): The (list of) `Dataset(s)` which are input(s) to
80
+ the operator used to compute this `Dataset`.
81
+ operator (`LogicalOperator`): The `LogicalOperator` used to compute this `Dataset`.
82
+ schema (type[`BaseModel`] | None): The schema of this `Dataset`.
83
+ id (str | None): an identifier for this `Dataset` provided by the user
84
+
85
+ Raises:
86
+ ValueError: if `sources` is not a `Dataset` or list of `Datasets`
87
+ """
88
+ # set sources
89
+ self._sources = []
90
+ if isinstance(sources, list):
91
+ self._sources = sources
92
+ elif isinstance(sources, Dataset):
93
+ self._sources = [sources]
94
+ elif sources is not None:
95
+ raise ValueError("Dataset sources must be another Dataset or a list of Datasets. For root Datasets, you must subclass pz.IterDataset, pz.IndexDataset, or pz.Context.")
96
+
97
+ # set the logical operator and schema
98
+ self._operator: LogicalOperator = operator
99
+ self._schema = schema
100
+
101
+ # compute the dataset id
102
+ self._id = self._compute_dataset_id() if id is None else id
103
+
104
+ @property
105
+ def id(self) -> str:
106
+ """The string identifier for this `Dataset`"""
107
+ return self._id
108
+
109
+ @property
110
+ def schema(self) -> type[BaseModel]:
111
+ """The Pydantic model defining the schema of this `Dataset`"""
112
+ return self._schema
113
+
114
+ @property
115
+ def is_root(self) -> bool:
116
+ return len(self._sources) == 0
117
+
118
+ def __str__(self) -> str:
119
+ return f"Dataset(schema={self._schema}, id={self._id}, op_id={self._operator.get_logical_op_id()})"
120
+
121
+ def __iter__(self) -> Iterator[Dataset]:
122
+ for source in self._sources:
123
+ yield from source
124
+ yield self
125
+
126
+ def _compute_dataset_id(self) -> str:
127
+ """
128
+ Compute the identifier for this `Dataset`. The ID is uniquely defined by the operation(s)
129
+ applied to the `Dataset's` sources.
130
+ """
131
+ return hash_for_serialized_dict({
132
+ "source_ids": [source.id for source in self._sources],
133
+ "logical_op_id": self._operator.get_logical_op_id(),
134
+ })
135
+
136
+ def _set_root_datasets(self, new_root_datasets: dict[str, Dataset]) -> None:
137
+ """
138
+ Update the root dataset(s) for this dataset with the `new_root_datasets`. This is used during
139
+ optimization to reuse the same physical plan while running it on a train dataset.
140
+
141
+ Args:
142
+ new_root_datasets (dict[str, Dataset]): the new root datasets for this dataset.
143
+ """
144
+ new_sources = []
145
+ for old_source in self._sources:
146
+ if old_source.id in new_root_datasets:
147
+ new_sources.append(new_root_datasets[old_source.id])
148
+ else:
149
+ old_source._set_root_datasets(new_root_datasets)
150
+ new_sources.append(old_source)
151
+ self._sources = new_sources
152
+
153
+ # TODO: the entire way (unique) logical op ids are computed and stored needs to be rethought
154
+ def _generate_unique_logical_op_ids(self, topo_idx: int | None = None) -> None:
155
+ """
156
+ Generate unique operation IDs for all operators in this dataset and its sources.
157
+ This is used to ensure that each operator can be uniquely identified during execution.
158
+ """
159
+ # generate the unique op ids for all sources' operators
160
+ for source in self._sources:
161
+ topo_idx = source._generate_unique_logical_op_ids(topo_idx)
162
+ topo_idx += 1
163
+
164
+ # if topo_idx is None, this is the first call, so we initialize it to 0
165
+ if topo_idx is None:
166
+ topo_idx = 0
167
+
168
+ # compute this operator's unique operator id
169
+ this_unique_logical_op_id = f"{topo_idx}-{self._operator.get_logical_op_id()}"
170
+
171
+ # update the unique logical op id for this operator
172
+ self._operator.set_unique_logical_op_id(this_unique_logical_op_id)
173
+
174
+ # return the current unique full_op_id for this operator
175
+ return topo_idx
176
+
177
+ # TODO
178
+ def _resolve_depends_on(self, depends_on: list[str]) -> list[str]:
179
+ """
180
+ TODO: resolve the `depends_on` strings to their full field names ({Dataset.id}.{field_name}).
181
+ """
182
+ return []
183
+
184
+ def _get_root_datasets(self) -> dict[str, Dataset]:
185
+ """Return a mapping from the id --> Dataset for all root datasets."""
186
+ if self.is_root:
187
+ return {self.id: self}
188
+
189
+ root_datasets = {}
190
+ for source in self._sources:
191
+ child_root_datasets = source._get_root_datasets()
192
+ root_datasets = {**root_datasets, **child_root_datasets}
193
+
194
+ return root_datasets
195
+
196
+ def get_upstream_datasets(self) -> list[Dataset]:
197
+ """
198
+ Get the list of all upstream datasets that are sources to this dataset.
199
+ """
200
+ # recursively get the upstream datasets
201
+ upstream = []
202
+ for source in self._sources:
203
+ upstream.extend(source.get_upstream_datasets())
204
+ upstream.append(source)
205
+ return upstream
206
+
207
+ def get_limit(self) -> int | None:
208
+ """Get the limit applied to this Dataset, if any."""
209
+ if isinstance(self._operator, LimitScan):
210
+ return self._operator.limit
211
+
212
+ source_limits = []
213
+ for source in self._sources:
214
+ source_limit = source.get_limit()
215
+ if source_limit is not None:
216
+ source_limits.append(source_limit)
217
+
218
+ if len(source_limits) == 0:
219
+ return None
220
+
221
+ return min([limit for limit in source_limits if limit is not None])
222
+
223
+ def copy(self):
224
+ return Dataset(
225
+ sources=[source.copy() for source in self._sources],
226
+ operator=self._operator.copy(),
227
+ schema=self._schema,
228
+ id=self.id,
229
+ )
230
+
231
+ def sem_join(self, other: Dataset, condition: str, depends_on: str | list[str] | None = None) -> Dataset:
232
+ """
233
+ Perform a semantic (inner) join on the specified join predicate
234
+ """
235
+ # enforce type for depends_on
236
+ if isinstance(depends_on, str):
237
+ depends_on = [depends_on]
238
+
239
+ # construct new output schema
240
+ combined_schema = union_schemas([self.schema, other.schema], join=True)
241
+
242
+ # construct logical operator
243
+ operator = JoinOp(
244
+ input_schema=combined_schema,
245
+ output_schema=combined_schema,
246
+ condition=condition,
247
+ depends_on=depends_on,
248
+ )
249
+
250
+ return Dataset(sources=[self, other], operator=operator, schema=combined_schema)
251
+
252
+ def filter(
253
+ self,
254
+ filter: Callable,
255
+ depends_on: str | list[str] | None = None,
256
+ ) -> Dataset:
257
+ """Add a user defined function as a filter to the Set. This filter will possibly restrict the items that are returned later."""
258
+ # construct Filter object
259
+ f = None
260
+ if callable(filter):
261
+ f = Filter(filter_fn=filter)
262
+ else:
263
+ error_str = f"Only support callable for filter, currently got {type(filter)}"
264
+ if isinstance(filter, str):
265
+ error_str += ". Consider using sem_filter() for semantic filters."
266
+ raise Exception(error_str)
267
+
268
+ # enforce type for depends_on
269
+ if isinstance(depends_on, str):
270
+ depends_on = [depends_on]
271
+
272
+ # construct logical operator
273
+ operator = FilteredScan(input_schema=self.schema, output_schema=self.schema, filter=f, depends_on=depends_on)
274
+
275
+ return Dataset(sources=[self], operator=operator, schema=self.schema)
276
+
277
+ def sem_filter(
278
+ self,
279
+ filter: str,
280
+ depends_on: str | list[str] | None = None,
281
+ ) -> Dataset:
282
+ """Add a natural language description of a filter to the Set. This filter will possibly restrict the items that are returned later."""
283
+ # construct Filter object
284
+ f = None
285
+ if isinstance(filter, str):
286
+ f = Filter(filter)
287
+ else:
288
+ raise Exception("sem_filter() only supports `str` input for _filter.", type(filter))
289
+
290
+ # enforce type for depends_on
291
+ if isinstance(depends_on, str):
292
+ depends_on = [depends_on]
293
+
294
+ # construct logical operator
295
+ operator = FilteredScan(input_schema=self.schema, output_schema=self.schema, filter=f, depends_on=depends_on)
296
+
297
+ return Dataset(sources=[self], operator=operator, schema=self.schema)
298
+
299
+ def _sem_map(self, cols: list[dict] | type[BaseModel] | None,
300
+ cardinality: Cardinality,
301
+ depends_on: str | list[str] | None = None) -> Dataset:
302
+ """Execute the semantic map operation with the appropriate cardinality."""
303
+ # construct new output schema
304
+ new_output_schema = None
305
+ if cols is None:
306
+ new_output_schema = self.schema
307
+ elif isinstance(cols, list):
308
+ cols = create_schema_from_fields(cols)
309
+ new_output_schema = union_schemas([self.schema, cols])
310
+ elif issubclass(cols, BaseModel):
311
+ new_output_schema = union_schemas([self.schema, cols])
312
+ else:
313
+ raise ValueError("`cols` must be a list of dictionaries or a BaseModel.")
314
+
315
+ # enforce type for depends_on
316
+ if isinstance(depends_on, str):
317
+ depends_on = [depends_on]
318
+
319
+ # construct logical operator
320
+ operator = ConvertScan(
321
+ input_schema=self.schema,
322
+ output_schema=new_output_schema,
323
+ cardinality=cardinality,
324
+ udf=None,
325
+ depends_on=depends_on,
326
+ )
327
+
328
+ return Dataset(sources=[self], operator=operator, schema=new_output_schema)
329
+
330
+
331
+ def sem_add_columns(self, cols: list[dict] | type[BaseModel],
332
+ cardinality: Cardinality = Cardinality.ONE_TO_ONE,
333
+ depends_on: str | list[str] | None = None) -> Dataset:
334
+ """
335
+ NOTE: we are renaming this function to `sem_map` and deprecating `sem_add_columns` in the next
336
+ release of PZ. To update your code, simply change your calls from `.sem_add_columns(...)` to `.sem_map(...)`.
337
+ The function arguments will stay the same.
338
+
339
+ Add new columns by specifying the column names, descriptions, and types.
340
+ The column will be computed during the execution of the Dataset.
341
+ Example:
342
+ sem_add_columns(
343
+ [{'name': 'greeting', 'desc': 'The greeting message', 'type': str},
344
+ {'name': 'age', 'desc': 'The age of the person', 'type': int},
345
+ {'name': 'full_name', 'desc': 'The name of the person', 'type': str}]
346
+ )
347
+ """
348
+ # issue deprecation warning
349
+ warnings.warn(
350
+ "we are renaming this function to `sem_map` and deprecating `sem_add_columns` in the next"
351
+ " release of PZ. To update your code, simply change your calls from `.sem_add_columns(...)`"
352
+ " to `.sem_map(...)`. The function arguments will stay the same.",
353
+ DeprecationWarning,
354
+ stacklevel=2
355
+ )
356
+
357
+ return self._sem_map(cols, cardinality, depends_on)
358
+
359
+ def sem_map(self, cols: list[dict] | type[BaseModel], depends_on: str | list[str] | None = None) -> Dataset:
360
+ """
361
+ Compute new field(s) by specifying their names, descriptions, and types. For each input there will
362
+ be one output. The field(s) will be computed during the execution of the Dataset.
363
+
364
+ Example:
365
+ sem_map(
366
+ [{'name': 'greeting', 'desc': 'The greeting message', 'type': str},
367
+ {'name': 'age', 'desc': 'The age of the person', 'type': int},
368
+ {'name': 'full_name', 'desc': 'The name of the person', 'type': str}]
369
+ )
370
+ """
371
+ return self._sem_map(cols, Cardinality.ONE_TO_ONE, depends_on)
372
+
373
+ def sem_flat_map(self, cols: list[dict] | type[BaseModel], depends_on: str | list[str] | None = None) -> Dataset:
374
+ """
375
+ Compute new field(s) by specifying their names, descriptions, and types. For each input there will
376
+ be one or more output(s). The field(s) will be computed during the execution of the Dataset.
377
+
378
+ Example:
379
+ sem_flat_map(
380
+ cols=[
381
+ {'name': 'author_name', 'description': 'The name of the author', 'type': str},
382
+ {'name': 'institution', 'description': 'The institution of the author', 'type': str},
383
+ {'name': 'email', 'description': 'The author's email', 'type': str},
384
+ ]
385
+ )
386
+ """
387
+ return self._sem_map(cols, Cardinality.ONE_TO_MANY, depends_on)
388
+
389
+ def _map(self, udf: Callable,
390
+ cols: list[dict] | type[BaseModel] | None,
391
+ cardinality: Cardinality,
392
+ depends_on: str | list[str] | None = None) -> Dataset:
393
+ """Execute the map operation with the appropriate cardinality."""
394
+ # construct new output schema
395
+ new_output_schema = None
396
+ if cols is None:
397
+ new_output_schema = self.schema
398
+ elif isinstance(cols, list):
399
+ cols = create_schema_from_fields(cols)
400
+ new_output_schema = union_schemas([self.schema, cols])
401
+ elif issubclass(cols, BaseModel):
402
+ new_output_schema = union_schemas([self.schema, cols])
403
+ else:
404
+ raise ValueError("`cols` must be a list of dictionaries, a BaseModel, or None.")
405
+
406
+ # enforce type for depends_on
407
+ if isinstance(depends_on, str):
408
+ depends_on = [depends_on]
409
+
410
+ # construct logical operator
411
+ operator = ConvertScan(
412
+ input_schema=self.schema,
413
+ output_schema=new_output_schema,
414
+ cardinality=cardinality,
415
+ udf=udf,
416
+ depends_on=depends_on,
417
+ )
418
+
419
+ return Dataset(sources=[self], operator=operator, schema=new_output_schema)
420
+
421
+ def add_columns(self, udf: Callable,
422
+ cols: list[dict] | type[BaseModel] | None,
423
+ cardinality: Cardinality = Cardinality.ONE_TO_ONE,
424
+ depends_on: str | list[str] | None = None) -> Dataset:
425
+ """
426
+ NOTE: we are renaming this function to `map` and deprecating `add_columns` in the next
427
+ release of PZ. To update your code, simply change your calls from `.add_columns(...)` to `.map(...)`.
428
+ The function arguments will stay the same.
429
+
430
+ Compute new fields (or update existing ones) with a UDF. For each input, this function will compute one output.
431
+
432
+ Set `cols=None` if your add_columns operation is not computing any new fields.
433
+
434
+ Examples:
435
+ add_columns(
436
+ udf=compute_personal_greeting,
437
+ cols=[
438
+ {'name': 'greeting', 'description': 'The greeting message', 'type': str},
439
+ {'name': 'age', 'description': 'The age of the person', 'type': int},
440
+ {'name': 'full_name', 'description': 'The name of the person', 'type': str},
441
+ ]
442
+ )
443
+ """
444
+ # issue deprecation warning
445
+ warnings.warn(
446
+ "we are renaming this function to `map` and deprecating `add_columns` in the next"
447
+ " release of PZ. To update your code, simply change your calls from `.add_columns(...)`"
448
+ " to `.map(...)`. The function arguments will stay the same.",
449
+ DeprecationWarning,
450
+ stacklevel=2
451
+ )
452
+
453
+ # sanity check inputs
454
+ if udf is None:
455
+ raise ValueError("`udf` must be provided for add_columns.")
456
+
457
+ return self._map(udf, cols, cardinality, depends_on)
458
+
459
+ def map(self, udf: Callable,
460
+ cols: list[dict] | type[BaseModel] | None,
461
+ depends_on: str | list[str] | None = None) -> Dataset:
462
+ """
463
+ Compute new fields (or update existing ones) with a UDF. For each input, this function will compute one output.
464
+
465
+ Set `cols=None` if your map is not computing any new fields.
466
+
467
+ Examples:
468
+ map(
469
+ udf=compute_personal_greeting,
470
+ cols=[
471
+ {'name': 'greeting', 'description': 'The greeting message', 'type': str},
472
+ {'name': 'age', 'description': 'The age of the person', 'type': int},
473
+ {'name': 'full_name', 'description': 'The name of the person', 'type': str},
474
+ ]
475
+ )
476
+ """
477
+ # sanity check inputs
478
+ if udf is None:
479
+ raise ValueError("`udf` must be provided for map.")
480
+
481
+ return self._map(udf, cols, Cardinality.ONE_TO_ONE, depends_on)
482
+
483
+ def flat_map(self, udf: Callable,
484
+ cols: list[dict] | type[BaseModel] | None,
485
+ depends_on: str | list[str] | None = None) -> Dataset:
486
+ """
487
+ Compute new fields (or update existing ones) with a UDF. For each input, this function will compute one or more outputs.
488
+
489
+ Set `cols=None` if your flat_map is not computing any new fields.
490
+
491
+ Examples:
492
+ flat_map(
493
+ udf=extract_paper_authors,
494
+ cols=[
495
+ {'name': 'author_name', 'description': 'The name of the author', 'type': str},
496
+ {'name': 'institution', 'description': 'The institution of the author', 'type': str},
497
+ {'name': 'email', 'description': 'The author's email', 'type': str},
498
+ ]
499
+ )
500
+ """
501
+ # sanity check inputs
502
+ if udf is None:
503
+ raise ValueError("`udf` must be provided for map.")
504
+
505
+ return self._map(udf, cols, Cardinality.ONE_TO_MANY, depends_on)
506
+
507
+ def count(self) -> Dataset:
508
+ """Apply a count aggregation to this set"""
509
+ operator = Aggregate(input_schema=self.schema, agg_func=AggFunc.COUNT)
510
+ return Dataset(sources=[self], operator=operator, schema=operator.output_schema)
511
+
512
+ def average(self) -> Dataset:
513
+ """Apply an average aggregation to this set"""
514
+ operator = Aggregate(input_schema=self.schema, agg_func=AggFunc.AVERAGE)
515
+ return Dataset(sources=[self], operator=operator, schema=operator.output_schema)
516
+
517
+ def groupby(self, groupby: GroupBySig) -> Dataset:
518
+ output_schema = groupby.output_schema()
519
+ operator = GroupByAggregate(input_schema=self.schema, output_schema=output_schema, group_by_sig=groupby)
520
+ return Dataset(sources=[self], operator=operator, schema=output_schema)
521
+
522
+ def retrieve(
523
+ self,
524
+ index: Collection,
525
+ search_attr: str,
526
+ output_attrs: list[dict] | type[BaseModel],
527
+ search_func: Callable | None = None,
528
+ k: int = -1,
529
+ ) -> Dataset:
530
+ """
531
+ Retrieve the top-k nearest neighbors of the value of the `search_attr` from the `index` and
532
+ use these results to construct the `output_attrs` field(s).
533
+ """
534
+ # construct new output schema
535
+ new_output_schema = None
536
+ if isinstance(output_attrs, list):
537
+ output_attrs = create_schema_from_fields(output_attrs)
538
+ new_output_schema = union_schemas([self.schema, output_attrs])
539
+ elif issubclass(output_attrs, BaseModel):
540
+ new_output_schema = union_schemas([self.schema, output_attrs])
541
+ else:
542
+ raise ValueError("`output_attrs` must be a list of dictionaries or a BaseModel.")
543
+
544
+ # TODO: revisit once we can think through abstraction(s)
545
+ # # construct the PZIndex from the user-provided index
546
+ # index = index_factory(index)
547
+
548
+ # construct logical operator
549
+ operator = RetrieveScan(
550
+ input_schema=self.schema,
551
+ output_schema=new_output_schema,
552
+ index=index,
553
+ search_func=search_func,
554
+ search_attr=search_attr,
555
+ output_attrs=output_attrs,
556
+ k=k,
557
+ )
558
+
559
+ return Dataset(sources=[self], operator=operator, schema=new_output_schema)
560
+
561
+ def limit(self, n: int) -> Dataset:
562
+ """Limit the set size to no more than n rows"""
563
+ operator = LimitScan(input_schema=self.schema, output_schema=self.schema, limit=n)
564
+ return Dataset(sources=[self], operator=operator, schema=self.schema)
565
+
566
+ def distinct(self, distinct_cols: list[str] | None = None) -> Dataset:
567
+ """Return a new Dataset with distinct rows based on the current schema."""
568
+ operator = Distinct(input_schema=self.schema, output_schema=self.schema, distinct_cols=distinct_cols)
569
+ return Dataset(sources=[self], operator=operator, schema=self.schema)
570
+
571
+ def project(self, project_cols: list[str] | str) -> Dataset:
572
+ """Project the Set to only include the specified columns."""
573
+ project_cols = project_cols if isinstance(project_cols, list) else [project_cols]
574
+ new_output_schema = project(self.schema, project_cols)
575
+ operator = Project(input_schema=self.schema, output_schema=new_output_schema, project_cols=project_cols)
576
+ return Dataset(sources=[self], operator=operator, schema=new_output_schema)
577
+
578
+ def run(self, config: QueryProcessorConfig | None = None, **kwargs):
579
+ """Invoke the QueryProcessor to execute the query. `kwargs` will be applied to the QueryProcessorConfig."""
580
+ # TODO: this import currently needs to be here to avoid a circular import; we should fix this in a subsequent PR
581
+ from palimpzest.query.processor.query_processor_factory import QueryProcessorFactory
582
+
583
+ # as syntactic sugar, we will allow some keyword arguments to parameterize our policies
584
+ policy = construct_policy_from_kwargs(**kwargs)
585
+ if policy is not None:
586
+ kwargs["policy"] = policy
587
+
588
+ # construct unique logical op ids for all operators in this dataset
589
+ self._generate_unique_logical_op_ids()
590
+
591
+ return QueryProcessorFactory.create_and_run_processor(self, config)
592
+
593
+ def optimize_and_run(self, train_dataset: dict[str, Dataset] | Dataset | None = None, validator: Validator | None = None, config: QueryProcessorConfig | None = None, **kwargs):
594
+ """Optimize the PZ program using the train_dataset and validator before running the optimized plan."""
595
+ # TODO: this import currently needs to be here to avoid a circular import; we should fix this in a subsequent PR
596
+ from palimpzest.query.processor.query_processor_factory import QueryProcessorFactory
597
+
598
+ # confirm that either train_dataset or validator is provided
599
+ assert train_dataset is not None or validator is not None, "Must provide at least one of train_dataset or validator to use optimize_and_run()"
600
+
601
+ # validate the train_dataset has one input for each source dataset and normalize its type to be a dict
602
+ if train_dataset is not None:
603
+ root_datasets = self._get_root_datasets()
604
+ if isinstance(train_dataset, Dataset) and len(root_datasets) > 1:
605
+ raise ValueError(
606
+ "For plans with more than one root dataset, `train_dataset` must be a dictionary mapping"
607
+ " {'dataset_id' --> Dataset} for all root Datasets"
608
+ )
609
+
610
+ elif isinstance(train_dataset, Dataset):
611
+ root_dataset_id = list(root_datasets.values())[0].id
612
+ if train_dataset.id != root_dataset_id:
613
+ warnings.warn(
614
+ f"train_dataset.id={train_dataset.id} does not match root dataset id={root_dataset_id}\n"
615
+ f"Setting train_dataset to be the training data for root dataset with id={root_dataset_id} anyways.",
616
+ stacklevel=2,
617
+ )
618
+ train_dataset = {root_dataset_id: train_dataset}
619
+
620
+ elif not all(dataset_id in train_dataset for dataset_id in root_datasets):
621
+ missing_ids = [dataset_id for dataset_id in root_datasets if dataset_id not in train_dataset]
622
+ raise ValueError(
623
+ f"`train_dataset` is missing the following root dataset id(s): {missing_ids}"
624
+ )
625
+
626
+ # as syntactic sugar, we will allow some keyword arguments to parameterize our policies
627
+ policy = construct_policy_from_kwargs(**kwargs)
628
+ if policy is not None:
629
+ kwargs["policy"] = policy
630
+
631
+ # construct unique logical op ids for all operators in this dataset
632
+ self._generate_unique_logical_op_ids()
633
+
634
+ return QueryProcessorFactory.create_and_run_processor(self, config, train_dataset, validator)