palimpzest 0.7.21__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.21.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,22 @@
1
1
  import logging
2
- import multiprocessing
3
2
  from concurrent.futures import ThreadPoolExecutor, wait
4
3
 
5
4
  from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
6
- from palimpzest.core.data.dataclasses import PlanStats
7
- from palimpzest.core.elements.records import DataRecord, DataRecordSet
5
+ from palimpzest.core.elements.records import DataRecord
6
+ from palimpzest.core.models import PlanStats
8
7
  from palimpzest.query.execution.execution_strategy import ExecutionStrategy
9
8
  from palimpzest.query.operators.aggregate import AggregateOp
9
+ from palimpzest.query.operators.distinct import DistinctOp
10
+ from palimpzest.query.operators.join import JoinOp
10
11
  from palimpzest.query.operators.limit import LimitScanOp
11
12
  from palimpzest.query.operators.physical import PhysicalOperator
12
- from palimpzest.query.operators.scan import ScanPhysicalOp
13
+ from palimpzest.query.operators.scan import ContextScanOp, ScanPhysicalOp
13
14
  from palimpzest.query.optimizer.plan import PhysicalPlan
14
15
  from palimpzest.utils.progress import create_progress_manager
15
16
 
16
17
  logger = logging.getLogger(__name__)
17
18
 
19
+
18
20
  class ParallelExecutionStrategy(ExecutionStrategy):
19
21
  """
20
22
  A parallel execution strategy that processes data through a pipeline of operators using thread-based parallelism.
@@ -22,76 +24,72 @@ class ParallelExecutionStrategy(ExecutionStrategy):
22
24
 
23
25
  def __init__(self, *args, **kwargs):
24
26
  super().__init__(*args, **kwargs)
25
- self.max_workers = (
26
- self._get_parallel_max_workers()
27
- if self.max_workers is None
28
- else self.max_workers
29
- )
30
-
31
- def _get_parallel_max_workers(self):
32
- # for now, return the number of system CPUs;
33
- # in the future, we may want to consider the models the user has access to
34
- # and whether or not they will encounter rate-limits. If they will, we should
35
- # set the max workers in a manner that is designed to avoid hitting them.
36
- # Doing this "right" may require considering their logical, physical plan,
37
- # and tier status with LLM providers. It may also be worth dynamically
38
- # changing the max_workers in response to 429 errors.
39
- return max(int(0.8 * multiprocessing.cpu_count()), 1)
40
-
41
- def _any_queue_not_empty(self, queues: dict[str, list]) -> bool:
42
- """Helper function to check if any queue is not empty."""
43
- return any(len(queue) > 0 for queue in queues.values())
44
-
45
- def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list], future_queues: dict[str, list]) -> bool:
46
- """Helper function to check if all upstream operators have finished processing their inputs."""
47
- for upstream_op_idx in range(op_idx):
48
- upstream_full_op_id = plan.operators[upstream_op_idx].get_full_op_id()
49
- if len(input_queues[upstream_full_op_id]) > 0 or len(future_queues[upstream_full_op_id]) > 0:
50
- return False
51
-
52
- return True
53
27
 
54
- def _process_future_results(self, operator: PhysicalOperator, future_queues: dict[str, list], plan_stats: PlanStats) -> list[DataRecord]:
28
+ def _any_queue_not_empty(self, queues: dict[str, list] | dict[str, dict[str, list]]) -> bool:
29
+ """Helper function to check if any queue is not empty."""
30
+ for _, value in queues.items():
31
+ if isinstance(value, dict):
32
+ if any(len(subqueue) > 0 for subqueue in value.values()):
33
+ return True
34
+ elif len(value) > 0:
35
+ return True
36
+ return False
37
+
38
+ def _upstream_ops_finished(self, plan: PhysicalPlan, topo_idx: int, operator: PhysicalOperator, input_queues: dict[str, dict[str, list]], future_queues: dict[str, list]) -> bool:
39
+ """Helper function to check if agg / join operator is ready to process its inputs."""
40
+ # for agg / join operator, we can only process it when all upstream operators have finished processing their inputs
41
+ upstream_unique_full_op_ids = plan.get_upstream_unique_full_op_ids(topo_idx, operator)
42
+ upstream_input_queues = {upstream_unique_full_op_id: input_queues[upstream_unique_full_op_id] for upstream_unique_full_op_id in upstream_unique_full_op_ids}
43
+ upstream_future_queues = {upstream_unique_full_op_id: future_queues[upstream_unique_full_op_id] for upstream_unique_full_op_id in upstream_unique_full_op_ids}
44
+ return not (self._any_queue_not_empty(upstream_input_queues) or self._any_queue_not_empty(upstream_future_queues))
45
+
46
+ def _process_future_results(self, unique_full_op_id: str, future_queues: dict[str, list], plan_stats: PlanStats) -> list[DataRecord]:
55
47
  """
