palimpzest 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. palimpzest/__init__.py +5 -0
  2. palimpzest/constants.py +110 -43
  3. palimpzest/core/__init__.py +0 -78
  4. palimpzest/core/data/dataclasses.py +382 -44
  5. palimpzest/core/elements/filters.py +7 -3
  6. palimpzest/core/elements/index.py +70 -0
  7. palimpzest/core/elements/records.py +33 -11
  8. palimpzest/core/lib/fields.py +1 -0
  9. palimpzest/core/lib/schemas.py +4 -3
  10. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
  11. palimpzest/prompts/prompt_factory.py +44 -7
  12. palimpzest/prompts/split_merge_prompts.py +56 -0
  13. palimpzest/prompts/split_proposer_prompts.py +55 -0
  14. palimpzest/query/execution/execution_strategy.py +435 -53
  15. palimpzest/query/execution/execution_strategy_type.py +20 -0
  16. palimpzest/query/execution/mab_execution_strategy.py +532 -0
  17. palimpzest/query/execution/parallel_execution_strategy.py +143 -172
  18. palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
  19. palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
  20. palimpzest/query/generators/api_client_factory.py +31 -0
  21. palimpzest/query/generators/generators.py +256 -76
  22. palimpzest/query/operators/__init__.py +1 -2
  23. palimpzest/query/operators/code_synthesis_convert.py +33 -18
  24. palimpzest/query/operators/convert.py +30 -97
  25. palimpzest/query/operators/critique_and_refine_convert.py +5 -6
  26. palimpzest/query/operators/filter.py +7 -10
  27. palimpzest/query/operators/logical.py +54 -10
  28. palimpzest/query/operators/map.py +130 -0
  29. palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
  30. palimpzest/query/operators/physical.py +3 -12
  31. palimpzest/query/operators/rag_convert.py +66 -18
  32. palimpzest/query/operators/retrieve.py +230 -34
  33. palimpzest/query/operators/scan.py +5 -2
  34. palimpzest/query/operators/split_convert.py +169 -0
  35. palimpzest/query/operators/token_reduction_convert.py +8 -14
  36. palimpzest/query/optimizer/__init__.py +4 -16
  37. palimpzest/query/optimizer/cost_model.py +73 -266
  38. palimpzest/query/optimizer/optimizer.py +87 -58
  39. palimpzest/query/optimizer/optimizer_strategy.py +18 -97
  40. palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
  41. palimpzest/query/optimizer/plan.py +2 -3
  42. palimpzest/query/optimizer/primitives.py +5 -3
  43. palimpzest/query/optimizer/rules.py +336 -172
  44. palimpzest/query/optimizer/tasks.py +30 -100
  45. palimpzest/query/processor/config.py +38 -22
  46. palimpzest/query/processor/nosentinel_processor.py +16 -520
  47. palimpzest/query/processor/processing_strategy_type.py +28 -0
  48. palimpzest/query/processor/query_processor.py +38 -206
  49. palimpzest/query/processor/query_processor_factory.py +117 -130
  50. palimpzest/query/processor/sentinel_processor.py +90 -0
  51. palimpzest/query/processor/streaming_processor.py +25 -32
  52. palimpzest/sets.py +88 -41
  53. palimpzest/utils/model_helpers.py +8 -7
  54. palimpzest/utils/progress.py +368 -152
  55. palimpzest/utils/token_reduction_helpers.py +1 -3
  56. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/METADATA +28 -24
  57. palimpzest-0.7.0.dist-info/RECORD +96 -0
  58. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
  59. palimpzest/query/processor/mab_sentinel_processor.py +0 -884
  60. palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
  61. palimpzest/utils/index_helpers.py +0 -6
  62. palimpzest-0.6.3.dist-info/RECORD +0 -87
  63. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
  64. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,11 @@
1
+ import logging
1
2
  from copy import deepcopy
2
3
  from itertools import combinations
3
4
 
4
5
  from palimpzest.constants import AggFunc, Cardinality, Model, PromptStrategy
5
6
  from palimpzest.query.operators.aggregate import ApplyGroupByOp, AverageAggregateOp, CountAggregateOp
6
7
  from palimpzest.query.operators.code_synthesis_convert import CodeSynthesisConvertSingle
7
- from palimpzest.query.operators.convert import LLMConvertBonded, LLMConvertConventional, NonLLMConvert
8
+ from palimpzest.query.operators.convert import LLMConvertBonded, NonLLMConvert
8
9
  from palimpzest.query.operators.critique_and_refine_convert import CriticAndRefineConvert
9
10
  from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter
10
11
  from palimpzest.query.operators.limit import LimitScanOp
@@ -16,21 +17,23 @@ from palimpzest.query.operators.logical import (
16
17
  FilteredScan,
17
18
  GroupByAggregate,
18
19
  LimitScan,
20
+ MapScan,
19
21
  Project,
20
22
  RetrieveScan,
21
23
  )
24
+ from palimpzest.query.operators.map import MapOp
22
25
  from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgentsConvert
23
26
  from palimpzest.query.operators.project import ProjectOp
