palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.20.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -1,88 +0,0 @@
1
- import logging
2
-
3
- from palimpzest.core.data.dataclasses import ExecutionStats, SentinelPlanStats
4
- from palimpzest.core.elements.records import DataRecordCollection
5
- from palimpzest.query.optimizer.cost_model import SampleBasedCostModel
6
- from palimpzest.query.optimizer.optimizer_strategy_type import OptimizationStrategyType
7
- from palimpzest.query.optimizer.plan import SentinelPlan
8
- from palimpzest.query.processor.query_processor import QueryProcessor
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- class SentinelQueryProcessor(QueryProcessor):
13
-
14
- def _generate_sample_observations(self, sentinel_plan: SentinelPlan) -> SentinelPlanStats:
15
- """
16
- This function is responsible for generating sample observation data which can be
17
- consumed by the CostModel.
18
-
19
- To accomplish this, we construct a special sentinel plan using the Optimizer which is
20
- capable of executing any valid physical implementation of a Filter or Convert operator
21
- on each record.
22
- """
23
- # if we're using validation data, get the set of expected output records
24
- expected_outputs = {}
25
- for source_idx in range(len(self.val_datasource)):
26
- expected_output = self.val_datasource[source_idx]
27
- expected_outputs[source_idx] = expected_output
28
-
29
- # execute sentinel plan; returns sentinel_plan_stats
30
- return self.sentinel_execution_strategy.execute_sentinel_plan(sentinel_plan, expected_outputs)
31
-
32
- def _create_sentinel_plan(self) -> SentinelPlan:
33
- """
34
- Generates and returns a SentinelPlan for the given dataset.
35
- """
36
- # create a new optimizer and update its strategy to SENTINEL
37
- optimizer = self.optimizer.deepcopy_clean()
38
- optimizer.update_strategy(OptimizationStrategyType.SENTINEL)
39
-
40
- # create copy of dataset, but change its data source to the validation data source
41
- dataset = self.dataset.copy()
42
- dataset._set_data_source(self.val_datasource)
43
-
44
- # get the sentinel plan for the given dataset
45
- sentinel_plans = optimizer.optimize(dataset)
46
- sentinel_plan = sentinel_plans[0]
47
-
48
- return sentinel_plan
49
-
50
- def execute(self) -> DataRecordCollection:
51
- # for now, enforce that we are using validation data; we can relax this after paper submission
52
- if self.val_datasource is None:
53
- raise Exception("Make sure you are using validation data with SentinelQueryProcessor")
54
- logger.info(f"Executing {self.__class__.__name__}")
55
-
56
- # create execution stats
57
- execution_stats = ExecutionStats(execution_id=self.execution_id())
58
- execution_stats.start()
59
-
60
- # create sentinel plan
61
- sentinel_plan = self._create_sentinel_plan()
62
-
63
- # generate sample execution data
64
- sentinel_plan_stats = self._generate_sample_observations(sentinel_plan)
65
-
66
- # update the execution stats to account for the work done in optimization
67
- execution_stats.add_plan_stats(sentinel_plan_stats)
68
- execution_stats.finish_optimization()
69
-
70
- # (re-)initialize the optimizer
71
- optimizer = self.optimizer.deepcopy_clean()
72
-
73
- # construct the CostModel with any sample execution data we've gathered
74
- cost_model = SampleBasedCostModel(sentinel_plan_stats, self.verbose)
75
- optimizer.update_cost_model(cost_model)
76
-
77
- # execute plan(s) according to the optimization strategy
78
- records, plan_stats = self._execute_best_plan(self.dataset, optimizer)
79
-
80
- # update the execution stats to account for the work to execute the final plan
81
- execution_stats.add_plan_stats(plan_stats)
82
- execution_stats.finish()
83
-
84
- # construct and return the DataRecordCollection
85
- result = DataRecordCollection(records, execution_stats=execution_stats)
86
- logger.info("Done executing SentinelQueryProcessor")
87
-
88
- return result
@@ -1,149 +0,0 @@
1
- import logging
2
- import time
3
-
4
- from palimpzest.core.data.dataclasses import PlanStats
5
- from palimpzest.core.elements.records import DataRecordCollection
6
- from palimpzest.query.operators.aggregate import AggregateOp
7
- from palimpzest.query.operators.filter import FilterOp
8
- from palimpzest.query.operators.limit import LimitScanOp
9
- from palimpzest.query.operators.scan import ScanPhysicalOp
10
- from palimpzest.query.optimizer.plan import PhysicalPlan
11
- from palimpzest.query.processor.query_processor import QueryProcessor
12
- from palimpzest.sets import Dataset
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
- class StreamingQueryProcessor(QueryProcessor):
17
- """This class can be used for a streaming, record-based execution.
18
- Results are returned as an iterable that can be consumed by the caller."""
19
-
20
- def __init__(self, *args, **kwargs) -> None:
21
- super().__init__(*args, **kwargs)
22
- self._plan: PhysicalPlan | None = None
23
- self._plan_stats: PlanStats | None = None
24
- self.last_record = False
25
- self.current_scan_idx: int = 0
26
- self.plan_generated: bool = False
27
- self.records_count: int = 0
28
- logger.info("Initialized StreamingQueryProcessor")
29
-
30
- @property
31
- def plan(self) -> PhysicalPlan:
32
- if self._plan is None:
33
- raise Exception("Plan has not been generated yet.")
34
- return self._plan
35
-
36
- @plan.setter
37
- def plan(self, plan: PhysicalPlan):
38
- self._plan = plan
39
-
40
- @property
41
- def plan_stats(self) -> PlanStats:
42
- if self._plan_stats is None:
43
- raise Exception("Plan stats have not been generated yet.")
44
- return self._plan_stats
45
-
46
- @plan_stats.setter
47
- def plan_stats(self, plan_stats: PlanStats):
48
- self._plan_stats = plan_stats
49
-
50
- def generate_plan(self, dataset: Dataset):
51
- # self.clear_cached_examples()
52
- start_time = time.time()
53
-
54
- # check that the plan does not contain any aggregation operators
55
- for op in self.plan.operators:
56
- if isinstance(op, AggregateOp):
57
- raise Exception("You cannot have a Streaming Execution if there is an Aggregation Operator")
58
-
59
- # TODO: Do we need to re-initialize the optimizer here?
60
- # Effectively always use the optimal strategy
61
- optimizer = self.optimizer.deepcopy_clean()
62
- plans = optimizer.optimize(dataset)
63
- self.plan = plans[0]
64
- self.plan_stats = PlanStats.from_plan(self.plan)
65
- self.plan_stats.start()
66
- logger.info(f"Time for planning: {time.time() - start_time:.2f} seconds")
67
- self.plan_generated = True
68
- logger.info(f"Generated plan:\n{self.plan}")
69
- return self.plan
70
-
71
- def execute(self):
72
- logger.info("Executing StreamingQueryProcessor")
73
- # Always delete cache
74
- if not self.plan_generated:
75
- self.generate_plan(self.dataset)
76
-
77
- # if dry_run:
78
- # yield [], self.plan, self.plan_stats
79
- # return
80
-
81
- input_records = self.get_input_records()
82
- for idx, record in enumerate(input_records):
83
- # print("Iteration number: ", idx+1, "out of", len(input_records))
84
- output_records = self.execute_opstream(self.plan, record)
85
- if idx == len(input_records) - 1:
86
- # finalize plan stats
87
- self.plan_stats.finish()
88
- self.plan_stats.plan_str = str(self.plan)
89
- yield DataRecordCollection(output_records, plan_stats=self.plan_stats)
90
-
91
- logger.info("Done executing StreamingQueryProcessor")
92
-
93
-
94
- def get_input_records(self):
95
- scan_operator = self.plan.operators[0]
96
- assert isinstance(scan_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
97
- datareader = scan_operator.datareader
98
- if not datareader:
99
- raise Exception("DataReader not found")
100
- datareader_len = len(datareader)
101
-
102
- input_records = []
103
- record_op_stats = []
104
- for source_idx in range(datareader_len):
105
- record_set = scan_operator(source_idx)
106
- input_records += record_set.data_records
107
- record_op_stats += record_set.record_op_stats
108
-
109
- self.plan_stats.add_record_op_stats(record_op_stats)
110
-
111
- return input_records
112
-
113
- def execute_opstream(self, plan, record):
114
- # initialize list of output records and intermediate variables
115
- input_records = [record]
116
- record_op_stats_lst = []
117
-
118
- for operator in plan.operators:
119
- # TODO: this being defined in the for loop potentially makes the return
120
- # unbounded if plan.operators is empty. This should be defined outside the loop
121
- # and the loop refactored to account for not redeclaring this for each operator
122
- output_records = []
123
-
124
- if isinstance(operator, ScanPhysicalOp):
125
- continue
126
- # only invoke aggregate operator(s) once there are no more source records and all
127
- # upstream operators' processing queues are empty
128
- # elif isinstance(operator, AggregateOp):
129
- # output_records, record_op_stats_lst = operator(candidates=input_records)
130
- elif isinstance(operator, LimitScanOp):
131
- if self.records_count >= operator.limit:
132
- break
133
- else:
134
- for r in input_records:
135
- record_set = operator(r)
136
- output_records += record_set.data_records
137
- record_op_stats_lst += record_set.record_op_stats
138
-
139
- if isinstance(operator, FilterOp):
140
- # delete all records that did not pass the filter
141
- output_records = [r for r in output_records if r.passed_operator]
142
- if not output_records:
143
- break
144
-
145
- self.plan_stats.add_record_op_stats(record_op_stats_lst)
146
- input_records = output_records
147
- self.records_count += len(output_records)
148
-
149
- return output_records
palimpzest/sets.py DELETED
@@ -1,405 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
- from typing import Callable
5
-
6
- import pandas as pd
7
- from chromadb.api.models.Collection import Collection
8
-
9
- from palimpzest.constants import AggFunc, Cardinality
10
- from palimpzest.core.data.datareaders import DataReader
11
- from palimpzest.core.elements.filters import Filter
12
- from palimpzest.core.elements.groupbysig import GroupBySig
13
- from palimpzest.core.lib.schemas import Number, Schema
14
- from palimpzest.policy import construct_policy_from_kwargs
15
- from palimpzest.query.processor.config import QueryProcessorConfig
16
- from palimpzest.utils.datareader_helpers import get_local_datareader
17
- from palimpzest.utils.hash_helpers import hash_for_serialized_dict
18
-
19
-
20
- #####################################################
21
- #
22
- #####################################################
23
- class Set:
24
- """
25
- """
26
-
27
- def __init__(
28
- self,
29
- source: Set | DataReader,
30
- schema: Schema,
31
- desc: str | None = None,
32
- filter: Filter | None = None,
33
- udf: Callable | None = None,
34
- agg_func: AggFunc | None = None,
35
- group_by: GroupBySig | None = None,
36
- project_cols: list[str] | None = None,
37
- index: Collection | None = None,
38
- search_func: Callable | None = None,
39
- search_attr: str | None = None,
40
- output_attrs: list[dict] | None = None,
41
- k: int | None = None, # TODO: disambiguate `k` to be something like `retrieve_k`
42
- limit: int | None = None,
43
- cardinality: Cardinality = Cardinality.ONE_TO_ONE,
44
- depends_on: list[str] | None = None,
45
- cache: bool = False,
46
- ):
47
- self._schema = schema
48
- self._source = source
49
- self._desc = desc
50
- self._filter = filter
51
- self._udf = udf
52
- self._agg_func = agg_func
53
- self._group_by = group_by
54
- self._project_cols = None if project_cols is None else sorted(project_cols)
55
- self._index = index
56
- self._search_func = search_func
57
- self._search_attr = search_attr
58
- self._output_attrs = output_attrs
59
- self._k = k
60
- self._limit = limit
61
- self._cardinality = cardinality
62
- self._depends_on = [] if depends_on is None else sorted(depends_on)
63
- self._cache = cache
64
-
65
- @property
66
- def schema(self) -> Schema:
67
- return self._schema
68
-
69
- def _set_data_source(self, source: DataReader):
70
- if isinstance(self._source, Set):
71
- self._source._set_data_source(source)
72
- else:
73
- self._source = source
74
-
75
- def serialize(self):
76
- # NOTE: I needed to remove depends_on from the serialization dictionary because
77
- # the optimizer changes the name of the depends_on fields to be their "full" name.
78
- # This created an issue with the node.universal_identifier() not being consistent
79
- # after changing the field to its full name.
80
- d = {
81
- "schema": self.schema.json_schema(),
82
- "source": self._source.serialize(),
83
- "desc": repr(self._desc),
84
- "filter": None if self._filter is None else self._filter.serialize(),
85
- "udf": None if self._udf is None else self._udf.__name__,
86
- "agg_func": None if self._agg_func is None else self._agg_func.value,
87
- "cardinality": self._cardinality,
88
- "limit": self._limit,
89
- "group_by": None if self._group_by is None else self._group_by.serialize(),
90
- "project_cols": None if self._project_cols is None else self._project_cols,
91
- "index": None if self._index is None else self._index.__class__.__name__,
92
- "search_func": None if self._search_func is None else self._search_func.__name__,
93
- "search_attr": self._search_attr,
94
- "output_attrs": None if self._output_attrs is None else str(self._output_attrs),
95
- "k": self._k,
96
- }
97
-
98
- return d
99
-
100
- def universal_identifier(self):
101
- """Return a unique identifier for this Set."""
102
- return hash_for_serialized_dict(self.serialize())
103
-
104
- def json_schema(self):
105
- """Return the JSON schema for this Set."""
106
- return self.schema.json_schema()
107
-
108
-
109
- class Dataset(Set):
110
- """
111
- A Dataset is the intended abstraction for programmers to interact with when writing PZ programs.
112
-
113
- Users instantiate a Dataset by specifying a `source` that either points to a DataReader
114
- or an existing Dataset. Users can then perform computations on the Dataset in a lazy fashion
115
- by leveraging functions such as `filter`, `sem_filter`, `sem_add_columns`, `aggregate`, etc.
116
- Underneath the hood, each of these operations creates a new Dataset. As a result, the Dataset
117
- defines a lineage of computation.
118
- """
119
-
120
- def __init__(
121
- self,
122
- source: str | Path | list | pd.DataFrame | DataReader | Dataset,
123
- schema: Schema | None = None,
124
- *args,
125
- **kwargs,
126
- ) -> None:
127
- # NOTE: this function currently assumes that DataReader will always be provided with a schema;
128
- # we will relax this assumption in a subsequent PR
129
- # convert source into a DataReader
130
- updated_source = get_local_datareader(source, **kwargs) if isinstance(source, (str, Path, list, pd.DataFrame)) else source
131
-
132
- # get the schema
133
- schema = updated_source.schema if schema is None else schema
134
-
135
- # intialize class
136
- super().__init__(updated_source, schema, *args, **kwargs)
137
-
138
- def copy(self):
139
- return Dataset(
140
- source=self._source.copy() if isinstance(self._source, Set) else self._source,
141
- schema=self._schema,
142
- desc=self._desc,
143
- filter=self._filter,
144
- udf=self._udf,
145
- agg_func=self._agg_func,
146
- group_by=self._group_by,
147
- project_cols=self._project_cols,
148
- index=self._index,
149
- search_func=self._search_func,
150
- search_attr=self._search_attr,
151
- output_attrs=self._output_attrs,
152
- k=self._k,
153
- limit=self._limit,
154
- cardinality=self._cardinality,
155
- depends_on=self._depends_on,
156
- cache=self._cache,
157
- )
158
-
159
- def filter(
160
- self,
161
- _filter: Callable,
162
- depends_on: str | list[str] | None = None,
163
- ) -> Dataset:
164
- """Add a user defined function as a filter to the Set. This filter will possibly restrict the items that are returned later."""
165
- f = None
166
- if callable(_filter):
167
- f = Filter(filter_fn=_filter)
168
- else:
169
- error_str = f"Only support callable for filter, currently got {type(_filter)}"
170
- if isinstance(_filter, str):
171
- error_str += ". Consider using sem_filter() for semantic filters."
172
- raise Exception(error_str)
173
-
174
- if isinstance(depends_on, str):
175
- depends_on = [depends_on]
176
-
177
- return Dataset(
178
- source=self,
179
- schema=self.schema,
180
- filter=f,
181
- depends_on=depends_on,
182
- cache=self._cache,
183
- )
184
-
185
- def sem_filter(
186
- self,
187
- _filter: str,
188
- depends_on: str | list[str] | None = None,
189
- ) -> Dataset:
190
- """Add a natural language description of a filter to the Set. This filter will possibly restrict the items that are returned later."""
191
- f = None
192
- if isinstance(_filter, str):
193
- f = Filter(_filter)
194
- else:
195
- raise Exception("sem_filter() only supports `str` input for _filter.", type(_filter))
196
-
197
- if isinstance(depends_on, str):
198
- depends_on = [depends_on]
199
-
200
- return Dataset(
201
- source=self,
202
- schema=self.schema,
203
- filter=f,
204
- depends_on=depends_on,
205
- cache=self._cache,
206
- )
207
-
208
- def sem_add_columns(self, cols: list[dict] | type[Schema],
209
- cardinality: Cardinality = Cardinality.ONE_TO_ONE,
210
- depends_on: str | list[str] | None = None,
211
- desc: str = "Add new columns via semantic reasoning") -> Dataset:
212
- """
213
- Add new columns by specifying the column names, descriptions, and types.
214
- The column will be computed during the execution of the Dataset.
215
- Example:
216
- sem_add_columns(
217
- [{'name': 'greeting', 'desc': 'The greeting message', 'type': str},
218
- {'name': 'age', 'desc': 'The age of the person', 'type': int},
219
- {'name': 'full_name', 'desc': 'The name of the person', 'type': str}]
220
- )
221
- """
222
- if isinstance(depends_on, str):
223
- depends_on = [depends_on]
224
-
225
- new_output_schema = None
226
- if isinstance(cols, list):
227
- new_output_schema = self.schema.add_fields(cols)
228
- elif issubclass(cols, Schema):
229
- new_output_schema = self.schema.union(cols)
230
- else:
231
- raise ValueError("`cols` must be a list of dictionaries or a Schema.")
232
-
233
- return Dataset(
234
- source=self,
235
- schema=new_output_schema,
236
- udf=None,
237
- cardinality=cardinality,
238
- depends_on=depends_on,
239
- desc=desc,
240
- cache=self._cache,
241
- )
242
-
243
- def add_columns(self, udf: Callable,
244
- cols: list[dict] | type[Schema],
245
- cardinality: Cardinality = Cardinality.ONE_TO_ONE,
246
- depends_on: str | list[str] | None = None,
247
- desc: str = "Add new columns via UDF") -> Dataset:
248
- """
249
- Add new columns by specifying UDFs.
250
-
251
- Examples:
252
- add_columns(
253
- udf=compute_personal_greeting,
254
- cols=[
255
- {'name': 'greeting', 'desc': 'The greeting message', 'type': str},
256
- {'name': 'age', 'desc': 'The age of the person', 'type': int},
257
- {'name': 'full_name', 'desc': 'The name of the person', 'type': str},
258
- ]
259
- )
260
- """
261
- if udf is None or cols is None:
262
- raise ValueError("`udf` and `cols` must be provided for add_columns.")
263
-
264
- if isinstance(depends_on, str):
265
- depends_on = [depends_on]
266
-
267
- new_output_schema = None
268
- if isinstance(cols, list):
269
- updated_cols = []
270
- for col_dict in cols:
271
- assert isinstance(col_dict, dict), "each entry in `cols` must be a dictionary"
272
- assert "name" in col_dict, "each type must contain a 'name' key specifying the column name"
273
- assert "type" in col_dict, "each type must contain a 'type' key specifying the column type"
274
- col_dict["desc"] = col_dict.get("desc", "New column: " + col_dict["name"])
275
- updated_cols.append(col_dict)
276
- new_output_schema = self.schema.add_fields(updated_cols)
277
-
278
- elif issubclass(cols, Schema):
279
- new_output_schema = self.schema.union(cols)
280
-
281
- else:
282
- raise ValueError("`cols` must be a list of dictionaries or a Schema.")
283
-
284
- return Dataset(
285
- source=self,
286
- schema=new_output_schema,
287
- udf=udf,
288
- cardinality=cardinality,
289
- desc=desc,
290
- depends_on=depends_on,
291
- cache=self._cache,
292
- )
293
-
294
- def map(self, udf: Callable) -> Dataset:
295
- """
296
- Apply a UDF map function.
297
-
298
- Examples:
299
- map(udf=clean_column_values)
300
- """
301
- if udf is None:
302
- raise ValueError("`udf` must be provided for map.")
303
-
304
- return Dataset(
305
- source=self,
306
- schema=self.schema,
307
- udf=udf,
308
- cache=self._cache,
309
- )
310
-
311
- def count(self) -> Dataset:
312
- """Apply a count aggregation to this set"""
313
- return Dataset(
314
- source=self,
315
- schema=Number,
316
- desc="Count results",
317
- agg_func=AggFunc.COUNT,
318
- cache=self._cache,
319
- )
320
-
321
- def average(self) -> Dataset:
322
- """Apply an average aggregation to this set"""
323
- return Dataset(
324
- source=self,
325
- schema=Number,
326
- desc="Average results",
327
- agg_func=AggFunc.AVERAGE,
328
- cache=self._cache,
329
- )
330
-
331
- def groupby(self, groupby: GroupBySig) -> Dataset:
332
- return Dataset(
333
- source=self,
334
- schema=groupby.output_schema(),
335
- desc="Group By",
336
- group_by=groupby,
337
- cache=self._cache,
338
- )
339
-
340
- def retrieve(
341
- self,
342
- index: Collection,
343
- search_attr: str,
344
- output_attrs: list[dict] | type[Schema],
345
- search_func: Callable | None = None,
346
- k: int = -1,
347
- ) -> Dataset:
348
- """
349
- Retrieve the top-k nearest neighbors of the value of the `search_attr` from the `index` and
350
- use these results to construct the `output_attrs` field(s).
351
- """
352
- new_output_schema = None
353
- if isinstance(output_attrs, list):
354
- new_output_schema = self.schema.add_fields(output_attrs)
355
- elif issubclass(output_attrs, Schema):
356
- new_output_schema = self.schema.union(output_attrs)
357
- else:
358
- raise ValueError("`cols` must be a list of dictionaries or a Schema.")
359
-
360
- # TODO: revisit once we can think through abstraction(s)
361
- # # construct the PZIndex from the user-provided index
362
- # index = index_factory(index)
363
-
364
- return Dataset(
365
- source=self,
366
- schema=new_output_schema,
367
- desc="Retrieve",
368
- index=index,
369
- search_func=search_func,
370
- search_attr=search_attr,
371
- output_attrs=output_attrs,
372
- k=k,
373
- cache=self._cache,
374
- )
375
-
376
- def limit(self, n: int) -> Dataset:
377
- """Limit the set size to no more than n rows"""
378
- return Dataset(
379
- source=self,
380
- schema=self.schema,
381
- desc="LIMIT " + str(n),
382
- limit=n,
383
- cache=self._cache,
384
- )
385
-
386
- def project(self, project_cols: list[str] | str) -> Dataset:
387
- """Project the Set to only include the specified columns."""
388
- return Dataset(
389
- source=self,
390
- schema=self.schema.project(project_cols),
391
- project_cols=project_cols if isinstance(project_cols, list) else [project_cols],
392
- cache=self._cache,
393
- )
394
-
395
- def run(self, config: QueryProcessorConfig | None = None, **kwargs):
396
- """Invoke the QueryProcessor to execute the query. `kwargs` will be applied to the QueryProcessorConfig."""
397
- # TODO: this import currently needs to be here to avoid a circular import; we should fix this in a subsequent PR
398
- from palimpzest.query.processor.query_processor_factory import QueryProcessorFactory
399
-
400
- # as syntactic sugar, we will allow some keyword arguments to parameterize our policies
401
- policy = construct_policy_from_kwargs(**kwargs)
402
- if policy is not None:
403
- kwargs["policy"] = policy
404
-
405
- return QueryProcessorFactory.create_and_run_processor(self, config, **kwargs)