palimpzest 0.6.4__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. palimpzest/__init__.py +5 -0
  2. palimpzest/constants.py +110 -43
  3. palimpzest/core/__init__.py +0 -78
  4. palimpzest/core/data/dataclasses.py +382 -44
  5. palimpzest/core/elements/filters.py +7 -3
  6. palimpzest/core/elements/index.py +70 -0
  7. palimpzest/core/elements/records.py +33 -11
  8. palimpzest/core/lib/fields.py +1 -0
  9. palimpzest/core/lib/schemas.py +4 -3
  10. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
  11. palimpzest/prompts/prompt_factory.py +44 -7
  12. palimpzest/prompts/split_merge_prompts.py +56 -0
  13. palimpzest/prompts/split_proposer_prompts.py +55 -0
  14. palimpzest/query/execution/execution_strategy.py +435 -53
  15. palimpzest/query/execution/execution_strategy_type.py +20 -0
  16. palimpzest/query/execution/mab_execution_strategy.py +532 -0
  17. palimpzest/query/execution/parallel_execution_strategy.py +143 -172
  18. palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
  19. palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
  20. palimpzest/query/generators/api_client_factory.py +31 -0
  21. palimpzest/query/generators/generators.py +256 -76
  22. palimpzest/query/operators/__init__.py +1 -2
  23. palimpzest/query/operators/code_synthesis_convert.py +33 -18
  24. palimpzest/query/operators/convert.py +30 -97
  25. palimpzest/query/operators/critique_and_refine_convert.py +5 -6
  26. palimpzest/query/operators/filter.py +7 -10
  27. palimpzest/query/operators/logical.py +54 -10
  28. palimpzest/query/operators/map.py +130 -0
  29. palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
  30. palimpzest/query/operators/physical.py +3 -12
  31. palimpzest/query/operators/rag_convert.py +66 -18
  32. palimpzest/query/operators/retrieve.py +230 -34
  33. palimpzest/query/operators/scan.py +5 -2
  34. palimpzest/query/operators/split_convert.py +169 -0
  35. palimpzest/query/operators/token_reduction_convert.py +8 -14
  36. palimpzest/query/optimizer/__init__.py +4 -16
  37. palimpzest/query/optimizer/cost_model.py +73 -266
  38. palimpzest/query/optimizer/optimizer.py +87 -58
  39. palimpzest/query/optimizer/optimizer_strategy.py +18 -97
  40. palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
  41. palimpzest/query/optimizer/plan.py +2 -3
  42. palimpzest/query/optimizer/primitives.py +5 -3
  43. palimpzest/query/optimizer/rules.py +336 -172
  44. palimpzest/query/optimizer/tasks.py +30 -100
  45. palimpzest/query/processor/config.py +38 -22
  46. palimpzest/query/processor/nosentinel_processor.py +16 -520
  47. palimpzest/query/processor/processing_strategy_type.py +28 -0
  48. palimpzest/query/processor/query_processor.py +38 -206
  49. palimpzest/query/processor/query_processor_factory.py +117 -130
  50. palimpzest/query/processor/sentinel_processor.py +90 -0
  51. palimpzest/query/processor/streaming_processor.py +25 -32
  52. palimpzest/sets.py +88 -41
  53. palimpzest/utils/model_helpers.py +8 -7
  54. palimpzest/utils/progress.py +368 -152
  55. palimpzest/utils/token_reduction_helpers.py +1 -3
  56. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info}/METADATA +19 -9
  57. palimpzest-0.7.0.dist-info/RECORD +96 -0
  58. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
  59. palimpzest/query/processor/mab_sentinel_processor.py +0 -884
  60. palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
  61. palimpzest/utils/index_helpers.py +0 -6
  62. palimpzest-0.6.4.dist-info/RECORD +0 -87
  63. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
  64. {palimpzest-0.6.4.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
1
+ import logging
2
+
3
+ import numpy as np
4
+
5
+ from palimpzest.core.data.dataclasses import OperatorStats, SentinelPlanStats
6
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
7
+ from palimpzest.policy import Policy
8
+ from palimpzest.query.execution.execution_strategy import SentinelExecutionStrategy
9
+ from palimpzest.query.operators.filter import FilterOp
10
+ from palimpzest.query.operators.physical import PhysicalOperator
11
+ from palimpzest.query.operators.scan import ScanPhysicalOp
12
+ from palimpzest.query.optimizer.plan import SentinelPlan
13
+ from palimpzest.utils.progress import create_progress_manager
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class OpFrontier:
18
+ """
19
+ This class represents the set of operators which are currently in the frontier for a given logical operator.
20
+ Each operator in the frontier is an instance of a PhysicalOperator which either:
21
+
22
+ 1. lies on the Pareto frontier of the set of sampled operators, or
23
+ 2. has been sampled fewer than j times
24
+ """
25
+
26
+ def __init__(self, op_set: list[PhysicalOperator], source_indices: list[int], k: int, j: int, seed: int, policy: Policy):
27
+ # set k and j, which are the initial number of operators in the frontier and the
28
+ # initial number of records to sample for each frontier operator
29
+ self.k = min(k, len(op_set))
30
+ self.j = min(j, len(source_indices))
31
+
32
+ # store the policy that we are optimizing under
33
+ self.policy = policy
34
+
35
+ # get order in which we will sample physical operators for this logical operator
36
+ sample_op_indices = self._get_op_index_order(op_set, seed)
37
+
38
+ # construct the initial set of frontier and reservoir operators
39
+ self.frontier_ops = [op_set[sample_idx] for sample_idx in sample_op_indices[:self.k]]
40
+ self.reservoir_ops = [op_set[sample_idx] for sample_idx in sample_op_indices[self.k:]]
41
+ self.off_frontier_ops = []
42
+
43
+ # store the order in which we will sample the source records
44
+ self.source_indices = source_indices
45
+
46
+ # keep track of the source ids processed by each physical operator
47
+ self.phys_op_id_to_sources_processed = {op.get_op_id(): set() for op in op_set}
48
+
49
+ # set the initial inputs for this logical operator
50
+ is_scan_op = isinstance(op_set[0], ScanPhysicalOp)
51
+ self.source_idx_to_input = {source_idx: [source_idx] for source_idx in self.source_indices} if is_scan_op else {}
52
+
53
+ # boolean indication of whether this is a logical filter
54
+ self.is_filter_op = isinstance(op_set[0], FilterOp)
55
+
56
+ def get_frontier_ops(self) -> list[PhysicalOperator]:
57
+ """
58
+ Returns the set of frontier operators for this OpFrontier.
59
+ """
60
+ return self.frontier_ops
61
+
62
+ def _get_op_index_order(self, op_set: list[PhysicalOperator], seed: int) -> list[int]:
63
+ """
64
+ Returns a list of indices for the operators in the op_set.
65
+ """
66
+ rng = np.random.default_rng(seed=seed)
67
+ op_indices = np.arange(len(op_set))
68
+ rng.shuffle(op_indices)
69
+ return op_indices
70
+
71
+ def _get_op_source_idx_pairs(self) -> list[tuple[PhysicalOperator, int]]:
72
+ """
73
+ Returns a list of tuples for (op, source_idx) which this operator needs to execute
74
+ in the next iteration.
75
+ """
76
+ op_source_idx_pairs = []
77
+ for op in self.frontier_ops:
78
+ # execute new operators on first j source indices, and previously sampled operators on one additional source_idx
79
+ num_processed = len(self.phys_op_id_to_sources_processed[op.get_op_id()])
80
+ num_new_samples = 1 if num_processed > 0 else self.j
81
+ num_new_samples = min(num_new_samples, len(self.source_indices) - num_processed)
82
+ assert num_new_samples >= 0, "Number of new samples must be non-negative"
83
+
84
+ # construct list of inputs by looking up the input for the given source_idx
85
+ samples_added = 0
86
+ for source_idx in self.source_indices:
87
+ if source_idx in self.phys_op_id_to_sources_processed[op.get_op_id()]:
88
+ continue
89
+
90
+ if samples_added == num_new_samples:
91
+ break
92
+
93
+ # construct the (op, source_idx) for this source_idx
94
+ op_source_idx_pairs.append((op, source_idx))
95
+ samples_added += 1
96
+
97
+ return op_source_idx_pairs
98
+
99
+ def get_source_indices_for_next_iteration(self) -> set[int]:
100
+ """
101
+ Returns the set of source indices which need to be sampled for the next iteration.
102
+ """
103
+ op_source_idx_pairs = self._get_op_source_idx_pairs()
104
+ return set(map(lambda tup: tup[1], op_source_idx_pairs))
105
+
106
+ def get_frontier_op_input_pairs(self, source_indices_to_sample: set[int], max_quality_op: PhysicalOperator) -> list[PhysicalOperator, DataRecord | int | None]:
107
+ """
108
+ Returns the list of frontier operators and their next input to process. If there are
109
+ any indices in `source_indices_to_sample` which this operator does not sample on its own, then
110
+ we also have this frontier process that source_idx's input with its max quality operator.
111
+ """
112
+ # get the list of (op, source_idx) pairs which this operator needs to execute
113
+ op_source_idx_pairs = self._get_op_source_idx_pairs()
114
+
115
+ # if there are any source_idxs in source_indices_to_sample which are not sampled
116
+ # by this operator, apply the max quality operator (and any other frontier operators
117
+ # with no samples)
118
+ sampled_source_indices = set(map(lambda tup: tup[1], op_source_idx_pairs))
119
+ unsampled_source_indices = source_indices_to_sample - sampled_source_indices
120
+ for source_idx in unsampled_source_indices:
121
+ op_source_idx_pairs.append((max_quality_op, source_idx))
122
+ for op in self.frontier_ops:
123
+ if len(self.phys_op_id_to_sources_processed[op.get_op_id()]) == 0 and op.get_op_id() != max_quality_op.get_op_id():
124
+ op_source_idx_pairs.append((op, source_idx))
125
+
126
+ # fetch the corresponding (op, input) pairs
127
+ op_input_pairs = [
128
+ (op, input)
129
+ for op, source_idx in op_source_idx_pairs
130
+ for input in self.source_idx_to_input[source_idx]
131
+ ]
132
+
133
+ return op_input_pairs
134
+
135
+ def update_frontier(self, logical_op_id: str, plan_stats: SentinelPlanStats) -> None:
136
+ """
137
+ Update the set of frontier operators, pulling in new ones from the reservoir as needed.
138
+ This function will:
139
+ 1. Compute the mean, LCB, and UCB for the cost, time, quality, and selectivity of each frontier operator
140
+ 2. Compute the pareto optimal set of frontier operators (using the mean values)
141
+ 3. Update the frontier and reservoir sets of operators based on their LCB/UCB overlap with the pareto frontier
142
+ """
143
+ # NOTE: downstream operators may end up re-computing the same record_id with a diff. input as upstream
144
+ # upstream operators change; in this case, we de-duplicate record_op_stats with identical record_ids
145
+ # and keep the one with the maximum quality
146
+ # get a mapping from physical_op_id --> list[RecordOpStats]
147
+ phys_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(logical_op_id, {})
148
+ phys_op_id_to_record_op_stats = {}
149
+ for phys_op_id, op_stats in phys_op_id_to_op_stats.items():
150
+ # skip over operators which have not been sampled
151
+ if len(op_stats.record_op_stats_lst) == 0:
152
+ continue
153
+
154
+ # compute mapping from record_id to highest quality record op stats
155
+ record_id_to_max_quality_record_op_stats = {}
156
+ for record_op_stats in op_stats.record_op_stats_lst:
157
+ record_id = record_op_stats.record_id
158
+ if record_id not in record_id_to_max_quality_record_op_stats: # noqa: SIM114
159
+ record_id_to_max_quality_record_op_stats[record_id] = record_op_stats
160
+
161
+ elif record_op_stats.quality > record_id_to_max_quality_record_op_stats[record_id].quality:
162
+ record_id_to_max_quality_record_op_stats[record_id] = record_op_stats
163
+
164
+ # compute final list of record op stats
165
+ phys_op_id_to_record_op_stats[phys_op_id] = list(record_id_to_max_quality_record_op_stats.values())
166
+
167
+ # compute mapping of physical op to num samples and total samples drawn;
168
+ # also update the set of source indices which have been processed by each physical operator
169
+ phys_op_id_to_num_samples, total_num_samples = {}, 0
170
+ for phys_op_id, record_op_stats_lst in phys_op_id_to_record_op_stats.items():
171
+ # update teh set of source indices processed
172
+ for record_op_stats in record_op_stats_lst:
173
+ self.phys_op_id_to_sources_processed[phys_op_id].add(record_op_stats.record_source_idx)
174
+
175
+ # compute the number of samples as the number of source indices processed
176
+ num_samples = len(self.phys_op_id_to_sources_processed[phys_op_id])
177
+ phys_op_id_to_num_samples[phys_op_id] = num_samples
178
+ total_num_samples += num_samples
179
+
180
+ # compute avg. selectivity, cost, time, and quality for each physical operator
181
+ def total_output(record_op_stats_lst):
182
+ return sum([record_op_stats.passed_operator for record_op_stats in record_op_stats_lst])
183
+
184
+ def total_input(record_op_stats_lst):
185
+ return len(set([record_op_stats.record_parent_id for record_op_stats in record_op_stats_lst]))
186
+
187
+ phys_op_to_mean_selectivity = {
188
+ op_id: total_output(record_op_stats_lst) / total_input(record_op_stats_lst)
189
+ for op_id, record_op_stats_lst in phys_op_id_to_record_op_stats.items()
190
+ }
191
+ phys_op_to_mean_cost = {
192
+ op_id: np.mean([record_op_stats.cost_per_record for record_op_stats in record_op_stats_lst])
193
+ for op_id, record_op_stats_lst in phys_op_id_to_record_op_stats.items()
194
+ }
195
+ phys_op_to_mean_time = {
196
+ op_id: np.mean([record_op_stats.time_per_record for record_op_stats in record_op_stats_lst])
197
+ for op_id, record_op_stats_lst in phys_op_id_to_record_op_stats.items()
198
+ }
199
+ phys_op_to_mean_quality = {
200
+ op_id: np.mean([record_op_stats.quality for record_op_stats in record_op_stats_lst])
201
+ for op_id, record_op_stats_lst in phys_op_id_to_record_op_stats.items()
202
+ }
203
+
204
+ # # compute average, LCB, and UCB of each operator; the confidence bounds depend upon
205
+ # # the computation of the alpha parameter, which we scale to be 0.5 * the mean (of means)
206
+ # # of the metric across all operators in this operator set
207
+ # cost_alpha = 0.5 * np.mean([mean_cost for mean_cost in phys_op_to_mean_cost.values()])
208
+ # time_alpha = 0.5 * np.mean([mean_time for mean_time in phys_op_to_mean_time.values()])
209
+ # quality_alpha = 0.5 * np.mean([mean_quality for mean_quality in phys_op_to_mean_quality.values()])
210
+ # selectivity_alpha = 0.5 * np.mean([mean_selectivity for mean_selectivity in phys_op_to_mean_selectivity.values()])
211
+ cost_alpha = 0.5 * (np.max(list(phys_op_to_mean_cost.values())) - np.min(list(phys_op_to_mean_cost.values())))
212
+ time_alpha = 0.5 * (np.max(list(phys_op_to_mean_time.values())) - np.min(list(phys_op_to_mean_time.values())))
213
+ quality_alpha = 0.5 * (np.max(list(phys_op_to_mean_quality.values())) - np.min(list(phys_op_to_mean_quality.values())))
214
+ selectivity_alpha = 0.5 * (np.max(list(phys_op_to_mean_selectivity.values())) - np.min(list(phys_op_to_mean_selectivity.values())))
215
+
216
+ # compute metrics for each physical operator
217
+ op_metrics = {}
218
+ for op_id in phys_op_id_to_record_op_stats:
219
+ sample_ratio = np.sqrt(np.log(total_num_samples) / phys_op_id_to_num_samples[op_id])
220
+ exploration_terms = np.array([cost_alpha * sample_ratio, time_alpha * sample_ratio, quality_alpha * sample_ratio, selectivity_alpha * sample_ratio])
221
+ mean_terms = (phys_op_to_mean_cost[op_id], phys_op_to_mean_time[op_id], phys_op_to_mean_quality[op_id], phys_op_to_mean_selectivity[op_id])
222
+
223
+ # NOTE: we could clip these; however I will not do so for now to allow for arbitrary quality metric(s)
224
+ lcb_terms = mean_terms - exploration_terms
225
+ ucb_terms = mean_terms + exploration_terms
226
+ op_metrics[op_id] = {"mean": mean_terms, "lcb": lcb_terms, "ucb": ucb_terms}
227
+
228
+ # get the tuple representation of this policy
229
+ policy_dict = self.policy.get_dict()
230
+
231
+ # compute the pareto optimal set of operators
232
+ pareto_op_set = set()
233
+ for op_id, metrics in op_metrics.items():
234
+ cost, time, quality, selectivity = metrics["mean"]
235
+ pareto_frontier = True
236
+
237
+ # check if any other operator dominates op_id
238
+ for other_op_id, other_metrics in op_metrics.items():
239
+ other_cost, other_time, other_quality, other_selectivity = other_metrics["mean"]
240
+ if op_id == other_op_id:
241
+ continue
242
+
243
+ # if op_id is dominated by other_op_id, set pareto_frontier = False and break
244
+ # NOTE: here we use a strict inequality (instead of the usual <= or >=) because
245
+ # all ops which have equal cost / time / quality / sel. should not be
246
+ # filtered out from sampling by our logic in this function
247
+ cost_dominated = True if policy_dict["cost"] == 0.0 else other_cost < cost
248
+ time_dominated = True if policy_dict["time"] == 0.0 else other_time < time
249
+ quality_dominated = True if policy_dict["quality"] == 0.0 else other_quality > quality
250
+ selectivity_dominated = True if not self.is_filter_op else other_selectivity < selectivity
251
+ if cost_dominated and time_dominated and quality_dominated and selectivity_dominated:
252
+ pareto_frontier = False
253
+ break
254
+
255
+ # add op_id to pareto frontier if it's not dominated
256
+ if pareto_frontier:
257
+ pareto_op_set.add(op_id)
258
+
259
+ # iterate over op metrics and compute the new frontier set of operators
260
+ new_frontier_op_ids = set()
261
+ for op_id, metrics in op_metrics.items():
262
+
263
+ # if this op is fully sampled, do not keep it on the frontier
264
+ if phys_op_id_to_num_samples[op_id] == len(self.source_indices):
265
+ continue
266
+
267
+ # if this op is pareto optimal keep it in our frontier ops
268
+ if op_id in pareto_op_set:
269
+ new_frontier_op_ids.add(op_id)
270
+ continue
271
+
272
+ # otherwise, if this op overlaps with an op on the pareto frontier, keep it in our frontier ops
273
+ # NOTE: for now, we perform an optimistic comparison with the ucb/lcb
274
+ pareto_frontier = True
275
+ op_cost, op_time, _, op_selectivity = metrics["lcb"]
276
+ op_quality = metrics["ucb"][2]
277
+ for pareto_op_id in pareto_op_set:
278
+ pareto_cost, pareto_time, _, pareto_selectivity = op_metrics[pareto_op_id]["ucb"]
279
+ pareto_quality = op_metrics[pareto_op_id]["lcb"][2]
280
+
281
+ # if op_id is dominated by pareto_op_id, set pareto_frontier = False and break
282
+ cost_dominated = True if policy_dict["cost"] == 0.0 else pareto_cost <= op_cost
283
+ time_dominated = True if policy_dict["time"] == 0.0 else pareto_time <= op_time
284
+ quality_dominated = True if policy_dict["quality"] == 0.0 else pareto_quality >= op_quality
285
+ selectivity_dominated = True if not self.is_filter_op else pareto_selectivity <= op_selectivity
286
+ if cost_dominated and time_dominated and quality_dominated and selectivity_dominated:
287
+ pareto_frontier = False
288
+ break
289
+
290
+ # add op_id to pareto frontier if it's not dominated
291
+ if pareto_frontier:
292
+ new_frontier_op_ids.add(op_id)
293
+
294
+ # for operators that were in the frontier, keep them in the frontier if they
295
+ # are still pareto optimal, otherwise, move them to the end of the reservoir
296
+ new_frontier_ops = []
297
+ for op in self.frontier_ops:
298
+ if op.get_op_id() in new_frontier_op_ids:
299
+ new_frontier_ops.append(op)
300
+ else:
301
+ self.off_frontier_ops.append(op)
302
+
303
+ # if there are operators we previously sampled which are now back on the frontier
304
+ # add them to the frontier, otherwise, put them back in the off_frontier_ops
305
+ new_off_frontier_ops = []
306
+ for op in self.off_frontier_ops:
307
+ if op.get_op_id() in new_frontier_op_ids:
308
+ new_frontier_ops.append(op)
309
+ else:
310
+ new_off_frontier_ops.append(op)
311
+
312
+ # finally, if we have fewer than k operators in the frontier, sample new operators
313
+ # from the reservoir and put them in the frontier
314
+ while len(new_frontier_ops) < self.k and len(self.reservoir_ops) > 0:
315
+ new_op = self.reservoir_ops.pop(0)
316
+ new_frontier_ops.append(new_op)
317
+
318
+ # update the frontier and off frontier ops
319
+ self.frontier_ops = new_frontier_ops
320
+ self.off_frontier_ops = new_off_frontier_ops
321
+
322
+ def pick_highest_quality_output(self, record_sets: list[DataRecordSet]) -> DataRecordSet:
323
+ # if there's only one operator in the set, we return its record_set
324
+ if len(record_sets) == 1:
325
+ return record_sets[0]
326
+
327
+ # NOTE: I don't like that this assumes the models are consistent in
328
+ # how they order their record outputs for one-to-many converts;
329
+ # eventually we can try out more robust schemes to account for
330
+ # differences in ordering
331
+ # aggregate records at each index in the response
332
+ idx_to_records = {}
333
+ for record_set in record_sets:
334
+ for idx in range(len(record_set)):
335
+ record, record_op_stats = record_set[idx], record_set.record_op_stats[idx]
336
+ if idx not in idx_to_records:
337
+ idx_to_records[idx] = [(record, record_op_stats)]
338
+ else:
339
+ idx_to_records[idx].append((record, record_op_stats))
340
+
341
+ # compute highest quality answer at each index
342
+ out_records = []
343
+ out_record_op_stats = []
344
+ for idx in range(len(idx_to_records)):
345
+ records_lst, record_op_stats_lst = zip(*idx_to_records[idx])
346
+ max_quality_record, max_quality = records_lst[0], record_op_stats_lst[0].quality
347
+ max_quality_stats = record_op_stats_lst[0]
348
+ for record, record_op_stats in zip(records_lst[1:], record_op_stats_lst[1:]):
349
+ record_quality = record_op_stats.quality
350
+ if record_quality > max_quality:
351
+ max_quality_record = record
352
+ max_quality = record_quality
353
+ max_quality_stats = record_op_stats
354
+ out_records.append(max_quality_record)
355
+ out_record_op_stats.append(max_quality_stats)
356
+
357
+ # create and return final DataRecordSet
358
+ return DataRecordSet(out_records, out_record_op_stats)
359
+
360
+ def update_inputs(self, source_idx_to_record_sets: dict[int, DataRecordSet]):
361
+ """
362
+ Update the inputs for this logical operator based on the outputs of the previous logical operator.
363
+ """
364
+ for source_idx, record_sets in source_idx_to_record_sets.items():
365
+ input = []
366
+ max_quality_record_set = self.pick_highest_quality_output(record_sets)
367
+ for record in max_quality_record_set:
368
+ input.append(record if record.passed_operator else None)
369
+
370
+ self.source_idx_to_input[source_idx] = input
371
+
372
+
373
+ # TODO: post-submission we will need to modify this to:
374
+ # - submit all inputs for aggregate operators
375
+ # - handle limits
376
+ class MABExecutionStrategy(SentinelExecutionStrategy):
377
+ """
378
+ This class implements the Multi-Armed Bandit (MAB) execution strategy for SentinelQueryProcessors.
379
+
380
+ NOTE: the number of samples will slightly exceed the sample_budget if the number of operator
381
+ calls does not perfectly match the sample_budget. This may cause some minor discrepancies with
382
+ the progress manager as a result.
383
+ """
384
+ def _get_max_quality_op(self, logical_op_id: str, op_frontiers: dict[str, OpFrontier], plan_stats: SentinelPlanStats) -> PhysicalOperator:
385
+ """
386
+ Returns the operator in the frontier with the highest (estimated) quality.
387
+ """
388
+ # get the operators in the frontier set for this logical_op_id
389
+ frontier_ops = op_frontiers[logical_op_id].get_frontier_ops()
390
+
391
+ # get a mapping from physical_op_id --> list[RecordOpStats]
392
+ phys_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(logical_op_id, {})
393
+ phys_op_id_to_record_op_stats = {
394
+ phys_op_id: op_stats.record_op_stats_lst
395
+ for phys_op_id, op_stats in phys_op_id_to_op_stats.items()
396
+ }
397
+
398
+ # iterate over the frontier ops and return the one with the highest quality
399
+ max_quality_op, max_avg_quality = None, None
400
+ for op in frontier_ops:
401
+ op_quality_stats = []
402
+ if op.get_op_id() in phys_op_id_to_record_op_stats:
403
+ op_quality_stats = [record_op_stats.quality for record_op_stats in phys_op_id_to_record_op_stats[op.get_op_id()]]
404
+ avg_op_quality = sum(op_quality_stats) / len(op_quality_stats) if len(op_quality_stats) > 0 else 0.0
405
+ if max_avg_quality is None or avg_op_quality > max_avg_quality:
406
+ max_quality_op = op
407
+ max_avg_quality = avg_op_quality
408
+
409
+ return max_quality_op
410
+
411
+ def _execute_sentinel_plan(
412
+ self,
413
+ plan: SentinelPlan,
414
+ op_frontiers: dict[str, OpFrontier],
415
+ expected_outputs: dict[int, dict] | None,
416
+ plan_stats: SentinelPlanStats,
417
+ ) -> SentinelPlanStats:
418
+ # sample records and operators and update the frontiers
419
+ samples_drawn = 0
420
+ while samples_drawn < self.sample_budget:
421
+ # pre-compute the set of source indices which will need to be sampled
422
+ source_indices_to_sample = set()
423
+ for op_frontier in op_frontiers.values():
424
+ source_indices = op_frontier.get_source_indices_for_next_iteration()
425
+ source_indices_to_sample.update(source_indices)
426
+
427
+ # execute operator sets in sequence
428
+ for op_idx, (logical_op_id, op_set) in enumerate(plan):
429
+ # use the execution cache to determine the maximum quality operator for this logical_op_id
430
+ max_quality_op = self._get_max_quality_op(logical_op_id, op_frontiers, plan_stats)
431
+
432
+ # TODO: can have None as an operator if _get_max_quality_op returns None
433
+ # get frontier ops and their next input
434
+ frontier_op_input_pairs = op_frontiers[logical_op_id].get_frontier_op_input_pairs(source_indices_to_sample, max_quality_op)
435
+ frontier_op_input_pairs = list(filter(lambda tup: tup[1] is not None, frontier_op_input_pairs))
436
+
437
+ # break out of the loop if frontier_op_input_pairs is empty, as this means all records have been filtered out
438
+ if len(frontier_op_input_pairs) == 0:
439
+ break
440
+
441
+ # run sampled operators on sampled inputs and update the number of samples drawn
442
+ source_idx_to_record_set_tuples, num_llm_ops = self._execute_op_set(frontier_op_input_pairs)
443
+ samples_drawn += num_llm_ops
444
+
445
+ # FUTURE TODO: have this return the highest quality record set simply based on our posterior (or prior) belief on operator quality
446
+ # get the target record set for each source_idx
447
+ source_idx_to_target_record_set = self._get_target_record_sets(logical_op_id, source_idx_to_record_set_tuples, expected_outputs)
448
+
449
+ # FUTURE TODO: move this outside of the loop (i.e. assume we only get quality label(s) after executing full program)
450
+ # score the quality of each generated output
451
+ physical_op_cls = op_set[0].__class__
452
+ source_idx_to_all_record_sets = {
453
+ source_idx: [record_set for record_set, _, _ in record_set_tuples]
454
+ for source_idx, record_set_tuples in source_idx_to_record_set_tuples.items()
455
+ }
456
+ source_idx_to_all_record_sets = self._score_quality(physical_op_cls, source_idx_to_all_record_sets, source_idx_to_target_record_set)
457
+
458
+ # flatten the lists of newly computed records and record_op_stats
459
+ source_idx_to_new_record_sets = {
460
+ source_idx: [record_set for record_set, _, is_new in record_set_tuples if is_new]
461
+ for source_idx, record_set_tuples in source_idx_to_record_set_tuples.items()
462
+ }
463
+ new_records, new_record_op_stats = self._flatten_record_sets(source_idx_to_new_record_sets)
464
+
465
+ # update the number of samples drawn for each operator
466
+
467
+ # update plan stats
468
+ plan_stats.add_record_op_stats(new_record_op_stats)
469
+
470
+ # add records (which are not filtered) to the cache, if allowed
471
+ self._add_records_to_cache(logical_op_id, new_records)
472
+
473
+ # FUTURE TODO: simply set input based on source_idx_to_target_record_set (b/c we won't have scores computed)
474
+ # provide the champion record sets as inputs to the next logical operator
475
+ if op_idx + 1 < len(plan):
476
+ next_logical_op_id = plan.logical_op_ids[op_idx + 1]
477
+ op_frontiers[next_logical_op_id].update_inputs(source_idx_to_all_record_sets)
478
+
479
+ # update the (pareto) frontier for each set of operators
480
+ op_frontiers[logical_op_id].update_frontier(logical_op_id, plan_stats)
481
+
482
+ # FUTURE TODO: score op quality based on final outputs
483
+
484
+ # close the cache
485
+ self._close_cache(plan.logical_op_ids)
486
+
487
+ # finalize plan stats
488
+ plan_stats.finish()
489
+
490
+ return plan_stats
491
+
492
+
493
+ def execute_sentinel_plan(self, plan: SentinelPlan, expected_outputs: dict[int, dict] | None):
494
+ # for now, assert that the first operator in the plan is a ScanPhysicalOp
495
+ assert all(isinstance(op, ScanPhysicalOp) for op in plan.operator_sets[0]), "First operator in physical plan must be a ScanPhysicalOp"
496
+ logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
497
+ logger.info(f"Plan Details: {plan}")
498
+
499
+ # initialize plan stats
500
+ plan_stats = SentinelPlanStats.from_plan(plan)
501
+ plan_stats.start()
502
+
503
+ # shuffle the indices of records to sample
504
+ total_num_samples = len(self.val_datasource)
505
+ shuffled_source_indices = [int(idx) for idx in np.arange(total_num_samples)]
506
+ self.rng.shuffle(shuffled_source_indices)
507
+
508
+ # initialize frontier for each logical operator
509
+ op_frontiers = {
510
+ logical_op_id: OpFrontier(op_set, shuffled_source_indices, self.k, self.j, self.seed, self.policy)
511
+ for logical_op_id, op_set in plan
512
+ }
513
+
514
+ # initialize and start the progress manager
515
+ self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, progress=self.progress)
516
+ self.progress_manager.start()
517
+
518
+ # NOTE: we must handle progress manager outside of _exeecute_sentinel_plan to ensure that it is shut down correctly;
519
+ # if we don't have the `finally:` branch, then program crashes can cause future program runs to fail because
520
+ # the progress manager cannot get a handle to the console
521
+ try:
522
+ # execute sentinel plan by sampling records and operators
523
+ plan_stats = self._execute_sentinel_plan(plan, op_frontiers, expected_outputs, plan_stats)
524
+
525
+ finally:
526
+ # finish progress tracking
527
+ self.progress_manager.finish()
528
+
529
+ logger.info(f"Done executing sentinel plan: {plan.plan_id}")
530
+ logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
531
+
532
+ return plan_stats