palimpzest 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +7 -9
- palimpzest/constants.py +47 -7
- palimpzest/core/__init__.py +20 -26
- palimpzest/core/data/dataclasses.py +9 -2
- palimpzest/core/data/datareaders.py +497 -0
- palimpzest/core/elements/records.py +29 -37
- palimpzest/core/lib/fields.py +14 -12
- palimpzest/core/lib/schemas.py +80 -94
- palimpzest/policy.py +58 -0
- palimpzest/prompts/__init__.py +22 -0
- palimpzest/prompts/code_synthesis_prompts.py +28 -0
- palimpzest/prompts/convert_prompts.py +87 -0
- palimpzest/prompts/critique_and_refine_convert_prompts.py +216 -0
- palimpzest/prompts/filter_prompts.py +69 -0
- palimpzest/prompts/moa_aggregator_convert_prompts.py +57 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +79 -0
- palimpzest/prompts/prompt_factory.py +732 -0
- palimpzest/prompts/util_phrases.py +14 -0
- palimpzest/query/execution/execution_strategy.py +0 -3
- palimpzest/query/execution/parallel_execution_strategy.py +12 -25
- palimpzest/query/execution/single_threaded_execution_strategy.py +31 -45
- palimpzest/query/generators/generators.py +71 -347
- palimpzest/query/operators/__init__.py +5 -5
- palimpzest/query/operators/aggregate.py +10 -5
- palimpzest/query/operators/code_synthesis_convert.py +4 -48
- palimpzest/query/operators/convert.py +5 -2
- palimpzest/query/operators/critique_and_refine_convert.py +112 -0
- palimpzest/query/operators/filter.py +1 -1
- palimpzest/query/operators/limit.py +1 -1
- palimpzest/query/operators/logical.py +28 -27
- palimpzest/query/operators/mixture_of_agents_convert.py +4 -1
- palimpzest/query/operators/physical.py +32 -20
- palimpzest/query/operators/project.py +1 -1
- palimpzest/query/operators/rag_convert.py +6 -3
- palimpzest/query/operators/retrieve.py +13 -31
- palimpzest/query/operators/scan.py +150 -0
- palimpzest/query/optimizer/__init__.py +5 -1
- palimpzest/query/optimizer/cost_model.py +18 -34
- palimpzest/query/optimizer/optimizer.py +40 -25
- palimpzest/query/optimizer/optimizer_strategy.py +26 -0
- palimpzest/query/optimizer/plan.py +2 -2
- palimpzest/query/optimizer/rules.py +118 -27
- palimpzest/query/processor/config.py +12 -1
- palimpzest/query/processor/mab_sentinel_processor.py +125 -112
- palimpzest/query/processor/nosentinel_processor.py +46 -62
- palimpzest/query/processor/query_processor.py +10 -20
- palimpzest/query/processor/query_processor_factory.py +12 -5
- palimpzest/query/processor/random_sampling_sentinel_processor.py +112 -91
- palimpzest/query/processor/streaming_processor.py +11 -17
- palimpzest/sets.py +170 -94
- palimpzest/tools/pdfparser.py +5 -64
- palimpzest/utils/datareader_helpers.py +61 -0
- palimpzest/utils/field_helpers.py +69 -0
- palimpzest/utils/hash_helpers.py +3 -2
- palimpzest/utils/udfs.py +0 -28
- {palimpzest-0.5.3.dist-info → palimpzest-0.6.0.dist-info}/METADATA +49 -49
- palimpzest-0.6.0.dist-info/RECORD +87 -0
- {palimpzest-0.5.3.dist-info → palimpzest-0.6.0.dist-info}/top_level.txt +0 -1
- cli/README.md +0 -156
- cli/__init__.py +0 -0
- cli/cli_main.py +0 -390
- palimpzest/config.py +0 -89
- palimpzest/core/data/datasources.py +0 -369
- palimpzest/datamanager/__init__.py +0 -0
- palimpzest/datamanager/datamanager.py +0 -300
- palimpzest/prompts.py +0 -397
- palimpzest/query/operators/datasource.py +0 -202
- palimpzest-0.5.3.dist-info/RECORD +0 -83
- palimpzest-0.5.3.dist-info/entry_points.txt +0 -2
- {palimpzest-0.5.3.dist-info → palimpzest-0.6.0.dist-info}/LICENSE +0 -0
- {palimpzest-0.5.3.dist-info → palimpzest-0.6.0.dist-info}/WHEEL +0 -0
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
from copy import deepcopy
|
|
2
2
|
from itertools import combinations
|
|
3
|
-
from typing import Dict, Set, Tuple
|
|
4
3
|
|
|
5
4
|
from palimpzest.constants import AggFunc, Cardinality, Model, PromptStrategy
|
|
6
|
-
from palimpzest.core.lib.fields import ListField
|
|
7
5
|
from palimpzest.query.operators.aggregate import ApplyGroupByOp, AverageAggregateOp, CountAggregateOp
|
|
8
6
|
from palimpzest.query.operators.code_synthesis_convert import CodeSynthesisConvertSingle
|
|
9
7
|
from palimpzest.query.operators.convert import LLMConvertBonded, LLMConvertConventional, NonLLMConvert
|
|
10
|
-
from palimpzest.query.operators.
|
|
8
|
+
from palimpzest.query.operators.critique_and_refine_convert import CriticAndRefineConvert
|
|
11
9
|
from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter
|
|
12
10
|
from palimpzest.query.operators.limit import LimitScanOp
|
|
13
11
|
from palimpzest.query.operators.logical import (
|
|
@@ -25,6 +23,7 @@ from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgents
|
|
|
25
23
|
from palimpzest.query.operators.project import ProjectOp
|
|
26
24
|
from palimpzest.query.operators.rag_convert import RAGConvert
|
|
27
25
|
from palimpzest.query.operators.retrieve import RetrieveOp
|
|
26
|
+
from palimpzest.query.operators.scan import CacheScanDataOp, MarshalAndScanDataOp
|
|
28
27
|
from palimpzest.query.operators.token_reduction_convert import (
|
|
29
28
|
TokenReducedConvertBonded,
|
|
30
29
|
TokenReducedConvertConventional,
|
|
@@ -47,7 +46,7 @@ class Rule:
|
|
|
47
46
|
raise NotImplementedError("Calling this method from an abstract base class.")
|
|
48
47
|
|
|
49
48
|
@staticmethod
|
|
50
|
-
def substitute(logical_expression: LogicalExpression, **kwargs) ->
|
|
49
|
+
def substitute(logical_expression: LogicalExpression, **kwargs) -> set[Expression]:
|
|
51
50
|
raise NotImplementedError("Calling this method from an abstract base class.")
|
|
52
51
|
|
|
53
52
|
|
|
@@ -60,8 +59,8 @@ class TransformationRule(Rule):
|
|
|
60
59
|
|
|
61
60
|
@staticmethod
|
|
62
61
|
def substitute(
|
|
63
|
-
logical_expression: LogicalExpression, groups:
|
|
64
|
-
) ->
|
|
62
|
+
logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
|
|
63
|
+
) -> tuple[set[LogicalExpression], set[Group]]:
|
|
65
64
|
"""
|
|
66
65
|
This function applies the transformation rule to the logical expression, which
|
|
67
66
|
potentially creates new intermediate expressions and groups.
|
|
@@ -86,8 +85,8 @@ class PushDownFilter(TransformationRule):
|
|
|
86
85
|
|
|
87
86
|
@staticmethod
|
|
88
87
|
def substitute(
|
|
89
|
-
logical_expression: LogicalExpression, groups:
|
|
90
|
-
) ->
|
|
88
|
+
logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
|
|
89
|
+
) -> tuple[set[LogicalExpression], set[Group]]:
|
|
91
90
|
# initialize the sets of new logical expressions and groups to be returned
|
|
92
91
|
new_logical_expressions, new_groups = set(), set()
|
|
93
92
|
|
|
@@ -103,7 +102,7 @@ class PushDownFilter(TransformationRule):
|
|
|
103
102
|
continue
|
|
104
103
|
|
|
105
104
|
# iterate over logical expressions
|
|
106
|
-
logical_exprs = input_group.logical_expressions
|
|
105
|
+
logical_exprs = deepcopy(input_group.logical_expressions)
|
|
107
106
|
for expr in logical_exprs:
|
|
108
107
|
# if the expression operator is not a convert or a filter, we cannot swap
|
|
109
108
|
if not (isinstance(expr.operator, (ConvertScan, FilteredScan))):
|
|
@@ -114,10 +113,10 @@ class PushDownFilter(TransformationRule):
|
|
|
114
113
|
continue
|
|
115
114
|
|
|
116
115
|
# create new logical expression with filter pushed down to the input group's logical expression
|
|
117
|
-
new_input_group_ids = expr.input_group_ids
|
|
118
|
-
new_input_fields = expr.input_fields
|
|
119
|
-
new_depends_on_field_names = logical_expression.depends_on_field_names
|
|
120
|
-
new_generated_fields = logical_expression.generated_fields
|
|
116
|
+
new_input_group_ids = deepcopy(expr.input_group_ids)
|
|
117
|
+
new_input_fields = deepcopy(expr.input_fields)
|
|
118
|
+
new_depends_on_field_names = deepcopy(logical_expression.depends_on_field_names)
|
|
119
|
+
new_generated_fields = deepcopy(logical_expression.generated_fields)
|
|
121
120
|
new_filter_expr = LogicalExpression(
|
|
122
121
|
filter_operator,
|
|
123
122
|
input_group_ids=new_input_group_ids,
|
|
@@ -215,7 +214,7 @@ class NonLLMConvertRule(ImplementationRule):
|
|
|
215
214
|
return isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is not None
|
|
216
215
|
|
|
217
216
|
@classmethod
|
|
218
|
-
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) ->
|
|
217
|
+
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
219
218
|
logical_op = logical_expression.operator
|
|
220
219
|
|
|
221
220
|
# get initial set of parameters for physical op
|
|
@@ -259,7 +258,7 @@ class LLMConvertRule(ImplementationRule):
|
|
|
259
258
|
return isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is None
|
|
260
259
|
|
|
261
260
|
@classmethod
|
|
262
|
-
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) ->
|
|
261
|
+
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
263
262
|
logical_op = logical_expression.operator
|
|
264
263
|
|
|
265
264
|
# get initial set of parameters for physical op
|
|
@@ -293,7 +292,7 @@ class LLMConvertRule(ImplementationRule):
|
|
|
293
292
|
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
294
293
|
])
|
|
295
294
|
list_image_field = any([
|
|
296
|
-
field.is_image_field and
|
|
295
|
+
field.is_image_field and hasattr(field, "element_type")
|
|
297
296
|
for field_name, field in logical_expression.input_fields.items()
|
|
298
297
|
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
299
298
|
])
|
|
@@ -368,7 +367,7 @@ class TokenReducedConvertRule(ImplementationRule):
|
|
|
368
367
|
return isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
|
|
369
368
|
|
|
370
369
|
@classmethod
|
|
371
|
-
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) ->
|
|
370
|
+
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
372
371
|
logical_op = logical_expression.operator
|
|
373
372
|
|
|
374
373
|
# get initial set of parameters for physical op
|
|
@@ -459,7 +458,7 @@ class CodeSynthesisConvertRule(ImplementationRule):
|
|
|
459
458
|
)
|
|
460
459
|
|
|
461
460
|
@classmethod
|
|
462
|
-
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) ->
|
|
461
|
+
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
463
462
|
logical_op = logical_expression.operator
|
|
464
463
|
|
|
465
464
|
# get initial set of parameters for physical op
|
|
@@ -518,7 +517,7 @@ class RAGConvertRule(ImplementationRule):
|
|
|
518
517
|
return isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
|
|
519
518
|
|
|
520
519
|
@classmethod
|
|
521
|
-
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) ->
|
|
520
|
+
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
522
521
|
logical_op = logical_expression.operator
|
|
523
522
|
|
|
524
523
|
# get initial set of parameters for physical op
|
|
@@ -580,7 +579,7 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
|
|
|
580
579
|
return isinstance(logical_op, ConvertScan) and logical_op.udf is None
|
|
581
580
|
|
|
582
581
|
@classmethod
|
|
583
|
-
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) ->
|
|
582
|
+
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
584
583
|
logical_op = logical_expression.operator
|
|
585
584
|
|
|
586
585
|
# get initial set of parameters for physical op
|
|
@@ -605,7 +604,7 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
|
|
|
605
604
|
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
606
605
|
])
|
|
607
606
|
list_image_field = any([
|
|
608
|
-
field.is_image_field and
|
|
607
|
+
field.is_image_field and hasattr(field, "element_type")
|
|
609
608
|
for field_name, field in logical_expression.input_fields.items()
|
|
610
609
|
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
611
610
|
])
|
|
@@ -651,6 +650,98 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
|
|
|
651
650
|
|
|
652
651
|
return set(physical_expressions)
|
|
653
652
|
|
|
653
|
+
class CriticAndRefineConvertRule(ImplementationRule):
|
|
654
|
+
"""
|
|
655
|
+
Implementation rule for the CriticAndRefineConvert operator.
|
|
656
|
+
"""
|
|
657
|
+
|
|
658
|
+
@classmethod
|
|
659
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
660
|
+
logical_op = logical_expression.operator
|
|
661
|
+
return isinstance(logical_op, ConvertScan) and logical_op.udf is None
|
|
662
|
+
|
|
663
|
+
@classmethod
|
|
664
|
+
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
665
|
+
logical_op = logical_expression.operator
|
|
666
|
+
|
|
667
|
+
# Get initial parameters for physical operator
|
|
668
|
+
op_kwargs = logical_op.get_logical_op_params()
|
|
669
|
+
op_kwargs.update(
|
|
670
|
+
{
|
|
671
|
+
"verbose": physical_op_params["verbose"],
|
|
672
|
+
"logical_op_id": logical_op.get_logical_op_id(),
|
|
673
|
+
"logical_op_name": logical_op.logical_op_name(),
|
|
674
|
+
}
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
# NOTE: when comparing pz.Model(s), equality is determined by the string (i.e. pz.Model.value)
|
|
678
|
+
# thus, Model.GPT_4o and Model.GPT_4o_V map to the same value; this allows us to use set logic
|
|
679
|
+
#
|
|
680
|
+
# identify models which can be used strictly for text or strictly for images
|
|
681
|
+
vision_models = set(get_vision_models())
|
|
682
|
+
text_models = set(get_models())
|
|
683
|
+
pure_text_models = {model for model in text_models if model not in vision_models}
|
|
684
|
+
pure_vision_models = {model for model in vision_models if model not in text_models}
|
|
685
|
+
|
|
686
|
+
# compute attributes about this convert operation
|
|
687
|
+
is_image_conversion = any([
|
|
688
|
+
field.is_image_field
|
|
689
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
690
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
691
|
+
])
|
|
692
|
+
num_image_fields = sum([
|
|
693
|
+
field.is_image_field
|
|
694
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
695
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
696
|
+
])
|
|
697
|
+
list_image_field = any([
|
|
698
|
+
field.is_image_field and hasattr(field, "element_type")
|
|
699
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
700
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
701
|
+
])
|
|
702
|
+
|
|
703
|
+
# identify models which can be used for this convert operation
|
|
704
|
+
models = []
|
|
705
|
+
for model in physical_op_params["available_models"]:
|
|
706
|
+
# skip this model if:
|
|
707
|
+
# 1. this is a pure vision model and we're not doing an image conversion, or
|
|
708
|
+
# 2. this is a pure text model and we're doing an image conversion, or
|
|
709
|
+
# 3. this is a vision model hosted by Together (i.e. LLAMA3_V) and there is more than one image field
|
|
710
|
+
first_criteria = model in pure_vision_models and not is_image_conversion
|
|
711
|
+
second_criteria = model in pure_text_models and is_image_conversion
|
|
712
|
+
third_criteria = model == Model.LLAMA3_V and (num_image_fields > 1 or list_image_field)
|
|
713
|
+
if first_criteria or second_criteria or third_criteria:
|
|
714
|
+
continue
|
|
715
|
+
|
|
716
|
+
models.append(model)
|
|
717
|
+
|
|
718
|
+
# TODO: heuristic(s) to narrow the space of critic and refine models we consider using class attributes
|
|
719
|
+
# construct CriticAndRefineConvert operations for every combination of model, critic model, and refinement model
|
|
720
|
+
physical_expressions = []
|
|
721
|
+
for model in models:
|
|
722
|
+
for critic_model in models:
|
|
723
|
+
for refine_model in models:
|
|
724
|
+
# construct multi-expression
|
|
725
|
+
op = CriticAndRefineConvert(
|
|
726
|
+
model=model,
|
|
727
|
+
prompt_strategy=PromptStrategy.COT_QA_IMAGE if is_image_conversion else PromptStrategy.COT_QA,
|
|
728
|
+
critic_model=critic_model,
|
|
729
|
+
refine_model=refine_model,
|
|
730
|
+
**op_kwargs,
|
|
731
|
+
)
|
|
732
|
+
expression = PhysicalExpression(
|
|
733
|
+
operator=op,
|
|
734
|
+
input_group_ids=logical_expression.input_group_ids,
|
|
735
|
+
input_fields=logical_expression.input_fields,
|
|
736
|
+
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
737
|
+
generated_fields=logical_expression.generated_fields,
|
|
738
|
+
group_id=logical_expression.group_id,
|
|
739
|
+
)
|
|
740
|
+
physical_expressions.append(expression)
|
|
741
|
+
|
|
742
|
+
# Return the set containing the new physical expression
|
|
743
|
+
return set(physical_expressions)
|
|
744
|
+
|
|
654
745
|
|
|
655
746
|
class RetrieveRule(ImplementationRule):
|
|
656
747
|
"""
|
|
@@ -667,7 +758,7 @@ class RetrieveRule(ImplementationRule):
|
|
|
667
758
|
@classmethod
|
|
668
759
|
def substitute(
|
|
669
760
|
cls, logical_expression: LogicalExpression, **physical_op_params
|
|
670
|
-
) ->
|
|
761
|
+
) -> set[PhysicalExpression]:
|
|
671
762
|
logical_op = logical_expression.operator
|
|
672
763
|
|
|
673
764
|
physical_expressions = []
|
|
@@ -713,7 +804,7 @@ class NonLLMFilterRule(ImplementationRule):
|
|
|
713
804
|
)
|
|
714
805
|
|
|
715
806
|
@staticmethod
|
|
716
|
-
def substitute(logical_expression: LogicalExpression, **physical_op_params) ->
|
|
807
|
+
def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
717
808
|
logical_op = logical_expression.operator
|
|
718
809
|
op_kwargs = logical_op.get_logical_op_params()
|
|
719
810
|
op_kwargs.update(
|
|
@@ -749,7 +840,7 @@ class LLMFilterRule(ImplementationRule):
|
|
|
749
840
|
)
|
|
750
841
|
|
|
751
842
|
@staticmethod
|
|
752
|
-
def substitute(logical_expression: LogicalExpression, **physical_op_params) ->
|
|
843
|
+
def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
753
844
|
logical_op = logical_expression.operator
|
|
754
845
|
op_kwargs = logical_op.get_logical_op_params()
|
|
755
846
|
op_kwargs.update({
|
|
@@ -779,7 +870,7 @@ class LLMFilterRule(ImplementationRule):
|
|
|
779
870
|
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
780
871
|
])
|
|
781
872
|
list_image_field = any([
|
|
782
|
-
field.is_image_field and
|
|
873
|
+
field.is_image_field and hasattr(field, "element_type")
|
|
783
874
|
for field_name, field in logical_expression.input_fields.items()
|
|
784
875
|
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
785
876
|
])
|
|
@@ -825,7 +916,7 @@ class AggregateRule(ImplementationRule):
|
|
|
825
916
|
return isinstance(logical_expression.operator, Aggregate)
|
|
826
917
|
|
|
827
918
|
@staticmethod
|
|
828
|
-
def substitute(logical_expression: LogicalExpression, **physical_op_params) ->
|
|
919
|
+
def substitute(logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
829
920
|
logical_op = logical_expression.operator
|
|
830
921
|
op_kwargs = logical_op.get_logical_op_params()
|
|
831
922
|
op_kwargs.update(
|
|
@@ -875,7 +966,7 @@ class BasicSubstitutionRule(ImplementationRule):
|
|
|
875
966
|
return logical_op_class in cls.LOGICAL_OP_CLASS_TO_PHYSICAL_OP_CLASS_MAP
|
|
876
967
|
|
|
877
968
|
@classmethod
|
|
878
|
-
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) ->
|
|
969
|
+
def substitute(cls, logical_expression: LogicalExpression, **physical_op_params) -> set[PhysicalExpression]:
|
|
879
970
|
logical_op = logical_expression.operator
|
|
880
971
|
op_kwargs = logical_op.get_logical_op_params()
|
|
881
972
|
op_kwargs.update(
|
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
|
|
4
4
|
from palimpzest.constants import Model
|
|
5
|
+
from palimpzest.core.data.datareaders import DataReader
|
|
5
6
|
from palimpzest.policy import MaxQuality, Policy
|
|
6
7
|
|
|
7
8
|
|
|
@@ -14,6 +15,8 @@ class QueryProcessorConfig:
|
|
|
14
15
|
execution_strategy: str = field(default="sequential")
|
|
15
16
|
optimizer_strategy: str = field(default="pareto")
|
|
16
17
|
|
|
18
|
+
val_datasource: DataReader | None = field(default=None)
|
|
19
|
+
|
|
17
20
|
policy: Policy = field(default_factory=MaxQuality)
|
|
18
21
|
scan_start_idx: int = field(default=0)
|
|
19
22
|
num_samples: int = field(default=float("inf"))
|
|
@@ -31,8 +34,9 @@ class QueryProcessorConfig:
|
|
|
31
34
|
allow_model_selection: bool = field(default=True)
|
|
32
35
|
allow_code_synth: bool = field(default=False)
|
|
33
36
|
allow_token_reduction: bool = field(default=False)
|
|
34
|
-
allow_rag_reduction: bool = field(default=
|
|
37
|
+
allow_rag_reduction: bool = field(default=False)
|
|
35
38
|
allow_mixtures: bool = field(default=True)
|
|
39
|
+
allow_critic: bool = field(default=False)
|
|
36
40
|
use_final_op_quality: bool = field(default=False)
|
|
37
41
|
|
|
38
42
|
def to_json_str(self):
|
|
@@ -40,6 +44,7 @@ class QueryProcessorConfig:
|
|
|
40
44
|
"processing_strategy": self.processing_strategy,
|
|
41
45
|
"execution_strategy": self.execution_strategy,
|
|
42
46
|
"optimizer_strategy": self.optimizer_strategy,
|
|
47
|
+
"val_datasource": None if self.val_datasource is None else self.val_datasource.serialize(),
|
|
43
48
|
"policy": self.policy.to_json_str(),
|
|
44
49
|
"scan_start_idx": self.scan_start_idx,
|
|
45
50
|
"num_samples": self.num_samples,
|
|
@@ -57,5 +62,11 @@ class QueryProcessorConfig:
|
|
|
57
62
|
"allow_token_reduction": self.allow_token_reduction,
|
|
58
63
|
"allow_rag_reduction": self.allow_rag_reduction,
|
|
59
64
|
"allow_mixtures": self.allow_mixtures,
|
|
65
|
+
"allow_critic": self.allow_critic,
|
|
60
66
|
"use_final_op_quality": self.use_final_op_quality,
|
|
61
67
|
}, indent=2)
|
|
68
|
+
|
|
69
|
+
def update(self, **kwargs) -> None:
|
|
70
|
+
for key, value in kwargs.items():
|
|
71
|
+
if hasattr(self, key):
|
|
72
|
+
setattr(self, key, value)
|