palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.20.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -2,11 +2,7 @@ from enum import Enum
2
2
 
3
3
  from palimpzest.query.execution.all_sample_execution_strategy import AllSamplingExecutionStrategy
4
4
  from palimpzest.query.execution.mab_execution_strategy import MABExecutionStrategy
5
- from palimpzest.query.execution.parallel_execution_strategy import (
6
- ParallelExecutionStrategy,
7
- SequentialParallelExecutionStrategy,
8
- )
9
- from palimpzest.query.execution.random_sampling_execution_strategy import RandomSamplingExecutionStrategy
5
+ from palimpzest.query.execution.parallel_execution_strategy import ParallelExecutionStrategy
10
6
  from palimpzest.query.execution.single_threaded_execution_strategy import (
11
7
  PipelinedSingleThreadExecutionStrategy,
12
8
  SequentialSingleThreadExecutionStrategy,
@@ -18,9 +14,11 @@ class ExecutionStrategyType(Enum):
18
14
  SEQUENTIAL = SequentialSingleThreadExecutionStrategy
19
15
  PIPELINED = PipelinedSingleThreadExecutionStrategy
20
16
  PARALLEL = ParallelExecutionStrategy
21
- SEQUENTIAL_PARALLEL = SequentialParallelExecutionStrategy
17
+
18
+ def is_fully_parallel(self) -> bool:
19
+ """Check if the execution strategy executes operators in parallel."""
20
+ return self == ExecutionStrategyType.PARALLEL
22
21
 
23
22
  class SentinelExecutionStrategyType(Enum):
24
23
  MAB = MABExecutionStrategy
25
- RANDOM = RandomSamplingExecutionStrategy
26
24
  ALL = AllSamplingExecutionStrategy
@@ -1,19 +1,26 @@
1
+
1
2
  import logging
2
3
 
3
4
  import numpy as np
4
5
 
5
- from palimpzest.core.data.dataclasses import OperatorStats, SentinelPlanStats
6
+ from palimpzest.core.data.dataset import Dataset
6
7
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
8
+ from palimpzest.core.models import OperatorStats, RecordOpStats, SentinelPlanStats
7
9
  from palimpzest.policy import Policy
8
10
  from palimpzest.query.execution.execution_strategy import SentinelExecutionStrategy
11
+ from palimpzest.query.operators.aggregate import AggregateOp
9
12
  from palimpzest.query.operators.filter import FilterOp
13
+ from palimpzest.query.operators.join import JoinOp
10
14
  from palimpzest.query.operators.physical import PhysicalOperator
11
- from palimpzest.query.operators.scan import ScanPhysicalOp
15
+ from palimpzest.query.operators.scan import ContextScanOp, ScanPhysicalOp
12
16
  from palimpzest.query.optimizer.plan import SentinelPlan
13
17
  from palimpzest.utils.progress import create_progress_manager
18
+ from palimpzest.validator.validator import Validator
14
19
 
15
20
  logger = logging.getLogger(__name__)
16
21
 
22
+ # NOTE: we currently do not support Sentinel Plans with aggregates or limits which are not the final plan operator
23
+
17
24
  class OpFrontier:
18
25
  """
19
26
  This class represents the set of operators which are currently in the frontier for a given logical operator.
@@ -23,11 +30,24 @@ class OpFrontier:
23
30
  2. has been sampled fewer than j times
24
31
  """
25
32
 
26
- def __init__(self, op_set: list[PhysicalOperator], source_indices: list[int], k: int, j: int, seed: int, policy: Policy, priors: dict | None = None):
33
+ def __init__(
34
+ self,
35
+ op_set: list[PhysicalOperator],
36
+ source_unique_logical_op_ids: list[str],
37
+ root_dataset_ids: list[str],
38
+ source_indices: list[tuple],
39
+ k: int,
40
+ j: int,
41
+ seed: int,
42
+ policy: Policy,
43
+ priors: dict | None = None,
44
+ ):
27
45
  # set k and j, which are the initial number of operators in the frontier and the
28
46
  # initial number of records to sample for each frontier operator
29
47
  self.k = min(k, len(op_set))
30
- self.j = min(j, len(source_indices))
48
+ self.j = j
49
+ self.source_indices = source_indices
50
+ self.root_dataset_ids = root_dataset_ids
31
51
 
32
52
  # store the policy that we are optimizing under
33
53
  self.policy = policy
@@ -43,18 +63,25 @@ class OpFrontier:
43
63
  self.reservoir_ops = [op_set[sample_idx] for sample_idx in sample_op_indices[self.k:]]
44
64
  self.off_frontier_ops: list[PhysicalOperator] = []
45
65
 
46
- # store the order in which we will sample the source records
47
- self.source_indices = source_indices
48
-
49
- # keep track of the source ids processed by each physical operator
66
+ # keep track of the source indices processed by each physical operator
50
67
  self.full_op_id_to_sources_processed = {op.get_full_op_id(): set() for op in op_set}
51
-
52
- # set the initial inputs for this logical operator
53
- is_scan_op = isinstance(op_set[0], ScanPhysicalOp)
54
- self.source_idx_to_input = {source_idx: [source_idx] for source_idx in self.source_indices} if is_scan_op else {}
55
-
56
- # boolean indication of whether this is a logical filter
57
- self.is_filter_op = isinstance(op_set[0], FilterOp)
68
+ self.full_op_id_to_sources_not_processed = {op.get_full_op_id(): source_indices for op in op_set}
69
+ self.max_inputs = len(source_indices)
70
+
71
+ # boolean indication of the type of operator in this OpFrontier
72
+ sample_op = op_set[0]
73
+ self.is_scan_op = isinstance(sample_op, (ScanPhysicalOp, ContextScanOp))
74
+ self.is_filter_op = isinstance(sample_op, FilterOp)
75
+ self.is_aggregate_op = isinstance(sample_op, AggregateOp)
76
+ self.is_llm_join = isinstance(sample_op, JoinOp)
77
+
78
+ # set the initial inputs for this logical operator; we maintain a mapping from source_unique_logical_op_id --> source_indices --> input;
79
+ # for each unique source and (tuple of) source indices, we store its output, which is an input to this operator
80
+ # for scan operators, we use the default name "source" since these operators have no source
81
+ self.source_indices_to_inputs = {source_unique_logical_op_id: {} for source_unique_logical_op_id in source_unique_logical_op_ids}
82
+ if self.is_scan_op:
83
+ self.source_indices_to_inputs["source"] = {source_idx: [int(source_idx.split("-")[-1])] for source_idx in source_indices}
84
+
58
85
 
59
86
  def get_frontier_ops(self) -> list[PhysicalOperator]:
60
87
  """
@@ -180,71 +207,126 @@ class OpFrontier:
180
207
 
181
208
  return op_indices
182
209
 
183
- def _get_op_source_idx_pairs(self) -> list[tuple[PhysicalOperator, int]]:
210
+ def _get_op_source_indices_pairs(self) -> list[tuple[PhysicalOperator, tuple[str] | None]]:
184
211
  """
185
- Returns a list of tuples for (op, source_idx) which this operator needs to execute
212
+ Returns a list of tuples for (op, source_indices) which this operator needs to execute
186
213
  in the next iteration.
187
214
  """
188
- op_source_idx_pairs = []
189
- for op in self.frontier_ops:
190
- # execute new operators on first j source indices, and previously sampled operators on one additional source_idx
191
- num_processed = len(self.full_op_id_to_sources_processed[op.get_full_op_id()])
192
- num_new_samples = 1 if num_processed > 0 else self.j
193
- num_new_samples = min(num_new_samples, len(self.source_indices) - num_processed)
194
- assert num_new_samples >= 0, "Number of new samples must be non-negative"
195
-
196
- # construct list of inputs by looking up the input for the given source_idx
197
- samples_added = 0
198
- for source_idx in self.source_indices:
199
- if source_idx in self.full_op_id_to_sources_processed[op.get_full_op_id()]:
200
- continue
215
+ op_source_indices_pairs = []
201
216
 
202
- if samples_added == num_new_samples:
203
- break
217
+ # if this operator is not being optimized: we don't request inputs, but simply process what we are given / told to (in the case of scans)
218
+ if not self.is_llm_join and len(self.frontier_ops) == 1:
219
+ return [(self.frontier_ops[0], None)]
204
220
 
205
- # construct the (op, source_idx) for this source_idx
206
- op_source_idx_pairs.append((op, source_idx))
207
- samples_added += 1
208
-
209
- return op_source_idx_pairs
210
-
211
- def get_source_indices_for_next_iteration(self) -> set[int]:
221
+ # otherwise, sample (operator, source_indices) pairs
222
+ for op in self.frontier_ops:
223
+ # execute new operators on first j indices per root dataset, and previously sampled operators on one per root dataset
224
+ new_operator = self.full_op_id_to_sources_processed[op.get_full_op_id()] == set()
225
+ samples_per_root_dataset = self.j if new_operator else 1
226
+ num_root_datasets = len(self.root_dataset_ids)
227
+ num_samples = samples_per_root_dataset**num_root_datasets
228
+ samples = self.full_op_id_to_sources_not_processed[op.get_full_op_id()][:num_samples]
229
+ for source_indices in samples:
230
+ op_source_indices_pairs.append((op, source_indices))
231
+
232
+ return op_source_indices_pairs
233
+
234
+ def get_source_indices_for_next_iteration(self) -> set[tuple[str]]:
212
235
  """
213
236
  Returns the set of source indices which need to be sampled for the next iteration.
214
237
  """
215
- op_source_idx_pairs = self._get_op_source_idx_pairs()
216
- return set(map(lambda tup: tup[1], op_source_idx_pairs))
238
+ op_source_indices_pairs = self._get_op_source_indices_pairs()
239
+ return set([source_indices for _, source_indices in op_source_indices_pairs if source_indices is not None])
217
240
 
218
- def get_frontier_op_input_pairs(self, source_indices_to_sample: set[int], max_quality_op: PhysicalOperator) -> list[PhysicalOperator, DataRecord | int | None]:
241
+ def get_frontier_op_inputs(self, source_indices_to_sample: set[tuple[str]], max_quality_op: PhysicalOperator) -> list[tuple[PhysicalOperator, tuple[str], list[DataRecord] | list[int] | None]]:
219
242
  """
220
243
  Returns the list of frontier operators and their next input to process. If there are
221
244
  any indices in `source_indices_to_sample` which this operator does not sample on its own, then
222
- we also have this frontier process that source_idx's input with its max quality operator.
245
+ we also have this frontier process those source indices' input with its max quality operator.
223
246
  """
224
- # get the list of (op, source_idx) pairs which this operator needs to execute
225
- op_source_idx_pairs = self._get_op_source_idx_pairs()
226
-
227
- # if there are any source_idxs in source_indices_to_sample which are not sampled
228
- # by this operator, apply the max quality operator (and any other frontier operators
229
- # with no samples)
230
- sampled_source_indices = set(map(lambda tup: tup[1], op_source_idx_pairs))
247
+ # if this is an aggregate, run on every input
248
+ if self.is_aggregate_op:
249
+ # NOTE: we don't keep track of source indices for aggregate (would require computing powerset of all source records);
250
+ # thus, we cannot currently support optimizing plans w/LLM operators after aggregations
251
+ op = self.frontier_ops[0]
252
+ all_inputs = []
253
+ for _, source_indices_to_inputs in self.source_indices_to_inputs.items():
254
+ for _, inputs in source_indices_to_inputs.items():
255
+ all_inputs.extend(inputs)
256
+ return [(op, tuple(), all_inputs)]
257
+
258
+ # if this is an un-optimized (non-scan, non-join) operator, flatten inputs and run on each one
259
+ elif not self.is_scan_op and not self.is_llm_join and len(self.frontier_ops) == 1:
260
+ op_inputs = []
261
+ op = self.frontier_ops[0]
262
+ for _, source_indices_to_inputs in self.source_indices_to_inputs.items():
263
+ for source_indices, inputs in source_indices_to_inputs.items():
264
+ for input in inputs:
265
+ op_inputs.append((op, source_indices, input))
266
+ return op_inputs
267
+
268
+ ### for optimized operators
269
+ # get the list of (op, source_indices) pairs which this operator needs to execute
270
+ op_source_indices_pairs = self._get_op_source_indices_pairs()
271
+
272
+ # remove any root datasets which this op frontier does not have access to from the source_indices_to_sample
273
+ def remove_unavailable_root_datasets(source_indices: str | tuple) -> str | tuple | None:
274
+ # base case: source_indices is a string
275
+ if isinstance(source_indices, str):
276
+ return source_indices if source_indices.split("-")[0] in self.root_dataset_ids else None
277
+
278
+ # recursive case: source_indices is a tuple
279
+ left_indices = source_indices[0]
280
+ right_indices = source_indices[1]
281
+ left_filtered = remove_unavailable_root_datasets(left_indices)
282
+ right_filtered = remove_unavailable_root_datasets(right_indices)
283
+ if left_filtered is None and right_filtered is None:
284
+ return None
285
+
286
+ if left_filtered is None:
287
+ return right_filtered
288
+ if right_filtered is None:
289
+ return left_filtered
290
+ return (left_filtered, right_filtered)
291
+
292
+ source_indices_to_sample = {remove_unavailable_root_datasets(source_indices) for source_indices in source_indices_to_sample}
293
+
294
+ # if there are any source_indices in source_indices_to_sample which are not sampled by this operator,
295
+ # apply the max quality operator (and any other frontier operators with no samples)
296
+ sampled_source_indices = set(map(lambda tup: tup[1], op_source_indices_pairs))
231
297
  unsampled_source_indices = source_indices_to_sample - sampled_source_indices
232
- for source_idx in unsampled_source_indices:
233
- op_source_idx_pairs.append((max_quality_op, source_idx))
298
+ for source_indices in unsampled_source_indices:
299
+ op_source_indices_pairs.append((max_quality_op, source_indices))
234
300
  for op in self.frontier_ops:
235
- if len(self.full_op_id_to_sources_processed[op.get_full_op_id()]) == 0 and op.get_full_op_id() != max_quality_op.get_full_op_id():
236
- op_source_idx_pairs.append((op, source_idx))
237
-
238
- # fetch the corresponding (op, input) pairs
239
- op_input_pairs = [
240
- (op, input)
241
- for op, source_idx in op_source_idx_pairs
242
- for input in self.source_idx_to_input[source_idx]
301
+ if self.full_op_id_to_sources_processed[op.get_full_op_id()] == set() and op.get_full_op_id() != max_quality_op.get_full_op_id():
302
+ op_source_indices_pairs.append((op, source_indices))
303
+
304
+ # construct the op inputs
305
+ op_inputs = []
306
+ if self.is_llm_join:
307
+ left_source_unique_logical_op_id, right_source_unique_logical_op_id = list(self.source_indices_to_inputs)
308
+ left_source_indices_to_inputs = self.source_indices_to_inputs[left_source_unique_logical_op_id]
309
+ right_source_indices_to_inputs = self.source_indices_to_inputs[right_source_unique_logical_op_id]
310
+ for op, source_indices in op_source_indices_pairs:
311
+ left_source_indices = source_indices[0]
312
+ right_source_indices = source_indices[1]
313
+ left_inputs = left_source_indices_to_inputs.get(left_source_indices, [])
314
+ right_inputs = right_source_indices_to_inputs.get(right_source_indices, [])
315
+ if len(left_inputs) > 0 and len(right_inputs) > 0:
316
+ op_inputs.append((op, (left_source_indices, right_source_indices), (left_inputs, right_inputs)))
317
+ return op_inputs
318
+
319
+ # if operator is not a join
320
+ source_unique_logical_op_id = list(self.source_indices_to_inputs)[0]
321
+ op_inputs = [
322
+ (op, source_indices, input)
323
+ for op, source_indices in op_source_indices_pairs
324
+ for input in self.source_indices_to_inputs[source_unique_logical_op_id].get(source_indices, [])
243
325
  ]
244
326
 
245
- return op_input_pairs
327
+ return op_inputs
246
328
 
247
- def update_frontier(self, logical_op_id: str, plan_stats: SentinelPlanStats) -> None:
329
+ def update_frontier(self, unique_logical_op_id: str, plan_stats: SentinelPlanStats) -> None:
248
330
  """
249
331
  Update the set of frontier operators, pulling in new ones from the reservoir as needed.
250
332
  This function will:
@@ -256,8 +338,8 @@ class OpFrontier:
256
338
  # upstream operators change; in this case, we de-duplicate record_op_stats with identical record_ids
257
339
  # and keep the one with the maximum quality
258
340
  # get a mapping from full_op_id --> list[RecordOpStats]
259
- full_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(logical_op_id, {})
260
- full_op_id_to_record_op_stats = {}
341
+ full_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(unique_logical_op_id, {})
342
+ full_op_id_to_record_op_stats: dict[str, list[RecordOpStats]] = {}
261
343
  for full_op_id, op_stats in full_op_id_to_op_stats.items():
262
344
  # skip over operators which have not been sampled
263
345
  if len(op_stats.record_op_stats_lst) == 0:
@@ -281,8 +363,23 @@ class OpFrontier:
281
363
  full_op_id_to_num_samples, total_num_samples = {}, 0
282
364
  for full_op_id, record_op_stats_lst in full_op_id_to_record_op_stats.items():
283
365
  # update the set of source indices processed
366
+ source_indices_processed = set()
284
367
  for record_op_stats in record_op_stats_lst:
285
- self.full_op_id_to_sources_processed[full_op_id].add(record_op_stats.record_source_idx)
368
+ source_indices = record_op_stats.record_source_indices
369
+
370
+ if len(source_indices) == 1:
371
+ source_indices = source_indices[0]
372
+ elif self.is_llm_join or self.is_aggregate_op:
373
+ source_indices = tuple(source_indices)
374
+
375
+ self.full_op_id_to_sources_processed[full_op_id].add(source_indices)
376
+ source_indices_processed.add(source_indices)
377
+
378
+ # update the set of source indices not processed
379
+ self.full_op_id_to_sources_not_processed[full_op_id] = [
380
+ indices for indices in self.full_op_id_to_sources_not_processed[full_op_id]
381
+ if indices not in source_indices_processed
382
+ ]
286
383
 
287
384
  # compute the number of samples as the number of source indices processed
288
385
  num_samples = len(self.full_op_id_to_sources_processed[full_op_id])
@@ -290,11 +387,20 @@ class OpFrontier:
290
387
  total_num_samples += num_samples
291
388
 
292
389
  # compute avg. selectivity, cost, time, and quality for each physical operator
293
- def total_output(record_op_stats_lst):
390
+ def total_output(record_op_stats_lst: list[RecordOpStats]):
294
391
  return sum([record_op_stats.passed_operator for record_op_stats in record_op_stats_lst])
295
392
 
296
- def total_input(record_op_stats_lst):
297
- return len(set([record_op_stats.record_parent_id for record_op_stats in record_op_stats_lst]))
393
+ def total_input(record_op_stats_lst: list[RecordOpStats]):
394
+ # TODO: this is okay for now because we only really need these calculations for Converts and Filters,
395
+ # but this will need more thought if/when we optimize joins
396
+ all_parent_ids = []
397
+ for record_op_stats in record_op_stats_lst:
398
+ all_parent_ids.extend(
399
+ [None]
400
+ if record_op_stats.record_parent_ids is None
401
+ else record_op_stats.record_parent_ids
402
+ )
403
+ return len(set(all_parent_ids))
298
404
 
299
405
  full_op_id_to_mean_selectivity = {
300
406
  full_op_id: total_output(record_op_stats_lst) / total_input(record_op_stats_lst)
@@ -309,7 +415,7 @@ class OpFrontier:
309
415
  for full_op_id, record_op_stats_lst in full_op_id_to_record_op_stats.items()
310
416
  }
311
417
  full_op_id_to_mean_quality = {
312
- full_op_id: np.mean([record_op_stats.quality for record_op_stats in record_op_stats_lst])
418
+ full_op_id: np.mean([record_op_stats.quality for record_op_stats in record_op_stats_lst if record_op_stats.quality is not None])
313
419
  for full_op_id, record_op_stats_lst in full_op_id_to_record_op_stats.items()
314
420
  }
315
421
 
@@ -373,7 +479,7 @@ class OpFrontier:
373
479
  for full_op_id, metrics in op_metrics.items():
374
480
 
375
481
  # if this op is fully sampled, do not keep it on the frontier
376
- if full_op_id_to_num_samples[full_op_id] == len(self.source_indices):
482
+ if len(self.full_op_id_to_sources_processed[full_op_id]) == self.max_inputs:
377
483
  continue
378
484
 
379
485
  # if this op is pareto optimal keep it in our frontier ops
@@ -455,10 +561,10 @@ class OpFrontier:
455
561
  out_record_op_stats = []
456
562
  for idx in range(len(idx_to_records)):
457
563
  records_lst, record_op_stats_lst = zip(*idx_to_records[idx])
458
- max_quality_record, max_quality = records_lst[0], record_op_stats_lst[0].quality
564
+ max_quality_record, max_quality = records_lst[0], record_op_stats_lst[0].quality if record_op_stats_lst[0].quality is not None else 0.0
459
565
  max_quality_stats = record_op_stats_lst[0]
460
566
  for record, record_op_stats in zip(records_lst[1:], record_op_stats_lst[1:]):
461
- record_quality = record_op_stats.quality
567
+ record_quality = record_op_stats.quality if record_op_stats.quality is not None else 0.0
462
568
  if record_quality > max_quality:
463
569
  max_quality_record = record
464
570
  max_quality = record_quality
@@ -469,22 +575,19 @@ class OpFrontier:
469
575
  # create and return final DataRecordSet
470
576
  return DataRecordSet(out_records, out_record_op_stats)
471
577
 
472
- def update_inputs(self, source_idx_to_record_sets: dict[int, DataRecordSet]):
578
+ def update_inputs(self, source_unique_logical_op_id: str, source_indices_to_record_sets: dict[tuple[int], list[DataRecordSet]]):
473
579
  """
474
580
  Update the inputs for this logical operator based on the outputs of the previous logical operator.
475
581
  """
476
- for source_idx, record_sets in source_idx_to_record_sets.items():
582
+ for source_indices, record_sets in source_indices_to_record_sets.items():
477
583
  input = []
478
584
  max_quality_record_set = self.pick_highest_quality_output(record_sets)
479
585
  for record in max_quality_record_set:
480
586
  input.append(record if record.passed_operator else None)
481
587
 
482
- self.source_idx_to_input[source_idx] = input
588
+ self.source_indices_to_inputs[source_unique_logical_op_id][source_indices] = input
483
589
 
484
590
 
485
- # TODO: post-submission we will need to modify this to:
486
- # - submit all inputs for aggregate operators
487
- # - handle limits
488
591
  class MABExecutionStrategy(SentinelExecutionStrategy):
489
592
  """
490
593
  This class implements the Multi-Armed Bandit (MAB) execution strategy for SentinelQueryProcessors.
@@ -493,15 +596,15 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
493
596
  calls does not perfectly match the sample_budget. This may cause some minor discrepancies with
494
597
  the progress manager as a result.
495
598
  """
496
- def _get_max_quality_op(self, logical_op_id: str, op_frontiers: dict[str, OpFrontier], plan_stats: SentinelPlanStats) -> PhysicalOperator:
599
+ def _get_max_quality_op(self, unique_logical_op_id: str, op_frontiers: dict[str, OpFrontier], plan_stats: SentinelPlanStats) -> PhysicalOperator:
497
600
  """
498
601
  Returns the operator in the frontier with the highest (estimated) quality.
499
602
  """
500
603
  # get the operators in the frontier set for this logical_op_id
501
- frontier_ops = op_frontiers[logical_op_id].get_frontier_ops()
604
+ frontier_ops = op_frontiers[unique_logical_op_id].get_frontier_ops()
502
605
 
503
606
  # get a mapping from full_op_id --> list[RecordOpStats]
504
- full_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(logical_op_id, {})
607
+ full_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(unique_logical_op_id, {})
505
608
  full_op_id_to_record_op_stats = {
506
609
  full_op_id: op_stats.record_op_stats_lst
507
610
  for full_op_id, op_stats in full_op_id_to_op_stats.items()
@@ -524,7 +627,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
524
627
  self,
525
628
  plan: SentinelPlan,
526
629
  op_frontiers: dict[str, OpFrontier],
527
- expected_outputs: dict[int, dict] | None,
630
+ validator: Validator,
528
631
  plan_stats: SentinelPlanStats,
529
632
  ) -> SentinelPlanStats:
530
633
  # sample records and operators and update the frontiers
@@ -537,64 +640,54 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
537
640
  source_indices_to_sample.update(source_indices)
538
641
 
539
642
  # execute operator sets in sequence
540
- for op_idx, (logical_op_id, op_set) in enumerate(plan):
643
+ for topo_idx, (logical_op_id, _) in enumerate(plan):
644
+ # compute unique logical op id within plan
645
+ unique_logical_op_id = f"{topo_idx}-{logical_op_id}"
646
+
541
647
  # use the execution cache to determine the maximum quality operator for this logical_op_id
542
- max_quality_op = self._get_max_quality_op(logical_op_id, op_frontiers, plan_stats)
648
+ max_quality_op = self._get_max_quality_op(unique_logical_op_id, op_frontiers, plan_stats)
543
649
 
544
- # TODO: can have None as an operator if _get_max_quality_op returns None
545
650
  # get frontier ops and their next input
546
- frontier_op_input_pairs = op_frontiers[logical_op_id].get_frontier_op_input_pairs(source_indices_to_sample, max_quality_op)
547
- frontier_op_input_pairs = list(filter(lambda tup: tup[1] is not None, frontier_op_input_pairs))
651
+ frontier_op_inputs = op_frontiers[unique_logical_op_id].get_frontier_op_inputs(source_indices_to_sample, max_quality_op)
652
+ frontier_op_inputs = list(filter(lambda tup: tup[-1] is not None, frontier_op_inputs))
548
653
 
549
- # break out of the loop if frontier_op_input_pairs is empty, as this means all records have been filtered out
550
- if len(frontier_op_input_pairs) == 0:
654
+ # break out of the loop if frontier_op_inputs is empty, as this means all records have been filtered out
655
+ if len(frontier_op_inputs) == 0:
551
656
  break
552
657
 
553
658
  # run sampled operators on sampled inputs and update the number of samples drawn
554
- source_idx_to_record_set_tuples, num_llm_ops = self._execute_op_set(frontier_op_input_pairs)
659
+ source_indices_to_record_set_tuples, num_llm_ops = self._execute_op_set(unique_logical_op_id, frontier_op_inputs)
555
660
  samples_drawn += num_llm_ops
556
661
 
557
- # FUTURE TODO: have this return the highest quality record set simply based on our posterior (or prior) belief on operator quality
558
- # get the target record set for each source_idx
559
- source_idx_to_target_record_set = self._get_target_record_sets(logical_op_id, source_idx_to_record_set_tuples, expected_outputs)
560
-
561
- # FUTURE TODO: move this outside of the loop (i.e. assume we only get quality label(s) after executing full program)
562
662
  # score the quality of each generated output
563
- physical_op_cls = op_set[0].__class__
564
- source_idx_to_all_record_sets = {
565
- source_idx: [record_set for record_set, _, _ in record_set_tuples]
566
- for source_idx, record_set_tuples in source_idx_to_record_set_tuples.items()
663
+ source_indices_to_all_record_sets = {
664
+ source_indices: [(record_set, op) for record_set, op, _ in record_set_tuples]
665
+ for source_indices, record_set_tuples in source_indices_to_record_set_tuples.items()
567
666
  }
568
- source_idx_to_all_record_sets = self._score_quality(physical_op_cls, source_idx_to_all_record_sets, source_idx_to_target_record_set)
667
+ source_indices_to_all_record_sets, val_gen_stats = self._score_quality(validator, source_indices_to_all_record_sets)
569
668
 
570
- # flatten the lists of newly computed records and record_op_stats
571
- source_idx_to_new_record_sets = {
572
- source_idx: [record_set for record_set, _, is_new in record_set_tuples if is_new]
573
- for source_idx, record_set_tuples in source_idx_to_record_set_tuples.items()
574
- }
575
- new_records, new_record_op_stats = self._flatten_record_sets(source_idx_to_new_record_sets)
576
-
577
- # update the number of samples drawn for each operator
669
+ # remove records that were read from the execution cache before adding to record op stats
670
+ new_record_op_stats = []
671
+ for _, record_set_tuples in source_indices_to_record_set_tuples.items():
672
+ for record_set, _, is_new in record_set_tuples:
673
+ if is_new:
674
+ new_record_op_stats.extend(record_set.record_op_stats)
578
675
 
579
676
  # update plan stats
580
- plan_stats.add_record_op_stats(new_record_op_stats)
581
-
582
- # add records (which are not filtered) to the cache, if allowed
583
- self._add_records_to_cache(logical_op_id, new_records)
584
-
585
- # FUTURE TODO: simply set input based on source_idx_to_target_record_set (b/c we won't have scores computed)
586
- # provide the champion record sets as inputs to the next logical operator
587
- if op_idx + 1 < len(plan):
588
- next_logical_op_id = plan.logical_op_ids[op_idx + 1]
589
- op_frontiers[next_logical_op_id].update_inputs(source_idx_to_all_record_sets)
677
+ plan_stats.add_record_op_stats(unique_logical_op_id, new_record_op_stats)
678
+ plan_stats.add_validation_gen_stats(unique_logical_op_id, val_gen_stats)
679
+
680
+ # provide the best record sets as inputs to the next logical operator
681
+ next_unique_logical_op_id = plan.get_next_unique_logical_op_id(unique_logical_op_id)
682
+ if next_unique_logical_op_id is not None:
683
+ source_indices_to_all_record_sets = {
684
+ source_indices: [record_set for record_set, _ in record_set_tuples]
685
+ for source_indices, record_set_tuples in source_indices_to_all_record_sets.items()
686
+ }
687
+ op_frontiers[next_unique_logical_op_id].update_inputs(unique_logical_op_id, source_indices_to_all_record_sets)
590
688
 
591
689
  # update the (pareto) frontier for each set of operators
592
- op_frontiers[logical_op_id].update_frontier(logical_op_id, plan_stats)
593
-
594
- # FUTURE TODO: score op quality based on final outputs
595
-
596
- # close the cache
597
- self._close_cache(plan.logical_op_ids)
690
+ op_frontiers[unique_logical_op_id].update_frontier(unique_logical_op_id, plan_stats)
598
691
 
599
692
  # finalize plan stats
600
693
  plan_stats.finish()
@@ -602,9 +695,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
602
695
  return plan_stats
603
696
 
604
697
 
605
- def execute_sentinel_plan(self, plan: SentinelPlan, expected_outputs: dict[int, dict] | None):
606
- # for now, assert that the first operator in the plan is a ScanPhysicalOp
607
- assert all(isinstance(op, ScanPhysicalOp) for op in plan.operator_sets[0]), "First operator in physical plan must be a ScanPhysicalOp"
698
+ def execute_sentinel_plan(self, plan: SentinelPlan, train_dataset: dict[str, Dataset], validator: Validator) -> SentinelPlanStats:
608
699
  logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
609
700
  logger.info(f"Plan Details: {plan}")
610
701
 
@@ -613,26 +704,48 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
613
704
  plan_stats.start()
614
705
 
615
706
  # shuffle the indices of records to sample
616
- total_num_samples = len(self.val_datasource)
617
- shuffled_source_indices = [int(idx) for idx in np.arange(total_num_samples)]
618
- self.rng.shuffle(shuffled_source_indices)
707
+ dataset_id_to_shuffled_source_indices = {}
708
+ for dataset_id, dataset in train_dataset.items():
709
+ total_num_samples = len(dataset)
710
+ shuffled_source_indices = [f"{dataset_id}-{int(idx)}" for idx in np.arange(total_num_samples)]
711
+ self.rng.shuffle(shuffled_source_indices)
712
+ dataset_id_to_shuffled_source_indices[dataset_id] = shuffled_source_indices
619
713
 
620
714
  # initialize frontier for each logical operator
621
- op_frontiers = {
622
- logical_op_id: OpFrontier(op_set, shuffled_source_indices, self.k, self.j, self.seed, self.policy, self.priors)
623
- for logical_op_id, op_set in plan
624
- }
715
+ op_frontiers = {}
716
+ for topo_idx, (logical_op_id, op_set) in enumerate(plan):
717
+ unique_logical_op_id = f"{topo_idx}-{logical_op_id}"
718
+ source_unique_logical_op_ids = plan.get_source_unique_logical_op_ids(unique_logical_op_id)
719
+ root_dataset_ids = plan.get_root_dataset_ids(unique_logical_op_id)
720
+ sample_op = op_set[0]
721
+ if isinstance(sample_op, (ScanPhysicalOp, ContextScanOp)):
722
+ assert len(root_dataset_ids) == 1, f"Scan for {sample_op} has {len(root_dataset_ids)} > 1 root dataset ids"
723
+ root_dataset_id = root_dataset_ids[0]
724
+ source_indices = dataset_id_to_shuffled_source_indices[root_dataset_id]
725
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
726
+ elif isinstance(sample_op, JoinOp):
727
+ assert len(source_unique_logical_op_ids) == 2, f"Join for {sample_op} has {len(source_unique_logical_op_ids)} != 2 source logical operators"
728
+ left_source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
729
+ right_source_indices = op_frontiers[source_unique_logical_op_ids[1]].source_indices
730
+ source_indices = []
731
+ for left_source_idx in left_source_indices:
732
+ for right_source_idx in right_source_indices:
733
+ source_indices.append((left_source_idx, right_source_idx))
734
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
735
+ else:
736
+ source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
737
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
625
738
 
626
739
  # initialize and start the progress manager
627
740
  self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, progress=self.progress)
628
741
  self.progress_manager.start()
629
742
 
630
- # NOTE: we must handle progress manager outside of _exeecute_sentinel_plan to ensure that it is shut down correctly;
743
+ # NOTE: we must handle progress manager outside of _execute_sentinel_plan to ensure that it is shut down correctly;
631
744
  # if we don't have the `finally:` branch, then program crashes can cause future program runs to fail because
632
745
  # the progress manager cannot get a handle to the console
633
746
  try:
634
747
  # execute sentinel plan by sampling records and operators
635
- plan_stats = self._execute_sentinel_plan(plan, op_frontiers, expected_outputs, plan_stats)
748
+ plan_stats = self._execute_sentinel_plan(plan, op_frontiers, validator, plan_stats)
636
749
 
637
750
  finally:
638
751
  # finish progress tracking