palimpzest 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. palimpzest/constants.py +38 -62
  2. palimpzest/core/data/dataset.py +1 -1
  3. palimpzest/core/data/iter_dataset.py +5 -5
  4. palimpzest/core/elements/groupbysig.py +1 -1
  5. palimpzest/core/elements/records.py +91 -109
  6. palimpzest/core/lib/schemas.py +23 -0
  7. palimpzest/core/models.py +3 -3
  8. palimpzest/prompts/__init__.py +2 -6
  9. palimpzest/prompts/convert_prompts.py +10 -66
  10. palimpzest/prompts/critique_and_refine_prompts.py +66 -0
  11. palimpzest/prompts/filter_prompts.py +8 -46
  12. palimpzest/prompts/join_prompts.py +12 -75
  13. palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
  14. palimpzest/prompts/moa_proposer_prompts.py +87 -0
  15. palimpzest/prompts/prompt_factory.py +351 -479
  16. palimpzest/prompts/split_merge_prompts.py +51 -2
  17. palimpzest/prompts/split_proposer_prompts.py +48 -16
  18. palimpzest/prompts/utils.py +109 -0
  19. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  20. palimpzest/query/execution/execution_strategy.py +4 -4
  21. palimpzest/query/execution/mab_execution_strategy.py +47 -23
  22. palimpzest/query/execution/parallel_execution_strategy.py +3 -3
  23. palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
  24. palimpzest/query/generators/generators.py +31 -17
  25. palimpzest/query/operators/__init__.py +15 -2
  26. palimpzest/query/operators/aggregate.py +21 -19
  27. palimpzest/query/operators/compute.py +6 -8
  28. palimpzest/query/operators/convert.py +12 -37
  29. palimpzest/query/operators/critique_and_refine.py +194 -0
  30. palimpzest/query/operators/distinct.py +7 -7
  31. palimpzest/query/operators/filter.py +13 -25
  32. palimpzest/query/operators/join.py +321 -192
  33. palimpzest/query/operators/limit.py +4 -4
  34. palimpzest/query/operators/mixture_of_agents.py +246 -0
  35. palimpzest/query/operators/physical.py +25 -2
  36. palimpzest/query/operators/project.py +4 -4
  37. palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
  38. palimpzest/query/operators/retrieve.py +10 -9
  39. palimpzest/query/operators/scan.py +9 -10
  40. palimpzest/query/operators/search.py +18 -24
  41. palimpzest/query/operators/split.py +321 -0
  42. palimpzest/query/optimizer/__init__.py +12 -8
  43. palimpzest/query/optimizer/optimizer.py +12 -10
  44. palimpzest/query/optimizer/rules.py +201 -108
  45. palimpzest/query/optimizer/tasks.py +18 -6
  46. palimpzest/query/processor/config.py +2 -2
  47. palimpzest/query/processor/query_processor.py +2 -2
  48. palimpzest/query/processor/query_processor_factory.py +9 -5
  49. palimpzest/validator/validator.py +7 -9
  50. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/METADATA +3 -8
  51. palimpzest-0.8.3.dist-info/RECORD +95 -0
  52. palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
  53. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
  54. palimpzest/prompts/util_phrases.py +0 -19
  55. palimpzest/query/operators/critique_and_refine_convert.py +0 -113
  56. palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
  57. palimpzest/query/operators/split_convert.py +0 -170
  58. palimpzest-0.8.1.dist-info/RECORD +0 -95
  59. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/WHEEL +0 -0
  60. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/licenses/LICENSE +0 -0
  61. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/top_level.txt +0 -0
@@ -91,17 +91,15 @@ class SmolAgentsSearch(PhysicalOperator):
91
91
  Given an input DataRecord and a determination of whether it passed the filter or not,
92
92
  construct the resulting RecordSet.
