palimpzest 0.7.21__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.21.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -1,38 +1,72 @@
1
1
  import logging
2
+ import os
2
3
  from copy import deepcopy
3
4
  from itertools import combinations
4
5
 
5
- from palimpzest.constants import AggFunc, Cardinality, PromptStrategy
6
+ from palimpzest.constants import AggFunc, Model, PromptStrategy
7
+ from palimpzest.core.data.context_manager import ContextManager
8
+ from palimpzest.core.lib.schemas import AudioBase64, AudioFilepath, ImageBase64, ImageFilepath, ImageURL
9
+ from palimpzest.prompts import CONTEXT_SEARCH_PROMPT
6
10
  from palimpzest.query.operators.aggregate import ApplyGroupByOp, AverageAggregateOp, CountAggregateOp
7
- from palimpzest.query.operators.code_synthesis_convert import CodeSynthesisConvertSingle
11
+ from palimpzest.query.operators.compute import SmolAgentsCompute
8
12
  from palimpzest.query.operators.convert import LLMConvertBonded, NonLLMConvert
9
13
  from palimpzest.query.operators.critique_and_refine_convert import CriticAndRefineConvert
14
+ from palimpzest.query.operators.distinct import DistinctOp
10
15
  from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter
16
+ from palimpzest.query.operators.join import NestedLoopsJoin
11
17
  from palimpzest.query.operators.limit import LimitScanOp
12
18
  from palimpzest.query.operators.logical import (
13
19
  Aggregate,
14
20
  BaseScan,
15
- CacheScan,
21
+ ComputeOperator,
22
+ ContextScan,
16
23
  ConvertScan,
24
+ Distinct,
17
25
  FilteredScan,
18
26
  GroupByAggregate,
27
+ JoinOp,
19
28
  LimitScan,
20
- MapScan,
21
29
  Project,
22
30
  RetrieveScan,
31
+ SearchOperator,
23
32
  )
24
- from palimpzest.query.operators.map import MapOp
25
33
  from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgentsConvert
34
+ from palimpzest.query.operators.physical import PhysicalOperator
26
35
  from palimpzest.query.operators.project import ProjectOp
27
36
  from palimpzest.query.operators.rag_convert import RAGConvert
28
37
  from palimpzest.query.operators.retrieve import RetrieveOp
29
- from palimpzest.query.operators.scan import CacheScanDataOp, MarshalAndScanDataOp
38
+ from palimpzest.query.operators.scan import ContextScanOp, MarshalAndScanDataOp
39
+ from palimpzest.query.operators.search import (
40
+ SmolAgentsSearch, # SmolAgentsCustomManagedSearch, # SmolAgentsManagedSearch
41
+ )
30
42
  from palimpzest.query.operators.split_convert import SplitConvert
31
43
  from palimpzest.query.optimizer.primitives import Expression, Group, LogicalExpression, PhysicalExpression
32
- from palimpzest.utils.model_helpers import get_models, get_vision_models
33
44
 
34
45
  logger = logging.getLogger(__name__)
35
46
 
47
+ # DEFINITIONS
48
+ IMAGE_LIST_FIELD_TYPES = [
49
+ list[ImageBase64],
50
+ list[ImageFilepath],
51
+ list[ImageURL],
52
+ list[ImageBase64] | None,
53
+ list[ImageFilepath] | None,
54
+ list[ImageURL] | None,
55
+ ]
56
+ IMAGE_FIELD_TYPES = IMAGE_LIST_FIELD_TYPES + [
57
+ ImageBase64, ImageFilepath, ImageURL,
58
+ ImageBase64 | None, ImageFilepath | None, ImageURL | None,
59
+ ]
60
+ AUDIO_LIST_FIELD_TYPES = [
61
+ list[AudioBase64],
62
+ list[AudioFilepath],
63
+ list[AudioBase64] | None,
64
+ list[AudioFilepath] | None,
65
+ ]
66
+ AUDIO_FIELD_TYPES = AUDIO_LIST_FIELD_TYPES + [
67
+ AudioBase64, AudioFilepath,
68
+ AudioBase64 | None, AudioFilepath | None,
69
+ ]
36
70
 
37
71
  class Rule:
38
72
  """
@@ -43,12 +77,12 @@ class Rule:
43
77
  def get_rule_id(cls):
44
78
  return cls.__name__
45
79
 
46
- @staticmethod
47
- def matches_pattern(logical_expression: LogicalExpression) -> bool:
80
+ @classmethod
81
+ def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
48
82
  raise NotImplementedError("Calling this method from an abstract base class.")
49
83
 
50
- @staticmethod
51
- def substitute(logical_expression: LogicalExpression, **kwargs) -> set[Expression]:
84
+ @classmethod
85
+ def substitute(cls, logical_expression: LogicalExpression, **kwargs: dict) -> set[Expression]:
52
86
  raise NotImplementedError("Calling this method from an abstract base class.")
53
87
 
54
88
 
@@ -59,9 +93,9 @@ class TransformationRule(Rule):
59
93
  which are created during the substitution.
60
94
  """
61
95
 
62
- @staticmethod
96
+ @classmethod
63
97
  def substitute(
64
- logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
98
+ cls, logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
65
99
  ) -> tuple[set[LogicalExpression], set[Group]]:
66
100
  """
67
101
  This function applies the transformation rule to the logical expression, which
@@ -81,15 +115,15 @@ class PushDownFilter(TransformationRule):
81
115
  most expensive operator in the input group.
82
116
  """
83
117
 
84
- @staticmethod
85
- def matches_pattern(logical_expression: Expression) -> bool:
118
+ @classmethod
119
+ def matches_pattern(cls, logical_expression: Expression) -> bool:
86
120
  is_match = isinstance(logical_expression.operator, FilteredScan)
87
121
  logger.debug(f"PushDownFilter matches_pattern: {is_match} for {logical_expression}")
88
122
  return is_match
89
123
 
90
- @staticmethod
124
+ @classmethod
91
125
  def substitute(
92
- logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
126
+ cls, logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs: dict
93
127
  ) -> tuple[set[LogicalExpression], set[Group]]:
94
128
  logger.debug(f"Substituting PushDownFilter for {logical_expression}")
95
129
 
@@ -113,7 +147,7 @@ class PushDownFilter(TransformationRule):
113
147
  # we see a regression / bug in the future
114
148
  for expr in input_group.logical_expressions:
115
149
  # if the expression operator is not a convert or a filter, we cannot swap
116
- if not (isinstance(expr.operator, (ConvertScan, FilteredScan))):
150
+ if not (isinstance(expr.operator, (ConvertScan, FilteredScan, JoinOp))):
117
151
  continue
118
152
 
119
153
  # if this filter depends on a field generated by the expression we're trying to swap with, we can't swap
@@ -141,8 +175,8 @@ class PushDownFilter(TransformationRule):
141
175
  group_id, group = None, None
142
176
 
143
177
  # if the expression already exists, lookup the group_id and group
144
- if new_filter_expr.get_expr_id() in expressions:
145
- group_id = expressions[new_filter_expr.get_expr_id()].group_id
178
+ if new_filter_expr.expr_id in expressions:
179
+ group_id = expressions[new_filter_expr.expr_id].group_id
146
180
  new_filter_expr.set_group_id(group_id)
147
181
  group = groups[group_id]
148
182
 
@@ -190,8 +224,7 @@ class PushDownFilter(TransformationRule):
190
224
  # create final new logical expression with expr's operator pulled up
