palimpzest 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. palimpzest/__init__.py +5 -0
  2. palimpzest/constants.py +110 -43
  3. palimpzest/core/__init__.py +0 -78
  4. palimpzest/core/data/dataclasses.py +382 -44
  5. palimpzest/core/elements/filters.py +7 -3
  6. palimpzest/core/elements/index.py +70 -0
  7. palimpzest/core/elements/records.py +33 -11
  8. palimpzest/core/lib/fields.py +1 -0
  9. palimpzest/core/lib/schemas.py +4 -3
  10. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
  11. palimpzest/prompts/prompt_factory.py +44 -7
  12. palimpzest/prompts/split_merge_prompts.py +56 -0
  13. palimpzest/prompts/split_proposer_prompts.py +55 -0
  14. palimpzest/query/execution/execution_strategy.py +435 -53
  15. palimpzest/query/execution/execution_strategy_type.py +20 -0
  16. palimpzest/query/execution/mab_execution_strategy.py +532 -0
  17. palimpzest/query/execution/parallel_execution_strategy.py +143 -172
  18. palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
  19. palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
  20. palimpzest/query/generators/api_client_factory.py +31 -0
  21. palimpzest/query/generators/generators.py +256 -76
  22. palimpzest/query/operators/__init__.py +1 -2
  23. palimpzest/query/operators/code_synthesis_convert.py +33 -18
  24. palimpzest/query/operators/convert.py +30 -97
  25. palimpzest/query/operators/critique_and_refine_convert.py +5 -6
  26. palimpzest/query/operators/filter.py +7 -10
  27. palimpzest/query/operators/logical.py +54 -10
  28. palimpzest/query/operators/map.py +130 -0
  29. palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
  30. palimpzest/query/operators/physical.py +3 -12
  31. palimpzest/query/operators/rag_convert.py +66 -18
  32. palimpzest/query/operators/retrieve.py +230 -34
  33. palimpzest/query/operators/scan.py +5 -2
  34. palimpzest/query/operators/split_convert.py +169 -0
  35. palimpzest/query/operators/token_reduction_convert.py +8 -14
  36. palimpzest/query/optimizer/__init__.py +4 -16
  37. palimpzest/query/optimizer/cost_model.py +73 -266
  38. palimpzest/query/optimizer/optimizer.py +87 -58
  39. palimpzest/query/optimizer/optimizer_strategy.py +18 -97
  40. palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
  41. palimpzest/query/optimizer/plan.py +2 -3
  42. palimpzest/query/optimizer/primitives.py +5 -3
  43. palimpzest/query/optimizer/rules.py +336 -172
  44. palimpzest/query/optimizer/tasks.py +30 -100
  45. palimpzest/query/processor/config.py +38 -22
  46. palimpzest/query/processor/nosentinel_processor.py +16 -520
  47. palimpzest/query/processor/processing_strategy_type.py +28 -0
  48. palimpzest/query/processor/query_processor.py +38 -206
  49. palimpzest/query/processor/query_processor_factory.py +117 -130
  50. palimpzest/query/processor/sentinel_processor.py +90 -0
  51. palimpzest/query/processor/streaming_processor.py +25 -32
  52. palimpzest/sets.py +88 -41
  53. palimpzest/utils/model_helpers.py +8 -7
  54. palimpzest/utils/progress.py +368 -152
  55. palimpzest/utils/token_reduction_helpers.py +1 -3
  56. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/METADATA +28 -24
  57. palimpzest-0.7.0.dist-info/RECORD +96 -0
  58. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
  59. palimpzest/query/processor/mab_sentinel_processor.py +0 -884
  60. palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
  61. palimpzest/utils/index_helpers.py +0 -6
  62. palimpzest-0.6.3.dist-info/RECORD +0 -87
  63. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
  64. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,21 @@
1
+ import logging
1
2
  import multiprocessing
2
- import time
3
3
  from concurrent.futures import ThreadPoolExecutor, wait
4
4
 
