palimpzest 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. palimpzest/__init__.py +7 -9
  2. palimpzest/constants.py +47 -7
  3. palimpzest/core/__init__.py +20 -26
  4. palimpzest/core/data/dataclasses.py +9 -2
  5. palimpzest/core/data/datareaders.py +497 -0
  6. palimpzest/core/elements/records.py +29 -37
  7. palimpzest/core/lib/fields.py +14 -12
  8. palimpzest/core/lib/schemas.py +80 -94
  9. palimpzest/policy.py +58 -0
  10. palimpzest/prompts/__init__.py +22 -0
  11. palimpzest/prompts/code_synthesis_prompts.py +28 -0
  12. palimpzest/prompts/convert_prompts.py +87 -0
  13. palimpzest/prompts/critique_and_refine_convert_prompts.py +216 -0
  14. palimpzest/prompts/filter_prompts.py +69 -0
  15. palimpzest/prompts/moa_aggregator_convert_prompts.py +57 -0
  16. palimpzest/prompts/moa_proposer_convert_prompts.py +79 -0
  17. palimpzest/prompts/prompt_factory.py +732 -0
  18. palimpzest/prompts/util_phrases.py +14 -0
  19. palimpzest/query/execution/execution_strategy.py +0 -3
  20. palimpzest/query/execution/parallel_execution_strategy.py +12 -25
  21. palimpzest/query/execution/single_threaded_execution_strategy.py +31 -45
  22. palimpzest/query/generators/generators.py +71 -347
  23. palimpzest/query/operators/__init__.py +5 -5
  24. palimpzest/query/operators/aggregate.py +10 -5
  25. palimpzest/query/operators/code_synthesis_convert.py +4 -48
  26. palimpzest/query/operators/convert.py +5 -2
  27. palimpzest/query/operators/critique_and_refine_convert.py +112 -0
  28. palimpzest/query/operators/filter.py +1 -1
  29. palimpzest/query/operators/limit.py +1 -1
  30. palimpzest/query/operators/logical.py +28 -27
  31. palimpzest/query/operators/mixture_of_agents_convert.py +4 -1
  32. palimpzest/query/operators/physical.py +32 -20
  33. palimpzest/query/operators/project.py +1 -1
  34. palimpzest/query/operators/rag_convert.py +6 -3
  35. palimpzest/query/operators/retrieve.py +13 -31
  36. palimpzest/query/operators/scan.py +150 -0
  37. palimpzest/query/optimizer/__init__.py +5 -1
  38. palimpzest/query/optimizer/cost_model.py +18 -34
  39. palimpzest/query/optimizer/optimizer.py +40 -25
  40. palimpzest/query/optimizer/optimizer_strategy.py +26 -0
  41. palimpzest/query/optimizer/plan.py +2 -2
  42. palimpzest/query/optimizer/rules.py +118 -27
  43. palimpzest/query/processor/config.py +12 -1
  44. palimpzest/query/processor/mab_sentinel_processor.py +125 -112
  45. palimpzest/query/processor/nosentinel_processor.py +46 -62
  46. palimpzest/query/processor/query_processor.py +10 -20
  47. palimpzest/query/processor/query_processor_factory.py +12 -5
  48. palimpzest/query/processor/random_sampling_sentinel_processor.py +112 -91
  49. palimpzest/query/processor/streaming_processor.py +11 -17
  50. palimpzest/sets.py +170 -94
  51. palimpzest/tools/pdfparser.py +5 -64
  52. palimpzest/utils/datareader_helpers.py +61 -0
  53. palimpzest/utils/field_helpers.py +69 -0
  54. palimpzest/utils/hash_helpers.py +3 -2
  55. palimpzest/utils/udfs.py +0 -28
  56. {palimpzest-0.5.3.dist-info → palimpzest-0.6.0.dist-info}/METADATA +49 -49
  57. palimpzest-0.6.0.dist-info/RECORD +87 -0
  58. {palimpzest-0.5.3.dist-info → palimpzest-0.6.0.dist-info}/top_level.txt +0 -1
  59. cli/README.md +0 -156
  60. cli/__init__.py +0 -0
  61. cli/cli_main.py +0 -390
  62. palimpzest/config.py +0 -89
  63. palimpzest/core/data/datasources.py +0 -369
  64. palimpzest/datamanager/__init__.py +0 -0
  65. palimpzest/datamanager/datamanager.py +0 -300
  66. palimpzest/prompts.py +0 -397
  67. palimpzest/query/operators/datasource.py +0 -202
  68. palimpzest-0.5.3.dist-info/RECORD +0 -83
  69. palimpzest-0.5.3.dist-info/entry_points.txt +0 -2
  70. {palimpzest-0.5.3.dist-info → palimpzest-0.6.0.dist-info}/LICENSE +0 -0
  71. {palimpzest-0.5.3.dist-info → palimpzest-0.6.0.dist-info}/WHEEL +0 -0
