palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +259 -197
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +634 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +61 -5
- palimpzest/prompts/filter_prompts.py +50 -5
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
- palimpzest/prompts/prompt_factory.py +358 -46
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +157 -330
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +27 -21
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +22 -13
- palimpzest/query/operators/join.py +402 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +198 -80
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +41 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +27 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
- palimpzest-0.8.0.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.20.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -1,88 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
|
|
3
|
-
from palimpzest.core.data.dataclasses import ExecutionStats, SentinelPlanStats
|
|
4
|
-
from palimpzest.core.elements.records import DataRecordCollection
|
|
5
|
-
from palimpzest.query.optimizer.cost_model import SampleBasedCostModel
|
|
6
|
-
from palimpzest.query.optimizer.optimizer_strategy_type import OptimizationStrategyType
|
|
7
|
-
from palimpzest.query.optimizer.plan import SentinelPlan
|
|
8
|
-
from palimpzest.query.processor.query_processor import QueryProcessor
|
|
9
|
-
|
|
10
|
-
logger = logging.getLogger(__name__)
|
|
11
|
-
|
|
12
|
-
class SentinelQueryProcessor(QueryProcessor):
|
|
13
|
-
|
|
14
|
-
def _generate_sample_observations(self, sentinel_plan: SentinelPlan) -> SentinelPlanStats:
|
|
15
|
-
"""
|
|
16
|
-
This function is responsible for generating sample observation data which can be
|
|
17
|
-
consumed by the CostModel.
|
|
18
|
-
|
|
19
|
-
To accomplish this, we construct a special sentinel plan using the Optimizer which is
|
|
20
|
-
capable of executing any valid physical implementation of a Filter or Convert operator
|
|
21
|
-
on each record.
|
|
22
|
-
"""
|
|
23
|
-
# if we're using validation data, get the set of expected output records
|
|
24
|
-
expected_outputs = {}
|
|
25
|
-
for source_idx in range(len(self.val_datasource)):
|
|
26
|
-
expected_output = self.val_datasource[source_idx]
|
|
27
|
-
expected_outputs[source_idx] = expected_output
|
|
28
|
-
|
|
29
|
-
# execute sentinel plan; returns sentinel_plan_stats
|
|
30
|
-
return self.sentinel_execution_strategy.execute_sentinel_plan(sentinel_plan, expected_outputs)
|
|
31
|
-
|
|
32
|
-
def _create_sentinel_plan(self) -> SentinelPlan:
|
|
33
|
-
"""
|
|
34
|
-
Generates and returns a SentinelPlan for the given dataset.
|
|
35
|
-
"""
|
|
36
|
-
# create a new optimizer and update its strategy to SENTINEL
|
|
37
|
-
optimizer = self.optimizer.deepcopy_clean()
|
|
38
|
-
optimizer.update_strategy(OptimizationStrategyType.SENTINEL)
|
|
39
|
-
|
|
40
|
-
# create copy of dataset, but change its data source to the validation data source
|
|
41
|
-
dataset = self.dataset.copy()
|
|
42
|
-
dataset._set_data_source(self.val_datasource)
|
|
43
|
-
|
|
44
|
-
# get the sentinel plan for the given dataset
|
|
45
|
-
sentinel_plans = optimizer.optimize(dataset)
|
|
46
|
-
sentinel_plan = sentinel_plans[0]
|
|
47
|
-
|
|
48
|
-
return sentinel_plan
|
|
49
|
-
|
|
50
|
-
def execute(self) -> DataRecordCollection:
|
|
51
|
-
# for now, enforce that we are using validation data; we can relax this after paper submission
|
|
52
|
-
if self.val_datasource is None:
|
|
53
|
-
raise Exception("Make sure you are using validation data with SentinelQueryProcessor")
|
|
54
|
-
logger.info(f"Executing {self.__class__.__name__}")
|
|
55
|
-
|
|
56
|
-
# create execution stats
|
|
57
|
-
execution_stats = ExecutionStats(execution_id=self.execution_id())
|
|
58
|
-
execution_stats.start()
|
|
59
|
-
|
|
60
|
-
# create sentinel plan
|
|
61
|
-
sentinel_plan = self._create_sentinel_plan()
|
|
62
|
-
|
|
63
|
-
# generate sample execution data
|
|
64
|
-
sentinel_plan_stats = self._generate_sample_observations(sentinel_plan)
|
|
65
|
-
|
|
66
|
-
# update the execution stats to account for the work done in optimization
|
|
67
|
-
execution_stats.add_plan_stats(sentinel_plan_stats)
|
|
68
|
-
execution_stats.finish_optimization()
|
|
69
|
-
|
|
70
|
-
# (re-)initialize the optimizer
|
|
71
|
-
optimizer = self.optimizer.deepcopy_clean()
|
|
72
|
-
|
|
73
|
-
# construct the CostModel with any sample execution data we've gathered
|
|
74
|
-
cost_model = SampleBasedCostModel(sentinel_plan_stats, self.verbose)
|
|
75
|
-
optimizer.update_cost_model(cost_model)
|
|
76
|
-
|
|
77
|
-
# execute plan(s) according to the optimization strategy
|
|
78
|
-
records, plan_stats = self._execute_best_plan(self.dataset, optimizer)
|
|
79
|
-
|
|
80
|
-
# update the execution stats to account for the work to execute the final plan
|
|
81
|
-
execution_stats.add_plan_stats(plan_stats)
|
|
82
|
-
execution_stats.finish()
|
|
83
|
-
|
|
84
|
-
# construct and return the DataRecordCollection
|
|
85
|
-
result = DataRecordCollection(records, execution_stats=execution_stats)
|
|
86
|
-
logger.info("Done executing SentinelQueryProcessor")
|
|
87
|
-
|
|
88
|
-
return result
|
|
@@ -1,149 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import time
|
|
3
|
-
|
|
4
|
-
from palimpzest.core.data.dataclasses import PlanStats
|
|
5
|
-
from palimpzest.core.elements.records import DataRecordCollection
|
|
6
|
-
from palimpzest.query.operators.aggregate import AggregateOp
|
|
7
|
-
from palimpzest.query.operators.filter import FilterOp
|
|
8
|
-
from palimpzest.query.operators.limit import LimitScanOp
|
|
9
|
-
from palimpzest.query.operators.scan import ScanPhysicalOp
|
|
10
|
-
from palimpzest.query.optimizer.plan import PhysicalPlan
|
|
11
|
-
from palimpzest.query.processor.query_processor import QueryProcessor
|
|
12
|
-
from palimpzest.sets import Dataset
|
|
13
|
-
|
|
14
|
-
logger = logging.getLogger(__name__)
|
|
15
|
-
|
|
16
|
-
class StreamingQueryProcessor(QueryProcessor):
|
|
17
|
-
"""This class can be used for a streaming, record-based execution.
|
|
18
|
-
Results are returned as an iterable that can be consumed by the caller."""
|
|
19
|
-
|
|
20
|
-
def __init__(self, *args, **kwargs) -> None:
|
|
21
|
-
super().__init__(*args, **kwargs)
|
|
22
|
-
self._plan: PhysicalPlan | None = None
|
|
23
|
-
self._plan_stats: PlanStats | None = None
|
|
24
|
-
self.last_record = False
|
|
25
|
-
self.current_scan_idx: int = 0
|
|
26
|
-
self.plan_generated: bool = False
|
|
27
|
-
self.records_count: int = 0
|
|
28
|
-
logger.info("Initialized StreamingQueryProcessor")
|
|
29
|
-
|
|
30
|
-
@property
|
|
31
|
-
def plan(self) -> PhysicalPlan:
|
|
32
|
-
if self._plan is None:
|
|
33
|
-
raise Exception("Plan has not been generated yet.")
|
|
34
|
-
return self._plan
|
|
35
|
-
|
|
36
|
-
@plan.setter
|
|
37
|
-
def plan(self, plan: PhysicalPlan):
|
|
38
|
-
self._plan = plan
|
|
39
|
-
|
|
40
|
-
@property
|
|
41
|
-
def plan_stats(self) -> PlanStats:
|
|
42
|
-
if self._plan_stats is None:
|
|
43
|
-
raise Exception("Plan stats have not been generated yet.")
|
|
44
|
-
return self._plan_stats
|
|
45
|
-
|
|
46
|
-
@plan_stats.setter
|
|
47
|
-
def plan_stats(self, plan_stats: PlanStats):
|
|
48
|
-
self._plan_stats = plan_stats
|
|
49
|
-
|
|
50
|
-
def generate_plan(self, dataset: Dataset):
|
|
51
|
-
# self.clear_cached_examples()
|
|
52
|
-
start_time = time.time()
|
|
53
|
-
|
|
54
|
-
# check that the plan does not contain any aggregation operators
|
|
55
|
-
for op in self.plan.operators:
|
|
56
|
-
if isinstance(op, AggregateOp):
|
|
57
|
-
raise Exception("You cannot have a Streaming Execution if there is an Aggregation Operator")
|
|
58
|
-
|
|
59
|
-
# TODO: Do we need to re-initialize the optimizer here?
|
|
60
|
-
# Effectively always use the optimal strategy
|
|
61
|
-
optimizer = self.optimizer.deepcopy_clean()
|
|
62
|
-
plans = optimizer.optimize(dataset)
|
|
63
|
-
self.plan = plans[0]
|
|
64
|
-
self.plan_stats = PlanStats.from_plan(self.plan)
|
|
65
|
-
self.plan_stats.start()
|
|
66
|
-
logger.info(f"Time for planning: {time.time() - start_time:.2f} seconds")
|
|
67
|
-
self.plan_generated = True
|
|
68
|
-
logger.info(f"Generated plan:\n{self.plan}")
|
|
69
|
-
return self.plan
|
|
70
|
-
|
|
71
|
-
def execute(self):
|
|
72
|
-
logger.info("Executing StreamingQueryProcessor")
|
|
73
|
-
# Always delete cache
|
|
74
|
-
if not self.plan_generated:
|
|
75
|
-
self.generate_plan(self.dataset)
|
|
76
|
-
|
|
77
|
-
# if dry_run:
|
|
78
|
-
# yield [], self.plan, self.plan_stats
|
|
79
|
-
# return
|
|
80
|
-
|
|
81
|
-
input_records = self.get_input_records()
|
|
82
|
-
for idx, record in enumerate(input_records):
|
|
83
|
-
# print("Iteration number: ", idx+1, "out of", len(input_records))
|
|
84
|
-
output_records = self.execute_opstream(self.plan, record)
|
|
85
|
-
if idx == len(input_records) - 1:
|
|
86
|
-
# finalize plan stats
|
|
87
|
-
self.plan_stats.finish()
|
|
88
|
-
self.plan_stats.plan_str = str(self.plan)
|
|
89
|
-
yield DataRecordCollection(output_records, plan_stats=self.plan_stats)
|
|
90
|
-
|
|
91
|
-
logger.info("Done executing StreamingQueryProcessor")
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
def get_input_records(self):
|
|
95
|
-
scan_operator = self.plan.operators[0]
|
|
96
|
-
assert isinstance(scan_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
|
|
97
|
-
datareader = scan_operator.datareader
|
|
98
|
-
if not datareader:
|
|
99
|
-
raise Exception("DataReader not found")
|
|
100
|
-
datareader_len = len(datareader)
|
|
101
|
-
|
|
102
|
-
input_records = []
|
|
103
|
-
record_op_stats = []
|
|
104
|
-
for source_idx in range(datareader_len):
|
|
105
|
-
record_set = scan_operator(source_idx)
|
|
106
|
-
input_records += record_set.data_records
|
|
107
|
-
record_op_stats += record_set.record_op_stats
|
|
108
|
-
|
|
109
|
-
self.plan_stats.add_record_op_stats(record_op_stats)
|
|
110
|
-
|
|
111
|
-
return input_records
|
|
112
|
-
|
|
113
|
-
def execute_opstream(self, plan, record):
|
|
114
|
-
# initialize list of output records and intermediate variables
|
|
115
|
-
input_records = [record]
|
|
116
|
-
record_op_stats_lst = []
|
|
117
|
-
|
|
118
|
-
for operator in plan.operators:
|
|
119
|
-
# TODO: this being defined in the for loop potentially makes the return
|
|
120
|
-
# unbounded if plan.operators is empty. This should be defined outside the loop
|
|
121
|
-
# and the loop refactored to account for not redeclaring this for each operator
|
|
122
|
-
output_records = []
|
|
123
|
-
|
|
124
|
-
if isinstance(operator, ScanPhysicalOp):
|
|
125
|
-
continue
|
|
126
|
-
# only invoke aggregate operator(s) once there are no more source records and all
|
|
127
|
-
# upstream operators' processing queues are empty
|
|
128
|
-
# elif isinstance(operator, AggregateOp):
|
|
129
|
-
# output_records, record_op_stats_lst = operator(candidates=input_records)
|
|
130
|
-
elif isinstance(operator, LimitScanOp):
|
|
131
|
-
if self.records_count >= operator.limit:
|
|
132
|
-
break
|
|
133
|
-
else:
|
|
134
|
-
for r in input_records:
|
|
135
|
-
record_set = operator(r)
|
|
136
|
-
output_records += record_set.data_records
|
|
137
|
-
record_op_stats_lst += record_set.record_op_stats
|
|
138
|
-
|
|
139
|
-
if isinstance(operator, FilterOp):
|
|
140
|
-
# delete all records that did not pass the filter
|
|
141
|
-
output_records = [r for r in output_records if r.passed_operator]
|
|
142
|
-
if not output_records:
|
|
143
|
-
break
|
|
144
|
-
|
|
145
|
-
self.plan_stats.add_record_op_stats(record_op_stats_lst)
|
|
146
|
-
input_records = output_records
|
|
147
|
-
self.records_count += len(output_records)
|
|
148
|
-
|
|
149
|
-
return output_records
|
palimpzest/sets.py
DELETED
|
@@ -1,405 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from pathlib import Path
|
|
4
|
-
from typing import Callable
|
|
5
|
-
|
|
6
|
-
import pandas as pd
|
|
7
|
-
from chromadb.api.models.Collection import Collection
|
|
8
|
-
|
|
9
|
-
from palimpzest.constants import AggFunc, Cardinality
|
|
10
|
-
from palimpzest.core.data.datareaders import DataReader
|
|
11
|
-
from palimpzest.core.elements.filters import Filter
|
|
12
|
-
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
13
|
-
from palimpzest.core.lib.schemas import Number, Schema
|
|
14
|
-
from palimpzest.policy import construct_policy_from_kwargs
|
|
15
|
-
from palimpzest.query.processor.config import QueryProcessorConfig
|
|
16
|
-
from palimpzest.utils.datareader_helpers import get_local_datareader
|
|
17
|
-
from palimpzest.utils.hash_helpers import hash_for_serialized_dict
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
#####################################################
|
|
21
|
-
#
|
|
22
|
-
#####################################################
|
|
23
|
-
class Set:
|
|
24
|
-
"""
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(
|
|
28
|
-
self,
|
|
29
|
-
source: Set | DataReader,
|
|
30
|
-
schema: Schema,
|
|
31
|
-
desc: str | None = None,
|
|
32
|
-
filter: Filter | None = None,
|
|
33
|
-
udf: Callable | None = None,
|
|
34
|
-
agg_func: AggFunc | None = None,
|
|
35
|
-
group_by: GroupBySig | None = None,
|
|
36
|
-
project_cols: list[str] | None = None,
|
|
37
|
-
index: Collection | None = None,
|
|
38
|
-
search_func: Callable | None = None,
|
|
39
|
-
search_attr: str | None = None,
|
|
40
|
-
output_attrs: list[dict] | None = None,
|
|
41
|
-
k: int | None = None, # TODO: disambiguate `k` to be something like `retrieve_k`
|
|
42
|
-
limit: int | None = None,
|
|
43
|
-
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
44
|
-
depends_on: list[str] | None = None,
|
|
45
|
-
cache: bool = False,
|
|
46
|
-
):
|
|
47
|
-
self._schema = schema
|
|
48
|
-
self._source = source
|
|
49
|
-
self._desc = desc
|
|
50
|
-
self._filter = filter
|
|
51
|
-
self._udf = udf
|
|
52
|
-
self._agg_func = agg_func
|
|
53
|
-
self._group_by = group_by
|
|
54
|
-
self._project_cols = None if project_cols is None else sorted(project_cols)
|
|
55
|
-
self._index = index
|
|
56
|
-
self._search_func = search_func
|
|
57
|
-
self._search_attr = search_attr
|
|
58
|
-
self._output_attrs = output_attrs
|
|
59
|
-
self._k = k
|
|
60
|
-
self._limit = limit
|
|
61
|
-
self._cardinality = cardinality
|
|
62
|
-
self._depends_on = [] if depends_on is None else sorted(depends_on)
|
|
63
|
-
self._cache = cache
|
|
64
|
-
|
|
65
|
-
@property
|
|
66
|
-
def schema(self) -> Schema:
|
|
67
|
-
return self._schema
|
|
68
|
-
|
|
69
|
-
def _set_data_source(self, source: DataReader):
|
|
70
|
-
if isinstance(self._source, Set):
|
|
71
|
-
self._source._set_data_source(source)
|
|
72
|
-
else:
|
|
73
|
-
self._source = source
|
|
74
|
-
|
|
75
|
-
def serialize(self):
|
|
76
|
-
# NOTE: I needed to remove depends_on from the serialization dictionary because
|
|
77
|
-
# the optimizer changes the name of the depends_on fields to be their "full" name.
|
|
78
|
-
# This created an issue with the node.universal_identifier() not being consistent
|
|
79
|
-
# after changing the field to its full name.
|
|
80
|
-
d = {
|
|
81
|
-
"schema": self.schema.json_schema(),
|
|
82
|
-
"source": self._source.serialize(),
|
|
83
|
-
"desc": repr(self._desc),
|
|
84
|
-
"filter": None if self._filter is None else self._filter.serialize(),
|
|
85
|
-
"udf": None if self._udf is None else self._udf.__name__,
|
|
86
|
-
"agg_func": None if self._agg_func is None else self._agg_func.value,
|
|
87
|
-
"cardinality": self._cardinality,
|
|
88
|
-
"limit": self._limit,
|
|
89
|
-
"group_by": None if self._group_by is None else self._group_by.serialize(),
|
|
90
|
-
"project_cols": None if self._project_cols is None else self._project_cols,
|
|
91
|
-
"index": None if self._index is None else self._index.__class__.__name__,
|
|
92
|
-
"search_func": None if self._search_func is None else self._search_func.__name__,
|
|
93
|
-
"search_attr": self._search_attr,
|
|
94
|
-
"output_attrs": None if self._output_attrs is None else str(self._output_attrs),
|
|
95
|
-
"k": self._k,
|
|
96
|
-
}
|
|
97
|
-
|
|
98
|
-
return d
|
|
99
|
-
|
|
100
|
-
def universal_identifier(self):
|
|
101
|
-
"""Return a unique identifier for this Set."""
|
|
102
|
-
return hash_for_serialized_dict(self.serialize())
|
|
103
|
-
|
|
104
|
-
def json_schema(self):
|
|
105
|
-
"""Return the JSON schema for this Set."""
|
|
106
|
-
return self.schema.json_schema()
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
class Dataset(Set):
|
|
110
|
-
"""
|
|
111
|
-
A Dataset is the intended abstraction for programmers to interact with when writing PZ programs.
|
|
112
|
-
|
|
113
|
-
Users instantiate a Dataset by specifying a `source` that either points to a DataReader
|
|
114
|
-
or an existing Dataset. Users can then perform computations on the Dataset in a lazy fashion
|
|
115
|
-
by leveraging functions such as `filter`, `sem_filter`, `sem_add_columns`, `aggregate`, etc.
|
|
116
|
-
Underneath the hood, each of these operations creates a new Dataset. As a result, the Dataset
|
|
117
|
-
defines a lineage of computation.
|
|
118
|
-
"""
|
|
119
|
-
|
|
120
|
-
def __init__(
|
|
121
|
-
self,
|
|
122
|
-
source: str | Path | list | pd.DataFrame | DataReader | Dataset,
|
|
123
|
-
schema: Schema | None = None,
|
|
124
|
-
*args,
|
|
125
|
-
**kwargs,
|
|
126
|
-
) -> None:
|
|
127
|
-
# NOTE: this function currently assumes that DataReader will always be provided with a schema;
|
|
128
|
-
# we will relax this assumption in a subsequent PR
|
|
129
|
-
# convert source into a DataReader
|
|
130
|
-
updated_source = get_local_datareader(source, **kwargs) if isinstance(source, (str, Path, list, pd.DataFrame)) else source
|
|
131
|
-
|
|
132
|
-
# get the schema
|
|
133
|
-
schema = updated_source.schema if schema is None else schema
|
|
134
|
-
|
|
135
|
-
# intialize class
|
|
136
|
-
super().__init__(updated_source, schema, *args, **kwargs)
|
|
137
|
-
|
|
138
|
-
def copy(self):
|
|
139
|
-
return Dataset(
|
|
140
|
-
source=self._source.copy() if isinstance(self._source, Set) else self._source,
|
|
141
|
-
schema=self._schema,
|
|
142
|
-
desc=self._desc,
|
|
143
|
-
filter=self._filter,
|
|
144
|
-
udf=self._udf,
|
|
145
|
-
agg_func=self._agg_func,
|
|
146
|
-
group_by=self._group_by,
|
|
147
|
-
project_cols=self._project_cols,
|
|
148
|
-
index=self._index,
|
|
149
|
-
search_func=self._search_func,
|
|
150
|
-
search_attr=self._search_attr,
|
|
151
|
-
output_attrs=self._output_attrs,
|
|
152
|
-
k=self._k,
|
|
153
|
-
limit=self._limit,
|
|
154
|
-
cardinality=self._cardinality,
|
|
155
|
-
depends_on=self._depends_on,
|
|
156
|
-
cache=self._cache,
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
def filter(
|
|
160
|
-
self,
|
|
161
|
-
_filter: Callable,
|
|
162
|
-
depends_on: str | list[str] | None = None,
|
|
163
|
-
) -> Dataset:
|
|
164
|
-
"""Add a user defined function as a filter to the Set. This filter will possibly restrict the items that are returned later."""
|
|
165
|
-
f = None
|
|
166
|
-
if callable(_filter):
|
|
167
|
-
f = Filter(filter_fn=_filter)
|
|
168
|
-
else:
|
|
169
|
-
error_str = f"Only support callable for filter, currently got {type(_filter)}"
|
|
170
|
-
if isinstance(_filter, str):
|
|
171
|
-
error_str += ". Consider using sem_filter() for semantic filters."
|
|
172
|
-
raise Exception(error_str)
|
|
173
|
-
|
|
174
|
-
if isinstance(depends_on, str):
|
|
175
|
-
depends_on = [depends_on]
|
|
176
|
-
|
|
177
|
-
return Dataset(
|
|
178
|
-
source=self,
|
|
179
|
-
schema=self.schema,
|
|
180
|
-
filter=f,
|
|
181
|
-
depends_on=depends_on,
|
|
182
|
-
cache=self._cache,
|
|
183
|
-
)
|
|
184
|
-
|
|
185
|
-
def sem_filter(
|
|
186
|
-
self,
|
|
187
|
-
_filter: str,
|
|
188
|
-
depends_on: str | list[str] | None = None,
|
|
189
|
-
) -> Dataset:
|
|
190
|
-
"""Add a natural language description of a filter to the Set. This filter will possibly restrict the items that are returned later."""
|
|
191
|
-
f = None
|
|
192
|
-
if isinstance(_filter, str):
|
|
193
|
-
f = Filter(_filter)
|
|
194
|
-
else:
|
|
195
|
-
raise Exception("sem_filter() only supports `str` input for _filter.", type(_filter))
|
|
196
|
-
|
|
197
|
-
if isinstance(depends_on, str):
|
|
198
|
-
depends_on = [depends_on]
|
|
199
|
-
|
|
200
|
-
return Dataset(
|
|
201
|
-
source=self,
|
|
202
|
-
schema=self.schema,
|
|
203
|
-
filter=f,
|
|
204
|
-
depends_on=depends_on,
|
|
205
|
-
cache=self._cache,
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
def sem_add_columns(self, cols: list[dict] | type[Schema],
|
|
209
|
-
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
210
|
-
depends_on: str | list[str] | None = None,
|
|
211
|
-
desc: str = "Add new columns via semantic reasoning") -> Dataset:
|
|
212
|
-
"""
|
|
213
|
-
Add new columns by specifying the column names, descriptions, and types.
|
|
214
|
-
The column will be computed during the execution of the Dataset.
|
|
215
|
-
Example:
|
|
216
|
-
sem_add_columns(
|
|
217
|
-
[{'name': 'greeting', 'desc': 'The greeting message', 'type': str},
|
|
218
|
-
{'name': 'age', 'desc': 'The age of the person', 'type': int},
|
|
219
|
-
{'name': 'full_name', 'desc': 'The name of the person', 'type': str}]
|
|
220
|
-
)
|
|
221
|
-
"""
|
|
222
|
-
if isinstance(depends_on, str):
|
|
223
|
-
depends_on = [depends_on]
|
|
224
|
-
|
|
225
|
-
new_output_schema = None
|
|
226
|
-
if isinstance(cols, list):
|
|
227
|
-
new_output_schema = self.schema.add_fields(cols)
|
|
228
|
-
elif issubclass(cols, Schema):
|
|
229
|
-
new_output_schema = self.schema.union(cols)
|
|
230
|
-
else:
|
|
231
|
-
raise ValueError("`cols` must be a list of dictionaries or a Schema.")
|
|
232
|
-
|
|
233
|
-
return Dataset(
|
|
234
|
-
source=self,
|
|
235
|
-
schema=new_output_schema,
|
|
236
|
-
udf=None,
|
|
237
|
-
cardinality=cardinality,
|
|
238
|
-
depends_on=depends_on,
|
|
239
|
-
desc=desc,
|
|
240
|
-
cache=self._cache,
|
|
241
|
-
)
|
|
242
|
-
|
|
243
|
-
def add_columns(self, udf: Callable,
|
|
244
|
-
cols: list[dict] | type[Schema],
|
|
245
|
-
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
246
|
-
depends_on: str | list[str] | None = None,
|
|
247
|
-
desc: str = "Add new columns via UDF") -> Dataset:
|
|
248
|
-
"""
|
|
249
|
-
Add new columns by specifying UDFs.
|
|
250
|
-
|
|
251
|
-
Examples:
|
|
252
|
-
add_columns(
|
|
253
|
-
udf=compute_personal_greeting,
|
|
254
|
-
cols=[
|
|
255
|
-
{'name': 'greeting', 'desc': 'The greeting message', 'type': str},
|
|
256
|
-
{'name': 'age', 'desc': 'The age of the person', 'type': int},
|
|
257
|
-
{'name': 'full_name', 'desc': 'The name of the person', 'type': str},
|
|
258
|
-
]
|
|
259
|
-
)
|
|
260
|
-
"""
|
|
261
|
-
if udf is None or cols is None:
|
|
262
|
-
raise ValueError("`udf` and `cols` must be provided for add_columns.")
|
|
263
|
-
|
|
264
|
-
if isinstance(depends_on, str):
|
|
265
|
-
depends_on = [depends_on]
|
|
266
|
-
|
|
267
|
-
new_output_schema = None
|
|
268
|
-
if isinstance(cols, list):
|
|
269
|
-
updated_cols = []
|
|
270
|
-
for col_dict in cols:
|
|
271
|
-
assert isinstance(col_dict, dict), "each entry in `cols` must be a dictionary"
|
|
272
|
-
assert "name" in col_dict, "each type must contain a 'name' key specifying the column name"
|
|
273
|
-
assert "type" in col_dict, "each type must contain a 'type' key specifying the column type"
|
|
274
|
-
col_dict["desc"] = col_dict.get("desc", "New column: " + col_dict["name"])
|
|
275
|
-
updated_cols.append(col_dict)
|
|
276
|
-
new_output_schema = self.schema.add_fields(updated_cols)
|
|
277
|
-
|
|
278
|
-
elif issubclass(cols, Schema):
|
|
279
|
-
new_output_schema = self.schema.union(cols)
|
|
280
|
-
|
|
281
|
-
else:
|
|
282
|
-
raise ValueError("`cols` must be a list of dictionaries or a Schema.")
|
|
283
|
-
|
|
284
|
-
return Dataset(
|
|
285
|
-
source=self,
|
|
286
|
-
schema=new_output_schema,
|
|
287
|
-
udf=udf,
|
|
288
|
-
cardinality=cardinality,
|
|
289
|
-
desc=desc,
|
|
290
|
-
depends_on=depends_on,
|
|
291
|
-
cache=self._cache,
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
def map(self, udf: Callable) -> Dataset:
|
|
295
|
-
"""
|
|
296
|
-
Apply a UDF map function.
|
|
297
|
-
|
|
298
|
-
Examples:
|
|
299
|
-
map(udf=clean_column_values)
|
|
300
|
-
"""
|
|
301
|
-
if udf is None:
|
|
302
|
-
raise ValueError("`udf` must be provided for map.")
|
|
303
|
-
|
|
304
|
-
return Dataset(
|
|
305
|
-
source=self,
|
|
306
|
-
schema=self.schema,
|
|
307
|
-
udf=udf,
|
|
308
|
-
cache=self._cache,
|
|
309
|
-
)
|
|
310
|
-
|
|
311
|
-
def count(self) -> Dataset:
|
|
312
|
-
"""Apply a count aggregation to this set"""
|
|
313
|
-
return Dataset(
|
|
314
|
-
source=self,
|
|
315
|
-
schema=Number,
|
|
316
|
-
desc="Count results",
|
|
317
|
-
agg_func=AggFunc.COUNT,
|
|
318
|
-
cache=self._cache,
|
|
319
|
-
)
|
|
320
|
-
|
|
321
|
-
def average(self) -> Dataset:
|
|
322
|
-
"""Apply an average aggregation to this set"""
|
|
323
|
-
return Dataset(
|
|
324
|
-
source=self,
|
|
325
|
-
schema=Number,
|
|
326
|
-
desc="Average results",
|
|
327
|
-
agg_func=AggFunc.AVERAGE,
|
|
328
|
-
cache=self._cache,
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
def groupby(self, groupby: GroupBySig) -> Dataset:
|
|
332
|
-
return Dataset(
|
|
333
|
-
source=self,
|
|
334
|
-
schema=groupby.output_schema(),
|
|
335
|
-
desc="Group By",
|
|
336
|
-
group_by=groupby,
|
|
337
|
-
cache=self._cache,
|
|
338
|
-
)
|
|
339
|
-
|
|
340
|
-
def retrieve(
|
|
341
|
-
self,
|
|
342
|
-
index: Collection,
|
|
343
|
-
search_attr: str,
|
|
344
|
-
output_attrs: list[dict] | type[Schema],
|
|
345
|
-
search_func: Callable | None = None,
|
|
346
|
-
k: int = -1,
|
|
347
|
-
) -> Dataset:
|
|
348
|
-
"""
|
|
349
|
-
Retrieve the top-k nearest neighbors of the value of the `search_attr` from the `index` and
|
|
350
|
-
use these results to construct the `output_attrs` field(s).
|
|
351
|
-
"""
|
|
352
|
-
new_output_schema = None
|
|
353
|
-
if isinstance(output_attrs, list):
|
|
354
|
-
new_output_schema = self.schema.add_fields(output_attrs)
|
|
355
|
-
elif issubclass(output_attrs, Schema):
|
|
356
|
-
new_output_schema = self.schema.union(output_attrs)
|
|
357
|
-
else:
|
|
358
|
-
raise ValueError("`cols` must be a list of dictionaries or a Schema.")
|
|
359
|
-
|
|
360
|
-
# TODO: revisit once we can think through abstraction(s)
|
|
361
|
-
# # construct the PZIndex from the user-provided index
|
|
362
|
-
# index = index_factory(index)
|
|
363
|
-
|
|
364
|
-
return Dataset(
|
|
365
|
-
source=self,
|
|
366
|
-
schema=new_output_schema,
|
|
367
|
-
desc="Retrieve",
|
|
368
|
-
index=index,
|
|
369
|
-
search_func=search_func,
|
|
370
|
-
search_attr=search_attr,
|
|
371
|
-
output_attrs=output_attrs,
|
|
372
|
-
k=k,
|
|
373
|
-
cache=self._cache,
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
def limit(self, n: int) -> Dataset:
|
|
377
|
-
"""Limit the set size to no more than n rows"""
|
|
378
|
-
return Dataset(
|
|
379
|
-
source=self,
|
|
380
|
-
schema=self.schema,
|
|
381
|
-
desc="LIMIT " + str(n),
|
|
382
|
-
limit=n,
|
|
383
|
-
cache=self._cache,
|
|
384
|
-
)
|
|
385
|
-
|
|
386
|
-
def project(self, project_cols: list[str] | str) -> Dataset:
|
|
387
|
-
"""Project the Set to only include the specified columns."""
|
|
388
|
-
return Dataset(
|
|
389
|
-
source=self,
|
|
390
|
-
schema=self.schema.project(project_cols),
|
|
391
|
-
project_cols=project_cols if isinstance(project_cols, list) else [project_cols],
|
|
392
|
-
cache=self._cache,
|
|
393
|
-
)
|
|
394
|
-
|
|
395
|
-
def run(self, config: QueryProcessorConfig | None = None, **kwargs):
|
|
396
|
-
"""Invoke the QueryProcessor to execute the query. `kwargs` will be applied to the QueryProcessorConfig."""
|
|
397
|
-
# TODO: this import currently needs to be here to avoid a circular import; we should fix this in a subsequent PR
|
|
398
|
-
from palimpzest.query.processor.query_processor_factory import QueryProcessorFactory
|
|
399
|
-
|
|
400
|
-
# as syntactic sugar, we will allow some keyword arguments to parameterize our policies
|
|
401
|
-
policy = construct_policy_from_kwargs(**kwargs)
|
|
402
|
-
if policy is not None:
|
|
403
|
-
kwargs["policy"] = policy
|
|
404
|
-
|
|
405
|
-
return QueryProcessorFactory.create_and_run_processor(self, config, **kwargs)
|