palimpzest 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. palimpzest/constants.py +38 -62
  2. palimpzest/core/data/dataset.py +1 -1
  3. palimpzest/core/data/iter_dataset.py +5 -5
  4. palimpzest/core/elements/groupbysig.py +1 -1
  5. palimpzest/core/elements/records.py +91 -109
  6. palimpzest/core/lib/schemas.py +23 -0
  7. palimpzest/core/models.py +3 -3
  8. palimpzest/prompts/__init__.py +2 -6
  9. palimpzest/prompts/convert_prompts.py +10 -66
  10. palimpzest/prompts/critique_and_refine_prompts.py +66 -0
  11. palimpzest/prompts/filter_prompts.py +8 -46
  12. palimpzest/prompts/join_prompts.py +12 -75
  13. palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
  14. palimpzest/prompts/moa_proposer_prompts.py +87 -0
  15. palimpzest/prompts/prompt_factory.py +351 -479
  16. palimpzest/prompts/split_merge_prompts.py +51 -2
  17. palimpzest/prompts/split_proposer_prompts.py +48 -16
  18. palimpzest/prompts/utils.py +109 -0
  19. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  20. palimpzest/query/execution/execution_strategy.py +4 -4
  21. palimpzest/query/execution/mab_execution_strategy.py +47 -23
  22. palimpzest/query/execution/parallel_execution_strategy.py +3 -3
  23. palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
  24. palimpzest/query/generators/generators.py +31 -17
  25. palimpzest/query/operators/__init__.py +15 -2
  26. palimpzest/query/operators/aggregate.py +21 -19
  27. palimpzest/query/operators/compute.py +6 -8
  28. palimpzest/query/operators/convert.py +12 -37
  29. palimpzest/query/operators/critique_and_refine.py +194 -0
  30. palimpzest/query/operators/distinct.py +7 -7
  31. palimpzest/query/operators/filter.py +13 -25
  32. palimpzest/query/operators/join.py +321 -192
  33. palimpzest/query/operators/limit.py +4 -4
  34. palimpzest/query/operators/mixture_of_agents.py +246 -0
  35. palimpzest/query/operators/physical.py +25 -2
  36. palimpzest/query/operators/project.py +4 -4
  37. palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
  38. palimpzest/query/operators/retrieve.py +10 -9
  39. palimpzest/query/operators/scan.py +9 -10
  40. palimpzest/query/operators/search.py +18 -24
  41. palimpzest/query/operators/split.py +321 -0
  42. palimpzest/query/optimizer/__init__.py +12 -8
  43. palimpzest/query/optimizer/optimizer.py +12 -10
  44. palimpzest/query/optimizer/rules.py +201 -108
  45. palimpzest/query/optimizer/tasks.py +18 -6
  46. palimpzest/query/processor/config.py +2 -2
  47. palimpzest/query/processor/query_processor.py +2 -2
  48. palimpzest/query/processor/query_processor_factory.py +9 -5
  49. palimpzest/validator/validator.py +7 -9
  50. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/METADATA +3 -8
  51. palimpzest-0.8.3.dist-info/RECORD +95 -0
  52. palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
  53. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
  54. palimpzest/prompts/util_phrases.py +0 -19
  55. palimpzest/query/operators/critique_and_refine_convert.py +0 -113
  56. palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
  57. palimpzest/query/operators/split_convert.py +0 -170
  58. palimpzest-0.8.1.dist-info/RECORD +0 -95
  59. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/WHEEL +0 -0
  60. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/licenses/LICENSE +0 -0
  61. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/top_level.txt +0 -0
@@ -5,12 +5,17 @@ from itertools import combinations
5
5
 
6
6
  from palimpzest.constants import AggFunc, Model, PromptStrategy
7
7
  from palimpzest.core.data.context_manager import ContextManager
8
- from palimpzest.core.lib.schemas import AudioBase64, AudioFilepath, ImageBase64, ImageFilepath, ImageURL
8
+ from palimpzest.core.lib.schemas import (
9
+ AUDIO_FIELD_TYPES,
10
+ AUDIO_LIST_FIELD_TYPES,
11
+ IMAGE_FIELD_TYPES,
12
+ IMAGE_LIST_FIELD_TYPES,
13
+ )
9
14
  from palimpzest.prompts import CONTEXT_SEARCH_PROMPT
10
15
  from palimpzest.query.operators.aggregate import ApplyGroupByOp, AverageAggregateOp, CountAggregateOp
11
16
  from palimpzest.query.operators.compute import SmolAgentsCompute
12
17
  from palimpzest.query.operators.convert import LLMConvertBonded, NonLLMConvert