@@ -9,9 +9,10 @@ from palimpzest.query.operators.physical import PhysicalOperator
9
9
 
10
10
 
11
11
  class RetrieveOp(PhysicalOperator):
12
- def __init__(self, index, search_attr, output_attr, k, *args, **kwargs):
12
+ def __init__(self, index, search_func, search_attr, output_attr, k, *args, **kwargs):
13
13
  super().__init__(*args, **kwargs)
14
14
  self.index = index
15
+ self.search_func = search_func
15
16
  self.search_attr = search_attr
16
17
  self.output_attr = output_attr
17
18
  self.k = k
@@ -36,6 +37,7 @@ class RetrieveOp(PhysicalOperator):
36
37
  op_params = super().get_op_params()
37
38
  op_params = {
38
39
  "index": self.index,
40
+ "search_func": self.search_func,
39
41
  "search_attr": self.search_attr,
40
42
  "output_attr": self.output_attr,
41
43
  "k": self.k,
@@ -61,51 +63,31 @@ class RetrieveOp(PhysicalOperator):
61
63
 
62
64
  query = getattr(candidate, self.search_attr)
63
65
 
64
- top_k_results, top_k_result_doc_ids = [], []
65
- if isinstance(query, str):
66
- results = self.index.search(query, k=self.k)
67
- top_k_results = [result["content"] for result in results]
68
-
69
- # This is hacky, fix this later.
70
- top_k_result_doc_ids = list({result["document_id"] for result in results})
71
-
72
- elif isinstance(query, list):
73
- try:
74
- # retrieve top entry for each query
75
- results = self.index.search(query, k=1)
76
-
77
- # filter for the top-k entries
78
- results = [result[0] if isinstance(result, list) else result for result in results]
79
- sorted_results = sorted(results, key=lambda result: result["score"], reverse=True)
80
- top_k_results = [result["content"] for result in sorted_results[:self.k]]
81
- top_k_result_doc_ids = [result["document_id"] for result in sorted_results[:self.k]]
82
- except Exception:
83
- os.makedirs("retrieve-errors", exist_ok=True)
84
- ts = time.time()
85
- with open(f"retrieve-errors/error-{ts}.txt", "w") as f:
86
- f.write(str(query))
87
-
88
- top_k_results = ["error-in-retrieve"]
89
- top_k_result_doc_ids = ["error-in-retrieve"]
66
+ try:
67
+ top_k_results = self.search_func(self.index, query, self.k)
68
+ except Exception:
69
+ top_k_results = ["error-in-retrieve"]
70
+ os.makedirs("retrieve-errors", exist_ok=True)
71
+ ts = time.time()
72
+ with open(f"retrieve-errors/error-{ts}.txt", "w") as f:
73
+ f.write(str(query))
90
74
 
91
75
  output_dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
92
76
  setattr(output_dr, self.output_attr, top_k_results)
93
- output_dr._evidence_file_ids = top_k_result_doc_ids
94
77
 
95
78
  duration_secs = time.time() - start_time
96
79
  answer = {self.output_attr: top_k_results}
97
80
  record_state = output_dr.to_dict(include_bytes=False)
98
- record_state["_evidence_file_ids"] = top_k_result_doc_ids
99
81
 
100
82
  # NOTE: right now this should be equivalent to [self.output_attr], but in the future we may
101
83
  # want to support the RetrieveOp generating multiple fields. (Also, the function will
102
84
  # return the full field name (as opposed to the short field name))
103
- generated_fields = self.get_fields_to_generate(candidate, self.input_schema, self.output_schema)
85
+ generated_fields = self.get_fields_to_generate(candidate)
104
86
 
105
87
  record_op_stats = RecordOpStats(
106
88
  record_id=output_dr.id,
107
89
  record_parent_id=output_dr.parent_id,
108
- record_source_id=output_dr.source_id,
90
+ record_source_idx=output_dr.source_idx,
109
91
  record_state=record_state,
110
92
  op_id=self.get_op_id(),
111
93
  logical_op_id=self.logical_op_id,
@@ -0,0 +1,150 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from abc import ABC, abstractmethod
5
+
6
+ from palimpzest.constants import (
7
+ LOCAL_SCAN_TIME_PER_KB,
8
+ MEMORY_SCAN_TIME_PER_KB,
9
+ Cardinality,
10
+ )
11
+ from palimpzest.core.data.dataclasses import OperatorCostEstimates, RecordOpStats
12
+ from palimpzest.core.data.datareaders import DataReader, DirectoryReader, FileReader
13
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
14
+ from palimpzest.query.operators.physical import PhysicalOperator
15
+
16
+
17
+ class ScanPhysicalOp(PhysicalOperator, ABC):
18
+ """
19
+ Physical operators which implement DataReaders require slightly more information
20
+ in order to accurately compute naive cost estimates. Thus, we use a slightly
21
+ modified abstract base class for these operators.
22
+ """
23
+
24
+ def __init__(self, datareader: DataReader, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ self.datareader = datareader
27
+
28
+ def __str__(self):
29
+ op = f"{self.op_name()}({self.datareader}) -> {self.output_schema}\n"
30
+ op += f" ({', '.join(self.output_schema.field_names())[:30]})\n"
31
+ return op
32
+
33
+ def get_id_params(self):
34
+ return super().get_id_params()
35
+
36
+ def get_op_params(self):
37
+ op_params = super().get_op_params()
38
+ return {"datareader": self.datareader, **op_params}
39
+
40
+ @abstractmethod
41
+ def naive_cost_estimates(
42
+ self,
43
+ source_op_cost_estimates: OperatorCostEstimates,
44
+ input_cardinality: Cardinality,
45
+ input_record_size_in_bytes: int | float,
46
+ ) -> OperatorCostEstimates:
47
+ """
48
+ This function returns a naive estimate of this operator's:
49
+ - cardinality
50
+ - time_per_record
51
+ - cost_per_record
52
+ - quality
53
+
54
+ For the implemented operator. These will be used by the CostModel
55
+ when PZ does not have sample execution data -- and it will be necessary
56
+ in some cases even when sample execution data is present. (For example,
57
+ the cardinality of each operator cannot be estimated based on sample
58
+ execution data alone -- thus ScanPhysicalOps need to give
59
+ at least ballpark correct estimates of this quantity).
60
+ """
61
+ pass
62
+
63
+ def __call__(self, idx: int) -> DataRecordSet:
64
+ """
65
+ This function invokes `self.datareader.__getitem__` on the given `idx` to retrieve the next data item.
66
+ It then returns this item as a DataRecord wrapped in a DataRecordSet.
67
+ """
68
+ start_time = time.time()
69
+ item = self.datareader[idx]
70
+ end_time = time.time()
71
+
72
+ # check that item covers fields in output schema
73
+ output_field_names = self.output_schema.field_names()
74
+ assert all([field in item for field in output_field_names]), f"Some fields in DataReader schema not present in item!\n - DataReader fields: {output_field_names}\n - Item fields: {list(item.keys())}"
75
+
76
+ # construct a DataRecord from the item
77
+ dr = DataRecord(self.output_schema, source_idx=idx)
78
+ for field in output_field_names:
79
+ setattr(dr, field, item[field])
80
+
81
+ # create RecordOpStats objects
82
+ record_op_stats = RecordOpStats(
83
+ record_id=dr.id,
84
+ record_parent_id=dr.parent_id,
85
+ record_source_idx=dr.source_idx,
86
+ record_state=dr.to_dict(include_bytes=False),
87
+ op_id=self.get_op_id(),
88
+ logical_op_id=self.logical_op_id,
89
+ op_name=self.op_name(),
90
+ time_per_record=(end_time - start_time),
91
+ cost_per_record=0.0,
92
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
93
+ )
94
+
95
+ # construct and return DataRecordSet object
96
+ return DataRecordSet([dr], [record_op_stats])
97
+
98
+
99
+ class MarshalAndScanDataOp(ScanPhysicalOp):
100
+ def naive_cost_estimates(
101
+ self,
102
+ source_op_cost_estimates: OperatorCostEstimates,
103
+ input_record_size_in_bytes: int | float,
104
+ ) -> OperatorCostEstimates:
105
+ # get inputs needed for naive cost estimation
106
+ # TODO: we should rename cardinality --> "multiplier" or "selectivity" one-to-one / one-to-many
107
+
108
+ # estimate time spent reading each record
109
+ per_record_size_kb = input_record_size_in_bytes / 1024.0
110
+ time_per_record = (
111
+ LOCAL_SCAN_TIME_PER_KB * per_record_size_kb
112
+ if isinstance(self.datareader, (DirectoryReader, FileReader))
113
+ else MEMORY_SCAN_TIME_PER_KB * per_record_size_kb
114
+ )
115
+
116
+ # estimate output cardinality
117
+ cardinality = source_op_cost_estimates.cardinality
118
+
119
+ # for now, assume no cost per record for reading data
120
+ return OperatorCostEstimates(
121
+ cardinality=cardinality,
122
+ time_per_record=time_per_record,
123
+ cost_per_record=0,
124
+ quality=1.0,
125
+ )
126
+
127
+
128
+ class CacheScanDataOp(ScanPhysicalOp):
129
+ def naive_cost_estimates(
130
+ self,
131
+ source_op_cost_estimates: OperatorCostEstimates,
132
+ input_record_size_in_bytes: int | float,
133
+ ):
134
+ # get inputs needed for naive cost estimation
135
+ # TODO: we should rename cardinality --> "multiplier" or "selectivity" one-to-one / one-to-many
136
+
137
+ # estimate time spent reading each record
138
+ per_record_size_kb = input_record_size_in_bytes / 1024.0
139
+ time_per_record = LOCAL_SCAN_TIME_PER_KB * per_record_size_kb
140
+
141
+ # estimate output cardinality
142
+ cardinality = source_op_cost_estimates.cardinality
143
+
144
+ # for now, assume no cost per record for reading from cache
145
+ return OperatorCostEstimates(
146
+ cardinality=cardinality,
147
+ time_per_record=time_per_record,
148
+ cost_per_record=0,
149
+ quality=1.0,
150
+ )
@@ -10,6 +10,9 @@ from palimpzest.query.optimizer.rules import (
10
10
  from palimpzest.query.optimizer.rules import (
11
11
  CodeSynthesisConvertSingleRule as _CodeSynthesisConvertSingleRule,
12
12
  )
13
+ from palimpzest.query.optimizer.rules import (
14
+ CriticAndRefineConvertRule as _CriticAndRefineConvertRule,
15
+ )
13
16
  from palimpzest.query.optimizer.rules import (
14
17
  ImplementationRule as _ImplementationRule,
15
18
  )
@@ -64,6 +67,7 @@ ALL_RULES = [
64
67
  _BasicSubstitutionRule,
65
68
  _CodeSynthesisConvertRule,
66
69
  _CodeSynthesisConvertSingleRule,
70
+ _CriticAndRefineConvertRule,
67
71
  _ImplementationRule,
68
72
  _LLMConvertBondedRule,
69
73
  _LLMConvertConventionalRule,
@@ -86,7 +90,7 @@ IMPLEMENTATION_RULES = [
86
90
  rule
87
91
  for rule in ALL_RULES
88
92
  if issubclass(rule, _ImplementationRule)
89
- and rule not in [_CodeSynthesisConvertRule, _ImplementationRule, _LLMConvertRule, _RAGConvertRule, _TokenReducedConvertRule]
93
+ and rule not in [_CodeSynthesisConvertRule, _ImplementationRule, _LLMConvertRule, _TokenReducedConvertRule]
90
94
  ]
91
95
 
92
96
  TRANSFORMATION_RULES = [
@@ -14,18 +14,17 @@ from typing import Any
14
14
  import pandas as pd
15
15
  import scipy.stats as stats
16
16
 
17
- from palimpzest.constants import MODEL_CARDS, GPT_4o_MODEL_CARD, Model
17
+ from palimpzest.constants import MODEL_CARDS, NAIVE_BYTES_PER_RECORD, GPT_4o_MODEL_CARD, Model
18
18
  from palimpzest.core.data.dataclasses import OperatorCostEstimates, PlanCost, RecordOpStats
19
19
  from palimpzest.core.elements.records import DataRecordSet
20
- from palimpzest.datamanager.datamanager import DataDirectory
21
20
  from palimpzest.query.operators.aggregate import ApplyGroupByOp, AverageAggregateOp, CountAggregateOp
22
21
  from palimpzest.query.operators.code_synthesis_convert import CodeSynthesisConvert
23
22
  from palimpzest.query.operators.convert import LLMConvert
24
- from palimpzest.query.operators.datasource import CacheScanDataOp, DataSourcePhysicalOp, MarshalAndScanDataOp
25
23
  from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter
26
24
  from palimpzest.query.operators.limit import LimitScanOp
27
25
  from palimpzest.query.operators.physical import PhysicalOperator
28
26
  from palimpzest.query.operators.rag_convert import RAGConvert
27
+ from palimpzest.query.operators.scan import CacheScanDataOp, MarshalAndScanDataOp, ScanPhysicalOp
29
28
  from palimpzest.query.operators.token_reduction_convert import TokenReducedConvert
30
29
  from palimpzest.query.optimizer.plan import SentinelPlan
31
30
  from palimpzest.utils.model_helpers import get_champion_model_name, get_models
@@ -89,10 +88,6 @@ class SampleBasedCostModel:
89
88
  for phys_op_id, _ in phys_op_id_to_stats.items()
90
89
  ])
91
90
 
92
- # reference to data directory
93
- self.datadir = DataDirectory()
94
-
95
- # import pdb; pdb.set_trace()
96
91
 
97
92
  def get_costed_phys_op_ids(self):
98
93
  return self.costed_phys_op_ids
@@ -131,9 +126,9 @@ class SampleBasedCostModel:
131
126
  "time_per_record": record_op_stats.time_per_record,
132
127
  "quality": record_op_stats.quality,
133
128
  "passed_operator": record_op_stats.passed_operator,
134
- "source_id": record_op_stats.record_source_id, # TODO: remove
135
- "op_details": record_op_stats.op_details, # TODO: remove
136
- "answer": record_op_stats.answer, # TODO: remove
129
+ "source_idx": record_op_stats.record_source_idx, # TODO: remove
130
+ "op_details": record_op_stats.op_details, # TODO: remove
131
+ "answer": record_op_stats.answer, # TODO: remove
137
132
  }
138
133
  execution_record_op_stats.append(record_op_stats_dict)
139
134
 
@@ -189,14 +184,13 @@ class SampleBasedCostModel:
189
184
  est_quality = self.operator_to_stats[logical_op_id][phys_op_id]["quality"]
190
185
  est_selectivity = self.operator_to_stats[logical_op_id][phys_op_id]["selectivity"]
191
186
 
192
- # create source_op_estimates for datasources if they are not provided
193
- if isinstance(operator, DataSourcePhysicalOp):
194
- # get handle to DataSource and pre-compute its size (number of records)
195
- datasource = operator.get_datasource()
196
- datasource_len = len(datasource)
187
+ # create source_op_estimates for scan operators if they are not provided
188
+ if isinstance(operator, ScanPhysicalOp):
189
+ # get handle to scan operator and pre-compute its size (number of records)
190
+ datareader_len = len(operator.datareader)
197
191
 
198
192
  source_op_estimates = OperatorCostEstimates(
199
- cardinality=datasource_len,
193
+ cardinality=datareader_len,
200
194
  time_per_record=0.0,
201
195
  cost_per_record=0.0,
202
196
  quality=1.0,
@@ -245,9 +239,6 @@ class CostModel(BaseCostModel):
245
239
  # df contains a column called record_state, that sometimes contain a dict
246
240
  # we want to extract the keys from the dict and create a new column for each key
247
241
 
248
- # reference to data directory
249
- self.datadir = DataDirectory()
250
-
251
242
  # set available models
252
243
  self.available_models = available_models
253
244
 
@@ -610,36 +601,29 @@ class CostModel(BaseCostModel):
610
601
 
611
602
  # initialize estimates of operator metrics based on naive (but sometimes precise) logic
612
603
  if isinstance(operator, MarshalAndScanDataOp):
613
- # get handle to DataSource and pre-compute its size (number of records)
614
- datasource = operator.get_datasource()
615
- dataset_type = operator.get_datasource_type()
616
- datasource_len = len(datasource)
617
- datasource_memsize = datasource.get_size()
604
+ # get handle to scan operator and pre-compute its size (number of records)
605
+ datareader_len = len(operator.datareader)
618
606
 
619
607
  source_op_estimates = OperatorCostEstimates(
620
- cardinality=datasource_len,
608
+ cardinality=datareader_len,
621
609
  time_per_record=0.0,
622
610
  cost_per_record=0.0,
623
611
  quality=1.0,
624
612
  )
625
613
 
626
- op_estimates = operator.naive_cost_estimates(source_op_estimates,
627
- input_record_size_in_bytes=datasource_memsize/datasource_len,
628
- dataset_type=dataset_type)
614
+ op_estimates = operator.naive_cost_estimates(source_op_estimates, input_record_size_in_bytes=NAIVE_BYTES_PER_RECORD)
629
615
 
630
616
  elif isinstance(operator, CacheScanDataOp):
631
- datasource = operator.get_datasource()
632
- datasource_len = len(datasource)
633
- datasource_memsize = datasource.get_size()
617
+ datareader_len = len(operator.datareader)
634
618
 
635
619
  source_op_estimates = OperatorCostEstimates(
636
- cardinality=datasource_len,
620
+ cardinality=datareader_len,
637
621
  time_per_record=0.0,
638
622
  cost_per_record=0.0,
639
623
  quality=1.0,
640
624
  )
641
625
 
642
- op_estimates = operator.naive_cost_estimates(source_op_estimates, input_record_size_in_bytes=datasource_memsize/datasource_len)
626
+ op_estimates = operator.naive_cost_estimates(source_op_estimates, input_record_size_in_bytes=NAIVE_BYTES_PER_RECORD)
643
627
 
644
628
  else:
645
629
  op_estimates = operator.naive_cost_estimates(source_op_estimates)
@@ -660,7 +644,7 @@ class CostModel(BaseCostModel):
660
644
  # NOTE: this cardinality is the only cardinality we estimate directly b/c we can observe how many groups are
661
645
  # produced by the groupby in our sample and assume it may generalize to the full workload. To estimate
662
646
  # actual cardinalities of operators we estimate their selectivities / fan-outs and multiply those by
663
- # the input cardinality (where the initial input cardinality from the datasource is known).
647
+ # the input cardinality (where the initial input cardinality from the datareader is known).
664
648
  op_estimates.cardinality = sample_op_estimates[op_id]["cardinality"]
665
649
  op_estimates.cardinality_lower_bound = op_estimates.cardinality
666
650
  op_estimates.cardinality_upper_bound = op_estimates.cardinality
@@ -3,14 +3,12 @@ from __future__ import annotations
3
3
  from copy import deepcopy
4
4
 
5
5
  from palimpzest.constants import Model
6
- from palimpzest.core.data.datasources import DataSource
6
+ from palimpzest.core.data.datareaders import DataReader
7
7
  from palimpzest.core.lib.fields import Field
8
- from palimpzest.datamanager.datamanager import DataDirectory
9
8
  from palimpzest.policy import Policy
10
9
  from palimpzest.query.operators.logical import (
11
10
  Aggregate,
12
11
  BaseScan,
13
- CacheScan,
14
12
  ConvertScan,
15
13
  FilteredScan,
16
14
  GroupByAggregate,
@@ -32,6 +30,7 @@ from palimpzest.query.optimizer.plan import PhysicalPlan
32
30
  from palimpzest.query.optimizer.primitives import Group, LogicalExpression
33
31
  from palimpzest.query.optimizer.rules import (
34
32
  CodeSynthesisConvertRule,
33
+ CriticAndRefineConvertRule,
35
34
  LLMConvertBondedRule,
36
35
  LLMConvertConventionalRule,
37
36
  MixtureOfAgentsConvertRule,
@@ -48,9 +47,18 @@ from palimpzest.query.optimizer.tasks import (
48
47
  OptimizePhysicalExpression,
49
48
  )
50
49
  from palimpzest.sets import Dataset, Set
50
+ from palimpzest.utils.hash_helpers import hash_for_serialized_dict
51
51
  from palimpzest.utils.model_helpers import get_champion_model, get_code_champion_model, get_conventional_fallback_model
52
52
 
53
53
 
54
+ def get_node_uid(node: Dataset | DataReader) -> str:
55
+ """Helper function to compute the universal identifier for a node in the query plan."""
56
+ # NOTE: technically, hash_for_serialized_dict(node.serialize()) would be valid for both DataReader and Dataset;
57
+ # for the moment, I want to be explicit in Dataset about what constitutes a unique Dataset object, but
58
+ # in ther future we may be able to remove universal_identifier() from Dataset and just use this function
59
+ return node.universal_identifier() if isinstance(node, Dataset) else hash_for_serialized_dict(node.serialize())
60
+
61
+
54
62
  class Optimizer:
55
63
  """
56
64
  The optimizer is responsible for searching the space of possible physical plans
@@ -85,8 +93,9 @@ class Optimizer:
85
93
  allow_conventional_query: bool = False,
86
94
  allow_code_synth: bool = False,
87
95
  allow_token_reduction: bool = False,
88
- allow_rag_reduction: bool = True,
96
+ allow_rag_reduction: bool = False,
89
97
  allow_mixtures: bool = True,
98
+ allow_critic: bool = False,
90
99
  optimization_strategy_type: OptimizationStrategyType = OptimizationStrategyType.PARETO,
91
100
  use_final_op_quality: bool = False, # TODO: make this func(plan) -> final_quality
92
101
  ):
@@ -129,6 +138,7 @@ class Optimizer:
129
138
  self.allow_token_reduction = False
130
139
  self.allow_rag_reduction = False
131
140
  self.allow_mixtures = False
141
+ self.allow_critic = False
132
142
  self.available_models = [available_models[0]]
133
143
 
134
144
  # store optimization hyperparameters
@@ -141,6 +151,7 @@ class Optimizer:
141
151
  self.allow_token_reduction = allow_token_reduction
142
152
  self.allow_rag_reduction = allow_rag_reduction
143
153
  self.allow_mixtures = allow_mixtures
154
+ self.allow_critic = allow_critic
144
155
  self.optimization_strategy_type = optimization_strategy_type
145
156
  self.use_final_op_quality = use_final_op_quality
146
157
 
@@ -180,10 +191,14 @@ class Optimizer:
180
191
  if not issubclass(rule, MixtureOfAgentsConvertRule)
181
192
  ]
182
193
 
194
+ if not self.allow_critic:
195
+ self.implementation_rules = [
196
+ rule for rule in self.implementation_rules if not issubclass(rule, CriticAndRefineConvertRule)
197
+ ]
198
+
183
199
  def update_cost_model(self, cost_model: CostModel):
184
200
  self.cost_model = cost_model
185
201
 
186
-
187
202
  def get_physical_op_params(self):
188
203
  return {
189
204
  "verbose": self.verbose,
@@ -214,20 +229,22 @@ class Optimizer:
214
229
  self.strategy = OptimizerStrategyRegistry.get_strategy(optimizer_strategy_type.value)
215
230
 
216
231
  def construct_group_tree(self, dataset_nodes: list[Set]) -> tuple[list[int], dict[str, Field], dict[str, set[str]]]:
217
- # get node, output_schema, and input_schema(if applicable)
232
+ # get node, output_schema, and input_schema (if applicable)
218
233
  node = dataset_nodes[-1]
219
234
  output_schema = node.schema
220
235
  input_schema = dataset_nodes[-2].schema if len(dataset_nodes) > 1 else None
221
-
236
+
222
237
  ### convert node --> Group ###
223
- uid = node.universal_identifier()
238
+ uid = get_node_uid(node)
224
239
 
225
240
  # create the op for the given node
226
241
  op: LogicalOperator | None = None
227
- if not self.no_cache and DataDirectory().has_cached_answer(uid):
228
- op = CacheScan(dataset_id=uid, input_schema=None, output_schema=output_schema)
229
- elif isinstance(node, DataSource):
230
- op = BaseScan(dataset_id=uid, output_schema=output_schema)
242
+
243
+ # TODO: add cache scan when we add caching back to PZ
244
+ # if not self.no_cache:
245
+ # op = CacheScan(datareader=node, output_schema=output_schema)
246
+ if isinstance(node, DataReader):
247
+ op = BaseScan(datareader=node, output_schema=output_schema)
231
248
  elif node._filter is not None:
232
249
  op = FilteredScan(
233
250
  input_schema=input_schema,
@@ -269,6 +286,7 @@ class Optimizer:
269
286
  input_schema=input_schema,
270
287
  output_schema=output_schema,
271
288
  index=node._index,
289
+ search_func=node._search_func,
272
290
  search_attr=node._search_attr,
273
291
  output_attr=node._output_attr,
274
292
  k=node._k,
@@ -283,6 +301,9 @@ class Optimizer:
283
301
  depends_on=node._depends_on,
284
302
  target_cache_id=uid,
285
303
  )
304
+ # some legacy plans may have a useless convert; for now we simply skip it
305
+ elif output_schema == input_schema:
306
+ return self.construct_group_tree(dataset_nodes[:-1]) if len(dataset_nodes) > 1 else ([], {}, {})
286
307
  else:
287
308
  raise NotImplementedError(
288
309
  f"""No logical operator exists for the specified dataset construction.
@@ -306,7 +327,7 @@ class Optimizer:
306
327
  # compute the set of (short) field names this operation depends on
307
328
  depends_on_field_names = (
308
329
  {}
309
- if isinstance(node, DataSource)
330
+ if isinstance(node, DataReader)
310
331
  else {field_name.split(".")[-1] for field_name in node._depends_on}
311
332
  )
312
333
 
@@ -359,28 +380,22 @@ class Optimizer:
359
380
 
360
381
  def convert_query_plan_to_group_tree(self, query_plan: Dataset) -> str:
361
382
  # Obtain ordered list of datasets
362
- dataset_nodes = []
363
- node = query_plan.copy()
383
+ dataset_nodes: list[Dataset | DataReader] = []
384
+ node = deepcopy(query_plan)
364
385
 
386
+ # NOTE: the very first node will be a DataReader; the rest will be Dataset
365
387
  while isinstance(node, Dataset):
366
388
  dataset_nodes.append(node)
367
389
  node = node._source
368
390
  dataset_nodes.append(node)
369
391
  dataset_nodes = list(reversed(dataset_nodes))
370
392
 
371
- # remove unnecessary convert if output schema from data source scan matches
372
- # input schema for the next operator
373
- if len(dataset_nodes) > 1 and dataset_nodes[0].schema.get_desc() == dataset_nodes[1].schema.get_desc():
374
- dataset_nodes = [dataset_nodes[0]] + dataset_nodes[2:]
375
- if len(dataset_nodes) > 1:
376
- dataset_nodes[1]._source = dataset_nodes[0]
377
-
378
393
  # compute depends_on field for every node
379
394
  short_to_full_field_name = {}
380
395
  for node_idx, node in enumerate(dataset_nodes):
381
396
  # update mapping from short to full field names
382
397
  short_field_names = node.schema.field_names()
383
- full_field_names = node.schema.field_names(unique=True, id=node.universal_identifier())
398
+ full_field_names = node.schema.field_names(unique=True, id=get_node_uid(node))
384
399
  for short_field_name, full_field_name in zip(short_field_names, full_field_names):
385
400
  # set mapping automatically if this is a new field
386
401
  if short_field_name not in short_to_full_field_name or (
@@ -389,7 +404,7 @@ class Optimizer:
389
404
  short_to_full_field_name[short_field_name] = full_field_name
390
405
 
391
406
  # if the node is a data source, then skip
392
- if isinstance(node, DataSource):
407
+ if isinstance(node, DataReader):
393
408
  continue
394
409
 
395
410
  # If the node already has depends_on specified, then resolve each field name to a full (unique) field name
@@ -400,7 +415,7 @@ class Optimizer:
400
415
  # otherwise, make the node depend on all upstream nodes
401
416
  node._depends_on = set()
402
417
  for upstream_node in dataset_nodes[:node_idx]:
403
- node._depends_on.update(upstream_node.schema.field_names(unique=True, id=upstream_node.universal_identifier()))
418
+ node._depends_on.update(upstream_node.schema.field_names(unique=True, id=get_node_uid(upstream_node)))
404
419
  node._depends_on = list(node._depends_on)
405
420
 
406
421
  # construct tree of groups
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
+ from copy import deepcopy
4
5
  from enum import Enum
5
6
 
6
7
  from palimpzest.policy import Policy
@@ -31,6 +32,31 @@ class OptimizationStrategy(ABC):
31
32
  """Factory method to create strategy instances"""
32
33
  return OptimizerStrategyRegistry.get_strategy(strategy_type)
33
34
 
35
+ def normalize_final_plans(self, plans: list[PhysicalPlan]) -> list[PhysicalPlan]:
36
+ """
37
+ For each plan in `plans`, this function enforces that the input schema of every
38
+ operator is the output schema of the previous operator in the plan.
39
+
40
+ Args:
41
+ plans list[PhysicalPlan]: list of physical plans to normalize
42
+
43
+ Returns:
44
+ list[PhysicalPlan]: list of normalized physical plans
45
+ """
46
+ normalized_plans = []
47
+ for plan in plans:
48
+ normalized_ops = []
49
+ for idx, op in enumerate(plan.operators):
50
+ op_copy = deepcopy(op)
51
+ if idx == 0:
52
+ normalized_ops.append(op_copy)
53
+ else:
54
+ op_copy.input_schema = plan.operators[-1].output_schema
55
+ normalized_ops.append(op_copy)
56
+ normalized_plans.append(PhysicalPlan(operators=normalized_ops, plan_cost=plan.plan_cost))
57
+
58
+ return normalized_plans
59
+
34
60
 
35
61
  class GreedyStrategy(OptimizationStrategy):
36
62
  def _get_greedy_physical_plan(self, groups: dict, group_id: int) -> PhysicalPlan:
@@ -3,8 +3,8 @@ from __future__ import annotations
3
3
  from abc import ABC, abstractmethod
4
4
 
5
5
  from palimpzest.core.data.dataclasses import PlanCost
6
- from palimpzest.query.operators.datasource import DataSourcePhysicalOp
7
6
  from palimpzest.query.operators.physical import PhysicalOperator
7
+ from palimpzest.query.operators.scan import ScanPhysicalOp
8
8
  from palimpzest.utils.hash_helpers import hash_for_id
9
9
 
10
10
 
@@ -100,7 +100,7 @@ class SentinelPlan(Plan):
100
100
  def __init__(self, operator_sets: list[list[PhysicalOperator]]):
101
101
  # enforce that first operator_set is a scan and that every operator_set has at least one operator
102
102
  if len(operator_sets) > 0:
103
- assert isinstance(operator_sets[0][0], DataSourcePhysicalOp), "first operator set must be a scan"
103
+ assert isinstance(operator_sets[0][0], ScanPhysicalOp), "first operator set must be a scan"
104
104
  assert all(len(op_set) > 0 for op_set in operator_sets), "every operator set must have at least one operator"
105
105
 
106
106
  # store operator_sets and logical_op_ids; sort operator_sets internally by op_id