palimpzest 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +1 -0
- palimpzest/core/data/dataset.py +33 -5
- palimpzest/core/elements/groupbysig.py +5 -1
- palimpzest/core/elements/records.py +16 -7
- palimpzest/core/lib/schemas.py +20 -3
- palimpzest/core/models.py +4 -4
- palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
- palimpzest/query/execution/execution_strategy.py +8 -8
- palimpzest/query/execution/mab_execution_strategy.py +30 -11
- palimpzest/query/execution/parallel_execution_strategy.py +31 -7
- palimpzest/query/execution/single_threaded_execution_strategy.py +23 -6
- palimpzest/query/operators/__init__.py +7 -6
- palimpzest/query/operators/aggregate.py +110 -5
- palimpzest/query/operators/convert.py +1 -1
- palimpzest/query/operators/join.py +279 -23
- palimpzest/query/operators/logical.py +20 -8
- palimpzest/query/operators/mixture_of_agents.py +3 -1
- palimpzest/query/operators/physical.py +5 -2
- palimpzest/query/operators/{retrieve.py → topk.py} +10 -10
- palimpzest/query/optimizer/__init__.py +7 -3
- palimpzest/query/optimizer/cost_model.py +5 -5
- palimpzest/query/optimizer/optimizer.py +3 -2
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/rules.py +31 -11
- palimpzest/query/optimizer/tasks.py +4 -4
- palimpzest/utils/progress.py +19 -17
- palimpzest/validator/validator.py +7 -7
- {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/METADATA +26 -66
- {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/RECORD +32 -32
- {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -9,7 +9,7 @@ from palimpzest.constants import AggFunc, Cardinality
|
|
|
9
9
|
from palimpzest.core.data import context, dataset
|
|
10
10
|
from palimpzest.core.elements.filters import Filter
|
|
11
11
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
12
|
-
from palimpzest.core.lib.schemas import Average, Count, Max, Min
|
|
12
|
+
from palimpzest.core.lib.schemas import Average, Count, Max, Min, Sum
|
|
13
13
|
from palimpzest.utils.hash_helpers import hash_for_id
|
|
14
14
|
|
|
15
15
|
|
|
@@ -25,7 +25,7 @@ class LogicalOperator:
|
|
|
25
25
|
- LimitScan (scans up to N records from a Set)
|
|
26
26
|
- GroupByAggregate (applies a group by on the Set)
|
|
27
27
|
- Aggregate (applies an aggregation on the Set)
|
|
28
|
-
-
|
|
28
|
+
- TopKScan (fetches documents from a provided input for a given query)
|
|
29
29
|
- Map (applies a function to each record in the Set without adding any new columns)
|
|
30
30
|
- ComputeOperator (executes a computation described in natural language)
|
|
31
31
|
- SearchOperator (executes a search query on the input Context)
|
|
@@ -160,6 +160,8 @@ class Aggregate(LogicalOperator):
|
|
|
160
160
|
kwargs["output_schema"] = Count
|
|
161
161
|
elif agg_func == AggFunc.AVERAGE:
|
|
162
162
|
kwargs["output_schema"] = Average
|
|
163
|
+
elif agg_func == AggFunc.SUM:
|
|
164
|
+
kwargs["output_schema"] = Sum
|
|
163
165
|
elif agg_func == AggFunc.MIN:
|
|
164
166
|
kwargs["output_schema"] = Min
|
|
165
167
|
elif agg_func == AggFunc.MAX:
|
|
@@ -411,17 +413,25 @@ class GroupByAggregate(LogicalOperator):
|
|
|
411
413
|
|
|
412
414
|
|
|
413
415
|
class JoinOp(LogicalOperator):
|
|
414
|
-
def __init__(self, condition: str, desc: str | None = None, *args, **kwargs):
|
|
416
|
+
def __init__(self, condition: str, on: list[str] | None = None, how: str = "inner", desc: str | None = None, *args, **kwargs):
|
|
415
417
|
super().__init__(*args, **kwargs)
|
|
416
418
|
self.condition = condition
|
|
419
|
+
self.on = on
|
|
420
|
+
self.how = how
|
|
417
421
|
self.desc = desc
|
|
418
422
|
|
|
419
423
|
def __str__(self):
|
|
420
|
-
return f"Join(condition={self.condition})"
|
|
424
|
+
return f"Join(condition={self.condition})" if self.on is None else f"Join(on={self.on}, how={self.how})"
|
|
421
425
|
|
|
422
426
|
def get_logical_id_params(self) -> dict:
|
|
423
427
|
logical_id_params = super().get_logical_id_params()
|
|
424
|
-
logical_id_params = {
|
|
428
|
+
logical_id_params = {
|
|
429
|
+
"condition": self.condition,
|
|
430
|
+
"on": self.on,
|
|
431
|
+
"how": self.how,
|
|
432
|
+
"desc": self.desc,
|
|
433
|
+
**logical_id_params,
|
|
434
|
+
}
|
|
425
435
|
|
|
426
436
|
return logical_id_params
|
|
427
437
|
|
|
@@ -429,6 +439,8 @@ class JoinOp(LogicalOperator):
|
|
|
429
439
|
logical_op_params = super().get_logical_op_params()
|
|
430
440
|
logical_op_params = {
|
|
431
441
|
"condition": self.condition,
|
|
442
|
+
"on": self.on,
|
|
443
|
+
"how": self.how,
|
|
432
444
|
"desc": self.desc,
|
|
433
445
|
**logical_op_params,
|
|
434
446
|
}
|
|
@@ -484,8 +496,8 @@ class Project(LogicalOperator):
|
|
|
484
496
|
return logical_op_params
|
|
485
497
|
|
|
486
498
|
|
|
487
|
-
class
|
|
488
|
-
"""A
|
|
499
|
+
class TopKScan(LogicalOperator):
|
|
500
|
+
"""A TopKScan is a logical operator that represents a scan of a particular input Dataset, with a top-k operation applied."""
|
|
489
501
|
|
|
490
502
|
def __init__(
|
|
491
503
|
self,
|
|
@@ -505,7 +517,7 @@ class RetrieveScan(LogicalOperator):
|
|
|
505
517
|
self.k = k
|
|
506
518
|
|
|
507
519
|
def __str__(self):
|
|
508
|
-
return f"
|
|
520
|
+
return f"TopKScan({self.input_schema} -> {str(self.output_schema)})"
|
|
509
521
|
|
|
510
522
|
def get_logical_id_params(self) -> dict:
|
|
511
523
|
# NOTE: if we allow optimization over index, then we will need to include it in the id params
|
|
@@ -75,8 +75,9 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
75
75
|
In practice, this naive quality estimate will be overwritten by the CostModel's estimate
|
|
76
76
|
once it executes a few instances of the operator.
|
|
77
77
|
"""
|
|
78
|
-
# temporarily set self.model so that super().naive_cost_estimates(...) can compute an estimate
|
|
78
|
+
# temporarily set self.model and self.prompt_strategy so that super().naive_cost_estimates(...) can compute an estimate
|
|
79
79
|
self.model = self.proposer_models[0]
|
|
80
|
+
self.prompt_strategy = PromptStrategy.MAP_MOA_PROPOSER
|
|
80
81
|
|
|
81
82
|
# get naive cost estimates for single LLM call and scale it by number of LLMs used in MoA
|
|
82
83
|
naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
|
|
@@ -98,6 +99,7 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
98
99
|
|
|
99
100
|
# reset self.model to be None
|
|
100
101
|
self.model = None
|
|
102
|
+
self.prompt_strategy = None
|
|
101
103
|
|
|
102
104
|
return naive_op_cost_estimates
|
|
103
105
|
|
|
@@ -42,10 +42,13 @@ class PhysicalOperator:
|
|
|
42
42
|
self.op_id = None
|
|
43
43
|
|
|
44
44
|
# compute the input modalities (if any) for this physical operator
|
|
45
|
+
depends_on_short_field_names = [field.split(".")[-1] for field in self.depends_on] if self.depends_on is not None else None
|
|
45
46
|
self.input_modalities = None
|
|
46
47
|
if self.input_schema is not None:
|
|
47
48
|
self.input_modalities = set()
|
|
48
|
-
for field in self.input_schema.model_fields.
|
|
49
|
+
for field_name, field in self.input_schema.model_fields.items():
|
|
50
|
+
if self.depends_on is not None and field_name not in depends_on_short_field_names:
|
|
51
|
+
continue
|
|
49
52
|
field_type = field.annotation
|
|
50
53
|
if field_type in IMAGE_FIELD_TYPES:
|
|
51
54
|
self.input_modalities.add(Modality.IMAGE)
|
|
@@ -191,7 +194,7 @@ class PhysicalOperator:
|
|
|
191
194
|
in the candidate. This is important for operators with retry logic, where we may only need to
|
|
192
195
|
recompute a subset of self.generated_fields.
|
|
193
196
|
|
|
194
|
-
Right now this is only used by convert and
|
|
197
|
+
Right now this is only used by convert and top-k operators.
|
|
195
198
|
"""
|
|
196
199
|
fields_to_generate = [
|
|
197
200
|
field_name
|
|
@@ -17,7 +17,7 @@ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, Recor
|
|
|
17
17
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class
|
|
20
|
+
class TopKOp(PhysicalOperator):
|
|
21
21
|
def __init__(
|
|
22
22
|
self,
|
|
23
23
|
index: Collection,
|
|
@@ -29,7 +29,7 @@ class RetrieveOp(PhysicalOperator):
|
|
|
29
29
|
**kwargs,
|
|
30
30
|
) -> None:
|
|
31
31
|
"""
|
|
32
|
-
Initialize the
|
|
32
|
+
Initialize the TopKOp object.
|
|
33
33
|
|
|
34
34
|
Args:
|
|
35
35
|
index (Collection): The PZ index to use for retrieval.
|
|
@@ -59,7 +59,7 @@ class RetrieveOp(PhysicalOperator):
|
|
|
59
59
|
|
|
60
60
|
def __str__(self):
|
|
61
61
|
op = super().__str__()
|
|
62
|
-
op += f"
|
|
62
|
+
op += f" Top-K: {self.index.__class__.__name__} with k={self.k}\n"
|
|
63
63
|
return op
|
|
64
64
|
|
|
65
65
|
def get_id_params(self):
|
|
@@ -89,8 +89,8 @@ class RetrieveOp(PhysicalOperator):
|
|
|
89
89
|
|
|
90
90
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
91
91
|
"""
|
|
92
|
-
Compute naive cost estimates for the
|
|
93
|
-
that the
|
|
92
|
+
Compute naive cost estimates for the Top-K operation. These estimates assume
|
|
93
|
+
that the Top-K (1) has negligible cost and (2) has perfect quality.
|
|
94
94
|
"""
|
|
95
95
|
return OperatorCostEstimates(
|
|
96
96
|
cardinality=source_op_cost_estimates.cardinality,
|
|
@@ -101,7 +101,7 @@ class RetrieveOp(PhysicalOperator):
|
|
|
101
101
|
|
|
102
102
|
def default_search_func(self, index: Collection, query: list[str] | list[list[float]], k: int) -> list[str] | list[list[str]]:
|
|
103
103
|
"""
|
|
104
|
-
Default search function for the
|
|
104
|
+
Default search function for the Top-K operation. This function uses the index to
|
|
105
105
|
retrieve the top-k results for the given query. The query will be a (possibly singleton)
|
|
106
106
|
list of strings or a list of lists of floats (i.e., embeddings). The function will return
|
|
107
107
|
the top-k results per-query in (descending) sorted order. If the input is a singleton list,
|
|
@@ -111,7 +111,7 @@ class RetrieveOp(PhysicalOperator):
|
|
|
111
111
|
Args:
|
|
112
112
|
index (PZIndex): The index to use for retrieval.
|
|
113
113
|
query (list[str] | list[list[float]]): The query (or queries) to search for.
|
|
114
|
-
k (int): The maximum number of results the
|
|
114
|
+
k (int): The maximum number of results the top-k operator will return.
|
|
115
115
|
|
|
116
116
|
Returns:
|
|
117
117
|
list[str] | list[list[str]]: The top results in (descending) sorted order per query.
|
|
@@ -260,10 +260,10 @@ class RetrieveOp(PhysicalOperator):
|
|
|
260
260
|
top_results = self.search_func(self.index, inputs, self.k)
|
|
261
261
|
|
|
262
262
|
except Exception:
|
|
263
|
-
top_results = ["error-in-
|
|
264
|
-
os.makedirs("
|
|
263
|
+
top_results = ["error-in-topk"]
|
|
264
|
+
os.makedirs("topk-errors", exist_ok=True)
|
|
265
265
|
ts = time.time()
|
|
266
|
-
with open(f"
|
|
266
|
+
with open(f"topk-errors/error-{ts}.txt", "w") as f:
|
|
267
267
|
f.write(str(query))
|
|
268
268
|
|
|
269
269
|
# TODO: the user is always right! let's drop this post-processing in the future
|
|
@@ -39,10 +39,10 @@ from palimpzest.query.optimizer.rules import (
|
|
|
39
39
|
RAGRule as _RAGRule,
|
|
40
40
|
)
|
|
41
41
|
from palimpzest.query.optimizer.rules import (
|
|
42
|
-
|
|
42
|
+
RelationalJoinRule as _RelationalJoinRule,
|
|
43
43
|
)
|
|
44
44
|
from palimpzest.query.optimizer.rules import (
|
|
45
|
-
|
|
45
|
+
ReorderConverts as _ReorderConverts,
|
|
46
46
|
)
|
|
47
47
|
from palimpzest.query.optimizer.rules import (
|
|
48
48
|
Rule as _Rule,
|
|
@@ -53,6 +53,9 @@ from palimpzest.query.optimizer.rules import (
|
|
|
53
53
|
from palimpzest.query.optimizer.rules import (
|
|
54
54
|
SplitRule as _SplitRule,
|
|
55
55
|
)
|
|
56
|
+
from palimpzest.query.optimizer.rules import (
|
|
57
|
+
TopKRule as _TopKRule,
|
|
58
|
+
)
|
|
56
59
|
from palimpzest.query.optimizer.rules import (
|
|
57
60
|
TransformationRule as _TransformationRule,
|
|
58
61
|
)
|
|
@@ -72,8 +75,9 @@ ALL_RULES = [
|
|
|
72
75
|
_NonLLMFilterRule,
|
|
73
76
|
_PushDownFilter,
|
|
74
77
|
_RAGRule,
|
|
78
|
+
_RelationalJoinRule,
|
|
75
79
|
_ReorderConverts,
|
|
76
|
-
|
|
80
|
+
_TopKRule,
|
|
77
81
|
_Rule,
|
|
78
82
|
_SemanticAggregateRule,
|
|
79
83
|
_SplitRule,
|
|
@@ -131,17 +131,17 @@ class SampleBasedCostModel:
|
|
|
131
131
|
# compute selectivity
|
|
132
132
|
selectivity = physical_op_df.passed_operator.sum() / num_source_records
|
|
133
133
|
|
|
134
|
+
# compute quality; if all qualities are None then this will be NaN
|
|
135
|
+
quality = physical_op_df.quality.mean()
|
|
136
|
+
|
|
137
|
+
# set operator stats for this physical operator
|
|
134
138
|
operator_to_stats[unique_logical_op_id][full_op_id] = {
|
|
135
139
|
"cost": physical_op_df.cost_per_record.mean(),
|
|
136
140
|
"time": physical_op_df.time_per_record.mean(),
|
|
137
|
-
"quality":
|
|
141
|
+
"quality": 1.0 if pd.isna(quality) else quality,
|
|
138
142
|
"selectivity": selectivity,
|
|
139
143
|
}
|
|
140
144
|
|
|
141
|
-
# if this is an experiment, log the dataframe and operator_to_stats dictionary
|
|
142
|
-
if self.exp_name is not None:
|
|
143
|
-
operator_stats_df.to_csv(f"opt-profiling-data/{self.exp_name}-operator-stats.csv", index=False)
|
|
144
|
-
|
|
145
145
|
logger.debug(f"Done computing operator statistics for {len(operator_to_stats)} operators!")
|
|
146
146
|
return operator_to_stats
|
|
147
147
|
|
|
@@ -284,10 +284,11 @@ class Optimizer:
|
|
|
284
284
|
all_properties["filters"] = set([op_filter_str])
|
|
285
285
|
|
|
286
286
|
elif isinstance(op, JoinOp):
|
|
287
|
+
unique_join_str = str(sorted(op.on)) if op.condition is None else op.condition
|
|
287
288
|
if "joins" in all_properties:
|
|
288
|
-
all_properties["joins"].add(
|
|
289
|
+
all_properties["joins"].add(unique_join_str)
|
|
289
290
|
else:
|
|
290
|
-
all_properties["joins"] = set([
|
|
291
|
+
all_properties["joins"] = set([unique_join_str])
|
|
291
292
|
|
|
292
293
|
elif isinstance(op, LimitScan):
|
|
293
294
|
op_limit_str = op.get_logical_op_id()
|
|
@@ -203,9 +203,8 @@ class PhysicalPlan(Plan):
|
|
|
203
203
|
# return the current index and the upstream unique full_op_ids for this operator
|
|
204
204
|
return current_idx, self.operator.get_full_op_id(), upstream_map[this_unique_full_op_id]
|
|
205
205
|
|
|
206
|
-
def get_upstream_unique_full_op_ids(self,
|
|
207
|
-
"""Return the list of unique full_op_ids for the upstream operators of
|
|
208
|
-
unique_full_op_id = f"{topo_idx}-{operator.get_full_op_id()}"
|
|
206
|
+
def get_upstream_unique_full_op_ids(self, unique_full_op_id: str) -> list[str]:
|
|
207
|
+
"""Return the list of unique full_op_ids for the upstream operators of the operator specified by `unique_full_op_id`."""
|
|
209
208
|
return self.unique_full_op_id_to_upstream_full_op_ids[unique_full_op_id]
|
|
210
209
|
|
|
211
210
|
def _compute_source_unique_full_op_ids_map(self, source_map: dict[str, list[str]], current_idx: int | None = None) -> tuple[int, str]:
|
|
@@ -19,13 +19,14 @@ from palimpzest.query.operators.aggregate import (
|
|
|
19
19
|
MaxAggregateOp,
|
|
20
20
|
MinAggregateOp,
|
|
21
21
|
SemanticAggregate,
|
|
22
|
+
SumAggregateOp,
|
|
22
23
|
)
|
|
23
24
|
from palimpzest.query.operators.compute import SmolAgentsCompute
|
|
24
25
|
from palimpzest.query.operators.convert import LLMConvertBonded, NonLLMConvert
|
|
25
26
|
from palimpzest.query.operators.critique_and_refine import CritiqueAndRefineConvert, CritiqueAndRefineFilter
|
|
26
27
|
from palimpzest.query.operators.distinct import DistinctOp
|
|
27
28
|
from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter
|
|
28
|
-
from palimpzest.query.operators.join import EmbeddingJoin, NestedLoopsJoin
|
|
29
|
+
from palimpzest.query.operators.join import EmbeddingJoin, NestedLoopsJoin, RelationalJoin
|
|
29
30
|
from palimpzest.query.operators.limit import LimitScanOp
|
|
30
31
|
from palimpzest.query.operators.logical import (
|
|
31
32
|
Aggregate,
|
|
@@ -39,19 +40,19 @@ from palimpzest.query.operators.logical import (
|
|
|
39
40
|
JoinOp,
|
|
40
41
|
LimitScan,
|
|
41
42
|
Project,
|
|
42
|
-
RetrieveScan,
|
|
43
43
|
SearchOperator,
|
|
44
|
+
TopKScan,
|
|
44
45
|
)
|
|
45
46
|
from palimpzest.query.operators.mixture_of_agents import MixtureOfAgentsConvert, MixtureOfAgentsFilter
|
|
46
47
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
47
48
|
from palimpzest.query.operators.project import ProjectOp
|
|
48
49
|
from palimpzest.query.operators.rag import RAGConvert, RAGFilter
|
|
49
|
-
from palimpzest.query.operators.retrieve import RetrieveOp
|
|
50
50
|
from palimpzest.query.operators.scan import ContextScanOp, MarshalAndScanDataOp
|
|
51
51
|
from palimpzest.query.operators.search import (
|
|
52
52
|
SmolAgentsSearch, # SmolAgentsCustomManagedSearch, # SmolAgentsManagedSearch
|
|
53
53
|
)
|
|
54
54
|
from palimpzest.query.operators.split import SplitConvert, SplitFilter
|
|
55
|
+
from palimpzest.query.operators.topk import TopKOp
|
|
55
56
|
from palimpzest.query.optimizer.primitives import Expression, Group, LogicalExpression, PhysicalExpression
|
|
56
57
|
|
|
57
58
|
logger = logging.getLogger(__name__)
|
|
@@ -796,26 +797,26 @@ class SplitRule(ImplementationRule):
|
|
|
796
797
|
return cls._perform_substitution(logical_expression, phys_op_cls, runtime_kwargs, variable_op_kwargs)
|
|
797
798
|
|
|
798
799
|
|
|
799
|
-
class
|
|
800
|
+
class TopKRule(ImplementationRule):
|
|
800
801
|
"""
|
|
801
|
-
Substitute a logical expression for a
|
|
802
|
+
Substitute a logical expression for a TopKScan with a TopK physical implementation.
|
|
802
803
|
"""
|
|
803
804
|
k_budgets = [1, 3, 5, 10, 15, 20, 25]
|
|
804
805
|
|
|
805
806
|
@classmethod
|
|
806
807
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
807
|
-
is_match = isinstance(logical_expression.operator,
|
|
808
|
-
logger.debug(f"
|
|
808
|
+
is_match = isinstance(logical_expression.operator, TopKScan)
|
|
809
|
+
logger.debug(f"TopKRule matches_pattern: {is_match} for {logical_expression}")
|
|
809
810
|
return is_match
|
|
810
811
|
|
|
811
812
|
@classmethod
|
|
812
813
|
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
813
|
-
logger.debug(f"Substituting
|
|
814
|
+
logger.debug(f"Substituting TopKRule for {logical_expression}")
|
|
814
815
|
|
|
815
816
|
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
816
817
|
ks = cls.k_budgets if logical_expression.operator.k == -1 else [logical_expression.operator.k]
|
|
817
818
|
variable_op_kwargs = [{"k": k} for k in ks]
|
|
818
|
-
return cls._perform_substitution(logical_expression,
|
|
819
|
+
return cls._perform_substitution(logical_expression, TopKOp, runtime_kwargs, variable_op_kwargs)
|
|
819
820
|
|
|
820
821
|
|
|
821
822
|
class NonLLMFilterRule(ImplementationRule):
|
|
@@ -867,6 +868,23 @@ class LLMFilterRule(ImplementationRule):
|
|
|
867
868
|
return cls._perform_substitution(logical_expression, LLMFilter, runtime_kwargs, variable_op_kwargs)
|
|
868
869
|
|
|
869
870
|
|
|
871
|
+
class RelationalJoinRule(ImplementationRule):
|
|
872
|
+
"""
|
|
873
|
+
Substitute a logical expression for a JoinOp with a RelationalJoin physical implementation.
|
|
874
|
+
"""
|
|
875
|
+
|
|
876
|
+
@classmethod
|
|
877
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
878
|
+
is_match = isinstance(logical_expression.operator, JoinOp) and logical_expression.operator.condition == ""
|
|
879
|
+
logger.debug(f"RelationalJoinRule matches_pattern: {is_match} for {logical_expression}")
|
|
880
|
+
return is_match
|
|
881
|
+
|
|
882
|
+
@classmethod
|
|
883
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
884
|
+
logger.debug(f"Substituting RelationalJoinRule for {logical_expression}")
|
|
885
|
+
return cls._perform_substitution(logical_expression, RelationalJoin, runtime_kwargs)
|
|
886
|
+
|
|
887
|
+
|
|
870
888
|
class NestedLoopsJoinRule(ImplementationRule):
|
|
871
889
|
"""
|
|
872
890
|
Substitute a logical expression for a JoinOp with an (LLM) NestedLoopsJoin physical implementation.
|
|
@@ -874,7 +892,7 @@ class NestedLoopsJoinRule(ImplementationRule):
|
|
|
874
892
|
|
|
875
893
|
@classmethod
|
|
876
894
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
877
|
-
is_match = isinstance(logical_expression.operator, JoinOp)
|
|
895
|
+
is_match = isinstance(logical_expression.operator, JoinOp) and logical_expression.operator.condition != ""
|
|
878
896
|
logger.debug(f"NestedLoopsJoinRule matches_pattern: {is_match} for {logical_expression}")
|
|
879
897
|
return is_match
|
|
880
898
|
|
|
@@ -906,7 +924,7 @@ class EmbeddingJoinRule(ImplementationRule):
|
|
|
906
924
|
|
|
907
925
|
@classmethod
|
|
908
926
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
909
|
-
is_match = isinstance(logical_expression.operator, JoinOp) and not cls._is_audio_operation(logical_expression)
|
|
927
|
+
is_match = isinstance(logical_expression.operator, JoinOp) and logical_expression.operator.condition != "" and not cls._is_audio_operation(logical_expression)
|
|
910
928
|
logger.debug(f"EmbeddingJoinRule matches_pattern: {is_match} for {logical_expression}")
|
|
911
929
|
return is_match
|
|
912
930
|
|
|
@@ -982,6 +1000,8 @@ class AggregateRule(ImplementationRule):
|
|
|
982
1000
|
physical_op_class = CountAggregateOp
|
|
983
1001
|
elif logical_expression.operator.agg_func == AggFunc.AVERAGE:
|
|
984
1002
|
physical_op_class = AverageAggregateOp
|
|
1003
|
+
elif logical_expression.operator.agg_func == AggFunc.SUM:
|
|
1004
|
+
physical_op_class = SumAggregateOp
|
|
985
1005
|
elif logical_expression.operator.agg_func == AggFunc.MIN:
|
|
986
1006
|
physical_op_class = MinAggregateOp
|
|
987
1007
|
elif logical_expression.operator.agg_func == AggFunc.MAX:
|
|
@@ -501,8 +501,8 @@ class OptimizePhysicalExpression(Task):
|
|
|
501
501
|
|
|
502
502
|
# compute the total cost for this physical expression by summing its operator's PlanCost
|
|
503
503
|
# with the input groups' total PlanCost; also set the op_estimates for this expression's operator
|
|
504
|
-
|
|
505
|
-
full_plan_cost = op_plan_cost.join_add(left_input_plan_cost, right_input_plan_cost,
|
|
504
|
+
execution_strategy_str = "parallel" if execution_strategy.is_fully_parallel() else "sequential"
|
|
505
|
+
full_plan_cost = op_plan_cost.join_add(left_input_plan_cost, right_input_plan_cost, execution_strategy_str)
|
|
506
506
|
full_plan_cost.op_estimates = op_plan_cost.op_estimates
|
|
507
507
|
all_possible_plan_costs.append((full_plan_cost, (left_input_plan_cost, right_input_plan_cost)))
|
|
508
508
|
|
|
@@ -570,8 +570,8 @@ class OptimizePhysicalExpression(Task):
|
|
|
570
570
|
|
|
571
571
|
# compute the total cost for this physical expression by summing its operator's PlanCost
|
|
572
572
|
# with the input groups' total PlanCost; also set the op_estimates for this expression's operator
|
|
573
|
-
|
|
574
|
-
full_plan_cost = op_plan_cost.join_add(left_best_input_plan_cost, right_best_input_plan_cost,
|
|
573
|
+
execution_strategy_str = "parallel" if execution_strategy.is_fully_parallel() else "sequential"
|
|
574
|
+
full_plan_cost = op_plan_cost.join_add(left_best_input_plan_cost, right_best_input_plan_cost, execution_strategy_str)
|
|
575
575
|
full_plan_cost.op_estimates = op_plan_cost.op_estimates
|
|
576
576
|
|
|
577
577
|
else:
|
palimpzest/utils/progress.py
CHANGED
|
@@ -24,7 +24,7 @@ from palimpzest.query.operators.filter import LLMFilter
|
|
|
24
24
|
from palimpzest.query.operators.join import JoinOp
|
|
25
25
|
from palimpzest.query.operators.limit import LimitScanOp
|
|
26
26
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
27
|
-
from palimpzest.query.operators.
|
|
27
|
+
from palimpzest.query.operators.topk import TopKOp
|
|
28
28
|
from palimpzest.query.optimizer.plan import PhysicalPlan, SentinelPlan
|
|
29
29
|
|
|
30
30
|
|
|
@@ -225,20 +225,22 @@ class PZProgressManager(ProgressManager):
|
|
|
225
225
|
current_unique_full_op_id = unique_full_op_id
|
|
226
226
|
next_op, next_unique_full_op_id = self.unique_full_op_id_to_next_op_and_id[unique_full_op_id]
|
|
227
227
|
while next_op is not None:
|
|
228
|
-
if
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
228
|
+
if isinstance(next_op, (AggregateOp, LimitScanOp)):
|
|
229
|
+
break
|
|
230
|
+
|
|
231
|
+
next_task = self.unique_full_op_id_to_task[next_unique_full_op_id]
|
|
232
|
+
multiplier = 1
|
|
233
|
+
if isinstance(next_op, JoinOp):
|
|
234
|
+
# for joins, scale the delta by the number of inputs from the other side of the join
|
|
235
|
+
left_input_unique_full_op_id, right_input_unique_input_op_id = self.unique_full_op_id_to_input_unique_full_op_ids[next_unique_full_op_id]
|
|
236
|
+
if current_unique_full_op_id == left_input_unique_full_op_id:
|
|
237
|
+
multiplier = self.get_task_total(right_input_unique_input_op_id)
|
|
238
|
+
elif current_unique_full_op_id == right_input_unique_input_op_id:
|
|
239
|
+
multiplier = self.get_task_total(left_input_unique_full_op_id)
|
|
240
|
+
else:
|
|
241
|
+
raise ValueError(f"Current op ID {current_unique_full_op_id} not found in join inputs {left_input_unique_full_op_id}, {right_input_unique_input_op_id}")
|
|
242
|
+
delta_adjusted = delta * multiplier
|
|
243
|
+
self.progress.update(next_task, total=self.get_task_total(next_unique_full_op_id) + delta_adjusted)
|
|
242
244
|
|
|
243
245
|
# move to the next operator in the plan
|
|
244
246
|
current_unique_full_op_id = next_unique_full_op_id
|
|
@@ -348,9 +350,9 @@ class PZSentinelProgressManager(ProgressManager):
|
|
|
348
350
|
def _is_llm_op(self, physical_op: PhysicalOperator) -> bool:
|
|
349
351
|
is_llm_convert = isinstance(physical_op, LLMConvert)
|
|
350
352
|
is_llm_filter = isinstance(physical_op, LLMFilter)
|
|
351
|
-
|
|
353
|
+
is_llm_topk = isinstance(physical_op, TopKOp) and isinstance(physical_op.index, Collection)
|
|
352
354
|
is_llm_join = isinstance(physical_op, JoinOp)
|
|
353
|
-
return is_llm_convert or is_llm_filter or
|
|
355
|
+
return is_llm_convert or is_llm_filter or is_llm_topk or is_llm_join
|
|
354
356
|
|
|
355
357
|
def get_task_description(self, unique_logical_op_id: str) -> str:
|
|
356
358
|
"""Return the current description for the given task."""
|
|
@@ -19,7 +19,7 @@ from palimpzest.query.generators.generators import get_json_from_answer
|
|
|
19
19
|
from palimpzest.query.operators.convert import LLMConvert
|
|
20
20
|
from palimpzest.query.operators.filter import LLMFilter
|
|
21
21
|
from palimpzest.query.operators.join import JoinOp
|
|
22
|
-
from palimpzest.query.operators.
|
|
22
|
+
from palimpzest.query.operators.topk import TopKOp
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class Validator:
|
|
@@ -47,7 +47,7 @@ class Validator:
|
|
|
47
47
|
def join_score_fn(self, condition: str, left_input_record: dict, right_input_record: dict, output: bool) -> float | None:
|
|
48
48
|
raise NotImplementedError("Validator.join_score_fn not implemented.")
|
|
49
49
|
|
|
50
|
-
def
|
|
50
|
+
def topk_score_fn(self, fields: list[str], input_record: dict, output: dict) -> float | None:
|
|
51
51
|
raise NotImplementedError("Validator.map_score_fn not implemented.")
|
|
52
52
|
|
|
53
53
|
def _get_gen_stats_from_completion(self, completion, start_time: float) -> GenerationStats:
|
|
@@ -218,11 +218,11 @@ class Validator:
|
|
|
218
218
|
|
|
219
219
|
return score, gen_stats
|
|
220
220
|
|
|
221
|
-
def
|
|
221
|
+
def _default_topk_score_fn(self, op: TopKOp, fields: list[str], input_record: DataRecord, output: dict) -> tuple[float | None, GenerationStats]:
|
|
222
222
|
"""
|
|
223
223
|
Compute the quality of the generated output for the given fields and input_record.
|
|
224
224
|
"""
|
|
225
|
-
# TODO:
|
|
225
|
+
# TODO: top-k k=25; score each item based on relevance; compute F1
|
|
226
226
|
# TODO: support retrieval over images
|
|
227
227
|
# create prompt factory
|
|
228
228
|
factory = PromptFactory(PromptStrategy.MAP, self.model, Cardinality.ONE_TO_ONE)
|
|
@@ -294,11 +294,11 @@ class Validator:
|
|
|
294
294
|
score, gen_stats = self._default_join_score_fn(op, condition, left_input_record, right_input_record, output)
|
|
295
295
|
return score, gen_stats, full_hash
|
|
296
296
|
|
|
297
|
-
def
|
|
297
|
+
def _score_topk(self, op: TopKOp, fields: list[str], input_record: DataRecord, output: dict, full_hash: str) -> tuple[float | None, GenerationStats, str]:
|
|
298
298
|
try:
|
|
299
|
-
out = self.
|
|
299
|
+
out = self.topk_score_fn(fields, input_record.to_dict(), output)
|
|
300
300
|
score, gen_stats = out if isinstance(out, tuple) else (out, GenerationStats())
|
|
301
301
|
return score, gen_stats, full_hash
|
|
302
302
|
except NotImplementedError:
|
|
303
|
-
score, gen_stats = self.
|
|
303
|
+
score, gen_stats = self._default_topk_score_fn(op, fields, input_record, output)
|
|
304
304
|
return score, gen_stats, full_hash
|