24
27
  from palimpzest.query.operators.rag_convert import RAGConvert
25
28
  from palimpzest.query.operators.retrieve import RetrieveOp
26
29
  from palimpzest.query.operators.scan import CacheScanDataOp, MarshalAndScanDataOp
27
- from palimpzest.query.operators.token_reduction_convert import (
28
- TokenReducedConvertBonded,
29
- TokenReducedConvertConventional,
30
- )
30
+ from palimpzest.query.operators.split_convert import SplitConvert
31
+ from palimpzest.query.operators.token_reduction_convert import TokenReducedConvertBonded
31
32
  from palimpzest.query.optimizer.primitives import Expression, Group, LogicalExpression, PhysicalExpression
32
33
  from palimpzest.utils.model_helpers import get_models, get_vision_models
33
34
 
35
+ logger = logging.getLogger(__name__)
36
+
34
37
 
35
38
  class Rule:
36
39
  """
@@ -81,12 +84,16 @@ class PushDownFilter(TransformationRule):
81
84
 
82
85
  @staticmethod
83
86
  def matches_pattern(logical_expression: Expression) -> bool:
84
- return isinstance(logical_expression.operator, FilteredScan)
87
+ is_match = isinstance(logical_expression.operator, FilteredScan)
88
+ logger.debug(f"PushDownFilter matches_pattern: {is_match} for {logical_expression}")
89
+ return is_match
85
90
 
86
91
  @staticmethod
87
92
  def substitute(
88
93
  logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
89
94
  ) -> tuple[set[LogicalExpression], set[Group]]:
95
+ logger.debug(f"Substituting PushDownFilter for {logical_expression}")
96
+
90
97
  # initialize the sets of new logical expressions and groups to be returned
91
98
  new_logical_expressions, new_groups = set(), set()
92
99
 
@@ -102,8 +109,10 @@ class PushDownFilter(TransformationRule):
102
109
  continue
103
110
 
104
111
  # iterate over logical expressions
105
- logical_exprs = deepcopy(input_group.logical_expressions)
106
- for expr in logical_exprs:
112
+ # NOTE: we previously deepcopy'ed the logical expression to avoid modifying the original;
113
+ # I think I've fixed this internally, but I'm leaving this NOTE as a reminder in case
114
+ # we see a regression / bug in the future
115
+ for expr in input_group.logical_expressions:
107
116
  # if the expression operator is not a convert or a filter, we cannot swap
108
117
  if not (isinstance(expr.operator, (ConvertScan, FilteredScan))):
109
118
  continue
@@ -181,7 +190,7 @@ class PushDownFilter(TransformationRule):
181
190
 
182
191
  # create final new logical expression with expr's operator pulled up
183
192
  new_expr = LogicalExpression(
184
- expr.operator,
193
+ expr.operator.copy(),
185
194
  input_group_ids=[group_id]
186
195
  + [g_id for g_id in logical_expression.input_group_ids if g_id != input_group_id],
187
196
  input_fields=group.fields,
@@ -193,6 +202,8 @@ class PushDownFilter(TransformationRule):
193
202
  # add newly created expression to set of returned expressions
194
203
  new_logical_expressions.add(new_expr)
195
204
 
205
+ logger.debug(f"Done substituting PushDownFilter for {logical_expression}")
206
+
196
207
  return new_logical_expressions, new_groups
197
208
 
198
209
 
@@ -211,10 +222,14 @@ class NonLLMConvertRule(ImplementationRule):
211
222
 
212
223
  @classmethod
213
224
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
214
- return isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is not None
225
+ is_match = isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is not None
226
+ logger.debug(f"NonLLMConvertRule matches_pattern: {is_match} for {logical_expression}")
227
+ return is_match
215
228
 
216
229
  @classmethod
217
230
  def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
231
+ logger.debug(f"Substituting NonLLMConvertRule for {logical_expression}")
232
+
218
233
  logical_op = logical_expression.operator
219
234
 
220
235
  # get initial set of parameters for physical op
@@ -238,27 +253,27 @@ class NonLLMConvertRule(ImplementationRule):
238
253
  group_id=logical_expression.group_id,
239
254
  )
240
255
 
241
- return set([expression])
256
+ deduped_physical_expressions = set([expression])
257
+ logger.debug(f"Done substituting NonLLMConvertRule for {logical_expression}")
242
258
 
259
+ return deduped_physical_expressions
243
260
 
244
- class LLMConvertRule(ImplementationRule):
245
- """
246
- Base rule for bonded and conventional LLM convert operators; the physical convert class
247
- (LLMConvertBonded or LLMConvertConventional) is provided by sub-class rules.
248
261
 
249
- NOTE: we provide the physical convert class(es) in their own sub-classed rules to make
250
- it easier to allow/disallow groups of rules at the Optimizer level.
262
+ class LLMConvertBondedRule(ImplementationRule):
263
+ """
264
+ Substitute a logical expression for a ConvertScan with a bonded convert physical implementation.
251
265
  """
252
-
253
- # overridden by sub-classes
254
- physical_convert_class = None
255
266
 
256
267
  @classmethod
