palimpzest 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +38 -62
- palimpzest/core/data/dataset.py +1 -1
- palimpzest/core/data/iter_dataset.py +5 -5
- palimpzest/core/elements/groupbysig.py +1 -1
- palimpzest/core/elements/records.py +91 -109
- palimpzest/core/lib/schemas.py +23 -0
- palimpzest/core/models.py +3 -3
- palimpzest/prompts/__init__.py +2 -6
- palimpzest/prompts/convert_prompts.py +10 -66
- palimpzest/prompts/critique_and_refine_prompts.py +66 -0
- palimpzest/prompts/filter_prompts.py +8 -46
- palimpzest/prompts/join_prompts.py +12 -75
- palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
- palimpzest/prompts/moa_proposer_prompts.py +87 -0
- palimpzest/prompts/prompt_factory.py +351 -479
- palimpzest/prompts/split_merge_prompts.py +51 -2
- palimpzest/prompts/split_proposer_prompts.py +48 -16
- palimpzest/prompts/utils.py +109 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
- palimpzest/query/execution/execution_strategy.py +4 -4
- palimpzest/query/execution/mab_execution_strategy.py +47 -23
- palimpzest/query/execution/parallel_execution_strategy.py +3 -3
- palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
- palimpzest/query/generators/generators.py +31 -17
- palimpzest/query/operators/__init__.py +15 -2
- palimpzest/query/operators/aggregate.py +21 -19
- palimpzest/query/operators/compute.py +6 -8
- palimpzest/query/operators/convert.py +12 -37
- palimpzest/query/operators/critique_and_refine.py +194 -0
- palimpzest/query/operators/distinct.py +7 -7
- palimpzest/query/operators/filter.py +13 -25
- palimpzest/query/operators/join.py +321 -192
- palimpzest/query/operators/limit.py +4 -4
- palimpzest/query/operators/mixture_of_agents.py +246 -0
- palimpzest/query/operators/physical.py +25 -2
- palimpzest/query/operators/project.py +4 -4
- palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
- palimpzest/query/operators/retrieve.py +10 -9
- palimpzest/query/operators/scan.py +9 -10
- palimpzest/query/operators/search.py +18 -24
- palimpzest/query/operators/split.py +321 -0
- palimpzest/query/optimizer/__init__.py +12 -8
- palimpzest/query/optimizer/optimizer.py +12 -10
- palimpzest/query/optimizer/rules.py +201 -108
- palimpzest/query/optimizer/tasks.py +18 -6
- palimpzest/query/processor/config.py +2 -2
- palimpzest/query/processor/query_processor.py +2 -2
- palimpzest/query/processor/query_processor_factory.py +9 -5
- palimpzest/validator/validator.py +7 -9
- {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/METADATA +3 -8
- palimpzest-0.8.3.dist-info/RECORD +95 -0
- palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
- palimpzest/prompts/util_phrases.py +0 -19
- palimpzest/query/operators/critique_and_refine_convert.py +0 -113
- palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
- palimpzest/query/operators/split_convert.py +0 -170
- palimpzest-0.8.1.dist-info/RECORD +0 -95
- {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/WHEEL +0 -0
- {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -91,17 +91,15 @@ class SmolAgentsSearch(PhysicalOperator):
|
|
|
91
91
|
Given an input DataRecord and a determination of whether it passed the filter or not,
|
|
92
92
|
construct the resulting RecordSet.
|
|
93
93
|
"""
|
|
94
|
-
# create new DataRecord
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
if field in answer:
|
|
98
|
-
dr[field] = answer[field]
|
|
94
|
+
# create new DataRecord
|
|
95
|
+
data_item = {field: answer[field] for field in self.output_schema.model_fields if field in answer}
|
|
96
|
+
dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
|
|
99
97
|
|
|
100
98
|
# create RecordOpStats object
|
|
101
99
|
record_op_stats = RecordOpStats(
|
|
102
|
-
record_id=dr.
|
|
103
|
-
record_parent_ids=dr.
|
|
104
|
-
record_source_indices=dr.
|
|
100
|
+
record_id=dr._id,
|
|
101
|
+
record_parent_ids=dr._parent_ids,
|
|
102
|
+
record_source_indices=dr._source_indices,
|
|
105
103
|
record_state=dr.to_dict(include_bytes=False),
|
|
106
104
|
full_op_id=self.get_full_op_id(),
|
|
107
105
|
logical_op_id=self.logical_op_id,
|
|
@@ -248,17 +246,15 @@ class SmolAgentsSearch(PhysicalOperator):
|
|
|
248
246
|
# Given an input DataRecord and a determination of whether it passed the filter or not,
|
|
249
247
|
# construct the resulting RecordSet.
|
|
250
248
|
# """
|
|
251
|
-
# # create new DataRecord
|
|
252
|
-
#
|
|
253
|
-
#
|
|
254
|
-
# if field in answer:
|
|
255
|
-
# dr[field] = answer[field]
|
|
249
|
+
# # create new DataRecord
|
|
250
|
+
# data_item = {field: answer[field] for field in self.output_schema.model_fields if field in answer}
|
|
251
|
+
# dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
|
|
256
252
|
|
|
257
253
|
# # create RecordOpStats object
|
|
258
254
|
# record_op_stats = RecordOpStats(
|
|
259
|
-
# record_id=dr.
|
|
260
|
-
# record_parent_ids=dr.
|
|
261
|
-
# record_source_indices=dr.
|
|
255
|
+
# record_id=dr._id,
|
|
256
|
+
# record_parent_ids=dr._parent_ids,
|
|
257
|
+
# record_source_indices=dr._source_indices,
|
|
262
258
|
# record_state=dr.to_dict(include_bytes=False),
|
|
263
259
|
# full_op_id=self.get_full_op_id(),
|
|
264
260
|
# logical_op_id=self.logical_op_id,
|
|
@@ -440,17 +436,15 @@ class SmolAgentsSearch(PhysicalOperator):
|
|
|
440
436
|
# Given an input DataRecord and a determination of whether it passed the filter or not,
|
|
441
437
|
# construct the resulting RecordSet.
|
|
442
438
|
# """
|
|
443
|
-
# # create new DataRecord
|
|
444
|
-
#
|
|
445
|
-
#
|
|
446
|
-
# if field in answer:
|
|
447
|
-
# dr[field] = answer[field]
|
|
439
|
+
# # create new DataRecord
|
|
440
|
+
# data_item = {field: answer[field] for field in self.output_schema.model_fields if field in answer}
|
|
441
|
+
# dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
|
|
448
442
|
|
|
449
443
|
# # create RecordOpStats object
|
|
450
444
|
# record_op_stats = RecordOpStats(
|
|
451
|
-
# record_id=dr.
|
|
452
|
-
# record_parent_ids=dr.
|
|
453
|
-
# record_source_indices=dr.
|
|
445
|
+
# record_id=dr._id,
|
|
446
|
+
# record_parent_ids=dr._parent_ids,
|
|
447
|
+
# record_source_indices=dr._source_indices,
|
|
454
448
|
# record_state=dr.to_dict(include_bytes=False),
|
|
455
449
|
# full_op_id=self.get_full_op_id(),
|
|
456
450
|
# logical_op_id=self.logical_op_id,
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from pydantic.fields import FieldInfo
|
|
6
|
+
|
|
7
|
+
from palimpzest.constants import (
|
|
8
|
+
MODEL_CARDS,
|
|
9
|
+
NAIVE_EST_NUM_INPUT_TOKENS,
|
|
10
|
+
NAIVE_EST_NUM_OUTPUT_TOKENS,
|
|
11
|
+
Cardinality,
|
|
12
|
+
PromptStrategy,
|
|
13
|
+
)
|
|
14
|
+
from palimpzest.core.elements.records import DataRecord
|
|
15
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates
|
|
16
|
+
from palimpzest.query.generators.generators import Generator
|
|
17
|
+
from palimpzest.query.operators.convert import LLMConvert
|
|
18
|
+
from palimpzest.query.operators.filter import LLMFilter
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SplitConvert(LLMConvert):
|
|
22
|
+
def __init__(self, num_chunks: int = 2, min_size_to_chunk: int = 1000, *args, **kwargs):
|
|
23
|
+
kwargs["prompt_strategy"] = None
|
|
24
|
+
super().__init__(*args, **kwargs)
|
|
25
|
+
self.num_chunks = num_chunks
|
|
26
|
+
self.min_size_to_chunk = min_size_to_chunk
|
|
27
|
+
self.split_generator = Generator(self.model, PromptStrategy.MAP_SPLIT_PROPOSER, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
28
|
+
self.split_merge_generator = Generator(self.model, PromptStrategy.MAP_SPLIT_MERGER, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
29
|
+
|
|
30
|
+
# crude adjustment factor for naive estimation in unoptimized setting
|
|
31
|
+
self.naive_quality_adjustment = 0.6
|
|
32
|
+
|
|
33
|
+
def __str__(self):
|
|
34
|
+
op = super().__str__()
|
|
35
|
+
op += f" Chunk Size: {str(self.num_chunks)}\n"
|
|
36
|
+
op += f" Min Size to Chunk: {str(self.min_size_to_chunk)}\n"
|
|
37
|
+
return op
|
|
38
|
+
|
|
39
|
+
def get_id_params(self):
|
|
40
|
+
id_params = super().get_id_params()
|
|
41
|
+
id_params = {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **id_params}
|
|
42
|
+
|
|
43
|
+
return id_params
|
|
44
|
+
|
|
45
|
+
def get_op_params(self):
|
|
46
|
+
op_params = super().get_op_params()
|
|
47
|
+
return {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **op_params}
|
|
48
|
+
|
|
49
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
50
|
+
"""
|
|
51
|
+
Update the cost per record and quality estimates produced by LLMConvert's naive estimates.
|
|
52
|
+
We adjust the cost per record to account for the reduced number of input tokens following
|
|
53
|
+
the retrieval of relevant chunks, and we make a crude estimate of the quality degradation
|
|
54
|
+
that results from using a downsized input (although this may in fact improve quality in
|
|
55
|
+
some cases).
|
|
56
|
+
"""
|
|
57
|
+
# get naive cost estimates from LLMConvert
|
|
58
|
+
naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
|
|
59
|
+
|
|
60
|
+
# re-compute cost per record assuming we use fewer input tokens; naively assume a single input field
|
|
61
|
+
est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS
|
|
62
|
+
est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
|
|
63
|
+
model_conversion_usd_per_record = (
|
|
64
|
+
MODEL_CARDS[self.model.value]["usd_per_input_token"] * est_num_input_tokens
|
|
65
|
+
+ MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# set refined estimate of cost per record
|
|
69
|
+
naive_op_cost_estimates.cost_per_record = model_conversion_usd_per_record
|
|
70
|
+
naive_op_cost_estimates.cost_per_record_lower_bound = naive_op_cost_estimates.cost_per_record
|
|
71
|
+
naive_op_cost_estimates.cost_per_record_upper_bound = naive_op_cost_estimates.cost_per_record
|
|
72
|
+
naive_op_cost_estimates.quality = (naive_op_cost_estimates.quality) * self.naive_quality_adjustment
|
|
73
|
+
naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
|
|
74
|
+
naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
|
|
75
|
+
|
|
76
|
+
return naive_op_cost_estimates
|
|
77
|
+
|
|
78
|
+
def get_text_chunks(self, text: str, num_chunks: int) -> list[str]:
|
|
79
|
+
"""
|
|
80
|
+
Given a text string, chunk it into num_chunks substrings of roughly equal size.
|
|
81
|
+
"""
|
|
82
|
+
chunks = []
|
|
83
|
+
|
|
84
|
+
idx, chunk_size = 0, math.ceil(len(text) / num_chunks)
|
|
85
|
+
while idx + chunk_size < len(text):
|
|
86
|
+
chunks.append(text[idx : idx + chunk_size])
|
|
87
|
+
idx += chunk_size
|
|
88
|
+
|
|
89
|
+
if idx < len(text):
|
|
90
|
+
chunks.append(text[idx:])
|
|
91
|
+
|
|
92
|
+
return chunks
|
|
93
|
+
|
|
94
|
+
def get_chunked_candidate(self, candidate: DataRecord, input_fields: list[str]) -> list[DataRecord]:
|
|
95
|
+
"""
|
|
96
|
+
For each text field, chunk the content. If a field is smaller than the chunk size,
|
|
97
|
+
simply include the full field.
|
|
98
|
+
"""
|
|
99
|
+
# compute mapping from each field to its chunked content
|
|
100
|
+
field_name_to_chunked_content = {}
|
|
101
|
+
for field_name in input_fields:
|
|
102
|
+
field = candidate.get_field_type(field_name)
|
|
103
|
+
content = candidate[field_name]
|
|
104
|
+
|
|
105
|
+
# do not chunk this field if it is not a string or a list of strings
|
|
106
|
+
is_string_field = field.annotation in [str, str | None]
|
|
107
|
+
is_list_string_field = field.annotation in [list[str], list[str] | None]
|
|
108
|
+
if not (is_string_field or is_list_string_field):
|
|
109
|
+
field_name_to_chunked_content[field_name] = [content]
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
# if this is a list of strings, join the strings
|
|
113
|
+
if is_list_string_field:
|
|
114
|
+
content = "[" + ", ".join(content) + "]"
|
|
115
|
+
|
|
116
|
+
# skip this field if its length is less than the min size to chunk
|
|
117
|
+
if len(content) < self.min_size_to_chunk:
|
|
118
|
+
field_name_to_chunked_content[field_name] = [content]
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
# chunk the content
|
|
122
|
+
field_name_to_chunked_content[field_name] = self.get_text_chunks(content, self.num_chunks)
|
|
123
|
+
|
|
124
|
+
# compute the true number of chunks (may be 1 if all fields are not chunked)
|
|
125
|
+
num_chunks = max(len(chunks) for chunks in field_name_to_chunked_content.values())
|
|
126
|
+
|
|
127
|
+
# create the chunked canidates
|
|
128
|
+
candidates = []
|
|
129
|
+
for chunk_idx in range(num_chunks):
|
|
130
|
+
candidate_copy = candidate.copy()
|
|
131
|
+
for field_name in input_fields:
|
|
132
|
+
field_chunks = field_name_to_chunked_content[field_name]
|
|
133
|
+
candidate_copy[field_name] = field_chunks[chunk_idx] if len(field_chunks) > 1 else field_chunks[0]
|
|
134
|
+
|
|
135
|
+
candidates.append(candidate_copy)
|
|
136
|
+
|
|
137
|
+
return candidates
|
|
138
|
+
|
|
139
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
140
|
+
# get the set of input fields to use for the convert operation
|
|
141
|
+
input_fields = self.get_input_fields()
|
|
142
|
+
|
|
143
|
+
# lookup most relevant chunks for each field using embedding search
|
|
144
|
+
candidate_copy = candidate.copy()
|
|
145
|
+
chunked_candidates = self.get_chunked_candidate(candidate_copy, input_fields)
|
|
146
|
+
|
|
147
|
+
# construct kwargs for generation
|
|
148
|
+
gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema}
|
|
149
|
+
|
|
150
|
+
# generate outputs for each chunk separately
|
|
151
|
+
chunk_outputs, chunk_generation_stats_lst = [], []
|
|
152
|
+
for candidate in chunked_candidates:
|
|
153
|
+
_, reasoning, chunk_generation_stats, _ = self.split_generator(candidate, fields, json_output=False, **gen_kwargs)
|
|
154
|
+
chunk_outputs.append(reasoning)
|
|
155
|
+
chunk_generation_stats_lst.append(chunk_generation_stats)
|
|
156
|
+
|
|
157
|
+
# call the merger
|
|
158
|
+
gen_kwargs = {
|
|
159
|
+
"project_cols": input_fields,
|
|
160
|
+
"output_schema": self.output_schema,
|
|
161
|
+
"chunk_outputs": chunk_outputs,
|
|
162
|
+
}
|
|
163
|
+
field_answers, _, merger_gen_stats, _ = self.split_merge_generator(candidate, fields, **gen_kwargs)
|
|
164
|
+
|
|
165
|
+
# compute the total generation stats
|
|
166
|
+
generation_stats = sum(chunk_generation_stats_lst) + merger_gen_stats
|
|
167
|
+
|
|
168
|
+
return field_answers, generation_stats
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class SplitFilter(LLMFilter):
|
|
172
|
+
def __init__(self, num_chunks: int = 2, min_size_to_chunk: int = 1000, *args, **kwargs):
|
|
173
|
+
kwargs["prompt_strategy"] = None
|
|
174
|
+
super().__init__(*args, **kwargs)
|
|
175
|
+
self.num_chunks = num_chunks
|
|
176
|
+
self.min_size_to_chunk = min_size_to_chunk
|
|
177
|
+
self.split_generator = Generator(self.model, PromptStrategy.FILTER_SPLIT_PROPOSER, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
|
|
178
|
+
self.split_merge_generator = Generator(self.model, PromptStrategy.FILTER_SPLIT_MERGER, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
|
|
179
|
+
|
|
180
|
+
# crude adjustment factor for naive estimation in no-sentinel setting
|
|
181
|
+
self.naive_quality_adjustment = 0.6
|
|
182
|
+
|
|
183
|
+
def __str__(self):
|
|
184
|
+
op = super().__str__()
|
|
185
|
+
op += f" Chunk Size: {str(self.num_chunks)}\n"
|
|
186
|
+
op += f" Min Size to Chunk: {str(self.min_size_to_chunk)}\n"
|
|
187
|
+
return op
|
|
188
|
+
|
|
189
|
+
def get_id_params(self):
|
|
190
|
+
id_params = super().get_id_params()
|
|
191
|
+
id_params = {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **id_params}
|
|
192
|
+
|
|
193
|
+
return id_params
|
|
194
|
+
|
|
195
|
+
def get_op_params(self):
|
|
196
|
+
op_params = super().get_op_params()
|
|
197
|
+
return {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **op_params}
|
|
198
|
+
|
|
199
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
200
|
+
"""
|
|
201
|
+
Update the cost per record and quality estimates produced by LLMFilter's naive estimates.
|
|
202
|
+
We adjust the cost per record to account for the reduced number of input tokens following
|
|
203
|
+
the retrieval of relevant chunks, and we make a crude estimate of the quality degradation
|
|
204
|
+
that results from using a downsized input (although this may in fact improve quality in
|
|
205
|
+
some cases).
|
|
206
|
+
"""
|
|
207
|
+
# get naive cost estimates from LLMFilter
|
|
208
|
+
naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
|
|
209
|
+
|
|
210
|
+
# re-compute cost per record assuming we use fewer input tokens; naively assume a single input field
|
|
211
|
+
est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS
|
|
212
|
+
est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
|
|
213
|
+
model_conversion_usd_per_record = (
|
|
214
|
+
MODEL_CARDS[self.model.value]["usd_per_input_token"] * est_num_input_tokens
|
|
215
|
+
+ MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# set refined estimate of cost per record
|
|
219
|
+
naive_op_cost_estimates.cost_per_record = model_conversion_usd_per_record
|
|
220
|
+
naive_op_cost_estimates.cost_per_record_lower_bound = naive_op_cost_estimates.cost_per_record
|
|
221
|
+
naive_op_cost_estimates.cost_per_record_upper_bound = naive_op_cost_estimates.cost_per_record
|
|
222
|
+
naive_op_cost_estimates.quality = (naive_op_cost_estimates.quality) * self.naive_quality_adjustment
|
|
223
|
+
naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
|
|
224
|
+
naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
|
|
225
|
+
|
|
226
|
+
return naive_op_cost_estimates
|
|
227
|
+
|
|
228
|
+
def get_text_chunks(self, text: str, num_chunks: int) -> list[str]:
|
|
229
|
+
"""
|
|
230
|
+
Given a text string, chunk it into num_chunks substrings of roughly equal size.
|
|
231
|
+
"""
|
|
232
|
+
chunks = []
|
|
233
|
+
|
|
234
|
+
idx, chunk_size = 0, math.ceil(len(text) / num_chunks)
|
|
235
|
+
while idx + chunk_size < len(text):
|
|
236
|
+
chunks.append(text[idx : idx + chunk_size])
|
|
237
|
+
idx += chunk_size
|
|
238
|
+
|
|
239
|
+
if idx < len(text):
|
|
240
|
+
chunks.append(text[idx:])
|
|
241
|
+
|
|
242
|
+
return chunks
|
|
243
|
+
|
|
244
|
+
def get_chunked_candidate(self, candidate: DataRecord, input_fields: list[str]) -> list[DataRecord]:
|
|
245
|
+
"""
|
|
246
|
+
For each text field, chunk the content. If a field is smaller than the chunk size,
|
|
247
|
+
simply include the full field.
|
|
248
|
+
"""
|
|
249
|
+
# compute mapping from each field to its chunked content
|
|
250
|
+
field_name_to_chunked_content = {}
|
|
251
|
+
for field_name in input_fields:
|
|
252
|
+
field = candidate.get_field_type(field_name)
|
|
253
|
+
content = candidate[field_name]
|
|
254
|
+
|
|
255
|
+
# do not chunk this field if it is not a string or a list of strings
|
|
256
|
+
is_string_field = field.annotation in [str, str | None]
|
|
257
|
+
is_list_string_field = field.annotation in [list[str], list[str] | None]
|
|
258
|
+
if not (is_string_field or is_list_string_field):
|
|
259
|
+
field_name_to_chunked_content[field_name] = [content]
|
|
260
|
+
continue
|
|
261
|
+
|
|
262
|
+
# if this is a list of strings, join the strings
|
|
263
|
+
if is_list_string_field:
|
|
264
|
+
content = "[" + ", ".join(content) + "]"
|
|
265
|
+
|
|
266
|
+
# skip this field if its length is less than the min size to chunk
|
|
267
|
+
if len(content) < self.min_size_to_chunk:
|
|
268
|
+
field_name_to_chunked_content[field_name] = [content]
|
|
269
|
+
continue
|
|
270
|
+
|
|
271
|
+
# chunk the content
|
|
272
|
+
field_name_to_chunked_content[field_name] = self.get_text_chunks(content, self.num_chunks)
|
|
273
|
+
|
|
274
|
+
# compute the true number of chunks (may be 1 if all fields are not chunked)
|
|
275
|
+
num_chunks = max(len(chunks) for chunks in field_name_to_chunked_content.values())
|
|
276
|
+
|
|
277
|
+
# create the chunked canidates
|
|
278
|
+
candidates = []
|
|
279
|
+
for chunk_idx in range(num_chunks):
|
|
280
|
+
candidate_copy = candidate.copy()
|
|
281
|
+
for field_name in input_fields:
|
|
282
|
+
field_chunks = field_name_to_chunked_content[field_name]
|
|
283
|
+
candidate_copy[field_name] = field_chunks[chunk_idx] if len(field_chunks) > 1 else field_chunks[0]
|
|
284
|
+
|
|
285
|
+
candidates.append(candidate_copy)
|
|
286
|
+
|
|
287
|
+
return candidates
|
|
288
|
+
|
|
289
|
+
def filter(self, candidate: DataRecord) -> tuple[dict[str, bool], GenerationStats]:
|
|
290
|
+
# get the set of input fields to use for the filter operation
|
|
291
|
+
input_fields = self.get_input_fields()
|
|
292
|
+
|
|
293
|
+
# construct output fields
|
|
294
|
+
fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the record passed the filter operation")}
|
|
295
|
+
|
|
296
|
+
# lookup most relevant chunks for each field using embedding search
|
|
297
|
+
candidate_copy = candidate.copy()
|
|
298
|
+
chunked_candidates = self.get_chunked_candidate(candidate_copy, input_fields)
|
|
299
|
+
|
|
300
|
+
# construct kwargs for generation
|
|
301
|
+
gen_kwargs = {"project_cols": input_fields, "filter_condition": self.filter_obj.filter_condition}
|
|
302
|
+
|
|
303
|
+
# generate outputs for each chunk separately
|
|
304
|
+
chunk_outputs, chunk_generation_stats_lst = [], []
|
|
305
|
+
for candidate in chunked_candidates:
|
|
306
|
+
_, reasoning, chunk_generation_stats, _ = self.split_generator(candidate, fields, json_output=False, **gen_kwargs)
|
|
307
|
+
chunk_outputs.append(reasoning)
|
|
308
|
+
chunk_generation_stats_lst.append(chunk_generation_stats)
|
|
309
|
+
|
|
310
|
+
# call the merger
|
|
311
|
+
gen_kwargs = {
|
|
312
|
+
"project_cols": input_fields,
|
|
313
|
+
"filter_condition": self.filter_obj.filter_condition,
|
|
314
|
+
"chunk_outputs": chunk_outputs,
|
|
315
|
+
}
|
|
316
|
+
field_answers, _, merger_gen_stats, _ = self.split_merge_generator(candidate, fields, **gen_kwargs)
|
|
317
|
+
|
|
318
|
+
# compute the total generation stats
|
|
319
|
+
generation_stats = sum(chunk_generation_stats_lst) + merger_gen_stats
|
|
320
|
+
|
|
321
|
+
return field_answers, generation_stats
|
|
@@ -6,7 +6,7 @@ from palimpzest.query.optimizer.rules import (
|
|
|
6
6
|
BasicSubstitutionRule as _BasicSubstitutionRule,
|
|
7
7
|
)
|
|
8
8
|
from palimpzest.query.optimizer.rules import (
|
|
9
|
-
|
|
9
|
+
CritiqueAndRefineRule as _CritiqueAndRefineRule,
|
|
10
10
|
)
|
|
11
11
|
from palimpzest.query.optimizer.rules import (
|
|
12
12
|
ImplementationRule as _ImplementationRule,
|
|
@@ -21,7 +21,7 @@ from palimpzest.query.optimizer.rules import (
|
|
|
21
21
|
LLMJoinRule as _LLMJoinRule,
|
|
22
22
|
)
|
|
23
23
|
from palimpzest.query.optimizer.rules import (
|
|
24
|
-
|
|
24
|
+
MixtureOfAgentsRule as _MixtureOfAgentsRule,
|
|
25
25
|
)
|
|
26
26
|
from palimpzest.query.optimizer.rules import (
|
|
27
27
|
NonLLMConvertRule as _NonLLMConvertRule,
|
|
@@ -33,7 +33,10 @@ from palimpzest.query.optimizer.rules import (
|
|
|
33
33
|
PushDownFilter as _PushDownFilter,
|
|
34
34
|
)
|
|
35
35
|
from palimpzest.query.optimizer.rules import (
|
|
36
|
-
|
|
36
|
+
RAGRule as _RAGRule,
|
|
37
|
+
)
|
|
38
|
+
from palimpzest.query.optimizer.rules import (
|
|
39
|
+
ReorderConverts as _ReorderConverts,
|
|
37
40
|
)
|
|
38
41
|
from palimpzest.query.optimizer.rules import (
|
|
39
42
|
RetrieveRule as _RetrieveRule,
|
|
@@ -42,7 +45,7 @@ from palimpzest.query.optimizer.rules import (
|
|
|
42
45
|
Rule as _Rule,
|
|
43
46
|
)
|
|
44
47
|
from palimpzest.query.optimizer.rules import (
|
|
45
|
-
|
|
48
|
+
SplitRule as _SplitRule,
|
|
46
49
|
)
|
|
47
50
|
from palimpzest.query.optimizer.rules import (
|
|
48
51
|
TransformationRule as _TransformationRule,
|
|
@@ -52,19 +55,20 @@ ALL_RULES = [
|
|
|
52
55
|
_AddContextsBeforeComputeRule,
|
|
53
56
|
_AggregateRule,
|
|
54
57
|
_BasicSubstitutionRule,
|
|
55
|
-
|
|
58
|
+
_CritiqueAndRefineRule,
|
|
56
59
|
_ImplementationRule,
|
|
57
60
|
_LLMConvertBondedRule,
|
|
58
61
|
_LLMFilterRule,
|
|
59
62
|
_LLMJoinRule,
|
|
60
|
-
|
|
63
|
+
_MixtureOfAgentsRule,
|
|
61
64
|
_NonLLMConvertRule,
|
|
62
65
|
_NonLLMFilterRule,
|
|
63
66
|
_PushDownFilter,
|
|
64
|
-
|
|
67
|
+
_RAGRule,
|
|
68
|
+
_ReorderConverts,
|
|
65
69
|
_RetrieveRule,
|
|
66
70
|
_Rule,
|
|
67
|
-
|
|
71
|
+
_SplitRule,
|
|
68
72
|
_TransformationRule,
|
|
69
73
|
]
|
|
70
74
|
|
|
@@ -29,15 +29,15 @@ from palimpzest.query.optimizer.optimizer_strategy_type import OptimizationStrat
|
|
|
29
29
|
from palimpzest.query.optimizer.plan import PhysicalPlan
|
|
30
30
|
from palimpzest.query.optimizer.primitives import Group, LogicalExpression
|
|
31
31
|
from palimpzest.query.optimizer.rules import (
|
|
32
|
-
|
|
32
|
+
CritiqueAndRefineRule,
|
|
33
33
|
LLMConvertBondedRule,
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
34
|
+
MixtureOfAgentsRule,
|
|
35
|
+
RAGRule,
|
|
36
|
+
SplitRule,
|
|
37
37
|
)
|
|
38
38
|
from palimpzest.query.optimizer.tasks import (
|
|
39
39
|
ApplyRule,
|
|
40
|
-
|
|
40
|
+
ExploreGroup,
|
|
41
41
|
OptimizeGroup,
|
|
42
42
|
OptimizeLogicalExpression,
|
|
43
43
|
OptimizePhysicalExpression,
|
|
@@ -150,22 +150,22 @@ class Optimizer:
|
|
|
150
150
|
|
|
151
151
|
if not self.allow_rag_reduction:
|
|
152
152
|
self.implementation_rules = [
|
|
153
|
-
rule for rule in self.implementation_rules if not issubclass(rule,
|
|
153
|
+
rule for rule in self.implementation_rules if not issubclass(rule, RAGRule)
|
|
154
154
|
]
|
|
155
155
|
|
|
156
156
|
if not self.allow_mixtures:
|
|
157
157
|
self.implementation_rules = [
|
|
158
|
-
rule for rule in self.implementation_rules if not issubclass(rule,
|
|
158
|
+
rule for rule in self.implementation_rules if not issubclass(rule, MixtureOfAgentsRule)
|
|
159
159
|
]
|
|
160
160
|
|
|
161
161
|
if not self.allow_critic:
|
|
162
162
|
self.implementation_rules = [
|
|
163
|
-
rule for rule in self.implementation_rules if not issubclass(rule,
|
|
163
|
+
rule for rule in self.implementation_rules if not issubclass(rule, CritiqueAndRefineRule)
|
|
164
164
|
]
|
|
165
165
|
|
|
166
166
|
if not self.allow_split_merge:
|
|
167
167
|
self.implementation_rules = [
|
|
168
|
-
rule for rule in self.implementation_rules if not issubclass(rule,
|
|
168
|
+
rule for rule in self.implementation_rules if not issubclass(rule, SplitRule)
|
|
169
169
|
]
|
|
170
170
|
|
|
171
171
|
logger.info(f"Initialized Optimizer with verbose={self.verbose}")
|
|
@@ -396,8 +396,9 @@ class Optimizer:
|
|
|
396
396
|
# TODO: conditionally stop when X number of tasks have been executed to limit exhaustive search
|
|
397
397
|
while len(self.tasks_stack) > 0:
|
|
398
398
|
task = self.tasks_stack.pop(-1)
|
|
399
|
+
|
|
399
400
|
new_tasks = []
|
|
400
|
-
if isinstance(task, (OptimizeGroup,
|
|
401
|
+
if isinstance(task, (OptimizeGroup, ExploreGroup)):
|
|
401
402
|
new_tasks = task.perform(self.groups)
|
|
402
403
|
elif isinstance(task, OptimizeLogicalExpression):
|
|
403
404
|
new_tasks = task.perform(self.transformation_rules, self.implementation_rules)
|
|
@@ -409,6 +410,7 @@ class Optimizer:
|
|
|
409
410
|
elif isinstance(task, OptimizePhysicalExpression):
|
|
410
411
|
context = {"optimizer_strategy": self.optimizer_strategy, "execution_strategy": self.execution_strategy}
|
|
411
412
|
new_tasks = task.perform(self.cost_model, self.groups, self.policy, context=context)
|
|
413
|
+
|
|
412
414
|
self.tasks_stack.extend(new_tasks)
|
|
413
415
|
|
|
414
416
|
logger.debug(f"Done searching optimization space for group_id: {group_id}")
|