palimpzest 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +5 -0
- palimpzest/constants.py +110 -43
- palimpzest/core/__init__.py +0 -78
- palimpzest/core/data/dataclasses.py +382 -44
- palimpzest/core/elements/filters.py +7 -3
- palimpzest/core/elements/index.py +70 -0
- palimpzest/core/elements/records.py +33 -11
- palimpzest/core/lib/fields.py +1 -0
- palimpzest/core/lib/schemas.py +4 -3
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
- palimpzest/prompts/prompt_factory.py +44 -7
- palimpzest/prompts/split_merge_prompts.py +56 -0
- palimpzest/prompts/split_proposer_prompts.py +55 -0
- palimpzest/query/execution/execution_strategy.py +435 -53
- palimpzest/query/execution/execution_strategy_type.py +20 -0
- palimpzest/query/execution/mab_execution_strategy.py +532 -0
- palimpzest/query/execution/parallel_execution_strategy.py +143 -172
- palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
- palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
- palimpzest/query/generators/api_client_factory.py +31 -0
- palimpzest/query/generators/generators.py +256 -76
- palimpzest/query/operators/__init__.py +1 -2
- palimpzest/query/operators/code_synthesis_convert.py +33 -18
- palimpzest/query/operators/convert.py +30 -97
- palimpzest/query/operators/critique_and_refine_convert.py +5 -6
- palimpzest/query/operators/filter.py +7 -10
- palimpzest/query/operators/logical.py +54 -10
- palimpzest/query/operators/map.py +130 -0
- palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
- palimpzest/query/operators/physical.py +3 -12
- palimpzest/query/operators/rag_convert.py +66 -18
- palimpzest/query/operators/retrieve.py +230 -34
- palimpzest/query/operators/scan.py +5 -2
- palimpzest/query/operators/split_convert.py +169 -0
- palimpzest/query/operators/token_reduction_convert.py +8 -14
- palimpzest/query/optimizer/__init__.py +4 -16
- palimpzest/query/optimizer/cost_model.py +73 -266
- palimpzest/query/optimizer/optimizer.py +87 -58
- palimpzest/query/optimizer/optimizer_strategy.py +18 -97
- palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/primitives.py +5 -3
- palimpzest/query/optimizer/rules.py +336 -172
- palimpzest/query/optimizer/tasks.py +30 -100
- palimpzest/query/processor/config.py +38 -22
- palimpzest/query/processor/nosentinel_processor.py +16 -520
- palimpzest/query/processor/processing_strategy_type.py +28 -0
- palimpzest/query/processor/query_processor.py +38 -206
- palimpzest/query/processor/query_processor_factory.py +117 -130
- palimpzest/query/processor/sentinel_processor.py +90 -0
- palimpzest/query/processor/streaming_processor.py +25 -32
- palimpzest/sets.py +88 -41
- palimpzest/utils/model_helpers.py +8 -7
- palimpzest/utils/progress.py +368 -152
- palimpzest/utils/token_reduction_helpers.py +1 -3
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/METADATA +28 -24
- palimpzest-0.7.0.dist-info/RECORD +96 -0
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
- palimpzest/query/processor/mab_sentinel_processor.py +0 -884
- palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
- palimpzest/utils/index_helpers.py +0 -6
- palimpzest-0.6.3.dist-info/RECORD +0 -87
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,21 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import multiprocessing
|
|
2
|
-
import time
|
|
3
3
|
from concurrent.futures import ThreadPoolExecutor, wait
|
|
4
4
|
|
|
5
5
|
from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
|
|
6
|
-
from palimpzest.core.data.dataclasses import
|
|
6
|
+
from palimpzest.core.data.dataclasses import PlanStats
|
|
7
|
+
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
7
8
|
from palimpzest.query.execution.execution_strategy import ExecutionStrategy
|
|
8
9
|
from palimpzest.query.operators.aggregate import AggregateOp
|
|
9
10
|
from palimpzest.query.operators.limit import LimitScanOp
|
|
10
11
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
11
12
|
from palimpzest.query.operators.scan import ScanPhysicalOp
|
|
12
13
|
from palimpzest.query.optimizer.plan import PhysicalPlan
|
|
14
|
+
from palimpzest.utils.progress import create_progress_manager
|
|
13
15
|
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
14
17
|
|
|
15
|
-
class
|
|
18
|
+
class ParallelExecutionStrategy(ExecutionStrategy):
|
|
16
19
|
"""
|
|
17
20
|
A parallel execution strategy that processes data through a pipeline of operators using thread-based parallelism.
|
|
18
21
|
"""
|
|
@@ -20,12 +23,12 @@ class PipelinedParallelExecutionStrategy(ExecutionStrategy):
|
|
|
20
23
|
def __init__(self, *args, **kwargs):
|
|
21
24
|
super().__init__(*args, **kwargs)
|
|
22
25
|
self.max_workers = (
|
|
23
|
-
self.
|
|
26
|
+
self._get_parallel_max_workers()
|
|
24
27
|
if self.max_workers is None
|
|
25
28
|
else self.max_workers
|
|
26
29
|
)
|
|
27
30
|
|
|
28
|
-
def
|
|
31
|
+
def _get_parallel_max_workers(self):
|
|
29
32
|
# for now, return the number of system CPUs;
|
|
30
33
|
# in the future, we may want to consider the models the user has access to
|
|
31
34
|
# and whether or not they will encounter rate-limits. If they will, we should
|
|
@@ -35,180 +38,148 @@ class PipelinedParallelExecutionStrategy(ExecutionStrategy):
|
|
|
35
38
|
# changing the max_workers in response to 429 errors.
|
|
36
39
|
return max(int(0.8 * multiprocessing.cpu_count()), 1)
|
|
37
40
|
|
|
38
|
-
def
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
#
|
|
41
|
+
def _any_queue_not_empty(self, queues: dict[str, list]) -> bool:
|
|
42
|
+
"""Helper function to check if any queue is not empty."""
|
|
43
|
+
return any(len(queue) > 0 for queue in queues.values())
|
|
44
|
+
|
|
45
|
+
def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list], future_queues: dict[str, list]) -> bool:
|
|
46
|
+
"""Helper function to check if all upstream operators have finished processing their inputs."""
|
|
47
|
+
for upstream_op_idx in range(op_idx):
|
|
48
|
+
upstream_op_id = plan.operators[upstream_op_idx].get_op_id()
|
|
49
|
+
if len(input_queues[upstream_op_id]) > 0 or len(future_queues[upstream_op_id]) > 0:
|
|
50
|
+
return False
|
|
51
|
+
|
|
52
|
+
return True
|
|
53
|
+
|
|
54
|
+
def _process_future_results(self, operator: PhysicalOperator, future_queues: dict[str, list], plan_stats: PlanStats) -> list[DataRecord]:
|
|
55
|
+
"""
|
|
56
|
+
Helper function which takes an operator, the future queues, and plan stats, and performs
|
|
57
|
+
the updates to plan stats and progress manager before returning the results from the finished futures.
|
|
58
|
+
"""
|
|
59
|
+
# get the op_id for the operator
|
|
60
|
+
op_id = operator.get_op_id()
|
|
61
|
+
|
|
62
|
+
# this function is called when the future queue is not empty
|
|
63
|
+
# and the executor is not busy processing other futures
|
|
64
|
+
done_futures, not_done_futures = wait(future_queues[op_id], timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
|
|
65
|
+
|
|
66
|
+
# add the unfinished futures back to the previous op's future queue
|
|
67
|
+
future_queues[op_id] = list(not_done_futures)
|
|
68
|
+
|
|
69
|
+
# add the finished futures to the input queue for this operator
|
|
70
|
+
output_records = []
|
|
71
|
+
for future in done_futures:
|
|
72
|
+
record_set: DataRecordSet = future.result()
|
|
73
|
+
records = record_set.data_records
|
|
74
|
+
record_op_stats = record_set.record_op_stats
|
|
75
|
+
num_outputs = sum(record.passed_operator for record in records)
|
|
76
|
+
|
|
77
|
+
# update the progress manager
|
|
78
|
+
self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
79
|
+
|
|
80
|
+
# update plan stats
|
|
81
|
+
plan_stats.add_record_op_stats(record_op_stats)
|
|
82
|
+
|
|
83
|
+
# add records to the cache
|
|
84
|
+
self._add_records_to_cache(operator.target_cache_id, records)
|
|
85
|
+
|
|
86
|
+
# add records which aren't filtered to the output records
|
|
87
|
+
output_records.extend([record for record in records if record.passed_operator])
|
|
88
|
+
|
|
89
|
+
return output_records
|
|
90
|
+
|
|
91
|
+
def _execute_plan(
|
|
92
|
+
self,
|
|
93
|
+
plan: PhysicalPlan,
|
|
94
|
+
input_queues: dict[str, list],
|
|
95
|
+
future_queues: dict[str, list],
|
|
96
|
+
plan_stats: PlanStats,
|
|
97
|
+
) -> tuple[list[DataRecord], PlanStats]:
|
|
98
|
+
# process all of the input records using a thread pool
|
|
57
99
|
output_records = []
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
op_id_to_next_operator = {
|
|
68
|
-
op.get_op_id(): plan.operators[idx + 1] if idx + 1 < len(plan.operators) else None
|
|
69
|
-
for idx, op in enumerate(plan.operators)
|
|
70
|
-
}
|
|
71
|
-
op_id_to_op_idx = {op.get_op_id(): idx for idx, op in enumerate(plan.operators)}
|
|
72
|
-
|
|
73
|
-
# get handle to scan operator and pre-compute its op_id and size
|
|
74
|
-
source_operator = plan.operators[0]
|
|
75
|
-
assert isinstance(source_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
|
|
76
|
-
source_op_id = source_operator.get_op_id()
|
|
77
|
-
datareader_len = len(source_operator.datareader)
|
|
78
|
-
|
|
79
|
-
# get limit of final limit operator (if one exists)
|
|
80
|
-
final_limit = plan.operators[-1].limit if isinstance(plan.operators[-1], LimitScanOp) else None
|
|
81
|
-
|
|
82
|
-
# create thread pool w/max workers
|
|
83
|
-
futures = []
|
|
84
|
-
current_scan_idx = self.scan_start_idx
|
|
85
|
-
with ThreadPoolExecutor(max_workers=plan_workers) as executor:
|
|
86
|
-
# create initial (set of) future(s) to read first source record;
|
|
87
|
-
futures.append(executor.submit(PhysicalOperator.execute_op_wrapper, source_operator, current_scan_idx))
|
|
88
|
-
op_id_to_futures_in_flight[source_op_id] += 1
|
|
89
|
-
current_scan_idx += 1
|
|
90
|
-
|
|
91
|
-
# iterate until we have processed all operators on all records or come to an early stopping condition
|
|
92
|
-
while len(futures) > 0:
|
|
93
|
-
# get the set of futures that have (and have not) finished in the last PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
|
|
94
|
-
done_futures, not_done_futures = wait(futures, timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
|
|
95
|
-
|
|
96
|
-
# cast not_done_futures from a set to a list so we can append to it
|
|
97
|
-
not_done_futures = list(not_done_futures)
|
|
98
|
-
|
|
99
|
-
# process finished futures, creating new ones as needed
|
|
100
|
-
new_futures = []
|
|
101
|
-
for future in done_futures:
|
|
102
|
-
# get the result
|
|
103
|
-
record_set, operator, _ = future.result()
|
|
100
|
+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
101
|
+
logger.debug(f"Created thread pool with {self.max_workers} workers")
|
|
102
|
+
|
|
103
|
+
# execute the plan until either:
|
|
104
|
+
# 1. all records have been processed, or
|
|
105
|
+
# 2. the final limit operation has completed (we break out of the loop if this happens)
|
|
106
|
+
final_op = plan.operators[-1]
|
|
107
|
+
while self._any_queue_not_empty(input_queues) or self._any_queue_not_empty(future_queues):
|
|
108
|
+
for op_idx, operator in enumerate(plan.operators):
|
|
104
109
|
op_id = operator.get_op_id()
|
|
105
110
|
|
|
106
|
-
#
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
#
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
next_operator = op_id_to_next_operator[op_id]
|
|
130
|
-
if next_operator is not None:
|
|
131
|
-
processing_queue.append((next_operator, record))
|
|
132
|
-
else:
|
|
133
|
-
output_records.append(record)
|
|
134
|
-
|
|
135
|
-
# if this operator was a source scan, update the number of source records scanned
|
|
136
|
-
if op_id == source_op_id:
|
|
137
|
-
source_records_scanned += len(record_set)
|
|
138
|
-
|
|
139
|
-
# scan next record if we can still draw records from source
|
|
140
|
-
if source_records_scanned < num_samples and current_scan_idx < datareader_len:
|
|
141
|
-
new_futures.append(executor.submit(PhysicalOperator.execute_op_wrapper, source_operator, current_scan_idx))
|
|
142
|
-
op_id_to_futures_in_flight[source_op_id] += 1
|
|
143
|
-
current_scan_idx += 1
|
|
144
|
-
|
|
145
|
-
# check early stopping condition based on final limit
|
|
146
|
-
if final_limit is not None and len(output_records) >= final_limit:
|
|
147
|
-
output_records = output_records[:final_limit]
|
|
148
|
-
futures = []
|
|
149
|
-
break
|
|
150
|
-
|
|
151
|
-
# process all records in the processing queue which are ready to be executed
|
|
152
|
-
temp_processing_queue = []
|
|
153
|
-
for operator, candidate in processing_queue:
|
|
154
|
-
# if the candidate is not an input to an aggregate, execute it right away
|
|
155
|
-
if not isinstance(operator, AggregateOp):
|
|
156
|
-
future = executor.submit(PhysicalOperator.execute_op_wrapper, operator, candidate)
|
|
157
|
-
new_futures.append(future)
|
|
158
|
-
op_id_to_futures_in_flight[operator.get_op_id()] += 1
|
|
159
|
-
|
|
160
|
-
# otherwise, put it back on the queue
|
|
161
|
-
else:
|
|
162
|
-
temp_processing_queue.append((operator, candidate))
|
|
163
|
-
|
|
164
|
-
# any remaining candidates are inputs to aggregate operators; for each aggregate operator
|
|
165
|
-
# determine if it is ready to execute -- and execute all of its candidates if so
|
|
166
|
-
processing_queue = []
|
|
167
|
-
agg_op_ids = set([operator.get_op_id() for operator, _ in temp_processing_queue])
|
|
168
|
-
for agg_op_id in agg_op_ids:
|
|
169
|
-
agg_op_idx = op_id_to_op_idx[agg_op_id]
|
|
170
|
-
|
|
171
|
-
# compute if all upstream operators' processing queues are empty and their in-flight futures are finished
|
|
172
|
-
upstream_ops_are_finished = True
|
|
173
|
-
for upstream_op_idx in range(agg_op_idx):
|
|
174
|
-
upstream_op_id = plan.operators[upstream_op_idx].get_op_id()
|
|
175
|
-
upstream_op_id_queue = list(
|
|
176
|
-
filter(lambda tup: tup[0].get_op_id() == upstream_op_id, temp_processing_queue)
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
upstream_ops_are_finished = (
|
|
180
|
-
upstream_ops_are_finished
|
|
181
|
-
and len(upstream_op_id_queue) == 0
|
|
182
|
-
and op_id_to_futures_in_flight[upstream_op_id] == 0
|
|
183
|
-
)
|
|
184
|
-
|
|
185
|
-
# get the subset of candidates for this aggregate operator
|
|
186
|
-
candidate_tuples = list(filter(lambda tup: tup[0].get_op_id() == agg_op_id, temp_processing_queue))
|
|
187
|
-
|
|
188
|
-
# execute the operator on the candidates if it's ready
|
|
189
|
-
if upstream_ops_are_finished:
|
|
190
|
-
operator = op_id_to_operator[agg_op_id]
|
|
191
|
-
candidates = list(map(lambda tup: tup[1], candidate_tuples))
|
|
192
|
-
future = executor.submit(PhysicalOperator.execute_op_wrapper, operator, candidates)
|
|
193
|
-
new_futures.append(future)
|
|
194
|
-
op_id_to_futures_in_flight[operator.get_op_id()] += 1
|
|
195
|
-
|
|
196
|
-
# otherwise, add the candidates back to the processing queue
|
|
111
|
+
# get any finished futures from the previous operator and add them to the input queue for this operator
|
|
112
|
+
if not isinstance(operator, ScanPhysicalOp):
|
|
113
|
+
prev_operator = plan.operators[op_idx - 1]
|
|
114
|
+
records = self._process_future_results(prev_operator, future_queues, plan_stats)
|
|
115
|
+
input_queues[op_id].extend(records)
|
|
116
|
+
|
|
117
|
+
# for the final operator, add any finished futures to the output_records
|
|
118
|
+
if operator.get_op_id() == final_op.get_op_id():
|
|
119
|
+
records = self._process_future_results(operator, future_queues, plan_stats)
|
|
120
|
+
output_records.extend(records)
|
|
121
|
+
|
|
122
|
+
# if this operator does not have enough inputs to execute, then skip it
|
|
123
|
+
num_inputs = len(input_queues[op_id])
|
|
124
|
+
agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, op_idx, input_queues, future_queues)
|
|
125
|
+
if num_inputs == 0 or agg_op_not_ready:
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
# if this operator is an aggregate, process all the records in the input queue
|
|
129
|
+
if isinstance(operator, AggregateOp):
|
|
130
|
+
input_records = [input_queues[op_id].pop(0) for _ in range(num_inputs)]
|
|
131
|
+
future = executor.submit(operator, input_records)
|
|
132
|
+
future_queues[op_id].append(future)
|
|
133
|
+
|
|
197
134
|
else:
|
|
198
|
-
|
|
135
|
+
input_record = input_queues[op_id].pop(0)
|
|
136
|
+
future = executor.submit(operator, input_record)
|
|
137
|
+
future_queues[op_id].append(future)
|
|
199
138
|
|
|
200
|
-
#
|
|
201
|
-
|
|
202
|
-
|
|
139
|
+
# break out of loop if the final operator is a LimitScanOp and we've reached its limit
|
|
140
|
+
if isinstance(final_op, LimitScanOp) and len(output_records) == final_op.limit:
|
|
141
|
+
break
|
|
203
142
|
|
|
204
|
-
#
|
|
205
|
-
|
|
206
|
-
for _ in plan.operators:
|
|
207
|
-
# self.datadir.close_cache(operator.target_cache_id)
|
|
208
|
-
pass
|
|
143
|
+
# close the cache
|
|
144
|
+
self._close_cache([op.target_cache_id for op in plan.operators])
|
|
209
145
|
|
|
210
146
|
# finalize plan stats
|
|
211
|
-
|
|
212
|
-
|
|
147
|
+
plan_stats.finish()
|
|
148
|
+
|
|
149
|
+
return output_records, plan_stats
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def execute_plan(self, plan: PhysicalPlan):
|
|
153
|
+
"""Initialize the stats and execute the plan."""
|
|
154
|
+
# for now, assert that the first operator in the plan is a ScanPhysicalOp
|
|
155
|
+
assert isinstance(plan.operators[0], ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
|
|
156
|
+
logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
|
|
157
|
+
logger.info(f"Plan Details: {plan}")
|
|
158
|
+
|
|
159
|
+
# initialize plan stats
|
|
160
|
+
plan_stats = PlanStats.from_plan(plan)
|
|
161
|
+
plan_stats.start()
|
|
162
|
+
|
|
163
|
+
# initialize input queues and future queues for each operation
|
|
164
|
+
input_queues = self._create_input_queues(plan)
|
|
165
|
+
future_queues = {op.get_op_id(): [] for op in plan.operators}
|
|
166
|
+
|
|
167
|
+
# initialize and start the progress manager
|
|
168
|
+
self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
|
|
169
|
+
self.progress_manager.start()
|
|
170
|
+
|
|
171
|
+
# NOTE: we must handle progress manager outside of _exeecute_plan to ensure that it is shut down correctly;
|
|
172
|
+
# if we don't have the `finally:` branch, then program crashes can cause future program runs to fail
|
|
173
|
+
# because the progress manager cannot get a handle to the console
|
|
174
|
+
try:
|
|
175
|
+
# execute plan
|
|
176
|
+
output_records, plan_stats = self._execute_plan(plan, input_queues, future_queues, plan_stats)
|
|
177
|
+
|
|
178
|
+
finally:
|
|
179
|
+
# finish progress tracking
|
|
180
|
+
self.progress_manager.finish()
|
|
181
|
+
|
|
182
|
+
logger.info(f"Done executing plan: {plan.plan_id}")
|
|
183
|
+
logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
|
|
213
184
|
|
|
214
185
|
return output_records, plan_stats
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from palimpzest.core.data.dataclasses import SentinelPlanStats
|
|
6
|
+
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
7
|
+
from palimpzest.query.execution.execution_strategy import SentinelExecutionStrategy
|
|
8
|
+
from palimpzest.query.operators.physical import PhysicalOperator
|
|
9
|
+
from palimpzest.query.operators.scan import ScanPhysicalOp
|
|
10
|
+
from palimpzest.query.optimizer.plan import SentinelPlan
|
|
11
|
+
from palimpzest.utils.progress import create_progress_manager
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
class OpSet:
|
|
16
|
+
"""
|
|
17
|
+
This class represents the set of operators which are currently in the frontier for a given logical operator.
|
|
18
|
+
Each operator in the frontier is an instance of a PhysicalOperator which either:
|
|
19
|
+
|
|
20
|
+
1. lies on the Pareto frontier of the set of sampled operators, or
|
|
21
|
+
2. has been sampled fewer than j times
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, op_set: list[PhysicalOperator], source_indices: list[int], k: int, j: int, seed: int):
|
|
25
|
+
# set k and j, which are the initial number of operators in the frontier and the
|
|
26
|
+
# initial number of records to sample for each frontier operator
|
|
27
|
+
self.k = min(k, len(op_set))
|
|
28
|
+
self.j = min(j, len(source_indices))
|
|
29
|
+
|
|
30
|
+
# get order in which we will sample physical operators for this logical operator
|
|
31
|
+
sample_op_indices = self._get_op_index_order(op_set, seed)
|
|
32
|
+
|
|
33
|
+
# construct the set of operators
|
|
34
|
+
self.ops = [op_set[sample_idx] for sample_idx in sample_op_indices[:self.k]]
|
|
35
|
+
|
|
36
|
+
# store the order in which we will sample the source records
|
|
37
|
+
self.source_indices = source_indices
|
|
38
|
+
|
|
39
|
+
# set the initial inputs for this logical operator
|
|
40
|
+
is_scan_op = isinstance(op_set[0], ScanPhysicalOp)
|
|
41
|
+
self.source_idx_to_input = {source_idx: [source_idx] for source_idx in self.source_indices} if is_scan_op else {}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _get_op_index_order(self, op_set: list[PhysicalOperator], seed: int) -> list[int]:
|
|
45
|
+
"""
|
|
46
|
+
Returns a list of indices for the operators in the op_set.
|
|
47
|
+
"""
|
|
48
|
+
rng = np.random.default_rng(seed=seed)
|
|
49
|
+
op_indices = np.arange(len(op_set))
|
|
50
|
+
rng.shuffle(op_indices)
|
|
51
|
+
return op_indices
|
|
52
|
+
|
|
53
|
+
def get_op_input_pairs(self) -> list[PhysicalOperator, DataRecord | int | None]:
|
|
54
|
+
"""
|
|
55
|
+
Returns the list of frontier operators and their next input to process. If there are
|
|
56
|
+
any indices in `source_indices_to_sample` which this operator does not sample on its own, then
|
|
57
|
+
we also have this frontier process that source_idx's input with its max quality operator.
|
|
58
|
+
"""
|
|
59
|
+
# get the list of (op, source_idx) pairs which this operator needs to execute
|
|
60
|
+
op_source_idx_pairs = []
|
|
61
|
+
for op in self.ops:
|
|
62
|
+
# construct list of inputs by looking up the input for the given source_idx
|
|
63
|
+
for sample_idx in range(self.j):
|
|
64
|
+
source_idx = self.source_indices[sample_idx]
|
|
65
|
+
op_source_idx_pairs.append((op, source_idx))
|
|
66
|
+
|
|
67
|
+
# fetch the corresponding (op, input) pairs
|
|
68
|
+
op_input_pairs = []
|
|
69
|
+
for op, source_idx in op_source_idx_pairs:
|
|
70
|
+
op_input_pairs.extend([(op, input_record) for input_record in self.source_idx_to_input[source_idx]])
|
|
71
|
+
|
|
72
|
+
return op_input_pairs
|
|
73
|
+
|
|
74
|
+
def pick_highest_quality_output(self, record_sets: list[DataRecordSet]) -> DataRecordSet:
|
|
75
|
+
# if there's only one operator in the set, we return its record_set
|
|
76
|
+
if len(record_sets) == 1:
|
|
77
|
+
return record_sets[0]
|
|
78
|
+
|
|
79
|
+
# NOTE: I don't like that this assumes the models are consistent in
|
|
80
|
+
# how they order their record outputs for one-to-many converts;
|
|
81
|
+
# eventually we can try out more robust schemes to account for
|
|
82
|
+
# differences in ordering
|
|
83
|
+
# aggregate records at each index in the response
|
|
84
|
+
idx_to_records = {}
|
|
85
|
+
for record_set in record_sets:
|
|
86
|
+
for idx in range(len(record_set)):
|
|
87
|
+
record, record_op_stats = record_set[idx], record_set.record_op_stats[idx]
|
|
88
|
+
if idx not in idx_to_records:
|
|
89
|
+
idx_to_records[idx] = [(record, record_op_stats)]
|
|
90
|
+
else:
|
|
91
|
+
idx_to_records[idx].append((record, record_op_stats))
|
|
92
|
+
|
|
93
|
+
# compute highest quality answer at each index
|
|
94
|
+
out_records = []
|
|
95
|
+
out_record_op_stats = []
|
|
96
|
+
for idx in range(len(idx_to_records)):
|
|
97
|
+
records_lst, record_op_stats_lst = zip(*idx_to_records[idx])
|
|
98
|
+
max_quality_record, max_quality = records_lst[0], record_op_stats_lst[0].quality
|
|
99
|
+
max_quality_stats = record_op_stats_lst[0]
|
|
100
|
+
for record, record_op_stats in zip(records_lst[1:], record_op_stats_lst[1:]):
|
|
101
|
+
record_quality = record_op_stats.quality
|
|
102
|
+
if record_quality > max_quality:
|
|
103
|
+
max_quality_record = record
|
|
104
|
+
max_quality = record_quality
|
|
105
|
+
max_quality_stats = record_op_stats
|
|
106
|
+
out_records.append(max_quality_record)
|
|
107
|
+
out_record_op_stats.append(max_quality_stats)
|
|
108
|
+
|
|
109
|
+
# create and return final DataRecordSet
|
|
110
|
+
return DataRecordSet(out_records, out_record_op_stats)
|
|
111
|
+
|
|
112
|
+
def update_inputs(self, source_idx_to_record_sets: dict[int, DataRecordSet]):
|
|
113
|
+
"""
|
|
114
|
+
Update the inputs for this logical operator based on the outputs of the previous logical operator.
|
|
115
|
+
"""
|
|
116
|
+
for source_idx, record_sets in source_idx_to_record_sets.items():
|
|
117
|
+
input = []
|
|
118
|
+
max_quality_record_set = self.pick_highest_quality_output(record_sets)
|
|
119
|
+
for record in max_quality_record_set:
|
|
120
|
+
input.append(record if record.passed_operator else None)
|
|
121
|
+
|
|
122
|
+
self.source_idx_to_input[source_idx] = input
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class RandomSamplingExecutionStrategy(SentinelExecutionStrategy):
|
|
126
|
+
|
|
127
|
+
def _get_source_indices(self):
|
|
128
|
+
"""Get the list of source indices which the sentinel plan should execute over."""
|
|
129
|
+
# create list of all source indices and shuffle it
|
|
130
|
+
total_num_samples = len(self.val_datasource)
|
|
131
|
+
source_indices = list(np.arange(total_num_samples))
|
|
132
|
+
self.rng.shuffle(source_indices)
|
|
133
|
+
|
|
134
|
+
# slice the list of source indices to get the first j indices
|
|
135
|
+
j = min(self.j, len(source_indices))
|
|
136
|
+
source_indices = source_indices[:j]
|
|
137
|
+
|
|
138
|
+
return source_indices
|
|
139
|
+
|
|
140
|
+
def _execute_sentinel_plan(self,
|
|
141
|
+
plan: SentinelPlan,
|
|
142
|
+
op_sets: dict[str, OpSet],
|
|
143
|
+
expected_outputs: dict[int, dict] | None,
|
|
144
|
+
plan_stats: SentinelPlanStats,
|
|
145
|
+
) -> SentinelPlanStats:
|
|
146
|
+
# execute operator sets in sequence
|
|
147
|
+
for op_idx, (logical_op_id, op_set) in enumerate(plan):
|
|
148
|
+
# get frontier ops and their next input
|
|
149
|
+
op_input_pairs = op_sets[logical_op_id].get_op_input_pairs()
|
|
150
|
+
|
|
151
|
+
# break out of the loop if op_input_pairs is empty, as this means all records have been filtered out
|
|
152
|
+
if len(op_input_pairs) == 0:
|
|
153
|
+
break
|
|
154
|
+
|
|
155
|
+
# run sampled operators on sampled inputs
|
|
156
|
+
source_idx_to_record_sets_and_ops, _ = self._execute_op_set(op_input_pairs)
|
|
157
|
+
|
|
158
|
+
# FUTURE TODO: have this return the highest quality record set simply based on our posterior (or prior) belief on operator quality
|
|
159
|
+
# get the target record set for each source_idx
|
|
160
|
+
source_idx_to_target_record_set = self._get_target_record_sets(logical_op_id, source_idx_to_record_sets_and_ops, expected_outputs)
|
|
161
|
+
|
|
162
|
+
# TODO: make consistent across here and RandomSampling
|
|
163
|
+
# FUTURE TODO: move this outside of the loop (i.e. assume we only get quality label(s) after executing full program)
|
|
164
|
+
# score the quality of each generated output
|
|
165
|
+
physical_op_cls = op_set[0].__class__
|
|
166
|
+
source_idx_to_record_sets = {
|
|
167
|
+
source_idx: list(map(lambda tup: tup[0], record_sets_and_ops))
|
|
168
|
+
for source_idx, record_sets_and_ops in source_idx_to_record_sets_and_ops.items()
|
|
169
|
+
}
|
|
170
|
+
source_idx_to_record_sets = self._score_quality(physical_op_cls, source_idx_to_record_sets, source_idx_to_target_record_set)
|
|
171
|
+
|
|
172
|
+
# flatten the lists of records and record_op_stats
|
|
173
|
+
all_records, all_record_op_stats = self._flatten_record_sets(source_idx_to_record_sets)
|
|
174
|
+
|
|
175
|
+
# update plan stats
|
|
176
|
+
plan_stats.add_record_op_stats(all_record_op_stats)
|
|
177
|
+
|
|
178
|
+
# add records (which are not filtered) to the cache, if allowed
|
|
179
|
+
self._add_records_to_cache(logical_op_id, all_records)
|
|
180
|
+
|
|
181
|
+
# FUTURE TODO: simply set input based on source_idx_to_target_record_set (b/c we won't have scores computed)
|
|
182
|
+
# provide the champion record sets as inputs to the next logical operator
|
|
183
|
+
if op_idx + 1 < len(plan):
|
|
184
|
+
next_logical_op_id = plan.logical_op_ids[op_idx + 1]
|
|
185
|
+
op_sets[next_logical_op_id].update_inputs(source_idx_to_record_sets)
|
|
186
|
+
|
|
187
|
+
# close the cache
|
|
188
|
+
self._close_cache(plan.logical_op_ids)
|
|
189
|
+
|
|
190
|
+
# finalize plan stats
|
|
191
|
+
plan_stats.finish()
|
|
192
|
+
|
|
193
|
+
return plan_stats
|
|
194
|
+
|
|
195
|
+
def execute_sentinel_plan(self, plan: SentinelPlan, expected_outputs: dict[int, dict] | None):
|
|
196
|
+
"""
|
|
197
|
+
NOTE: this function currently requires us to set k and j properly in order to make
|
|
198
|
+
comparison in our research against the corresponding sample budget in MAB.
|
|
199
|
+
|
|
200
|
+
NOTE: the number of samples will slightly exceed the sample_budget if the number of operator
|
|
201
|
+
calls does not perfectly match the sample_budget. This may cause some minor discrepancies with
|
|
202
|
+
the progress manager as a result.
|
|
203
|
+
"""
|
|
204
|
+
# for now, assert that the first operator in the plan is a ScanPhysicalOp
|
|
205
|
+
assert all(isinstance(op, ScanPhysicalOp) for op in plan.operator_sets[0]), "First operator in physical plan must be a ScanPhysicalOp"
|
|
206
|
+
logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
|
|
207
|
+
logger.info(f"Plan Details: {plan}")
|
|
208
|
+
|
|
209
|
+
# initialize plan stats
|
|
210
|
+
plan_stats = SentinelPlanStats.from_plan(plan)
|
|
211
|
+
plan_stats.start()
|
|
212
|
+
|
|
213
|
+
# get list of source indices which can be sampled from
|
|
214
|
+
source_indices = self._get_source_indices()
|
|
215
|
+
|
|
216
|
+
# initialize set of physical operators for each logical operator
|
|
217
|
+
op_sets = {
|
|
218
|
+
logical_op_id: OpSet(op_set, source_indices, self.k, self.j, self.seed)
|
|
219
|
+
for logical_op_id, op_set in plan
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
# initialize and start the progress manager
|
|
223
|
+
self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, progress=self.progress)
|
|
224
|
+
self.progress_manager.start()
|
|
225
|
+
|
|
226
|
+
# NOTE: we must handle progress manager outside of _exeecute_sentinel_plan to ensure that it is shut down correctly;
|
|
227
|
+
# if we don't have the `finally:` branch, then program crashes can cause future program runs to fail because
|
|
228
|
+
# the progress manager cannot get a handle to the console
|
|
229
|
+
try:
|
|
230
|
+
# execute sentinel plan by sampling records and operators
|
|
231
|
+
plan_stats = self._execute_sentinel_plan(plan, op_sets, expected_outputs, plan_stats)
|
|
232
|
+
|
|
233
|
+
finally:
|
|
234
|
+
# finish progress tracking
|
|
235
|
+
self.progress_manager.finish()
|
|
236
|
+
|
|
237
|
+
logger.info(f"Done executing sentinel plan: {plan.plan_id}")
|
|
238
|
+
logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
|
|
239
|
+
|
|
240
|
+
return plan_stats
|