257
268
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
258
- return isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is None
269
+ is_match = isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is None
270
+ logger.debug(f"LLMConvertBondedRule matches_pattern: {is_match} for {logical_expression}")
271
+ return is_match
259
272
 
260
273
  @classmethod
261
274
  def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
275
+ logger.debug(f"Substituting LLMConvertBondedRule for {logical_expression}")
276
+
262
277
  logical_op = logical_expression.operator
263
278
 
264
279
  # get initial set of parameters for physical op
@@ -281,21 +296,27 @@ class LLMConvertRule(ImplementationRule):
281
296
  pure_vision_models = {model for model in vision_models if model not in text_models}
282
297
 
283
298
  # compute attributes about this convert operation
284
- is_image_conversion = any([
285
- field.is_image_field
286
- for field_name, field in logical_expression.input_fields.items()
287
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
288
- ])
289
- num_image_fields = sum([
290
- field.is_image_field
291
- for field_name, field in logical_expression.input_fields.items()
292
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
293
- ])
294
- list_image_field = any([
295
- field.is_image_field and hasattr(field, "element_type")
296
- for field_name, field in logical_expression.input_fields.items()
297
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
298
- ])
299
+ is_image_conversion = any(
300
+ [
301
+ field.is_image_field
302
+ for field_name, field in logical_expression.input_fields.items()
303
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
304
+ ]
305
+ )
306
+ num_image_fields = sum(
307
+ [
308
+ field.is_image_field
309
+ for field_name, field in logical_expression.input_fields.items()
310
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
311
+ ]
312
+ )
313
+ list_image_field = any(
314
+ [
315
+ field.is_image_field and hasattr(field, "element_type")
316
+ for field_name, field in logical_expression.input_fields.items()
317
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
318
+ ]
319
+ )
299
320
 
300
321
  physical_expressions = []
301
322
  for model in physical_op_params["available_models"]:
@@ -310,7 +331,7 @@ class LLMConvertRule(ImplementationRule):
310
331
  continue
311
332
 
312
333
  # construct multi-expression