5
5
  from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
6
- from palimpzest.core.data.dataclasses import OperatorStats, PlanStats
6
+ from palimpzest.core.data.dataclasses import PlanStats
7
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
7
8
  from palimpzest.query.execution.execution_strategy import ExecutionStrategy
8
9
  from palimpzest.query.operators.aggregate import AggregateOp
9
10
  from palimpzest.query.operators.limit import LimitScanOp
10
11
  from palimpzest.query.operators.physical import PhysicalOperator
11
12
  from palimpzest.query.operators.scan import ScanPhysicalOp
12
13
  from palimpzest.query.optimizer.plan import PhysicalPlan
14
+ from palimpzest.utils.progress import create_progress_manager
13
15
 
16
+ logger = logging.getLogger(__name__)
14
17
 
15
- class PipelinedParallelExecutionStrategy(ExecutionStrategy):
18
+ class ParallelExecutionStrategy(ExecutionStrategy):
16
19
  """
17
20
  A parallel execution strategy that processes data through a pipeline of operators using thread-based parallelism.
18
21
  """
@@ -20,12 +23,12 @@ class PipelinedParallelExecutionStrategy(ExecutionStrategy):
20
23
  def __init__(self, *args, **kwargs):
21
24
  super().__init__(*args, **kwargs)
22
25
  self.max_workers = (
23
- self.get_parallel_max_workers()
26
+ self._get_parallel_max_workers()
24
27
  if self.max_workers is None
25
28
  else self.max_workers
26
29
  )
27
30
 
28
- def get_parallel_max_workers(self):
31
+ def _get_parallel_max_workers(self):
29
32
  # for now, return the number of system CPUs;
30
33
  # in the future, we may want to consider the models the user has access to
31
34
  # and whether or not they will encounter rate-limits. If they will, we should
@@ -35,180 +38,148 @@ class PipelinedParallelExecutionStrategy(ExecutionStrategy):
35
38
  # changing the max_workers in response to 429 errors.
36
39
  return max(int(0.8 * multiprocessing.cpu_count()), 1)
37
40
 
