palimpzest 0.7.7__py3-none-any.whl → 0.7.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +113 -75
- palimpzest/core/data/dataclasses.py +55 -38
- palimpzest/core/elements/index.py +5 -15
- palimpzest/core/elements/records.py +1 -1
- palimpzest/prompts/prompt_factory.py +1 -1
- palimpzest/query/execution/all_sample_execution_strategy.py +216 -0
- palimpzest/query/execution/execution_strategy.py +4 -4
- palimpzest/query/execution/execution_strategy_type.py +7 -1
- palimpzest/query/execution/mab_execution_strategy.py +184 -72
- palimpzest/query/execution/parallel_execution_strategy.py +182 -15
- palimpzest/query/execution/single_threaded_execution_strategy.py +21 -21
- palimpzest/query/generators/api_client_factory.py +6 -7
- palimpzest/query/generators/generators.py +5 -8
- palimpzest/query/operators/aggregate.py +4 -3
- palimpzest/query/operators/convert.py +1 -1
- palimpzest/query/operators/filter.py +1 -1
- palimpzest/query/operators/limit.py +1 -1
- palimpzest/query/operators/map.py +1 -1
- palimpzest/query/operators/physical.py +8 -4
- palimpzest/query/operators/project.py +1 -1
- palimpzest/query/operators/retrieve.py +7 -23
- palimpzest/query/operators/scan.py +1 -1
- palimpzest/query/optimizer/cost_model.py +54 -62
- palimpzest/query/optimizer/optimizer.py +2 -6
- palimpzest/query/optimizer/plan.py +4 -4
- palimpzest/query/optimizer/primitives.py +1 -1
- palimpzest/query/optimizer/rules.py +8 -26
- palimpzest/query/optimizer/tasks.py +3 -3
- palimpzest/query/processor/processing_strategy_type.py +2 -2
- palimpzest/query/processor/sentinel_processor.py +0 -2
- palimpzest/sets.py +2 -3
- palimpzest/utils/generation_helpers.py +1 -1
- palimpzest/utils/model_helpers.py +27 -9
- palimpzest/utils/progress.py +81 -72
- {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/METADATA +4 -2
- {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/RECORD +39 -38
- {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/WHEEL +1 -1
- {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from palimpzest.core.data.dataclasses import SentinelPlanStats
|
|
6
|
+
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
7
|
+
from palimpzest.query.execution.execution_strategy import SentinelExecutionStrategy
|
|
8
|
+
from palimpzest.query.operators.physical import PhysicalOperator
|
|
9
|
+
from palimpzest.query.operators.scan import ScanPhysicalOp
|
|
10
|
+
from palimpzest.query.optimizer.plan import SentinelPlan
|
|
11
|
+
from palimpzest.utils.progress import create_progress_manager
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
class OpSet:
|
|
16
|
+
"""
|
|
17
|
+
This class represents the set of operators which are currently in the frontier for a given logical operator.
|
|
18
|
+
Each operator in the frontier is an instance of a PhysicalOperator which either:
|
|
19
|
+
|
|
20
|
+
1. lies on the Pareto frontier of the set of sampled operators, or
|
|
21
|
+
2. has been sampled fewer than j times
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, op_set: list[PhysicalOperator], source_indices: list[int]):
|
|
25
|
+
# construct the set of operators
|
|
26
|
+
self.ops = op_set
|
|
27
|
+
|
|
28
|
+
# store the order in which we will sample the source records
|
|
29
|
+
self.source_indices = source_indices
|
|
30
|
+
|
|
31
|
+
# set the initial inputs for this logical operator
|
|
32
|
+
is_scan_op = isinstance(op_set[0], ScanPhysicalOp)
|
|
33
|
+
self.source_idx_to_input = {source_idx: [source_idx] for source_idx in self.source_indices} if is_scan_op else {}
|
|
34
|
+
|
|
35
|
+
def get_op_input_pairs(self) -> list[PhysicalOperator, DataRecord | int | None]:
|
|
36
|
+
"""
|
|
37
|
+
Returns the list of frontier operators and their next input to process. If there are
|
|
38
|
+
any indices in `source_indices_to_sample` which this operator does not sample on its own, then
|
|
39
|
+
we also have this frontier process that source_idx's input with its max quality operator.
|
|
40
|
+
"""
|
|
41
|
+
# get the list of (op, source_idx) pairs which this operator needs to execute
|
|
42
|
+
op_source_idx_pairs = []
|
|
43
|
+
for op in self.ops:
|
|
44
|
+
# construct list of inputs by looking up the input for the given source_idx
|
|
45
|
+
for source_idx in self.source_indices:
|
|
46
|
+
op_source_idx_pairs.append((op, source_idx))
|
|
47
|
+
|
|
48
|
+
# fetch the corresponding (op, input) pairs
|
|
49
|
+
op_input_pairs = []
|
|
50
|
+
for op, source_idx in op_source_idx_pairs:
|
|
51
|
+
op_input_pairs.extend([(op, input_record) for input_record in self.source_idx_to_input[source_idx]])
|
|
52
|
+
|
|
53
|
+
return op_input_pairs
|
|
54
|
+
|
|
55
|
+
def pick_highest_quality_output(self, record_sets: list[DataRecordSet]) -> DataRecordSet:
|
|
56
|
+
# if there's only one operator in the set, we return its record_set
|
|
57
|
+
if len(record_sets) == 1:
|
|
58
|
+
return record_sets[0]
|
|
59
|
+
|
|
60
|
+
# NOTE: I don't like that this assumes the models are consistent in
|
|
61
|
+
# how they order their record outputs for one-to-many converts;
|
|
62
|
+
# eventually we can try out more robust schemes to account for
|
|
63
|
+
# differences in ordering
|
|
64
|
+
# aggregate records at each index in the response
|
|
65
|
+
idx_to_records = {}
|
|
66
|
+
for record_set in record_sets:
|
|
67
|
+
for idx in range(len(record_set)):
|
|
68
|
+
record, record_op_stats = record_set[idx], record_set.record_op_stats[idx]
|
|
69
|
+
if idx not in idx_to_records:
|
|
70
|
+
idx_to_records[idx] = [(record, record_op_stats)]
|
|
71
|
+
else:
|
|
72
|
+
idx_to_records[idx].append((record, record_op_stats))
|
|
73
|
+
|
|
74
|
+
# compute highest quality answer at each index
|
|
75
|
+
out_records = []
|
|
76
|
+
out_record_op_stats = []
|
|
77
|
+
for idx in range(len(idx_to_records)):
|
|
78
|
+
records_lst, record_op_stats_lst = zip(*idx_to_records[idx])
|
|
79
|
+
max_quality_record, max_quality = records_lst[0], record_op_stats_lst[0].quality
|
|
80
|
+
max_quality_stats = record_op_stats_lst[0]
|
|
81
|
+
for record, record_op_stats in zip(records_lst[1:], record_op_stats_lst[1:]):
|
|
82
|
+
record_quality = record_op_stats.quality
|
|
83
|
+
if record_quality > max_quality:
|
|
84
|
+
max_quality_record = record
|
|
85
|
+
max_quality = record_quality
|
|
86
|
+
max_quality_stats = record_op_stats
|
|
87
|
+
out_records.append(max_quality_record)
|
|
88
|
+
out_record_op_stats.append(max_quality_stats)
|
|
89
|
+
|
|
90
|
+
# create and return final DataRecordSet
|
|
91
|
+
return DataRecordSet(out_records, out_record_op_stats)
|
|
92
|
+
|
|
93
|
+
def update_inputs(self, source_idx_to_record_sets: dict[int, DataRecordSet]):
|
|
94
|
+
"""
|
|
95
|
+
Update the inputs for this logical operator based on the outputs of the previous logical operator.
|
|
96
|
+
"""
|
|
97
|
+
for source_idx, record_sets in source_idx_to_record_sets.items():
|
|
98
|
+
input = []
|
|
99
|
+
max_quality_record_set = self.pick_highest_quality_output(record_sets)
|
|
100
|
+
for record in max_quality_record_set:
|
|
101
|
+
input.append(record if record.passed_operator else None)
|
|
102
|
+
|
|
103
|
+
self.source_idx_to_input[source_idx] = input
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class AllSamplingExecutionStrategy(SentinelExecutionStrategy):
|
|
107
|
+
|
|
108
|
+
def _get_source_indices(self):
|
|
109
|
+
"""Get the list of source indices which the sentinel plan should execute over."""
|
|
110
|
+
# create list of all source indices and shuffle it
|
|
111
|
+
total_num_samples = len(self.val_datasource)
|
|
112
|
+
source_indices = list(np.arange(total_num_samples))
|
|
113
|
+
|
|
114
|
+
return source_indices
|
|
115
|
+
|
|
116
|
+
def _execute_sentinel_plan(self,
|
|
117
|
+
plan: SentinelPlan,
|
|
118
|
+
op_sets: dict[str, OpSet],
|
|
119
|
+
expected_outputs: dict[int, dict] | None,
|
|
120
|
+
plan_stats: SentinelPlanStats,
|
|
121
|
+
) -> SentinelPlanStats:
|
|
122
|
+
# execute operator sets in sequence
|
|
123
|
+
for op_idx, (logical_op_id, op_set) in enumerate(plan):
|
|
124
|
+
# get frontier ops and their next input
|
|
125
|
+
op_input_pairs = op_sets[logical_op_id].get_op_input_pairs()
|
|
126
|
+
|
|
127
|
+
# break out of the loop if op_input_pairs is empty, as this means all records have been filtered out
|
|
128
|
+
if len(op_input_pairs) == 0:
|
|
129
|
+
break
|
|
130
|
+
|
|
131
|
+
# run sampled operators on sampled inputs
|
|
132
|
+
source_idx_to_record_sets_and_ops, _ = self._execute_op_set(op_input_pairs)
|
|
133
|
+
|
|
134
|
+
# FUTURE TODO: have this return the highest quality record set simply based on our posterior (or prior) belief on operator quality
|
|
135
|
+
# get the target record set for each source_idx
|
|
136
|
+
source_idx_to_target_record_set = self._get_target_record_sets(logical_op_id, source_idx_to_record_sets_and_ops, expected_outputs)
|
|
137
|
+
|
|
138
|
+
# TODO: make consistent across here and RandomSampling
|
|
139
|
+
# FUTURE TODO: move this outside of the loop (i.e. assume we only get quality label(s) after executing full program)
|
|
140
|
+
# score the quality of each generated output
|
|
141
|
+
physical_op_cls = op_set[0].__class__
|
|
142
|
+
source_idx_to_record_sets = {
|
|
143
|
+
source_idx: list(map(lambda tup: tup[0], record_sets_and_ops))
|
|
144
|
+
for source_idx, record_sets_and_ops in source_idx_to_record_sets_and_ops.items()
|
|
145
|
+
}
|
|
146
|
+
source_idx_to_record_sets = self._score_quality(physical_op_cls, source_idx_to_record_sets, source_idx_to_target_record_set)
|
|
147
|
+
|
|
148
|
+
# flatten the lists of records and record_op_stats
|
|
149
|
+
all_records, all_record_op_stats = self._flatten_record_sets(source_idx_to_record_sets)
|
|
150
|
+
|
|
151
|
+
# update plan stats
|
|
152
|
+
plan_stats.add_record_op_stats(all_record_op_stats)
|
|
153
|
+
|
|
154
|
+
# add records (which are not filtered) to the cache, if allowed
|
|
155
|
+
self._add_records_to_cache(logical_op_id, all_records)
|
|
156
|
+
|
|
157
|
+
# FUTURE TODO: simply set input based on source_idx_to_target_record_set (b/c we won't have scores computed)
|
|
158
|
+
# provide the champion record sets as inputs to the next logical operator
|
|
159
|
+
if op_idx + 1 < len(plan):
|
|
160
|
+
next_logical_op_id = plan.logical_op_ids[op_idx + 1]
|
|
161
|
+
op_sets[next_logical_op_id].update_inputs(source_idx_to_record_sets)
|
|
162
|
+
|
|
163
|
+
# close the cache
|
|
164
|
+
self._close_cache(plan.logical_op_ids)
|
|
165
|
+
|
|
166
|
+
# finalize plan stats
|
|
167
|
+
plan_stats.finish()
|
|
168
|
+
|
|
169
|
+
return plan_stats
|
|
170
|
+
|
|
171
|
+
def execute_sentinel_plan(self, plan: SentinelPlan, expected_outputs: dict[int, dict] | None):
|
|
172
|
+
"""
|
|
173
|
+
NOTE: this function currently requires us to set k and j properly in order to make
|
|
174
|
+
comparison in our research against the corresponding sample budget in MAB.
|
|
175
|
+
|
|
176
|
+
NOTE: the number of samples will slightly exceed the sample_budget if the number of operator
|
|
177
|
+
calls does not perfectly match the sample_budget. This may cause some minor discrepancies with
|
|
178
|
+
the progress manager as a result.
|
|
179
|
+
"""
|
|
180
|
+
# for now, assert that the first operator in the plan is a ScanPhysicalOp
|
|
181
|
+
assert all(isinstance(op, ScanPhysicalOp) for op in plan.operator_sets[0]), "First operator in physical plan must be a ScanPhysicalOp"
|
|
182
|
+
logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
|
|
183
|
+
logger.info(f"Plan Details: {plan}")
|
|
184
|
+
|
|
185
|
+
# initialize plan stats
|
|
186
|
+
plan_stats = SentinelPlanStats.from_plan(plan)
|
|
187
|
+
plan_stats.start()
|
|
188
|
+
|
|
189
|
+
# get list of source indices which can be sampled from
|
|
190
|
+
source_indices = self._get_source_indices()
|
|
191
|
+
|
|
192
|
+
# initialize set of physical operators for each logical operator
|
|
193
|
+
op_sets = {
|
|
194
|
+
logical_op_id: OpSet(op_set, source_indices)
|
|
195
|
+
for logical_op_id, op_set in plan
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
# initialize and start the progress manager
|
|
199
|
+
self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, progress=self.progress)
|
|
200
|
+
self.progress_manager.start()
|
|
201
|
+
|
|
202
|
+
# NOTE: we must handle progress manager outside of _exeecute_sentinel_plan to ensure that it is shut down correctly;
|
|
203
|
+
# if we don't have the `finally:` branch, then program crashes can cause future program runs to fail because
|
|
204
|
+
# the progress manager cannot get a handle to the console
|
|
205
|
+
try:
|
|
206
|
+
# execute sentinel plan by sampling records and operators
|
|
207
|
+
plan_stats = self._execute_sentinel_plan(plan, op_sets, expected_outputs, plan_stats)
|
|
208
|
+
|
|
209
|
+
finally:
|
|
210
|
+
# finish progress tracking
|
|
211
|
+
self.progress_manager.finish()
|
|
212
|
+
|
|
213
|
+
logger.info(f"Done executing sentinel plan: {plan.plan_id}")
|
|
214
|
+
logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
|
|
215
|
+
|
|
216
|
+
return plan_stats
|
|
@@ -78,7 +78,7 @@ class ExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
78
78
|
else min(self.scan_start_idx + self.num_samples, len(op.datareader))
|
|
79
79
|
)
|
|
80
80
|
inputs = [idx for idx in range(self.scan_start_idx, scan_end_idx)]
|
|
81
|
-
input_queues[op.
|
|
81
|
+
input_queues[op.get_full_op_id()] = inputs
|
|
82
82
|
|
|
83
83
|
return input_queues
|
|
84
84
|
|
|
@@ -95,6 +95,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
95
95
|
j: int,
|
|
96
96
|
sample_budget: int,
|
|
97
97
|
policy: Policy,
|
|
98
|
+
priors: dict | None = None,
|
|
98
99
|
use_final_op_quality: bool = False,
|
|
99
100
|
seed: int = 42,
|
|
100
101
|
exp_name: str | None = None,
|
|
@@ -107,6 +108,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
107
108
|
self.j = j
|
|
108
109
|
self.sample_budget = sample_budget
|
|
109
110
|
self.policy = policy
|
|
111
|
+
self.priors = priors
|
|
110
112
|
self.use_final_op_quality = use_final_op_quality
|
|
111
113
|
self.seed = seed
|
|
112
114
|
self.rng = np.random.default_rng(seed=seed)
|
|
@@ -378,9 +380,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
378
380
|
return input.source_idx if isinstance(input, DataRecord) else input
|
|
379
381
|
|
|
380
382
|
def get_hash(operator, input):
|
|
381
|
-
|
|
382
|
-
phys_op_id = operator.get_op_id()
|
|
383
|
-
return hash(f"{logical_op_id}{phys_op_id}{hash(input)}")
|
|
383
|
+
return hash(f"{operator.get_full_op_id()}{hash(input)}")
|
|
384
384
|
|
|
385
385
|
# initialize mapping from source indices to output record sets
|
|
386
386
|
source_idx_to_record_sets_and_ops = {get_source_idx(input): [] for _, input in op_input_pairs}
|
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
2
|
|
|
3
|
+
from palimpzest.query.execution.all_sample_execution_strategy import AllSamplingExecutionStrategy
|
|
3
4
|
from palimpzest.query.execution.mab_execution_strategy import MABExecutionStrategy
|
|
4
|
-
from palimpzest.query.execution.parallel_execution_strategy import
|
|
5
|
+
from palimpzest.query.execution.parallel_execution_strategy import (
|
|
6
|
+
ParallelExecutionStrategy,
|
|
7
|
+
SequentialParallelExecutionStrategy,
|
|
8
|
+
)
|
|
5
9
|
from palimpzest.query.execution.random_sampling_execution_strategy import RandomSamplingExecutionStrategy
|
|
6
10
|
from palimpzest.query.execution.single_threaded_execution_strategy import (
|
|
7
11
|
PipelinedSingleThreadExecutionStrategy,
|
|
@@ -14,7 +18,9 @@ class ExecutionStrategyType(Enum):
|
|
|
14
18
|
SEQUENTIAL = SequentialSingleThreadExecutionStrategy
|
|
15
19
|
PIPELINED = PipelinedSingleThreadExecutionStrategy
|
|
16
20
|
PARALLEL = ParallelExecutionStrategy
|
|
21
|
+
SEQUENTIAL_PARALLEL = SequentialParallelExecutionStrategy
|
|
17
22
|
|
|
18
23
|
class SentinelExecutionStrategyType(Enum):
|
|
19
24
|
MAB = MABExecutionStrategy
|
|
20
25
|
RANDOM = RandomSamplingExecutionStrategy
|
|
26
|
+
ALL = AllSamplingExecutionStrategy
|