palimpzest 0.7.21__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +259 -197
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +634 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +61 -5
- palimpzest/prompts/filter_prompts.py +50 -5
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
- palimpzest/prompts/prompt_factory.py +358 -46
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +157 -330
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +27 -21
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +22 -13
- palimpzest/query/operators/join.py +402 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +198 -80
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +41 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +27 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
- palimpzest-0.8.0.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -2,11 +2,7 @@ from enum import Enum
|
|
|
2
2
|
|
|
3
3
|
from palimpzest.query.execution.all_sample_execution_strategy import AllSamplingExecutionStrategy
|
|
4
4
|
from palimpzest.query.execution.mab_execution_strategy import MABExecutionStrategy
|
|
5
|
-
from palimpzest.query.execution.parallel_execution_strategy import
|
|
6
|
-
ParallelExecutionStrategy,
|
|
7
|
-
SequentialParallelExecutionStrategy,
|
|
8
|
-
)
|
|
9
|
-
from palimpzest.query.execution.random_sampling_execution_strategy import RandomSamplingExecutionStrategy
|
|
5
|
+
from palimpzest.query.execution.parallel_execution_strategy import ParallelExecutionStrategy
|
|
10
6
|
from palimpzest.query.execution.single_threaded_execution_strategy import (
|
|
11
7
|
PipelinedSingleThreadExecutionStrategy,
|
|
12
8
|
SequentialSingleThreadExecutionStrategy,
|
|
@@ -18,9 +14,11 @@ class ExecutionStrategyType(Enum):
|
|
|
18
14
|
SEQUENTIAL = SequentialSingleThreadExecutionStrategy
|
|
19
15
|
PIPELINED = PipelinedSingleThreadExecutionStrategy
|
|
20
16
|
PARALLEL = ParallelExecutionStrategy
|
|
21
|
-
|
|
17
|
+
|
|
18
|
+
def is_fully_parallel(self) -> bool:
|
|
19
|
+
"""Check if the execution strategy executes operators in parallel."""
|
|
20
|
+
return self == ExecutionStrategyType.PARALLEL
|
|
22
21
|
|
|
23
22
|
class SentinelExecutionStrategyType(Enum):
|
|
24
23
|
MAB = MABExecutionStrategy
|
|
25
|
-
RANDOM = RandomSamplingExecutionStrategy
|
|
26
24
|
ALL = AllSamplingExecutionStrategy
|
|
@@ -1,19 +1,26 @@
|
|
|
1
|
+
|
|
1
2
|
import logging
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
|
|
5
|
-
from palimpzest.core.data.
|
|
6
|
+
from palimpzest.core.data.dataset import Dataset
|
|
6
7
|
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
8
|
+
from palimpzest.core.models import OperatorStats, RecordOpStats, SentinelPlanStats
|
|
7
9
|
from palimpzest.policy import Policy
|
|
8
10
|
from palimpzest.query.execution.execution_strategy import SentinelExecutionStrategy
|
|
11
|
+
from palimpzest.query.operators.aggregate import AggregateOp
|
|
9
12
|
from palimpzest.query.operators.filter import FilterOp
|
|
13
|
+
from palimpzest.query.operators.join import JoinOp
|
|
10
14
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
11
|
-
from palimpzest.query.operators.scan import ScanPhysicalOp
|
|
15
|
+
from palimpzest.query.operators.scan import ContextScanOp, ScanPhysicalOp
|
|
12
16
|
from palimpzest.query.optimizer.plan import SentinelPlan
|
|
13
17
|
from palimpzest.utils.progress import create_progress_manager
|
|
18
|
+
from palimpzest.validator.validator import Validator
|
|
14
19
|
|
|
15
20
|
logger = logging.getLogger(__name__)
|
|
16
21
|
|
|
22
|
+
# NOTE: we currently do not support Sentinel Plans with aggregates or limits which are not the final plan operator
|
|
23
|
+
|
|
17
24
|
class OpFrontier:
|
|
18
25
|
"""
|
|
19
26
|
This class represents the set of operators which are currently in the frontier for a given logical operator.
|
|
@@ -23,11 +30,24 @@ class OpFrontier:
|
|
|
23
30
|
2. has been sampled fewer than j times
|
|
24
31
|
"""
|
|
25
32
|
|
|
26
|
-
def __init__(
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
op_set: list[PhysicalOperator],
|
|
36
|
+
source_unique_logical_op_ids: list[str],
|
|
37
|
+
root_dataset_ids: list[str],
|
|
38
|
+
source_indices: list[tuple],
|
|
39
|
+
k: int,
|
|
40
|
+
j: int,
|
|
41
|
+
seed: int,
|
|
42
|
+
policy: Policy,
|
|
43
|
+
priors: dict | None = None,
|
|
44
|
+
):
|
|
27
45
|
# set k and j, which are the initial number of operators in the frontier and the
|
|
28
46
|
# initial number of records to sample for each frontier operator
|
|
29
47
|
self.k = min(k, len(op_set))
|
|
30
|
-
self.j =
|
|
48
|
+
self.j = j
|
|
49
|
+
self.source_indices = source_indices
|
|
50
|
+
self.root_dataset_ids = root_dataset_ids
|
|
31
51
|
|
|
32
52
|
# store the policy that we are optimizing under
|
|
33
53
|
self.policy = policy
|
|
@@ -43,18 +63,25 @@ class OpFrontier:
|
|
|
43
63
|
self.reservoir_ops = [op_set[sample_idx] for sample_idx in sample_op_indices[self.k:]]
|
|
44
64
|
self.off_frontier_ops: list[PhysicalOperator] = []
|
|
45
65
|
|
|
46
|
-
#
|
|
47
|
-
self.source_indices = source_indices
|
|
48
|
-
|
|
49
|
-
# keep track of the source ids processed by each physical operator
|
|
66
|
+
# keep track of the source indices processed by each physical operator
|
|
50
67
|
self.full_op_id_to_sources_processed = {op.get_full_op_id(): set() for op in op_set}
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
self.is_filter_op = isinstance(
|
|
68
|
+
self.full_op_id_to_sources_not_processed = {op.get_full_op_id(): source_indices for op in op_set}
|
|
69
|
+
self.max_inputs = len(source_indices)
|
|
70
|
+
|
|
71
|
+
# boolean indication of the type of operator in this OpFrontier
|
|
72
|
+
sample_op = op_set[0]
|
|
73
|
+
self.is_scan_op = isinstance(sample_op, (ScanPhysicalOp, ContextScanOp))
|
|
74
|
+
self.is_filter_op = isinstance(sample_op, FilterOp)
|
|
75
|
+
self.is_aggregate_op = isinstance(sample_op, AggregateOp)
|
|
76
|
+
self.is_llm_join = isinstance(sample_op, JoinOp)
|
|
77
|
+
|
|
78
|
+
# set the initial inputs for this logical operator; we maintain a mapping from source_unique_logical_op_id --> source_indices --> input;
|
|
79
|
+
# for each unique source and (tuple of) source indices, we store its output, which is an input to this operator
|
|
80
|
+
# for scan operators, we use the default name "source" since these operators have no source
|
|
81
|
+
self.source_indices_to_inputs = {source_unique_logical_op_id: {} for source_unique_logical_op_id in source_unique_logical_op_ids}
|
|
82
|
+
if self.is_scan_op:
|
|
83
|
+
self.source_indices_to_inputs["source"] = {source_idx: [int(source_idx.split("-")[-1])] for source_idx in source_indices}
|
|
84
|
+
|
|
58
85
|
|
|
59
86
|
def get_frontier_ops(self) -> list[PhysicalOperator]:
|
|
60
87
|
"""
|
|
@@ -180,71 +207,126 @@ class OpFrontier:
|
|
|
180
207
|
|
|
181
208
|
return op_indices
|
|
182
209
|
|
|
183
|
-
def
|
|
210
|
+
def _get_op_source_indices_pairs(self) -> list[tuple[PhysicalOperator, tuple[str] | None]]:
|
|
184
211
|
"""
|
|
185
|
-
Returns a list of tuples for (op,
|
|
212
|
+
Returns a list of tuples for (op, source_indices) which this operator needs to execute
|
|
186
213
|
in the next iteration.
|
|
187
214
|
"""
|
|
188
|
-
|
|
189
|
-
for op in self.frontier_ops:
|
|
190
|
-
# execute new operators on first j source indices, and previously sampled operators on one additional source_idx
|
|
191
|
-
num_processed = len(self.full_op_id_to_sources_processed[op.get_full_op_id()])
|
|
192
|
-
num_new_samples = 1 if num_processed > 0 else self.j
|
|
193
|
-
num_new_samples = min(num_new_samples, len(self.source_indices) - num_processed)
|
|
194
|
-
assert num_new_samples >= 0, "Number of new samples must be non-negative"
|
|
195
|
-
|
|
196
|
-
# construct list of inputs by looking up the input for the given source_idx
|
|
197
|
-
samples_added = 0
|
|
198
|
-
for source_idx in self.source_indices:
|
|
199
|
-
if source_idx in self.full_op_id_to_sources_processed[op.get_full_op_id()]:
|
|
200
|
-
continue
|
|
215
|
+
op_source_indices_pairs = []
|
|
201
216
|
|
|
202
|
-
|
|
203
|
-
|
|
217
|
+
# if this operator is not being optimized: we don't request inputs, but simply process what we are given / told to (in the case of scans)
|
|
218
|
+
if not self.is_llm_join and len(self.frontier_ops) == 1:
|
|
219
|
+
return [(self.frontier_ops[0], None)]
|
|
204
220
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
221
|
+
# otherwise, sample (operator, source_indices) pairs
|
|
222
|
+
for op in self.frontier_ops:
|
|
223
|
+
# execute new operators on first j indices per root dataset, and previously sampled operators on one per root dataset
|
|
224
|
+
new_operator = self.full_op_id_to_sources_processed[op.get_full_op_id()] == set()
|
|
225
|
+
samples_per_root_dataset = self.j if new_operator else 1
|
|
226
|
+
num_root_datasets = len(self.root_dataset_ids)
|
|
227
|
+
num_samples = samples_per_root_dataset**num_root_datasets
|
|
228
|
+
samples = self.full_op_id_to_sources_not_processed[op.get_full_op_id()][:num_samples]
|
|
229
|
+
for source_indices in samples:
|
|
230
|
+
op_source_indices_pairs.append((op, source_indices))
|
|
231
|
+
|
|
232
|
+
return op_source_indices_pairs
|
|
233
|
+
|
|
234
|
+
def get_source_indices_for_next_iteration(self) -> set[tuple[str]]:
|
|
212
235
|
"""
|
|
213
236
|
Returns the set of source indices which need to be sampled for the next iteration.
|
|
214
237
|
"""
|
|
215
|
-
|
|
216
|
-
return set(
|
|
238
|
+
op_source_indices_pairs = self._get_op_source_indices_pairs()
|
|
239
|
+
return set([source_indices for _, source_indices in op_source_indices_pairs if source_indices is not None])
|
|
217
240
|
|
|
218
|
-
def
|
|
241
|
+
def get_frontier_op_inputs(self, source_indices_to_sample: set[tuple[str]], max_quality_op: PhysicalOperator) -> list[tuple[PhysicalOperator, tuple[str], list[DataRecord] | list[int] | None]]:
|
|
219
242
|
"""
|
|
220
243
|
Returns the list of frontier operators and their next input to process. If there are
|
|
221
244
|
any indices in `source_indices_to_sample` which this operator does not sample on its own, then
|
|
222
|
-
we also have this frontier process
|
|
245
|
+
we also have this frontier process those source indices' input with its max quality operator.
|
|
223
246
|
"""
|
|
224
|
-
#
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
247
|
+
# if this is an aggregate, run on every input
|
|
248
|
+
if self.is_aggregate_op:
|
|
249
|
+
# NOTE: we don't keep track of source indices for aggregate (would require computing powerset of all source records);
|
|
250
|
+
# thus, we cannot currently support optimizing plans w/LLM operators after aggregations
|
|
251
|
+
op = self.frontier_ops[0]
|
|
252
|
+
all_inputs = []
|
|
253
|
+
for _, source_indices_to_inputs in self.source_indices_to_inputs.items():
|
|
254
|
+
for _, inputs in source_indices_to_inputs.items():
|
|
255
|
+
all_inputs.extend(inputs)
|
|
256
|
+
return [(op, tuple(), all_inputs)]
|
|
257
|
+
|
|
258
|
+
# if this is an un-optimized (non-scan, non-join) operator, flatten inputs and run on each one
|
|
259
|
+
elif not self.is_scan_op and not self.is_llm_join and len(self.frontier_ops) == 1:
|
|
260
|
+
op_inputs = []
|
|
261
|
+
op = self.frontier_ops[0]
|
|
262
|
+
for _, source_indices_to_inputs in self.source_indices_to_inputs.items():
|
|
263
|
+
for source_indices, inputs in source_indices_to_inputs.items():
|
|
264
|
+
for input in inputs:
|
|
265
|
+
op_inputs.append((op, source_indices, input))
|
|
266
|
+
return op_inputs
|
|
267
|
+
|
|
268
|
+
### for optimized operators
|
|
269
|
+
# get the list of (op, source_indices) pairs which this operator needs to execute
|
|
270
|
+
op_source_indices_pairs = self._get_op_source_indices_pairs()
|
|
271
|
+
|
|
272
|
+
# remove any root datasets which this op frontier does not have access to from the source_indices_to_sample
|
|
273
|
+
def remove_unavailable_root_datasets(source_indices: str | tuple) -> str | tuple | None:
|
|
274
|
+
# base case: source_indices is a string
|
|
275
|
+
if isinstance(source_indices, str):
|
|
276
|
+
return source_indices if source_indices.split("-")[0] in self.root_dataset_ids else None
|
|
277
|
+
|
|
278
|
+
# recursive case: source_indices is a tuple
|
|
279
|
+
left_indices = source_indices[0]
|
|
280
|
+
right_indices = source_indices[1]
|
|
281
|
+
left_filtered = remove_unavailable_root_datasets(left_indices)
|
|
282
|
+
right_filtered = remove_unavailable_root_datasets(right_indices)
|
|
283
|
+
if left_filtered is None and right_filtered is None:
|
|
284
|
+
return None
|
|
285
|
+
|
|
286
|
+
if left_filtered is None:
|
|
287
|
+
return right_filtered
|
|
288
|
+
if right_filtered is None:
|
|
289
|
+
return left_filtered
|
|
290
|
+
return (left_filtered, right_filtered)
|
|
291
|
+
|
|
292
|
+
source_indices_to_sample = {remove_unavailable_root_datasets(source_indices) for source_indices in source_indices_to_sample}
|
|
293
|
+
|
|
294
|
+
# if there are any source_indices in source_indices_to_sample which are not sampled by this operator,
|
|
295
|
+
# apply the max quality operator (and any other frontier operators with no samples)
|
|
296
|
+
sampled_source_indices = set(map(lambda tup: tup[1], op_source_indices_pairs))
|
|
231
297
|
unsampled_source_indices = source_indices_to_sample - sampled_source_indices
|
|
232
|
-
for
|
|
233
|
-
|
|
298
|
+
for source_indices in unsampled_source_indices:
|
|
299
|
+
op_source_indices_pairs.append((max_quality_op, source_indices))
|
|
234
300
|
for op in self.frontier_ops:
|
|
235
|
-
if
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
#
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
301
|
+
if self.full_op_id_to_sources_processed[op.get_full_op_id()] == set() and op.get_full_op_id() != max_quality_op.get_full_op_id():
|
|
302
|
+
op_source_indices_pairs.append((op, source_indices))
|
|
303
|
+
|
|
304
|
+
# construct the op inputs
|
|
305
|
+
op_inputs = []
|
|
306
|
+
if self.is_llm_join:
|
|
307
|
+
left_source_unique_logical_op_id, right_source_unique_logical_op_id = list(self.source_indices_to_inputs)
|
|
308
|
+
left_source_indices_to_inputs = self.source_indices_to_inputs[left_source_unique_logical_op_id]
|
|
309
|
+
right_source_indices_to_inputs = self.source_indices_to_inputs[right_source_unique_logical_op_id]
|
|
310
|
+
for op, source_indices in op_source_indices_pairs:
|
|
311
|
+
left_source_indices = source_indices[0]
|
|
312
|
+
right_source_indices = source_indices[1]
|
|
313
|
+
left_inputs = left_source_indices_to_inputs.get(left_source_indices, [])
|
|
314
|
+
right_inputs = right_source_indices_to_inputs.get(right_source_indices, [])
|
|
315
|
+
if len(left_inputs) > 0 and len(right_inputs) > 0:
|
|
316
|
+
op_inputs.append((op, (left_source_indices, right_source_indices), (left_inputs, right_inputs)))
|
|
317
|
+
return op_inputs
|
|
318
|
+
|
|
319
|
+
# if operator is not a join
|
|
320
|
+
source_unique_logical_op_id = list(self.source_indices_to_inputs)[0]
|
|
321
|
+
op_inputs = [
|
|
322
|
+
(op, source_indices, input)
|
|
323
|
+
for op, source_indices in op_source_indices_pairs
|
|
324
|
+
for input in self.source_indices_to_inputs[source_unique_logical_op_id].get(source_indices, [])
|
|
243
325
|
]
|
|
244
326
|
|
|
245
|
-
return
|
|
327
|
+
return op_inputs
|
|
246
328
|
|
|
247
|
-
def update_frontier(self,
|
|
329
|
+
def update_frontier(self, unique_logical_op_id: str, plan_stats: SentinelPlanStats) -> None:
|
|
248
330
|
"""
|
|
249
331
|
Update the set of frontier operators, pulling in new ones from the reservoir as needed.
|
|
250
332
|
This function will:
|
|
@@ -256,8 +338,8 @@ class OpFrontier:
|
|
|
256
338
|
# upstream operators change; in this case, we de-duplicate record_op_stats with identical record_ids
|
|
257
339
|
# and keep the one with the maximum quality
|
|
258
340
|
# get a mapping from full_op_id --> list[RecordOpStats]
|
|
259
|
-
full_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(
|
|
260
|
-
full_op_id_to_record_op_stats = {}
|
|
341
|
+
full_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(unique_logical_op_id, {})
|
|
342
|
+
full_op_id_to_record_op_stats: dict[str, list[RecordOpStats]] = {}
|
|
261
343
|
for full_op_id, op_stats in full_op_id_to_op_stats.items():
|
|
262
344
|
# skip over operators which have not been sampled
|
|
263
345
|
if len(op_stats.record_op_stats_lst) == 0:
|
|
@@ -281,8 +363,23 @@ class OpFrontier:
|
|
|
281
363
|
full_op_id_to_num_samples, total_num_samples = {}, 0
|
|
282
364
|
for full_op_id, record_op_stats_lst in full_op_id_to_record_op_stats.items():
|
|
283
365
|
# update the set of source indices processed
|
|
366
|
+
source_indices_processed = set()
|
|
284
367
|
for record_op_stats in record_op_stats_lst:
|
|
285
|
-
|
|
368
|
+
source_indices = record_op_stats.record_source_indices
|
|
369
|
+
|
|
370
|
+
if len(source_indices) == 1:
|
|
371
|
+
source_indices = source_indices[0]
|
|
372
|
+
elif self.is_llm_join or self.is_aggregate_op:
|
|
373
|
+
source_indices = tuple(source_indices)
|
|
374
|
+
|
|
375
|
+
self.full_op_id_to_sources_processed[full_op_id].add(source_indices)
|
|
376
|
+
source_indices_processed.add(source_indices)
|
|
377
|
+
|
|
378
|
+
# update the set of source indices not processed
|
|
379
|
+
self.full_op_id_to_sources_not_processed[full_op_id] = [
|
|
380
|
+
indices for indices in self.full_op_id_to_sources_not_processed[full_op_id]
|
|
381
|
+
if indices not in source_indices_processed
|
|
382
|
+
]
|
|
286
383
|
|
|
287
384
|
# compute the number of samples as the number of source indices processed
|
|
288
385
|
num_samples = len(self.full_op_id_to_sources_processed[full_op_id])
|
|
@@ -290,11 +387,20 @@ class OpFrontier:
|
|
|
290
387
|
total_num_samples += num_samples
|
|
291
388
|
|
|
292
389
|
# compute avg. selectivity, cost, time, and quality for each physical operator
|
|
293
|
-
def total_output(record_op_stats_lst):
|
|
390
|
+
def total_output(record_op_stats_lst: list[RecordOpStats]):
|
|
294
391
|
return sum([record_op_stats.passed_operator for record_op_stats in record_op_stats_lst])
|
|
295
392
|
|
|
296
|
-
def total_input(record_op_stats_lst):
|
|
297
|
-
|
|
393
|
+
def total_input(record_op_stats_lst: list[RecordOpStats]):
|
|
394
|
+
# TODO: this is okay for now because we only really need these calculations for Converts and Filters,
|
|
395
|
+
# but this will need more thought if/when we optimize joins
|
|
396
|
+
all_parent_ids = []
|
|
397
|
+
for record_op_stats in record_op_stats_lst:
|
|
398
|
+
all_parent_ids.extend(
|
|
399
|
+
[None]
|
|
400
|
+
if record_op_stats.record_parent_ids is None
|
|
401
|
+
else record_op_stats.record_parent_ids
|
|
402
|
+
)
|
|
403
|
+
return len(set(all_parent_ids))
|
|
298
404
|
|
|
299
405
|
full_op_id_to_mean_selectivity = {
|
|
300
406
|
full_op_id: total_output(record_op_stats_lst) / total_input(record_op_stats_lst)
|
|
@@ -309,7 +415,7 @@ class OpFrontier:
|
|
|
309
415
|
for full_op_id, record_op_stats_lst in full_op_id_to_record_op_stats.items()
|
|
310
416
|
}
|
|
311
417
|
full_op_id_to_mean_quality = {
|
|
312
|
-
full_op_id: np.mean([record_op_stats.quality for record_op_stats in record_op_stats_lst])
|
|
418
|
+
full_op_id: np.mean([record_op_stats.quality for record_op_stats in record_op_stats_lst if record_op_stats.quality is not None])
|
|
313
419
|
for full_op_id, record_op_stats_lst in full_op_id_to_record_op_stats.items()
|
|
314
420
|
}
|
|
315
421
|
|
|
@@ -373,7 +479,7 @@ class OpFrontier:
|
|
|
373
479
|
for full_op_id, metrics in op_metrics.items():
|
|
374
480
|
|
|
375
481
|
# if this op is fully sampled, do not keep it on the frontier
|
|
376
|
-
if
|
|
482
|
+
if len(self.full_op_id_to_sources_processed[full_op_id]) == self.max_inputs:
|
|
377
483
|
continue
|
|
378
484
|
|
|
379
485
|
# if this op is pareto optimal keep it in our frontier ops
|
|
@@ -455,10 +561,10 @@ class OpFrontier:
|
|
|
455
561
|
out_record_op_stats = []
|
|
456
562
|
for idx in range(len(idx_to_records)):
|
|
457
563
|
records_lst, record_op_stats_lst = zip(*idx_to_records[idx])
|
|
458
|
-
max_quality_record, max_quality = records_lst[0], record_op_stats_lst[0].quality
|
|
564
|
+
max_quality_record, max_quality = records_lst[0], record_op_stats_lst[0].quality if record_op_stats_lst[0].quality is not None else 0.0
|
|
459
565
|
max_quality_stats = record_op_stats_lst[0]
|
|
460
566
|
for record, record_op_stats in zip(records_lst[1:], record_op_stats_lst[1:]):
|
|
461
|
-
record_quality = record_op_stats.quality
|
|
567
|
+
record_quality = record_op_stats.quality if record_op_stats.quality is not None else 0.0
|
|
462
568
|
if record_quality > max_quality:
|
|
463
569
|
max_quality_record = record
|
|
464
570
|
max_quality = record_quality
|
|
@@ -469,22 +575,19 @@ class OpFrontier:
|
|
|
469
575
|
# create and return final DataRecordSet
|
|
470
576
|
return DataRecordSet(out_records, out_record_op_stats)
|
|
471
577
|
|
|
472
|
-
def update_inputs(self,
|
|
578
|
+
def update_inputs(self, source_unique_logical_op_id: str, source_indices_to_record_sets: dict[tuple[int], list[DataRecordSet]]):
|
|
473
579
|
"""
|
|
474
580
|
Update the inputs for this logical operator based on the outputs of the previous logical operator.
|
|
475
581
|
"""
|
|
476
|
-
for
|
|
582
|
+
for source_indices, record_sets in source_indices_to_record_sets.items():
|
|
477
583
|
input = []
|
|
478
584
|
max_quality_record_set = self.pick_highest_quality_output(record_sets)
|
|
479
585
|
for record in max_quality_record_set:
|
|
480
586
|
input.append(record if record.passed_operator else None)
|
|
481
587
|
|
|
482
|
-
self.
|
|
588
|
+
self.source_indices_to_inputs[source_unique_logical_op_id][source_indices] = input
|
|
483
589
|
|
|
484
590
|
|
|
485
|
-
# TODO: post-submission we will need to modify this to:
|
|
486
|
-
# - submit all inputs for aggregate operators
|
|
487
|
-
# - handle limits
|
|
488
591
|
class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
489
592
|
"""
|
|
490
593
|
This class implements the Multi-Armed Bandit (MAB) execution strategy for SentinelQueryProcessors.
|
|
@@ -493,15 +596,15 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
493
596
|
calls does not perfectly match the sample_budget. This may cause some minor discrepancies with
|
|
494
597
|
the progress manager as a result.
|
|
495
598
|
"""
|
|
496
|
-
def _get_max_quality_op(self,
|
|
599
|
+
def _get_max_quality_op(self, unique_logical_op_id: str, op_frontiers: dict[str, OpFrontier], plan_stats: SentinelPlanStats) -> PhysicalOperator:
|
|
497
600
|
"""
|
|
498
601
|
Returns the operator in the frontier with the highest (estimated) quality.
|
|
499
602
|
"""
|
|
500
603
|
# get the operators in the frontier set for this logical_op_id
|
|
501
|
-
frontier_ops = op_frontiers[
|
|
604
|
+
frontier_ops = op_frontiers[unique_logical_op_id].get_frontier_ops()
|
|
502
605
|
|
|
503
606
|
# get a mapping from full_op_id --> list[RecordOpStats]
|
|
504
|
-
full_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(
|
|
607
|
+
full_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(unique_logical_op_id, {})
|
|
505
608
|
full_op_id_to_record_op_stats = {
|
|
506
609
|
full_op_id: op_stats.record_op_stats_lst
|
|
507
610
|
for full_op_id, op_stats in full_op_id_to_op_stats.items()
|
|
@@ -524,7 +627,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
524
627
|
self,
|
|
525
628
|
plan: SentinelPlan,
|
|
526
629
|
op_frontiers: dict[str, OpFrontier],
|
|
527
|
-
|
|
630
|
+
validator: Validator,
|
|
528
631
|
plan_stats: SentinelPlanStats,
|
|
529
632
|
) -> SentinelPlanStats:
|
|
530
633
|
# sample records and operators and update the frontiers
|
|
@@ -537,64 +640,54 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
537
640
|
source_indices_to_sample.update(source_indices)
|
|
538
641
|
|
|
539
642
|
# execute operator sets in sequence
|
|
540
|
-
for
|
|
643
|
+
for topo_idx, (logical_op_id, _) in enumerate(plan):
|
|
644
|
+
# compute unique logical op id within plan
|
|
645
|
+
unique_logical_op_id = f"{topo_idx}-{logical_op_id}"
|
|
646
|
+
|
|
541
647
|
# use the execution cache to determine the maximum quality operator for this logical_op_id
|
|
542
|
-
max_quality_op = self._get_max_quality_op(
|
|
648
|
+
max_quality_op = self._get_max_quality_op(unique_logical_op_id, op_frontiers, plan_stats)
|
|
543
649
|
|
|
544
|
-
# TODO: can have None as an operator if _get_max_quality_op returns None
|
|
545
650
|
# get frontier ops and their next input
|
|
546
|
-
|
|
547
|
-
|
|
651
|
+
frontier_op_inputs = op_frontiers[unique_logical_op_id].get_frontier_op_inputs(source_indices_to_sample, max_quality_op)
|
|
652
|
+
frontier_op_inputs = list(filter(lambda tup: tup[-1] is not None, frontier_op_inputs))
|
|
548
653
|
|
|
549
|
-
# break out of the loop if
|
|
550
|
-
if len(
|
|
654
|
+
# break out of the loop if frontier_op_inputs is empty, as this means all records have been filtered out
|
|
655
|
+
if len(frontier_op_inputs) == 0:
|
|
551
656
|
break
|
|
552
657
|
|
|
553
658
|
# run sampled operators on sampled inputs and update the number of samples drawn
|
|
554
|
-
|
|
659
|
+
source_indices_to_record_set_tuples, num_llm_ops = self._execute_op_set(unique_logical_op_id, frontier_op_inputs)
|
|
555
660
|
samples_drawn += num_llm_ops
|
|
556
661
|
|
|
557
|
-
# FUTURE TODO: have this return the highest quality record set simply based on our posterior (or prior) belief on operator quality
|
|
558
|
-
# get the target record set for each source_idx
|
|
559
|
-
source_idx_to_target_record_set = self._get_target_record_sets(logical_op_id, source_idx_to_record_set_tuples, expected_outputs)
|
|
560
|
-
|
|
561
|
-
# FUTURE TODO: move this outside of the loop (i.e. assume we only get quality label(s) after executing full program)
|
|
562
662
|
# score the quality of each generated output
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
for source_idx, record_set_tuples in source_idx_to_record_set_tuples.items()
|
|
663
|
+
source_indices_to_all_record_sets = {
|
|
664
|
+
source_indices: [(record_set, op) for record_set, op, _ in record_set_tuples]
|
|
665
|
+
for source_indices, record_set_tuples in source_indices_to_record_set_tuples.items()
|
|
567
666
|
}
|
|
568
|
-
|
|
667
|
+
source_indices_to_all_record_sets, val_gen_stats = self._score_quality(validator, source_indices_to_all_record_sets)
|
|
569
668
|
|
|
570
|
-
#
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
for
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
# update the number of samples drawn for each operator
|
|
669
|
+
# remove records that were read from the execution cache before adding to record op stats
|
|
670
|
+
new_record_op_stats = []
|
|
671
|
+
for _, record_set_tuples in source_indices_to_record_set_tuples.items():
|
|
672
|
+
for record_set, _, is_new in record_set_tuples:
|
|
673
|
+
if is_new:
|
|
674
|
+
new_record_op_stats.extend(record_set.record_op_stats)
|
|
578
675
|
|
|
579
676
|
# update plan stats
|
|
580
|
-
plan_stats.add_record_op_stats(new_record_op_stats)
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
677
|
+
plan_stats.add_record_op_stats(unique_logical_op_id, new_record_op_stats)
|
|
678
|
+
plan_stats.add_validation_gen_stats(unique_logical_op_id, val_gen_stats)
|
|
679
|
+
|
|
680
|
+
# provide the best record sets as inputs to the next logical operator
|
|
681
|
+
next_unique_logical_op_id = plan.get_next_unique_logical_op_id(unique_logical_op_id)
|
|
682
|
+
if next_unique_logical_op_id is not None:
|
|
683
|
+
source_indices_to_all_record_sets = {
|
|
684
|
+
source_indices: [record_set for record_set, _ in record_set_tuples]
|
|
685
|
+
for source_indices, record_set_tuples in source_indices_to_all_record_sets.items()
|
|
686
|
+
}
|
|
687
|
+
op_frontiers[next_unique_logical_op_id].update_inputs(unique_logical_op_id, source_indices_to_all_record_sets)
|
|
590
688
|
|
|
591
689
|
# update the (pareto) frontier for each set of operators
|
|
592
|
-
op_frontiers[
|
|
593
|
-
|
|
594
|
-
# FUTURE TODO: score op quality based on final outputs
|
|
595
|
-
|
|
596
|
-
# close the cache
|
|
597
|
-
self._close_cache(plan.logical_op_ids)
|
|
690
|
+
op_frontiers[unique_logical_op_id].update_frontier(unique_logical_op_id, plan_stats)
|
|
598
691
|
|
|
599
692
|
# finalize plan stats
|
|
600
693
|
plan_stats.finish()
|
|
@@ -602,9 +695,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
602
695
|
return plan_stats
|
|
603
696
|
|
|
604
697
|
|
|
605
|
-
def execute_sentinel_plan(self, plan: SentinelPlan,
|
|
606
|
-
# for now, assert that the first operator in the plan is a ScanPhysicalOp
|
|
607
|
-
assert all(isinstance(op, ScanPhysicalOp) for op in plan.operator_sets[0]), "First operator in physical plan must be a ScanPhysicalOp"
|
|
698
|
+
def execute_sentinel_plan(self, plan: SentinelPlan, train_dataset: dict[str, Dataset], validator: Validator) -> SentinelPlanStats:
|
|
608
699
|
logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
|
|
609
700
|
logger.info(f"Plan Details: {plan}")
|
|
610
701
|
|
|
@@ -613,26 +704,48 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
613
704
|
plan_stats.start()
|
|
614
705
|
|
|
615
706
|
# shuffle the indices of records to sample
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
707
|
+
dataset_id_to_shuffled_source_indices = {}
|
|
708
|
+
for dataset_id, dataset in train_dataset.items():
|
|
709
|
+
total_num_samples = len(dataset)
|
|
710
|
+
shuffled_source_indices = [f"{dataset_id}-{int(idx)}" for idx in np.arange(total_num_samples)]
|
|
711
|
+
self.rng.shuffle(shuffled_source_indices)
|
|
712
|
+
dataset_id_to_shuffled_source_indices[dataset_id] = shuffled_source_indices
|
|
619
713
|
|
|
620
714
|
# initialize frontier for each logical operator
|
|
621
|
-
op_frontiers = {
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
715
|
+
op_frontiers = {}
|
|
716
|
+
for topo_idx, (logical_op_id, op_set) in enumerate(plan):
|
|
717
|
+
unique_logical_op_id = f"{topo_idx}-{logical_op_id}"
|
|
718
|
+
source_unique_logical_op_ids = plan.get_source_unique_logical_op_ids(unique_logical_op_id)
|
|
719
|
+
root_dataset_ids = plan.get_root_dataset_ids(unique_logical_op_id)
|
|
720
|
+
sample_op = op_set[0]
|
|
721
|
+
if isinstance(sample_op, (ScanPhysicalOp, ContextScanOp)):
|
|
722
|
+
assert len(root_dataset_ids) == 1, f"Scan for {sample_op} has {len(root_dataset_ids)} > 1 root dataset ids"
|
|
723
|
+
root_dataset_id = root_dataset_ids[0]
|
|
724
|
+
source_indices = dataset_id_to_shuffled_source_indices[root_dataset_id]
|
|
725
|
+
op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
|
|
726
|
+
elif isinstance(sample_op, JoinOp):
|
|
727
|
+
assert len(source_unique_logical_op_ids) == 2, f"Join for {sample_op} has {len(source_unique_logical_op_ids)} != 2 source logical operators"
|
|
728
|
+
left_source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
|
|
729
|
+
right_source_indices = op_frontiers[source_unique_logical_op_ids[1]].source_indices
|
|
730
|
+
source_indices = []
|
|
731
|
+
for left_source_idx in left_source_indices:
|
|
732
|
+
for right_source_idx in right_source_indices:
|
|
733
|
+
source_indices.append((left_source_idx, right_source_idx))
|
|
734
|
+
op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
|
|
735
|
+
else:
|
|
736
|
+
source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
|
|
737
|
+
op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
|
|
625
738
|
|
|
626
739
|
# initialize and start the progress manager
|
|
627
740
|
self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, progress=self.progress)
|
|
628
741
|
self.progress_manager.start()
|
|
629
742
|
|
|
630
|
-
# NOTE: we must handle progress manager outside of
|
|
743
|
+
# NOTE: we must handle progress manager outside of _execute_sentinel_plan to ensure that it is shut down correctly;
|
|
631
744
|
# if we don't have the `finally:` branch, then program crashes can cause future program runs to fail because
|
|
632
745
|
# the progress manager cannot get a handle to the console
|
|
633
746
|
try:
|
|
634
747
|
# execute sentinel plan by sampling records and operators
|
|
635
|
-
plan_stats = self._execute_sentinel_plan(plan, op_frontiers,
|
|
748
|
+
plan_stats = self._execute_sentinel_plan(plan, op_frontiers, validator, plan_stats)
|
|
636
749
|
|
|
637
750
|
finally:
|
|
638
751
|
# finish progress tracking
|