313
- op = cls.physical_convert_class(
334
+ op = LLMConvertBonded(
314
335
  model=model,
315
336
  prompt_strategy=PromptStrategy.COT_QA_IMAGE if is_image_conversion else PromptStrategy.COT_QA,
316
337
  **op_kwargs,
@@ -325,49 +346,37 @@ class LLMConvertRule(ImplementationRule):
325
346
  )
326
347
  physical_expressions.append(expression)
327
348
 
328
- return set(physical_expressions)
349
+ deduped_physical_expressions = set(physical_expressions)
350
+ logger.debug(f"Done substituting LLMConvertBondedRule for {logical_expression}")
329
351
 
352
+ return deduped_physical_expressions
330
353
 
331
- class LLMConvertBondedRule(LLMConvertRule):
332
- """
333
- Substitute a logical expression for a ConvertScan with a bonded convert physical implementation.
334
- """
335
-
336
- physical_convert_class = LLMConvertBonded
337
354
 
338
-
339
- class LLMConvertConventionalRule(LLMConvertRule):
340
- """
341
- Substitute a logical expression for a ConvertScan with a conventional convert physical implementation.
355
+ class TokenReducedConvertBondedRule(ImplementationRule):
342
356
  """
343
-
344
- physical_convert_class = LLMConvertConventional
345
-
346
-
347
- class TokenReducedConvertRule(ImplementationRule):
348
- """
349
- Base rule for bonded and conventional token reduced convert operators; the physical convert class
350
- (TokenReducedConvertBonded or TokenReducedConvertConventional) is provided by sub-class rules.
351
-
352
- NOTE: we provide the physical convert class(es) in their own sub-classed rules to make
353
- it easier to allow/disallow groups of rules at the Optimizer level.
357
+ Substitute a logical expression for a ConvertScan with a bonded token reduced physical implementation.
354
358
  """
355
359
 
356
- physical_convert_class = None # overriden by sub-classes
357
360
  token_budgets = [0.1, 0.5, 0.9]
358
361
 
359
362
  @classmethod
360
363
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
361
364
  logical_op = logical_expression.operator
362
- is_image_conversion = any([
363
- field.is_image_field
364
- for field_name, field in logical_expression.input_fields.items()
365
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
366
- ])
367
- return isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
365
+ is_image_conversion = any(
366
+ [
367
+ field.is_image_field
368
+ for field_name, field in logical_expression.input_fields.items()
369
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
370
+ ]
371
+ )
372
+ is_match = isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
373
+ logger.debug(f"TokenReducedConvertBondedRule matches_pattern: {is_match} for {logical_expression}")
374
+ return is_match
368
375
 
369
376
  @classmethod
370
377
  def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
378
+ logger.debug(f"Substituting TokenReducedConvertBondedRule for {logical_expression}")
379
+
371
380
  logical_op = logical_expression.operator
372
381
 
373
382
  # get initial set of parameters for physical op
@@ -396,7 +405,7 @@ class TokenReducedConvertRule(ImplementationRule):
396
405
  continue
397
406
 
398
407
  # construct multi-expression
399
- op = cls.physical_convert_class(
408
+ op = TokenReducedConvertBonded(
400
409
  model=model,
401
410
  prompt_strategy=PromptStrategy.COT_QA,
402
411
  token_budget=token_budget,
@@ -412,23 +421,10 @@ class TokenReducedConvertRule(ImplementationRule):
412
421
  )
413
422
  physical_expressions.append(expression)
414
423
 
415
- return set(physical_expressions)
416
-
417
-
418
- class TokenReducedConvertBondedRule(TokenReducedConvertRule):
419
- """
420
- Substitute a logical expression for a ConvertScan with a bonded token reduced physical implementation.
421
- """
424
+ logger.debug(f"Done substituting TokenReducedConvertBondedRule for {logical_expression}")
425
+ deduped_physical_expressions = set(physical_expressions)
422
426
 
423
- physical_convert_class = TokenReducedConvertBonded
424
-
425
-
426
- class TokenReducedConvertConventionalRule(TokenReducedConvertRule):
427
- """
428
- Substitute a logical expression for a ConvertScan with a conventional token reduced physical implementation.
429
- """
430
-
431
- physical_convert_class = TokenReducedConvertConventional
427
+ return deduped_physical_expressions
432
428
 
433
429
 
434
430
  class CodeSynthesisConvertRule(ImplementationRule):
@@ -445,20 +441,26 @@ class CodeSynthesisConvertRule(ImplementationRule):
445
441
  @classmethod
446
442
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
447
443
  logical_op = logical_expression.operator
448
- is_image_conversion = any([
449
- field.is_image_field
450
- for field_name, field in logical_expression.input_fields.items()
451
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
452
- ])
453
- return (
444
+ is_image_conversion = any(
445
+ [
446
+ field.is_image_field
447
+ for field_name, field in logical_expression.input_fields.items()
448
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
449
+ ]
450
+ )
451
+ is_match = (
454
452
  isinstance(logical_op, ConvertScan)
455
453
  and not is_image_conversion
456
454
  and logical_op.cardinality != Cardinality.ONE_TO_MANY
457
455
  and logical_op.udf is None
458
456
  )
457
+ logger.debug(f"CodeSynthesisConvertRule matches_pattern: {is_match} for {logical_expression}")
458
+ return is_match
459
459
 
460
460
  @classmethod
461
461
  def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
462
+ logger.debug(f"Substituting CodeSynthesisConvertRule for {logical_expression}")
463
+
462
464
  logical_op = logical_expression.operator
463
465
 
464
466
  # get initial set of parameters for physical op
@@ -475,7 +477,7 @@ class CodeSynthesisConvertRule(ImplementationRule):
475
477
  op = cls.physical_convert_class(
476
478
  exemplar_generation_model=physical_op_params["champion_model"],
477
479
  code_synth_model=physical_op_params["code_champion_model"],
478
- conventional_fallback_model=physical_op_params["conventional_fallback_model"],
480
+ fallback_model=physical_op_params["fallback_model"],
479
481
  prompt_strategy=PromptStrategy.COT_QA,
480
482
  **op_kwargs,
481
483
  )
@@ -487,8 +489,10 @@ class CodeSynthesisConvertRule(ImplementationRule):
487
489
  generated_fields=logical_expression.generated_fields,
488
490
  group_id=logical_expression.group_id,
489
491
  )
492
+ deduped_physical_expressions = set([expression])
493
+ logger.debug(f"Done substituting CodeSynthesisConvertRule for {logical_expression}")
490
494
 
491
- return set([expression])
495
+ return deduped_physical_expressions
492
496
 
493
497
 
494
498
  class CodeSynthesisConvertSingleRule(CodeSynthesisConvertRule):
@@ -503,21 +507,28 @@ class RAGConvertRule(ImplementationRule):
503
507
  """
504
508
  Substitute a logical expression for a ConvertScan with a RAGConvert physical implementation.
505
509
  """
510
+
506
511
  num_chunks_per_fields = [1, 2, 4]
507
512
  chunk_sizes = [1000, 2000, 4000]
508
513
 
509
514
  @classmethod
510
515
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
511
516
  logical_op = logical_expression.operator
512
- is_image_conversion = any([
513
- field.is_image_field
514
- for field_name, field in logical_expression.input_fields.items()
515
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
516
- ])
517
- return isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
517
+ is_image_conversion = any(
518
+ [
519
+ field.is_image_field
520
+ for field_name, field in logical_expression.input_fields.items()
521
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
522
+ ]
523
+ )
524
+ is_match = isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
525
+ logger.debug(f"RAGConvertRule matches_pattern: {is_match} for {logical_expression}")
526
+ return is_match
518
527
 
519
528
  @classmethod
520
529
  def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
530
+ logger.debug(f"Substituting RAGConvertRule for {logical_expression}")
531
+
521
532
  logical_op = logical_expression.operator
522
533
 
523
534
  # get initial set of parameters for physical op
@@ -564,31 +575,42 @@ class RAGConvertRule(ImplementationRule):
564
575
  )
565
576
  physical_expressions.append(expression)
566
577
 
567
- return set(physical_expressions)
578
+ logger.debug(f"Done substituting RAGConvertRule for {logical_expression}")
579
+ deduped_physical_expressions = set(physical_expressions)
580
+
581
+ return deduped_physical_expressions
582
+
568
583
 
569
584
  class MixtureOfAgentsConvertRule(ImplementationRule):
570
585
  """
571
586
  Implementation rule for the MixtureOfAgentsConvert operator.
572
587
  """
588
+
573
589
  num_proposer_models = [1, 2, 3]
574
590
  temperatures = [0.0, 0.4, 0.8]
575
591
 
576
592
  @classmethod
577
593
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
578
594
  logical_op = logical_expression.operator
579
- return isinstance(logical_op, ConvertScan) and logical_op.udf is None
595
+ is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None
596
+ logger.debug(f"MixtureOfAgentsConvertRule matches_pattern: {is_match} for {logical_expression}")
597
+ return is_match
580
598
 
581
599
  @classmethod
582
600
  def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
601
+ logger.debug(f"Substituting MixtureOfAgentsConvertRule for {logical_expression}")
602
+
583
603
  logical_op = logical_expression.operator
584
604
 
585
605
  # get initial set of parameters for physical op
586
606
  op_kwargs: dict = logical_op.get_logical_op_params()
587
- op_kwargs.update({
588
- "verbose": physical_op_params['verbose'],
589
- "logical_op_id": logical_op.get_logical_op_id(),
590
- "logical_op_name": logical_op.logical_op_name(),
591
- })
607
+ op_kwargs.update(
608
+ {
609
+ "verbose": physical_op_params["verbose"],
610
+ "logical_op_id": logical_op.get_logical_op_id(),
611
+ "logical_op_name": logical_op.logical_op_name(),
612
+ }
613
+ )
592
614
 
593
615
  # NOTE: when comparing pz.Model(s), equality is determined by the string (i.e. pz.Model.value)
594
616
  # thus, Model.GPT_4o and Model.GPT_4o_V map to the same value; this allows us to use set logic
@@ -598,16 +620,20 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
598
620
  text_models = set(get_models())
599
621
 
600
622
  # construct set of proposer models and set of aggregator models
601
- num_image_fields = sum([
602
- field.is_image_field
603
- for field_name, field in logical_expression.input_fields.items()
604
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
605
- ])
606
- list_image_field = any([
607
- field.is_image_field and hasattr(field, "element_type")
608
- for field_name, field in logical_expression.input_fields.items()
609
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
610
- ])
623
+ num_image_fields = sum(
624
+ [
625
+ field.is_image_field
626
+ for field_name, field in logical_expression.input_fields.items()
627
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
628
+ ]
629
+ )
630
+ list_image_field = any(
631
+ [
632
+ field.is_image_field and hasattr(field, "element_type")
633
+ for field_name, field in logical_expression.input_fields.items()
634
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
635
+ ]
636
+ )
611
637
  proposer_model_set, is_image_conversion = text_models, False
612
638
  if num_image_fields > 1 or list_image_field:
613
639
  proposer_model_set = [model for model in vision_models if model != Model.LLAMA3_V]
@@ -618,8 +644,10 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
618
644
  aggregator_model_set = text_models
619
645
 
620
646
  # filter un-available models out of sets
621
- proposer_model_set = {model for model in proposer_model_set if model in physical_op_params['available_models']}
622
- aggregator_model_set = {model for model in aggregator_model_set if model in physical_op_params['available_models']}
647
+ proposer_model_set = {model for model in proposer_model_set if model in physical_op_params["available_models"]}
648
+ aggregator_model_set = {
649
+ model for model in aggregator_model_set if model in physical_op_params["available_models"]
650
+ }
623
651
 
624
652
  # construct MixtureOfAgentsConvert operations for various numbers of proposer models
625
653
  # and for every combination of proposer models and aggregator model
@@ -634,7 +662,9 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
634
662
  temperatures=[temp] * len(proposer_models),
635
663
  aggregator_model=aggregator_model,
636
664
  proposer_prompt=op_kwargs.get("prompt"),
637
- proposer_prompt_strategy=PromptStrategy.COT_MOA_PROPOSER_IMAGE if is_image_conversion else PromptStrategy.COT_MOA_PROPOSER,
665
+ proposer_prompt_strategy=PromptStrategy.COT_MOA_PROPOSER_IMAGE
666
+ if is_image_conversion
667
+ else PromptStrategy.COT_MOA_PROPOSER,
638
668
  aggregator_prompt_strategy=PromptStrategy.COT_MOA_AGG,
639
669
  **op_kwargs,
640
670
  )
