palimpzest 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. palimpzest/constants.py +1 -0
  2. palimpzest/core/data/dataset.py +33 -5
  3. palimpzest/core/elements/groupbysig.py +5 -1
  4. palimpzest/core/elements/records.py +16 -7
  5. palimpzest/core/lib/schemas.py +20 -3
  6. palimpzest/core/models.py +4 -4
  7. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  8. palimpzest/query/execution/execution_strategy.py +8 -8
  9. palimpzest/query/execution/mab_execution_strategy.py +30 -11
  10. palimpzest/query/execution/parallel_execution_strategy.py +31 -7
  11. palimpzest/query/execution/single_threaded_execution_strategy.py +23 -6
  12. palimpzest/query/operators/__init__.py +7 -6
  13. palimpzest/query/operators/aggregate.py +110 -5
  14. palimpzest/query/operators/convert.py +1 -1
  15. palimpzest/query/operators/join.py +279 -23
  16. palimpzest/query/operators/logical.py +20 -8
  17. palimpzest/query/operators/mixture_of_agents.py +3 -1
  18. palimpzest/query/operators/physical.py +5 -2
  19. palimpzest/query/operators/{retrieve.py → topk.py} +10 -10
  20. palimpzest/query/optimizer/__init__.py +7 -3
  21. palimpzest/query/optimizer/cost_model.py +5 -5
  22. palimpzest/query/optimizer/optimizer.py +3 -2
  23. palimpzest/query/optimizer/plan.py +2 -3
  24. palimpzest/query/optimizer/rules.py +31 -11
  25. palimpzest/query/optimizer/tasks.py +4 -4
  26. palimpzest/utils/progress.py +19 -17
  27. palimpzest/validator/validator.py +7 -7
  28. {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/METADATA +26 -66
  29. {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/RECORD +32 -32
  30. {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/WHEEL +0 -0
  31. {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/licenses/LICENSE +0 -0
  32. {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ from palimpzest.constants import AggFunc, Cardinality
9
9
  from palimpzest.core.data import context, dataset
10
10
  from palimpzest.core.elements.filters import Filter
11
11
  from palimpzest.core.elements.groupbysig import GroupBySig
12
- from palimpzest.core.lib.schemas import Average, Count, Max, Min
12
+ from palimpzest.core.lib.schemas import Average, Count, Max, Min, Sum
13
13
  from palimpzest.utils.hash_helpers import hash_for_id
14
14
 
15
15
 
@@ -25,7 +25,7 @@ class LogicalOperator:
25
25
  - LimitScan (scans up to N records from a Set)
26
26
  - GroupByAggregate (applies a group by on the Set)
27
27
  - Aggregate (applies an aggregation on the Set)
28
- - RetrieveScan (fetches documents from a provided input for a given query)
28
+ - TopKScan (fetches documents from a provided input for a given query)
29
29
  - Map (applies a function to each record in the Set without adding any new columns)
30
30
  - ComputeOperator (executes a computation described in natural language)
31
31
  - SearchOperator (executes a search query on the input Context)
@@ -160,6 +160,8 @@ class Aggregate(LogicalOperator):
160
160
  kwargs["output_schema"] = Count
161
161
  elif agg_func == AggFunc.AVERAGE:
162
162
  kwargs["output_schema"] = Average
163
+ elif agg_func == AggFunc.SUM:
164
+ kwargs["output_schema"] = Sum
163
165
  elif agg_func == AggFunc.MIN:
164
166
  kwargs["output_schema"] = Min
165
167
  elif agg_func == AggFunc.MAX:
@@ -411,17 +413,25 @@ class GroupByAggregate(LogicalOperator):
411
413
 
412
414
 
413
415
  class JoinOp(LogicalOperator):
414
- def __init__(self, condition: str, desc: str | None = None, *args, **kwargs):
416
+ def __init__(self, condition: str, on: list[str] | None = None, how: str = "inner", desc: str | None = None, *args, **kwargs):
415
417
  super().__init__(*args, **kwargs)
416
418
  self.condition = condition
419
+ self.on = on
420
+ self.how = how
417
421
  self.desc = desc
418
422
 
419
423
  def __str__(self):
420
- return f"Join(condition={self.condition})"
424
+ return f"Join(condition={self.condition})" if self.on is None else f"Join(on={self.on}, how={self.how})"
421
425
 
422
426
  def get_logical_id_params(self) -> dict:
423
427
  logical_id_params = super().get_logical_id_params()
424
- logical_id_params = {"condition": self.condition, "desc": self.desc, **logical_id_params}
428
+ logical_id_params = {
429
+ "condition": self.condition,
430
+ "on": self.on,
431
+ "how": self.how,
432
+ "desc": self.desc,
433
+ **logical_id_params,
434
+ }
425
435
 
426
436
  return logical_id_params
427
437
 
@@ -429,6 +439,8 @@ class JoinOp(LogicalOperator):
429
439
  logical_op_params = super().get_logical_op_params()
430
440
  logical_op_params = {
431
441
  "condition": self.condition,
442
+ "on": self.on,
443
+ "how": self.how,
432
444
  "desc": self.desc,
433
445
  **logical_op_params,
434
446
  }
@@ -484,8 +496,8 @@ class Project(LogicalOperator):
484
496
  return logical_op_params
485
497
 
486
498
 
487
- class RetrieveScan(LogicalOperator):
488
- """A RetrieveScan is a logical operator that represents a scan of a particular input Dataset, with a convert-like retrieve applied."""
499
+ class TopKScan(LogicalOperator):
500
+ """A TopKScan is a logical operator that represents a scan of a particular input Dataset, with a top-k operation applied."""
489
501
 
490
502
  def __init__(
491
503
  self,
@@ -505,7 +517,7 @@ class RetrieveScan(LogicalOperator):
505
517
  self.k = k
506
518
 
507
519
  def __str__(self):
508
- return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)})"
520
+ return f"TopKScan({self.input_schema} -> {str(self.output_schema)})"
509
521
 
510
522
  def get_logical_id_params(self) -> dict:
511
523
  # NOTE: if we allow optimization over index, then we will need to include it in the id params
@@ -75,8 +75,9 @@ class MixtureOfAgentsConvert(LLMConvert):
75
75
  In practice, this naive quality estimate will be overwritten by the CostModel's estimate
76
76
  once it executes a few instances of the operator.
77
77
  """
78
- # temporarily set self.model so that super().naive_cost_estimates(...) can compute an estimate
78
+ # temporarily set self.model and self.prompt_strategy so that super().naive_cost_estimates(...) can compute an estimate
79
79
  self.model = self.proposer_models[0]
80
+ self.prompt_strategy = PromptStrategy.MAP_MOA_PROPOSER
80
81
 
81
82
  # get naive cost estimates for single LLM call and scale it by number of LLMs used in MoA
82
83
  naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
@@ -98,6 +99,7 @@ class MixtureOfAgentsConvert(LLMConvert):
98
99
 
99
100
  # reset self.model to be None
100
101
  self.model = None
102
+ self.prompt_strategy = None
101
103
 
102
104
  return naive_op_cost_estimates
103
105
 
@@ -42,10 +42,13 @@ class PhysicalOperator:
42
42
  self.op_id = None
43
43
 
44
44
  # compute the input modalities (if any) for this physical operator
45
+ depends_on_short_field_names = [field.split(".")[-1] for field in self.depends_on] if self.depends_on is not None else None
45
46
  self.input_modalities = None
46
47
  if self.input_schema is not None:
47
48
  self.input_modalities = set()
48
- for field in self.input_schema.model_fields.values():
49
+ for field_name, field in self.input_schema.model_fields.items():
50
+ if self.depends_on is not None and field_name not in depends_on_short_field_names:
51
+ continue
49
52
  field_type = field.annotation
50
53
  if field_type in IMAGE_FIELD_TYPES:
51
54
  self.input_modalities.add(Modality.IMAGE)
@@ -191,7 +194,7 @@ class PhysicalOperator:
191
194
  in the candidate. This is important for operators with retry logic, where we may only need to
192
195
  recompute a subset of self.generated_fields.
193
196
 
194
- Right now this is only used by convert and retrieve operators.
197
+ Right now this is only used by convert and top-k operators.
195
198
  """
196
199
  fields_to_generate = [
197
200
  field_name
@@ -17,7 +17,7 @@ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, Recor
17
17
  from palimpzest.query.operators.physical import PhysicalOperator
18
18
 
19
19
 
20
- class RetrieveOp(PhysicalOperator):
20
+ class TopKOp(PhysicalOperator):
21
21
  def __init__(
22
22
  self,
23
23
  index: Collection,
@@ -29,7 +29,7 @@ class RetrieveOp(PhysicalOperator):
29
29
  **kwargs,
30
30
  ) -> None:
31
31
  """
32
- Initialize the RetrieveOp object.
32
+ Initialize the TopKOp object.
33
33
 
34
34
  Args:
35
35
  index (Collection): The PZ index to use for retrieval.
@@ -59,7 +59,7 @@ class RetrieveOp(PhysicalOperator):
59
59
 
60
60
  def __str__(self):
61
61
  op = super().__str__()
62
- op += f" Retrieve: {self.index.__class__.__name__} with top {self.k}\n"
62
+ op += f" Top-K: {self.index.__class__.__name__} with k={self.k}\n"
63
63
  return op
64
64
 
65
65
  def get_id_params(self):
@@ -89,8 +89,8 @@ class RetrieveOp(PhysicalOperator):
89
89
 
90
90
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
91
91
  """
92
- Compute naive cost estimates for the Retrieve operation. These estimates assume
93
- that the Retrieve (1) has no cost and (2) has perfect quality.
92
+ Compute naive cost estimates for the Top-K operation. These estimates assume
93
+ that the Top-K (1) has negligible cost and (2) has perfect quality.
94
94
  """
95
95
  return OperatorCostEstimates(
96
96
  cardinality=source_op_cost_estimates.cardinality,
@@ -101,7 +101,7 @@ class RetrieveOp(PhysicalOperator):
101
101
 
102
102
  def default_search_func(self, index: Collection, query: list[str] | list[list[float]], k: int) -> list[str] | list[list[str]]:
103
103
  """
104
- Default search function for the Retrieve operation. This function uses the index to
104
+ Default search function for the Top-K operation. This function uses the index to
105
105
  retrieve the top-k results for the given query. The query will be a (possibly singleton)
106
106
  list of strings or a list of lists of floats (i.e., embeddings). The function will return
107
107
  the top-k results per-query in (descending) sorted order. If the input is a singleton list,
@@ -111,7 +111,7 @@ class RetrieveOp(PhysicalOperator):
111
111
  Args:
112
112
  index (PZIndex): The index to use for retrieval.
113
113
  query (list[str] | list[list[float]]): The query (or queries) to search for.
114
- k (int): The maximum number of results the retrieve operator will return.
114
+ k (int): The maximum number of results the top-k operator will return.
115
115
 
116
116
  Returns:
117
117
  list[str] | list[list[str]]: The top results in (descending) sorted order per query.
@@ -260,10 +260,10 @@ class RetrieveOp(PhysicalOperator):
260
260
  top_results = self.search_func(self.index, inputs, self.k)
261
261
 
262
262
  except Exception:
263
- top_results = ["error-in-retrieve"]
264
- os.makedirs("retrieve-errors", exist_ok=True)
263
+ top_results = ["error-in-topk"]
264
+ os.makedirs("topk-errors", exist_ok=True)
265
265
  ts = time.time()
266
- with open(f"retrieve-errors/error-{ts}.txt", "w") as f:
266
+ with open(f"topk-errors/error-{ts}.txt", "w") as f:
267
267
  f.write(str(query))
268
268
 
269
269
  # TODO: the user is always right! let's drop this post-processing in the future
@@ -39,10 +39,10 @@ from palimpzest.query.optimizer.rules import (
39
39
  RAGRule as _RAGRule,
40
40
  )
41
41
  from palimpzest.query.optimizer.rules import (
42
- ReorderConverts as _ReorderConverts,
42
+ RelationalJoinRule as _RelationalJoinRule,
43
43
  )
44
44
  from palimpzest.query.optimizer.rules import (
45
- RetrieveRule as _RetrieveRule,
45
+ ReorderConverts as _ReorderConverts,
46
46
  )
47
47
  from palimpzest.query.optimizer.rules import (
48
48
  Rule as _Rule,
@@ -53,6 +53,9 @@ from palimpzest.query.optimizer.rules import (
53
53
  from palimpzest.query.optimizer.rules import (
54
54
  SplitRule as _SplitRule,
55
55
  )
56
+ from palimpzest.query.optimizer.rules import (
57
+ TopKRule as _TopKRule,
58
+ )
56
59
  from palimpzest.query.optimizer.rules import (
57
60
  TransformationRule as _TransformationRule,
58
61
  )
@@ -72,8 +75,9 @@ ALL_RULES = [
72
75
  _NonLLMFilterRule,
73
76
  _PushDownFilter,
74
77
  _RAGRule,
78
+ _RelationalJoinRule,
75
79
  _ReorderConverts,
76
- _RetrieveRule,
80
+ _TopKRule,
77
81
  _Rule,
78
82
  _SemanticAggregateRule,
79
83
  _SplitRule,
@@ -131,17 +131,17 @@ class SampleBasedCostModel:
131
131
  # compute selectivity
132
132
  selectivity = physical_op_df.passed_operator.sum() / num_source_records
133
133
 
134
+ # compute quality; if all qualities are None then this will be NaN
135
+ quality = physical_op_df.quality.mean()
136
+
137
+ # set operator stats for this physical operator
134
138
  operator_to_stats[unique_logical_op_id][full_op_id] = {
135
139
  "cost": physical_op_df.cost_per_record.mean(),
136
140
  "time": physical_op_df.time_per_record.mean(),
137
- "quality": physical_op_df.quality.mean(),
141
+ "quality": 1.0 if pd.isna(quality) else quality,
138
142
  "selectivity": selectivity,
139
143
  }
140
144
 
141
- # if this is an experiment, log the dataframe and operator_to_stats dictionary
142
- if self.exp_name is not None:
143
- operator_stats_df.to_csv(f"opt-profiling-data/{self.exp_name}-operator-stats.csv", index=False)
144
-
145
145
  logger.debug(f"Done computing operator statistics for {len(operator_to_stats)} operators!")
146
146
  return operator_to_stats
147
147
 
@@ -284,10 +284,11 @@ class Optimizer:
284
284
  all_properties["filters"] = set([op_filter_str])
285
285
 
286
286
  elif isinstance(op, JoinOp):
287
+ unique_join_str = str(sorted(op.on)) if op.condition is None else op.condition
287
288
  if "joins" in all_properties:
288
- all_properties["joins"].add(op.condition)
289
+ all_properties["joins"].add(unique_join_str)
289
290
  else:
290
- all_properties["joins"] = set([op.condition])
291
+ all_properties["joins"] = set([unique_join_str])
291
292
 
292
293
  elif isinstance(op, LimitScan):
293
294
  op_limit_str = op.get_logical_op_id()
@@ -203,9 +203,8 @@ class PhysicalPlan(Plan):
203
203
  # return the current index and the upstream unique full_op_ids for this operator
204
204
  return current_idx, self.operator.get_full_op_id(), upstream_map[this_unique_full_op_id]
205
205
 
206
- def get_upstream_unique_full_op_ids(self, topo_idx: int, operator: PhysicalOperator) -> list[str]:
207
- """Return the list of unique full_op_ids for the upstream operators of this operator."""
208
- unique_full_op_id = f"{topo_idx}-{operator.get_full_op_id()}"
206
+ def get_upstream_unique_full_op_ids(self, unique_full_op_id: str) -> list[str]:
207
+ """Return the list of unique full_op_ids for the upstream operators of the operator specified by `unique_full_op_id`."""
209
208
  return self.unique_full_op_id_to_upstream_full_op_ids[unique_full_op_id]
210
209
 
211
210
  def _compute_source_unique_full_op_ids_map(self, source_map: dict[str, list[str]], current_idx: int | None = None) -> tuple[int, str]:
@@ -19,13 +19,14 @@ from palimpzest.query.operators.aggregate import (
19
19
  MaxAggregateOp,
20
20
  MinAggregateOp,
21
21
  SemanticAggregate,
22
+ SumAggregateOp,
22
23
  )
23
24
  from palimpzest.query.operators.compute import SmolAgentsCompute
24
25
  from palimpzest.query.operators.convert import LLMConvertBonded, NonLLMConvert
25
26
  from palimpzest.query.operators.critique_and_refine import CritiqueAndRefineConvert, CritiqueAndRefineFilter
26
27
  from palimpzest.query.operators.distinct import DistinctOp
27
28
  from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter
28
- from palimpzest.query.operators.join import EmbeddingJoin, NestedLoopsJoin
29
+ from palimpzest.query.operators.join import EmbeddingJoin, NestedLoopsJoin, RelationalJoin
29
30
  from palimpzest.query.operators.limit import LimitScanOp
30
31
  from palimpzest.query.operators.logical import (
31
32
  Aggregate,
@@ -39,19 +40,19 @@ from palimpzest.query.operators.logical import (
39
40
  JoinOp,
40
41
  LimitScan,
41
42
  Project,
42
- RetrieveScan,
43
43
  SearchOperator,
44
+ TopKScan,
44
45
  )
45
46
  from palimpzest.query.operators.mixture_of_agents import MixtureOfAgentsConvert, MixtureOfAgentsFilter
46
47
  from palimpzest.query.operators.physical import PhysicalOperator
47
48
  from palimpzest.query.operators.project import ProjectOp
48
49
  from palimpzest.query.operators.rag import RAGConvert, RAGFilter
49
- from palimpzest.query.operators.retrieve import RetrieveOp
50
50
  from palimpzest.query.operators.scan import ContextScanOp, MarshalAndScanDataOp
51
51
  from palimpzest.query.operators.search import (
52
52
  SmolAgentsSearch, # SmolAgentsCustomManagedSearch, # SmolAgentsManagedSearch
53
53
  )
54
54
  from palimpzest.query.operators.split import SplitConvert, SplitFilter
55
+ from palimpzest.query.operators.topk import TopKOp
55
56
  from palimpzest.query.optimizer.primitives import Expression, Group, LogicalExpression, PhysicalExpression
56
57
 
57
58
  logger = logging.getLogger(__name__)
@@ -796,26 +797,26 @@ class SplitRule(ImplementationRule):
796
797
  return cls._perform_substitution(logical_expression, phys_op_cls, runtime_kwargs, variable_op_kwargs)
797
798
 
798
799
 
799
- class RetrieveRule(ImplementationRule):
800
+ class TopKRule(ImplementationRule):
800
801
  """
801
- Substitute a logical expression for a RetrieveScan with a Retrieve physical implementation.
802
+ Substitute a logical expression for a TopKScan with a TopK physical implementation.
802
803
  """
803
804
  k_budgets = [1, 3, 5, 10, 15, 20, 25]
804
805
 
805
806
  @classmethod
806
807
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
807
- is_match = isinstance(logical_expression.operator, RetrieveScan)
808
- logger.debug(f"RetrieveRule matches_pattern: {is_match} for {logical_expression}")
808
+ is_match = isinstance(logical_expression.operator, TopKScan)
809
+ logger.debug(f"TopKRule matches_pattern: {is_match} for {logical_expression}")
809
810
  return is_match
810
811
 
811
812
  @classmethod
812
813
  def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
813
- logger.debug(f"Substituting RetrieveRule for {logical_expression}")
814
+ logger.debug(f"Substituting TopKRule for {logical_expression}")
814
815
 
815
816
  # create variable physical operator kwargs for each model which can implement this logical_expression
816
817
  ks = cls.k_budgets if logical_expression.operator.k == -1 else [logical_expression.operator.k]
817
818
  variable_op_kwargs = [{"k": k} for k in ks]
818
- return cls._perform_substitution(logical_expression, RetrieveOp, runtime_kwargs, variable_op_kwargs)
819
+ return cls._perform_substitution(logical_expression, TopKOp, runtime_kwargs, variable_op_kwargs)
819
820
 
820
821
 
821
822
  class NonLLMFilterRule(ImplementationRule):
@@ -867,6 +868,23 @@ class LLMFilterRule(ImplementationRule):
867
868
  return cls._perform_substitution(logical_expression, LLMFilter, runtime_kwargs, variable_op_kwargs)
868
869
 
869
870
 
871
+ class RelationalJoinRule(ImplementationRule):
872
+ """
873
+ Substitute a logical expression for a JoinOp with a RelationalJoin physical implementation.
874
+ """
875
+
876
+ @classmethod
877
+ def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
878
+ is_match = isinstance(logical_expression.operator, JoinOp) and logical_expression.operator.condition == ""
879
+ logger.debug(f"RelationalJoinRule matches_pattern: {is_match} for {logical_expression}")
880
+ return is_match
881
+
882
+ @classmethod
883
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
884
+ logger.debug(f"Substituting RelationalJoinRule for {logical_expression}")
885
+ return cls._perform_substitution(logical_expression, RelationalJoin, runtime_kwargs)
886
+
887
+
870
888
  class NestedLoopsJoinRule(ImplementationRule):
871
889
  """
872
890
  Substitute a logical expression for a JoinOp with an (LLM) NestedLoopsJoin physical implementation.
@@ -874,7 +892,7 @@ class NestedLoopsJoinRule(ImplementationRule):
874
892
 
875
893
  @classmethod
876
894
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
877
- is_match = isinstance(logical_expression.operator, JoinOp)
895
+ is_match = isinstance(logical_expression.operator, JoinOp) and logical_expression.operator.condition != ""
878
896
  logger.debug(f"NestedLoopsJoinRule matches_pattern: {is_match} for {logical_expression}")
879
897
  return is_match
880
898
 
@@ -906,7 +924,7 @@ class EmbeddingJoinRule(ImplementationRule):
906
924
 
907
925
  @classmethod
908
926
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
909
- is_match = isinstance(logical_expression.operator, JoinOp) and not cls._is_audio_operation(logical_expression)
927
+ is_match = isinstance(logical_expression.operator, JoinOp) and logical_expression.operator.condition != "" and not cls._is_audio_operation(logical_expression)
910
928
  logger.debug(f"EmbeddingJoinRule matches_pattern: {is_match} for {logical_expression}")
911
929
  return is_match
912
930
 
@@ -982,6 +1000,8 @@ class AggregateRule(ImplementationRule):
982
1000
  physical_op_class = CountAggregateOp
983
1001
  elif logical_expression.operator.agg_func == AggFunc.AVERAGE:
984
1002
  physical_op_class = AverageAggregateOp
1003
+ elif logical_expression.operator.agg_func == AggFunc.SUM:
1004
+ physical_op_class = SumAggregateOp
985
1005
  elif logical_expression.operator.agg_func == AggFunc.MIN:
986
1006
  physical_op_class = MinAggregateOp
987
1007
  elif logical_expression.operator.agg_func == AggFunc.MAX:
@@ -501,8 +501,8 @@ class OptimizePhysicalExpression(Task):
501
501
 
502
502
  # compute the total cost for this physical expression by summing its operator's PlanCost
503
503
  # with the input groups' total PlanCost; also set the op_estimates for this expression's operator
504
- execution_strategy = "parallel" if execution_strategy.is_fully_parallel() else "sequential"
505
- full_plan_cost = op_plan_cost.join_add(left_input_plan_cost, right_input_plan_cost, execution_strategy)
504
+ execution_strategy_str = "parallel" if execution_strategy.is_fully_parallel() else "sequential"
505
+ full_plan_cost = op_plan_cost.join_add(left_input_plan_cost, right_input_plan_cost, execution_strategy_str)
506
506
  full_plan_cost.op_estimates = op_plan_cost.op_estimates
507
507
  all_possible_plan_costs.append((full_plan_cost, (left_input_plan_cost, right_input_plan_cost)))
508
508
 
@@ -570,8 +570,8 @@ class OptimizePhysicalExpression(Task):
570
570
 
571
571
  # compute the total cost for this physical expression by summing its operator's PlanCost
572
572
  # with the input groups' total PlanCost; also set the op_estimates for this expression's operator
573
- execution_strategy = "parallel" if execution_strategy.is_fully_parallel() else "sequential"
574
- full_plan_cost = op_plan_cost.join_add(left_best_input_plan_cost, right_best_input_plan_cost, execution_strategy)
573
+ execution_strategy_str = "parallel" if execution_strategy.is_fully_parallel() else "sequential"
574
+ full_plan_cost = op_plan_cost.join_add(left_best_input_plan_cost, right_best_input_plan_cost, execution_strategy_str)
575
575
  full_plan_cost.op_estimates = op_plan_cost.op_estimates
576
576
 
577
577
  else:
@@ -24,7 +24,7 @@ from palimpzest.query.operators.filter import LLMFilter
24
24
  from palimpzest.query.operators.join import JoinOp
25
25
  from palimpzest.query.operators.limit import LimitScanOp
26
26
  from palimpzest.query.operators.physical import PhysicalOperator
27
- from palimpzest.query.operators.retrieve import RetrieveOp
27
+ from palimpzest.query.operators.topk import TopKOp
28
28
  from palimpzest.query.optimizer.plan import PhysicalPlan, SentinelPlan
29
29
 
30
30
 
@@ -225,20 +225,22 @@ class PZProgressManager(ProgressManager):
225
225
  current_unique_full_op_id = unique_full_op_id
226
226
  next_op, next_unique_full_op_id = self.unique_full_op_id_to_next_op_and_id[unique_full_op_id]
227
227
  while next_op is not None:
228
- if not isinstance(next_op, (AggregateOp, LimitScanOp)):
229
- next_task = self.unique_full_op_id_to_task[next_unique_full_op_id]
230
- multiplier = 1
231
- if isinstance(next_op, JoinOp):
232
- # for joins, scale the delta by the number of inputs from the other side of the join
233
- left_input_unique_full_op_id, right_input_unique_input_op_id = self.unique_full_op_id_to_input_unique_full_op_ids[next_unique_full_op_id]
234
- if current_unique_full_op_id == left_input_unique_full_op_id:
235
- multiplier = self.get_task_total(right_input_unique_input_op_id)
236
- elif current_unique_full_op_id == right_input_unique_input_op_id:
237
- multiplier = self.get_task_total(left_input_unique_full_op_id)
238
- else:
239
- raise ValueError(f"Current op ID {current_unique_full_op_id} not found in join inputs {left_input_unique_full_op_id}, {right_input_unique_input_op_id}")
240
- delta_adjusted = delta * multiplier
241
- self.progress.update(next_task, total=self.get_task_total(next_unique_full_op_id) + delta_adjusted)
228
+ if isinstance(next_op, (AggregateOp, LimitScanOp)):
229
+ break
230
+
231
+ next_task = self.unique_full_op_id_to_task[next_unique_full_op_id]
232
+ multiplier = 1
233
+ if isinstance(next_op, JoinOp):
234
+ # for joins, scale the delta by the number of inputs from the other side of the join
235
+ left_input_unique_full_op_id, right_input_unique_input_op_id = self.unique_full_op_id_to_input_unique_full_op_ids[next_unique_full_op_id]
236
+ if current_unique_full_op_id == left_input_unique_full_op_id:
237
+ multiplier = self.get_task_total(right_input_unique_input_op_id)
238
+ elif current_unique_full_op_id == right_input_unique_input_op_id:
239
+ multiplier = self.get_task_total(left_input_unique_full_op_id)
240
+ else:
241
+ raise ValueError(f"Current op ID {current_unique_full_op_id} not found in join inputs {left_input_unique_full_op_id}, {right_input_unique_input_op_id}")
242
+ delta_adjusted = delta * multiplier
243
+ self.progress.update(next_task, total=self.get_task_total(next_unique_full_op_id) + delta_adjusted)
242
244
 
243
245
  # move to the next operator in the plan
244
246
  current_unique_full_op_id = next_unique_full_op_id
@@ -348,9 +350,9 @@ class PZSentinelProgressManager(ProgressManager):
348
350
  def _is_llm_op(self, physical_op: PhysicalOperator) -> bool:
349
351
  is_llm_convert = isinstance(physical_op, LLMConvert)
350
352
  is_llm_filter = isinstance(physical_op, LLMFilter)
351
- is_llm_retrieve = isinstance(physical_op, RetrieveOp) and isinstance(physical_op.index, Collection)
353
+ is_llm_topk = isinstance(physical_op, TopKOp) and isinstance(physical_op.index, Collection)
352
354
  is_llm_join = isinstance(physical_op, JoinOp)
353
- return is_llm_convert or is_llm_filter or is_llm_retrieve or is_llm_join
355
+ return is_llm_convert or is_llm_filter or is_llm_topk or is_llm_join
354
356
 
355
357
  def get_task_description(self, unique_logical_op_id: str) -> str:
356
358
  """Return the current description for the given task."""
@@ -19,7 +19,7 @@ from palimpzest.query.generators.generators import get_json_from_answer
19
19
  from palimpzest.query.operators.convert import LLMConvert
20
20
  from palimpzest.query.operators.filter import LLMFilter
21
21
  from palimpzest.query.operators.join import JoinOp
22
- from palimpzest.query.operators.retrieve import RetrieveOp
22
+ from palimpzest.query.operators.topk import TopKOp
23
23
 
24
24
 
25
25
  class Validator:
@@ -47,7 +47,7 @@ class Validator:
47
47
  def join_score_fn(self, condition: str, left_input_record: dict, right_input_record: dict, output: bool) -> float | None:
48
48
  raise NotImplementedError("Validator.join_score_fn not implemented.")
49
49
 
50
- def retrieve_score_fn(self, fields: list[str], input_record: dict, output: dict) -> float | None:
50
+ def topk_score_fn(self, fields: list[str], input_record: dict, output: dict) -> float | None:
51
51
  raise NotImplementedError("Validator.map_score_fn not implemented.")
52
52
 
53
53
  def _get_gen_stats_from_completion(self, completion, start_time: float) -> GenerationStats:
@@ -218,11 +218,11 @@ class Validator:
218
218
 
219
219
  return score, gen_stats
220
220
 
221
- def _default_retrieve_score_fn(self, op: RetrieveOp, fields: list[str], input_record: DataRecord, output: dict) -> tuple[float | None, GenerationStats]:
221
+ def _default_topk_score_fn(self, op: TopKOp, fields: list[str], input_record: DataRecord, output: dict) -> tuple[float | None, GenerationStats]:
222
222
  """
223
223
  Compute the quality of the generated output for the given fields and input_record.
224
224
  """
225
- # TODO: retrieve k=25; score each item based on relevance; compute F1
225
+ # TODO: top-k k=25; score each item based on relevance; compute F1
226
226
  # TODO: support retrieval over images
227
227
  # create prompt factory
228
228
  factory = PromptFactory(PromptStrategy.MAP, self.model, Cardinality.ONE_TO_ONE)
@@ -294,11 +294,11 @@ class Validator:
294
294
  score, gen_stats = self._default_join_score_fn(op, condition, left_input_record, right_input_record, output)
295
295
  return score, gen_stats, full_hash
296
296
 
297
- def _score_retrieve(self, op: RetrieveOp, fields: list[str], input_record: DataRecord, output: dict, full_hash: str) -> tuple[float | None, GenerationStats, str]:
297
+ def _score_topk(self, op: TopKOp, fields: list[str], input_record: DataRecord, output: dict, full_hash: str) -> tuple[float | None, GenerationStats, str]:
298
298
  try:
299
- out = self.retrieve_score_fn(fields, input_record.to_dict(), output)
299
+ out = self.topk_score_fn(fields, input_record.to_dict(), output)
300
300
  score, gen_stats = out if isinstance(out, tuple) else (out, GenerationStats())
301
301
  return score, gen_stats, full_hash
302
302
  except NotImplementedError:
303
- score, gen_stats = self._default_retrieve_score_fn(op, fields, input_record, output)
303
+ score, gen_stats = self._default_topk_score_fn(op, fields, input_record, output)
304
304
  return score, gen_stats, full_hash