palimpzest 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +5 -0
- palimpzest/constants.py +110 -43
- palimpzest/core/__init__.py +0 -78
- palimpzest/core/data/dataclasses.py +382 -44
- palimpzest/core/elements/filters.py +7 -3
- palimpzest/core/elements/index.py +70 -0
- palimpzest/core/elements/records.py +33 -11
- palimpzest/core/lib/fields.py +1 -0
- palimpzest/core/lib/schemas.py +4 -3
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
- palimpzest/prompts/prompt_factory.py +44 -7
- palimpzest/prompts/split_merge_prompts.py +56 -0
- palimpzest/prompts/split_proposer_prompts.py +55 -0
- palimpzest/query/execution/execution_strategy.py +435 -53
- palimpzest/query/execution/execution_strategy_type.py +20 -0
- palimpzest/query/execution/mab_execution_strategy.py +532 -0
- palimpzest/query/execution/parallel_execution_strategy.py +143 -172
- palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
- palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
- palimpzest/query/generators/api_client_factory.py +31 -0
- palimpzest/query/generators/generators.py +256 -76
- palimpzest/query/operators/__init__.py +1 -2
- palimpzest/query/operators/code_synthesis_convert.py +33 -18
- palimpzest/query/operators/convert.py +30 -97
- palimpzest/query/operators/critique_and_refine_convert.py +5 -6
- palimpzest/query/operators/filter.py +7 -10
- palimpzest/query/operators/logical.py +54 -10
- palimpzest/query/operators/map.py +130 -0
- palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
- palimpzest/query/operators/physical.py +3 -12
- palimpzest/query/operators/rag_convert.py +66 -18
- palimpzest/query/operators/retrieve.py +230 -34
- palimpzest/query/operators/scan.py +5 -2
- palimpzest/query/operators/split_convert.py +169 -0
- palimpzest/query/operators/token_reduction_convert.py +8 -14
- palimpzest/query/optimizer/__init__.py +4 -16
- palimpzest/query/optimizer/cost_model.py +73 -266
- palimpzest/query/optimizer/optimizer.py +87 -58
- palimpzest/query/optimizer/optimizer_strategy.py +18 -97
- palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/primitives.py +5 -3
- palimpzest/query/optimizer/rules.py +336 -172
- palimpzest/query/optimizer/tasks.py +30 -100
- palimpzest/query/processor/config.py +38 -22
- palimpzest/query/processor/nosentinel_processor.py +16 -520
- palimpzest/query/processor/processing_strategy_type.py +28 -0
- palimpzest/query/processor/query_processor.py +38 -206
- palimpzest/query/processor/query_processor_factory.py +117 -130
- palimpzest/query/processor/sentinel_processor.py +90 -0
- palimpzest/query/processor/streaming_processor.py +25 -32
- palimpzest/sets.py +88 -41
- palimpzest/utils/model_helpers.py +8 -7
- palimpzest/utils/progress.py +368 -152
- palimpzest/utils/token_reduction_helpers.py +1 -3
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/METADATA +28 -24
- palimpzest-0.7.0.dist-info/RECORD +96 -0
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
- palimpzest/query/processor/mab_sentinel_processor.py +0 -884
- palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
- palimpzest/utils/index_helpers.py +0 -6
- palimpzest-0.6.3.dist-info/RECORD +0 -87
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from copy import deepcopy
|
|
2
3
|
from itertools import combinations
|
|
3
4
|
|
|
4
5
|
from palimpzest.constants import AggFunc, Cardinality, Model, PromptStrategy
|
|
5
6
|
from palimpzest.query.operators.aggregate import ApplyGroupByOp, AverageAggregateOp, CountAggregateOp
|
|
6
7
|
from palimpzest.query.operators.code_synthesis_convert import CodeSynthesisConvertSingle
|
|
7
|
-
from palimpzest.query.operators.convert import LLMConvertBonded,
|
|
8
|
+
from palimpzest.query.operators.convert import LLMConvertBonded, NonLLMConvert
|
|
8
9
|
from palimpzest.query.operators.critique_and_refine_convert import CriticAndRefineConvert
|
|
9
10
|
from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter
|
|
10
11
|
from palimpzest.query.operators.limit import LimitScanOp
|
|
@@ -16,21 +17,23 @@ from palimpzest.query.operators.logical import (
|
|
|
16
17
|
FilteredScan,
|
|
17
18
|
GroupByAggregate,
|
|
18
19
|
LimitScan,
|
|
20
|
+
MapScan,
|
|
19
21
|
Project,
|
|
20
22
|
RetrieveScan,
|
|
21
23
|
)
|
|
24
|
+
from palimpzest.query.operators.map import MapOp
|
|
22
25
|
from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgentsConvert
|
|
23
26
|
from palimpzest.query.operators.project import ProjectOp
|
|
24
27
|
from palimpzest.query.operators.rag_convert import RAGConvert
|
|
25
28
|
from palimpzest.query.operators.retrieve import RetrieveOp
|
|
26
29
|
from palimpzest.query.operators.scan import CacheScanDataOp, MarshalAndScanDataOp
|
|
27
|
-
from palimpzest.query.operators.
|
|
28
|
-
|
|
29
|
-
TokenReducedConvertConventional,
|
|
30
|
-
)
|
|
30
|
+
from palimpzest.query.operators.split_convert import SplitConvert
|
|
31
|
+
from palimpzest.query.operators.token_reduction_convert import TokenReducedConvertBonded
|
|
31
32
|
from palimpzest.query.optimizer.primitives import Expression, Group, LogicalExpression, PhysicalExpression
|
|
32
33
|
from palimpzest.utils.model_helpers import get_models, get_vision_models
|
|
33
34
|
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
34
37
|
|
|
35
38
|
class Rule:
|
|
36
39
|
"""
|
|
@@ -81,12 +84,16 @@ class PushDownFilter(TransformationRule):
|
|
|
81
84
|
|
|
82
85
|
@staticmethod
|
|
83
86
|
def matches_pattern(logical_expression: Expression) -> bool:
|
|
84
|
-
|
|
87
|
+
is_match = isinstance(logical_expression.operator, FilteredScan)
|
|
88
|
+
logger.debug(f"PushDownFilter matches_pattern: {is_match} for {logical_expression}")
|
|
89
|
+
return is_match
|
|
85
90
|
|
|
86
91
|
@staticmethod
|
|
87
92
|
def substitute(
|
|
88
93
|
logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
|
|
89
94
|
) -> tuple[set[LogicalExpression], set[Group]]:
|
|
95
|
+
logger.debug(f"Substituting PushDownFilter for {logical_expression}")
|
|
96
|
+
|
|
90
97
|
# initialize the sets of new logical expressions and groups to be returned
|
|
91
98
|
new_logical_expressions, new_groups = set(), set()
|
|
92
99
|
|
|
@@ -102,8 +109,10 @@ class PushDownFilter(TransformationRule):
|
|
|
102
109
|
continue
|
|
103
110
|
|
|
104
111
|
# iterate over logical expressions
|
|
105
|
-
|
|
106
|
-
|
|
112
|
+
# NOTE: we previously deepcopy'ed the logical expression to avoid modifying the original;
|
|
113
|
+
# I think I've fixed this internally, but I'm leaving this NOTE as a reminder in case
|
|
114
|
+
# we see a regression / bug in the future
|
|
115
|
+
for expr in input_group.logical_expressions:
|
|
107
116
|
# if the expression operator is not a convert or a filter, we cannot swap
|
|
108
117
|
if not (isinstance(expr.operator, (ConvertScan, FilteredScan))):
|
|
109
118
|
continue
|
|
@@ -181,7 +190,7 @@ class PushDownFilter(TransformationRule):
|
|
|
181
190
|
|
|
182
191
|
# create final new logical expression with expr's operator pulled up
|
|
183
192
|
new_expr = LogicalExpression(
|
|
184
|
-
expr.operator,
|
|
193
|
+
expr.operator.copy(),
|
|
185
194
|
input_group_ids=[group_id]
|
|
186
195
|
+ [g_id for g_id in logical_expression.input_group_ids if g_id != input_group_id],
|
|
187
196
|
input_fields=group.fields,
|
|
@@ -193,6 +202,8 @@ class PushDownFilter(TransformationRule):
|
|
|
193
202
|
# add newly created expression to set of returned expressions
|
|
194
203
|
new_logical_expressions.add(new_expr)
|
|
195
204
|
|
|
205
|
+
logger.debug(f"Done substituting PushDownFilter for {logical_expression}")
|
|
206
|
+
|
|
196
207
|
return new_logical_expressions, new_groups
|
|
197
208
|
|
|
198
209
|
|
|
@@ -211,10 +222,14 @@ class NonLLMConvertRule(ImplementationRule):
|
|
|
211
222
|
|
|
212
223
|
@classmethod
|
|
213
224
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
214
|
-
|
|
225
|
+
is_match = isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is not None
|
|
226
|
+
logger.debug(f"NonLLMConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
227
|
+
return is_match
|
|
215
228
|
|
|
216
229
|
@classmethod
|
|
217
230
|
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
231
|
+
logger.debug(f"Substituting NonLLMConvertRule for {logical_expression}")
|
|
232
|
+
|
|
218
233
|
logical_op = logical_expression.operator
|
|
219
234
|
|
|
220
235
|
# get initial set of parameters for physical op
|
|
@@ -238,27 +253,27 @@ class NonLLMConvertRule(ImplementationRule):
|
|
|
238
253
|
group_id=logical_expression.group_id,
|
|
239
254
|
)
|
|
240
255
|
|
|
241
|
-
|
|
256
|
+
deduped_physical_expressions = set([expression])
|
|
257
|
+
logger.debug(f"Done substituting NonLLMConvertRule for {logical_expression}")
|
|
242
258
|
|
|
259
|
+
return deduped_physical_expressions
|
|
243
260
|
|
|
244
|
-
class LLMConvertRule(ImplementationRule):
|
|
245
|
-
"""
|
|
246
|
-
Base rule for bonded and conventional LLM convert operators; the physical convert class
|
|
247
|
-
(LLMConvertBonded or LLMConvertConventional) is provided by sub-class rules.
|
|
248
261
|
|
|
249
|
-
|
|
250
|
-
|
|
262
|
+
class LLMConvertBondedRule(ImplementationRule):
|
|
263
|
+
"""
|
|
264
|
+
Substitute a logical expression for a ConvertScan with a bonded convert physical implementation.
|
|
251
265
|
"""
|
|
252
|
-
|
|
253
|
-
# overridden by sub-classes
|
|
254
|
-
physical_convert_class = None
|
|
255
266
|
|
|
256
267
|
@classmethod
|
|
257
268
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
258
|
-
|
|
269
|
+
is_match = isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is None
|
|
270
|
+
logger.debug(f"LLMConvertBondedRule matches_pattern: {is_match} for {logical_expression}")
|
|
271
|
+
return is_match
|
|
259
272
|
|
|
260
273
|
@classmethod
|
|
261
274
|
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
275
|
+
logger.debug(f"Substituting LLMConvertBondedRule for {logical_expression}")
|
|
276
|
+
|
|
262
277
|
logical_op = logical_expression.operator
|
|
263
278
|
|
|
264
279
|
# get initial set of parameters for physical op
|
|
@@ -281,21 +296,27 @@ class LLMConvertRule(ImplementationRule):
|
|
|
281
296
|
pure_vision_models = {model for model in vision_models if model not in text_models}
|
|
282
297
|
|
|
283
298
|
# compute attributes about this convert operation
|
|
284
|
-
is_image_conversion = any(
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
+
is_image_conversion = any(
|
|
300
|
+
[
|
|
301
|
+
field.is_image_field
|
|
302
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
303
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
304
|
+
]
|
|
305
|
+
)
|
|
306
|
+
num_image_fields = sum(
|
|
307
|
+
[
|
|
308
|
+
field.is_image_field
|
|
309
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
310
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
311
|
+
]
|
|
312
|
+
)
|
|
313
|
+
list_image_field = any(
|
|
314
|
+
[
|
|
315
|
+
field.is_image_field and hasattr(field, "element_type")
|
|
316
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
317
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
318
|
+
]
|
|
319
|
+
)
|
|
299
320
|
|
|
300
321
|
physical_expressions = []
|
|
301
322
|
for model in physical_op_params["available_models"]:
|
|
@@ -310,7 +331,7 @@ class LLMConvertRule(ImplementationRule):
|
|
|
310
331
|
continue
|
|
311
332
|
|
|
312
333
|
# construct multi-expression
|
|
313
|
-
op =
|
|
334
|
+
op = LLMConvertBonded(
|
|
314
335
|
model=model,
|
|
315
336
|
prompt_strategy=PromptStrategy.COT_QA_IMAGE if is_image_conversion else PromptStrategy.COT_QA,
|
|
316
337
|
**op_kwargs,
|
|
@@ -325,49 +346,37 @@ class LLMConvertRule(ImplementationRule):
|
|
|
325
346
|
)
|
|
326
347
|
physical_expressions.append(expression)
|
|
327
348
|
|
|
328
|
-
|
|
349
|
+
deduped_physical_expressions = set(physical_expressions)
|
|
350
|
+
logger.debug(f"Done substituting LLMConvertBondedRule for {logical_expression}")
|
|
329
351
|
|
|
352
|
+
return deduped_physical_expressions
|
|
330
353
|
|
|
331
|
-
class LLMConvertBondedRule(LLMConvertRule):
|
|
332
|
-
"""
|
|
333
|
-
Substitute a logical expression for a ConvertScan with a bonded convert physical implementation.
|
|
334
|
-
"""
|
|
335
|
-
|
|
336
|
-
physical_convert_class = LLMConvertBonded
|
|
337
354
|
|
|
338
|
-
|
|
339
|
-
class LLMConvertConventionalRule(LLMConvertRule):
|
|
340
|
-
"""
|
|
341
|
-
Substitute a logical expression for a ConvertScan with a conventional convert physical implementation.
|
|
355
|
+
class TokenReducedConvertBondedRule(ImplementationRule):
|
|
342
356
|
"""
|
|
343
|
-
|
|
344
|
-
physical_convert_class = LLMConvertConventional
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
class TokenReducedConvertRule(ImplementationRule):
|
|
348
|
-
"""
|
|
349
|
-
Base rule for bonded and conventional token reduced convert operators; the physical convert class
|
|
350
|
-
(TokenReducedConvertBonded or TokenReducedConvertConventional) is provided by sub-class rules.
|
|
351
|
-
|
|
352
|
-
NOTE: we provide the physical convert class(es) in their own sub-classed rules to make
|
|
353
|
-
it easier to allow/disallow groups of rules at the Optimizer level.
|
|
357
|
+
Substitute a logical expression for a ConvertScan with a bonded token reduced physical implementation.
|
|
354
358
|
"""
|
|
355
359
|
|
|
356
|
-
physical_convert_class = None # overriden by sub-classes
|
|
357
360
|
token_budgets = [0.1, 0.5, 0.9]
|
|
358
361
|
|
|
359
362
|
@classmethod
|
|
360
363
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
361
364
|
logical_op = logical_expression.operator
|
|
362
|
-
is_image_conversion = any(
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
365
|
+
is_image_conversion = any(
|
|
366
|
+
[
|
|
367
|
+
field.is_image_field
|
|
368
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
369
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
370
|
+
]
|
|
371
|
+
)
|
|
372
|
+
is_match = isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
|
|
373
|
+
logger.debug(f"TokenReducedConvertBondedRule matches_pattern: {is_match} for {logical_expression}")
|
|
374
|
+
return is_match
|
|
368
375
|
|
|
369
376
|
@classmethod
|
|
370
377
|
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
378
|
+
logger.debug(f"Substituting TokenReducedConvertBondedRule for {logical_expression}")
|
|
379
|
+
|
|
371
380
|
logical_op = logical_expression.operator
|
|
372
381
|
|
|
373
382
|
# get initial set of parameters for physical op
|
|
@@ -396,7 +405,7 @@ class TokenReducedConvertRule(ImplementationRule):
|
|
|
396
405
|
continue
|
|
397
406
|
|
|
398
407
|
# construct multi-expression
|
|
399
|
-
op =
|
|
408
|
+
op = TokenReducedConvertBonded(
|
|
400
409
|
model=model,
|
|
401
410
|
prompt_strategy=PromptStrategy.COT_QA,
|
|
402
411
|
token_budget=token_budget,
|
|
@@ -412,23 +421,10 @@ class TokenReducedConvertRule(ImplementationRule):
|
|
|
412
421
|
)
|
|
413
422
|
physical_expressions.append(expression)
|
|
414
423
|
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
class TokenReducedConvertBondedRule(TokenReducedConvertRule):
|
|
419
|
-
"""
|
|
420
|
-
Substitute a logical expression for a ConvertScan with a bonded token reduced physical implementation.
|
|
421
|
-
"""
|
|
424
|
+
logger.debug(f"Done substituting TokenReducedConvertBondedRule for {logical_expression}")
|
|
425
|
+
deduped_physical_expressions = set(physical_expressions)
|
|
422
426
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
class TokenReducedConvertConventionalRule(TokenReducedConvertRule):
|
|
427
|
-
"""
|
|
428
|
-
Substitute a logical expression for a ConvertScan with a conventional token reduced physical implementation.
|
|
429
|
-
"""
|
|
430
|
-
|
|
431
|
-
physical_convert_class = TokenReducedConvertConventional
|
|
427
|
+
return deduped_physical_expressions
|
|
432
428
|
|
|
433
429
|
|
|
434
430
|
class CodeSynthesisConvertRule(ImplementationRule):
|
|
@@ -445,20 +441,26 @@ class CodeSynthesisConvertRule(ImplementationRule):
|
|
|
445
441
|
@classmethod
|
|
446
442
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
447
443
|
logical_op = logical_expression.operator
|
|
448
|
-
is_image_conversion = any(
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
444
|
+
is_image_conversion = any(
|
|
445
|
+
[
|
|
446
|
+
field.is_image_field
|
|
447
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
448
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
449
|
+
]
|
|
450
|
+
)
|
|
451
|
+
is_match = (
|
|
454
452
|
isinstance(logical_op, ConvertScan)
|
|
455
453
|
and not is_image_conversion
|
|
456
454
|
and logical_op.cardinality != Cardinality.ONE_TO_MANY
|
|
457
455
|
and logical_op.udf is None
|
|
458
456
|
)
|
|
457
|
+
logger.debug(f"CodeSynthesisConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
458
|
+
return is_match
|
|
459
459
|
|
|
460
460
|
@classmethod
|
|
461
461
|
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
462
|
+
logger.debug(f"Substituting CodeSynthesisConvertRule for {logical_expression}")
|
|
463
|
+
|
|
462
464
|
logical_op = logical_expression.operator
|
|
463
465
|
|
|
464
466
|
# get initial set of parameters for physical op
|
|
@@ -475,7 +477,7 @@ class CodeSynthesisConvertRule(ImplementationRule):
|
|
|
475
477
|
op = cls.physical_convert_class(
|
|
476
478
|
exemplar_generation_model=physical_op_params["champion_model"],
|
|
477
479
|
code_synth_model=physical_op_params["code_champion_model"],
|
|
478
|
-
|
|
480
|
+
fallback_model=physical_op_params["fallback_model"],
|
|
479
481
|
prompt_strategy=PromptStrategy.COT_QA,
|
|
480
482
|
**op_kwargs,
|
|
481
483
|
)
|
|
@@ -487,8 +489,10 @@ class CodeSynthesisConvertRule(ImplementationRule):
|
|
|
487
489
|
generated_fields=logical_expression.generated_fields,
|
|
488
490
|
group_id=logical_expression.group_id,
|
|
489
491
|
)
|
|
492
|
+
deduped_physical_expressions = set([expression])
|
|
493
|
+
logger.debug(f"Done substituting CodeSynthesisConvertRule for {logical_expression}")
|
|
490
494
|
|
|
491
|
-
return
|
|
495
|
+
return deduped_physical_expressions
|
|
492
496
|
|
|
493
497
|
|
|
494
498
|
class CodeSynthesisConvertSingleRule(CodeSynthesisConvertRule):
|
|
@@ -503,21 +507,28 @@ class RAGConvertRule(ImplementationRule):
|
|
|
503
507
|
"""
|
|
504
508
|
Substitute a logical expression for a ConvertScan with a RAGConvert physical implementation.
|
|
505
509
|
"""
|
|
510
|
+
|
|
506
511
|
num_chunks_per_fields = [1, 2, 4]
|
|
507
512
|
chunk_sizes = [1000, 2000, 4000]
|
|
508
513
|
|
|
509
514
|
@classmethod
|
|
510
515
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
511
516
|
logical_op = logical_expression.operator
|
|
512
|
-
is_image_conversion = any(
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
517
|
+
is_image_conversion = any(
|
|
518
|
+
[
|
|
519
|
+
field.is_image_field
|
|
520
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
521
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
522
|
+
]
|
|
523
|
+
)
|
|
524
|
+
is_match = isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
|
|
525
|
+
logger.debug(f"RAGConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
526
|
+
return is_match
|
|
518
527
|
|
|
519
528
|
@classmethod
|
|
520
529
|
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
530
|
+
logger.debug(f"Substituting RAGConvertRule for {logical_expression}")
|
|
531
|
+
|
|
521
532
|
logical_op = logical_expression.operator
|
|
522
533
|
|
|
523
534
|
# get initial set of parameters for physical op
|
|
@@ -564,31 +575,42 @@ class RAGConvertRule(ImplementationRule):
|
|
|
564
575
|
)
|
|
565
576
|
physical_expressions.append(expression)
|
|
566
577
|
|
|
567
|
-
|
|
578
|
+
logger.debug(f"Done substituting RAGConvertRule for {logical_expression}")
|
|
579
|
+
deduped_physical_expressions = set(physical_expressions)
|
|
580
|
+
|
|
581
|
+
return deduped_physical_expressions
|
|
582
|
+
|
|
568
583
|
|
|
569
584
|
class MixtureOfAgentsConvertRule(ImplementationRule):
|
|
570
585
|
"""
|
|
571
586
|
Implementation rule for the MixtureOfAgentsConvert operator.
|
|
572
587
|
"""
|
|
588
|
+
|
|
573
589
|
num_proposer_models = [1, 2, 3]
|
|
574
590
|
temperatures = [0.0, 0.4, 0.8]
|
|
575
591
|
|
|
576
592
|
@classmethod
|
|
577
593
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
578
594
|
logical_op = logical_expression.operator
|
|
579
|
-
|
|
595
|
+
is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None
|
|
596
|
+
logger.debug(f"MixtureOfAgentsConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
597
|
+
return is_match
|
|
580
598
|
|
|
581
599
|
@classmethod
|
|
582
600
|
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
601
|
+
logger.debug(f"Substituting MixtureOfAgentsConvertRule for {logical_expression}")
|
|
602
|
+
|
|
583
603
|
logical_op = logical_expression.operator
|
|
584
604
|
|
|
585
605
|
# get initial set of parameters for physical op
|
|
586
606
|
op_kwargs: dict = logical_op.get_logical_op_params()
|
|
587
|
-
op_kwargs.update(
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
607
|
+
op_kwargs.update(
|
|
608
|
+
{
|
|
609
|
+
"verbose": physical_op_params["verbose"],
|
|
610
|
+
"logical_op_id": logical_op.get_logical_op_id(),
|
|
611
|
+
"logical_op_name": logical_op.logical_op_name(),
|
|
612
|
+
}
|
|
613
|
+
)
|
|
592
614
|
|
|
593
615
|
# NOTE: when comparing pz.Model(s), equality is determined by the string (i.e. pz.Model.value)
|
|
594
616
|
# thus, Model.GPT_4o and Model.GPT_4o_V map to the same value; this allows us to use set logic
|
|
@@ -598,16 +620,20 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
|
|
|
598
620
|
text_models = set(get_models())
|
|
599
621
|
|
|
600
622
|
# construct set of proposer models and set of aggregator models
|
|
601
|
-
num_image_fields = sum(
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
623
|
+
num_image_fields = sum(
|
|
624
|
+
[
|
|
625
|
+
field.is_image_field
|
|
626
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
627
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
628
|
+
]
|
|
629
|
+
)
|
|
630
|
+
list_image_field = any(
|
|
631
|
+
[
|
|
632
|
+
field.is_image_field and hasattr(field, "element_type")
|
|
633
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
634
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
635
|
+
]
|
|
636
|
+
)
|
|
611
637
|
proposer_model_set, is_image_conversion = text_models, False
|
|
612
638
|
if num_image_fields > 1 or list_image_field:
|
|
613
639
|
proposer_model_set = [model for model in vision_models if model != Model.LLAMA3_V]
|
|
@@ -618,8 +644,10 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
|
|
|
618
644
|
aggregator_model_set = text_models
|
|
619
645
|
|
|
620
646
|
# filter un-available models out of sets
|
|
621
|
-
proposer_model_set = {model for model in proposer_model_set if model in physical_op_params[
|
|
622
|
-
aggregator_model_set = {
|
|
647
|
+
proposer_model_set = {model for model in proposer_model_set if model in physical_op_params["available_models"]}
|
|
648
|
+
aggregator_model_set = {
|
|
649
|
+
model for model in aggregator_model_set if model in physical_op_params["available_models"]
|
|
650
|
+
}
|
|
623
651
|
|
|
624
652
|
# construct MixtureOfAgentsConvert operations for various numbers of proposer models
|
|
625
653
|
# and for every combination of proposer models and aggregator model
|
|
@@ -634,7 +662,9 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
|
|
|
634
662
|
temperatures=[temp] * len(proposer_models),
|
|
635
663
|
aggregator_model=aggregator_model,
|
|
636
664
|
proposer_prompt=op_kwargs.get("prompt"),
|
|
637
|
-
proposer_prompt_strategy=PromptStrategy.COT_MOA_PROPOSER_IMAGE
|
|
665
|
+
proposer_prompt_strategy=PromptStrategy.COT_MOA_PROPOSER_IMAGE
|
|
666
|
+
if is_image_conversion
|
|
667
|
+
else PromptStrategy.COT_MOA_PROPOSER,
|
|
638
668
|
aggregator_prompt_strategy=PromptStrategy.COT_MOA_AGG,
|
|
639
669
|
**op_kwargs,
|
|
640
670
|
)
|
|
@@ -648,7 +678,11 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
|
|
|
648
678
|
)
|
|
649
679
|
physical_expressions.append(expression)
|
|
650
680
|
|
|
651
|
-
|
|
681
|
+
logger.debug(f"Done substituting MixtureOfAgentsConvertRule for {logical_expression}")
|
|
682
|
+
deduped_physical_expressions = set(physical_expressions)
|
|
683
|
+
|
|
684
|
+
return deduped_physical_expressions
|
|
685
|
+
|
|
652
686
|
|
|
653
687
|
class CriticAndRefineConvertRule(ImplementationRule):
|
|
654
688
|
"""
|
|
@@ -658,10 +692,14 @@ class CriticAndRefineConvertRule(ImplementationRule):
|
|
|
658
692
|
@classmethod
|
|
659
693
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
660
694
|
logical_op = logical_expression.operator
|
|
661
|
-
|
|
695
|
+
is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None
|
|
696
|
+
logger.debug(f"CriticAndRefineConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
697
|
+
return is_match
|
|
662
698
|
|
|
663
699
|
@classmethod
|
|
664
700
|
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
701
|
+
logger.debug(f"Substituting CriticAndRefineConvertRule for {logical_expression}")
|
|
702
|
+
|
|
665
703
|
logical_op = logical_expression.operator
|
|
666
704
|
|
|
667
705
|
# Get initial parameters for physical operator
|
|
@@ -684,21 +722,27 @@ class CriticAndRefineConvertRule(ImplementationRule):
|
|
|
684
722
|
pure_vision_models = {model for model in vision_models if model not in text_models}
|
|
685
723
|
|
|
686
724
|
# compute attributes about this convert operation
|
|
687
|
-
is_image_conversion = any(
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
725
|
+
is_image_conversion = any(
|
|
726
|
+
[
|
|
727
|
+
field.is_image_field
|
|
728
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
729
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
730
|
+
]
|
|
731
|
+
)
|
|
732
|
+
num_image_fields = sum(
|
|
733
|
+
[
|
|
734
|
+
field.is_image_field
|
|
735
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
736
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
737
|
+
]
|
|
738
|
+
)
|
|
739
|
+
list_image_field = any(
|
|
740
|
+
[
|
|
741
|
+
field.is_image_field and hasattr(field, "element_type")
|
|
742
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
743
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
744
|
+
]
|
|
745
|
+
)
|
|
702
746
|
|
|
703
747
|
# identify models which can be used for this convert operation
|
|
704
748
|
models = []
|
|
@@ -739,26 +783,104 @@ class CriticAndRefineConvertRule(ImplementationRule):
|
|
|
739
783
|
)
|
|
740
784
|
physical_expressions.append(expression)
|
|
741
785
|
|
|
742
|
-
|
|
743
|
-
|
|
786
|
+
logger.debug(f"Done substituting CriticAndRefineConvertRule for {logical_expression}")
|
|
787
|
+
deduped_physical_expressions = set(physical_expressions)
|
|
788
|
+
|
|
789
|
+
return deduped_physical_expressions
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
class SplitConvertRule(ImplementationRule):
|
|
793
|
+
"""
|
|
794
|
+
Substitute a logical expression for a ConvertScan with a SplitConvert physical implementation.
|
|
795
|
+
"""
|
|
796
|
+
num_chunks = [2, 4, 6]
|
|
797
|
+
min_size_to_chunk = [1000, 4000]
|
|
798
|
+
|
|
799
|
+
@classmethod
|
|
800
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
801
|
+
logical_op = logical_expression.operator
|
|
802
|
+
is_image_conversion = any(
|
|
803
|
+
[
|
|
804
|
+
field.is_image_field
|
|
805
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
806
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
807
|
+
]
|
|
808
|
+
)
|
|
809
|
+
is_match = isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
|
|
810
|
+
logger.debug(f"SplitConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
811
|
+
return is_match
|
|
812
|
+
|
|
813
|
+
@classmethod
|
|
814
|
+
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
815
|
+
logger.debug(f"Substituting SplitConvertRule for {logical_expression}")
|
|
816
|
+
|
|
817
|
+
logical_op = logical_expression.operator
|
|
818
|
+
|
|
819
|
+
# get initial set of parameters for physical op
|
|
820
|
+
op_kwargs = logical_op.get_logical_op_params()
|
|
821
|
+
op_kwargs.update(
|
|
822
|
+
{
|
|
823
|
+
"verbose": physical_op_params["verbose"],
|
|
824
|
+
"logical_op_id": logical_op.get_logical_op_id(),
|
|
825
|
+
"logical_op_name": logical_op.logical_op_name(),
|
|
826
|
+
}
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
# NOTE: when comparing pz.Model(s), equality is determined by the string (i.e. pz.Model.value)
|
|
830
|
+
# thus, Model.GPT_4o and Model.GPT_4o_V map to the same value; this allows us to use set logic
|
|
831
|
+
#
|
|
832
|
+
# identify models which can be used strictly for text or strictly for images
|
|
833
|
+
vision_models = set(get_vision_models())
|
|
834
|
+
text_models = set(get_models())
|
|
835
|
+
pure_vision_models = {model for model in vision_models if model not in text_models}
|
|
836
|
+
|
|
837
|
+
physical_expressions = []
|
|
838
|
+
for model in physical_op_params["available_models"]:
|
|
839
|
+
# skip this model if this is a pure image model
|
|
840
|
+
if model in pure_vision_models:
|
|
841
|
+
continue
|
|
842
|
+
|
|
843
|
+
for min_size_to_chunk in cls.min_size_to_chunk:
|
|
844
|
+
for num_chunks in cls.num_chunks:
|
|
845
|
+
# construct multi-expression
|
|
846
|
+
op = SplitConvert(
|
|
847
|
+
model=model,
|
|
848
|
+
num_chunks=num_chunks,
|
|
849
|
+
min_size_to_chunk=min_size_to_chunk,
|
|
850
|
+
**op_kwargs,
|
|
851
|
+
)
|
|
852
|
+
expression = PhysicalExpression(
|
|
853
|
+
operator=op,
|
|
854
|
+
input_group_ids=logical_expression.input_group_ids,
|
|
855
|
+
input_fields=logical_expression.input_fields,
|
|
856
|
+
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
857
|
+
generated_fields=logical_expression.generated_fields,
|
|
858
|
+
group_id=logical_expression.group_id,
|
|
859
|
+
)
|
|
860
|
+
physical_expressions.append(expression)
|
|
861
|
+
|
|
862
|
+
logger.debug(f"Done substituting SplitConvertRule for {logical_expression}")
|
|
863
|
+
deduped_physical_expressions = set(physical_expressions)
|
|
864
|
+
|
|
865
|
+
return deduped_physical_expressions
|
|
744
866
|
|
|
745
867
|
|
|
746
868
|
class RetrieveRule(ImplementationRule):
|
|
747
869
|
"""
|
|
748
870
|
Substitute a logical expression for a RetrieveScan with a Retrieve physical implementation.
|
|
749
871
|
"""
|
|
750
|
-
k_budgets = [1, 3, 5, 10]
|
|
872
|
+
k_budgets = [1, 3, 5, 10, 15, 20, 25]
|
|
751
873
|
|
|
752
874
|
@classmethod
|
|
753
875
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
876
|
+
is_match = isinstance(logical_expression.operator, RetrieveScan)
|
|
877
|
+
logger.debug(f"RetrieveRule matches_pattern: {is_match} for {logical_expression}")
|
|
878
|
+
return is_match
|
|
757
879
|
|
|
758
880
|
@classmethod
|
|
759
|
-
def substitute(
|
|
760
|
-
|
|
761
|
-
|
|
881
|
+
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
882
|
+
logger.debug(f"Substituting RetrieveRule for {logical_expression}")
|
|
883
|
+
|
|
762
884
|
logical_op = logical_expression.operator
|
|
763
885
|
|
|
764
886
|
physical_expressions = []
|
|
@@ -788,7 +910,10 @@ class RetrieveRule(ImplementationRule):
|
|
|
788
910
|
|
|
789
911
|
physical_expressions.append(expression)
|
|
790
912
|
|
|
791
|
-
|
|
913
|
+
logger.debug(f"Done substituting RetrieveRule for {logical_expression}")
|
|
914
|
+
deduped_physical_expressions = set(physical_expressions)
|
|
915
|
+
|
|
916
|
+
return deduped_physical_expressions
|
|
792
917
|
|
|
793
918
|
|
|
794
919
|
class NonLLMFilterRule(ImplementationRule):
|
|
@@ -798,13 +923,17 @@ class NonLLMFilterRule(ImplementationRule):
|
|
|
798
923
|
|
|
799
924
|
@staticmethod
|
|
800
925
|
def matches_pattern(logical_expression: LogicalExpression) -> bool:
|
|
801
|
-
|
|
926
|
+
is_match = (
|
|
802
927
|
isinstance(logical_expression.operator, FilteredScan)
|
|
803
928
|
and logical_expression.operator.filter.filter_fn is not None
|
|
804
929
|
)
|
|
930
|
+
logger.debug(f"NonLLMFilterRule matches_pattern: {is_match} for {logical_expression}")
|
|
931
|
+
return is_match
|
|
805
932
|
|
|
806
933
|
@staticmethod
|
|
807
934
|
def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
935
|
+
logger.debug(f"Substituting NonLLMFilterRule for {logical_expression}")
|
|
936
|
+
|
|
808
937
|
logical_op = logical_expression.operator
|
|
809
938
|
op_kwargs = logical_op.get_logical_op_params()
|
|
810
939
|
op_kwargs.update(
|
|
@@ -824,7 +953,10 @@ class NonLLMFilterRule(ImplementationRule):
|
|
|
824
953
|
generated_fields=logical_expression.generated_fields,
|
|
825
954
|
group_id=logical_expression.group_id,
|
|
826
955
|
)
|
|
827
|
-
|
|
956
|
+
logger.debug(f"Done substituting NonLLMFilterRule for {logical_expression}")
|
|
957
|
+
deduped_physical_expressions = set([expression])
|
|
958
|
+
|
|
959
|
+
return deduped_physical_expressions
|
|
828
960
|
|
|
829
961
|
|
|
830
962
|
class LLMFilterRule(ImplementationRule):
|
|
@@ -834,20 +966,26 @@ class LLMFilterRule(ImplementationRule):
|
|
|
834
966
|
|
|
835
967
|
@staticmethod
|
|
836
968
|
def matches_pattern(logical_expression: LogicalExpression) -> bool:
|
|
837
|
-
|
|
969
|
+
is_match = (
|
|
838
970
|
isinstance(logical_expression.operator, FilteredScan)
|
|
839
971
|
and logical_expression.operator.filter.filter_condition is not None
|
|
840
972
|
)
|
|
973
|
+
logger.debug(f"LLMFilterRule matches_pattern: {is_match} for {logical_expression}")
|
|
974
|
+
return is_match
|
|
841
975
|
|
|
842
976
|
@staticmethod
|
|
843
977
|
def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
978
|
+
logger.debug(f"Substituting LLMFilterRule for {logical_expression}")
|
|
979
|
+
|
|
844
980
|
logical_op = logical_expression.operator
|
|
845
981
|
op_kwargs = logical_op.get_logical_op_params()
|
|
846
|
-
op_kwargs.update(
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
982
|
+
op_kwargs.update(
|
|
983
|
+
{
|
|
984
|
+
"verbose": physical_op_params["verbose"],
|
|
985
|
+
"logical_op_id": logical_op.get_logical_op_id(),
|
|
986
|
+
"logical_op_name": logical_op.logical_op_name(),
|
|
987
|
+
}
|
|
988
|
+
)
|
|
851
989
|
|
|
852
990
|
# NOTE: when comparing pz.Model(s), equality is determined by the string (i.e. pz.Model.value)
|
|
853
991
|
# thus, Model.GPT_4o and Model.GPT_4o_V map to the same value; this allows us to use set logic
|
|
@@ -859,21 +997,27 @@ class LLMFilterRule(ImplementationRule):
|
|
|
859
997
|
pure_vision_models = {model for model in vision_models if model not in text_models}
|
|
860
998
|
|
|
861
999
|
# compute attributes about this filter operation
|
|
862
|
-
is_image_filter = any(
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
1000
|
+
is_image_filter = any(
|
|
1001
|
+
[
|
|
1002
|
+
field.is_image_field
|
|
1003
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
1004
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
1005
|
+
]
|
|
1006
|
+
)
|
|
1007
|
+
num_image_fields = sum(
|
|
1008
|
+
[
|
|
1009
|
+
field.is_image_field
|
|
1010
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
1011
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
1012
|
+
]
|
|
1013
|
+
)
|
|
1014
|
+
list_image_field = any(
|
|
1015
|
+
[
|
|
1016
|
+
field.is_image_field and hasattr(field, "element_type")
|
|
1017
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
1018
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
1019
|
+
]
|
|
1020
|
+
)
|
|
877
1021
|
|
|
878
1022
|
physical_expressions = []
|
|
879
1023
|
for model in physical_op_params["available_models"]:
|
|
@@ -903,7 +1047,10 @@ class LLMFilterRule(ImplementationRule):
|
|
|
903
1047
|
)
|
|
904
1048
|
physical_expressions.append(expression)
|
|
905
1049
|
|
|
906
|
-
|
|
1050
|
+
logger.debug(f"Done substituting LLMFilterRule for {logical_expression}")
|
|
1051
|
+
deduped_physical_expressions = set(physical_expressions)
|
|
1052
|
+
|
|
1053
|
+
return deduped_physical_expressions
|
|
907
1054
|
|
|
908
1055
|
|
|
909
1056
|
class AggregateRule(ImplementationRule):
|
|
@@ -913,10 +1060,14 @@ class AggregateRule(ImplementationRule):
|
|
|
913
1060
|
|
|
914
1061
|
@staticmethod
|
|
915
1062
|
def matches_pattern(logical_expression: LogicalExpression) -> bool:
|
|
916
|
-
|
|
1063
|
+
is_match = isinstance(logical_expression.operator, Aggregate)
|
|
1064
|
+
logger.debug(f"AggregateRule matches_pattern: {is_match} for {logical_expression}")
|
|
1065
|
+
return is_match
|
|
917
1066
|
|
|
918
1067
|
@staticmethod
|
|
919
1068
|
def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
1069
|
+
logger.debug(f"Substituting AggregateRule for {logical_expression}")
|
|
1070
|
+
|
|
920
1071
|
logical_op = logical_expression.operator
|
|
921
1072
|
op_kwargs = logical_op.get_logical_op_params()
|
|
922
1073
|
op_kwargs.update(
|
|
@@ -943,7 +1094,11 @@ class AggregateRule(ImplementationRule):
|
|
|
943
1094
|
generated_fields=logical_expression.generated_fields,
|
|
944
1095
|
group_id=logical_expression.group_id,
|
|
945
1096
|
)
|
|
946
|
-
|
|
1097
|
+
|
|
1098
|
+
logger.debug(f"Done substituting AggregateRule for {logical_expression}")
|
|
1099
|
+
deduped_physical_expressions = set([expression])
|
|
1100
|
+
|
|
1101
|
+
return deduped_physical_expressions
|
|
947
1102
|
|
|
948
1103
|
|
|
949
1104
|
class BasicSubstitutionRule(ImplementationRule):
|
|
@@ -958,15 +1113,20 @@ class BasicSubstitutionRule(ImplementationRule):
|
|
|
958
1113
|
LimitScan: LimitScanOp,
|
|
959
1114
|
Project: ProjectOp,
|
|
960
1115
|
GroupByAggregate: ApplyGroupByOp,
|
|
1116
|
+
MapScan: MapOp,
|
|
961
1117
|
}
|
|
962
1118
|
|
|
963
1119
|
@classmethod
|
|
964
1120
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
965
1121
|
logical_op_class = logical_expression.operator.__class__
|
|
966
|
-
|
|
1122
|
+
is_match = logical_op_class in cls.LOGICAL_OP_CLASS_TO_PHYSICAL_OP_CLASS_MAP
|
|
1123
|
+
logger.debug(f"BasicSubstitutionRule matches_pattern: {is_match} for {logical_expression}")
|
|
1124
|
+
return is_match
|
|
967
1125
|
|
|
968
1126
|
@classmethod
|
|
969
1127
|
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
1128
|
+
logger.debug(f"Substituting BasicSubstitutionRule for {logical_expression}")
|
|
1129
|
+
|
|
970
1130
|
logical_op = logical_expression.operator
|
|
971
1131
|
op_kwargs = logical_op.get_logical_op_params()
|
|
972
1132
|
op_kwargs.update(
|
|
@@ -987,4 +1147,8 @@ class BasicSubstitutionRule(ImplementationRule):
|
|
|
987
1147
|
generated_fields=logical_expression.generated_fields,
|
|
988
1148
|
group_id=logical_expression.group_id,
|
|
989
1149
|
)
|
|
990
|
-
|
|
1150
|
+
|
|
1151
|
+
logger.debug(f"Done substituting BasicSubstitutionRule for {logical_expression}")
|
|
1152
|
+
deduped_physical_expressions = set([expression])
|
|
1153
|
+
|
|
1154
|
+
return deduped_physical_expressions
|