13
- from palimpzest.query.operators.critique_and_refine_convert import CriticAndRefineConvert
18
+ from palimpzest.query.operators.critique_and_refine import CritiqueAndRefineConvert, CritiqueAndRefineFilter
14
19
  from palimpzest.query.operators.distinct import DistinctOp
15
20
  from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter
16
21
  from palimpzest.query.operators.join import NestedLoopsJoin
@@ -30,43 +35,20 @@ from palimpzest.query.operators.logical import (
30
35
  RetrieveScan,
31
36
  SearchOperator,
32
37
  )
33
- from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgentsConvert
38
+ from palimpzest.query.operators.mixture_of_agents import MixtureOfAgentsConvert, MixtureOfAgentsFilter
34
39
  from palimpzest.query.operators.physical import PhysicalOperator
35
40
  from palimpzest.query.operators.project import ProjectOp
36
- from palimpzest.query.operators.rag_convert import RAGConvert
41
+ from palimpzest.query.operators.rag import RAGConvert, RAGFilter
37
42
  from palimpzest.query.operators.retrieve import RetrieveOp
38
43
  from palimpzest.query.operators.scan import ContextScanOp, MarshalAndScanDataOp
39
44
  from palimpzest.query.operators.search import (
40
45
  SmolAgentsSearch, # SmolAgentsCustomManagedSearch, # SmolAgentsManagedSearch
41
46
  )
42
- from palimpzest.query.operators.split_convert import SplitConvert
47
+ from palimpzest.query.operators.split import SplitConvert, SplitFilter
43
48
  from palimpzest.query.optimizer.primitives import Expression, Group, LogicalExpression, PhysicalExpression
44
49
 
45
50
  logger = logging.getLogger(__name__)
46
51
 
47
- # DEFINITIONS
48
- IMAGE_LIST_FIELD_TYPES = [
49
- list[ImageBase64],
50
- list[ImageFilepath],
51
- list[ImageURL],
52
- list[ImageBase64] | None,
53
- list[ImageFilepath] | None,
54
- list[ImageURL] | None,
55
- ]
56
- IMAGE_FIELD_TYPES = IMAGE_LIST_FIELD_TYPES + [
57
- ImageBase64, ImageFilepath, ImageURL,
58
- ImageBase64 | None, ImageFilepath | None, ImageURL | None,
59
- ]
60
- AUDIO_LIST_FIELD_TYPES = [
61
- list[AudioBase64],
62
- list[AudioFilepath],
63
- list[AudioBase64] | None,
64
- list[AudioFilepath] | None,
65
- ]
66
- AUDIO_FIELD_TYPES = AUDIO_LIST_FIELD_TYPES + [
67
- AudioBase64, AudioFilepath,
68
- AudioBase64 | None, AudioFilepath | None,
69
- ]
70
52
 
71
53
  class Rule:
72
54
  """
@@ -93,6 +75,11 @@ class TransformationRule(Rule):
93
75
  which are created during the substitution.
94
76
  """
95
77
 
78
+ @classmethod
79
+ def is_exploration_rule(cls) -> bool:
80
+ """Returns True if this rule is an exploration rule and False otherwise. Default is False."""
81
+ return False
82
+
96
83
  @classmethod