93
93
  """
94
- # create new DataRecord and set passed_operator attribute
95
- dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
96
- for field in self.output_schema.model_fields:
97
- if field in answer:
98
- dr[field] = answer[field]
94
+ # create new DataRecord
95
+ data_item = {field: answer[field] for field in self.output_schema.model_fields if field in answer}
96
+ dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
99
97
 
100
98
  # create RecordOpStats object
101
99
  record_op_stats = RecordOpStats(
102
- record_id=dr.id,
103
- record_parent_ids=dr.parent_ids,
104
- record_source_indices=dr.source_indices,
100
+ record_id=dr._id,
101
+ record_parent_ids=dr._parent_ids,
102
+ record_source_indices=dr._source_indices,
105
103
  record_state=dr.to_dict(include_bytes=False),
106
104
  full_op_id=self.get_full_op_id(),
107
105
  logical_op_id=self.logical_op_id,
@@ -248,17 +246,15 @@ class SmolAgentsSearch(PhysicalOperator):
248
246
  # Given an input DataRecord and a determination of whether it passed the filter or not,
249
247
  # construct the resulting RecordSet.
250
248
  # """
251
- # # create new DataRecord and set passed_operator attribute
252
- # dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
253
- # for field in self.output_schema.model_fields:
254
- # if field in answer:
255
- # dr[field] = answer[field]
249
+ # # create new DataRecord
250
+ # data_item = {field: answer[field] for field in self.output_schema.model_fields if field in answer}
251
+ # dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
256
252
 
257
253
  # # create RecordOpStats object
258
254
  # record_op_stats = RecordOpStats(
259
- # record_id=dr.id,
260
- # record_parent_ids=dr.parent_ids,
261
- # record_source_indices=dr.source_indices,
255
+ # record_id=dr._id,
256
+ # record_parent_ids=dr._parent_ids,
257
+ # record_source_indices=dr._source_indices,
262
258
  # record_state=dr.to_dict(include_bytes=False),
263
259
  # full_op_id=self.get_full_op_id(),
264
260
  # logical_op_id=self.logical_op_id,
@@ -440,17 +436,15 @@ class SmolAgentsSearch(PhysicalOperator):
440
436
  # Given an input DataRecord and a determination of whether it passed the filter or not,
441
437
  # construct the resulting RecordSet.
442
438
  # """
443
- # # create new DataRecord and set passed_operator attribute
444
- # dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
445
- # for field in self.output_schema.model_fields:
446
- # if field in answer:
447
- # dr[field] = answer[field]
439
+ # # create new DataRecord
440
+ # data_item = {field: answer[field] for field in self.output_schema.model_fields if field in answer}
441
+ # dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
448
442
 
449
443
  # # create RecordOpStats object
450
444
  # record_op_stats = RecordOpStats(
451
- # record_id=dr.id,
452
- # record_parent_ids=dr.parent_ids,
453
- # record_source_indices=dr.source_indices,
445
+ # record_id=dr._id,
446
+ # record_parent_ids=dr._parent_ids,
447
+ # record_source_indices=dr._source_indices,
454
448
  # record_state=dr.to_dict(include_bytes=False),
455
449
  # full_op_id=self.get_full_op_id(),
456
450
  # logical_op_id=self.logical_op_id,