@@ -648,7 +678,11 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
648
678
  )
649
679
  physical_expressions.append(expression)
650
680
 
651
- return set(physical_expressions)
681
+ logger.debug(f"Done substituting MixtureOfAgentsConvertRule for {logical_expression}")
682
+ deduped_physical_expressions = set(physical_expressions)
683
+
684
+ return deduped_physical_expressions
685
+
652
686
 
653
687
  class CriticAndRefineConvertRule(ImplementationRule):
654
688
  """
@@ -658,10 +692,14 @@ class CriticAndRefineConvertRule(ImplementationRule):
658
692
  @classmethod
659
693
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
660
694
  logical_op = logical_expression.operator
661
- return isinstance(logical_op, ConvertScan) and logical_op.udf is None
695
+ is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None
696
+ logger.debug(f"CriticAndRefineConvertRule matches_pattern: {is_match} for {logical_expression}")
697
+ return is_match
662
698
 
663
699
  @classmethod
664
700
  def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
701
+ logger.debug(f"Substituting CriticAndRefineConvertRule for {logical_expression}")
702
+
665
703
  logical_op = logical_expression.operator
666
704
 
667
705
  # Get initial parameters for physical operator
@@ -684,21 +722,27 @@ class CriticAndRefineConvertRule(ImplementationRule):
684
722
  pure_vision_models = {model for model in vision_models if model not in text_models}
685
723
 
686
724
  # compute attributes about this convert operation
687
- is_image_conversion = any([
688
- field.is_image_field
689
- for field_name, field in logical_expression.input_fields.items()
690
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
691
- ])
692
- num_image_fields = sum([
693
- field.is_image_field
694
- for field_name, field in logical_expression.input_fields.items()
695
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
696
- ])
697
- list_image_field = any([
698
- field.is_image_field and hasattr(field, "element_type")
699
- for field_name, field in logical_expression.input_fields.items()
700
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
701
- ])
725
+ is_image_conversion = any(
726
+ [
727
+ field.is_image_field
728
+ for field_name, field in logical_expression.input_fields.items()
729
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
730
+ ]
731
+ )
732
+ num_image_fields = sum(
733
+ [
734
+ field.is_image_field
735
+ for field_name, field in logical_expression.input_fields.items()
736
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
737
+ ]
738
+ )
739
+ list_image_field = any(
740
+ [
741
+ field.is_image_field and hasattr(field, "element_type")
742
+ for field_name, field in logical_expression.input_fields.items()
743
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
744
+ ]
745
+ )
702
746
 
703
747
  # identify models which can be used for this convert operation
704
748
  models = []
@@ -739,26 +783,104 @@ class CriticAndRefineConvertRule(ImplementationRule):
739
783
  )
740
784
  physical_expressions.append(expression)
741
785
 
742
- # Return the set containing the new physical expression
743
- return set(physical_expressions)
786
+ logger.debug(f"Done substituting CriticAndRefineConvertRule for {logical_expression}")
787
+ deduped_physical_expressions = set(physical_expressions)
788
+
789
+ return deduped_physical_expressions
790
+
791
+
792
+ class SplitConvertRule(ImplementationRule):
793
+ """
794
+ Substitute a logical expression for a ConvertScan with a SplitConvert physical implementation.
795
+ """
796
+ num_chunks = [2, 4, 6]
797
+ min_size_to_chunk = [1000, 4000]
798
+
799
+ @classmethod
800
+ def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
801
+ logical_op = logical_expression.operator
802
+ is_image_conversion = any(
803
+ [
804
+ field.is_image_field
805
+ for field_name, field in logical_expression.input_fields.items()
806
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
807
+ ]
808
+ )
809
+ is_match = isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
810
+ logger.debug(f"SplitConvertRule matches_pattern: {is_match} for {logical_expression}")
811
+ return is_match
812
+
813
+ @classmethod
814
+ def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
815
+ logger.debug(f"Substituting SplitConvertRule for {logical_expression}")
816
+
817
+ logical_op = logical_expression.operator
818
+
819
+ # get initial set of parameters for physical op
820
+ op_kwargs = logical_op.get_logical_op_params()
821
+ op_kwargs.update(
822
+ {
823
+ "verbose": physical_op_params["verbose"],
824
+ "logical_op_id": logical_op.get_logical_op_id(),
825
+ "logical_op_name": logical_op.logical_op_name(),
826
+ }
827
+ )
828
+
829
+ # NOTE: when comparing pz.Model(s), equality is determined by the string (i.e. pz.Model.value)
830
+ # thus, Model.GPT_4o and Model.GPT_4o_V map to the same value; this allows us to use set logic
831
+ #
832
+ # identify models which can be used strictly for text or strictly for images
833
+ vision_models = set(get_vision_models())
834
+ text_models = set(get_models())
835
+ pure_vision_models = {model for model in vision_models if model not in text_models}
836
+
837
+ physical_expressions = []
838
+ for model in physical_op_params["available_models"]:
839
+ # skip this model if this is a pure image model
840
+ if model in pure_vision_models:
841
+ continue
842
+
843
+ for min_size_to_chunk in cls.min_size_to_chunk:
844
+ for num_chunks in cls.num_chunks:
845
+ # construct multi-expression
846
+ op = SplitConvert(
847
+ model=model,
848
+ num_chunks=num_chunks,
849
+ min_size_to_chunk=min_size_to_chunk,
850
+ **op_kwargs,
851
+ )
852
+ expression = PhysicalExpression(
853
+ operator=op,
854
+ input_group_ids=logical_expression.input_group_ids,
855
+ input_fields=logical_expression.input_fields,
856
+ depends_on_field_names=logical_expression.depends_on_field_names,
857
+ generated_fields=logical_expression.generated_fields,
858
+ group_id=logical_expression.group_id,
859
+ )
860
+ physical_expressions.append(expression)
861
+
862
+ logger.debug(f"Done substituting SplitConvertRule for {logical_expression}")
863
+ deduped_physical_expressions = set(physical_expressions)
864
+
865
+ return deduped_physical_expressions
744
866
 
745
867
 
746
868
  class RetrieveRule(ImplementationRule):
747
869
  """