97
84
  def substitute(
98
85
  cls, logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
@@ -109,6 +96,143 @@ class TransformationRule(Rule):
109
96
  raise NotImplementedError("Calling this method from an abstract base class.")
110
97
 
111
98
 
99
+ class ReorderConverts(TransformationRule):
100
+ """
101
+ This rule is an exploration rule that returns new logical expressions by re-ordering a sequence of ConvertScans.
102
+ """
103
+
104
+ @classmethod
105
+ def is_exploration_rule(cls) -> bool:
106
+ return True
107
+
108
+ @classmethod
109
+ def matches_pattern(cls, logical_expression: Expression) -> bool:
110
+ is_match = isinstance(logical_expression.operator, ConvertScan)
111
+ logger.debug(f"ReorderConverts matches_pattern: {is_match} for {logical_expression}")
112
+ return is_match
113
+
114
+ @classmethod
115
+ def substitute(
116
+ cls, logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs: dict
117
+ ) -> tuple[set[LogicalExpression], set[Group]]:
118
+ logger.debug(f"Substituting ReorderConverts for {logical_expression}")
119
+
120
+ # initialize the sets of new logical expressions and groups to be returned
121
+ new_logical_expressions, new_groups = set(), set()
122
+
123
+ # for each input group, if this convert does not depend on an operator in that group:
124
+ # then swap the group with this convert
125
+ convert_operator: ConvertScan = logical_expression.operator
126
+ for input_group_id in logical_expression.input_group_ids:
127
+ input_group = groups[input_group_id]
128
+
129
+ # if the convert's dependencies aren't contained within the input group's fields,
130
+ # then we can not push it down into this group
131
+ if any([field not in input_group.fields for field in convert_operator.depends_on]):
132
+ continue
133
+
134
+ # iterate over logical expressions
135
+ for expr in input_group.logical_expressions:
136
+ # if the expression operator is not a convert, we cannot swap
137
+ if not isinstance(expr.operator, ConvertScan):
138
+ continue
139
+
140
+ # if this convert depends on a field generated by the expression we're trying to swap with, we can't swap
141
+ if any([field in expr.generated_fields for field in convert_operator.depends_on]):
142
+ continue
143
+
144
+ # create new logical expression with convert pushed down to the input group's logical expression
145
+ new_input_group_ids = deepcopy(expr.input_group_ids)
146
+ new_input_fields = deepcopy(expr.input_fields)
147
+ new_depends_on_field_names = deepcopy(logical_expression.depends_on_field_names)
148
+ new_generated_fields = deepcopy(logical_expression.generated_fields)
149
+ new_convert_expr = LogicalExpression(
150
+ convert_operator,
151
+ input_group_ids=new_input_group_ids,
152
+ input_fields=new_input_fields,
153
+ depends_on_field_names=new_depends_on_field_names,
154
+ generated_fields=new_generated_fields,
155
+ group_id=None,
156
+ )
157
+
158
+ # add new_convert_expr to set of new expressions
159
+ new_logical_expressions.add(new_convert_expr)
160
+
161
+ # get or compute the group_id and group for this new expression
162
+ group_id, group = None, None
163
+
164
+ # if the expression already exists, lookup the group_id and group
165
+ if new_convert_expr.expr_id in expressions:
166
+ group_id = expressions[new_convert_expr.expr_id].group_id
167
+ new_convert_expr.set_group_id(group_id)
168
+ group = groups[group_id]
169
+
170
+ # otherwise, lookup or create expression's group and add it to the new expressions
171
+ else:
172
+ # first, compute the fields for the group
173
+ all_fields = {**new_input_fields, **new_generated_fields}
174
+
175
+ # next, compute the properties; the properties will be identical to those of the input group
176
+ # EXCEPT for the filters which will change as a result of our swap
177
+ new_group_properties = deepcopy(input_group.properties)
178
+
179
+ # if the expression we're swapping with is a map,
180
+ # we need to remove its model fields from the input group properties
181
+ if sorted(expr.operator.input_schema.model_fields.keys()) == sorted(expr.operator.output_schema.model_fields.keys()):
182
+ model_fields_dict = {
183
+ k: {"annotation": v.annotation, "default": v.default, "description": v.description}
184
+ for k, v in expr.operator.output_schema.model_fields.items()
185
+ }
186
+ new_group_properties["maps"].remove(model_fields_dict)
187
+
188
+ # finally, if this expression is a map, add its model fields to the new group's properties
189
+ if sorted(convert_operator.input_schema.model_fields.keys()) == sorted(convert_operator.output_schema.model_fields.keys()):
190
+ model_fields_dict = {
191
+ k: {"annotation": v.annotation, "default": v.default, "description": v.description}
192
+ for k, v in convert_operator.output_schema.model_fields.items()
193
+ }
194
+ if "maps" in new_group_properties:
195
+ new_group_properties["maps"].add(model_fields_dict)
196
+ else:
197
+ new_group_properties["maps"] = set([model_fields_dict])
198
+
199
+ # create group for this new convert expression
200
+ group = Group(
201
+ logical_expressions=[new_convert_expr],
202
+ fields=all_fields,
203
+ properties=new_group_properties,
204
+ )
205
+ group_id = group.group_id
206
+ new_convert_expr.set_group_id(group_id)
207
+
208
+ # if the group already exists, add the expression to that group
209
+ if group_id in groups:
210
+ group = groups[group_id]
211
+ group.logical_expressions.add(new_convert_expr)
212
+
213
+ # otherwise, add this new group to groups and to the set of new groups
214
+ else:
215
+ groups[group_id] = group
216
+ new_groups.add(group)
217
+
218
+ # create final new logical expression with expr's operator pulled up
219
+ new_expr = LogicalExpression(
220
+ expr.operator.copy(),
221
+ input_group_ids=[group_id] + [g_id for g_id in logical_expression.input_group_ids if g_id != input_group_id],
222
+ input_fields=group.fields,
223
+ depends_on_field_names=expr.depends_on_field_names,
224
+ generated_fields=expr.generated_fields,
225
+ group_id=logical_expression.group_id,
226
+ )
227
+
228
+ # add newly created expression to set of returned expressions
229
+ new_logical_expressions.add(new_expr)
230
+
231
+ logger.debug(f"Done substituting ReorderConverts for {logical_expression}")
232
+
233
+ return new_logical_expressions, new_groups
234
+
235
+
112
236
  class PushDownFilter(TransformationRule):
113
237
  """
114
238
  If this operator is a filter, push down the filter and replace it with the
@@ -496,22 +620,11 @@ class LLMConvertBondedRule(ImplementationRule):
496
620
 
497
621
  # create variable physical operator kwargs for each model which can implement this logical_expression
498
622
  models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
499
- # NOTE: right now we exclusively allow image or audio operations, but not both simultaneously
500
- prompt_strategy, no_reasoning_prompt_strategy = None, None
501
623
  no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
502
- if cls._is_text_only_operation(logical_expression):
503
- prompt_strategy = PromptStrategy.COT_QA
504
- no_reasoning_prompt_strategy = PromptStrategy.COT_QA_NO_REASONING
505
- elif cls._is_image_operation(logical_expression):
506
- prompt_strategy = PromptStrategy.COT_QA_IMAGE
507
- no_reasoning_prompt_strategy = PromptStrategy.COT_QA_IMAGE_NO_REASONING
508
- elif cls._is_audio_operation(logical_expression):
509
- prompt_strategy = PromptStrategy.COT_QA_AUDIO
510
- no_reasoning_prompt_strategy = PromptStrategy.COT_QA_AUDIO_NO_REASONING
511
624
  variable_op_kwargs = [
512
625
  {
513
626
  "model": model,
514
- "prompt_strategy": no_reasoning_prompt_strategy if model.is_reasoning_model() and no_reasoning else prompt_strategy,
627
+ "prompt_strategy": PromptStrategy.MAP_NO_REASONING if model.is_reasoning_model() and no_reasoning else PromptStrategy.MAP,
515
628
  "reasoning_effort": runtime_kwargs["reasoning_effort"],
516
629
  }
517
630
  for model in models
@@ -520,9 +633,9 @@ class LLMConvertBondedRule(ImplementationRule):
520
633
  return cls._perform_substitution(logical_expression, LLMConvertBonded, runtime_kwargs, variable_op_kwargs)
521
634
 
522
635
 
523
- class RAGConvertRule(ImplementationRule):
636
+ class RAGRule(ImplementationRule):
524
637
  """
525
- Substitute a logical expression for a ConvertScan with a RAGConvert physical implementation.
638
+ Implementation rule for the RAG operators.
526
639
  """
527
640
 
528
641
  num_chunks_per_fields = [1, 2, 4]
@@ -531,20 +644,23 @@ class RAGConvertRule(ImplementationRule):
531
644
  @classmethod
532
645
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
533
646
  logical_op = logical_expression.operator
534
- is_match = isinstance(logical_op, ConvertScan) and cls._is_text_only_operation(logical_expression) and logical_op.udf is None
535
- logger.debug(f"RAGConvertRule matches_pattern: {is_match} for {logical_expression}")
536
- return is_match
647
+ is_map_match = isinstance(logical_op, ConvertScan) and cls._is_text_only_operation(logical_expression) and logical_op.udf is None
648
+ is_filter_match = isinstance(logical_op, FilteredScan) and cls._is_text_only_operation(logical_expression) and logical_op.filter.filter_fn is None
649
+ logger.debug(f"RAGRule matches_pattern: {is_map_match or is_filter_match} for {logical_expression}")
650
+ return is_map_match or is_filter_match
537
651
 
538
652
  @classmethod
539
653
  def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
540
- logger.debug(f"Substituting RAGConvertRule for {logical_expression}")
654
+ logger.debug(f"Substituting RAGRule for {logical_expression}")
655
+ # select physical operator class based on whether this is a map or filter operation
656
+ phys_op_cls = RAGConvert if isinstance(logical_expression.operator, ConvertScan) else RAGFilter
541
657
 
542
658
  # create variable physical operator kwargs for each model which can implement this logical_expression
543
659
  models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
544
660
  variable_op_kwargs = [
545
661
  {
546
662
  "model": model,
547
- "prompt_strategy": PromptStrategy.COT_QA,
663
+ "prompt_strategy": PromptStrategy.MAP if phys_op_cls is RAGConvert else PromptStrategy.FILTER,
548
664
  "num_chunks_per_field": num_chunks_per_field,
549
665
  "chunk_size": chunk_size,
550
666
  "reasoning_effort": runtime_kwargs["reasoning_effort"],
@@ -554,12 +670,12 @@ class RAGConvertRule(ImplementationRule):
554
670
  for chunk_size in cls.chunk_sizes
555
671
  ]
556
672
 
557
- return cls._perform_substitution(logical_expression, RAGConvert, runtime_kwargs, variable_op_kwargs)
673
+ return cls._perform_substitution(logical_expression, phys_op_cls, runtime_kwargs, variable_op_kwargs)
558
674
 
559
675
 
560
- class MixtureOfAgentsConvertRule(ImplementationRule):
676
+ class MixtureOfAgentsRule(ImplementationRule):
561
677
  """
562
- Implementation rule for the MixtureOfAgentsConvert operator.
678
+ Implementation rule for the MixtureOfAgents operators.
563
679
  """
564
680
 
565
681
  num_proposer_models = [1, 2, 3]
@@ -568,26 +684,25 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
568
684
  @classmethod
569
685
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
570
686
  logical_op = logical_expression.operator
571
- # TODO: remove audio limitation once I add prompts
572
- is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None and not cls._is_audio_operation(logical_expression)
573
- logger.debug(f"MixtureOfAgentsConvertRule matches_pattern: {is_match} for {logical_expression}")
574
- return is_match
687
+ is_map_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None
688
+ is_filter_match = isinstance(logical_op, FilteredScan) and logical_op.filter.filter_fn is None
689
+ logger.debug(f"MixtureOfAgentsRule matches_pattern: {is_map_match or is_filter_match} for {logical_expression}")
690
+ return is_map_match or is_filter_match
575
691
 
576
692
  @classmethod
577
693
  def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
578
- logger.debug(f"Substituting MixtureOfAgentsConvertRule for {logical_expression}")
694
+ logger.debug(f"Substituting MixtureOfAgentsRule for {logical_expression}")
695
+ # select physical operator class based on whether this is a map or filter operation
696
+ phys_op_cls = MixtureOfAgentsConvert if isinstance(logical_expression.operator, ConvertScan) else MixtureOfAgentsFilter
579
697
 
580
698
  # create variable physical operator kwargs for each model which can implement this logical_expression
581
699
  proposer_model_set = {model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)}
582
700
  aggregator_model_set = {model for model in runtime_kwargs["available_models"] if model.is_text_model()}
583
- proposer_prompt_strategy = PromptStrategy.COT_MOA_PROPOSER_IMAGE if cls._is_image_operation(logical_expression) else PromptStrategy.COT_MOA_PROPOSER
584
701
  variable_op_kwargs = [
585
702
  {
586
703
  "proposer_models": list(proposer_models),
587
704
  "temperatures": [temp] * len(proposer_models),
588
705
  "aggregator_model": aggregator_model,
589
- "proposer_prompt_strategy": proposer_prompt_strategy,
590
- "aggregator_prompt_strategy": PromptStrategy.COT_MOA_AGG,
591
706
  "reasoning_effort": runtime_kwargs["reasoning_effort"],
592
707
  }
593
708
  for k in cls.num_proposer_models
@@ -596,35 +711,36 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
596
711
  for aggregator_model in aggregator_model_set
597
712
  ]
598
713
 
599
- return cls._perform_substitution(logical_expression, MixtureOfAgentsConvert, runtime_kwargs, variable_op_kwargs)
714
+ return cls._perform_substitution(logical_expression, phys_op_cls, runtime_kwargs, variable_op_kwargs)
600
715
 
601
716
 
602
- class CriticAndRefineConvertRule(ImplementationRule):
717
+ class CritiqueAndRefineRule(ImplementationRule):
603
718
  """
604
- Implementation rule for the CriticAndRefineConvert operator.
719
+ Implementation rule for the CritiqueAndRefine operators.
605
720
  """
606
721
 
607
722
  @classmethod
608
723
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
609
724
  logical_op = logical_expression.operator
610
- # TODO: remove audio limitation once I add prompts
611
- is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None and not cls._is_audio_operation(logical_expression)
612
- logger.debug(f"CriticAndRefineConvertRule matches_pattern: {is_match} for {logical_expression}")
613
- return is_match
725
+ is_map_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None
726
+ is_filter_match = isinstance(logical_op, FilteredScan) and logical_op.filter.filter_fn is None
727
+ logger.debug(f"CritiqueAndRefineRule matches_pattern: {is_map_match or is_filter_match} for {logical_expression}")
728
+ return is_map_match or is_filter_match
614
729
 
615
730
  @classmethod
616
731
  def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
617
- logger.debug(f"Substituting CriticAndRefineConvertRule for {logical_expression}")
732
+ logger.debug(f"Substituting CritiqueAndRefineRule for {logical_expression}")
733
+ # select physical operator class based on whether this is a map or filter operation
734
+ phys_op_cls = CritiqueAndRefineConvert if isinstance(logical_expression.operator, ConvertScan) else CritiqueAndRefineFilter
618
735
 
619
736
  # create variable physical operator kwargs for each model which can implement this logical_expression
620
737
  models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
621
- prompt_strategy = PromptStrategy.COT_QA_IMAGE if cls._is_image_operation(logical_expression) else PromptStrategy.COT_QA
622
738
  variable_op_kwargs = [
623
739
  {
624
740
  "model": model,
625
741
  "critic_model": critic_model,
626
742
  "refine_model": refine_model,
627
- "prompt_strategy": prompt_strategy,
743
+ "prompt_strategy": PromptStrategy.MAP if phys_op_cls is CritiqueAndRefineConvert else PromptStrategy.FILTER,
628
744
  "reasoning_effort": runtime_kwargs["reasoning_effort"],
629
745
  }
630
746
  for model in models
@@ -632,12 +748,12 @@ class CriticAndRefineConvertRule(ImplementationRule):
632
748
  for refine_model in models
633
749
  ]
634
750
 
635
- return cls._perform_substitution(logical_expression, CriticAndRefineConvert, runtime_kwargs, variable_op_kwargs)
751
+ return cls._perform_substitution(logical_expression, phys_op_cls, runtime_kwargs, variable_op_kwargs)
636
752
 
637
753
 
638
- class SplitConvertRule(ImplementationRule):
754
+ class SplitRule(ImplementationRule):
639
755
  """
640
- Substitute a logical expression for a ConvertScan with a SplitConvert physical implementation.
756
+ Implementation rule for the Split operators.
641
757
  """
642
758
  num_chunks = [2, 4, 6]
643
759
  min_size_to_chunk = [1000, 4000]
@@ -645,13 +761,16 @@ class SplitConvertRule(ImplementationRule):
645
761
  @classmethod
646
762
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
647
763
  logical_op = logical_expression.operator
648
- is_match = isinstance(logical_op, ConvertScan) and cls._is_text_only_operation() and logical_op.udf is None
649
- logger.debug(f"SplitConvertRule matches_pattern: {is_match} for {logical_expression}")
650
- return is_match
764
+ is_map_match = isinstance(logical_op, ConvertScan) and cls._is_text_only_operation() and logical_op.udf is None
765
+ is_filter_match = isinstance(logical_op, FilteredScan) and cls._is_text_only_operation() and logical_op.filter.filter_fn is None
766
+ logger.debug(f"SplitRule matches_pattern: {is_map_match or is_filter_match} for {logical_expression}")
767
+ return is_map_match or is_filter_match
651
768
 
652
769
  @classmethod
653
770
  def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
654
- logger.debug(f"Substituting SplitConvertRule for {logical_expression}")
771
+ logger.debug(f"Substituting SplitRule for {logical_expression}")
772
+ # select physical operator class based on whether this is a map or filter operation
773
+ phys_op_cls = SplitConvert if isinstance(logical_expression.operator, ConvertScan) else SplitFilter
655
774
 
656
775
  # create variable physical operator kwargs for each model which can implement this logical_expression
657
776
  models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
@@ -667,7 +786,7 @@ class SplitConvertRule(ImplementationRule):
667
786
  for num_chunks in cls.num_chunks
668
787
  ]
669
788
 
670
- return cls._perform_substitution(logical_expression, SplitConvert, runtime_kwargs, variable_op_kwargs)
789
+ return cls._perform_substitution(logical_expression, phys_op_cls, runtime_kwargs, variable_op_kwargs)
671
790
 
672
791
 
673
792
  class RetrieveRule(ImplementationRule):
@@ -699,10 +818,8 @@ class NonLLMFilterRule(ImplementationRule):
699
818
 
700
819
  @classmethod
701
820
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
702
- is_match = (
703
- isinstance(logical_expression.operator, FilteredScan)
704
- and logical_expression.operator.filter.filter_fn is not None
705
- )
821
+ logical_op = logical_expression.operator
822
+ is_match = isinstance(logical_op, FilteredScan) and logical_op.filter.filter_fn is not None
706
823
  logger.debug(f"NonLLMFilterRule matches_pattern: {is_match} for {logical_expression}")
707
824
  return is_match
708
825
 
@@ -719,10 +836,8 @@ class LLMFilterRule(ImplementationRule):
719
836
 
720
837
  @classmethod
721
838
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
722
- is_match = (
723
- isinstance(logical_expression.operator, FilteredScan)
724
- and logical_expression.operator.filter.filter_condition is not None
725
- )
839
+ logical_op = logical_expression.operator
840
+ is_match = isinstance(logical_op, FilteredScan) and logical_op.filter.filter_fn is None
726
841
  logger.debug(f"LLMFilterRule matches_pattern: {is_match} for {logical_expression}")
727
842
  return is_match
728
843
 
@@ -732,22 +847,11 @@ class LLMFilterRule(ImplementationRule):
732
847
 
733
848
  # create variable physical operator kwargs for each model which can implement this logical_expression
734
849
  models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
735
- # NOTE: right now we exclusively allow image or audio operations, but not both simultaneously
736
- prompt_strategy, no_reasoning_prompt_strategy = None, None
737
850
  no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
738
- if cls._is_text_only_operation(logical_expression):
739
- prompt_strategy = PromptStrategy.COT_BOOL
740
- no_reasoning_prompt_strategy = PromptStrategy.COT_BOOL_NO_REASONING
741
- elif cls._is_image_operation(logical_expression):
742
- prompt_strategy = PromptStrategy.COT_BOOL_IMAGE
743
- no_reasoning_prompt_strategy = PromptStrategy.COT_BOOL_IMAGE_NO_REASONING
744
- elif cls._is_audio_operation(logical_expression):
745
- prompt_strategy = PromptStrategy.COT_BOOL_AUDIO
746
- no_reasoning_prompt_strategy = PromptStrategy.COT_BOOL_AUDIO_NO_REASONING
747
851
  variable_op_kwargs = [
748
852
  {
749
853
  "model": model,
750
- "prompt_strategy": no_reasoning_prompt_strategy if model.is_reasoning_model() and no_reasoning else prompt_strategy,
854
+ "prompt_strategy": PromptStrategy.FILTER_NO_REASONING if model.is_reasoning_model() and no_reasoning else PromptStrategy.FILTER,
751
855
  "reasoning_effort": runtime_kwargs["reasoning_effort"]
752
856
  }
753
857
  for model in models
@@ -773,22 +877,11 @@ class LLMJoinRule(ImplementationRule):
773
877
 
774
878
  # create variable physical operator kwargs for each model which can implement this logical_expression
775
879
  models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
776
- # NOTE: right now we exclusively allow image or audio operations, but not both simultaneously
777
- prompt_strategy, no_reasoning_prompt_strategy = None, None
778
880
  no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
779
- if cls._is_text_only_operation(logical_expression):
780
- prompt_strategy = PromptStrategy.COT_JOIN
781
- no_reasoning_prompt_strategy = PromptStrategy.COT_JOIN_NO_REASONING
782
- elif cls._is_image_operation(logical_expression):
783
- prompt_strategy = PromptStrategy.COT_JOIN_IMAGE
784
- no_reasoning_prompt_strategy = PromptStrategy.COT_JOIN_IMAGE_NO_REASONING
785
- elif cls._is_audio_operation(logical_expression):
786
- prompt_strategy = PromptStrategy.COT_JOIN_AUDIO
787
- no_reasoning_prompt_strategy = PromptStrategy.COT_JOIN_AUDIO_NO_REASONING
788
881
  variable_op_kwargs = [
789
882
  {
790
883
  "model": model,
791
- "prompt_strategy": no_reasoning_prompt_strategy if model.is_reasoning_model() and no_reasoning else prompt_strategy,
884
+ "prompt_strategy": PromptStrategy.JOIN_NO_REASONING if model.is_reasoning_model() and no_reasoning else PromptStrategy.JOIN,
792
885
  "join_parallelism": runtime_kwargs["join_parallelism"],
793
886
  "reasoning_effort": runtime_kwargs["reasoning_effort"],
794
887
  }
@@ -66,17 +66,19 @@ class OptimizeGroup(Task):
66
66
  task = OptimizePhysicalExpression(physical_expr)
67
67
  new_tasks.append(task)
68
68
 
69
+ # and first explore the group if it hasn't been explored yet
70
+ if not group.explored:
71
+ task = ExploreGroup(self.group_id)
72
+ new_tasks.append(task)
73
+
69
74
  logger.debug(f"Done optimizing group {self.group_id}")
70
75
  logger.debug(f"New tasks: {len(new_tasks)}")
71
76
  return new_tasks
72
77
 
73
78
 
74
- class ExpandGroup(Task):
79
+ class ExploreGroup(Task):
75
80
  """
76
- The task to expand a group.
77
-
78
- NOTE: we currently do not use this task, but I'm keeping it around in case we need it
79
- once we add join operations.
81
+ The task to explore a group and add additional logical expressions.
80
82
  """
81
83
 
82
84
  def __init__(self, group_id: int):
@@ -100,6 +102,12 @@ class ExpandGroup(Task):
100
102
  task = OptimizeLogicalExpression(logical_expr, exploring=True)
101
103
  new_tasks.append(task)
102
104
 
105
+ # but first (tasks are LIFO), we recursively explore input groups of logical expressions in this group
106
+ for logical_expr in group.logical_expressions:
107
+ for input_group_id in logical_expr.input_group_ids:
108
+ task = ExploreGroup(input_group_id)
109
+ new_tasks.append(task)
110
+
103
111
  # mark the group as explored and return tasks
104
112
  group.set_explored()
105
113
 
@@ -131,7 +139,11 @@ class OptimizeLogicalExpression(Task):
131
139
  context = {}
132
140
 
133
141
  # if we're exploring, only apply transformation rules
134
- rules = transformation_rules if self.exploring else transformation_rules + implementation_rules
142
+ rules = (
143
+ [rule for rule in transformation_rules if rule.is_exploration_rule()]
144
+ if self.exploring
145
+ else transformation_rules + implementation_rules
146
+ )
135
147
 
136
148
  # filter out rules that have already been applied to logical expression
137
149
  rules = list(filter(lambda rule: rule.get_rule_id() not in self.logical_expression.rules_applied, rules))
@@ -40,8 +40,8 @@ class QueryProcessorConfig(BaseModel):
40
40
  use_final_op_quality: bool = Field(default=False)
41
41
 
42
42
  # sentinel optimization flags
43
- k: int = Field(default=5)
44
- j: int = Field(default=5)
43
+ k: int = Field(default=6)
44
+ j: int = Field(default=4)
45
45
  sample_budget: int = Field(default=100)
46
46
  seed: int = Field(default=42)
47
47
  exp_name: str | None = Field(default=None)
@@ -114,8 +114,8 @@ class QueryProcessor:
114
114
  execution_stats = ExecutionStats(execution_id=self.execution_id())
115
115
  execution_stats.start()
116
116
 
117
- # if the user provides a train_dataset or validator, we perform optimization
118
- if self.train_dataset is not None or self.validator is not None:
117
+ # if the user provides a validator, we perform optimization
118
+ if self.validator is not None:
119
119
  # create sentinel plan
120
120
  sentinel_plan = self._create_sentinel_plan(self.train_dataset)
121
121
 
@@ -62,13 +62,17 @@ class QueryProcessorFactory:
62
62
  print("WARNING: Both `progress` and `verbose` are set to True, but only one can be True at a time; defaulting to `progress=True`")
63
63
  config.verbose = False
64
64
 
65
+ # if the user provides a training dataset, but no validator, create a default validator
66
+ if train_dataset is not None and validator is None:
67
+ validator = Validator()
68
+ logger.info("No validator provided; using default Validator")
69
+
65
70
  # boolean flag for whether we're performing optimization or not
66
- optimization = train_dataset is not None or validator is not None
67
- val_based_opt = train_dataset is None and validator is not None
71
+ optimization = validator is not None
68
72
 
69
73
  # handle "auto" default for sentinel execution strategies
70
74
  if config.sentinel_execution_strategy == "auto":
71
- config.sentinel_execution_strategy = ("validator" if val_based_opt else "mab") if optimization else None
75
+ config.sentinel_execution_strategy = "mab" if optimization else None
72
76
 
73
77
  # convert the config values for processing, execution, and optimization strategies to enums
74
78
  config = cls._normalize_strategies(config)
@@ -87,7 +91,7 @@ class QueryProcessorFactory:
87
91
  # set the final set of available models in the config
88
92
  config.available_models = available_models
89
93
 
90
- return config
94
+ return config, validator
91
95
 
92
96
  @classmethod
93
97
  def _create_optimizer(cls, config: QueryProcessorConfig) -> Optimizer:
@@ -143,7 +147,7 @@ class QueryProcessorFactory:
143
147
  config = QueryProcessorConfig()
144
148
 
145
149
  # apply any additional keyword arguments to the config and validate its contents
146
- config = cls._config_validation_and_normalization(config, train_dataset, validator)
150
+ config, validator = cls._config_validation_and_normalization(config, train_dataset, validator)
147
151
 
148
152
  # create the optimizer, execution strateg(ies), and processor
149
153
  optimizer = cls._create_optimizer(config)