56
- Helper function which takes an operator, the future queues, and plan stats, and performs
48
+ Helper function which takes a full operator id, the future queues, and plan stats, and performs
57
49
  the updates to plan stats and progress manager before returning the results from the finished futures.
58
50
  """
59
- # get the op_id for the operator
60
- full_op_id = operator.get_full_op_id()
61
-
62
51
  # this function is called when the future queue is not empty
63
52
  # and the executor is not busy processing other futures
64
- done_futures, not_done_futures = wait(future_queues[full_op_id], timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
53
+ done_futures, not_done_futures = wait(future_queues[unique_full_op_id], timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
65
54
 
66
55
  # add the unfinished futures back to the previous op's future queue
67
- future_queues[full_op_id] = list(not_done_futures)
56
+ future_queues[unique_full_op_id] = list(not_done_futures)
68
57
 
69
58
  # add the finished futures to the input queue for this operator
70
- output_records = []
59
+ output_records, total_inputs_processed, total_cost = [], 0, 0.0
71
60
  for future in done_futures:
72
- record_set: DataRecordSet = future.result()
61
+ output = future.result()
62
+ record_set, num_inputs_processed = output if self.is_join_op[unique_full_op_id] else (output, 1)
63
+
64
+ # record set can be None if one side of join has no input records yet
65
+ if record_set is None:
66
+ continue
67
+
68
+ # otherwise, process records and their stats
73
69
  records = record_set.data_records
74
70
  record_op_stats = record_set.record_op_stats
75
- num_outputs = sum(record.passed_operator for record in records)
76
71
 
77
- # update the progress manager
78
- self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
72
+ # update the inputs processed and total cost
73
+ total_inputs_processed += num_inputs_processed
74
+ total_cost += record_set.get_total_cost()
79
75
 
80
76
  # update plan stats
81
- plan_stats.add_record_op_stats(record_op_stats)
82
-
83
- # add records to the cache
84
- self._add_records_to_cache(operator.target_cache_id, records)
77
+ plan_stats.add_record_op_stats(unique_full_op_id, record_op_stats)
85
78
 
86
79
  # add records which aren't filtered to the output records
87
80
  output_records.extend([record for record in records if record.passed_operator])
88
-
81
+
82
+ # update the progress manager
83
+ if total_inputs_processed > 0:
84
+ num_outputs = len(output_records)
85
+ self.progress_manager.incr(unique_full_op_id, num_inputs=total_inputs_processed, num_outputs=num_outputs, total_cost=total_cost)
86
+
89
87
  return output_records
90
88
 
91
89
  def _execute_plan(
92
90
  self,
93
91
  plan: PhysicalPlan,
94
- input_queues: dict[str, list],
92
+ input_queues: dict[str, dict[str, list]],
95
93
  future_queues: dict[str, list],
96
94
  plan_stats: PlanStats,
97
95
  ) -> tuple[list[DataRecord], PlanStats]:
@@ -103,56 +101,119 @@ class ParallelExecutionStrategy(ExecutionStrategy):
103
101
  # execute the plan until either:
104
102
  # 1. all records have been processed, or
105
103
  # 2. the final limit operation has completed (we break out of the loop if this happens)
106
- final_op = plan.operators[-1]
104
+ final_op = plan.operator
107
105
  while self._any_queue_not_empty(input_queues) or self._any_queue_not_empty(future_queues):
108
- for op_idx, operator in enumerate(plan.operators):
109
- full_op_id = operator.get_full_op_id()
106
+ for topo_idx, operator in enumerate(plan):
107
+ source_unique_full_op_ids = (
108
+ [f"source_{operator.get_full_op_id()}"]
109
+ if isinstance(operator, (ContextScanOp, ScanPhysicalOp))
110
+ else plan.get_source_unique_full_op_ids(topo_idx, operator)
111
+ )
112
+ unique_full_op_id = f"{topo_idx}-{operator.get_full_op_id()}"
110
113
 
111
114
  # get any finished futures from the previous operator and add them to the input queue for this operator
112
- if not isinstance(operator, ScanPhysicalOp):
113
- prev_operator = plan.operators[op_idx - 1]
114
- records = self._process_future_results(prev_operator, future_queues, plan_stats)
115
- input_queues[full_op_id].extend(records)
115
+ if not isinstance(operator, (ContextScanOp, ScanPhysicalOp)):
116
+ for source_unique_full_op_id in source_unique_full_op_ids:
117
+ records = self._process_future_results(source_unique_full_op_id, future_queues, plan_stats)
118
+ input_queues[unique_full_op_id][source_unique_full_op_id].extend(records)
116
119
 
117
120
  # for the final operator, add any finished futures to the output_records
118
- if full_op_id == final_op.get_full_op_id():
119
- records = self._process_future_results(operator, future_queues, plan_stats)
121
+ if unique_full_op_id == f"{topo_idx}-{final_op.get_full_op_id()}":
122
+ records = self._process_future_results(unique_full_op_id, future_queues, plan_stats)
120
123
  output_records.extend(records)
121
124
 
122
125
  # if this operator does not have enough inputs to execute, then skip it
123
- num_inputs = len(input_queues[full_op_id])
124
- agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, op_idx, input_queues, future_queues)
125
- if num_inputs == 0 or agg_op_not_ready:
126
+ num_inputs = sum(len(inputs) for inputs in input_queues[unique_full_op_id].values())
127
+ agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, topo_idx, operator, input_queues, future_queues)
128
+ join_op_not_ready = isinstance(operator, JoinOp) and not self.join_has_downstream_limit_op[unique_full_op_id] and not self._upstream_ops_finished(plan, topo_idx, operator, input_queues, future_queues)
129
+ if num_inputs == 0 or agg_op_not_ready or join_op_not_ready:
126
130
  continue
127
131
 
128
132
  # if this operator is an aggregate, process all the records in the input queue
129
133
  if isinstance(operator, AggregateOp):
130
- input_records = [input_queues[full_op_id].pop(0) for _ in range(num_inputs)]
134
+ source_unique_full_op_id = source_unique_full_op_ids[0]
135
+ input_records = [input_queues[unique_full_op_id][source_unique_full_op_id].pop(0) for _ in range(num_inputs)]
131
136
  future = executor.submit(operator, input_records)
132
- future_queues[full_op_id].append(future)
133
-
137
+ future_queues[unique_full_op_id].append(future)
138
+
139
+ # if this operator is a join, process all pairs of records from the two input queues
140
+ elif isinstance(operator, JoinOp):
141
+ left_unique_full_source_op_id = source_unique_full_op_ids[0]
142
+ left_num_inputs = len(input_queues[unique_full_op_id][left_unique_full_source_op_id])
143
+ left_input_records = [input_queues[unique_full_op_id][left_unique_full_source_op_id].pop(0) for _ in range(left_num_inputs)]
144
+
145
+ right_unique_full_source_op_id = source_unique_full_op_ids[1]
146
+ right_num_inputs = len(input_queues[unique_full_op_id][right_unique_full_source_op_id])
147
+ right_input_records = [input_queues[unique_full_op_id][right_unique_full_source_op_id].pop(0) for _ in range(right_num_inputs)]
148
+
149
+ # NOTE: it would be nice to use executor for join inputs here; but for now synchronizing may be necessary
150
+ # future = executor.submit(operator, left_input_records, right_input_records)
151
+ # future_queues[unique_full_op_id].append(future)
152
+ record_set, num_inputs_processed = operator(left_input_records, right_input_records)
153
+ def no_op(rset, num_inputs_processed):
154
+ return rset, num_inputs_processed
155
+ future = executor.submit(no_op, record_set, num_inputs_processed)
156
+ future_queues[unique_full_op_id].append(future)
157
+
158
+ # if this operator is a limit, process one record at a time
159
+ elif isinstance(operator, LimitScanOp):
160
+ source_unique_full_op_id = source_unique_full_op_ids[0]
161
+ num_records_to_process = min(len(input_queues[unique_full_op_id][source_unique_full_op_id]), operator.limit - len(output_records))
162
+ for _ in range(num_records_to_process):
163
+ input_record = input_queues[unique_full_op_id][source_unique_full_op_id].pop(0)
164
+ future = executor.submit(operator, input_record)
165
+ future_queues[unique_full_op_id].append(future)
166
+
167
+ # if this is the final operator, add any finished futures to the output_records
168
+ # immediately so that we can break out of the loop if we've reached the limit
169
+ if unique_full_op_id == f"{topo_idx}-{final_op.get_full_op_id()}":
170
+ records = self._process_future_results(unique_full_op_id, future_queues, plan_stats)
171
+ output_records.extend(records)
172
+
173
+ # if this operator is a distinct, process records sequentially
174
+ # (distinct is not parallelized because it requires maintaining a set of seen records)
175
+ elif isinstance(operator, DistinctOp):
176
+ source_unique_full_op_id = source_unique_full_op_ids[0]
177
+ input_records = input_queues[unique_full_op_id][source_unique_full_op_id]
178
+ for record in input_records:
179
+ record_set = operator(record)
180
+ def no_op(rset):
181
+ return rset
182
+ future = executor.submit(no_op, record_set)
183
+ future_queues[unique_full_op_id].append(future)
184
+
185
+ # clear the input queue for this operator since we processed all records
186
+ input_queues[unique_full_op_id][source_unique_full_op_id].clear()
187
+
188
+ # otherwise, process records according to batch size
134
189
  else:
135
- input_record = input_queues[full_op_id].pop(0)
136
- future = executor.submit(operator, input_record)
137
- future_queues[full_op_id].append(future)
138
-
190
+ source_unique_full_op_id = source_unique_full_op_ids[0]
191
+ input_records = input_queues[unique_full_op_id][source_unique_full_op_id]
192
+ if self.batch_size is None:
193
+ for input_record in input_records:
194
+ future = executor.submit(operator, input_record)
195
+ future_queues[unique_full_op_id].append(future)
196
+ input_queues[unique_full_op_id][source_unique_full_op_id].clear()
197
+ else:
198
+ batch_size = min(self.batch_size, len(input_records))
199
+ batch_input_records = input_records[:batch_size]
200
+ for input_record in batch_input_records:
201
+ future = executor.submit(operator, input_record)
202
+ future_queues[unique_full_op_id].append(future)
203
+ input_queues[unique_full_op_id][source_unique_full_op_id] = input_records[batch_size:]
204
+
205
+ # TODO: change logic to stop upstream operators once a limit is reached
139
206
  # break out of loop if the final operator is a LimitScanOp and we've reached its limit
140
207
  if isinstance(final_op, LimitScanOp) and len(output_records) == final_op.limit:
141
208
  break
142
209
 
143
- # close the cache
144
- self._close_cache([op.target_cache_id for op in plan.operators])
145
-
146
210
  # finalize plan stats
147
211
  plan_stats.finish()
148
212
 
149
213
  return output_records, plan_stats
150
214
 
151
-
152
215
  def execute_plan(self, plan: PhysicalPlan):
153
216
  """Initialize the stats and execute the plan."""
154
- # for now, assert that the first operator in the plan is a ScanPhysicalOp
155
- assert isinstance(plan.operators[0], ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
156
217
  logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
157
218
  logger.info(f"Plan Details: {plan}")
158
219
 
@@ -162,180 +223,28 @@ class ParallelExecutionStrategy(ExecutionStrategy):
162
223
 
163
224
  # initialize input queues and future queues for each operation
164
225
  input_queues = self._create_input_queues(plan)
165
- future_queues = {op.get_full_op_id(): [] for op in plan.operators}
166
-
167
- # initialize and start the progress manager
168
- self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
169
- self.progress_manager.start()
170
-
171
- # NOTE: we must handle progress manager outside of _exeecute_plan to ensure that it is shut down correctly;
172
- # if we don't have the `finally:` branch, then program crashes can cause future program runs to fail
173
- # because the progress manager cannot get a handle to the console
174
- try:
175
- # execute plan
176
- output_records, plan_stats = self._execute_plan(plan, input_queues, future_queues, plan_stats)
177
-
178
- finally:
179
- # finish progress tracking
180
- self.progress_manager.finish()
181
-
182
- logger.info(f"Done executing plan: {plan.plan_id}")
183
- logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
184
-
185
- return output_records, plan_stats
186
-
187
-
188
- class SequentialParallelExecutionStrategy(ExecutionStrategy):
189
- """
190
- A parallel execution strategy that processes operators sequentially.
191
- """
192
- def __init__(self, *args, **kwargs):
193
- super().__init__(*args, **kwargs)
194
- self.max_workers = (
195
- self._get_parallel_max_workers()
196
- if self.max_workers is None
197
- else self.max_workers
198
- )
199
-
200
- def _get_parallel_max_workers(self):
201
- # for now, return the number of system CPUs;
202
- # in the future, we may want to consider the models the user has access to
203
- # and whether or not they will encounter rate-limits. If they will, we should
204
- # set the max workers in a manner that is designed to avoid hitting them.
205
- # Doing this "right" may require considering their logical, physical plan,
206
- # and tier status with LLM providers. It may also be worth dynamically
207
- # changing the max_workers in response to 429 errors.
208
- return max(int(0.8 * multiprocessing.cpu_count()), 1)
209
-
210
- def _any_queue_not_empty(self, queues: dict[str, list]) -> bool:
211
- """Helper function to check if any queue is not empty."""
212
- return any(len(queue) > 0 for queue in queues.values())
213
-
214
- def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list], future_queues: dict[str, list]) -> bool:
215
- """Helper function to check if all upstream operators have finished processing their inputs."""
216
- for upstream_op_idx in range(op_idx):
217
- upstream_full_op_id = plan.operators[upstream_op_idx].get_full_op_id()
218
- if len(input_queues[upstream_full_op_id]) > 0 or len(future_queues[upstream_full_op_id]) > 0:
219
- return False
220
-
221
- return True
222
-
223
- def _process_future_results(self, operator: PhysicalOperator, future_queues: dict[str, list], plan_stats: PlanStats) -> list[DataRecord]:
224
- """
225
- Helper function which takes an operator, the future queues, and plan stats, and performs
226
- the updates to plan stats and progress manager before returning the results from the finished futures.
227
- """
228
- # get the op_id for the operator
229
- full_op_id = operator.get_full_op_id()
230
-
231
- # this function is called when the future queue is not empty
232
- # and the executor is not busy processing other futures
233
- done_futures, not_done_futures = wait(future_queues[full_op_id], timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
234
-
235
- # add the unfinished futures back to the previous op's future queue
236
- future_queues[full_op_id] = list(not_done_futures)
237
-
238
- # add the finished futures to the input queue for this operator
239
- output_records = []
240
- for future in done_futures:
241
- record_set: DataRecordSet = future.result()
242
- records = record_set.data_records
243
- record_op_stats = record_set.record_op_stats
244
- num_outputs = sum(record.passed_operator for record in records)
245
-
246
- # update the progress manager
247
- self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
248
-
249
- # update plan stats
250
- plan_stats.add_record_op_stats(record_op_stats)
251
-
252
- # add records to the cache
253
- self._add_records_to_cache(operator.target_cache_id, records)
254
-
255
- # add records which aren't filtered to the output records
256
- output_records.extend([record for record in records if record.passed_operator])
257
-
258
- return output_records
259
-
260
- def _execute_plan(
261
- self,
262
- plan: PhysicalPlan,
263
- input_queues: dict[str, list],
264
- future_queues: dict[str, list],
265
- plan_stats: PlanStats,
266
- ) -> tuple[list[DataRecord], PlanStats]:
267
- # process all of the input records using a thread pool
268
- output_records = []
269
- with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
270
- logger.debug(f"Created thread pool with {self.max_workers} workers")
271
-
272
- # execute the plan until either:
273
- # 1. all records have been processed, or
274
- # 2. the final limit operation has completed (we break out of the loop if this happens)
275
- final_op = plan.operators[-1]
276
- for op_idx, operator in enumerate(plan.operators):
277
- full_op_id = operator.get_full_op_id()
278
- input_queue = input_queues[full_op_id]
279
-
280
- # if this operator is an aggregate, process all the records in the input queue
281
- if isinstance(operator, AggregateOp):
282
- num_inputs = len(input_queue)
283
- input_records = [input_queue.pop(0) for _ in range(num_inputs)]
284
- future = executor.submit(operator, input_records)
285
- future_queues[full_op_id].append(future)
286
-
287
- else:
288
- while len(input_queue) > 0:
289
- input_record = input_queue.pop(0)
290
- future = executor.submit(operator, input_record)
291
- future_queues[full_op_id].append(future)
292
-
293
- # block until all futures for this operator have completed; and add finished futures to next operator's input
294
- while len(future_queues[full_op_id]) > 0:
295
- records = self._process_future_results(operator, future_queues, plan_stats)
296
-
297
- # get any finished futures from the previous operator and add them to the input queue for this operator
298
- if full_op_id != final_op.get_full_op_id():
299
- next_op_id = plan.operators[op_idx + 1].get_full_op_id()
300
- input_queues[next_op_id].extend(records)
301
-
302
- # for the final operator, add any finished futures to the output_records
303
- else:
304
- output_records.extend(records)
305
-
306
- # break out of loop if the final operator is a LimitScanOp and we've reached its limit
307
- if isinstance(final_op, LimitScanOp) and len(output_records) == final_op.limit:
308
- break
309
-
310
- # close the cache
311
- self._close_cache([op.target_cache_id for op in plan.operators])
312
-
313
- # finalize plan stats
314
- plan_stats.finish()
315
-
316
- return output_records, plan_stats
317
-
318
-
319
- def execute_plan(self, plan: PhysicalPlan):
320
- """Initialize the stats and execute the plan."""
321
- # for now, assert that the first operator in the plan is a ScanPhysicalOp
322
- assert isinstance(plan.operators[0], ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
323
- logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
324
- logger.info(f"Plan Details: {plan}")
325
-
326
- # initialize plan stats
327
- plan_stats = PlanStats.from_plan(plan)
328
- plan_stats.start()
329
-
330
- # initialize input queues and future queues for each operation
331
- input_queues = self._create_input_queues(plan)
332
- future_queues = {op.get_full_op_id(): [] for op in plan.operators}
226
+ future_queues = {f"{topo_idx}-{op.get_full_op_id()}": [] for topo_idx, op in enumerate(plan)}
227
+
228
+ # precompute which operators are joins and which joins have downstream limit ops
229
+ self.is_join_op = {f"{topo_idx}-{op.get_full_op_id()}": isinstance(op, JoinOp) for topo_idx, op in enumerate(plan)}
230
+ self.join_has_downstream_limit_op = {}
231
+ for topo_idx, op in enumerate(plan):
232
+ if isinstance(op, JoinOp):
233
+ unique_full_op_id = f"{topo_idx}-{op.get_full_op_id()}"
234
+ has_downstream_limit_op = False
235
+ for inner_topo_idx, op in enumerate(plan):
236
+ if inner_topo_idx <= topo_idx:
237
+ continue
238
+ if isinstance(op, LimitScanOp):
239
+ has_downstream_limit_op = True
240
+ break
241
+ self.join_has_downstream_limit_op[unique_full_op_id] = has_downstream_limit_op
333
242
 
334
243
  # initialize and start the progress manager
335
244
  self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
336
245
  self.progress_manager.start()
337
246
 
338
- # NOTE: we must handle progress manager outside of _exeecute_plan to ensure that it is shut down correctly;
247
+ # NOTE: we must handle progress manager outside of _execute_plan to ensure that it is shut down correctly;
339
248
  # if we don't have the `finally:` branch, then program crashes can cause future program runs to fail
340
249
  # because the progress manager cannot get a handle to the console
341
250
  try: