palimpzest 0.6.4__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +5 -0
- palimpzest/constants.py +110 -43
- palimpzest/core/__init__.py +0 -78
- palimpzest/core/data/dataclasses.py +382 -44
- palimpzest/core/elements/filters.py +7 -3
- palimpzest/core/elements/index.py +70 -0
- palimpzest/core/elements/records.py +33 -11
- palimpzest/core/lib/fields.py +1 -0
- palimpzest/core/lib/schemas.py +4 -3
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
- palimpzest/prompts/prompt_factory.py +44 -7
- palimpzest/prompts/split_merge_prompts.py +56 -0
- palimpzest/prompts/split_proposer_prompts.py +55 -0
- palimpzest/query/execution/execution_strategy.py +435 -53
- palimpzest/query/execution/execution_strategy_type.py +20 -0
- palimpzest/query/execution/mab_execution_strategy.py +532 -0
- palimpzest/query/execution/parallel_execution_strategy.py +143 -172
- palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
- palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
- palimpzest/query/generators/api_client_factory.py +31 -0
- palimpzest/query/generators/generators.py +256 -76
- palimpzest/query/operators/__init__.py +1 -2
- palimpzest/query/operators/code_synthesis_convert.py +33 -18
- palimpzest/query/operators/convert.py +30 -97
- palimpzest/query/operators/critique_and_refine_convert.py +5 -6
- palimpzest/query/operators/filter.py +7 -10
- palimpzest/query/operators/logical.py +54 -10
- palimpzest/query/operators/map.py +130 -0
- palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
- palimpzest/query/operators/physical.py +3 -12
- palimpzest/query/operators/rag_convert.py +66 -18
- palimpzest/query/operators/retrieve.py +230 -34
- palimpzest/query/operators/scan.py +5 -2
- palimpzest/query/operators/split_convert.py +169 -0
- palimpzest/query/operators/token_reduction_convert.py +8 -14
- palimpzest/query/optimizer/__init__.py +4 -16
- palimpzest/query/optimizer/cost_model.py +73 -266
- palimpzest/query/optimizer/optimizer.py +87 -58
- palimpzest/query/optimizer/optimizer_strategy.py +18 -97
- palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/primitives.py +5 -3
- palimpzest/query/optimizer/rules.py +336 -172
- palimpzest/query/optimizer/tasks.py +30 -100
- palimpzest/query/processor/config.py +38 -22
- palimpzest/query/processor/nosentinel_processor.py +16 -520
- palimpzest/query/processor/processing_strategy_type.py +28 -0
- palimpzest/query/processor/query_processor.py +38 -206
- palimpzest/query/processor/query_processor_factory.py +117 -130
- palimpzest/query/processor/sentinel_processor.py +90 -0
- palimpzest/query/processor/streaming_processor.py +25 -32
- palimpzest/sets.py +88 -41
- palimpzest/utils/model_helpers.py +8 -7
- palimpzest/utils/progress.py +368 -152
- palimpzest/utils/token_reduction_helpers.py +1 -3
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/METADATA +19 -9
- palimpzest-0.7.1.dist-info/RECORD +96 -0
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/WHEEL +1 -1
- palimpzest/query/processor/mab_sentinel_processor.py +0 -884
- palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
- palimpzest/utils/index_helpers.py +0 -6
- palimpzest-0.6.4.dist-info/RECORD +0 -87
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info/licenses}/LICENSE +0 -0
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -1,71 +1,453 @@
|
|
|
1
|
-
import
|
|
1
|
+
import logging
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from
|
|
3
|
+
from concurrent.futures import ThreadPoolExecutor, wait
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
from
|
|
7
|
-
from palimpzest.query.optimizer.plan import PhysicalPlan
|
|
5
|
+
import numpy as np
|
|
6
|
+
from chromadb.api.models.Collection import Collection
|
|
8
7
|
|
|
8
|
+
from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
|
|
9
|
+
from palimpzest.core.data.dataclasses import OperatorCostEstimates, PlanStats, RecordOpStats
|
|
10
|
+
from palimpzest.core.data.datareaders import DataReader
|
|
11
|
+
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
12
|
+
from palimpzest.policy import Policy
|
|
13
|
+
from palimpzest.query.operators.convert import LLMConvert
|
|
14
|
+
from palimpzest.query.operators.filter import FilterOp, LLMFilter
|
|
15
|
+
from palimpzest.query.operators.physical import PhysicalOperator
|
|
16
|
+
from palimpzest.query.operators.retrieve import RetrieveOp
|
|
17
|
+
from palimpzest.query.operators.scan import ScanPhysicalOp
|
|
18
|
+
from palimpzest.query.optimizer.plan import PhysicalPlan, SentinelPlan
|
|
19
|
+
from palimpzest.utils.progress import PZSentinelProgressManager
|
|
9
20
|
|
|
10
|
-
|
|
11
|
-
"""Available execution strategy types"""
|
|
12
|
-
SEQUENTIAL = "sequential"
|
|
13
|
-
PIPELINED_SINGLE_THREAD = "pipelined"
|
|
14
|
-
PIPELINED_PARALLEL = "pipelined_parallel"
|
|
15
|
-
AUTO = "auto"
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
16
22
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
"""
|
|
20
|
-
Base strategy for executing query plans.
|
|
21
|
-
Defines how to execute a single plan.
|
|
22
|
-
"""
|
|
23
|
-
def __init__(self,
|
|
23
|
+
class BaseExecutionStrategy:
|
|
24
|
+
def __init__(self,
|
|
24
25
|
scan_start_idx: int = 0,
|
|
25
26
|
max_workers: int | None = None,
|
|
26
|
-
|
|
27
|
-
|
|
27
|
+
num_samples: int | None = None,
|
|
28
|
+
cache: bool = False,
|
|
29
|
+
verbose: bool = False,
|
|
30
|
+
progress: bool = True,
|
|
31
|
+
*args,
|
|
32
|
+
**kwargs):
|
|
28
33
|
self.scan_start_idx = scan_start_idx
|
|
29
|
-
self.nocache = nocache
|
|
30
|
-
self.verbose = verbose
|
|
31
34
|
self.max_workers = max_workers
|
|
32
|
-
self.
|
|
35
|
+
self.num_samples = num_samples
|
|
36
|
+
self.cache = cache
|
|
37
|
+
self.verbose = verbose
|
|
38
|
+
self.progress = progress
|
|
39
|
+
|
|
33
40
|
|
|
41
|
+
def _add_records_to_cache(self, target_cache_id: str, records: list[DataRecord]) -> None:
|
|
42
|
+
"""Add each record (which isn't filtered) to the cache for the given target_cache_id."""
|
|
43
|
+
if self.cache:
|
|
44
|
+
for record in records:
|
|
45
|
+
if getattr(record, "passed_operator", True):
|
|
46
|
+
# self.datadir.append_cache(target_cache_id, record)
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
def _close_cache(self, target_cache_ids: list[str]) -> None:
|
|
50
|
+
"""Close the cache for each of the given target_cache_ids"""
|
|
51
|
+
if self.cache:
|
|
52
|
+
for target_cache_id in target_cache_ids: # noqa: B007
|
|
53
|
+
# self.datadir.close_cache(target_cache_id)
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
class ExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
57
|
+
"""Base strategy for executing query plans. Defines how to execute a PhysicalPlan.
|
|
58
|
+
"""
|
|
59
|
+
def __init__(self, *args, **kwargs):
|
|
60
|
+
super().__init__(*args, **kwargs)
|
|
61
|
+
logger.info(f"Initialized ExecutionStrategy {self.__class__.__name__}")
|
|
62
|
+
logger.debug(f"ExecutionStrategy initialized with config: {self.__dict__}")
|
|
34
63
|
|
|
35
64
|
@abstractmethod
|
|
36
|
-
def execute_plan(
|
|
37
|
-
self,
|
|
38
|
-
plan: PhysicalPlan,
|
|
39
|
-
num_samples: int | float = float("inf"),
|
|
40
|
-
workers: int = 1
|
|
41
|
-
) -> tuple[list[DataRecord], PlanStats]:
|
|
65
|
+
def execute_plan(self, plan: PhysicalPlan) -> tuple[list[DataRecord], PlanStats]:
|
|
42
66
|
"""Execute a single plan according to strategy"""
|
|
43
67
|
pass
|
|
44
68
|
|
|
69
|
+
def _create_input_queues(self, plan: PhysicalPlan) -> dict[str, list]:
|
|
70
|
+
"""Initialize input queues for each operator in the plan."""
|
|
71
|
+
input_queues = {}
|
|
72
|
+
for op in plan.operators:
|
|
73
|
+
inputs = []
|
|
74
|
+
if isinstance(op, ScanPhysicalOp):
|
|
75
|
+
scan_end_idx = (
|
|
76
|
+
len(op.datareader)
|
|
77
|
+
if self.num_samples is None
|
|
78
|
+
else min(self.scan_start_idx + self.num_samples, len(op.datareader))
|
|
79
|
+
)
|
|
80
|
+
inputs = [idx for idx in range(self.scan_start_idx, scan_end_idx)]
|
|
81
|
+
input_queues[op.get_op_id()] = inputs
|
|
82
|
+
|
|
83
|
+
return input_queues
|
|
45
84
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
# plan_stats=aggregate_plan_stats,
|
|
54
|
-
# total_execution_time=time.time() - execution_start_time,
|
|
55
|
-
# total_execution_cost=sum(
|
|
56
|
-
# list(map(lambda plan_stats: plan_stats.total_plan_cost, aggregate_plan_stats.values()))
|
|
57
|
-
# ),
|
|
58
|
-
# plan_strs={plan_id: plan_stats.plan_str for plan_id, plan_stats in aggregate_plan_stats.items()},
|
|
59
|
-
# )
|
|
60
|
-
def _create_execution_stats(
|
|
85
|
+
class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
86
|
+
"""Base strategy for executing sentinel query plans. Defines how to execute a SentinelPlan."""
|
|
87
|
+
"""
|
|
88
|
+
Specialized query processor that implements MAB sentinel strategy
|
|
89
|
+
for coordinating optimization and execution.
|
|
90
|
+
"""
|
|
91
|
+
def __init__(
|
|
61
92
|
self,
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
93
|
+
val_datasource: DataReader,
|
|
94
|
+
k: int,
|
|
95
|
+
j: int,
|
|
96
|
+
sample_budget: int,
|
|
97
|
+
policy: Policy,
|
|
98
|
+
use_final_op_quality: bool = False,
|
|
99
|
+
seed: int = 42,
|
|
100
|
+
exp_name: str | None = None,
|
|
101
|
+
*args,
|
|
102
|
+
**kwargs,
|
|
103
|
+
):
|
|
104
|
+
super().__init__(*args, **kwargs)
|
|
105
|
+
self.val_datasource = val_datasource
|
|
106
|
+
self.k = k
|
|
107
|
+
self.j = j
|
|
108
|
+
self.sample_budget = sample_budget
|
|
109
|
+
self.policy = policy
|
|
110
|
+
self.use_final_op_quality = use_final_op_quality
|
|
111
|
+
self.seed = seed
|
|
112
|
+
self.rng = np.random.default_rng(seed=seed)
|
|
113
|
+
self.exp_name = exp_name
|
|
114
|
+
|
|
115
|
+
# special cache which is used for tracking the target record sets for each (source_idx, logical_op_id)
|
|
116
|
+
self.champion_output_cache: dict[int, dict[str, tuple[DataRecordSet, float]]] = {}
|
|
117
|
+
|
|
118
|
+
# general cache which maps hash(logical_op_id, phys_op_id, hash(input)) --> record_set
|
|
119
|
+
self.cache: dict[int, DataRecordSet] = {}
|
|
120
|
+
|
|
121
|
+
# progress manager used to track progress of the execution
|
|
122
|
+
self.progress_manager: PZSentinelProgressManager | None = None
|
|
123
|
+
|
|
124
|
+
def _compute_quality(
|
|
125
|
+
self,
|
|
126
|
+
physical_op_cls: type[PhysicalOperator],
|
|
127
|
+
record_set: DataRecordSet,
|
|
128
|
+
target_record_set: DataRecordSet,
|
|
129
|
+
) -> DataRecordSet:
|
|
130
|
+
"""
|
|
131
|
+
Compute the quality for the given `record_set` by comparing it to the `target_record_set`.
|
|
132
|
+
|
|
133
|
+
Update the record_set by assigning the quality to each entry in its record_op_stats and
|
|
134
|
+
returning the updated record_set.
|
|
135
|
+
"""
|
|
136
|
+
# if this operation failed
|
|
137
|
+
if len(record_set) == 0:
|
|
138
|
+
record_set.record_op_stats[0].quality = 0.0
|
|
139
|
+
|
|
140
|
+
# if this operation is a filter:
|
|
141
|
+
# - return 1.0 if there's a match in the expected output which this operator does not filter out and 0.0 otherwise
|
|
142
|
+
elif issubclass(physical_op_cls, FilterOp):
|
|
143
|
+
# NOTE: we know that record_set.data_records will contain a single entry for a filter op
|
|
144
|
+
record = record_set.data_records[0]
|
|
145
|
+
|
|
146
|
+
# search for a record in the target with the same set of fields
|
|
147
|
+
found_match_in_target = False
|
|
148
|
+
for target_record in target_record_set:
|
|
149
|
+
all_correct = True
|
|
150
|
+
for field, value in record.field_values.items():
|
|
151
|
+
if value != target_record[field]:
|
|
152
|
+
all_correct = False
|
|
153
|
+
break
|
|
154
|
+
|
|
155
|
+
if all_correct:
|
|
156
|
+
found_match_in_target = target_record.passed_operator
|
|
157
|
+
break
|
|
158
|
+
|
|
159
|
+
# set quality based on whether we found a match in the target and return
|
|
160
|
+
record_set.record_op_stats[0].quality = int(record.passed_operator == found_match_in_target)
|
|
161
|
+
|
|
162
|
+
return record_set
|
|
163
|
+
|
|
164
|
+
# if this is a successful convert operation
|
|
165
|
+
else:
|
|
166
|
+
# NOTE: the following computation assumes we do not project out computed values
|
|
167
|
+
# (and that the validation examples provide all computed fields); even if
|
|
168
|
+
# a user program does add projection, we can ignore the projection on the
|
|
169
|
+
# validation dataset and use the champion model (as opposed to the validation
|
|
170
|
+
# output) for scoring fields which have their values projected out
|
|
171
|
+
|
|
172
|
+
# GREEDY ALGORITHM
|
|
173
|
+
# for each record in the expected output, we look for the computed record which maximizes the quality metric;
|
|
174
|
+
# once we've identified that computed record we remove it from consideration for the next expected output
|
|
175
|
+
field_to_score_fn = target_record_set.get_field_to_score_fn()
|
|
176
|
+
for target_record in target_record_set:
|
|
177
|
+
best_quality, best_record_op_stats = 0.0, None
|
|
178
|
+
for record_op_stats in record_set.record_op_stats:
|
|
179
|
+
# if we already assigned this record a quality, skip it
|
|
180
|
+
if record_op_stats.quality is not None:
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
# compute number of matches between this record's computed fields and this expected record's outputs
|
|
184
|
+
total_quality = 0
|
|
185
|
+
for field in record_op_stats.generated_fields:
|
|
186
|
+
computed_value = record_op_stats.record_state.get(field, None)
|
|
187
|
+
expected_value = target_record[field]
|
|
188
|
+
|
|
189
|
+
# get the metric function for this field
|
|
190
|
+
score_fn = field_to_score_fn.get(field, "exact")
|
|
191
|
+
|
|
192
|
+
# compute exact match
|
|
193
|
+
if score_fn == "exact":
|
|
194
|
+
total_quality += int(computed_value == expected_value)
|
|
195
|
+
|
|
196
|
+
# compute UDF metric
|
|
197
|
+
elif callable(score_fn):
|
|
198
|
+
total_quality += score_fn(computed_value, expected_value)
|
|
199
|
+
|
|
200
|
+
# otherwise, throw an exception
|
|
201
|
+
else:
|
|
202
|
+
raise Exception(f"Unrecognized score_fn: {score_fn}")
|
|
203
|
+
|
|
204
|
+
# compute recall and update best seen so far
|
|
205
|
+
quality = total_quality / len(record_op_stats.generated_fields)
|
|
206
|
+
if quality > best_quality:
|
|
207
|
+
best_quality = quality
|
|
208
|
+
best_record_op_stats = record_op_stats
|
|
209
|
+
|
|
210
|
+
# set best_quality as quality for the best_record_op_stats
|
|
211
|
+
if best_record_op_stats is not None:
|
|
212
|
+
best_record_op_stats.quality = best_quality
|
|
213
|
+
|
|
214
|
+
# for any records which did not receive a quality, set it to 0.0 as these are unexpected extras
|
|
215
|
+
for record_op_stats in record_set.record_op_stats:
|
|
216
|
+
if record_op_stats.quality is None:
|
|
217
|
+
record_op_stats.quality = 0.0
|
|
218
|
+
|
|
219
|
+
return record_set
|
|
220
|
+
|
|
221
|
+
def _score_quality(
|
|
222
|
+
self,
|
|
223
|
+
physical_op_cls: type[PhysicalOperator],
|
|
224
|
+
source_idx_to_record_sets: dict[int, list[DataRecordSet]],
|
|
225
|
+
source_idx_to_target_record_set: dict[int, DataRecordSet],
|
|
226
|
+
) -> dict[int, list[DataRecordSet]]:
|
|
227
|
+
"""
|
|
228
|
+
NOTE: This approach to cost modeling does not work directly for aggregation queries;
|
|
229
|
+
for these queries, we would ask the user to provide validation data for the step immediately
|
|
230
|
+
before a final aggregation
|
|
231
|
+
|
|
232
|
+
NOTE: This function currently assumes that one-to-many converts do NOT create duplicate outputs.
|
|
233
|
+
This assumption would break if, for example, we extracted the breed of every dog in an image.
|
|
234
|
+
If there were two golden retrievers and a bernoodle in an image and we extracted:
|
|
235
|
+
|
|
236
|
+
{"image": "file1.png", "breed": "Golden Retriever"}
|
|
237
|
+
{"image": "file1.png", "breed": "Golden Retriever"}
|
|
238
|
+
{"image": "file1.png", "breed": "Bernedoodle"}
|
|
239
|
+
|
|
240
|
+
This function would currently give perfect accuracy to the following output:
|
|
241
|
+
|
|
242
|
+
{"image": "file1.png", "breed": "Golden Retriever"}
|
|
243
|
+
{"image": "file1.png", "breed": "Bernedoodle"}
|
|
244
|
+
|
|
245
|
+
Even though it is missing one of the golden retrievers.
|
|
246
|
+
"""
|
|
247
|
+
# extract information about the logical operation performed at this stage of the sentinel plan;
|
|
248
|
+
# NOTE: we can infer these fields from context clues, but in the long-term we should have a more
|
|
249
|
+
# principled way of getting these directly from attributes either stored in the sentinel_plan
|
|
250
|
+
# or in the PhysicalOperator
|
|
251
|
+
is_perfect_quality_op = (
|
|
252
|
+
not issubclass(physical_op_cls, LLMConvert)
|
|
253
|
+
and not issubclass(physical_op_cls, LLMFilter)
|
|
254
|
+
and not issubclass(physical_op_cls, RetrieveOp)
|
|
71
255
|
)
|
|
256
|
+
|
|
257
|
+
# compute quality of each output computed by this operator
|
|
258
|
+
for source_idx, record_sets in source_idx_to_record_sets.items():
|
|
259
|
+
# if this operation does not involve an LLM, every record_op_stats object gets perfect quality
|
|
260
|
+
if is_perfect_quality_op:
|
|
261
|
+
for record_set in record_sets:
|
|
262
|
+
for record_op_stats in record_set.record_op_stats:
|
|
263
|
+
record_op_stats.quality = 1.0
|
|
264
|
+
continue
|
|
265
|
+
|
|
266
|
+
# extract target output for this record set
|
|
267
|
+
target_record_set = source_idx_to_target_record_set[source_idx]
|
|
268
|
+
|
|
269
|
+
# for each record_set produced by an operation, compute its quality
|
|
270
|
+
for record_set in record_sets:
|
|
271
|
+
record_set = self._compute_quality(physical_op_cls, record_set, target_record_set)
|
|
272
|
+
|
|
273
|
+
# return the quality annotated record sets
|
|
274
|
+
return source_idx_to_record_sets
|
|
275
|
+
|
|
276
|
+
def _get_target_record_sets(
|
|
277
|
+
self,
|
|
278
|
+
logical_op_id: str,
|
|
279
|
+
source_idx_to_record_set_tuples: dict[int, list[tuple[DataRecordSet, PhysicalOperator, bool]]],
|
|
280
|
+
expected_outputs: dict[int, dict] | None,
|
|
281
|
+
) -> dict[int, DataRecordSet]:
|
|
282
|
+
# initialize mapping from source index to target record sets
|
|
283
|
+
source_idx_to_target_record_set = {}
|
|
284
|
+
|
|
285
|
+
for source_idx, record_set_tuples in source_idx_to_record_set_tuples.items():
|
|
286
|
+
# get the first generated output for this source_idx
|
|
287
|
+
base_target_record = None
|
|
288
|
+
for record_set, _, _ in record_set_tuples:
|
|
289
|
+
if len(record_set) > 0:
|
|
290
|
+
base_target_record = record_set[0]
|
|
291
|
+
break
|
|
292
|
+
|
|
293
|
+
# compute availability of data
|
|
294
|
+
base_target_present = base_target_record is not None
|
|
295
|
+
labels_present = expected_outputs is not None
|
|
296
|
+
labels_for_source_present = False
|
|
297
|
+
if labels_present and source_idx in expected_outputs:
|
|
298
|
+
labels = expected_outputs[source_idx].get("labels", [])
|
|
299
|
+
labels_dict_lst = labels if isinstance(labels, list) else [labels]
|
|
300
|
+
labels_for_source_present = labels_dict_lst != [] and labels_dict_lst != [None]
|
|
301
|
+
|
|
302
|
+
# if we have a base target record and label info, use the label info to construct the target record set
|
|
303
|
+
if base_target_present and labels_for_source_present:
|
|
304
|
+
# get the field_to_score_fn
|
|
305
|
+
field_to_score_fn = expected_outputs[source_idx].get("score_fn", {})
|
|
306
|
+
|
|
307
|
+
# construct the target record set; we force passed_operator to be True for all target records
|
|
308
|
+
target_records = []
|
|
309
|
+
for labels_dict in labels_dict_lst:
|
|
310
|
+
target_record = base_target_record.copy()
|
|
311
|
+
for field, value in labels_dict.items():
|
|
312
|
+
target_record[field] = value
|
|
313
|
+
target_record.passed_operator = True
|
|
314
|
+
target_records.append(target_record)
|
|
315
|
+
|
|
316
|
+
source_idx_to_target_record_set[source_idx] = DataRecordSet(target_records, None, field_to_score_fn)
|
|
317
|
+
continue
|
|
318
|
+
|
|
319
|
+
# get the best computed output for this (source_idx, logical_op_id) so far (if one exists)
|
|
320
|
+
champion_record_set, champion_op_quality = None, None
|
|
321
|
+
if source_idx in self.champion_output_cache and logical_op_id in self.champion_output_cache[source_idx]:
|
|
322
|
+
champion_record_set, champion_op_quality = self.champion_output_cache[source_idx][logical_op_id]
|
|
323
|
+
|
|
324
|
+
# get the highest quality output that we just computed
|
|
325
|
+
max_quality_record_set, max_op_quality = self._pick_champion_output(record_set_tuples)
|
|
326
|
+
|
|
327
|
+
# if this new output is of higher quality than our previous champion (or if we didn't have
|
|
328
|
+
# a previous champion) then we update our champion record set
|
|
329
|
+
if champion_op_quality is None or (max_op_quality is not None and max_op_quality > champion_op_quality):
|
|
330
|
+
champion_record_set, champion_op_quality = max_quality_record_set, max_op_quality
|
|
331
|
+
|
|
332
|
+
# update the cache with the new champion record set and quality
|
|
333
|
+
if source_idx not in self.champion_output_cache:
|
|
334
|
+
self.champion_output_cache[source_idx] = {}
|
|
335
|
+
self.champion_output_cache[source_idx][logical_op_id] = (champion_record_set, champion_op_quality)
|
|
336
|
+
|
|
337
|
+
# set the target
|
|
338
|
+
source_idx_to_target_record_set[source_idx] = champion_record_set
|
|
339
|
+
|
|
340
|
+
return source_idx_to_target_record_set
|
|
341
|
+
|
|
342
|
+
def _pick_champion_output(self, record_set_tuples: list[tuple[DataRecordSet, PhysicalOperator, bool]]) -> tuple[DataRecordSet, float | None]:
|
|
343
|
+
# find the operator with the highest estimated quality and return its record_set
|
|
344
|
+
base_op_cost_est = OperatorCostEstimates(cardinality=1.0, cost_per_record=0.0, time_per_record=0.0, quality=1.0)
|
|
345
|
+
champion_record_set, champion_quality = None, None
|
|
346
|
+
for record_set, op, _ in record_set_tuples:
|
|
347
|
+
# skip failed operations
|
|
348
|
+
if len(record_set) == 0:
|
|
349
|
+
continue
|
|
350
|
+
|
|
351
|
+
# get the estimated quality of this operator
|
|
352
|
+
est_quality = op.naive_cost_estimates(base_op_cost_est).quality if self._is_llm_op(op) else 1.0
|
|
353
|
+
if champion_quality is None or est_quality > champion_quality:
|
|
354
|
+
champion_record_set, champion_quality = record_set, est_quality
|
|
355
|
+
|
|
356
|
+
return champion_record_set, champion_quality
|
|
357
|
+
|
|
358
|
+
def _flatten_record_sets(self, source_idx_to_record_sets: dict[int, list[DataRecordSet]]) -> tuple[list[DataRecord], list[RecordOpStats]]:
|
|
359
|
+
"""
|
|
360
|
+
Flatten the list of record sets and record op stats for each source_idx.
|
|
361
|
+
"""
|
|
362
|
+
all_records, all_record_op_stats = [], []
|
|
363
|
+
for _, record_sets in source_idx_to_record_sets.items():
|
|
364
|
+
for record_set in record_sets:
|
|
365
|
+
all_records.extend(record_set.data_records)
|
|
366
|
+
all_record_op_stats.extend(record_set.record_op_stats)
|
|
367
|
+
|
|
368
|
+
return all_records, all_record_op_stats
|
|
369
|
+
|
|
370
|
+
def _execute_op_set(self, op_input_pairs: list[tuple[PhysicalOperator, DataRecord | int]]) -> tuple[dict[int, list[tuple[DataRecordSet, PhysicalOperator, bool]]], dict[str, int]]:
|
|
371
|
+
def execute_op_wrapper(operator, input) -> tuple[DataRecordSet, PhysicalOperator, DataRecord | int]:
|
|
372
|
+
record_set = operator(input)
|
|
373
|
+
return record_set, operator, input
|
|
374
|
+
|
|
375
|
+
# TODO: modify unit tests to always have record_op_stats so we can use record_op_stats for source_idx
|
|
376
|
+
# for scan operators, `input` will be the source_idx
|
|
377
|
+
def get_source_idx(input):
|
|
378
|
+
return input.source_idx if isinstance(input, DataRecord) else input
|
|
379
|
+
|
|
380
|
+
def get_hash(operator, input):
|
|
381
|
+
logical_op_id = operator.get_logical_op_id()
|
|
382
|
+
phys_op_id = operator.get_op_id()
|
|
383
|
+
return hash(f"{logical_op_id}{phys_op_id}{hash(input)}")
|
|
384
|
+
|
|
385
|
+
# initialize mapping from source indices to output record sets
|
|
386
|
+
source_idx_to_record_sets_and_ops = {get_source_idx(input): [] for _, input in op_input_pairs}
|
|
387
|
+
|
|
388
|
+
# if any operations were previously executed, read the results from the cache
|
|
389
|
+
final_op_input_pairs = []
|
|
390
|
+
for operator, input in op_input_pairs:
|
|
391
|
+
# compute hash
|
|
392
|
+
op_input_hash = get_hash(operator, input)
|
|
393
|
+
|
|
394
|
+
# get result from cache
|
|
395
|
+
if op_input_hash in self.cache:
|
|
396
|
+
source_idx = get_source_idx(input)
|
|
397
|
+
record_set, operator = self.cache[op_input_hash]
|
|
398
|
+
source_idx_to_record_sets_and_ops[source_idx].append((record_set, operator, False))
|
|
399
|
+
|
|
400
|
+
# otherwise, add to final_op_input_pairs
|
|
401
|
+
else:
|
|
402
|
+
final_op_input_pairs.append((operator, input))
|
|
403
|
+
|
|
404
|
+
# keep track of the number of llm operations
|
|
405
|
+
num_llm_ops = 0
|
|
406
|
+
|
|
407
|
+
# create thread pool w/max workers and run futures over worker pool
|
|
408
|
+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
409
|
+
# create futures
|
|
410
|
+
futures = [
|
|
411
|
+
executor.submit(execute_op_wrapper, operator, input)
|
|
412
|
+
for operator, input in final_op_input_pairs
|
|
413
|
+
]
|
|
414
|
+
output_record_sets = []
|
|
415
|
+
while len(futures) > 0:
|
|
416
|
+
done_futures, not_done_futures = wait(futures, timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
|
|
417
|
+
for future in done_futures:
|
|
418
|
+
# update output record sets
|
|
419
|
+
record_set, operator, input = future.result()
|
|
420
|
+
output_record_sets.append((record_set, operator, input))
|
|
421
|
+
|
|
422
|
+
# update cache
|
|
423
|
+
op_input_hash = get_hash(operator, input)
|
|
424
|
+
self.cache[op_input_hash] = (record_set, operator)
|
|
425
|
+
|
|
426
|
+
# update progress manager
|
|
427
|
+
if self._is_llm_op(operator):
|
|
428
|
+
num_llm_ops += 1
|
|
429
|
+
self.progress_manager.incr(operator.get_logical_op_id(), num_samples=1, total_cost=record_set.get_total_cost())
|
|
430
|
+
|
|
431
|
+
# update futures
|
|
432
|
+
futures = list(not_done_futures)
|
|
433
|
+
|
|
434
|
+
# update mapping from source_idx to record sets and operators
|
|
435
|
+
for record_set, operator, input in output_record_sets:
|
|
436
|
+
# get the source_idx associated with this input record;
|
|
437
|
+
source_idx = get_source_idx(input)
|
|
438
|
+
|
|
439
|
+
# add record_set to mapping from source_idx --> record_sets
|
|
440
|
+
source_idx_to_record_sets_and_ops[source_idx].append((record_set, operator, True))
|
|
441
|
+
|
|
442
|
+
return source_idx_to_record_sets_and_ops, num_llm_ops
|
|
443
|
+
|
|
444
|
+
def _is_llm_op(self, physical_op: PhysicalOperator) -> bool:
|
|
445
|
+
is_llm_convert = isinstance(physical_op, LLMConvert)
|
|
446
|
+
is_llm_filter = isinstance(physical_op, LLMFilter)
|
|
447
|
+
is_llm_retrieve = isinstance(physical_op, RetrieveOp) and isinstance(physical_op.index, Collection)
|
|
448
|
+
return is_llm_convert or is_llm_filter or is_llm_retrieve
|
|
449
|
+
|
|
450
|
+
@abstractmethod
|
|
451
|
+
def execute_sentinel_plan(self, sentinel_plan: SentinelPlan, expected_outputs: dict[str, dict]):
|
|
452
|
+
"""Execute a SentinelPlan according to strategy"""
|
|
453
|
+
pass
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
from palimpzest.query.execution.mab_execution_strategy import MABExecutionStrategy
|
|
4
|
+
from palimpzest.query.execution.parallel_execution_strategy import ParallelExecutionStrategy
|
|
5
|
+
from palimpzest.query.execution.random_sampling_execution_strategy import RandomSamplingExecutionStrategy
|
|
6
|
+
from palimpzest.query.execution.single_threaded_execution_strategy import (
|
|
7
|
+
PipelinedSingleThreadExecutionStrategy,
|
|
8
|
+
SequentialSingleThreadExecutionStrategy,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ExecutionStrategyType(Enum):
|
|
13
|
+
"""Available execution strategy types"""
|
|
14
|
+
SEQUENTIAL = SequentialSingleThreadExecutionStrategy
|
|
15
|
+
PIPELINED = PipelinedSingleThreadExecutionStrategy
|
|
16
|
+
PARALLEL = ParallelExecutionStrategy
|
|
17
|
+
|
|
18
|
+
class SentinelExecutionStrategyType(Enum):
|
|
19
|
+
MAB = MABExecutionStrategy
|
|
20
|
+
RANDOM = RandomSamplingExecutionStrategy
|