191
225
  new_expr = LogicalExpression(
192
226
  expr.operator.copy(),
193
- input_group_ids=[group_id]
194
- + [g_id for g_id in logical_expression.input_group_ids if g_id != input_group_id],
227
+ input_group_ids=[group_id] + [g_id for g_id in logical_expression.input_group_ids if g_id != input_group_id],
195
228
  input_fields=group.fields,
196
229
  depends_on_field_names=expr.depends_on_field_names,
197
230
  generated_fields=expr.generated_fields,
@@ -211,218 +244,280 @@ class ImplementationRule(Rule):
211
244
  Base class for implementation rules which convert a logical expression to a physical expression.
212
245
  """
213
246
 
214
- pass
215
-
216
-
217
- class NonLLMConvertRule(ImplementationRule):
218
- """
219
- Substitute a logical expression for a UDF ConvertScan with a NonLLMConvert physical implementation.
220
- """
247
+ @classmethod
248
+ def _get_image_fields(cls, logical_expression: LogicalExpression) -> set[str]:
249
+ """Returns the set of fields which have an image (or list[image]) type."""
250
+ return set([
251
+ field_name.split(".")[-1]
252
+ for field_name, field in logical_expression.input_fields.items()
253
+ if field.annotation in IMAGE_FIELD_TYPES and field_name.split(".")[-1] in logical_expression.depends_on_field_names
254
+ ])
221
255
 
222
256
  @classmethod
223
- def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
224
- is_match = isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is not None
225
- logger.debug(f"NonLLMConvertRule matches_pattern: {is_match} for {logical_expression}")
226
- return is_match
257
+ def _get_list_image_fields(cls, logical_expression: LogicalExpression) -> set[str]:
258
+ """Returns the set of fields which have a list[image] type."""
259
+ return set([
260
+ field_name.split(".")[-1]
261
+ for field_name, field in logical_expression.input_fields.items()
262
+ if field.annotation in IMAGE_LIST_FIELD_TYPES and field_name.split(".")[-1] in logical_expression.depends_on_field_names
263
+ ])
227
264
 
228
265
  @classmethod
229
- def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
230
- logger.debug(f"Substituting NonLLMConvertRule for {logical_expression}")
266
+ def _get_audio_fields(cls, logical_expression: LogicalExpression) -> set[str]:
267
+ """Returns the set of fields which have an audio (or list[audio]) type."""
268
+ return set([
269
+ field_name.split(".")[-1]
270
+ for field_name, field in logical_expression.input_fields.items()
271
+ if field.annotation in AUDIO_FIELD_TYPES and field_name.split(".")[-1] in logical_expression.depends_on_field_names
272
+ ])
231
273
 
232
- logical_op = logical_expression.operator
274
+ @classmethod
275
+ def _get_list_audio_fields(cls, logical_expression: LogicalExpression) -> set[str]:
276
+ """Returns the set of fields which have a list[audio] type."""
277
+ return set([
278
+ field_name.split(".")[-1]
279
+ for field_name, field in logical_expression.input_fields.items()
280
+ if field.annotation in AUDIO_LIST_FIELD_TYPES and field_name.split(".")[-1] in logical_expression.depends_on_field_names
281
+ ])
233
282
 
234
- # get initial set of parameters for physical op
235
- op_kwargs = logical_op.get_logical_op_params()
236
- op_kwargs.update(
237
- {
238
- "verbose": physical_op_params["verbose"],
239
- "logical_op_id": logical_op.get_logical_op_id(),
240
- "logical_op_name": logical_op.logical_op_name(),
241
- }
242
- )
283
+ @classmethod
284
+ def _is_image_only_operation(cls, logical_expression: LogicalExpression) -> bool:
285
+ """Returns True if the logical_expression processes only image input(s) and False otherwise."""
286
+ return all([
287
+ field.annotation in IMAGE_FIELD_TYPES
288
+ for field_name, field in logical_expression.input_fields.items()
289
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
290
+ ])
243
291
 
244
- # construct multi-expression
245
- op = NonLLMConvert(**op_kwargs)
246
- expression = PhysicalExpression(
247
- operator=op,
248
- input_group_ids=logical_expression.input_group_ids,
249
- input_fields=logical_expression.input_fields,
250
- depends_on_field_names=logical_expression.depends_on_field_names,
251
- generated_fields=logical_expression.generated_fields,
252
- group_id=logical_expression.group_id,
253
- )
292
+ @classmethod
293
+ def _is_image_operation(cls, logical_expression: LogicalExpression) -> bool:
294
+ """Returns True if the logical_expression processes image input(s) and False otherwise."""
295
+ return any([
296
+ field.annotation in IMAGE_FIELD_TYPES
297
+ for field_name, field in logical_expression.input_fields.items()
298
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
299
+ ])
254
300
 
255
- deduped_physical_expressions = set([expression])
256
- logger.debug(f"Done substituting NonLLMConvertRule for {logical_expression}")
301
+ @classmethod
302
+ def _is_audio_only_operation(cls, logical_expression: LogicalExpression) -> bool:
303
+ """Returns True if the logical_expression processes only audio input(s) and False otherwise."""
304
+ return all([
305
+ field.annotation in AUDIO_FIELD_TYPES
306
+ for field_name, field in logical_expression.input_fields.items()
307
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
308
+ ])
257
309
 
258
- return deduped_physical_expressions
310
+ @classmethod
311
+ def _is_audio_operation(cls, logical_expression: LogicalExpression) -> bool:
312
+ """Returns True if the logical_expression processes audio input(s) and False otherwise."""
313
+ return any([
314
+ field.annotation in AUDIO_FIELD_TYPES
315
+ for field_name, field in logical_expression.input_fields.items()
316
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
317
+ ])
259
318
 
319
+ @classmethod
320
+ def _is_text_only_operation(cls, logical_expression: LogicalExpression) -> bool:
321
+ """Returns True if the logical_expression processes only text input(s) and False otherwise."""
322
+ return all([
323
+ field.annotation not in IMAGE_FIELD_TYPES + AUDIO_FIELD_TYPES
324
+ for field_name, field in logical_expression.input_fields.items()
325
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
326
+ ])
260
327
 
261
- class LLMConvertBondedRule(ImplementationRule):
262
- """
263
- Substitute a logical expression for a ConvertScan with a bonded convert physical implementation.
264
- """
328
+ @classmethod
329
+ def _is_text_operation(cls, logical_expression: LogicalExpression) -> bool:
330
+ """Returns True if the logical_expression processes text input(s) and False otherwise."""
331
+ return any([
332
+ field.annotation not in IMAGE_FIELD_TYPES + AUDIO_FIELD_TYPES
333
+ for field_name, field in logical_expression.input_fields.items()
334
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
335
+ ])
336
+
337
+ # TODO: support powerset of text + image + audio (+ video) multi-modal operations
338
+ @classmethod
339
+ def _is_text_image_multimodal_operation(cls, logical_expression: LogicalExpression) -> bool:
340
+ """Returns True if the logical_expression processes text and image inputs and False otherwise."""
341
+ return cls._is_image_operation(logical_expression) and cls._is_text_operation(logical_expression)
265
342
 
266
343
  @classmethod
267
- def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
268
- is_match = isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is None
269
- logger.debug(f"LLMConvertBondedRule matches_pattern: {is_match} for {logical_expression}")
270
- return is_match
344
+ def _is_text_audio_multimodal_operation(cls, logical_expression: LogicalExpression) -> bool:
345
+ """Returns True if the logical_expression processes text and audio inputs and False otherwise."""
346
+ return cls._is_audio_operation(logical_expression) and cls._is_text_operation(logical_expression)
271
347
 
272
348
  @classmethod
273
- def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
274
- logger.debug(f"Substituting LLMConvertBondedRule for {logical_expression}")
349
+ def _model_matches_input(cls, model: Model, logical_expression: LogicalExpression) -> bool:
350
+ """Returns True if the model is capable of processing the input and False otherwise."""
351
+ # compute how many image fields are in the input, and whether any fields are list[image] fields
352
+ num_image_fields = len(cls._get_image_fields(logical_expression))
353
+ has_list_image_field = len(cls._get_list_image_fields(logical_expression)) > 0
354
+ num_audio_fields = len(cls._get_audio_fields(logical_expression))
355
+ has_list_audio_field = len(cls._get_list_audio_fields(logical_expression)) > 0
356
+
357
+ # corner-case: for now, all operators use text or vision models for processing inputs to __call__
358
+ if model.is_embedding_model():
359
+ return False
275
360
 
361
+ # corner-case: Llama vision models cannot handle multiple image inputs (at least using Together)
362
+ if model.is_llama_model() and model.is_vision_model() and (num_image_fields > 1 or has_list_image_field):
363
+ return False
364
+
365
+ # corner-case: Gemini models cannot handle multiple audio inputs
366
+ if model.is_vertex_model() and model.is_audio_model() and (num_audio_fields > 1 or has_list_audio_field):
367
+ return False
368
+
369
+ # text-only input and text supporting model
370
+ if cls._is_text_only_operation(logical_expression) and model.is_text_model():
371
+ return True
372
+
373
+ # image-only input and image supporting model
374
+ if cls._is_image_only_operation(logical_expression) and model.is_vision_model():
375
+ return True
376
+
377
+ # audio-only input and audio supporting model
378
+ if cls._is_audio_only_operation(logical_expression) and model.is_audio_model():
379
+ return True
380
+
381
+ # multi-modal input and multi-modal supporting model
382
+ if cls._is_text_image_multimodal_operation(logical_expression) and model.is_text_image_multimodal_model(): # noqa: SIM103
383
+ return True
384
+
385
+ # multi-modal input and multi-modal supporting model
386
+ if cls._is_text_audio_multimodal_operation(logical_expression) and model.is_text_audio_multimodal_model(): # noqa: SIM103
387
+ return True
388
+
389
+ return False
390
+
391
+ @classmethod
392
+ def _get_fixed_op_kwargs(cls, logical_expression: LogicalExpression, runtime_kwargs: dict) -> dict:
393
+ """Get the fixed set of physical op kwargs provided by the logical expression and the runtime keyword arguments."""
394
+ # get logical operator
276
395
  logical_op = logical_expression.operator
277
396
 
278
- # get initial set of parameters for physical op
397
+ # set initial set of parameters for physical op
279
398
  op_kwargs = logical_op.get_logical_op_params()
280
399
  op_kwargs.update(
281
400
  {
282
- "verbose": physical_op_params["verbose"],
401
+ "verbose": runtime_kwargs["verbose"],
283
402
  "logical_op_id": logical_op.get_logical_op_id(),
403
+ "unique_logical_op_id": logical_op.get_unique_logical_op_id(),
284
404
  "logical_op_name": logical_op.logical_op_name(),
405
+ "api_base": runtime_kwargs["api_base"],
285
406
  }
286
407
  )
287
408
 
288
- # identify models which can be used strictly for text or strictly for images
289
- vision_models = set(get_vision_models())
290
- text_models = set(get_models())
291
- pure_text_models = {model for model in text_models if model not in vision_models}
292
- pure_vision_models = {model for model in vision_models if model not in text_models}
293
-
294
- # compute attributes about this convert operation
295
- is_image_conversion = any(
296
- [
297
- field.is_image_field
298
- for field_name, field in logical_expression.input_fields.items()
299
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
300
- ]
301
- )
302
- num_image_fields = sum(
303
- [
304
- field.is_image_field
305
- for field_name, field in logical_expression.input_fields.items()
306
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
307
- ]
308
- )
309
- list_image_field = any(
310
- [
311
- field.is_image_field and hasattr(field, "element_type")
312
- for field_name, field in logical_expression.input_fields.items()
313
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
314
- ]
315
- )
409
+ return op_kwargs
410
+
411
+ @classmethod
412
+ def _perform_substitution(
413
+ cls,
414
+ logical_expression: LogicalExpression,
415
+ physical_op_class: type[PhysicalOperator],
416
+ runtime_kwargs: dict,
417
+ variable_op_kwargs: list[dict] | dict | None = None,
418
+ ) -> set[PhysicalExpression]:
419
+ """
420
+ This performs basic substitution logic which proceeds in four steps:
421
+
422
+ 1. The basic kwargs for the physical operator are computed using the logical operator
423
+ and runtime kwargs.
424
+ 2. If variable kwargs are provided, then they are merged with the basic kwargs and one
425
+ instance of the physical operator is created for each dictionary of variable kwargs.
426
+ 3. A physical expression is created for each physical operator instance.
427
+ 4. The unique set of physical expressions is returned.
428
+
429
+ Args:
430
+ logical_expression (LogicalExpression): The logical expression containing a logical operator.
431
+ physical_op_class (type[PhysicalOperator]): The class of the physical operator we wish to construct.
432
+ runtime_kwargs (dict): Keyword arguments which are provided at runtime.
433
+ variable_op_kwargs (list[dict] | dict | None): A (list of) variable kwargs to customize each
434
+ physical operator instance.
435
+
436
+ Returns:
437
+ set[PhysicalExpression]: The unique set of physical expressions produced by initializing the
438
+ physical_op_class with the provided keyword arguments.
439
+ """
440
+ # get physical operator kwargs which are fixed for each instance of the physical operator
441
+ fixed_op_kwargs = cls._get_fixed_op_kwargs(logical_expression, runtime_kwargs)
442
+
443
+ # make variable_op_kwargs a list of dictionaries
444
+ if variable_op_kwargs is None:
445
+ variable_op_kwargs = [{}]
446
+ elif isinstance(variable_op_kwargs, dict):
447
+ variable_op_kwargs = [variable_op_kwargs]
316
448
 
449
+ # construct physical operators for each set of kwargs
317
450
  physical_expressions = []
318
- for model in physical_op_params["available_models"]:
319
- # skip this model if:
320
- # 1. this is a pure vision model and we're not doing an image conversion, or
321
- # 2. this is a pure text model and we're doing an image conversion, or
322
- # 3. this is a vision model hosted by Together (i.e. LLAMA3 vision) and there is more than one image field
323
- first_criteria = model in pure_vision_models and not is_image_conversion
324
- second_criteria = model in pure_text_models and is_image_conversion
325
- third_criteria = model.is_llama_model() and model.is_vision_model() and (num_image_fields > 1 or list_image_field)
326
- fourth_criteria = model.is_embedding_model()
327
- if first_criteria or second_criteria or third_criteria or fourth_criteria:
328
- continue
451
+ for var_op_kwargs in variable_op_kwargs:
452
+ # get kwargs for this physical operator instance
453
+ op_kwargs = {**fixed_op_kwargs, **var_op_kwargs}
329
454
 
330
- # construct multi-expression
331
- op = LLMConvertBonded(
332
- model=model,
333
- prompt_strategy=PromptStrategy.COT_QA_IMAGE if is_image_conversion else PromptStrategy.COT_QA,
334
- **op_kwargs,
335
- )
336
- expression = PhysicalExpression(
337
- operator=op,
338
- input_group_ids=logical_expression.input_group_ids,
339
- input_fields=logical_expression.input_fields,
340
- depends_on_field_names=logical_expression.depends_on_field_names,
341
- generated_fields=logical_expression.generated_fields,
342
- group_id=logical_expression.group_id,
343
- )
344
- physical_expressions.append(expression)
455
+ # construct the physical operator
456
+ op = physical_op_class(**op_kwargs)
345
457
 
346
- deduped_physical_expressions = set(physical_expressions)
347
- logger.debug(f"Done substituting LLMConvertBondedRule for {logical_expression}")
458
+ # construct physical expression and add to list of expressions
459
+ expression = PhysicalExpression.from_op_and_logical_expr(op, logical_expression)
460
+ physical_expressions.append(expression)
348
461
 
349
- return deduped_physical_expressions
462
+ return set(physical_expressions)
350
463
 
351
464
 
352
- class CodeSynthesisConvertRule(ImplementationRule):
465
+ class NonLLMConvertRule(ImplementationRule):
353
466
  """
354
- Base rule for code synthesis convert operators; the physical convert class
355
- (CodeSynthesisConvertSingle) is provided by sub-class rules.
356
-
357
- NOTE: we provide the physical convert class(es) in their own sub-classed rules to make
358
- it easier to allow/disallow groups of rules at the Optimizer level.
467
+ Substitute a logical expression for a UDF ConvertScan with a NonLLMConvert physical implementation.
359
468
  """
360
469
 
361
- physical_convert_class = None # overriden by sub-classes
362
-
363
470
  @classmethod
364
471
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
365
- logical_op = logical_expression.operator
366
- is_image_conversion = any(
367
- [
368
- field.is_image_field
369
- for field_name, field in logical_expression.input_fields.items()
370
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
371
- ]
372
- )
373
- is_match = (
374
- isinstance(logical_op, ConvertScan)
375
- and not is_image_conversion
376
- and logical_op.cardinality != Cardinality.ONE_TO_MANY
377
- and logical_op.udf is None
378
- )
379
- logger.debug(f"CodeSynthesisConvertRule matches_pattern: {is_match} for {logical_expression}")
472
+ is_match = isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is not None
473
+ logger.debug(f"NonLLMConvertRule matches_pattern: {is_match} for {logical_expression}")
380
474
  return is_match
381
475
 
382
476
  @classmethod
383
- def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
384
- logger.debug(f"Substituting CodeSynthesisConvertRule for {logical_expression}")
385
-
386
- logical_op = logical_expression.operator
477
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
478
+ logger.debug(f"Substituting NonLLMConvertRule for {logical_expression}")
479
+ return cls._perform_substitution(logical_expression, NonLLMConvert, runtime_kwargs)
387
480
 
388
- # get initial set of parameters for physical op
389
- op_kwargs = logical_op.get_logical_op_params()
390
- op_kwargs.update(
391
- {
392
- "verbose": physical_op_params["verbose"],
393
- "logical_op_id": logical_op.get_logical_op_id(),
394
- "logical_op_name": logical_op.logical_op_name(),
395
- }
396
- )
397
481
 
398
- # construct multi-expression
399
- op = cls.physical_convert_class(
400
- exemplar_generation_model=physical_op_params["champion_model"],
401
- code_synth_model=physical_op_params["code_champion_model"],
402
- fallback_model=physical_op_params["fallback_model"],
403
- prompt_strategy=PromptStrategy.COT_QA,
404
- **op_kwargs,
405
- )
406
- expression = PhysicalExpression(
407
- operator=op,
408
- input_group_ids=logical_expression.input_group_ids,
409
- input_fields=logical_expression.input_fields,
410
- depends_on_field_names=logical_expression.depends_on_field_names,
411
- generated_fields=logical_expression.generated_fields,
412
- group_id=logical_expression.group_id,
413
- )
414
- deduped_physical_expressions = set([expression])
415
- logger.debug(f"Done substituting CodeSynthesisConvertRule for {logical_expression}")
482
+ class LLMConvertBondedRule(ImplementationRule):
483
+ """
484
+ Substitute a logical expression for a ConvertScan with a bonded convert physical implementation.
485
+ """
416
486
 
417
- return deduped_physical_expressions
487
+ @classmethod
488
+ def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
489
+ is_match = isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is None
490
+ logger.debug(f"LLMConvertBondedRule matches_pattern: {is_match} for {logical_expression}")
491
+ return is_match
418
492
 
493
+ @classmethod
494
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
495
+ logger.debug(f"Substituting LLMConvertBondedRule for {logical_expression}")
419
496
 
420
- class CodeSynthesisConvertSingleRule(CodeSynthesisConvertRule):
421
- """
422
- Substitute a logical expression for a ConvertScan with a (single) code synthesis physical implementation.
423
- """
497
+ # create variable physical operator kwargs for each model which can implement this logical_expression
498
+ models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
499
+ # NOTE: right now we exclusively allow image or audio operations, but not both simultaneously
500
+ prompt_strategy, no_reasoning_prompt_strategy = None, None
501
+ no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
502
+ if cls._is_text_only_operation(logical_expression):
503
+ prompt_strategy = PromptStrategy.COT_QA
504
+ no_reasoning_prompt_strategy = PromptStrategy.COT_QA_NO_REASONING
505
+ elif cls._is_image_operation(logical_expression):
506
+ prompt_strategy = PromptStrategy.COT_QA_IMAGE
507
+ no_reasoning_prompt_strategy = PromptStrategy.COT_QA_IMAGE_NO_REASONING
508
+ elif cls._is_audio_operation(logical_expression):
509
+ prompt_strategy = PromptStrategy.COT_QA_AUDIO
510
+ no_reasoning_prompt_strategy = PromptStrategy.COT_QA_AUDIO_NO_REASONING
511
+ variable_op_kwargs = [
512
+ {
513
+ "model": model,
514
+ "prompt_strategy": no_reasoning_prompt_strategy if model.is_reasoning_model() and no_reasoning else prompt_strategy,
515
+ "reasoning_effort": runtime_kwargs["reasoning_effort"],
516
+ }
517
+ for model in models
518
+ ]
424
519
 
425
- physical_convert_class = CodeSynthesisConvertSingle
520
+ return cls._perform_substitution(logical_expression, LLMConvertBonded, runtime_kwargs, variable_op_kwargs)
426
521
 
427
522
 
428
523
  class RAGConvertRule(ImplementationRule):
@@ -436,68 +531,30 @@ class RAGConvertRule(ImplementationRule):
436
531
  @classmethod
437
532
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
438
533
  logical_op = logical_expression.operator
439
- is_image_conversion = any(
440
- [
441
- field.is_image_field
442
- for field_name, field in logical_expression.input_fields.items()
443
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
444
- ]
445
- )
446
- is_match = isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
534
+ is_match = isinstance(logical_op, ConvertScan) and cls._is_text_only_operation(logical_expression) and logical_op.udf is None
447
535
  logger.debug(f"RAGConvertRule matches_pattern: {is_match} for {logical_expression}")
448
536
  return is_match
449
537
 
450
538
  @classmethod
451
- def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
539
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
452
540
  logger.debug(f"Substituting RAGConvertRule for {logical_expression}")
453
541
 
454
- logical_op = logical_expression.operator
455
-
456
- # get initial set of parameters for physical op
457
- op_kwargs = logical_op.get_logical_op_params()
458
- op_kwargs.update(
542
+ # create variable physical operator kwargs for each model which can implement this logical_expression
543
+ models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
544
+ variable_op_kwargs = [
459
545
  {
460
- "verbose": physical_op_params["verbose"],
461
- "logical_op_id": logical_op.get_logical_op_id(),
462
- "logical_op_name": logical_op.logical_op_name(),
546
+ "model": model,
547
+ "prompt_strategy": PromptStrategy.COT_QA,
548
+ "num_chunks_per_field": num_chunks_per_field,
549
+ "chunk_size": chunk_size,
550
+ "reasoning_effort": runtime_kwargs["reasoning_effort"],
463
551
  }
464
- )
552
+ for model in models
553
+ for num_chunks_per_field in cls.num_chunks_per_fields
554
+ for chunk_size in cls.chunk_sizes
555
+ ]
465
556
 
466
- # identify models which can be used strictly for text or strictly for images
467
- vision_models = set(get_vision_models())
468
- text_models = set(get_models())
469
- pure_vision_models = {model for model in vision_models if model not in text_models}
470
-
471
- physical_expressions = []
472
- for model in physical_op_params["available_models"]:
473
- # skip this model if this is a pure image model
474
- if model in pure_vision_models or model.is_embedding_model():
475
- continue
476
-
477
- for num_chunks_per_field in cls.num_chunks_per_fields:
478
- for chunk_size in cls.chunk_sizes:
479
- # construct multi-expression
480
- op = RAGConvert(
481
- model=model,
482
- prompt_strategy=PromptStrategy.COT_QA,
483
- num_chunks_per_field=num_chunks_per_field,
484
- chunk_size=chunk_size,
485
- **op_kwargs,
486
- )
487
- expression = PhysicalExpression(
488
- operator=op,
489
- input_group_ids=logical_expression.input_group_ids,
490
- input_fields=logical_expression.input_fields,
491
- depends_on_field_names=logical_expression.depends_on_field_names,
492
- generated_fields=logical_expression.generated_fields,
493
- group_id=logical_expression.group_id,
494
- )
495
- physical_expressions.append(expression)
496
-
497
- logger.debug(f"Done substituting RAGConvertRule for {logical_expression}")
498
- deduped_physical_expressions = set(physical_expressions)
499
-
500
- return deduped_physical_expressions
557
+ return cls._perform_substitution(logical_expression, RAGConvert, runtime_kwargs, variable_op_kwargs)
501
558
 
502
559
 
503
560
  class MixtureOfAgentsConvertRule(ImplementationRule):
@@ -511,93 +568,35 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
511
568
  @classmethod
512
569
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
513
570
  logical_op = logical_expression.operator
514
- is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None
571
+ # TODO: remove audio limitation once I add prompts
572
+ is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None and not cls._is_audio_operation(logical_expression)
515
573
  logger.debug(f"MixtureOfAgentsConvertRule matches_pattern: {is_match} for {logical_expression}")
516
574
  return is_match
517
575
 
518
576
  @classmethod
519
- def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
577
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
520
578
  logger.debug(f"Substituting MixtureOfAgentsConvertRule for {logical_expression}")
521
579
 
522
- logical_op = logical_expression.operator
523
-
524
- # get initial set of parameters for physical op
525
- op_kwargs: dict = logical_op.get_logical_op_params()
526
- op_kwargs.update(
580
+ # create variable physical operator kwargs for each model which can implement this logical_expression
581
+ proposer_model_set = {model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)}
582
+ aggregator_model_set = {model for model in runtime_kwargs["available_models"] if model.is_text_model()}
583
+ proposer_prompt_strategy = PromptStrategy.COT_MOA_PROPOSER_IMAGE if cls._is_image_operation(logical_expression) else PromptStrategy.COT_MOA_PROPOSER
584
+ variable_op_kwargs = [
527
585
  {
528
- "verbose": physical_op_params["verbose"],
529
- "logical_op_id": logical_op.get_logical_op_id(),
530
- "logical_op_name": logical_op.logical_op_name(),
586
+ "proposer_models": list(proposer_models),
587
+ "temperatures": [temp] * len(proposer_models),
588
+ "aggregator_model": aggregator_model,
589
+ "proposer_prompt_strategy": proposer_prompt_strategy,
590
+ "aggregator_prompt_strategy": PromptStrategy.COT_MOA_AGG,
591
+ "reasoning_effort": runtime_kwargs["reasoning_effort"],
531
592
  }
532
- )
593
+ for k in cls.num_proposer_models
594
+ for temp in cls.temperatures
595
+ for proposer_models in combinations(proposer_model_set, k)
596
+ for aggregator_model in aggregator_model_set
597
+ ]
533
598
 
534
- # identify models which can be used strictly for text or strictly for images
535
- vision_models = set(get_vision_models())
536
- text_models = set(get_models())
537
-
538
- # construct set of proposer models and set of aggregator models
539
- num_image_fields = sum(
540
- [
541
- field.is_image_field
542
- for field_name, field in logical_expression.input_fields.items()
543
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
544
- ]
545
- )
546
- list_image_field = any(
547
- [
548
- field.is_image_field and hasattr(field, "element_type")
549
- for field_name, field in logical_expression.input_fields.items()
550
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
551
- ]
552
- )
553
- proposer_model_set, is_image_conversion = text_models, False
554
- if num_image_fields > 1 or list_image_field:
555
- proposer_model_set = [model for model in vision_models if not model.is_llama_model()]
556
- is_image_conversion = True
557
- elif num_image_fields == 1:
558
- proposer_model_set = vision_models
559
- is_image_conversion = True
560
- aggregator_model_set = text_models
561
-
562
- # filter un-available models out of sets
563
- proposer_model_set = {model for model in proposer_model_set if model in physical_op_params["available_models"]}
564
- aggregator_model_set = {
565
- model for model in aggregator_model_set if model in physical_op_params["available_models"]
566
- }
567
-
568
- # construct MixtureOfAgentsConvert operations for various numbers of proposer models
569
- # and for every combination of proposer models and aggregator model
570
- physical_expressions = []
571
- for k in cls.num_proposer_models:
572
- for temp in cls.temperatures:
573
- for proposer_models in combinations(proposer_model_set, k):
574
- for aggregator_model in aggregator_model_set:
575
- # construct multi-expression
576
- op = MixtureOfAgentsConvert(
577
- proposer_models=list(proposer_models),
578
- temperatures=[temp] * len(proposer_models),
579
- aggregator_model=aggregator_model,
580
- proposer_prompt=op_kwargs.get("prompt"),
581
- proposer_prompt_strategy=PromptStrategy.COT_MOA_PROPOSER_IMAGE
582
- if is_image_conversion
583
- else PromptStrategy.COT_MOA_PROPOSER,
584
- aggregator_prompt_strategy=PromptStrategy.COT_MOA_AGG,
585
- **op_kwargs,
586
- )
587
- expression = PhysicalExpression(
588
- operator=op,
589
- input_group_ids=logical_expression.input_group_ids,
590
- input_fields=logical_expression.input_fields,
591
- depends_on_field_names=logical_expression.depends_on_field_names,
592
- generated_fields=logical_expression.generated_fields,
593
- group_id=logical_expression.group_id,
594
- )
595
- physical_expressions.append(expression)
596
-
597
- logger.debug(f"Done substituting MixtureOfAgentsConvertRule for {logical_expression}")
598
- deduped_physical_expressions = set(physical_expressions)
599
-
600
- return deduped_physical_expressions
599
+ return cls._perform_substitution(logical_expression, MixtureOfAgentsConvert, runtime_kwargs, variable_op_kwargs)
601
600
 
602
601
 
603
602
  class CriticAndRefineConvertRule(ImplementationRule):
@@ -608,99 +607,32 @@ class CriticAndRefineConvertRule(ImplementationRule):
608
607
  @classmethod
609
608
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
610
609
  logical_op = logical_expression.operator
611
- is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None
610
+ # TODO: remove audio limitation once I add prompts
611
+ is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None and not cls._is_audio_operation(logical_expression)
612
612
  logger.debug(f"CriticAndRefineConvertRule matches_pattern: {is_match} for {logical_expression}")
613
613
  return is_match
614
614
 
615
615
  @classmethod
616
- def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
616
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
617
617
  logger.debug(f"Substituting CriticAndRefineConvertRule for {logical_expression}")
618
618
 
619
- logical_op = logical_expression.operator
620
-
621
- # Get initial parameters for physical operator
622
- op_kwargs = logical_op.get_logical_op_params()
623
- op_kwargs.update(
619
+ # create variable physical operator kwargs for each model which can implement this logical_expression
620
+ models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
621
+ prompt_strategy = PromptStrategy.COT_QA_IMAGE if cls._is_image_operation(logical_expression) else PromptStrategy.COT_QA
622
+ variable_op_kwargs = [
624
623
  {
625
- "verbose": physical_op_params["verbose"],
626
- "logical_op_id": logical_op.get_logical_op_id(),
627
- "logical_op_name": logical_op.logical_op_name(),
624
+ "model": model,
625
+ "critic_model": critic_model,
626
+ "refine_model": refine_model,
627
+ "prompt_strategy": prompt_strategy,
628
+ "reasoning_effort": runtime_kwargs["reasoning_effort"],
628
629
  }
629
- )
630
+ for model in models
631
+ for critic_model in models
632
+ for refine_model in models
633
+ ]
630
634
 
631
- # identify models which can be used strictly for text or strictly for images
632
- vision_models = set(get_vision_models())
633
- text_models = set(get_models())
634
- pure_text_models = {model for model in text_models if model not in vision_models}
635
- pure_vision_models = {model for model in vision_models if model not in text_models}
636
-
637
- # compute attributes about this convert operation
638
- is_image_conversion = any(
639
- [
640
- field.is_image_field
641
- for field_name, field in logical_expression.input_fields.items()
642
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
643
- ]
644
- )
645
- num_image_fields = sum(
646
- [
647
- field.is_image_field
648
- for field_name, field in logical_expression.input_fields.items()
649
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
650
- ]
651
- )
652
- list_image_field = any(
653
- [
654
- field.is_image_field and hasattr(field, "element_type")
655
- for field_name, field in logical_expression.input_fields.items()
656
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
657
- ]
658
- )
659
-
660
- # identify models which can be used for this convert operation
661
- models = []
662
- for model in physical_op_params["available_models"]:
663
- # skip this model if:
664
- # 1. this is a pure vision model and we're not doing an image conversion, or
665
- # 2. this is a pure text model and we're doing an image conversion, or
666
- # 3. this is a vision model hosted by Together (i.e. LLAMA3 vision) and there is more than one image field
667
- first_criteria = model in pure_vision_models and not is_image_conversion
668
- second_criteria = model in pure_text_models and is_image_conversion
669
- third_criteria = model.is_llama_model() and model.is_vision_model() and (num_image_fields > 1 or list_image_field)
670
- fourth_criteria = model.is_embedding_model()
671
- if first_criteria or second_criteria or third_criteria or fourth_criteria:
672
- continue
673
-
674
- models.append(model)
675
-
676
- # TODO: heuristic(s) to narrow the space of critic and refine models we consider using class attributes
677
- # construct CriticAndRefineConvert operations for every combination of model, critic model, and refinement model
678
- physical_expressions = []
679
- for model in models:
680
- for critic_model in models:
681
- for refine_model in models:
682
- # construct multi-expression
683
- op = CriticAndRefineConvert(
684
- model=model,
685
- prompt_strategy=PromptStrategy.COT_QA_IMAGE if is_image_conversion else PromptStrategy.COT_QA,
686
- critic_model=critic_model,
687
- refine_model=refine_model,
688
- **op_kwargs,
689
- )
690
- expression = PhysicalExpression(
691
- operator=op,
692
- input_group_ids=logical_expression.input_group_ids,
693
- input_fields=logical_expression.input_fields,
694
- depends_on_field_names=logical_expression.depends_on_field_names,
695
- generated_fields=logical_expression.generated_fields,
696
- group_id=logical_expression.group_id,
697
- )
698
- physical_expressions.append(expression)
699
-
700
- logger.debug(f"Done substituting CriticAndRefineConvertRule for {logical_expression}")
701
- deduped_physical_expressions = set(physical_expressions)
702
-
703
- return deduped_physical_expressions
635
+ return cls._perform_substitution(logical_expression, CriticAndRefineConvert, runtime_kwargs, variable_op_kwargs)
704
636
 
705
637
 
706
638
  class SplitConvertRule(ImplementationRule):
@@ -713,67 +645,29 @@ class SplitConvertRule(ImplementationRule):
713
645
  @classmethod
714
646
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
715
647
  logical_op = logical_expression.operator
716
- is_image_conversion = any(
717
- [
718
- field.is_image_field
719
- for field_name, field in logical_expression.input_fields.items()
720
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
721
- ]
722
- )
723
- is_match = isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
648
+ is_match = isinstance(logical_op, ConvertScan) and cls._is_text_only_operation() and logical_op.udf is None
724
649
  logger.debug(f"SplitConvertRule matches_pattern: {is_match} for {logical_expression}")
725
650
  return is_match
726
651
 
727
652
  @classmethod
728
- def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
653
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
729
654
  logger.debug(f"Substituting SplitConvertRule for {logical_expression}")
730
655
 
731
- logical_op = logical_expression.operator
732
-
733
- # get initial set of parameters for physical op
734
- op_kwargs = logical_op.get_logical_op_params()
735
- op_kwargs.update(
656
+ # create variable physical operator kwargs for each model which can implement this logical_expression
657
+ models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
658
+ variable_op_kwargs = [
736
659
  {
737
- "verbose": physical_op_params["verbose"],
738
- "logical_op_id": logical_op.get_logical_op_id(),
739
- "logical_op_name": logical_op.logical_op_name(),
660
+ "model": model,
661
+ "min_size_to_chunk": min_size_to_chunk,
662
+ "num_chunks": num_chunks,
663
+ "reasoning_effort": runtime_kwargs["reasoning_effort"],
740
664
  }
741
- )
742
-
743
- # identify models which can be used strictly for text or strictly for images
744
- vision_models = set(get_vision_models())
745
- text_models = set(get_models())
746
- pure_vision_models = {model for model in vision_models if model not in text_models}
747
-
748
- physical_expressions = []
749
- for model in physical_op_params["available_models"]:
750
- # skip this model if this is a pure image model
751
- if model in pure_vision_models or model.is_embedding_model():
752
- continue
665
+ for model in models
666
+ for min_size_to_chunk in cls.min_size_to_chunk
667
+ for num_chunks in cls.num_chunks
668
+ ]
753
669
 
754
- for min_size_to_chunk in cls.min_size_to_chunk:
755
- for num_chunks in cls.num_chunks:
756
- # construct multi-expression
757
- op = SplitConvert(
758
- model=model,
759
- num_chunks=num_chunks,
760
- min_size_to_chunk=min_size_to_chunk,
761
- **op_kwargs,
762
- )
763
- expression = PhysicalExpression(
764
- operator=op,
765
- input_group_ids=logical_expression.input_group_ids,
766
- input_fields=logical_expression.input_fields,
767
- depends_on_field_names=logical_expression.depends_on_field_names,
768
- generated_fields=logical_expression.generated_fields,
769
- group_id=logical_expression.group_id,
770
- )
771
- physical_expressions.append(expression)
772
-
773
- logger.debug(f"Done substituting SplitConvertRule for {logical_expression}")
774
- deduped_physical_expressions = set(physical_expressions)
775
-
776
- return deduped_physical_expressions
670
+ return cls._perform_substitution(logical_expression, SplitConvert, runtime_kwargs, variable_op_kwargs)
777
671
 
778
672
 
779
673
  class RetrieveRule(ImplementationRule):
@@ -789,42 +683,13 @@ class RetrieveRule(ImplementationRule):
789
683
  return is_match
790
684
 
791
685
  @classmethod
792
- def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
686
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
793
687
  logger.debug(f"Substituting RetrieveRule for {logical_expression}")
794
688
 
795
- logical_op = logical_expression.operator
796
-
797
- physical_expressions = []
798
- ks = cls.k_budgets if logical_op.k == -1 else [logical_op.k]
799
- for k in ks:
800
- # get initial set of parameters for physical op
801
- op_kwargs = logical_op.get_logical_op_params()
802
- op_kwargs.update(
803
- {
804
- "verbose": physical_op_params["verbose"],
805
- "logical_op_id": logical_op.get_logical_op_id(),
806
- "logical_op_name": logical_op.logical_op_name(),
807
- "k": k,
808
- }
809
- )
810
-
811
- # construct multi-expression
812
- op = RetrieveOp(**op_kwargs)
813
- expression = PhysicalExpression(
814
- operator=op,
815
- input_group_ids=logical_expression.input_group_ids,
816
- input_fields=logical_expression.input_fields,
817
- depends_on_field_names=logical_expression.depends_on_field_names,
818
- generated_fields=logical_expression.generated_fields,
819
- group_id=logical_expression.group_id,
820
- )
821
-
822
- physical_expressions.append(expression)
823
-
824
- logger.debug(f"Done substituting RetrieveRule for {logical_expression}")
825
- deduped_physical_expressions = set(physical_expressions)
826
-
827
- return deduped_physical_expressions
689
+ # create variable physical operator kwargs for each model which can implement this logical_expression
690
+ ks = cls.k_budgets if logical_expression.operator.k == -1 else [logical_expression.operator.k]
691
+ variable_op_kwargs = [{"k": k} for k in ks]
692
+ return cls._perform_substitution(logical_expression, RetrieveOp, runtime_kwargs, variable_op_kwargs)
828
693
 
829
694
 
830
695
  class NonLLMFilterRule(ImplementationRule):
@@ -832,8 +697,8 @@ class NonLLMFilterRule(ImplementationRule):
832
697
  Substitute a logical expression for a FilteredScan with a non-llm filter physical implementation.
833
698
  """
834
699
 
835
- @staticmethod
836
- def matches_pattern(logical_expression: LogicalExpression) -> bool:
700
+ @classmethod
701
+ def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
837
702
  is_match = (
838
703
  isinstance(logical_expression.operator, FilteredScan)
839
704
  and logical_expression.operator.filter.filter_fn is not None
@@ -841,33 +706,10 @@ class NonLLMFilterRule(ImplementationRule):
841
706
  logger.debug(f"NonLLMFilterRule matches_pattern: {is_match} for {logical_expression}")
842
707
  return is_match
843
708
 
844
- @staticmethod
845
- def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
709
+ @classmethod
710
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
846
711
  logger.debug(f"Substituting NonLLMFilterRule for {logical_expression}")
847
-
848
- logical_op = logical_expression.operator
849
- op_kwargs = logical_op.get_logical_op_params()
850
- op_kwargs.update(
851
- {
852
- "verbose": physical_op_params["verbose"],
853
- "logical_op_id": logical_op.get_logical_op_id(),
854
- "logical_op_name": logical_op.logical_op_name(),
855
- }
856
- )
857
- op = NonLLMFilter(**op_kwargs)
858
-
859
- expression = PhysicalExpression(
860
- operator=op,
861
- input_group_ids=logical_expression.input_group_ids,
862
- input_fields=logical_expression.input_fields,
863
- depends_on_field_names=logical_expression.depends_on_field_names,
864
- generated_fields=logical_expression.generated_fields,
865
- group_id=logical_expression.group_id,
866
- )
867
- logger.debug(f"Done substituting NonLLMFilterRule for {logical_expression}")
868
- deduped_physical_expressions = set([expression])
869
-
870
- return deduped_physical_expressions
712
+ return cls._perform_substitution(logical_expression, NonLLMFilter, runtime_kwargs)
871
713
 
872
714
 
873
715
  class LLMFilterRule(ImplementationRule):
@@ -875,8 +717,8 @@ class LLMFilterRule(ImplementationRule):
875
717
  Substitute a logical expression for a FilteredScan with an llm filter physical implementation.
876
718
  """
877
719
 
878
- @staticmethod
879
- def matches_pattern(logical_expression: LogicalExpression) -> bool:
720
+ @classmethod
721
+ def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
880
722
  is_match = (
881
723
  isinstance(logical_expression.operator, FilteredScan)
882
724
  and logical_expression.operator.filter.filter_condition is not None
@@ -884,82 +726,76 @@ class LLMFilterRule(ImplementationRule):
884
726
  logger.debug(f"LLMFilterRule matches_pattern: {is_match} for {logical_expression}")
885
727
  return is_match
886
728
 
887
- @staticmethod
888
- def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
729
+ @classmethod
730
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
889
731
  logger.debug(f"Substituting LLMFilterRule for {logical_expression}")
890
732
 
891
- logical_op = logical_expression.operator
892
- op_kwargs = logical_op.get_logical_op_params()
893
- op_kwargs.update(
733
+ # create variable physical operator kwargs for each model which can implement this logical_expression
734
+ models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
735
+ # NOTE: right now we exclusively allow image or audio operations, but not both simultaneously
736
+ prompt_strategy, no_reasoning_prompt_strategy = None, None
737
+ no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
738
+ if cls._is_text_only_operation(logical_expression):
739
+ prompt_strategy = PromptStrategy.COT_BOOL
740
+ no_reasoning_prompt_strategy = PromptStrategy.COT_BOOL_NO_REASONING
741
+ elif cls._is_image_operation(logical_expression):
742
+ prompt_strategy = PromptStrategy.COT_BOOL_IMAGE
743
+ no_reasoning_prompt_strategy = PromptStrategy.COT_BOOL_IMAGE_NO_REASONING
744
+ elif cls._is_audio_operation(logical_expression):
745
+ prompt_strategy = PromptStrategy.COT_BOOL_AUDIO
746
+ no_reasoning_prompt_strategy = PromptStrategy.COT_BOOL_AUDIO_NO_REASONING
747
+ variable_op_kwargs = [
894
748
  {
895
- "verbose": physical_op_params["verbose"],
896
- "logical_op_id": logical_op.get_logical_op_id(),
897
- "logical_op_name": logical_op.logical_op_name(),
749
+ "model": model,
750
+ "prompt_strategy": no_reasoning_prompt_strategy if model.is_reasoning_model() and no_reasoning else prompt_strategy,
751
+ "reasoning_effort": runtime_kwargs["reasoning_effort"]
898
752
  }
899
- )
753
+ for model in models
754
+ ]
900
755
 
901
- # identify models which can be used strictly for text or strictly for images
902
- vision_models = set(get_vision_models())
903
- text_models = set(get_models())
904
- pure_text_models = {model for model in text_models if model not in vision_models}
905
- pure_vision_models = {model for model in vision_models if model not in text_models}
906
-
907
- # compute attributes about this filter operation
908
- is_image_filter = any(
909
- [
910
- field.is_image_field
911
- for field_name, field in logical_expression.input_fields.items()
912
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
913
- ]
914
- )
915
- num_image_fields = sum(
916
- [
917
- field.is_image_field
918
- for field_name, field in logical_expression.input_fields.items()
919
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
920
- ]
921
- )
922
- list_image_field = any(
923
- [
924
- field.is_image_field and hasattr(field, "element_type")
925
- for field_name, field in logical_expression.input_fields.items()
926
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
927
- ]
928
- )
756
+ return cls._perform_substitution(logical_expression, LLMFilter, runtime_kwargs, variable_op_kwargs)
929
757
 
930
- physical_expressions = []
931
- for model in physical_op_params["available_models"]:
932
- # skip this model if:
933
- # 1. this is a pure vision model and we're not doing an image filter, or
934
- # 2. this is a pure text model and we're doing an image filter, or
935
- # 3. this is a vision model hosted by Together (i.e. LLAMA3 vision) and there is more than one image field
936
- first_criteria = model in pure_vision_models and not is_image_filter
937
- second_criteria = model in pure_text_models and is_image_filter
938
- third_criteria = model.is_llama_model() and model.is_vision_model() and (num_image_fields > 1 or list_image_field)
939
- fourth_criteria = model.is_embedding_model()
940
- if first_criteria or second_criteria or third_criteria or fourth_criteria:
941
- continue
942
758
 
943
- # construct multi-expression
944
- op = LLMFilter(
945
- model=model,
946
- prompt_strategy=PromptStrategy.COT_BOOL_IMAGE if is_image_filter else PromptStrategy.COT_BOOL,
947
- **op_kwargs,
948
- )
949
- expression = PhysicalExpression(
950
- operator=op,
951
- input_group_ids=logical_expression.input_group_ids,
952
- input_fields=logical_expression.input_fields,
953
- depends_on_field_names=logical_expression.depends_on_field_names,
954
- generated_fields=logical_expression.generated_fields,
955
- group_id=logical_expression.group_id,
956
- )
957
- physical_expressions.append(expression)
759
+ class LLMJoinRule(ImplementationRule):
760
+ """
761
+ Substitute a logical expression for a JoinOp with an (LLM) NestedLoopsJoin physical implementation.
762
+ """
958
763
 
959
- logger.debug(f"Done substituting LLMFilterRule for {logical_expression}")
960
- deduped_physical_expressions = set(physical_expressions)
764
+ @classmethod
765
+ def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
766
+ is_match = isinstance(logical_expression.operator, JoinOp)
767
+ logger.debug(f"LLMJoinRule matches_pattern: {is_match} for {logical_expression}")
768
+ return is_match
961
769
 
962
- return deduped_physical_expressions
770
+ @classmethod
771
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
772
+ logger.debug(f"Substituting LLMJoinRule for {logical_expression}")
773
+
774
+ # create variable physical operator kwargs for each model which can implement this logical_expression
775
+ models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
776
+ # NOTE: right now we exclusively allow image or audio operations, but not both simultaneously
777
+ prompt_strategy, no_reasoning_prompt_strategy = None, None
778
+ no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
779
+ if cls._is_text_only_operation(logical_expression):
780
+ prompt_strategy = PromptStrategy.COT_JOIN
781
+ no_reasoning_prompt_strategy = PromptStrategy.COT_JOIN_NO_REASONING
782
+ elif cls._is_image_operation(logical_expression):
783
+ prompt_strategy = PromptStrategy.COT_JOIN_IMAGE
784
+ no_reasoning_prompt_strategy = PromptStrategy.COT_JOIN_IMAGE_NO_REASONING
785
+ elif cls._is_audio_operation(logical_expression):
786
+ prompt_strategy = PromptStrategy.COT_JOIN_AUDIO
787
+ no_reasoning_prompt_strategy = PromptStrategy.COT_JOIN_AUDIO_NO_REASONING
788
+ variable_op_kwargs = [
789
+ {
790
+ "model": model,
791
+ "prompt_strategy": no_reasoning_prompt_strategy if model.is_reasoning_model() and no_reasoning else prompt_strategy,
792
+ "join_parallelism": runtime_kwargs["join_parallelism"],
793
+ "reasoning_effort": runtime_kwargs["reasoning_effort"],
794
+ }
795
+ for model in models
796
+ ]
797
+
798
+ return cls._perform_substitution(logical_expression, NestedLoopsJoin, runtime_kwargs, variable_op_kwargs)
963
799
 
964
800
 
965
801
  class AggregateRule(ImplementationRule):
@@ -967,47 +803,71 @@ class AggregateRule(ImplementationRule):
967
803
  Substitute the logical expression for an aggregate with its physical counterpart.
968
804
  """
969
805
 
970
- @staticmethod
971
- def matches_pattern(logical_expression: LogicalExpression) -> bool:
806
+ @classmethod
807
+ def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
972
808
  is_match = isinstance(logical_expression.operator, Aggregate)
973
809
  logger.debug(f"AggregateRule matches_pattern: {is_match} for {logical_expression}")
974
810
  return is_match
975
811
 
976
- @staticmethod
977
- def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
812
+ @classmethod
813
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
978
814
  logger.debug(f"Substituting AggregateRule for {logical_expression}")
979
815
 
980
- logical_op = logical_expression.operator
981
- op_kwargs = logical_op.get_logical_op_params()
982
- op_kwargs.update(
983
- {
984
- "verbose": physical_op_params["verbose"],
985
- "logical_op_id": logical_op.get_logical_op_id(),
986
- "logical_op_name": logical_op.logical_op_name(),
987
- }
988
- )
989
-
990
- op = None
991
- if logical_op.agg_func == AggFunc.COUNT:
992
- op = CountAggregateOp(**op_kwargs)
993
- elif logical_op.agg_func == AggFunc.AVERAGE:
994
- op = AverageAggregateOp(**op_kwargs)
816
+ # get the physical op class based on the aggregation function
817
+ physical_op_class = None
818
+ if logical_expression.operator.agg_func == AggFunc.COUNT:
819
+ physical_op_class = CountAggregateOp
820
+ elif logical_expression.operator.agg_func == AggFunc.AVERAGE:
821
+ physical_op_class = AverageAggregateOp
995
822
  else:
996
- raise Exception(f"Cannot support aggregate function: {logical_op.agg_func}")
997
-
998
- expression = PhysicalExpression(
999
- operator=op,
1000
- input_group_ids=logical_expression.input_group_ids,
1001
- input_fields=logical_expression.input_fields,
1002
- depends_on_field_names=logical_expression.depends_on_field_names,
1003
- generated_fields=logical_expression.generated_fields,
1004
- group_id=logical_expression.group_id,
1005
- )
823
+ raise Exception(f"Cannot support aggregate function: {logical_expression.operator.agg_func}")
1006
824
 
1007
- logger.debug(f"Done substituting AggregateRule for {logical_expression}")
1008
- deduped_physical_expressions = set([expression])
825
+ # perform the substitution
826
+ return cls._perform_substitution(logical_expression, physical_op_class, runtime_kwargs)
827
+
828
+
829
+ class AddContextsBeforeComputeRule(ImplementationRule):
830
+ """
831
+ Searches the ContextManager for additional contexts which may be useful for the given computation.
832
+
833
+ TODO: track cost of generating search query
834
+ """
835
+ k = 1
836
+ SEARCH_GENERATOR_PROMPT = CONTEXT_SEARCH_PROMPT
837
+
838
+ @classmethod
839
+ def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
840
+ is_match = isinstance(logical_expression.operator, ComputeOperator)
841
+ logger.debug(f"AddContextsBeforeComputeRule matches_pattern: {is_match} for {logical_expression}")
842
+ return is_match
1009
843
 
1010
- return deduped_physical_expressions
844
+ @classmethod
845
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
846
+ logger.debug(f"Substituting AddContextsBeforeComputeRule for {logical_expression}")
847
+
848
+ # load an LLM to generate a short search query
849
+ model = None
850
+ if os.getenv("OPENAI_API_KEY"):
851
+ model = "openai/gpt-4o-mini"
852
+ elif os.getenv("ANTHROPIC_API_KEY"):
853
+ model = "anthropic/claude-3-5-sonnet-20241022"
854
+ elif os.getenv("GEMINI_API_KEY"):
855
+ model = "vertex_ai/gemini-2.0-flash"
856
+ elif os.getenv("TOGETHER_API_KEY"):
857
+ model = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo"
858
+
859
+ # importing litellm here because importing above causes deprecation warning
860
+ import litellm
861
+
862
+ # retrieve any additional context which may be useful
863
+ cm = ContextManager()
864
+ response = litellm.completion(
865
+ model=model,
866
+ messages=[{"role": "user", "content": cls.SEARCH_GENERATOR_PROMPT.format(instruction=logical_expression.operator.instruction)}]
867
+ )
868
+ query = response.choices[0].message.content
869
+ variable_op_kwargs = {"additional_contexts": cm.search_context(query, k=cls.k, where={"materialized": True})}
870
+ return cls._perform_substitution(logical_expression, SmolAgentsCompute, runtime_kwargs, variable_op_kwargs)
1011
871
 
1012
872
 
1013
873
  class BasicSubstitutionRule(ImplementationRule):
@@ -1018,11 +878,13 @@ class BasicSubstitutionRule(ImplementationRule):
1018
878
 
1019
879
  LOGICAL_OP_CLASS_TO_PHYSICAL_OP_CLASS_MAP = {
1020
880
  BaseScan: MarshalAndScanDataOp,
1021
- CacheScan: CacheScanDataOp,
881
+ # ComputeOperator: SmolAgentsCompute,
882
+ SearchOperator: SmolAgentsSearch, # SmolAgentsManagedSearch, # SmolAgentsCustomManagedSearch
883
+ ContextScan: ContextScanOp,
884
+ Distinct: DistinctOp,
1022
885
  LimitScan: LimitScanOp,
1023
886
  Project: ProjectOp,
1024
887
  GroupByAggregate: ApplyGroupByOp,
1025
- MapScan: MapOp,
1026
888
  }
1027
889
 
1028
890
  @classmethod
@@ -1033,31 +895,7 @@ class BasicSubstitutionRule(ImplementationRule):
1033
895
  return is_match
1034
896
 
1035
897
  @classmethod
1036
- def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
898
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
1037
899
  logger.debug(f"Substituting BasicSubstitutionRule for {logical_expression}")
1038
-
1039
- logical_op = logical_expression.operator
1040
- op_kwargs = logical_op.get_logical_op_params()
1041
- op_kwargs.update(
1042
- {
1043
- "verbose": physical_op_params["verbose"],
1044
- "logical_op_id": logical_op.get_logical_op_id(),
1045
- "logical_op_name": logical_op.logical_op_name(),
1046
- }
1047
- )
1048
- physical_op_class = cls.LOGICAL_OP_CLASS_TO_PHYSICAL_OP_CLASS_MAP[logical_op.__class__]
1049
- op = physical_op_class(**op_kwargs)
1050
-
1051
- expression = PhysicalExpression(
1052
- operator=op,
1053
- input_group_ids=logical_expression.input_group_ids,
1054
- input_fields=logical_expression.input_fields,
1055
- depends_on_field_names=logical_expression.depends_on_field_names,
1056
- generated_fields=logical_expression.generated_fields,
1057
- group_id=logical_expression.group_id,
1058
- )
1059
-
1060
- logger.debug(f"Done substituting BasicSubstitutionRule for {logical_expression}")
1061
- deduped_physical_expressions = set([expression])
1062
-
1063
- return deduped_physical_expressions
900
+ physical_op_class = cls.LOGICAL_OP_CLASS_TO_PHYSICAL_OP_CLASS_MAP[logical_expression.operator.__class__]
901
+ return cls._perform_substitution(logical_expression, physical_op_class, runtime_kwargs)