palimpzest 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. palimpzest/constants.py +1 -0
  2. palimpzest/core/data/dataset.py +33 -5
  3. palimpzest/core/elements/groupbysig.py +10 -1
  4. palimpzest/core/elements/records.py +16 -7
  5. palimpzest/core/lib/schemas.py +20 -3
  6. palimpzest/core/models.py +10 -4
  7. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  8. palimpzest/query/execution/execution_strategy.py +13 -11
  9. palimpzest/query/execution/mab_execution_strategy.py +40 -14
  10. palimpzest/query/execution/parallel_execution_strategy.py +31 -7
  11. palimpzest/query/execution/single_threaded_execution_strategy.py +23 -6
  12. palimpzest/query/generators/generators.py +1 -1
  13. palimpzest/query/operators/__init__.py +7 -6
  14. palimpzest/query/operators/aggregate.py +110 -5
  15. palimpzest/query/operators/convert.py +1 -1
  16. palimpzest/query/operators/join.py +279 -23
  17. palimpzest/query/operators/logical.py +20 -8
  18. palimpzest/query/operators/mixture_of_agents.py +3 -1
  19. palimpzest/query/operators/physical.py +5 -2
  20. palimpzest/query/operators/rag.py +5 -4
  21. palimpzest/query/operators/{retrieve.py → topk.py} +10 -10
  22. palimpzest/query/optimizer/__init__.py +7 -3
  23. palimpzest/query/optimizer/cost_model.py +5 -5
  24. palimpzest/query/optimizer/optimizer.py +3 -2
  25. palimpzest/query/optimizer/plan.py +2 -3
  26. palimpzest/query/optimizer/rules.py +31 -11
  27. palimpzest/query/optimizer/tasks.py +4 -4
  28. palimpzest/query/processor/config.py +1 -0
  29. palimpzest/utils/progress.py +51 -23
  30. palimpzest/validator/validator.py +7 -7
  31. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/METADATA +26 -66
  32. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/RECORD +35 -35
  33. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/WHEEL +0 -0
  34. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/licenses/LICENSE +0 -0
  35. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/top_level.txt +0 -0
palimpzest/constants.py CHANGED
@@ -207,6 +207,7 @@ class Modality(str, Enum):
207
207
  class AggFunc(str, Enum):
208
208
  COUNT = "count"
209
209
  AVERAGE = "average"
210
+ SUM = "sum"
210
211
  MIN = "min"
211
212
  MAX = "max"
212
213
 
@@ -22,7 +22,7 @@ from palimpzest.query.operators.logical import (
22
22
  LimitScan,
23
23
  LogicalOperator,
24
24
  Project,
25
- RetrieveScan,
25
+ TopKScan,
26
26
  )
27
27
  from palimpzest.query.processor.config import QueryProcessorConfig
28
28
  from palimpzest.utils.hash_helpers import hash_for_serialized_dict
@@ -243,7 +243,30 @@ class Dataset:
243
243
  id=self.id,
244
244
  )
245
245
 
246
- def sem_join(self, other: Dataset, condition: str, desc: str | None = None, depends_on: str | list[str] | None = None) -> Dataset:
246
+ def join(self, other: Dataset, on: str | list[str], how: str = "inner") -> Dataset:
247
+ """
248
+ Perform the specified join on the specified (list of) column(s)
249
+ """
250
+ # enforce type for on
251
+ if isinstance(on, str):
252
+ on = [on]
253
+
254
+ # construct new output schema
255
+ combined_schema = union_schemas([self.schema, other.schema], join=True, on=on)
256
+
257
+ # construct logical operator
258
+ operator = JoinOp(
259
+ input_schema=combined_schema,
260
+ output_schema=combined_schema,
261
+ condition="",
262
+ on=on,
263
+ how=how,
264
+ depends_on=on,
265
+ )
266
+
267
+ return Dataset(sources=[self, other], operator=operator, schema=combined_schema)
268
+
269
+ def sem_join(self, other: Dataset, condition: str, desc: str | None = None, depends_on: str | list[str] | None = None, how: str = "inner") -> Dataset:
247
270
  """
248
271
  Perform a semantic (inner) join on the specified join predicate
249
272
  """
@@ -259,6 +282,7 @@ class Dataset:
259
282
  input_schema=combined_schema,
260
283
  output_schema=combined_schema,
261
284
  condition=condition,
285
+ how=how,
262
286
  desc=desc,
263
287
  depends_on=depends_on,
264
288
  )
@@ -346,7 +370,6 @@ class Dataset:
346
370
 
347
371
  return Dataset(sources=[self], operator=operator, schema=new_output_schema)
348
372
 