38
- def execute_plan(self, plan: PhysicalPlan, num_samples: int | float = float("inf"), plan_workers: int = 1):
39
- """Initialize the stats and the execute the plan."""
40
- if self.verbose:
41
- print("----------------------")
42
- print(f"PLAN[{plan.plan_id}] (n={num_samples}):")
43
- print(plan)
44
- print("---")
45
-
46
- plan_start_time = time.time()
47
-
48
- # initialize plan stats and operator stats
49
- plan_stats = PlanStats(plan_id=plan.plan_id, plan_str=str(plan))
50
- for op in plan.operators:
51
- op_id = op.get_op_id()
52
- op_name = op.op_name()
53
- op_details = {k: str(v) for k, v in op.get_id_params().items()}
54
- plan_stats.operator_stats[op_id] = OperatorStats(op_id=op_id, op_name=op_name, op_details=op_details)
55
-
56
- # initialize list of output records and intermediate variables
41
+ def _any_queue_not_empty(self, queues: dict[str, list]) -> bool:
42
+ """Helper function to check if any queue is not empty."""
43
+ return any(len(queue) > 0 for queue in queues.values())
44
+
45
+ def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list], future_queues: dict[str, list]) -> bool:
46
+ """Helper function to check if all upstream operators have finished processing their inputs."""
47
+ for upstream_op_idx in range(op_idx):
48
+ upstream_op_id = plan.operators[upstream_op_idx].get_op_id()
49
+ if len(input_queues[upstream_op_id]) > 0 or len(future_queues[upstream_op_id]) > 0:
50
+ return False
51
+
52
+ return True
53
+
54
+ def _process_future_results(self, operator: PhysicalOperator, future_queues: dict[str, list], plan_stats: PlanStats) -> list[DataRecord]:
55
+ """
56
+ Helper function which takes an operator, the future queues, and plan stats, and performs
57
+ the updates to plan stats and progress manager before returning the results from the finished futures.
58
+ """
59
+ # get the op_id for the operator
60
+ op_id = operator.get_op_id()
61
+
62
+ # this function is called when the future queue is not empty
63
+ # and the executor is not busy processing other futures
64
+ done_futures, not_done_futures = wait(future_queues[op_id], timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
65
+
66
+ # add the unfinished futures back to the previous op's future queue
67
+ future_queues[op_id] = list(not_done_futures)
68
+
69
+ # add the finished futures to the input queue for this operator
70
+ output_records = []
71
+ for future in done_futures:
72
+ record_set: DataRecordSet = future.result()
73
+ records = record_set.data_records
74
+ record_op_stats = record_set.record_op_stats
75
+ num_outputs = sum(record.passed_operator for record in records)
76
+
77
+ # update the progress manager
78
+ self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
79
+
80
+ # update plan stats
81
+ plan_stats.add_record_op_stats(record_op_stats)
82
+
83
+ # add records to the cache
84
+ self._add_records_to_cache(operator.target_cache_id, records)
85
+
86
+ # add records which aren't filtered to the output records
87
+ output_records.extend([record for record in records if record.passed_operator])
88
+
89
+ return output_records
90
+
91
+ def _execute_plan(
92
+ self,
93
+ plan: PhysicalPlan,
94
+ input_queues: dict[str, list],
95
+ future_queues: dict[str, list],
96
+ plan_stats: PlanStats,
97
+ ) -> tuple[list[DataRecord], PlanStats]:
98
+ # process all of the input records using a thread pool
57
99
  output_records = []
58
- source_records_scanned = 0
59
-
60
- # initialize data structures to help w/processing DAG
61
- processing_queue = []
62
- op_id_to_futures_in_flight = {op.get_op_id(): 0 for op in plan.operators}
63
- op_id_to_operator = {op.get_op_id(): op for op in plan.operators}
64
- op_id_to_prev_operator = {
65
- op.get_op_id(): plan.operators[idx - 1] if idx > 0 else None for idx, op in enumerate(plan.operators)
66
- }
67
- op_id_to_next_operator = {
68
- op.get_op_id(): plan.operators[idx + 1] if idx + 1 < len(plan.operators) else None
69
- for idx, op in enumerate(plan.operators)
70
- }
71
- op_id_to_op_idx = {op.get_op_id(): idx for idx, op in enumerate(plan.operators)}
72
-
73
- # get handle to scan operator and pre-compute its op_id and size
74
- source_operator = plan.operators[0]
75
- assert isinstance(source_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
76
- source_op_id = source_operator.get_op_id()
77
- datareader_len = len(source_operator.datareader)
78
-
79
- # get limit of final limit operator (if one exists)
80
- final_limit = plan.operators[-1].limit if isinstance(plan.operators[-1], LimitScanOp) else None
81
-
82
- # create thread pool w/max workers
83
- futures = []
84
- current_scan_idx = self.scan_start_idx
85
- with ThreadPoolExecutor(max_workers=plan_workers) as executor:
86
- # create initial (set of) future(s) to read first source record;
87
- futures.append(executor.submit(PhysicalOperator.execute_op_wrapper, source_operator, current_scan_idx))
88
- op_id_to_futures_in_flight[source_op_id] += 1
89
- current_scan_idx += 1
90
-
91
- # iterate until we have processed all operators on all records or come to an early stopping condition
92
- while len(futures) > 0:
93
- # get the set of futures that have (and have not) finished in the last PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
94
- done_futures, not_done_futures = wait(futures, timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
95
-
96
- # cast not_done_futures from a set to a list so we can append to it
97
- not_done_futures = list(not_done_futures)
98
-
99
- # process finished futures, creating new ones as needed
100
- new_futures = []
101
- for future in done_futures:
102
- # get the result
103
- record_set, operator, _ = future.result()
100
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
101
+ logger.debug(f"Created thread pool with {self.max_workers} workers")
102
+
103
+ # execute the plan until either:
104
+ # 1. all records have been processed, or
105
+ # 2. the final limit operation has completed (we break out of the loop if this happens)
106
+ final_op = plan.operators[-1]
107
+ while self._any_queue_not_empty(input_queues) or self._any_queue_not_empty(future_queues):
108
+ for op_idx, operator in enumerate(plan.operators):
104
109
  op_id = operator.get_op_id()
105
110
 
106
- # decrement future from mapping of futures in-flight
107
- op_id_to_futures_in_flight[op_id] -= 1
108
-
109
- # update plan stats
110
- prev_operator = op_id_to_prev_operator[op_id]
111
- plan_stats.operator_stats[op_id].add_record_op_stats(
112
- record_set.record_op_stats,
113
- source_op_id=prev_operator.get_op_id() if prev_operator is not None else None,
114
- plan_id=plan.plan_id,
115
- )
116
-
117
- # process each record output by the future's operator
118
- for record in record_set:
119
- # skip records which are filtered out
120
- if not getattr(record, "passed_operator", True):
121
- continue
122
-
123
- # add records (which are not filtered) to the cache, if allowed
124
- if not self.nocache:
125
- # self.datadir.append_cache(operator.target_cache_id, record)
126
- pass
127
-
128
- # add records to processing queue if there is a next_operator; otherwise add to output_records
129
- next_operator = op_id_to_next_operator[op_id]
130
- if next_operator is not None:
131
- processing_queue.append((next_operator, record))
132
- else:
133
- output_records.append(record)
134
-
135
- # if this operator was a source scan, update the number of source records scanned
136
- if op_id == source_op_id:
137
- source_records_scanned += len(record_set)
138
-
139
- # scan next record if we can still draw records from source
140
- if source_records_scanned < num_samples and current_scan_idx < datareader_len:
141
- new_futures.append(executor.submit(PhysicalOperator.execute_op_wrapper, source_operator, current_scan_idx))
142
- op_id_to_futures_in_flight[source_op_id] += 1
143
- current_scan_idx += 1
144
-
145
- # check early stopping condition based on final limit
146
- if final_limit is not None and len(output_records) >= final_limit:
147
- output_records = output_records[:final_limit]
148
- futures = []
149
- break
150
-
151
- # process all records in the processing queue which are ready to be executed
152
- temp_processing_queue = []
153
- for operator, candidate in processing_queue:
154
- # if the candidate is not an input to an aggregate, execute it right away
155
- if not isinstance(operator, AggregateOp):
156
- future = executor.submit(PhysicalOperator.execute_op_wrapper, operator, candidate)
157
- new_futures.append(future)
158
- op_id_to_futures_in_flight[operator.get_op_id()] += 1
159
-
160
- # otherwise, put it back on the queue
161
- else:
162
- temp_processing_queue.append((operator, candidate))
163
-
164
- # any remaining candidates are inputs to aggregate operators; for each aggregate operator
165
- # determine if it is ready to execute -- and execute all of its candidates if so
166
- processing_queue = []
167
- agg_op_ids = set([operator.get_op_id() for operator, _ in temp_processing_queue])
168
- for agg_op_id in agg_op_ids:
169
- agg_op_idx = op_id_to_op_idx[agg_op_id]
170
-
171
- # compute if all upstream operators' processing queues are empty and their in-flight futures are finished
172
- upstream_ops_are_finished = True
173
- for upstream_op_idx in range(agg_op_idx):
174
- upstream_op_id = plan.operators[upstream_op_idx].get_op_id()
175
- upstream_op_id_queue = list(
176
- filter(lambda tup: tup[0].get_op_id() == upstream_op_id, temp_processing_queue)
177
- )
178
-
179
- upstream_ops_are_finished = (
180
- upstream_ops_are_finished
181
- and len(upstream_op_id_queue) == 0
182
- and op_id_to_futures_in_flight[upstream_op_id] == 0
183
- )
184
-
185
- # get the subset of candidates for this aggregate operator
186
- candidate_tuples = list(filter(lambda tup: tup[0].get_op_id() == agg_op_id, temp_processing_queue))
187
-
188
- # execute the operator on the candidates if it's ready
189
- if upstream_ops_are_finished:
190
- operator = op_id_to_operator[agg_op_id]
191
- candidates = list(map(lambda tup: tup[1], candidate_tuples))
192
- future = executor.submit(PhysicalOperator.execute_op_wrapper, operator, candidates)
193
- new_futures.append(future)
194
- op_id_to_futures_in_flight[operator.get_op_id()] += 1
195
-
196
- # otherwise, add the candidates back to the processing queue
111
+ # get any finished futures from the previous operator and add them to the input queue for this operator
112
+ if not isinstance(operator, ScanPhysicalOp):
113
+ prev_operator = plan.operators[op_idx - 1]
114
+ records = self._process_future_results(prev_operator, future_queues, plan_stats)
115
+ input_queues[op_id].extend(records)
116
+
117
+ # for the final operator, add any finished futures to the output_records
118
+ if operator.get_op_id() == final_op.get_op_id():
119
+ records = self._process_future_results(operator, future_queues, plan_stats)
120
+ output_records.extend(records)
121
+
122
+ # if this operator does not have enough inputs to execute, then skip it
123
+ num_inputs = len(input_queues[op_id])
124
+ agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, op_idx, input_queues, future_queues)
125
+ if num_inputs == 0 or agg_op_not_ready:
126
+ continue
127
+
128
+ # if this operator is an aggregate, process all the records in the input queue
129
+ if isinstance(operator, AggregateOp):
130
+ input_records = [input_queues[op_id].pop(0) for _ in range(num_inputs)]
131
+ future = executor.submit(operator, input_records)
132
+ future_queues[op_id].append(future)
133
+
197
134
  else:
198
- processing_queue.extend(candidate_tuples)
135
+ input_record = input_queues[op_id].pop(0)
136
+ future = executor.submit(operator, input_record)
137
+ future_queues[op_id].append(future)
199
138
 
200
- # update list of futures
201
- not_done_futures.extend(new_futures)
202
- futures = not_done_futures
139
+ # break out of loop if the final operator is a LimitScanOp and we've reached its limit
140
+ if isinstance(final_op, LimitScanOp) and len(output_records) == final_op.limit:
141
+ break
203
142
 
204
- # if caching was allowed, close the cache
205
- if not self.nocache:
206
- for _ in plan.operators:
207
- # self.datadir.close_cache(operator.target_cache_id)
208
- pass
143
+ # close the cache
144
+ self._close_cache([op.target_cache_id for op in plan.operators])
209
145
 
210
146
  # finalize plan stats
211
- total_plan_time = time.time() - plan_start_time
212
- plan_stats.finalize(total_plan_time)
147
+ plan_stats.finish()
148
+
149
+ return output_records, plan_stats
150
+
151
+
152
+ def execute_plan(self, plan: PhysicalPlan):
153
+ """Initialize the stats and execute the plan."""
154
+ # for now, assert that the first operator in the plan is a ScanPhysicalOp
155
+ assert isinstance(plan.operators[0], ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
156
+ logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
157
+ logger.info(f"Plan Details: {plan}")
158
+
159
+ # initialize plan stats
160
+ plan_stats = PlanStats.from_plan(plan)
161
+ plan_stats.start()
162
+
163
+ # initialize input queues and future queues for each operation
164
+ input_queues = self._create_input_queues(plan)
165
+ future_queues = {op.get_op_id(): [] for op in plan.operators}
166
+
167
+ # initialize and start the progress manager
168
+ self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
169
+ self.progress_manager.start()
170
+
171
+ # NOTE: we must handle progress manager outside of _exeecute_plan to ensure that it is shut down correctly;
172
+ # if we don't have the `finally:` branch, then program crashes can cause future program runs to fail
173
+ # because the progress manager cannot get a handle to the console
174
+ try:
175
+ # execute plan
176
+ output_records, plan_stats = self._execute_plan(plan, input_queues, future_queues, plan_stats)
177
+
178
+ finally:
179
+ # finish progress tracking
180
+ self.progress_manager.finish()
181
+
182
+ logger.info(f"Done executing plan: {plan.plan_id}")
183
+ logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
213
184
 
214
185
  return output_records, plan_stats
@@ -0,0 +1,240 @@
1
+ import logging
2
+
3
+ import numpy as np
4
+
5
+ from palimpzest.core.data.dataclasses import SentinelPlanStats
6
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
7
+ from palimpzest.query.execution.execution_strategy import SentinelExecutionStrategy
8
+ from palimpzest.query.operators.physical import PhysicalOperator
9
+ from palimpzest.query.operators.scan import ScanPhysicalOp
10
+ from palimpzest.query.optimizer.plan import SentinelPlan
11
+ from palimpzest.utils.progress import create_progress_manager
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class OpSet:
16
+ """
17
+ This class represents the set of operators which are currently in the frontier for a given logical operator.
18
+ Each operator in the frontier is an instance of a PhysicalOperator which either:
19
+
20
+ 1. lies on the Pareto frontier of the set of sampled operators, or
21
+ 2. has been sampled fewer than j times
22
+ """
23
+
24
+ def __init__(self, op_set: list[PhysicalOperator], source_indices: list[int], k: int, j: int, seed: int):
25
+ # set k and j, which are the initial number of operators in the frontier and the
26
+ # initial number of records to sample for each frontier operator
27
+ self.k = min(k, len(op_set))
28
+ self.j = min(j, len(source_indices))
29
+
30
+ # get order in which we will sample physical operators for this logical operator
31
+ sample_op_indices = self._get_op_index_order(op_set, seed)
32
+
33
+ # construct the set of operators
34
+ self.ops = [op_set[sample_idx] for sample_idx in sample_op_indices[:self.k]]
35
+
36
+ # store the order in which we will sample the source records
37
+ self.source_indices = source_indices
38
+
39
+ # set the initial inputs for this logical operator
40
+ is_scan_op = isinstance(op_set[0], ScanPhysicalOp)
41
+ self.source_idx_to_input = {source_idx: [source_idx] for source_idx in self.source_indices} if is_scan_op else {}
42
+
43
+
44
+ def _get_op_index_order(self, op_set: list[PhysicalOperator], seed: int) -> list[int]:
45
+ """
46
+ Returns a list of indices for the operators in the op_set.
47
+ """
48
+ rng = np.random.default_rng(seed=seed)
49
+ op_indices = np.arange(len(op_set))
50
+ rng.shuffle(op_indices)
51
+ return op_indices
52
+
53
+ def get_op_input_pairs(self) -> list[PhysicalOperator, DataRecord | int | None]:
54
+ """
55
+ Returns the list of frontier operators and their next input to process. If there are
56
+ any indices in `source_indices_to_sample` which this operator does not sample on its own, then
57
+ we also have this frontier process that source_idx's input with its max quality operator.
58
+ """
59
+ # get the list of (op, source_idx) pairs which this operator needs to execute
60
+ op_source_idx_pairs = []
61
+ for op in self.ops:
62
+ # construct list of inputs by looking up the input for the given source_idx
63
+ for sample_idx in range(self.j):
64
+ source_idx = self.source_indices[sample_idx]
65
+ op_source_idx_pairs.append((op, source_idx))
66
+
67
+ # fetch the corresponding (op, input) pairs
68
+ op_input_pairs = []
69
+ for op, source_idx in op_source_idx_pairs:
70
+ op_input_pairs.extend([(op, input_record) for input_record in self.source_idx_to_input[source_idx]])
71
+
72
+ return op_input_pairs
73
+
74
+ def pick_highest_quality_output(self, record_sets: list[DataRecordSet]) -> DataRecordSet:
75
+ # if there's only one operator in the set, we return its record_set
76
+ if len(record_sets) == 1:
77
+ return record_sets[0]
78
+
79
+ # NOTE: I don't like that this assumes the models are consistent in
80
+ # how they order their record outputs for one-to-many converts;
81
+ # eventually we can try out more robust schemes to account for
82
+ # differences in ordering
83
+ # aggregate records at each index in the response
84
+ idx_to_records = {}
85
+ for record_set in record_sets:
86
+ for idx in range(len(record_set)):
87
+ record, record_op_stats = record_set[idx], record_set.record_op_stats[idx]
88
+ if idx not in idx_to_records:
89
+ idx_to_records[idx] = [(record, record_op_stats)]
90
+ else:
91
+ idx_to_records[idx].append((record, record_op_stats))
92
+
93
+ # compute highest quality answer at each index
94
+ out_records = []
95
+ out_record_op_stats = []
96
+ for idx in range(len(idx_to_records)):
97
+ records_lst, record_op_stats_lst = zip(*idx_to_records[idx])
98
+ max_quality_record, max_quality = records_lst[0], record_op_stats_lst[0].quality
99
+ max_quality_stats = record_op_stats_lst[0]
100
+ for record, record_op_stats in zip(records_lst[1:], record_op_stats_lst[1:]):
101
+ record_quality = record_op_stats.quality
102
+ if record_quality > max_quality:
103
+ max_quality_record = record
104
+ max_quality = record_quality
105
+ max_quality_stats = record_op_stats
106
+ out_records.append(max_quality_record)
107
+ out_record_op_stats.append(max_quality_stats)
108
+
109
+ # create and return final DataRecordSet
110
+ return DataRecordSet(out_records, out_record_op_stats)
111
+
112
+ def update_inputs(self, source_idx_to_record_sets: dict[int, DataRecordSet]):
113
+ """
114
+ Update the inputs for this logical operator based on the outputs of the previous logical operator.
115
+ """
116
+ for source_idx, record_sets in source_idx_to_record_sets.items():
117
+ input = []
118
+ max_quality_record_set = self.pick_highest_quality_output(record_sets)
119
+ for record in max_quality_record_set:
120
+ input.append(record if record.passed_operator else None)
121
+
122
+ self.source_idx_to_input[source_idx] = input
123
+
124
+
125
+ class RandomSamplingExecutionStrategy(SentinelExecutionStrategy):
126
+
127
+ def _get_source_indices(self):
128
+ """Get the list of source indices which the sentinel plan should execute over."""
129
+ # create list of all source indices and shuffle it
130
+ total_num_samples = len(self.val_datasource)
131
+ source_indices = list(np.arange(total_num_samples))
132
+ self.rng.shuffle(source_indices)
133
+
134
+ # slice the list of source indices to get the first j indices
135
+ j = min(self.j, len(source_indices))
136
+ source_indices = source_indices[:j]
137
+
138
+ return source_indices
139
+
140
+ def _execute_sentinel_plan(self,
141
+ plan: SentinelPlan,
142
+ op_sets: dict[str, OpSet],
143
+ expected_outputs: dict[int, dict] | None,
144
+ plan_stats: SentinelPlanStats,
145
+ ) -> SentinelPlanStats:
146
+ # execute operator sets in sequence
147
+ for op_idx, (logical_op_id, op_set) in enumerate(plan):
148
+ # get frontier ops and their next input
149
+ op_input_pairs = op_sets[logical_op_id].get_op_input_pairs()
150
+
151
+ # break out of the loop if op_input_pairs is empty, as this means all records have been filtered out
152
+ if len(op_input_pairs) == 0:
153
+ break
154
+
155
+ # run sampled operators on sampled inputs
156
+ source_idx_to_record_sets_and_ops, _ = self._execute_op_set(op_input_pairs)
157
+
158
+ # FUTURE TODO: have this return the highest quality record set simply based on our posterior (or prior) belief on operator quality
159
+ # get the target record set for each source_idx
160
+ source_idx_to_target_record_set = self._get_target_record_sets(logical_op_id, source_idx_to_record_sets_and_ops, expected_outputs)
161
+
162
+ # TODO: make consistent across here and RandomSampling
163
+ # FUTURE TODO: move this outside of the loop (i.e. assume we only get quality label(s) after executing full program)
164
+ # score the quality of each generated output
165
+ physical_op_cls = op_set[0].__class__
166
+ source_idx_to_record_sets = {
167
+ source_idx: list(map(lambda tup: tup[0], record_sets_and_ops))
168
+ for source_idx, record_sets_and_ops in source_idx_to_record_sets_and_ops.items()
169
+ }
170
+ source_idx_to_record_sets = self._score_quality(physical_op_cls, source_idx_to_record_sets, source_idx_to_target_record_set)
171
+
172
+ # flatten the lists of records and record_op_stats
173
+ all_records, all_record_op_stats = self._flatten_record_sets(source_idx_to_record_sets)
174
+
175
+ # update plan stats
176
+ plan_stats.add_record_op_stats(all_record_op_stats)
177
+
178
+ # add records (which are not filtered) to the cache, if allowed
179
+ self._add_records_to_cache(logical_op_id, all_records)
180
+
181
+ # FUTURE TODO: simply set input based on source_idx_to_target_record_set (b/c we won't have scores computed)
182
+ # provide the champion record sets as inputs to the next logical operator
183
+ if op_idx + 1 < len(plan):
184
+ next_logical_op_id = plan.logical_op_ids[op_idx + 1]
185
+ op_sets[next_logical_op_id].update_inputs(source_idx_to_record_sets)
186
+
187
+ # close the cache
188
+ self._close_cache(plan.logical_op_ids)
189
+
190
+ # finalize plan stats
191
+ plan_stats.finish()
192
+
193
+ return plan_stats
194
+
195
+ def execute_sentinel_plan(self, plan: SentinelPlan, expected_outputs: dict[int, dict] | None):
196
+ """
197
+ NOTE: this function currently requires us to set k and j properly in order to make
198
+ comparison in our research against the corresponding sample budget in MAB.
199
+
200
+ NOTE: the number of samples will slightly exceed the sample_budget if the number of operator
201
+ calls does not perfectly match the sample_budget. This may cause some minor discrepancies with
202
+ the progress manager as a result.
203
+ """
204
+ # for now, assert that the first operator in the plan is a ScanPhysicalOp
205
+ assert all(isinstance(op, ScanPhysicalOp) for op in plan.operator_sets[0]), "First operator in physical plan must be a ScanPhysicalOp"
206
+ logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
207
+ logger.info(f"Plan Details: {plan}")
208
+
209
+ # initialize plan stats
210
+ plan_stats = SentinelPlanStats.from_plan(plan)
211
+ plan_stats.start()
212
+
213
+ # get list of source indices which can be sampled from
214
+ source_indices = self._get_source_indices()
215
+
216
+ # initialize set of physical operators for each logical operator
217
+ op_sets = {
218
+ logical_op_id: OpSet(op_set, source_indices, self.k, self.j, self.seed)
219
+ for logical_op_id, op_set in plan
220
+ }
221
+
222
+ # initialize and start the progress manager
223
+ self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, progress=self.progress)
224
+ self.progress_manager.start()
225
+
226
+ # NOTE: we must handle progress manager outside of _exeecute_sentinel_plan to ensure that it is shut down correctly;
227
+ # if we don't have the `finally:` branch, then program crashes can cause future program runs to fail because
228
+ # the progress manager cannot get a handle to the console
229
+ try:
230
+ # execute sentinel plan by sampling records and operators
231
+ plan_stats = self._execute_sentinel_plan(plan, op_sets, expected_outputs, plan_stats)
232
+
233
+ finally:
234
+ # finish progress tracking
235
+ self.progress_manager.finish()
236
+
237
+ logger.info(f"Done executing sentinel plan: {plan.plan_id}")
238
+ logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
239
+
240
+ return plan_stats