palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +343 -209
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +639 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +62 -6
- palimpzest/prompts/filter_prompts.py +51 -6
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
- palimpzest/prompts/prompt_factory.py +375 -47
- palimpzest/prompts/split_proposer_prompts.py +1 -1
- palimpzest/prompts/util_phrases.py +5 -0
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +160 -331
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +33 -19
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +26 -16
- palimpzest/query/operators/join.py +403 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +205 -77
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +42 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +32 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
- palimpzest-0.8.1.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -1,38 +1,72 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import os
|
|
2
3
|
from copy import deepcopy
|
|
3
4
|
from itertools import combinations
|
|
4
5
|
|
|
5
|
-
from palimpzest.constants import AggFunc,
|
|
6
|
+
from palimpzest.constants import AggFunc, Model, PromptStrategy
|
|
7
|
+
from palimpzest.core.data.context_manager import ContextManager
|
|
8
|
+
from palimpzest.core.lib.schemas import AudioBase64, AudioFilepath, ImageBase64, ImageFilepath, ImageURL
|
|
9
|
+
from palimpzest.prompts import CONTEXT_SEARCH_PROMPT
|
|
6
10
|
from palimpzest.query.operators.aggregate import ApplyGroupByOp, AverageAggregateOp, CountAggregateOp
|
|
7
|
-
from palimpzest.query.operators.
|
|
11
|
+
from palimpzest.query.operators.compute import SmolAgentsCompute
|
|
8
12
|
from palimpzest.query.operators.convert import LLMConvertBonded, NonLLMConvert
|
|
9
13
|
from palimpzest.query.operators.critique_and_refine_convert import CriticAndRefineConvert
|
|
14
|
+
from palimpzest.query.operators.distinct import DistinctOp
|
|
10
15
|
from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter
|
|
16
|
+
from palimpzest.query.operators.join import NestedLoopsJoin
|
|
11
17
|
from palimpzest.query.operators.limit import LimitScanOp
|
|
12
18
|
from palimpzest.query.operators.logical import (
|
|
13
19
|
Aggregate,
|
|
14
20
|
BaseScan,
|
|
15
|
-
|
|
21
|
+
ComputeOperator,
|
|
22
|
+
ContextScan,
|
|
16
23
|
ConvertScan,
|
|
24
|
+
Distinct,
|
|
17
25
|
FilteredScan,
|
|
18
26
|
GroupByAggregate,
|
|
27
|
+
JoinOp,
|
|
19
28
|
LimitScan,
|
|
20
|
-
MapScan,
|
|
21
29
|
Project,
|
|
22
30
|
RetrieveScan,
|
|
31
|
+
SearchOperator,
|
|
23
32
|
)
|
|
24
|
-
from palimpzest.query.operators.map import MapOp
|
|
25
33
|
from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgentsConvert
|
|
34
|
+
from palimpzest.query.operators.physical import PhysicalOperator
|
|
26
35
|
from palimpzest.query.operators.project import ProjectOp
|
|
27
36
|
from palimpzest.query.operators.rag_convert import RAGConvert
|
|
28
37
|
from palimpzest.query.operators.retrieve import RetrieveOp
|
|
29
|
-
from palimpzest.query.operators.scan import
|
|
38
|
+
from palimpzest.query.operators.scan import ContextScanOp, MarshalAndScanDataOp
|
|
39
|
+
from palimpzest.query.operators.search import (
|
|
40
|
+
SmolAgentsSearch, # SmolAgentsCustomManagedSearch, # SmolAgentsManagedSearch
|
|
41
|
+
)
|
|
30
42
|
from palimpzest.query.operators.split_convert import SplitConvert
|
|
31
43
|
from palimpzest.query.optimizer.primitives import Expression, Group, LogicalExpression, PhysicalExpression
|
|
32
|
-
from palimpzest.utils.model_helpers import get_models, get_vision_models
|
|
33
44
|
|
|
34
45
|
logger = logging.getLogger(__name__)
|
|
35
46
|
|
|
47
|
+
# DEFINITIONS
|
|
48
|
+
IMAGE_LIST_FIELD_TYPES = [
|
|
49
|
+
list[ImageBase64],
|
|
50
|
+
list[ImageFilepath],
|
|
51
|
+
list[ImageURL],
|
|
52
|
+
list[ImageBase64] | None,
|
|
53
|
+
list[ImageFilepath] | None,
|
|
54
|
+
list[ImageURL] | None,
|
|
55
|
+
]
|
|
56
|
+
IMAGE_FIELD_TYPES = IMAGE_LIST_FIELD_TYPES + [
|
|
57
|
+
ImageBase64, ImageFilepath, ImageURL,
|
|
58
|
+
ImageBase64 | None, ImageFilepath | None, ImageURL | None,
|
|
59
|
+
]
|
|
60
|
+
AUDIO_LIST_FIELD_TYPES = [
|
|
61
|
+
list[AudioBase64],
|
|
62
|
+
list[AudioFilepath],
|
|
63
|
+
list[AudioBase64] | None,
|
|
64
|
+
list[AudioFilepath] | None,
|
|
65
|
+
]
|
|
66
|
+
AUDIO_FIELD_TYPES = AUDIO_LIST_FIELD_TYPES + [
|
|
67
|
+
AudioBase64, AudioFilepath,
|
|
68
|
+
AudioBase64 | None, AudioFilepath | None,
|
|
69
|
+
]
|
|
36
70
|
|
|
37
71
|
class Rule:
|
|
38
72
|
"""
|
|
@@ -43,12 +77,12 @@ class Rule:
|
|
|
43
77
|
def get_rule_id(cls):
|
|
44
78
|
return cls.__name__
|
|
45
79
|
|
|
46
|
-
@
|
|
47
|
-
def matches_pattern(logical_expression: LogicalExpression) -> bool:
|
|
80
|
+
@classmethod
|
|
81
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
48
82
|
raise NotImplementedError("Calling this method from an abstract base class.")
|
|
49
83
|
|
|
50
|
-
@
|
|
51
|
-
def substitute(logical_expression: LogicalExpression, **kwargs) -> set[Expression]:
|
|
84
|
+
@classmethod
|
|
85
|
+
def substitute(cls, logical_expression: LogicalExpression, **kwargs: dict) -> set[Expression]:
|
|
52
86
|
raise NotImplementedError("Calling this method from an abstract base class.")
|
|
53
87
|
|
|
54
88
|
|
|
@@ -59,9 +93,9 @@ class TransformationRule(Rule):
|
|
|
59
93
|
which are created during the substitution.
|
|
60
94
|
"""
|
|
61
95
|
|
|
62
|
-
@
|
|
96
|
+
@classmethod
|
|
63
97
|
def substitute(
|
|
64
|
-
logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
|
|
98
|
+
cls, logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
|
|
65
99
|
) -> tuple[set[LogicalExpression], set[Group]]:
|
|
66
100
|
"""
|
|
67
101
|
This function applies the transformation rule to the logical expression, which
|
|
@@ -81,15 +115,15 @@ class PushDownFilter(TransformationRule):
|
|
|
81
115
|
most expensive operator in the input group.
|
|
82
116
|
"""
|
|
83
117
|
|
|
84
|
-
@
|
|
85
|
-
def matches_pattern(logical_expression: Expression) -> bool:
|
|
118
|
+
@classmethod
|
|
119
|
+
def matches_pattern(cls, logical_expression: Expression) -> bool:
|
|
86
120
|
is_match = isinstance(logical_expression.operator, FilteredScan)
|
|
87
121
|
logger.debug(f"PushDownFilter matches_pattern: {is_match} for {logical_expression}")
|
|
88
122
|
return is_match
|
|
89
123
|
|
|
90
|
-
@
|
|
124
|
+
@classmethod
|
|
91
125
|
def substitute(
|
|
92
|
-
logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs
|
|
126
|
+
cls, logical_expression: LogicalExpression, groups: dict[int, Group], expressions: dict[int, Expression], **kwargs: dict
|
|
93
127
|
) -> tuple[set[LogicalExpression], set[Group]]:
|
|
94
128
|
logger.debug(f"Substituting PushDownFilter for {logical_expression}")
|
|
95
129
|
|
|
@@ -113,7 +147,7 @@ class PushDownFilter(TransformationRule):
|
|
|
113
147
|
# we see a regression / bug in the future
|
|
114
148
|
for expr in input_group.logical_expressions:
|
|
115
149
|
# if the expression operator is not a convert or a filter, we cannot swap
|
|
116
|
-
if not (isinstance(expr.operator, (ConvertScan, FilteredScan))):
|
|
150
|
+
if not (isinstance(expr.operator, (ConvertScan, FilteredScan, JoinOp))):
|
|
117
151
|
continue
|
|
118
152
|
|
|
119
153
|
# if this filter depends on a field generated by the expression we're trying to swap with, we can't swap
|
|
@@ -141,8 +175,8 @@ class PushDownFilter(TransformationRule):
|
|
|
141
175
|
group_id, group = None, None
|
|
142
176
|
|
|
143
177
|
# if the expression already exists, lookup the group_id and group
|
|
144
|
-
if new_filter_expr.
|
|
145
|
-
group_id = expressions[new_filter_expr.
|
|
178
|
+
if new_filter_expr.expr_id in expressions:
|
|
179
|
+
group_id = expressions[new_filter_expr.expr_id].group_id
|
|
146
180
|
new_filter_expr.set_group_id(group_id)
|
|
147
181
|
group = groups[group_id]
|
|
148
182
|
|
|
@@ -190,8 +224,7 @@ class PushDownFilter(TransformationRule):
|
|
|
190
224
|
# create final new logical expression with expr's operator pulled up
|
|
191
225
|
new_expr = LogicalExpression(
|
|
192
226
|
expr.operator.copy(),
|
|
193
|
-
input_group_ids=[group_id]
|
|
194
|
-
+ [g_id for g_id in logical_expression.input_group_ids if g_id != input_group_id],
|
|
227
|
+
input_group_ids=[group_id] + [g_id for g_id in logical_expression.input_group_ids if g_id != input_group_id],
|
|
195
228
|
input_fields=group.fields,
|
|
196
229
|
depends_on_field_names=expr.depends_on_field_names,
|
|
197
230
|
generated_fields=expr.generated_fields,
|
|
@@ -211,218 +244,280 @@ class ImplementationRule(Rule):
|
|
|
211
244
|
Base class for implementation rules which convert a logical expression to a physical expression.
|
|
212
245
|
"""
|
|
213
246
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
247
|
+
@classmethod
|
|
248
|
+
def _get_image_fields(cls, logical_expression: LogicalExpression) -> set[str]:
|
|
249
|
+
"""Returns the set of fields which have an image (or list[image]) type."""
|
|
250
|
+
return set([
|
|
251
|
+
field_name.split(".")[-1]
|
|
252
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
253
|
+
if field.annotation in IMAGE_FIELD_TYPES and field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
254
|
+
])
|
|
221
255
|
|
|
222
256
|
@classmethod
|
|
223
|
-
def
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
257
|
+
def _get_list_image_fields(cls, logical_expression: LogicalExpression) -> set[str]:
|
|
258
|
+
"""Returns the set of fields which have a list[image] type."""
|
|
259
|
+
return set([
|
|
260
|
+
field_name.split(".")[-1]
|
|
261
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
262
|
+
if field.annotation in IMAGE_LIST_FIELD_TYPES and field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
263
|
+
])
|
|
227
264
|
|
|
228
265
|
@classmethod
|
|
229
|
-
def
|
|
230
|
-
|
|
266
|
+
def _get_audio_fields(cls, logical_expression: LogicalExpression) -> set[str]:
|
|
267
|
+
"""Returns the set of fields which have an audio (or list[audio]) type."""
|
|
268
|
+
return set([
|
|
269
|
+
field_name.split(".")[-1]
|
|
270
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
271
|
+
if field.annotation in AUDIO_FIELD_TYPES and field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
272
|
+
])
|
|
231
273
|
|
|
232
|
-
|
|
274
|
+
@classmethod
|
|
275
|
+
def _get_list_audio_fields(cls, logical_expression: LogicalExpression) -> set[str]:
|
|
276
|
+
"""Returns the set of fields which have a list[audio] type."""
|
|
277
|
+
return set([
|
|
278
|
+
field_name.split(".")[-1]
|
|
279
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
280
|
+
if field.annotation in AUDIO_LIST_FIELD_TYPES and field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
281
|
+
])
|
|
233
282
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
)
|
|
283
|
+
@classmethod
|
|
284
|
+
def _is_image_only_operation(cls, logical_expression: LogicalExpression) -> bool:
|
|
285
|
+
"""Returns True if the logical_expression processes only image input(s) and False otherwise."""
|
|
286
|
+
return all([
|
|
287
|
+
field.annotation in IMAGE_FIELD_TYPES
|
|
288
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
289
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
290
|
+
])
|
|
243
291
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
group_id=logical_expression.group_id,
|
|
253
|
-
)
|
|
292
|
+
@classmethod
|
|
293
|
+
def _is_image_operation(cls, logical_expression: LogicalExpression) -> bool:
|
|
294
|
+
"""Returns True if the logical_expression processes image input(s) and False otherwise."""
|
|
295
|
+
return any([
|
|
296
|
+
field.annotation in IMAGE_FIELD_TYPES
|
|
297
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
298
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
299
|
+
])
|
|
254
300
|
|
|
255
|
-
|
|
256
|
-
|
|
301
|
+
@classmethod
|
|
302
|
+
def _is_audio_only_operation(cls, logical_expression: LogicalExpression) -> bool:
|
|
303
|
+
"""Returns True if the logical_expression processes only audio input(s) and False otherwise."""
|
|
304
|
+
return all([
|
|
305
|
+
field.annotation in AUDIO_FIELD_TYPES
|
|
306
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
307
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
308
|
+
])
|
|
257
309
|
|
|
258
|
-
|
|
310
|
+
@classmethod
|
|
311
|
+
def _is_audio_operation(cls, logical_expression: LogicalExpression) -> bool:
|
|
312
|
+
"""Returns True if the logical_expression processes audio input(s) and False otherwise."""
|
|
313
|
+
return any([
|
|
314
|
+
field.annotation in AUDIO_FIELD_TYPES
|
|
315
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
316
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
317
|
+
])
|
|
259
318
|
|
|
319
|
+
@classmethod
|
|
320
|
+
def _is_text_only_operation(cls, logical_expression: LogicalExpression) -> bool:
|
|
321
|
+
"""Returns True if the logical_expression processes only text input(s) and False otherwise."""
|
|
322
|
+
return all([
|
|
323
|
+
field.annotation not in IMAGE_FIELD_TYPES + AUDIO_FIELD_TYPES
|
|
324
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
325
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
326
|
+
])
|
|
260
327
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
328
|
+
@classmethod
|
|
329
|
+
def _is_text_operation(cls, logical_expression: LogicalExpression) -> bool:
|
|
330
|
+
"""Returns True if the logical_expression processes text input(s) and False otherwise."""
|
|
331
|
+
return any([
|
|
332
|
+
field.annotation not in IMAGE_FIELD_TYPES + AUDIO_FIELD_TYPES
|
|
333
|
+
for field_name, field in logical_expression.input_fields.items()
|
|
334
|
+
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
335
|
+
])
|
|
336
|
+
|
|
337
|
+
# TODO: support powerset of text + image + audio (+ video) multi-modal operations
|
|
338
|
+
@classmethod
|
|
339
|
+
def _is_text_image_multimodal_operation(cls, logical_expression: LogicalExpression) -> bool:
|
|
340
|
+
"""Returns True if the logical_expression processes text and image inputs and False otherwise."""
|
|
341
|
+
return cls._is_image_operation(logical_expression) and cls._is_text_operation(logical_expression)
|
|
265
342
|
|
|
266
343
|
@classmethod
|
|
267
|
-
def
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
return is_match
|
|
344
|
+
def _is_text_audio_multimodal_operation(cls, logical_expression: LogicalExpression) -> bool:
|
|
345
|
+
"""Returns True if the logical_expression processes text and audio inputs and False otherwise."""
|
|
346
|
+
return cls._is_audio_operation(logical_expression) and cls._is_text_operation(logical_expression)
|
|
271
347
|
|
|
272
348
|
@classmethod
|
|
273
|
-
def
|
|
274
|
-
|
|
349
|
+
def _model_matches_input(cls, model: Model, logical_expression: LogicalExpression) -> bool:
|
|
350
|
+
"""Returns True if the model is capable of processing the input and False otherwise."""
|
|
351
|
+
# compute how many image fields are in the input, and whether any fields are list[image] fields
|
|
352
|
+
num_image_fields = len(cls._get_image_fields(logical_expression))
|
|
353
|
+
has_list_image_field = len(cls._get_list_image_fields(logical_expression)) > 0
|
|
354
|
+
num_audio_fields = len(cls._get_audio_fields(logical_expression))
|
|
355
|
+
has_list_audio_field = len(cls._get_list_audio_fields(logical_expression)) > 0
|
|
356
|
+
|
|
357
|
+
# corner-case: for now, all operators use text or vision models for processing inputs to __call__
|
|
358
|
+
if model.is_embedding_model():
|
|
359
|
+
return False
|
|
275
360
|
|
|
361
|
+
# corner-case: Llama vision models cannot handle multiple image inputs (at least using Together)
|
|
362
|
+
if model.is_llama_model() and model.is_vision_model() and (num_image_fields > 1 or has_list_image_field):
|
|
363
|
+
return False
|
|
364
|
+
|
|
365
|
+
# corner-case: Gemini models cannot handle multiple audio inputs
|
|
366
|
+
if model.is_vertex_model() and model.is_audio_model() and (num_audio_fields > 1 or has_list_audio_field):
|
|
367
|
+
return False
|
|
368
|
+
|
|
369
|
+
# text-only input and text supporting model
|
|
370
|
+
if cls._is_text_only_operation(logical_expression) and model.is_text_model():
|
|
371
|
+
return True
|
|
372
|
+
|
|
373
|
+
# image-only input and image supporting model
|
|
374
|
+
if cls._is_image_only_operation(logical_expression) and model.is_vision_model():
|
|
375
|
+
return True
|
|
376
|
+
|
|
377
|
+
# audio-only input and audio supporting model
|
|
378
|
+
if cls._is_audio_only_operation(logical_expression) and model.is_audio_model():
|
|
379
|
+
return True
|
|
380
|
+
|
|
381
|
+
# multi-modal input and multi-modal supporting model
|
|
382
|
+
if cls._is_text_image_multimodal_operation(logical_expression) and model.is_text_image_multimodal_model(): # noqa: SIM103
|
|
383
|
+
return True
|
|
384
|
+
|
|
385
|
+
# multi-modal input and multi-modal supporting model
|
|
386
|
+
if cls._is_text_audio_multimodal_operation(logical_expression) and model.is_text_audio_multimodal_model(): # noqa: SIM103
|
|
387
|
+
return True
|
|
388
|
+
|
|
389
|
+
return False
|
|
390
|
+
|
|
391
|
+
@classmethod
|
|
392
|
+
def _get_fixed_op_kwargs(cls, logical_expression: LogicalExpression, runtime_kwargs: dict) -> dict:
|
|
393
|
+
"""Get the fixed set of physical op kwargs provided by the logical expression and the runtime keyword arguments."""
|
|
394
|
+
# get logical operator
|
|
276
395
|
logical_op = logical_expression.operator
|
|
277
396
|
|
|
278
|
-
#
|
|
397
|
+
# set initial set of parameters for physical op
|
|
279
398
|
op_kwargs = logical_op.get_logical_op_params()
|
|
280
399
|
op_kwargs.update(
|
|
281
400
|
{
|
|
282
|
-
"verbose":
|
|
401
|
+
"verbose": runtime_kwargs["verbose"],
|
|
283
402
|
"logical_op_id": logical_op.get_logical_op_id(),
|
|
403
|
+
"unique_logical_op_id": logical_op.get_unique_logical_op_id(),
|
|
284
404
|
"logical_op_name": logical_op.logical_op_name(),
|
|
405
|
+
"api_base": runtime_kwargs["api_base"],
|
|
285
406
|
}
|
|
286
407
|
)
|
|
287
408
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
[
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
409
|
+
return op_kwargs
|
|
410
|
+
|
|
411
|
+
@classmethod
|
|
412
|
+
def _perform_substitution(
|
|
413
|
+
cls,
|
|
414
|
+
logical_expression: LogicalExpression,
|
|
415
|
+
physical_op_class: type[PhysicalOperator],
|
|
416
|
+
runtime_kwargs: dict,
|
|
417
|
+
variable_op_kwargs: list[dict] | dict | None = None,
|
|
418
|
+
) -> set[PhysicalExpression]:
|
|
419
|
+
"""
|
|
420
|
+
This performs basic substitution logic which proceeds in four steps:
|
|
421
|
+
|
|
422
|
+
1. The basic kwargs for the physical operator are computed using the logical operator
|
|
423
|
+
and runtime kwargs.
|
|
424
|
+
2. If variable kwargs are provided, then they are merged with the basic kwargs and one
|
|
425
|
+
instance of the physical operator is created for each dictionary of variable kwargs.
|
|
426
|
+
3. A physical expression is created for each physical operator instance.
|
|
427
|
+
4. The unique set of physical expressions is returned.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
logical_expression (LogicalExpression): The logical expression containing a logical operator.
|
|
431
|
+
physical_op_class (type[PhysicalOperator]): The class of the physical operator we wish to construct.
|
|
432
|
+
runtime_kwargs (dict): Keyword arguments which are provided at runtime.
|
|
433
|
+
variable_op_kwargs (list[dict] | dict | None): A (list of) variable kwargs to customize each
|
|
434
|
+
physical operator instance.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
set[PhysicalExpression]: The unique set of physical expressions produced by initializing the
|
|
438
|
+
physical_op_class with the provided keyword arguments.
|
|
439
|
+
"""
|
|
440
|
+
# get physical operator kwargs which are fixed for each instance of the physical operator
|
|
441
|
+
fixed_op_kwargs = cls._get_fixed_op_kwargs(logical_expression, runtime_kwargs)
|
|
442
|
+
|
|
443
|
+
# make variable_op_kwargs a list of dictionaries
|
|
444
|
+
if variable_op_kwargs is None:
|
|
445
|
+
variable_op_kwargs = [{}]
|
|
446
|
+
elif isinstance(variable_op_kwargs, dict):
|
|
447
|
+
variable_op_kwargs = [variable_op_kwargs]
|
|
316
448
|
|
|
449
|
+
# construct physical operators for each set of kwargs
|
|
317
450
|
physical_expressions = []
|
|
318
|
-
for
|
|
319
|
-
#
|
|
320
|
-
|
|
321
|
-
# 2. this is a pure text model and we're doing an image conversion, or
|
|
322
|
-
# 3. this is a vision model hosted by Together (i.e. LLAMA3 vision) and there is more than one image field
|
|
323
|
-
first_criteria = model in pure_vision_models and not is_image_conversion
|
|
324
|
-
second_criteria = model in pure_text_models and is_image_conversion
|
|
325
|
-
third_criteria = model.is_llama_model() and model.is_vision_model() and (num_image_fields > 1 or list_image_field)
|
|
326
|
-
fourth_criteria = model.is_embedding_model()
|
|
327
|
-
if first_criteria or second_criteria or third_criteria or fourth_criteria:
|
|
328
|
-
continue
|
|
451
|
+
for var_op_kwargs in variable_op_kwargs:
|
|
452
|
+
# get kwargs for this physical operator instance
|
|
453
|
+
op_kwargs = {**fixed_op_kwargs, **var_op_kwargs}
|
|
329
454
|
|
|
330
|
-
# construct
|
|
331
|
-
op =
|
|
332
|
-
model=model,
|
|
333
|
-
prompt_strategy=PromptStrategy.COT_QA_IMAGE if is_image_conversion else PromptStrategy.COT_QA,
|
|
334
|
-
**op_kwargs,
|
|
335
|
-
)
|
|
336
|
-
expression = PhysicalExpression(
|
|
337
|
-
operator=op,
|
|
338
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
339
|
-
input_fields=logical_expression.input_fields,
|
|
340
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
341
|
-
generated_fields=logical_expression.generated_fields,
|
|
342
|
-
group_id=logical_expression.group_id,
|
|
343
|
-
)
|
|
344
|
-
physical_expressions.append(expression)
|
|
455
|
+
# construct the physical operator
|
|
456
|
+
op = physical_op_class(**op_kwargs)
|
|
345
457
|
|
|
346
|
-
|
|
347
|
-
|
|
458
|
+
# construct physical expression and add to list of expressions
|
|
459
|
+
expression = PhysicalExpression.from_op_and_logical_expr(op, logical_expression)
|
|
460
|
+
physical_expressions.append(expression)
|
|
348
461
|
|
|
349
|
-
return
|
|
462
|
+
return set(physical_expressions)
|
|
350
463
|
|
|
351
464
|
|
|
352
|
-
class
|
|
465
|
+
class NonLLMConvertRule(ImplementationRule):
|
|
353
466
|
"""
|
|
354
|
-
|
|
355
|
-
(CodeSynthesisConvertSingle) is provided by sub-class rules.
|
|
356
|
-
|
|
357
|
-
NOTE: we provide the physical convert class(es) in their own sub-classed rules to make
|
|
358
|
-
it easier to allow/disallow groups of rules at the Optimizer level.
|
|
467
|
+
Substitute a logical expression for a UDF ConvertScan with a NonLLMConvert physical implementation.
|
|
359
468
|
"""
|
|
360
469
|
|
|
361
|
-
physical_convert_class = None # overriden by sub-classes
|
|
362
|
-
|
|
363
470
|
@classmethod
|
|
364
471
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
[
|
|
368
|
-
field.is_image_field
|
|
369
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
370
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
371
|
-
]
|
|
372
|
-
)
|
|
373
|
-
is_match = (
|
|
374
|
-
isinstance(logical_op, ConvertScan)
|
|
375
|
-
and not is_image_conversion
|
|
376
|
-
and logical_op.cardinality != Cardinality.ONE_TO_MANY
|
|
377
|
-
and logical_op.udf is None
|
|
378
|
-
)
|
|
379
|
-
logger.debug(f"CodeSynthesisConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
472
|
+
is_match = isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is not None
|
|
473
|
+
logger.debug(f"NonLLMConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
380
474
|
return is_match
|
|
381
475
|
|
|
382
476
|
@classmethod
|
|
383
|
-
def substitute(cls, logical_expression: LogicalExpression, **
|
|
384
|
-
logger.debug(f"Substituting
|
|
385
|
-
|
|
386
|
-
logical_op = logical_expression.operator
|
|
477
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
478
|
+
logger.debug(f"Substituting NonLLMConvertRule for {logical_expression}")
|
|
479
|
+
return cls._perform_substitution(logical_expression, NonLLMConvert, runtime_kwargs)
|
|
387
480
|
|
|
388
|
-
# get initial set of parameters for physical op
|
|
389
|
-
op_kwargs = logical_op.get_logical_op_params()
|
|
390
|
-
op_kwargs.update(
|
|
391
|
-
{
|
|
392
|
-
"verbose": physical_op_params["verbose"],
|
|
393
|
-
"logical_op_id": logical_op.get_logical_op_id(),
|
|
394
|
-
"logical_op_name": logical_op.logical_op_name(),
|
|
395
|
-
}
|
|
396
|
-
)
|
|
397
481
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
fallback_model=physical_op_params["fallback_model"],
|
|
403
|
-
prompt_strategy=PromptStrategy.COT_QA,
|
|
404
|
-
**op_kwargs,
|
|
405
|
-
)
|
|
406
|
-
expression = PhysicalExpression(
|
|
407
|
-
operator=op,
|
|
408
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
409
|
-
input_fields=logical_expression.input_fields,
|
|
410
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
411
|
-
generated_fields=logical_expression.generated_fields,
|
|
412
|
-
group_id=logical_expression.group_id,
|
|
413
|
-
)
|
|
414
|
-
deduped_physical_expressions = set([expression])
|
|
415
|
-
logger.debug(f"Done substituting CodeSynthesisConvertRule for {logical_expression}")
|
|
482
|
+
class LLMConvertBondedRule(ImplementationRule):
|
|
483
|
+
"""
|
|
484
|
+
Substitute a logical expression for a ConvertScan with a bonded convert physical implementation.
|
|
485
|
+
"""
|
|
416
486
|
|
|
417
|
-
|
|
487
|
+
@classmethod
|
|
488
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
489
|
+
is_match = isinstance(logical_expression.operator, ConvertScan) and logical_expression.operator.udf is None
|
|
490
|
+
logger.debug(f"LLMConvertBondedRule matches_pattern: {is_match} for {logical_expression}")
|
|
491
|
+
return is_match
|
|
418
492
|
|
|
493
|
+
@classmethod
|
|
494
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
495
|
+
logger.debug(f"Substituting LLMConvertBondedRule for {logical_expression}")
|
|
419
496
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
497
|
+
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
498
|
+
models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
|
|
499
|
+
# NOTE: right now we exclusively allow image or audio operations, but not both simultaneously
|
|
500
|
+
prompt_strategy, no_reasoning_prompt_strategy = None, None
|
|
501
|
+
no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
|
|
502
|
+
if cls._is_text_only_operation(logical_expression):
|
|
503
|
+
prompt_strategy = PromptStrategy.COT_QA
|
|
504
|
+
no_reasoning_prompt_strategy = PromptStrategy.COT_QA_NO_REASONING
|
|
505
|
+
elif cls._is_image_operation(logical_expression):
|
|
506
|
+
prompt_strategy = PromptStrategy.COT_QA_IMAGE
|
|
507
|
+
no_reasoning_prompt_strategy = PromptStrategy.COT_QA_IMAGE_NO_REASONING
|
|
508
|
+
elif cls._is_audio_operation(logical_expression):
|
|
509
|
+
prompt_strategy = PromptStrategy.COT_QA_AUDIO
|
|
510
|
+
no_reasoning_prompt_strategy = PromptStrategy.COT_QA_AUDIO_NO_REASONING
|
|
511
|
+
variable_op_kwargs = [
|
|
512
|
+
{
|
|
513
|
+
"model": model,
|
|
514
|
+
"prompt_strategy": no_reasoning_prompt_strategy if model.is_reasoning_model() and no_reasoning else prompt_strategy,
|
|
515
|
+
"reasoning_effort": runtime_kwargs["reasoning_effort"],
|
|
516
|
+
}
|
|
517
|
+
for model in models
|
|
518
|
+
]
|
|
424
519
|
|
|
425
|
-
|
|
520
|
+
return cls._perform_substitution(logical_expression, LLMConvertBonded, runtime_kwargs, variable_op_kwargs)
|
|
426
521
|
|
|
427
522
|
|
|
428
523
|
class RAGConvertRule(ImplementationRule):
|
|
@@ -436,68 +531,30 @@ class RAGConvertRule(ImplementationRule):
|
|
|
436
531
|
@classmethod
|
|
437
532
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
438
533
|
logical_op = logical_expression.operator
|
|
439
|
-
|
|
440
|
-
[
|
|
441
|
-
field.is_image_field
|
|
442
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
443
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
444
|
-
]
|
|
445
|
-
)
|
|
446
|
-
is_match = isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
|
|
534
|
+
is_match = isinstance(logical_op, ConvertScan) and cls._is_text_only_operation(logical_expression) and logical_op.udf is None
|
|
447
535
|
logger.debug(f"RAGConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
448
536
|
return is_match
|
|
449
537
|
|
|
450
538
|
@classmethod
|
|
451
|
-
def substitute(cls, logical_expression: LogicalExpression, **
|
|
539
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
452
540
|
logger.debug(f"Substituting RAGConvertRule for {logical_expression}")
|
|
453
541
|
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
op_kwargs = logical_op.get_logical_op_params()
|
|
458
|
-
op_kwargs.update(
|
|
542
|
+
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
543
|
+
models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
|
|
544
|
+
variable_op_kwargs = [
|
|
459
545
|
{
|
|
460
|
-
"
|
|
461
|
-
"
|
|
462
|
-
"
|
|
546
|
+
"model": model,
|
|
547
|
+
"prompt_strategy": PromptStrategy.COT_QA,
|
|
548
|
+
"num_chunks_per_field": num_chunks_per_field,
|
|
549
|
+
"chunk_size": chunk_size,
|
|
550
|
+
"reasoning_effort": runtime_kwargs["reasoning_effort"],
|
|
463
551
|
}
|
|
464
|
-
|
|
552
|
+
for model in models
|
|
553
|
+
for num_chunks_per_field in cls.num_chunks_per_fields
|
|
554
|
+
for chunk_size in cls.chunk_sizes
|
|
555
|
+
]
|
|
465
556
|
|
|
466
|
-
|
|
467
|
-
vision_models = set(get_vision_models())
|
|
468
|
-
text_models = set(get_models())
|
|
469
|
-
pure_vision_models = {model for model in vision_models if model not in text_models}
|
|
470
|
-
|
|
471
|
-
physical_expressions = []
|
|
472
|
-
for model in physical_op_params["available_models"]:
|
|
473
|
-
# skip this model if this is a pure image model
|
|
474
|
-
if model in pure_vision_models or model.is_embedding_model():
|
|
475
|
-
continue
|
|
476
|
-
|
|
477
|
-
for num_chunks_per_field in cls.num_chunks_per_fields:
|
|
478
|
-
for chunk_size in cls.chunk_sizes:
|
|
479
|
-
# construct multi-expression
|
|
480
|
-
op = RAGConvert(
|
|
481
|
-
model=model,
|
|
482
|
-
prompt_strategy=PromptStrategy.COT_QA,
|
|
483
|
-
num_chunks_per_field=num_chunks_per_field,
|
|
484
|
-
chunk_size=chunk_size,
|
|
485
|
-
**op_kwargs,
|
|
486
|
-
)
|
|
487
|
-
expression = PhysicalExpression(
|
|
488
|
-
operator=op,
|
|
489
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
490
|
-
input_fields=logical_expression.input_fields,
|
|
491
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
492
|
-
generated_fields=logical_expression.generated_fields,
|
|
493
|
-
group_id=logical_expression.group_id,
|
|
494
|
-
)
|
|
495
|
-
physical_expressions.append(expression)
|
|
496
|
-
|
|
497
|
-
logger.debug(f"Done substituting RAGConvertRule for {logical_expression}")
|
|
498
|
-
deduped_physical_expressions = set(physical_expressions)
|
|
499
|
-
|
|
500
|
-
return deduped_physical_expressions
|
|
557
|
+
return cls._perform_substitution(logical_expression, RAGConvert, runtime_kwargs, variable_op_kwargs)
|
|
501
558
|
|
|
502
559
|
|
|
503
560
|
class MixtureOfAgentsConvertRule(ImplementationRule):
|
|
@@ -511,93 +568,35 @@ class MixtureOfAgentsConvertRule(ImplementationRule):
|
|
|
511
568
|
@classmethod
|
|
512
569
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
513
570
|
logical_op = logical_expression.operator
|
|
514
|
-
|
|
571
|
+
# TODO: remove audio limitation once I add prompts
|
|
572
|
+
is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None and not cls._is_audio_operation(logical_expression)
|
|
515
573
|
logger.debug(f"MixtureOfAgentsConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
516
574
|
return is_match
|
|
517
575
|
|
|
518
576
|
@classmethod
|
|
519
|
-
def substitute(cls, logical_expression: LogicalExpression, **
|
|
577
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
520
578
|
logger.debug(f"Substituting MixtureOfAgentsConvertRule for {logical_expression}")
|
|
521
579
|
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
580
|
+
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
581
|
+
proposer_model_set = {model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)}
|
|
582
|
+
aggregator_model_set = {model for model in runtime_kwargs["available_models"] if model.is_text_model()}
|
|
583
|
+
proposer_prompt_strategy = PromptStrategy.COT_MOA_PROPOSER_IMAGE if cls._is_image_operation(logical_expression) else PromptStrategy.COT_MOA_PROPOSER
|
|
584
|
+
variable_op_kwargs = [
|
|
527
585
|
{
|
|
528
|
-
"
|
|
529
|
-
"
|
|
530
|
-
"
|
|
586
|
+
"proposer_models": list(proposer_models),
|
|
587
|
+
"temperatures": [temp] * len(proposer_models),
|
|
588
|
+
"aggregator_model": aggregator_model,
|
|
589
|
+
"proposer_prompt_strategy": proposer_prompt_strategy,
|
|
590
|
+
"aggregator_prompt_strategy": PromptStrategy.COT_MOA_AGG,
|
|
591
|
+
"reasoning_effort": runtime_kwargs["reasoning_effort"],
|
|
531
592
|
}
|
|
532
|
-
|
|
593
|
+
for k in cls.num_proposer_models
|
|
594
|
+
for temp in cls.temperatures
|
|
595
|
+
for proposer_models in combinations(proposer_model_set, k)
|
|
596
|
+
for aggregator_model in aggregator_model_set
|
|
597
|
+
]
|
|
533
598
|
|
|
534
|
-
|
|
535
|
-
vision_models = set(get_vision_models())
|
|
536
|
-
text_models = set(get_models())
|
|
537
|
-
|
|
538
|
-
# construct set of proposer models and set of aggregator models
|
|
539
|
-
num_image_fields = sum(
|
|
540
|
-
[
|
|
541
|
-
field.is_image_field
|
|
542
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
543
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
544
|
-
]
|
|
545
|
-
)
|
|
546
|
-
list_image_field = any(
|
|
547
|
-
[
|
|
548
|
-
field.is_image_field and hasattr(field, "element_type")
|
|
549
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
550
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
551
|
-
]
|
|
552
|
-
)
|
|
553
|
-
proposer_model_set, is_image_conversion = text_models, False
|
|
554
|
-
if num_image_fields > 1 or list_image_field:
|
|
555
|
-
proposer_model_set = [model for model in vision_models if not model.is_llama_model()]
|
|
556
|
-
is_image_conversion = True
|
|
557
|
-
elif num_image_fields == 1:
|
|
558
|
-
proposer_model_set = vision_models
|
|
559
|
-
is_image_conversion = True
|
|
560
|
-
aggregator_model_set = text_models
|
|
561
|
-
|
|
562
|
-
# filter un-available models out of sets
|
|
563
|
-
proposer_model_set = {model for model in proposer_model_set if model in physical_op_params["available_models"]}
|
|
564
|
-
aggregator_model_set = {
|
|
565
|
-
model for model in aggregator_model_set if model in physical_op_params["available_models"]
|
|
566
|
-
}
|
|
567
|
-
|
|
568
|
-
# construct MixtureOfAgentsConvert operations for various numbers of proposer models
|
|
569
|
-
# and for every combination of proposer models and aggregator model
|
|
570
|
-
physical_expressions = []
|
|
571
|
-
for k in cls.num_proposer_models:
|
|
572
|
-
for temp in cls.temperatures:
|
|
573
|
-
for proposer_models in combinations(proposer_model_set, k):
|
|
574
|
-
for aggregator_model in aggregator_model_set:
|
|
575
|
-
# construct multi-expression
|
|
576
|
-
op = MixtureOfAgentsConvert(
|
|
577
|
-
proposer_models=list(proposer_models),
|
|
578
|
-
temperatures=[temp] * len(proposer_models),
|
|
579
|
-
aggregator_model=aggregator_model,
|
|
580
|
-
proposer_prompt=op_kwargs.get("prompt"),
|
|
581
|
-
proposer_prompt_strategy=PromptStrategy.COT_MOA_PROPOSER_IMAGE
|
|
582
|
-
if is_image_conversion
|
|
583
|
-
else PromptStrategy.COT_MOA_PROPOSER,
|
|
584
|
-
aggregator_prompt_strategy=PromptStrategy.COT_MOA_AGG,
|
|
585
|
-
**op_kwargs,
|
|
586
|
-
)
|
|
587
|
-
expression = PhysicalExpression(
|
|
588
|
-
operator=op,
|
|
589
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
590
|
-
input_fields=logical_expression.input_fields,
|
|
591
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
592
|
-
generated_fields=logical_expression.generated_fields,
|
|
593
|
-
group_id=logical_expression.group_id,
|
|
594
|
-
)
|
|
595
|
-
physical_expressions.append(expression)
|
|
596
|
-
|
|
597
|
-
logger.debug(f"Done substituting MixtureOfAgentsConvertRule for {logical_expression}")
|
|
598
|
-
deduped_physical_expressions = set(physical_expressions)
|
|
599
|
-
|
|
600
|
-
return deduped_physical_expressions
|
|
599
|
+
return cls._perform_substitution(logical_expression, MixtureOfAgentsConvert, runtime_kwargs, variable_op_kwargs)
|
|
601
600
|
|
|
602
601
|
|
|
603
602
|
class CriticAndRefineConvertRule(ImplementationRule):
|
|
@@ -608,99 +607,32 @@ class CriticAndRefineConvertRule(ImplementationRule):
|
|
|
608
607
|
@classmethod
|
|
609
608
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
610
609
|
logical_op = logical_expression.operator
|
|
611
|
-
|
|
610
|
+
# TODO: remove audio limitation once I add prompts
|
|
611
|
+
is_match = isinstance(logical_op, ConvertScan) and logical_op.udf is None and not cls._is_audio_operation(logical_expression)
|
|
612
612
|
logger.debug(f"CriticAndRefineConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
613
613
|
return is_match
|
|
614
614
|
|
|
615
615
|
@classmethod
|
|
616
|
-
def substitute(cls, logical_expression: LogicalExpression, **
|
|
616
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
617
617
|
logger.debug(f"Substituting CriticAndRefineConvertRule for {logical_expression}")
|
|
618
618
|
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
op_kwargs.update(
|
|
619
|
+
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
620
|
+
models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
|
|
621
|
+
prompt_strategy = PromptStrategy.COT_QA_IMAGE if cls._is_image_operation(logical_expression) else PromptStrategy.COT_QA
|
|
622
|
+
variable_op_kwargs = [
|
|
624
623
|
{
|
|
625
|
-
"
|
|
626
|
-
"
|
|
627
|
-
"
|
|
624
|
+
"model": model,
|
|
625
|
+
"critic_model": critic_model,
|
|
626
|
+
"refine_model": refine_model,
|
|
627
|
+
"prompt_strategy": prompt_strategy,
|
|
628
|
+
"reasoning_effort": runtime_kwargs["reasoning_effort"],
|
|
628
629
|
}
|
|
629
|
-
|
|
630
|
+
for model in models
|
|
631
|
+
for critic_model in models
|
|
632
|
+
for refine_model in models
|
|
633
|
+
]
|
|
630
634
|
|
|
631
|
-
|
|
632
|
-
vision_models = set(get_vision_models())
|
|
633
|
-
text_models = set(get_models())
|
|
634
|
-
pure_text_models = {model for model in text_models if model not in vision_models}
|
|
635
|
-
pure_vision_models = {model for model in vision_models if model not in text_models}
|
|
636
|
-
|
|
637
|
-
# compute attributes about this convert operation
|
|
638
|
-
is_image_conversion = any(
|
|
639
|
-
[
|
|
640
|
-
field.is_image_field
|
|
641
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
642
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
643
|
-
]
|
|
644
|
-
)
|
|
645
|
-
num_image_fields = sum(
|
|
646
|
-
[
|
|
647
|
-
field.is_image_field
|
|
648
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
649
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
650
|
-
]
|
|
651
|
-
)
|
|
652
|
-
list_image_field = any(
|
|
653
|
-
[
|
|
654
|
-
field.is_image_field and hasattr(field, "element_type")
|
|
655
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
656
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
657
|
-
]
|
|
658
|
-
)
|
|
659
|
-
|
|
660
|
-
# identify models which can be used for this convert operation
|
|
661
|
-
models = []
|
|
662
|
-
for model in physical_op_params["available_models"]:
|
|
663
|
-
# skip this model if:
|
|
664
|
-
# 1. this is a pure vision model and we're not doing an image conversion, or
|
|
665
|
-
# 2. this is a pure text model and we're doing an image conversion, or
|
|
666
|
-
# 3. this is a vision model hosted by Together (i.e. LLAMA3 vision) and there is more than one image field
|
|
667
|
-
first_criteria = model in pure_vision_models and not is_image_conversion
|
|
668
|
-
second_criteria = model in pure_text_models and is_image_conversion
|
|
669
|
-
third_criteria = model.is_llama_model() and model.is_vision_model() and (num_image_fields > 1 or list_image_field)
|
|
670
|
-
fourth_criteria = model.is_embedding_model()
|
|
671
|
-
if first_criteria or second_criteria or third_criteria or fourth_criteria:
|
|
672
|
-
continue
|
|
673
|
-
|
|
674
|
-
models.append(model)
|
|
675
|
-
|
|
676
|
-
# TODO: heuristic(s) to narrow the space of critic and refine models we consider using class attributes
|
|
677
|
-
# construct CriticAndRefineConvert operations for every combination of model, critic model, and refinement model
|
|
678
|
-
physical_expressions = []
|
|
679
|
-
for model in models:
|
|
680
|
-
for critic_model in models:
|
|
681
|
-
for refine_model in models:
|
|
682
|
-
# construct multi-expression
|
|
683
|
-
op = CriticAndRefineConvert(
|
|
684
|
-
model=model,
|
|
685
|
-
prompt_strategy=PromptStrategy.COT_QA_IMAGE if is_image_conversion else PromptStrategy.COT_QA,
|
|
686
|
-
critic_model=critic_model,
|
|
687
|
-
refine_model=refine_model,
|
|
688
|
-
**op_kwargs,
|
|
689
|
-
)
|
|
690
|
-
expression = PhysicalExpression(
|
|
691
|
-
operator=op,
|
|
692
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
693
|
-
input_fields=logical_expression.input_fields,
|
|
694
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
695
|
-
generated_fields=logical_expression.generated_fields,
|
|
696
|
-
group_id=logical_expression.group_id,
|
|
697
|
-
)
|
|
698
|
-
physical_expressions.append(expression)
|
|
699
|
-
|
|
700
|
-
logger.debug(f"Done substituting CriticAndRefineConvertRule for {logical_expression}")
|
|
701
|
-
deduped_physical_expressions = set(physical_expressions)
|
|
702
|
-
|
|
703
|
-
return deduped_physical_expressions
|
|
635
|
+
return cls._perform_substitution(logical_expression, CriticAndRefineConvert, runtime_kwargs, variable_op_kwargs)
|
|
704
636
|
|
|
705
637
|
|
|
706
638
|
class SplitConvertRule(ImplementationRule):
|
|
@@ -713,67 +645,29 @@ class SplitConvertRule(ImplementationRule):
|
|
|
713
645
|
@classmethod
|
|
714
646
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
715
647
|
logical_op = logical_expression.operator
|
|
716
|
-
|
|
717
|
-
[
|
|
718
|
-
field.is_image_field
|
|
719
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
720
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
721
|
-
]
|
|
722
|
-
)
|
|
723
|
-
is_match = isinstance(logical_op, ConvertScan) and not is_image_conversion and logical_op.udf is None
|
|
648
|
+
is_match = isinstance(logical_op, ConvertScan) and cls._is_text_only_operation() and logical_op.udf is None
|
|
724
649
|
logger.debug(f"SplitConvertRule matches_pattern: {is_match} for {logical_expression}")
|
|
725
650
|
return is_match
|
|
726
651
|
|
|
727
652
|
@classmethod
|
|
728
|
-
def substitute(cls, logical_expression: LogicalExpression, **
|
|
653
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
729
654
|
logger.debug(f"Substituting SplitConvertRule for {logical_expression}")
|
|
730
655
|
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
op_kwargs = logical_op.get_logical_op_params()
|
|
735
|
-
op_kwargs.update(
|
|
656
|
+
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
657
|
+
models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
|
|
658
|
+
variable_op_kwargs = [
|
|
736
659
|
{
|
|
737
|
-
"
|
|
738
|
-
"
|
|
739
|
-
"
|
|
660
|
+
"model": model,
|
|
661
|
+
"min_size_to_chunk": min_size_to_chunk,
|
|
662
|
+
"num_chunks": num_chunks,
|
|
663
|
+
"reasoning_effort": runtime_kwargs["reasoning_effort"],
|
|
740
664
|
}
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
text_models = set(get_models())
|
|
746
|
-
pure_vision_models = {model for model in vision_models if model not in text_models}
|
|
747
|
-
|
|
748
|
-
physical_expressions = []
|
|
749
|
-
for model in physical_op_params["available_models"]:
|
|
750
|
-
# skip this model if this is a pure image model
|
|
751
|
-
if model in pure_vision_models or model.is_embedding_model():
|
|
752
|
-
continue
|
|
665
|
+
for model in models
|
|
666
|
+
for min_size_to_chunk in cls.min_size_to_chunk
|
|
667
|
+
for num_chunks in cls.num_chunks
|
|
668
|
+
]
|
|
753
669
|
|
|
754
|
-
|
|
755
|
-
for num_chunks in cls.num_chunks:
|
|
756
|
-
# construct multi-expression
|
|
757
|
-
op = SplitConvert(
|
|
758
|
-
model=model,
|
|
759
|
-
num_chunks=num_chunks,
|
|
760
|
-
min_size_to_chunk=min_size_to_chunk,
|
|
761
|
-
**op_kwargs,
|
|
762
|
-
)
|
|
763
|
-
expression = PhysicalExpression(
|
|
764
|
-
operator=op,
|
|
765
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
766
|
-
input_fields=logical_expression.input_fields,
|
|
767
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
768
|
-
generated_fields=logical_expression.generated_fields,
|
|
769
|
-
group_id=logical_expression.group_id,
|
|
770
|
-
)
|
|
771
|
-
physical_expressions.append(expression)
|
|
772
|
-
|
|
773
|
-
logger.debug(f"Done substituting SplitConvertRule for {logical_expression}")
|
|
774
|
-
deduped_physical_expressions = set(physical_expressions)
|
|
775
|
-
|
|
776
|
-
return deduped_physical_expressions
|
|
670
|
+
return cls._perform_substitution(logical_expression, SplitConvert, runtime_kwargs, variable_op_kwargs)
|
|
777
671
|
|
|
778
672
|
|
|
779
673
|
class RetrieveRule(ImplementationRule):
|
|
@@ -789,42 +683,13 @@ class RetrieveRule(ImplementationRule):
|
|
|
789
683
|
return is_match
|
|
790
684
|
|
|
791
685
|
@classmethod
|
|
792
|
-
def substitute(cls, logical_expression: LogicalExpression, **
|
|
686
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
793
687
|
logger.debug(f"Substituting RetrieveRule for {logical_expression}")
|
|
794
688
|
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
for k in ks:
|
|
800
|
-
# get initial set of parameters for physical op
|
|
801
|
-
op_kwargs = logical_op.get_logical_op_params()
|
|
802
|
-
op_kwargs.update(
|
|
803
|
-
{
|
|
804
|
-
"verbose": physical_op_params["verbose"],
|
|
805
|
-
"logical_op_id": logical_op.get_logical_op_id(),
|
|
806
|
-
"logical_op_name": logical_op.logical_op_name(),
|
|
807
|
-
"k": k,
|
|
808
|
-
}
|
|
809
|
-
)
|
|
810
|
-
|
|
811
|
-
# construct multi-expression
|
|
812
|
-
op = RetrieveOp(**op_kwargs)
|
|
813
|
-
expression = PhysicalExpression(
|
|
814
|
-
operator=op,
|
|
815
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
816
|
-
input_fields=logical_expression.input_fields,
|
|
817
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
818
|
-
generated_fields=logical_expression.generated_fields,
|
|
819
|
-
group_id=logical_expression.group_id,
|
|
820
|
-
)
|
|
821
|
-
|
|
822
|
-
physical_expressions.append(expression)
|
|
823
|
-
|
|
824
|
-
logger.debug(f"Done substituting RetrieveRule for {logical_expression}")
|
|
825
|
-
deduped_physical_expressions = set(physical_expressions)
|
|
826
|
-
|
|
827
|
-
return deduped_physical_expressions
|
|
689
|
+
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
690
|
+
ks = cls.k_budgets if logical_expression.operator.k == -1 else [logical_expression.operator.k]
|
|
691
|
+
variable_op_kwargs = [{"k": k} for k in ks]
|
|
692
|
+
return cls._perform_substitution(logical_expression, RetrieveOp, runtime_kwargs, variable_op_kwargs)
|
|
828
693
|
|
|
829
694
|
|
|
830
695
|
class NonLLMFilterRule(ImplementationRule):
|
|
@@ -832,8 +697,8 @@ class NonLLMFilterRule(ImplementationRule):
|
|
|
832
697
|
Substitute a logical expression for a FilteredScan with a non-llm filter physical implementation.
|
|
833
698
|
"""
|
|
834
699
|
|
|
835
|
-
@
|
|
836
|
-
def matches_pattern(logical_expression: LogicalExpression) -> bool:
|
|
700
|
+
@classmethod
|
|
701
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
837
702
|
is_match = (
|
|
838
703
|
isinstance(logical_expression.operator, FilteredScan)
|
|
839
704
|
and logical_expression.operator.filter.filter_fn is not None
|
|
@@ -841,33 +706,10 @@ class NonLLMFilterRule(ImplementationRule):
|
|
|
841
706
|
logger.debug(f"NonLLMFilterRule matches_pattern: {is_match} for {logical_expression}")
|
|
842
707
|
return is_match
|
|
843
708
|
|
|
844
|
-
@
|
|
845
|
-
def substitute(logical_expression: LogicalExpression, **
|
|
709
|
+
@classmethod
|
|
710
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
846
711
|
logger.debug(f"Substituting NonLLMFilterRule for {logical_expression}")
|
|
847
|
-
|
|
848
|
-
logical_op = logical_expression.operator
|
|
849
|
-
op_kwargs = logical_op.get_logical_op_params()
|
|
850
|
-
op_kwargs.update(
|
|
851
|
-
{
|
|
852
|
-
"verbose": physical_op_params["verbose"],
|
|
853
|
-
"logical_op_id": logical_op.get_logical_op_id(),
|
|
854
|
-
"logical_op_name": logical_op.logical_op_name(),
|
|
855
|
-
}
|
|
856
|
-
)
|
|
857
|
-
op = NonLLMFilter(**op_kwargs)
|
|
858
|
-
|
|
859
|
-
expression = PhysicalExpression(
|
|
860
|
-
operator=op,
|
|
861
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
862
|
-
input_fields=logical_expression.input_fields,
|
|
863
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
864
|
-
generated_fields=logical_expression.generated_fields,
|
|
865
|
-
group_id=logical_expression.group_id,
|
|
866
|
-
)
|
|
867
|
-
logger.debug(f"Done substituting NonLLMFilterRule for {logical_expression}")
|
|
868
|
-
deduped_physical_expressions = set([expression])
|
|
869
|
-
|
|
870
|
-
return deduped_physical_expressions
|
|
712
|
+
return cls._perform_substitution(logical_expression, NonLLMFilter, runtime_kwargs)
|
|
871
713
|
|
|
872
714
|
|
|
873
715
|
class LLMFilterRule(ImplementationRule):
|
|
@@ -875,8 +717,8 @@ class LLMFilterRule(ImplementationRule):
|
|
|
875
717
|
Substitute a logical expression for a FilteredScan with an llm filter physical implementation.
|
|
876
718
|
"""
|
|
877
719
|
|
|
878
|
-
@
|
|
879
|
-
def matches_pattern(logical_expression: LogicalExpression) -> bool:
|
|
720
|
+
@classmethod
|
|
721
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
880
722
|
is_match = (
|
|
881
723
|
isinstance(logical_expression.operator, FilteredScan)
|
|
882
724
|
and logical_expression.operator.filter.filter_condition is not None
|
|
@@ -884,82 +726,76 @@ class LLMFilterRule(ImplementationRule):
|
|
|
884
726
|
logger.debug(f"LLMFilterRule matches_pattern: {is_match} for {logical_expression}")
|
|
885
727
|
return is_match
|
|
886
728
|
|
|
887
|
-
@
|
|
888
|
-
def substitute(logical_expression: LogicalExpression, **
|
|
729
|
+
@classmethod
|
|
730
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
889
731
|
logger.debug(f"Substituting LLMFilterRule for {logical_expression}")
|
|
890
732
|
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
733
|
+
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
734
|
+
models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
|
|
735
|
+
# NOTE: right now we exclusively allow image or audio operations, but not both simultaneously
|
|
736
|
+
prompt_strategy, no_reasoning_prompt_strategy = None, None
|
|
737
|
+
no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
|
|
738
|
+
if cls._is_text_only_operation(logical_expression):
|
|
739
|
+
prompt_strategy = PromptStrategy.COT_BOOL
|
|
740
|
+
no_reasoning_prompt_strategy = PromptStrategy.COT_BOOL_NO_REASONING
|
|
741
|
+
elif cls._is_image_operation(logical_expression):
|
|
742
|
+
prompt_strategy = PromptStrategy.COT_BOOL_IMAGE
|
|
743
|
+
no_reasoning_prompt_strategy = PromptStrategy.COT_BOOL_IMAGE_NO_REASONING
|
|
744
|
+
elif cls._is_audio_operation(logical_expression):
|
|
745
|
+
prompt_strategy = PromptStrategy.COT_BOOL_AUDIO
|
|
746
|
+
no_reasoning_prompt_strategy = PromptStrategy.COT_BOOL_AUDIO_NO_REASONING
|
|
747
|
+
variable_op_kwargs = [
|
|
894
748
|
{
|
|
895
|
-
"
|
|
896
|
-
"
|
|
897
|
-
"
|
|
749
|
+
"model": model,
|
|
750
|
+
"prompt_strategy": no_reasoning_prompt_strategy if model.is_reasoning_model() and no_reasoning else prompt_strategy,
|
|
751
|
+
"reasoning_effort": runtime_kwargs["reasoning_effort"]
|
|
898
752
|
}
|
|
899
|
-
|
|
753
|
+
for model in models
|
|
754
|
+
]
|
|
900
755
|
|
|
901
|
-
|
|
902
|
-
vision_models = set(get_vision_models())
|
|
903
|
-
text_models = set(get_models())
|
|
904
|
-
pure_text_models = {model for model in text_models if model not in vision_models}
|
|
905
|
-
pure_vision_models = {model for model in vision_models if model not in text_models}
|
|
906
|
-
|
|
907
|
-
# compute attributes about this filter operation
|
|
908
|
-
is_image_filter = any(
|
|
909
|
-
[
|
|
910
|
-
field.is_image_field
|
|
911
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
912
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
913
|
-
]
|
|
914
|
-
)
|
|
915
|
-
num_image_fields = sum(
|
|
916
|
-
[
|
|
917
|
-
field.is_image_field
|
|
918
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
919
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
920
|
-
]
|
|
921
|
-
)
|
|
922
|
-
list_image_field = any(
|
|
923
|
-
[
|
|
924
|
-
field.is_image_field and hasattr(field, "element_type")
|
|
925
|
-
for field_name, field in logical_expression.input_fields.items()
|
|
926
|
-
if field_name.split(".")[-1] in logical_expression.depends_on_field_names
|
|
927
|
-
]
|
|
928
|
-
)
|
|
756
|
+
return cls._perform_substitution(logical_expression, LLMFilter, runtime_kwargs, variable_op_kwargs)
|
|
929
757
|
|
|
930
|
-
physical_expressions = []
|
|
931
|
-
for model in physical_op_params["available_models"]:
|
|
932
|
-
# skip this model if:
|
|
933
|
-
# 1. this is a pure vision model and we're not doing an image filter, or
|
|
934
|
-
# 2. this is a pure text model and we're doing an image filter, or
|
|
935
|
-
# 3. this is a vision model hosted by Together (i.e. LLAMA3 vision) and there is more than one image field
|
|
936
|
-
first_criteria = model in pure_vision_models and not is_image_filter
|
|
937
|
-
second_criteria = model in pure_text_models and is_image_filter
|
|
938
|
-
third_criteria = model.is_llama_model() and model.is_vision_model() and (num_image_fields > 1 or list_image_field)
|
|
939
|
-
fourth_criteria = model.is_embedding_model()
|
|
940
|
-
if first_criteria or second_criteria or third_criteria or fourth_criteria:
|
|
941
|
-
continue
|
|
942
758
|
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
**op_kwargs,
|
|
948
|
-
)
|
|
949
|
-
expression = PhysicalExpression(
|
|
950
|
-
operator=op,
|
|
951
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
952
|
-
input_fields=logical_expression.input_fields,
|
|
953
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
954
|
-
generated_fields=logical_expression.generated_fields,
|
|
955
|
-
group_id=logical_expression.group_id,
|
|
956
|
-
)
|
|
957
|
-
physical_expressions.append(expression)
|
|
759
|
+
class LLMJoinRule(ImplementationRule):
|
|
760
|
+
"""
|
|
761
|
+
Substitute a logical expression for a JoinOp with an (LLM) NestedLoopsJoin physical implementation.
|
|
762
|
+
"""
|
|
958
763
|
|
|
959
|
-
|
|
960
|
-
|
|
764
|
+
@classmethod
|
|
765
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
766
|
+
is_match = isinstance(logical_expression.operator, JoinOp)
|
|
767
|
+
logger.debug(f"LLMJoinRule matches_pattern: {is_match} for {logical_expression}")
|
|
768
|
+
return is_match
|
|
961
769
|
|
|
962
|
-
|
|
770
|
+
@classmethod
|
|
771
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
772
|
+
logger.debug(f"Substituting LLMJoinRule for {logical_expression}")
|
|
773
|
+
|
|
774
|
+
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
775
|
+
models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
|
|
776
|
+
# NOTE: right now we exclusively allow image or audio operations, but not both simultaneously
|
|
777
|
+
prompt_strategy, no_reasoning_prompt_strategy = None, None
|
|
778
|
+
no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
|
|
779
|
+
if cls._is_text_only_operation(logical_expression):
|
|
780
|
+
prompt_strategy = PromptStrategy.COT_JOIN
|
|
781
|
+
no_reasoning_prompt_strategy = PromptStrategy.COT_JOIN_NO_REASONING
|
|
782
|
+
elif cls._is_image_operation(logical_expression):
|
|
783
|
+
prompt_strategy = PromptStrategy.COT_JOIN_IMAGE
|
|
784
|
+
no_reasoning_prompt_strategy = PromptStrategy.COT_JOIN_IMAGE_NO_REASONING
|
|
785
|
+
elif cls._is_audio_operation(logical_expression):
|
|
786
|
+
prompt_strategy = PromptStrategy.COT_JOIN_AUDIO
|
|
787
|
+
no_reasoning_prompt_strategy = PromptStrategy.COT_JOIN_AUDIO_NO_REASONING
|
|
788
|
+
variable_op_kwargs = [
|
|
789
|
+
{
|
|
790
|
+
"model": model,
|
|
791
|
+
"prompt_strategy": no_reasoning_prompt_strategy if model.is_reasoning_model() and no_reasoning else prompt_strategy,
|
|
792
|
+
"join_parallelism": runtime_kwargs["join_parallelism"],
|
|
793
|
+
"reasoning_effort": runtime_kwargs["reasoning_effort"],
|
|
794
|
+
}
|
|
795
|
+
for model in models
|
|
796
|
+
]
|
|
797
|
+
|
|
798
|
+
return cls._perform_substitution(logical_expression, NestedLoopsJoin, runtime_kwargs, variable_op_kwargs)
|
|
963
799
|
|
|
964
800
|
|
|
965
801
|
class AggregateRule(ImplementationRule):
|
|
@@ -967,47 +803,71 @@ class AggregateRule(ImplementationRule):
|
|
|
967
803
|
Substitute the logical expression for an aggregate with its physical counterpart.
|
|
968
804
|
"""
|
|
969
805
|
|
|
970
|
-
@
|
|
971
|
-
def matches_pattern(logical_expression: LogicalExpression) -> bool:
|
|
806
|
+
@classmethod
|
|
807
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
972
808
|
is_match = isinstance(logical_expression.operator, Aggregate)
|
|
973
809
|
logger.debug(f"AggregateRule matches_pattern: {is_match} for {logical_expression}")
|
|
974
810
|
return is_match
|
|
975
811
|
|
|
976
|
-
@
|
|
977
|
-
def substitute(logical_expression: LogicalExpression, **
|
|
812
|
+
@classmethod
|
|
813
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
978
814
|
logger.debug(f"Substituting AggregateRule for {logical_expression}")
|
|
979
815
|
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
"logical_op_name": logical_op.logical_op_name(),
|
|
987
|
-
}
|
|
988
|
-
)
|
|
989
|
-
|
|
990
|
-
op = None
|
|
991
|
-
if logical_op.agg_func == AggFunc.COUNT:
|
|
992
|
-
op = CountAggregateOp(**op_kwargs)
|
|
993
|
-
elif logical_op.agg_func == AggFunc.AVERAGE:
|
|
994
|
-
op = AverageAggregateOp(**op_kwargs)
|
|
816
|
+
# get the physical op class based on the aggregation function
|
|
817
|
+
physical_op_class = None
|
|
818
|
+
if logical_expression.operator.agg_func == AggFunc.COUNT:
|
|
819
|
+
physical_op_class = CountAggregateOp
|
|
820
|
+
elif logical_expression.operator.agg_func == AggFunc.AVERAGE:
|
|
821
|
+
physical_op_class = AverageAggregateOp
|
|
995
822
|
else:
|
|
996
|
-
raise Exception(f"Cannot support aggregate function: {
|
|
997
|
-
|
|
998
|
-
expression = PhysicalExpression(
|
|
999
|
-
operator=op,
|
|
1000
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
1001
|
-
input_fields=logical_expression.input_fields,
|
|
1002
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
1003
|
-
generated_fields=logical_expression.generated_fields,
|
|
1004
|
-
group_id=logical_expression.group_id,
|
|
1005
|
-
)
|
|
823
|
+
raise Exception(f"Cannot support aggregate function: {logical_expression.operator.agg_func}")
|
|
1006
824
|
|
|
1007
|
-
|
|
1008
|
-
|
|
825
|
+
# perform the substitution
|
|
826
|
+
return cls._perform_substitution(logical_expression, physical_op_class, runtime_kwargs)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
class AddContextsBeforeComputeRule(ImplementationRule):
|
|
830
|
+
"""
|
|
831
|
+
Searches the ContextManager for additional contexts which may be useful for the given computation.
|
|
832
|
+
|
|
833
|
+
TODO: track cost of generating search query
|
|
834
|
+
"""
|
|
835
|
+
k = 1
|
|
836
|
+
SEARCH_GENERATOR_PROMPT = CONTEXT_SEARCH_PROMPT
|
|
837
|
+
|
|
838
|
+
@classmethod
|
|
839
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
840
|
+
is_match = isinstance(logical_expression.operator, ComputeOperator)
|
|
841
|
+
logger.debug(f"AddContextsBeforeComputeRule matches_pattern: {is_match} for {logical_expression}")
|
|
842
|
+
return is_match
|
|
1009
843
|
|
|
1010
|
-
|
|
844
|
+
@classmethod
|
|
845
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
846
|
+
logger.debug(f"Substituting AddContextsBeforeComputeRule for {logical_expression}")
|
|
847
|
+
|
|
848
|
+
# load an LLM to generate a short search query
|
|
849
|
+
model = None
|
|
850
|
+
if os.getenv("OPENAI_API_KEY"):
|
|
851
|
+
model = "openai/gpt-4o-mini"
|
|
852
|
+
elif os.getenv("ANTHROPIC_API_KEY"):
|
|
853
|
+
model = "anthropic/claude-3-5-sonnet-20241022"
|
|
854
|
+
elif os.getenv("GEMINI_API_KEY"):
|
|
855
|
+
model = "vertex_ai/gemini-2.0-flash"
|
|
856
|
+
elif os.getenv("TOGETHER_API_KEY"):
|
|
857
|
+
model = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
|
858
|
+
|
|
859
|
+
# importing litellm here because importing above causes deprecation warning
|
|
860
|
+
import litellm
|
|
861
|
+
|
|
862
|
+
# retrieve any additional context which may be useful
|
|
863
|
+
cm = ContextManager()
|
|
864
|
+
response = litellm.completion(
|
|
865
|
+
model=model,
|
|
866
|
+
messages=[{"role": "user", "content": cls.SEARCH_GENERATOR_PROMPT.format(instruction=logical_expression.operator.instruction)}]
|
|
867
|
+
)
|
|
868
|
+
query = response.choices[0].message.content
|
|
869
|
+
variable_op_kwargs = {"additional_contexts": cm.search_context(query, k=cls.k, where={"materialized": True})}
|
|
870
|
+
return cls._perform_substitution(logical_expression, SmolAgentsCompute, runtime_kwargs, variable_op_kwargs)
|
|
1011
871
|
|
|
1012
872
|
|
|
1013
873
|
class BasicSubstitutionRule(ImplementationRule):
|
|
@@ -1018,11 +878,13 @@ class BasicSubstitutionRule(ImplementationRule):
|
|
|
1018
878
|
|
|
1019
879
|
LOGICAL_OP_CLASS_TO_PHYSICAL_OP_CLASS_MAP = {
|
|
1020
880
|
BaseScan: MarshalAndScanDataOp,
|
|
1021
|
-
|
|
881
|
+
# ComputeOperator: SmolAgentsCompute,
|
|
882
|
+
SearchOperator: SmolAgentsSearch, # SmolAgentsManagedSearch, # SmolAgentsCustomManagedSearch
|
|
883
|
+
ContextScan: ContextScanOp,
|
|
884
|
+
Distinct: DistinctOp,
|
|
1022
885
|
LimitScan: LimitScanOp,
|
|
1023
886
|
Project: ProjectOp,
|
|
1024
887
|
GroupByAggregate: ApplyGroupByOp,
|
|
1025
|
-
MapScan: MapOp,
|
|
1026
888
|
}
|
|
1027
889
|
|
|
1028
890
|
@classmethod
|
|
@@ -1033,31 +895,7 @@ class BasicSubstitutionRule(ImplementationRule):
|
|
|
1033
895
|
return is_match
|
|
1034
896
|
|
|
1035
897
|
@classmethod
|
|
1036
|
-
def substitute(cls, logical_expression: LogicalExpression, **
|
|
898
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
1037
899
|
logger.debug(f"Substituting BasicSubstitutionRule for {logical_expression}")
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
op_kwargs = logical_op.get_logical_op_params()
|
|
1041
|
-
op_kwargs.update(
|
|
1042
|
-
{
|
|
1043
|
-
"verbose": physical_op_params["verbose"],
|
|
1044
|
-
"logical_op_id": logical_op.get_logical_op_id(),
|
|
1045
|
-
"logical_op_name": logical_op.logical_op_name(),
|
|
1046
|
-
}
|
|
1047
|
-
)
|
|
1048
|
-
physical_op_class = cls.LOGICAL_OP_CLASS_TO_PHYSICAL_OP_CLASS_MAP[logical_op.__class__]
|
|
1049
|
-
op = physical_op_class(**op_kwargs)
|
|
1050
|
-
|
|
1051
|
-
expression = PhysicalExpression(
|
|
1052
|
-
operator=op,
|
|
1053
|
-
input_group_ids=logical_expression.input_group_ids,
|
|
1054
|
-
input_fields=logical_expression.input_fields,
|
|
1055
|
-
depends_on_field_names=logical_expression.depends_on_field_names,
|
|
1056
|
-
generated_fields=logical_expression.generated_fields,
|
|
1057
|
-
group_id=logical_expression.group_id,
|
|
1058
|
-
)
|
|
1059
|
-
|
|
1060
|
-
logger.debug(f"Done substituting BasicSubstitutionRule for {logical_expression}")
|
|
1061
|
-
deduped_physical_expressions = set([expression])
|
|
1062
|
-
|
|
1063
|
-
return deduped_physical_expressions
|
|
900
|
+
physical_op_class = cls.LOGICAL_OP_CLASS_TO_PHYSICAL_OP_CLASS_MAP[logical_expression.operator.__class__]
|
|
901
|
+
return cls._perform_substitution(logical_expression, physical_op_class, runtime_kwargs)
|