349
-
350
373
  def sem_add_columns(self, cols: list[dict] | type[BaseModel],
351
374
  cardinality: Cardinality = Cardinality.ONE_TO_ONE,
352
375
  desc: str | None = None,
@@ -534,6 +557,11 @@ class Dataset:
534
557
  operator = Aggregate(input_schema=self.schema, agg_func=AggFunc.AVERAGE)
535
558
  return Dataset(sources=[self], operator=operator, schema=operator.output_schema)
536
559
 
560
+ def sum(self) -> Dataset:
561
+ """Apply a summation to this set"""
562
+ operator = Aggregate(input_schema=self.schema, agg_func=AggFunc.SUM)
563
+ return Dataset(sources=[self], operator=operator, schema=operator.output_schema)
564
+
537
565
  def min(self) -> Dataset:
538
566
  """Apply an min operator to this set"""
539
567
  operator = Aggregate(input_schema=self.schema, agg_func=AggFunc.MIN)
@@ -581,7 +609,7 @@ class Dataset:
581
609
 
582
610
  return Dataset(sources=[self], operator=operator, schema=operator.output_schema)
583
611
 
584
- def retrieve(
612
+ def sem_topk(
585
613
  self,
586
614
  index: Collection,
587
615
  search_attr: str,
@@ -608,7 +636,7 @@ class Dataset:
608
636
  # index = index_factory(index)
609
637
 
610
638
  # construct logical operator
611
- operator = RetrieveScan(
639
+ operator = TopKScan(
612
640
  input_schema=self.schema,
613
641
  output_schema=new_output_schema,
614
642
  index=index,
@@ -6,8 +6,16 @@ from pydantic import BaseModel
6
6
 
7
7
  from palimpzest.core.lib.schemas import create_schema_from_fields
8
8
 
9
+ # TODO:
10
+ # - move the arguments for group_by_fields, agg_funcs, and agg_fields into the Dataset.groupby() operator
11
+ # - construct the correct output schema using the input schema and the group by and aggregation fields
12
+ # - remove/update all other references to GroupBySig in the codebase
13
+
14
+ # TODO:
15
+ # - move the arguments for group_by_fields, agg_funcs, and agg_fields into the Dataset.groupby() operator
16
+ # - construct the correct output schema using the input schema and the group by and aggregation fields
17
+ # - remove/update all other references to GroupBySig in the codebase
9
18
 
10
- # TODO: need to rethink how group bys work
11
19
  # signature for a group by aggregate that applies
12
20
  # group and aggregation to an input tuple
13
21
  class GroupBySig:
@@ -50,6 +58,7 @@ class GroupBySig:
50
58
  ops.append(self.agg_funcs[i] + "(" + self.agg_fields[i] + ")")
51
59
  return ops
52
60
 
61
+ # TODO: output schema needs to account for input schema types and create new output schema types
53
62
  def output_schema(self) -> type[BaseModel]:
54
63
  # the output class varies depending on the group by, so here
55
64
  # we dynamically construct this output
@@ -140,7 +140,7 @@ class DataRecord:
140
140
  def schema(self) -> type[BaseModel]:
141
141
  return type(self._data_item)
142
142
 
143
- def copy(self):
143
+ def copy(self) -> DataRecord:
144
144
  # get the set of fields to copy from the parent record
145
145
  copy_field_names = [field.split(".")[-1] for field in self.get_field_names()]
146
146
 
@@ -228,18 +228,18 @@ class DataRecord:
228
228
  @staticmethod
229
229
  def from_join_parents(
230
230
  schema: type[BaseModel],
231
- left_parent_record: DataRecord,
232
- right_parent_record: DataRecord,
231
+ left_parent_record: DataRecord | None,
232
+ right_parent_record: DataRecord | None,
233
233
  project_cols: list[str] | None = None,
234
234
  cardinality_idx: int = None,
235
235
  ) -> DataRecord:
236
236
  # get the set of fields and field descriptions to copy from the parent record(s)
237
- left_copy_field_names = (
237
+ left_copy_field_names = [] if left_parent_record is None else (
238
238
  left_parent_record.get_field_names()
239
239
  if project_cols is None
240
240
  else [col for col in project_cols if col in left_parent_record.get_field_names()]
241
241
  )
242
- right_copy_field_names = (
242
+ right_copy_field_names = [] if right_parent_record is None else (
243
243
  right_parent_record.get_field_names()
244
244
  if project_cols is None
245
245
  else [col for col in project_cols if col in right_parent_record.get_field_names()]
@@ -255,11 +255,20 @@ class DataRecord:
255
255
  new_field_name = f"{field_name}_right"
256
256
  data_item[new_field_name] = right_parent_record[field_name]
257
257
 
258
+ # for any missing fields in the schema, set them to None
259
+ for field_name in schema.model_fields:
260
+ if field_name not in data_item:
261
+ data_item[field_name] = None
262
+
258
263
  # make new record which has left and right parent record as its parents
264
+ left_parent_source_indices = [] if left_parent_record is None else list(left_parent_record._source_indices)
265
+ right_parent_source_indices = [] if right_parent_record is None else list(right_parent_record._source_indices)
266
+ left_parent_record_id = [] if left_parent_record is None else [left_parent_record._id]
267
+ right_parent_record_id = [] if right_parent_record is None else [right_parent_record._id]
259
268
  new_dr = DataRecord(
260
269
  schema(**data_item),
261
- source_indices=list(left_parent_record._source_indices) + list(right_parent_record._source_indices),
262
- parent_ids=[left_parent_record._id, right_parent_record._id],
270
+ source_indices=left_parent_source_indices + right_parent_source_indices,
271
+ parent_ids=left_parent_record_id + right_parent_record_id,
263
272
  cardinality_idx=cardinality_idx,
264
273
  )
265
274
 
@@ -142,16 +142,30 @@ def create_schema_from_df(df: pd.DataFrame) -> type[BaseModel]:
142
142
  return _create_pickleable_model(fields)
143
143
 
144
144
 
145
- def union_schemas(models: list[type[BaseModel]], join: bool = False) -> type[BaseModel]:
145
+ def union_schemas(models: list[type[BaseModel]], join: bool = False, on: list[str] | None = None) -> type[BaseModel]:
146
146
  """Union multiple Pydantic models into a single model."""
147
+ # convert on to empty list if None
148
+ if on is None:
149
+ on = []
150
+
151
+ # build up the fields for the new schema
147
152
  fields = {}
148
153
  for model in models:
149
154
  for field_name, field in model.model_fields.items():
150
- if field_name in fields and not join:
155
+ # for non-join unions, make sure duplicate fields have the same type
156
+ if not join and field_name in fields:
151
157
  assert fields[field_name][0] == field.annotation, f"Field {field_name} has different types in different models"
152
- elif field_name in fields and join:
158
+
159
+ # for joins with "on" specified, no need to rename fields in "on"
160
+ elif join and field_name in on and field_name in fields:
161
+ continue
162
+
163
+ # otherwise, rename duplicate fields by appending _right
164
+ elif join and field_name in fields:
153
165
  while field_name in fields:
154
166
  field_name = f"{field_name}_right"
167
+
168
+ # add the field to the new schema
155
169
  fields[field_name] = (field.annotation, field)
156
170
 
157
171
  # create and return the new schema
@@ -194,6 +208,9 @@ class Average(BaseModel):
194
208
  class Count(BaseModel):
195
209
  count: int = Field(description="The count of items in the dataset")
196
210
 
211
+ class Sum(BaseModel):
212
+ sum: int = Field(description="The summation of items in the dataset")
213
+
197
214
  class Min(BaseModel):
198
215
  min: int | float = Field(description="The minimum value of some items in the dataset")
199
216
 
palimpzest/core/models.py CHANGED
@@ -51,10 +51,10 @@ class GenerationStats(BaseModel):
51
51
  fn_call_duration_secs: float = 0.0
52
52
 
53
53
  # (if applicable) the total number of LLM calls made by this operator
54
- total_llm_calls: int = 0
54
+ total_llm_calls: float = 0
55
55
 
56
56
  # (if applicable) the total number of embedding LLM calls made by this operator
57
- total_embedding_llm_calls: int = 0
57
+ total_embedding_llm_calls: float = 0
58
58
 
59
59
  def __iadd__(self, other: GenerationStats) -> GenerationStats:
60
60
  # self.raw_answers.extend(other.raw_answers)
@@ -243,10 +243,10 @@ class RecordOpStats(BaseModel):
243
243
  fn_call_duration_secs: float = 0.0
244
244
 
245
245
  # (if applicable) the total number of LLM calls made by this operator
246
- total_llm_calls: int = 0
246
+ total_llm_calls: float = 0
247
247
 
248
248
  # (if applicable) the total number of embedding LLM calls made by this operator
249
- total_embedding_llm_calls: int = 0
249
+ total_embedding_llm_calls: float = 0
250
250
 
251
251
  # (if applicable) a boolean indicating whether this is the statistics captured from a failed convert operation
252
252
  failed_convert: bool | None = None
@@ -454,6 +454,12 @@ class BasePlanStats(BaseModel):
454
454
  """
455
455
  return sum([gen_stats.total_output_tokens for _, gen_stats in self.validation_gen_stats.items()])
456
456
 
457
+ def get_total_cost_so_far(self) -> float:
458
+ """
459
+ Get the total cost incurred so far in this plan execution.
460
+ """
461
+ return self.sum_op_costs() + self.sum_validation_costs()
462
+
457
463
 
458
464
  class PlanStats(BasePlanStats):
459
465
  """
@@ -225,7 +225,7 @@ class AllSamplingExecutionStrategy(SentinelExecutionStrategy):
225
225
  dataset_id_to_source_indices = {}
226
226
  for dataset_id, dataset in train_dataset.items():
227
227
  total_num_samples = len(dataset)
228
- source_indices = [f"{dataset_id}-{int(idx)}" for idx in np.arange(total_num_samples)]
228
+ source_indices = [f"{dataset_id}---{int(idx)}" for idx in np.arange(total_num_samples)]
229
229
  dataset_id_to_source_indices[dataset_id] = source_indices
230
230
 
231
231
  # initialize set of physical operators for each logical operator
@@ -14,8 +14,8 @@ from palimpzest.query.operators.convert import LLMConvert
14
14
  from palimpzest.query.operators.filter import LLMFilter
15
15
  from palimpzest.query.operators.join import JoinOp
16
16
  from palimpzest.query.operators.physical import PhysicalOperator
17
- from palimpzest.query.operators.retrieve import RetrieveOp
18
17
  from palimpzest.query.operators.scan import ContextScanOp, ScanPhysicalOp
18
+ from palimpzest.query.operators.topk import TopKOp
19
19
  from palimpzest.query.optimizer.plan import PhysicalPlan, SentinelPlan
20
20
  from palimpzest.utils.progress import PZSentinelProgressManager
21
21
  from palimpzest.validator.validator import Validator
@@ -82,10 +82,11 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
82
82
  """
83
83
  def __init__(
84
84
  self,
85
- k: int,
86
- j: int,
87
- sample_budget: int,
88
85
  policy: Policy,
86
+ k: int = 6,
87
+ j: int = 4,
88
+ sample_budget: int = 100,
89
+ sample_cost_budget: float | None = None,
89
90
  priors: dict | None = None,
90
91
  use_final_op_quality: bool = False,
91
92
  seed: int = 42,
@@ -97,6 +98,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
97
98
  self.k = k
98
99
  self.j = j
99
100
  self.sample_budget = sample_budget
101
+ self.sample_cost_budget = sample_cost_budget
100
102
  self.policy = policy
101
103
  self.priors = priors
102
104
  self.use_final_op_quality = use_final_op_quality
@@ -123,7 +125,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
123
125
  return (
124
126
  not isinstance(op, LLMConvert)
125
127
  and not isinstance(op, LLMFilter)
126
- and not isinstance(op, RetrieveOp)
128
+ and not isinstance(op, TopKOp)
127
129
  and not isinstance(op, JoinOp)
128
130
  )
129
131
 
@@ -167,8 +169,8 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
167
169
  full_hashes.add(full_hash)
168
170
  futures.append(executor.submit(validator._score_flat_map, op, fields, input_record, output, full_hash))
169
171
 
170
- # create future for retrieve
171
- elif isinstance(op, RetrieveOp):
172
+ # create future for top-k
173
+ elif isinstance(op, TopKOp):
172
174
  fields = op.generated_fields
173
175
  input_record: DataRecord = record_set.input
174
176
  output = record_set.data_records[0].to_dict(project_cols=fields)
@@ -176,7 +178,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
176
178
  full_hash = f"{hash(input_record)}{hash(output_str)}"
177
179
  if full_hash not in full_hashes:
178
180
  full_hashes.add(full_hash)
179
- futures.append(executor.submit(validator._score_retrieve, op, fields, input_record, output, full_hash))
181
+ futures.append(executor.submit(validator._score_topk, op, fields, input_record, output, full_hash))
180
182
 
181
183
  # create future for filter
182
184
  elif isinstance(op, LLMFilter):
@@ -235,7 +237,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
235
237
 
236
238
  # TODO: this scoring function will (likely) bias towards small values of k since it
237
239
  # measures precision and not recall / F1; will need to revisit this in the future
238
- elif isinstance(op, RetrieveOp):
240
+ elif isinstance(op, TopKOp):
239
241
  fields = op.generated_fields
240
242
  input_record: DataRecord = record_set.input
241
243
  output_str = record_set.data_records[0].to_json_str(project_cols=fields, bytes_to_str=True, sorted=True)
@@ -341,9 +343,9 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
341
343
  def _is_llm_op(self, physical_op: PhysicalOperator) -> bool:
342
344
  is_llm_convert = isinstance(physical_op, LLMConvert)
343
345
  is_llm_filter = isinstance(physical_op, LLMFilter)
344
- is_llm_retrieve = isinstance(physical_op, RetrieveOp) and isinstance(physical_op.index, Collection)
346
+ is_llm_topk = isinstance(physical_op, TopKOp) and isinstance(physical_op.index, Collection)
345
347
  is_llm_join = isinstance(physical_op, JoinOp)
346
- return is_llm_convert or is_llm_filter or is_llm_retrieve or is_llm_join
348
+ return is_llm_convert or is_llm_filter or is_llm_topk or is_llm_join
347
349
 
348
350
  @abstractmethod
349
351
  def execute_sentinel_plan(self, sentinel_plan: SentinelPlan, train_dataset: dict[str, Dataset], validator: Validator) -> SentinelPlanStats:
@@ -14,8 +14,8 @@ from palimpzest.query.operators.convert import LLMConvert
14
14
  from palimpzest.query.operators.filter import FilterOp, LLMFilter, NonLLMFilter
15
15
  from palimpzest.query.operators.join import JoinOp
16
16
  from palimpzest.query.operators.physical import PhysicalOperator
17
- from palimpzest.query.operators.retrieve import RetrieveOp
18
17
  from palimpzest.query.operators.scan import ContextScanOp, ScanPhysicalOp
18
+ from palimpzest.query.operators.topk import TopKOp
19
19
  from palimpzest.query.optimizer.plan import SentinelPlan
20
20
  from palimpzest.utils.progress import create_progress_manager
21
21
  from palimpzest.validator.validator import Validator
@@ -66,8 +66,8 @@ class OpFrontier:
66
66
  self.is_llm_join = isinstance(sample_op, JoinOp)
67
67
  is_llm_convert = isinstance(sample_op, LLMConvert)
68
68
  is_llm_filter = isinstance(sample_op, LLMFilter)
69
- is_llm_retrieve = isinstance(sample_op, RetrieveOp) and isinstance(sample_op.index, Collection)
70
- self.is_llm_op = is_llm_convert or is_llm_filter or is_llm_retrieve or self.is_llm_join
69
+ is_llm_topk = isinstance(sample_op, TopKOp) and isinstance(sample_op.index, Collection)
70
+ self.is_llm_op = is_llm_convert or is_llm_filter or is_llm_topk or self.is_llm_join
71
71
 
72
72
  # get order in which we will sample physical operators for this logical operator
73
73
  sample_op_indices = self._get_op_index_order(op_set, seed)
@@ -96,6 +96,12 @@ class OpFrontier:
96
96
  """
97
97
  return self.frontier_ops
98
98
 
99
+ def get_off_frontier_ops(self) -> list[PhysicalOperator]:
100
+ """
101
+ Returns the set of off-frontier operators for this OpFrontier.
102
+ """
103
+ return self.off_frontier_ops
104
+
99
105
  def _compute_op_id_to_pareto_distance(self, priors: dict[str, dict[str, float]]) -> dict[str, float]:
100
106
  """
101
107
  Return l2-distance for each operator from the pareto frontier.
@@ -298,7 +304,7 @@ class OpFrontier:
298
304
  def remove_unavailable_root_datasets(source_indices: str | tuple) -> str | tuple | None:
299
305
  # base case: source_indices is a string
300
306
  if isinstance(source_indices, str):
301
- return source_indices if source_indices.split("-")[0] in self.root_dataset_ids else None
307
+ return source_indices if source_indices.split("---")[0] in self.root_dataset_ids else None
302
308
 
303
309
  # recursive case: source_indices is a tuple
304
310
  left_indices = source_indices[0]
@@ -383,6 +389,12 @@ class OpFrontier:
383
389
  # compute final list of record op stats
384
390
  full_op_id_to_record_op_stats[full_op_id] = list(record_id_to_max_quality_record_op_stats.values())
385
391
 
392
+ # NOTE: it is possible for the full_op_id_to_record_op_stats to be empty if there is a duplicate operator
393
+ # (e.g. a scan of the same dataset) which has all of its results cached and no new_record_op_stats;
394
+ # in this case, we do not update the frontier
395
+ if full_op_id_to_record_op_stats == {}:
396
+ return
397
+
386
398
  # update the set of source indices processed by each physical operator
387
399
  for full_op_id, source_indices_processed in full_op_id_to_source_indices_processed.items():
388
400
  # update the set of source indices processed
@@ -641,8 +653,8 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
641
653
  """
642
654
  Returns the operator in the frontier with the highest (estimated) quality.
643
655
  """
644
- # get the operators in the frontier set for this logical_op_id
645
- frontier_ops = op_frontiers[unique_logical_op_id].get_frontier_ops()
656
+ # get the (off) frontier operators for this logical_op_id
657
+ frontier_ops = op_frontiers[unique_logical_op_id].get_frontier_ops() + op_frontiers[unique_logical_op_id].get_off_frontier_ops()
646
658
 
647
659
  # get a mapping from full_op_id --> list[RecordOpStats]
648
660
  full_op_id_to_op_stats: dict[str, OperatorStats] = plan_stats.operator_stats.get(unique_logical_op_id, {})
@@ -668,6 +680,9 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
668
680
 
669
681
  return max_quality_op
670
682
 
683
+ def _compute_termination_condition(self, samples_drawn: int, sampling_cost: float) -> bool:
684
+ return (samples_drawn >= self.sample_budget) if self.sample_cost_budget is None else (sampling_cost >= self.sample_cost_budget)
685
+
671
686
  def _execute_sentinel_plan(
672
687
  self,
673
688
  plan: SentinelPlan,
@@ -676,8 +691,8 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
676
691
  plan_stats: SentinelPlanStats,
677
692
  ) -> SentinelPlanStats:
678
693
  # sample records and operators and update the frontiers
679
- samples_drawn = 0
680
- while samples_drawn < self.sample_budget:
694
+ samples_drawn, sampling_cost = 0, 0.0
695
+ while not self._compute_termination_condition(samples_drawn, sampling_cost):
681
696
  # pre-compute the set of source indices which will need to be sampled
682
697
  source_indices_to_sample = set()
683
698
  for op_frontier in op_frontiers.values():
@@ -693,14 +708,21 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
693
708
  max_quality_op = self._get_max_quality_op(unique_logical_op_id, op_frontiers, plan_stats)
694
709
 
695
710
  # get frontier ops and their next input
696
- def is_filtered_out(tup: tuple) -> bool:
697
- return tup[-1] is None or isinstance(tup[-1], list) and all([record is None for record in tup[-1]])
711
+ def filter_and_clean_inputs(frontier_op_inputs: list[tuple]) -> bool:
712
+ cleaned_inputs = []
713
+ for tup in frontier_op_inputs:
714
+ input = tup[-1]
715
+ if isinstance(input, list):
716
+ input = [record for record in input if record is not None]
717
+ if input is not None and input != []:
718
+ cleaned_inputs.append((tup[0], tup[1], input))
719
+ return cleaned_inputs
698
720
  frontier_op_inputs = op_frontiers[unique_logical_op_id].get_frontier_op_inputs(source_indices_to_sample, max_quality_op)
699
- frontier_op_inputs = list(filter(lambda tup: not is_filtered_out(tup), frontier_op_inputs))
721
+ frontier_op_inputs = filter_and_clean_inputs(frontier_op_inputs)
700
722
 
701
723
  # break out of the loop if frontier_op_inputs is empty, as this means all records have been filtered out
702
724
  if len(frontier_op_inputs) == 0:
703
- break
725
+ continue
704
726
 
705
727
  # run sampled operators on sampled inputs and update the number of samples drawn
706
728
  source_indices_to_record_set_tuples, num_llm_ops = self._execute_op_set(unique_logical_op_id, frontier_op_inputs)
@@ -713,6 +735,9 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
713
735
  }
714
736
  source_indices_to_all_record_sets, val_gen_stats = self._score_quality(validator, source_indices_to_all_record_sets)
715
737
 
738
+ # update the progress manager with validation cost
739
+ self.progress_manager.incr_overall_progress_cost(val_gen_stats.cost_per_record)
740
+
716
741
  # remove records that were read from the execution cache before adding to record op stats
717
742
  new_record_op_stats = []
718
743
  for _, record_set_tuples in source_indices_to_record_set_tuples.items():
@@ -723,6 +748,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
723
748
  # update plan stats
724
749
  plan_stats.add_record_op_stats(unique_logical_op_id, new_record_op_stats)
725
750
  plan_stats.add_validation_gen_stats(unique_logical_op_id, val_gen_stats)
751
+ sampling_cost = plan_stats.get_total_cost_so_far()
726
752
 
727
753
  # provide the best record sets as inputs to the next logical operator
728
754
  next_unique_logical_op_id = plan.get_next_unique_logical_op_id(unique_logical_op_id)
@@ -764,7 +790,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
764
790
  dataset_id_to_shuffled_source_indices = {}
765
791
  for dataset_id, dataset in train_dataset.items():
766
792
  total_num_samples = len(dataset)
767
- shuffled_source_indices = [f"{dataset_id}-{int(idx)}" for idx in np.arange(total_num_samples)]
793
+ shuffled_source_indices = [f"{dataset_id}---{int(idx)}" for idx in np.arange(total_num_samples)]
768
794
  self.rng.shuffle(shuffled_source_indices)
769
795
  dataset_id_to_shuffled_source_indices[dataset_id] = shuffled_source_indices
770
796
 
@@ -794,7 +820,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
794
820
  op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
795
821
 
796
822
  # initialize and start the progress manager
797
- self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, progress=self.progress)
823
+ self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, sample_cost_budget=self.sample_cost_budget, progress=self.progress)
798
824
  self.progress_manager.start()
799
825
 
800
826
  # NOTE: we must handle progress manager outside of _execute_sentinel_plan to ensure that it is shut down correctly;
@@ -9,7 +9,6 @@ from palimpzest.query.operators.aggregate import AggregateOp
9
9
  from palimpzest.query.operators.distinct import DistinctOp
10
10
  from palimpzest.query.operators.join import JoinOp
11
11
  from palimpzest.query.operators.limit import LimitScanOp
12
- from palimpzest.query.operators.physical import PhysicalOperator
13
12
  from palimpzest.query.operators.scan import ContextScanOp, ScanPhysicalOp
14
13
  from palimpzest.query.optimizer.plan import PhysicalPlan
15
14
  from palimpzest.utils.progress import create_progress_manager
@@ -35,14 +34,27 @@ class ParallelExecutionStrategy(ExecutionStrategy):
35
34
  return True
36
35
  return False
37
36
 
38
- def _upstream_ops_finished(self, plan: PhysicalPlan, topo_idx: int, operator: PhysicalOperator, input_queues: dict[str, dict[str, list]], future_queues: dict[str, list]) -> bool:
37
+ def _upstream_ops_finished(self, plan: PhysicalPlan, unique_full_op_id: str, input_queues: dict[str, dict[str, list]], future_queues: dict[str, list]) -> bool:
39
38
  """Helper function to check if agg / join operator is ready to process its inputs."""
40
- # for agg / join operator, we can only process it when all upstream operators have finished processing their inputs
41
- upstream_unique_full_op_ids = plan.get_upstream_unique_full_op_ids(topo_idx, operator)
39
+ upstream_unique_full_op_ids = plan.get_upstream_unique_full_op_ids(unique_full_op_id)
42
40
  upstream_input_queues = {upstream_unique_full_op_id: input_queues[upstream_unique_full_op_id] for upstream_unique_full_op_id in upstream_unique_full_op_ids}
43
41
  upstream_future_queues = {upstream_unique_full_op_id: future_queues[upstream_unique_full_op_id] for upstream_unique_full_op_id in upstream_unique_full_op_ids}
44
42
  return not (self._any_queue_not_empty(upstream_input_queues) or self._any_queue_not_empty(upstream_future_queues))
45
43
 
44
+ def _finish_outer_join(self, executor: ThreadPoolExecutor, plan: PhysicalPlan, unique_full_op_id: str, input_queues: dict[str, dict[str, list]], future_queues: dict[str, list]) -> None:
45
+ join_op_upstream_finished = self._upstream_ops_finished(plan, unique_full_op_id, input_queues, future_queues)
46
+ join_input_queues_empty = all(len(inputs) == 0 for inputs in input_queues[unique_full_op_id].values())
47
+ join_future_queue_empty = len(future_queues[unique_full_op_id]) == 0
48
+ if join_op_upstream_finished and join_input_queues_empty and join_future_queue_empty:
49
+ # process the join one last time with final=True to handle any left/right/outer join logic
50
+ operator = self.unique_full_op_id_to_operator[unique_full_op_id]
51
+ if not operator.finished:
52
+ def finalize_op(operator):
53
+ return operator([], [], final=True)
54
+ future = executor.submit(finalize_op, operator)
55
+ future_queues[unique_full_op_id].append(future)
56
+ operator.set_finished()
57
+
46
58
  def _process_future_results(self, unique_full_op_id: str, future_queues: dict[str, list], plan_stats: PlanStats) -> list[DataRecord]:
47
59
  """
48
60
  Helper function which takes a full operator id, the future queues, and plan stats, and performs
@@ -117,15 +129,23 @@ class ParallelExecutionStrategy(ExecutionStrategy):
117
129
  records = self._process_future_results(source_unique_full_op_id, future_queues, plan_stats)
118
130
  input_queues[unique_full_op_id][source_unique_full_op_id].extend(records)
119
131
 
132
+ # if the source is a left/right/outer join operator with no more inputs to process, then finish it
133
+ if self.is_outer_join_op[source_unique_full_op_id]:
134
+ self._finish_outer_join(executor, plan, source_unique_full_op_id, input_queues, future_queues)
135
+
120
136
  # for the final operator, add any finished futures to the output_records
121
137
  if unique_full_op_id == f"{topo_idx}-{final_op.get_full_op_id()}":
122
138
  records = self._process_future_results(unique_full_op_id, future_queues, plan_stats)
123
139
  output_records.extend(records)
124
140
 
141
+ # if this is a left/right/outer join operator with no more inputs to process, then finish it
142
+ if self.is_outer_join_op[unique_full_op_id]:
143
+ self._finish_outer_join(executor, plan, unique_full_op_id, input_queues, future_queues)
144
+
125
145
  # if this operator does not have enough inputs to execute, then skip it
126
146
  num_inputs = sum(len(inputs) for inputs in input_queues[unique_full_op_id].values())
127
- agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, topo_idx, operator, input_queues, future_queues)
128
- join_op_not_ready = isinstance(operator, JoinOp) and not self.join_has_downstream_limit_op[unique_full_op_id] and not self._upstream_ops_finished(plan, topo_idx, operator, input_queues, future_queues)
147
+ agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, unique_full_op_id, input_queues, future_queues)
148
+ join_op_not_ready = isinstance(operator, JoinOp) and not self.join_has_downstream_limit_op[unique_full_op_id] and not self._upstream_ops_finished(plan, unique_full_op_id, input_queues, future_queues)
129
149
  if num_inputs == 0 or agg_op_not_ready or join_op_not_ready:
130
150
  continue
131
151
 
@@ -225,8 +245,9 @@ class ParallelExecutionStrategy(ExecutionStrategy):
225
245
  input_queues = self._create_input_queues(plan)
226
246
  future_queues = {f"{topo_idx}-{op.get_full_op_id()}": [] for topo_idx, op in enumerate(plan)}
227
247
 
228
- # precompute which operators are joins and which joins have downstream limit ops
248
+ # precompute which operators are (outer) joins and which joins have downstream limit ops
229
249
  self.is_join_op = {f"{topo_idx}-{op.get_full_op_id()}": isinstance(op, JoinOp) for topo_idx, op in enumerate(plan)}
250
+ self.is_outer_join_op = {f"{topo_idx}-{op.get_full_op_id()}": isinstance(op, JoinOp) and op.how in ("left", "right", "outer") for topo_idx, op in enumerate(plan)}
230
251
  self.join_has_downstream_limit_op = {}
231
252
  for topo_idx, op in enumerate(plan):
232
253
  if isinstance(op, JoinOp):
@@ -240,6 +261,9 @@ class ParallelExecutionStrategy(ExecutionStrategy):
240
261
  break
241
262
  self.join_has_downstream_limit_op[unique_full_op_id] = has_downstream_limit_op
242
263
 
264
+ # precompute mapping from unique_full_op_id to operator instance
265
+ self.unique_full_op_id_to_operator = {f"{topo_idx}-{op.get_full_op_id()}": op for topo_idx, op in enumerate(plan)}
266
+
243
267
  # initialize and start the progress manager
244
268
  self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
245
269
  self.progress_manager.start()
@@ -6,7 +6,6 @@ from palimpzest.query.execution.execution_strategy import ExecutionStrategy
6
6
  from palimpzest.query.operators.aggregate import AggregateOp
7
7
  from palimpzest.query.operators.join import JoinOp
8
8
  from palimpzest.query.operators.limit import LimitScanOp
9
- from palimpzest.query.operators.physical import PhysicalOperator
10
9
  from palimpzest.query.operators.scan import ContextScanOp, ScanPhysicalOp
11
10
  from palimpzest.query.optimizer.plan import PhysicalPlan
12
11
  from palimpzest.utils.progress import create_progress_manager
@@ -70,6 +69,13 @@ class SequentialSingleThreadExecutionStrategy(ExecutionStrategy):
70
69
  record_set, num_inputs_processed = operator(left_input_records, right_input_records)
71
70
  records = record_set.data_records
72
71
  record_op_stats = record_set.record_op_stats
72
+
73
+ # process the join one last time with final=True to handle any left/right/outer join logic
74
+ if operator.how in ("left", "right", "outer"):
75
+ record_set, num_inputs_processed = operator([], [], final=True)
76
+ records.extend(record_set.data_records)
77
+ record_op_stats.extend(record_set.record_op_stats)
78
+
73
79
  num_outputs = sum(record._passed_operator for record in records)
74
80
 
75
81
  # update the progress manager
@@ -168,10 +174,9 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
168
174
  return True
169
175
  return False
170
176
 
171
- def _upstream_ops_finished(self, plan: PhysicalPlan, topo_idx: int, operator: PhysicalOperator, input_queues: dict[str, dict[str, list]]) -> bool:
177
+ def _upstream_ops_finished(self, plan: PhysicalPlan, unique_full_op_id: str, input_queues: dict[str, dict[str, list]]) -> bool:
172
178
  """Helper function to check if agg / join operator is ready to process its inputs."""
173
- # for agg / join operator, we can only process it when all upstream operators have finished processing their inputs
174
- upstream_unique_full_op_ids = plan.get_upstream_unique_full_op_ids(topo_idx, operator)
179
+ upstream_unique_full_op_ids = plan.get_upstream_unique_full_op_ids(unique_full_op_id)
175
180
  upstream_input_queues = {upstream_unique_full_op_id: input_queues[upstream_unique_full_op_id] for upstream_unique_full_op_id in upstream_unique_full_op_ids}
176
181
  return not self._any_queue_not_empty(upstream_input_queues)
177
182
 
@@ -192,8 +197,8 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
192
197
  unique_full_op_id = f"{topo_idx}-{operator.get_full_op_id()}"
193
198
 
194
199
  num_inputs = sum(len(input_queues[unique_full_op_id][source_unique_full_op_id]) for source_unique_full_op_id in source_unique_full_op_ids)
195
- agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, topo_idx, operator, input_queues)
196
- join_op_not_ready = isinstance(operator, JoinOp) and not self._upstream_ops_finished(plan, topo_idx, operator, input_queues)
200
+ agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, unique_full_op_id, input_queues)
201
+ join_op_not_ready = isinstance(operator, JoinOp) and not self._upstream_ops_finished(plan, unique_full_op_id, input_queues)
197
202
  if num_inputs == 0 or agg_op_not_ready or join_op_not_ready:
198
203
  continue
199
204
 
@@ -242,6 +247,18 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
242
247
  # update the progress manager
243
248
  self.progress_manager.incr(unique_full_op_id, num_inputs=1, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
244
249
 
250
+ # if this is a join operator with no more inputs to process, then finish it
251
+ if isinstance(operator, JoinOp) and operator.how in ("left", "right", "outer"):
252
+ join_op_upstream_finished = self._upstream_ops_finished(plan, unique_full_op_id, input_queues)
253
+ join_input_queues_empty = all(len(inputs) == 0 for inputs in input_queues[unique_full_op_id].values())
254
+ if join_op_upstream_finished and join_input_queues_empty and not operator.finished:
255
+ # process the join one last time with final=True to handle any left/right/outer join logic
256
+ record_set, num_inputs_processed = operator([], [], final=True)
257
+ records.extend(record_set.data_records)
258
+ record_op_stats.extend(record_set.record_op_stats)
259
+ num_outputs += sum(record._passed_operator for record in record_set.data_records)
260
+ operator.set_finished()
261
+
245
262
  # update plan stats
246
263
  plan_stats.add_record_op_stats(unique_full_op_id, record_op_stats)
247
264