748
870
  Substitute a logical expression for a RetrieveScan with a Retrieve physical implementation.
749
871
  """
750
- k_budgets = [1, 3, 5, 10]
872
+ k_budgets = [1, 3, 5, 10, 15, 20, 25]
751
873
 
752
874
  @classmethod
753
875
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
754
- return (
755
- isinstance(logical_expression.operator, RetrieveScan)
756
- )
876
+ is_match = isinstance(logical_expression.operator, RetrieveScan)
877
+ logger.debug(f"RetrieveRule matches_pattern: {is_match} for {logical_expression}")
878
+ return is_match
757
879
 
758
880
  @classmethod
759
- def substitute(
760
- cls, logical_expression: LogicalExpression, **physical_op_params
761
- ) -> set[PhysicalExpression]:
881
+ def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
882
+ logger.debug(f"Substituting RetrieveRule for {logical_expression}")
883
+
762
884
  logical_op = logical_expression.operator
763
885
 
764
886
  physical_expressions = []
@@ -788,7 +910,10 @@ class RetrieveRule(ImplementationRule):
788
910
 
789
911
  physical_expressions.append(expression)
790
912
 
791
- return set(physical_expressions)
913
+ logger.debug(f"Done substituting RetrieveRule for {logical_expression}")
914
+ deduped_physical_expressions = set(physical_expressions)
915
+
916
+ return deduped_physical_expressions
792
917
 
793
918
 
794
919
  class NonLLMFilterRule(ImplementationRule):
@@ -798,13 +923,17 @@ class NonLLMFilterRule(ImplementationRule):
798
923
 
799
924
  @staticmethod
800
925
  def matches_pattern(logical_expression: LogicalExpression) -> bool:
801
- return (
926
+ is_match = (
802
927
  isinstance(logical_expression.operator, FilteredScan)
803
928
  and logical_expression.operator.filter.filter_fn is not None
804
929
  )
930
+ logger.debug(f"NonLLMFilterRule matches_pattern: {is_match} for {logical_expression}")
931
+ return is_match
805
932
 
806
933
  @staticmethod
807
934
  def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
935
+ logger.debug(f"Substituting NonLLMFilterRule for {logical_expression}")
936
+
808
937
  logical_op = logical_expression.operator
809
938
  op_kwargs = logical_op.get_logical_op_params()
810
939
  op_kwargs.update(
@@ -824,7 +953,10 @@ class NonLLMFilterRule(ImplementationRule):
824
953
  generated_fields=logical_expression.generated_fields,
825
954
  group_id=logical_expression.group_id,
826
955
  )
827
- return set([expression])
956
+ logger.debug(f"Done substituting NonLLMFilterRule for {logical_expression}")
957
+ deduped_physical_expressions = set([expression])
958
+
959
+ return deduped_physical_expressions
828
960
 
829
961
 
830
962
  class LLMFilterRule(ImplementationRule):
@@ -834,20 +966,26 @@ class LLMFilterRule(ImplementationRule):
834
966
 
835
967
  @staticmethod
836
968
  def matches_pattern(logical_expression: LogicalExpression) -> bool:
837
- return (
969
+ is_match = (
838
970
  isinstance(logical_expression.operator, FilteredScan)
839
971
  and logical_expression.operator.filter.filter_condition is not None
840
972
  )
973
+ logger.debug(f"LLMFilterRule matches_pattern: {is_match} for {logical_expression}")
974
+ return is_match
841
975
 
842
976
  @staticmethod
843
977
  def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
978
+ logger.debug(f"Substituting LLMFilterRule for {logical_expression}")
979
+
844
980
  logical_op = logical_expression.operator
845
981
  op_kwargs = logical_op.get_logical_op_params()
846
- op_kwargs.update({
847
- "verbose": physical_op_params["verbose"],
848
- "logical_op_id": logical_op.get_logical_op_id(),
849
- "logical_op_name": logical_op.logical_op_name(),
850
- })
982
+ op_kwargs.update(
983
+ {
984
+ "verbose": physical_op_params["verbose"],
985
+ "logical_op_id": logical_op.get_logical_op_id(),
986
+ "logical_op_name": logical_op.logical_op_name(),
987
+ }
988
+ )
851
989
 
852
990
  # NOTE: when comparing pz.Model(s), equality is determined by the string (i.e. pz.Model.value)
853
991
  # thus, Model.GPT_4o and Model.GPT_4o_V map to the same value; this allows us to use set logic
@@ -859,21 +997,27 @@ class LLMFilterRule(ImplementationRule):
859
997
  pure_vision_models = {model for model in vision_models if model not in text_models}
860
998
 
861
999
  # compute attributes about this filter operation
862
- is_image_filter = any([
863
- field.is_image_field
864
- for field_name, field in logical_expression.input_fields.items()
865
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
866
- ])
867
- num_image_fields = sum([
868
- field.is_image_field
869
- for field_name, field in logical_expression.input_fields.items()
870
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
871
- ])
872
- list_image_field = any([
873
- field.is_image_field and hasattr(field, "element_type")
874
- for field_name, field in logical_expression.input_fields.items()
875
- if field_name.split(".")[-1] in logical_expression.depends_on_field_names
876
- ])
1000
+ is_image_filter = any(
1001
+ [
1002
+ field.is_image_field
1003
+ for field_name, field in logical_expression.input_fields.items()
1004
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
1005
+ ]
1006
+ )
1007
+ num_image_fields = sum(
1008
+ [
1009
+ field.is_image_field
1010
+ for field_name, field in logical_expression.input_fields.items()
1011
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
1012
+ ]
1013
+ )
1014
+ list_image_field = any(
1015
+ [
1016
+ field.is_image_field and hasattr(field, "element_type")
1017
+ for field_name, field in logical_expression.input_fields.items()
1018
+ if field_name.split(".")[-1] in logical_expression.depends_on_field_names
1019
+ ]
1020
+ )
877
1021
 
878
1022
  physical_expressions = []
879
1023
  for model in physical_op_params["available_models"]:
@@ -903,7 +1047,10 @@ class LLMFilterRule(ImplementationRule):
903
1047
  )
904
1048
  physical_expressions.append(expression)
905
1049
 
906
- return set(physical_expressions)
1050
+ logger.debug(f"Done substituting LLMFilterRule for {logical_expression}")
1051
+ deduped_physical_expressions = set(physical_expressions)
1052
+
1053
+ return deduped_physical_expressions
907
1054
 
908
1055
 
909
1056
  class AggregateRule(ImplementationRule):
@@ -913,10 +1060,14 @@ class AggregateRule(ImplementationRule):
913
1060
 
914
1061
  @staticmethod
915
1062
  def matches_pattern(logical_expression: LogicalExpression) -> bool:
916
- return isinstance(logical_expression.operator, Aggregate)
1063
+ is_match = isinstance(logical_expression.operator, Aggregate)
1064
+ logger.debug(f"AggregateRule matches_pattern: {is_match} for {logical_expression}")
1065
+ return is_match
917
1066
 
918
1067
  @staticmethod
919
1068
  def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
1069
+ logger.debug(f"Substituting AggregateRule for {logical_expression}")
1070
+
920
1071
  logical_op = logical_expression.operator
921
1072
  op_kwargs = logical_op.get_logical_op_params()
922
1073
  op_kwargs.update(
@@ -943,7 +1094,11 @@ class AggregateRule(ImplementationRule):
943
1094
  generated_fields=logical_expression.generated_fields,
944
1095
  group_id=logical_expression.group_id,
945
1096
  )
946
- return set([expression])
1097
+
1098
+ logger.debug(f"Done substituting AggregateRule for {logical_expression}")
1099
+ deduped_physical_expressions = set([expression])
1100
+
1101
+ return deduped_physical_expressions
947
1102
 
948
1103
 
949
1104
  class BasicSubstitutionRule(ImplementationRule):
@@ -958,15 +1113,20 @@ class BasicSubstitutionRule(ImplementationRule):
958
1113
  LimitScan: LimitScanOp,
959
1114
  Project: ProjectOp,
960
1115
  GroupByAggregate: ApplyGroupByOp,
1116
+ MapScan: MapOp,
961
1117
  }
962
1118
 
963
1119
  @classmethod
964
1120
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
965
1121
  logical_op_class = logical_expression.operator.__class__
966
- return logical_op_class in cls.LOGICAL_OP_CLASS_TO_PHYSICAL_OP_CLASS_MAP
1122
+ is_match = logical_op_class in cls.LOGICAL_OP_CLASS_TO_PHYSICAL_OP_CLASS_MAP
1123
+ logger.debug(f"BasicSubstitutionRule matches_pattern: {is_match} for {logical_expression}")
1124
+ return is_match
967
1125
 
968
1126
  @classmethod
969
1127
  def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
1128
+ logger.debug(f"Substituting BasicSubstitutionRule for {logical_expression}")
1129
+
970
1130
  logical_op = logical_expression.operator
971
1131
  op_kwargs = logical_op.get_logical_op_params()
972
1132
  op_kwargs.update(
@@ -987,4 +1147,8 @@ class BasicSubstitutionRule(ImplementationRule):
987
1147
  generated_fields=logical_expression.generated_fields,
988
1148
  group_id=logical_expression.group_id,
989
1149
  )
990
- return set([expression])
1150
+
1151
+ logger.debug(f"Done substituting BasicSubstitutionRule for {logical_expression}")
1152
+ deduped_physical_expressions = set([expression])
1153
+
1154
+ return deduped_physical_expressions