@@ -0,0 +1,321 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ from pydantic.fields import FieldInfo
6
+
7
+ from palimpzest.constants import (
8
+ MODEL_CARDS,
9
+ NAIVE_EST_NUM_INPUT_TOKENS,
10
+ NAIVE_EST_NUM_OUTPUT_TOKENS,
11
+ Cardinality,
12
+ PromptStrategy,
13
+ )
14
+ from palimpzest.core.elements.records import DataRecord
15
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates
16
+ from palimpzest.query.generators.generators import Generator
17
+ from palimpzest.query.operators.convert import LLMConvert
18
+ from palimpzest.query.operators.filter import LLMFilter
19
+
20
+
21
+ class SplitConvert(LLMConvert):
22
+ def __init__(self, num_chunks: int = 2, min_size_to_chunk: int = 1000, *args, **kwargs):
23
+ kwargs["prompt_strategy"] = None
24
+ super().__init__(*args, **kwargs)
25
+ self.num_chunks = num_chunks
26
+ self.min_size_to_chunk = min_size_to_chunk
27
+ self.split_generator = Generator(self.model, PromptStrategy.MAP_SPLIT_PROPOSER, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
28
+ self.split_merge_generator = Generator(self.model, PromptStrategy.MAP_SPLIT_MERGER, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
29
+
30
+ # crude adjustment factor for naive estimation in unoptimized setting
31
+ self.naive_quality_adjustment = 0.6
32
+
33
+ def __str__(self):
34
+ op = super().__str__()
35
+ op += f" Chunk Size: {str(self.num_chunks)}\n"
36
+ op += f" Min Size to Chunk: {str(self.min_size_to_chunk)}\n"
37
+ return op
38
+
39
+ def get_id_params(self):
40
+ id_params = super().get_id_params()
41
+ id_params = {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **id_params}
42
+
43
+ return id_params
44
+
45
+ def get_op_params(self):
46
+ op_params = super().get_op_params()
47
+ return {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **op_params}
48
+
49
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
50
+ """
51
+ Update the cost per record and quality estimates produced by LLMConvert's naive estimates.
52
+ We adjust the cost per record to account for the reduced number of input tokens following
53
+ the retrieval of relevant chunks, and we make a crude estimate of the quality degradation
54
+ that results from using a downsized input (although this may in fact improve quality in
55
+ some cases).
56
+ """
57
+ # get naive cost estimates from LLMConvert
58
+ naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
59
+
60
+ # re-compute cost per record assuming we use fewer input tokens; naively assume a single input field
61
+ est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS
62
+ est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
63
+ model_conversion_usd_per_record = (
64
+ MODEL_CARDS[self.model.value]["usd_per_input_token"] * est_num_input_tokens
65
+ + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
66
+ )
67
+
68
+ # set refined estimate of cost per record
69
+ naive_op_cost_estimates.cost_per_record = model_conversion_usd_per_record
70
+ naive_op_cost_estimates.cost_per_record_lower_bound = naive_op_cost_estimates.cost_per_record
71
+ naive_op_cost_estimates.cost_per_record_upper_bound = naive_op_cost_estimates.cost_per_record
72
+ naive_op_cost_estimates.quality = (naive_op_cost_estimates.quality) * self.naive_quality_adjustment
73
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
74
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
75
+
76
+ return naive_op_cost_estimates
77
+
78
+ def get_text_chunks(self, text: str, num_chunks: int) -> list[str]:
79
+ """
80
+ Given a text string, chunk it into num_chunks substrings of roughly equal size.
81
+ """
82
+ chunks = []
83
+
84
+ idx, chunk_size = 0, math.ceil(len(text) / num_chunks)
85
+ while idx + chunk_size < len(text):
86
+ chunks.append(text[idx : idx + chunk_size])
87
+ idx += chunk_size
88
+
89
+ if idx < len(text):
90
+ chunks.append(text[idx:])
91
+
92
+ return chunks
93
+
94
+ def get_chunked_candidate(self, candidate: DataRecord, input_fields: list[str]) -> list[DataRecord]:
95
+ """
96
+ For each text field, chunk the content. If a field is smaller than the chunk size,
97
+ simply include the full field.
98
+ """
99
+ # compute mapping from each field to its chunked content
100
+ field_name_to_chunked_content = {}
101
+ for field_name in input_fields:
102
+ field = candidate.get_field_type(field_name)
103
+ content = candidate[field_name]
104
+
105
+ # do not chunk this field if it is not a string or a list of strings
106
+ is_string_field = field.annotation in [str, str | None]
107
+ is_list_string_field = field.annotation in [list[str], list[str] | None]
108
+ if not (is_string_field or is_list_string_field):
109
+ field_name_to_chunked_content[field_name] = [content]
110
+ continue
111
+
112
+ # if this is a list of strings, join the strings
113
+ if is_list_string_field:
114
+ content = "[" + ", ".join(content) + "]"
115
+
116
+ # skip this field if its length is less than the min size to chunk
117
+ if len(content) < self.min_size_to_chunk:
118
+ field_name_to_chunked_content[field_name] = [content]
119
+ continue
120
+
121
+ # chunk the content
122
+ field_name_to_chunked_content[field_name] = self.get_text_chunks(content, self.num_chunks)
123
+
124
+ # compute the true number of chunks (may be 1 if all fields are not chunked)
125
+ num_chunks = max(len(chunks) for chunks in field_name_to_chunked_content.values())
126
+
127
+ # create the chunked canidates
128
+ candidates = []
129
+ for chunk_idx in range(num_chunks):
130
+ candidate_copy = candidate.copy()
131
+ for field_name in input_fields:
132
+ field_chunks = field_name_to_chunked_content[field_name]
133
+ candidate_copy[field_name] = field_chunks[chunk_idx] if len(field_chunks) > 1 else field_chunks[0]
134
+
135
+ candidates.append(candidate_copy)
136
+
137
+ return candidates
138
+
139
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
140
+ # get the set of input fields to use for the convert operation
141
+ input_fields = self.get_input_fields()
142
+
143
+ # lookup most relevant chunks for each field using embedding search
144
+ candidate_copy = candidate.copy()
145
+ chunked_candidates = self.get_chunked_candidate(candidate_copy, input_fields)
146
+
147
+ # construct kwargs for generation
148
+ gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema}
149
+
150
+ # generate outputs for each chunk separately
151
+ chunk_outputs, chunk_generation_stats_lst = [], []
152
+ for candidate in chunked_candidates:
153
+ _, reasoning, chunk_generation_stats, _ = self.split_generator(candidate, fields, json_output=False, **gen_kwargs)
154
+ chunk_outputs.append(reasoning)
155
+ chunk_generation_stats_lst.append(chunk_generation_stats)
156
+
157
+ # call the merger
158
+ gen_kwargs = {
159
+ "project_cols": input_fields,
160
+ "output_schema": self.output_schema,
161
+ "chunk_outputs": chunk_outputs,
162
+ }
163
+ field_answers, _, merger_gen_stats, _ = self.split_merge_generator(candidate, fields, **gen_kwargs)
164
+
165
+ # compute the total generation stats
166
+ generation_stats = sum(chunk_generation_stats_lst) + merger_gen_stats
167
+
168
+ return field_answers, generation_stats
169
+
170
+
171
+ class SplitFilter(LLMFilter):
172
+ def __init__(self, num_chunks: int = 2, min_size_to_chunk: int = 1000, *args, **kwargs):
173
+ kwargs["prompt_strategy"] = None
174
+ super().__init__(*args, **kwargs)
175
+ self.num_chunks = num_chunks
176
+ self.min_size_to_chunk = min_size_to_chunk
177
+ self.split_generator = Generator(self.model, PromptStrategy.FILTER_SPLIT_PROPOSER, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
178
+ self.split_merge_generator = Generator(self.model, PromptStrategy.FILTER_SPLIT_MERGER, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
179
+
180
+ # crude adjustment factor for naive estimation in no-sentinel setting
181
+ self.naive_quality_adjustment = 0.6
182
+
183
+ def __str__(self):
184
+ op = super().__str__()
185
+ op += f" Chunk Size: {str(self.num_chunks)}\n"
186
+ op += f" Min Size to Chunk: {str(self.min_size_to_chunk)}\n"
187
+ return op
188
+
189
+ def get_id_params(self):
190
+ id_params = super().get_id_params()
191
+ id_params = {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **id_params}
192
+
193
+ return id_params
194
+
195
+ def get_op_params(self):
196
+ op_params = super().get_op_params()
197
+ return {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **op_params}
198
+
199
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
200
+ """
201
+ Update the cost per record and quality estimates produced by LLMFilter's naive estimates.
202
+ We adjust the cost per record to account for the reduced number of input tokens following
203
+ the retrieval of relevant chunks, and we make a crude estimate of the quality degradation
204
+ that results from using a downsized input (although this may in fact improve quality in
205
+ some cases).
206
+ """
207
+ # get naive cost estimates from LLMFilter
208
+ naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
209
+
210
+ # re-compute cost per record assuming we use fewer input tokens; naively assume a single input field
211
+ est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS
212
+ est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
213
+ model_conversion_usd_per_record = (
214
+ MODEL_CARDS[self.model.value]["usd_per_input_token"] * est_num_input_tokens
215
+ + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
216
+ )
217
+
218
+ # set refined estimate of cost per record
219
+ naive_op_cost_estimates.cost_per_record = model_conversion_usd_per_record
220
+ naive_op_cost_estimates.cost_per_record_lower_bound = naive_op_cost_estimates.cost_per_record
221
+ naive_op_cost_estimates.cost_per_record_upper_bound = naive_op_cost_estimates.cost_per_record
222
+ naive_op_cost_estimates.quality = (naive_op_cost_estimates.quality) * self.naive_quality_adjustment
223
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
224
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
225
+
226
+ return naive_op_cost_estimates
227
+
228
+ def get_text_chunks(self, text: str, num_chunks: int) -> list[str]:
229
+ """
230
+ Given a text string, chunk it into num_chunks substrings of roughly equal size.
231
+ """
232
+ chunks = []
233
+
234
+ idx, chunk_size = 0, math.ceil(len(text) / num_chunks)
235
+ while idx + chunk_size < len(text):
236
+ chunks.append(text[idx : idx + chunk_size])
237
+ idx += chunk_size
238
+
239
+ if idx < len(text):
240
+ chunks.append(text[idx:])
241
+
242
+ return chunks
243
+
244
+ def get_chunked_candidate(self, candidate: DataRecord, input_fields: list[str]) -> list[DataRecord]:
245
+ """
246
+ For each text field, chunk the content. If a field is smaller than the chunk size,
247
+ simply include the full field.
248
+ """
249
+ # compute mapping from each field to its chunked content
250
+ field_name_to_chunked_content = {}
251
+ for field_name in input_fields:
252
+ field = candidate.get_field_type(field_name)
253
+ content = candidate[field_name]
254
+
255
+ # do not chunk this field if it is not a string or a list of strings
256
+ is_string_field = field.annotation in [str, str | None]
257
+ is_list_string_field = field.annotation in [list[str], list[str] | None]
258
+ if not (is_string_field or is_list_string_field):
259
+ field_name_to_chunked_content[field_name] = [content]
260
+ continue
261
+
262
+ # if this is a list of strings, join the strings
263
+ if is_list_string_field:
264
+ content = "[" + ", ".join(content) + "]"
265
+
266
+ # skip this field if its length is less than the min size to chunk
267
+ if len(content) < self.min_size_to_chunk:
268
+ field_name_to_chunked_content[field_name] = [content]
269
+ continue
270
+
271
+ # chunk the content
272
+ field_name_to_chunked_content[field_name] = self.get_text_chunks(content, self.num_chunks)
273
+
274
+ # compute the true number of chunks (may be 1 if all fields are not chunked)
275
+ num_chunks = max(len(chunks) for chunks in field_name_to_chunked_content.values())
276
+
277
+ # create the chunked canidates
278
+ candidates = []
279
+ for chunk_idx in range(num_chunks):
280
+ candidate_copy = candidate.copy()
281
+ for field_name in input_fields:
282
+ field_chunks = field_name_to_chunked_content[field_name]
283
+ candidate_copy[field_name] = field_chunks[chunk_idx] if len(field_chunks) > 1 else field_chunks[0]
284
+
285
+ candidates.append(candidate_copy)
286
+
287
+ return candidates
288
+
289
+ def filter(self, candidate: DataRecord) -> tuple[dict[str, bool], GenerationStats]:
290
+ # get the set of input fields to use for the filter operation
291
+ input_fields = self.get_input_fields()
292
+
293
+ # construct output fields
294
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the record passed the filter operation")}
295
+
296
+ # lookup most relevant chunks for each field using embedding search
297
+ candidate_copy = candidate.copy()
298
+ chunked_candidates = self.get_chunked_candidate(candidate_copy, input_fields)
299
+
300
+ # construct kwargs for generation
301
+ gen_kwargs = {"project_cols": input_fields, "filter_condition": self.filter_obj.filter_condition}
302
+
303
+ # generate outputs for each chunk separately
304
+ chunk_outputs, chunk_generation_stats_lst = [], []
305
+ for candidate in chunked_candidates:
306
+ _, reasoning, chunk_generation_stats, _ = self.split_generator(candidate, fields, json_output=False, **gen_kwargs)
307
+ chunk_outputs.append(reasoning)
308
+ chunk_generation_stats_lst.append(chunk_generation_stats)
309
+
310
+ # call the merger
311
+ gen_kwargs = {
312
+ "project_cols": input_fields,
313
+ "filter_condition": self.filter_obj.filter_condition,
314
+ "chunk_outputs": chunk_outputs,
315
+ }
316
+ field_answers, _, merger_gen_stats, _ = self.split_merge_generator(candidate, fields, **gen_kwargs)
317
+
318
+ # compute the total generation stats
319
+ generation_stats = sum(chunk_generation_stats_lst) + merger_gen_stats
320
+
321
+ return field_answers, generation_stats
@@ -6,7 +6,7 @@ from palimpzest.query.optimizer.rules import (
6
6
  BasicSubstitutionRule as _BasicSubstitutionRule,
7
7
  )
8
8
  from palimpzest.query.optimizer.rules import (
9
- CriticAndRefineConvertRule as _CriticAndRefineConvertRule,
9
+ CritiqueAndRefineRule as _CritiqueAndRefineRule,
10
10
  )
11
11
  from palimpzest.query.optimizer.rules import (
12
12
  ImplementationRule as _ImplementationRule,
@@ -21,7 +21,7 @@ from palimpzest.query.optimizer.rules import (
21
21
  LLMJoinRule as _LLMJoinRule,
22
22
  )
23
23
  from palimpzest.query.optimizer.rules import (
24
- MixtureOfAgentsConvertRule as _MixtureOfAgentsConvertRule,
24
+ MixtureOfAgentsRule as _MixtureOfAgentsRule,
25
25
  )
26
26
  from palimpzest.query.optimizer.rules import (
27
27
  NonLLMConvertRule as _NonLLMConvertRule,
@@ -33,7 +33,10 @@ from palimpzest.query.optimizer.rules import (
33
33
  PushDownFilter as _PushDownFilter,
34
34
  )
35
35
  from palimpzest.query.optimizer.rules import (
36
- RAGConvertRule as _RAGConvertRule,
36
+ RAGRule as _RAGRule,
37
+ )
38
+ from palimpzest.query.optimizer.rules import (
39
+ ReorderConverts as _ReorderConverts,
37
40
  )
38
41
  from palimpzest.query.optimizer.rules import (
39
42
  RetrieveRule as _RetrieveRule,
@@ -42,7 +45,7 @@ from palimpzest.query.optimizer.rules import (
42
45
  Rule as _Rule,
43
46
  )
44
47
  from palimpzest.query.optimizer.rules import (
45
- SplitConvertRule as _SplitConvertRule,
48
+ SplitRule as _SplitRule,
46
49
  )
47
50
  from palimpzest.query.optimizer.rules import (
48
51
  TransformationRule as _TransformationRule,
@@ -52,19 +55,20 @@ ALL_RULES = [
52
55
  _AddContextsBeforeComputeRule,
53
56
  _AggregateRule,
54
57
  _BasicSubstitutionRule,
55
- _CriticAndRefineConvertRule,
58
+ _CritiqueAndRefineRule,
56
59
  _ImplementationRule,
57
60
  _LLMConvertBondedRule,
58
61
  _LLMFilterRule,
59
62
  _LLMJoinRule,
60
- _MixtureOfAgentsConvertRule,
63
+ _MixtureOfAgentsRule,
61
64
  _NonLLMConvertRule,
62
65
  _NonLLMFilterRule,
63
66
  _PushDownFilter,
64
- _RAGConvertRule,
67
+ _RAGRule,
68
+ _ReorderConverts,
65
69
  _RetrieveRule,
66
70
  _Rule,
67
- _SplitConvertRule,
71
+ _SplitRule,
68
72
  _TransformationRule,
69
73
  ]
70
74
 
@@ -29,15 +29,15 @@ from palimpzest.query.optimizer.optimizer_strategy_type import OptimizationStrat
29
29
  from palimpzest.query.optimizer.plan import PhysicalPlan
30
30
  from palimpzest.query.optimizer.primitives import Group, LogicalExpression
31
31
  from palimpzest.query.optimizer.rules import (
32
- CriticAndRefineConvertRule,
32
+ CritiqueAndRefineRule,
33
33
  LLMConvertBondedRule,
34
- MixtureOfAgentsConvertRule,
35
- RAGConvertRule,
36
- SplitConvertRule,
34
+ MixtureOfAgentsRule,
35
+ RAGRule,
36
+ SplitRule,
37
37
  )
38
38
  from palimpzest.query.optimizer.tasks import (
39
39
  ApplyRule,
40
- ExpandGroup,
40
+ ExploreGroup,
41
41
  OptimizeGroup,
42
42
  OptimizeLogicalExpression,
43
43
  OptimizePhysicalExpression,
@@ -150,22 +150,22 @@ class Optimizer:
150
150
 
151
151
  if not self.allow_rag_reduction:
152
152
  self.implementation_rules = [
153
- rule for rule in self.implementation_rules if not issubclass(rule, RAGConvertRule)
153
+ rule for rule in self.implementation_rules if not issubclass(rule, RAGRule)
154
154
  ]
155
155
 
156
156
  if not self.allow_mixtures:
157
157
  self.implementation_rules = [
158
- rule for rule in self.implementation_rules if not issubclass(rule, MixtureOfAgentsConvertRule)
158
+ rule for rule in self.implementation_rules if not issubclass(rule, MixtureOfAgentsRule)
159
159
  ]
160
160
 
161
161
  if not self.allow_critic:
162
162
  self.implementation_rules = [
163
- rule for rule in self.implementation_rules if not issubclass(rule, CriticAndRefineConvertRule)
163
+ rule for rule in self.implementation_rules if not issubclass(rule, CritiqueAndRefineRule)
164
164
  ]
165
165
 
166
166
  if not self.allow_split_merge:
167
167
  self.implementation_rules = [
168
- rule for rule in self.implementation_rules if not issubclass(rule, SplitConvertRule)
168
+ rule for rule in self.implementation_rules if not issubclass(rule, SplitRule)
169
169
  ]
170
170
 
171
171
  logger.info(f"Initialized Optimizer with verbose={self.verbose}")
@@ -396,8 +396,9 @@ class Optimizer:
396
396
  # TODO: conditionally stop when X number of tasks have been executed to limit exhaustive search
397
397
  while len(self.tasks_stack) > 0:
398
398
  task = self.tasks_stack.pop(-1)
399
+
399
400
  new_tasks = []
400
- if isinstance(task, (OptimizeGroup, ExpandGroup)):
401
+ if isinstance(task, (OptimizeGroup, ExploreGroup)):
401
402
  new_tasks = task.perform(self.groups)
402
403
  elif isinstance(task, OptimizeLogicalExpression):
403
404
  new_tasks = task.perform(self.transformation_rules, self.implementation_rules)
@@ -409,6 +410,7 @@ class Optimizer:
409
410
  elif isinstance(task, OptimizePhysicalExpression):
410
411
  context = {"optimizer_strategy": self.optimizer_strategy, "execution_strategy": self.execution_strategy}
411
412
  new_tasks = task.perform(self.cost_model, self.groups, self.policy, context=context)
413
+
412
414
  self.tasks_stack.extend(new_tasks)
413
415
 
414
416
  logger.debug(f"Done searching optimization space for group_id: {group_id}")