palimpzest 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +5 -0
- palimpzest/constants.py +110 -43
- palimpzest/core/__init__.py +0 -78
- palimpzest/core/data/dataclasses.py +382 -44
- palimpzest/core/elements/filters.py +7 -3
- palimpzest/core/elements/index.py +70 -0
- palimpzest/core/elements/records.py +33 -11
- palimpzest/core/lib/fields.py +1 -0
- palimpzest/core/lib/schemas.py +4 -3
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
- palimpzest/prompts/prompt_factory.py +44 -7
- palimpzest/prompts/split_merge_prompts.py +56 -0
- palimpzest/prompts/split_proposer_prompts.py +55 -0
- palimpzest/query/execution/execution_strategy.py +435 -53
- palimpzest/query/execution/execution_strategy_type.py +20 -0
- palimpzest/query/execution/mab_execution_strategy.py +532 -0
- palimpzest/query/execution/parallel_execution_strategy.py +143 -172
- palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
- palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
- palimpzest/query/generators/api_client_factory.py +31 -0
- palimpzest/query/generators/generators.py +256 -76
- palimpzest/query/operators/__init__.py +1 -2
- palimpzest/query/operators/code_synthesis_convert.py +33 -18
- palimpzest/query/operators/convert.py +30 -97
- palimpzest/query/operators/critique_and_refine_convert.py +5 -6
- palimpzest/query/operators/filter.py +7 -10
- palimpzest/query/operators/logical.py +54 -10
- palimpzest/query/operators/map.py +130 -0
- palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
- palimpzest/query/operators/physical.py +3 -12
- palimpzest/query/operators/rag_convert.py +66 -18
- palimpzest/query/operators/retrieve.py +230 -34
- palimpzest/query/operators/scan.py +5 -2
- palimpzest/query/operators/split_convert.py +169 -0
- palimpzest/query/operators/token_reduction_convert.py +8 -14
- palimpzest/query/optimizer/__init__.py +4 -16
- palimpzest/query/optimizer/cost_model.py +73 -266
- palimpzest/query/optimizer/optimizer.py +87 -58
- palimpzest/query/optimizer/optimizer_strategy.py +18 -97
- palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/primitives.py +5 -3
- palimpzest/query/optimizer/rules.py +336 -172
- palimpzest/query/optimizer/tasks.py +30 -100
- palimpzest/query/processor/config.py +38 -22
- palimpzest/query/processor/nosentinel_processor.py +16 -520
- palimpzest/query/processor/processing_strategy_type.py +28 -0
- palimpzest/query/processor/query_processor.py +38 -206
- palimpzest/query/processor/query_processor_factory.py +117 -130
- palimpzest/query/processor/sentinel_processor.py +90 -0
- palimpzest/query/processor/streaming_processor.py +25 -32
- palimpzest/sets.py +88 -41
- palimpzest/utils/model_helpers.py +8 -7
- palimpzest/utils/progress.py +368 -152
- palimpzest/utils/token_reduction_helpers.py +1 -3
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/METADATA +28 -24
- palimpzest-0.7.0.dist-info/RECORD +96 -0
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
- palimpzest/query/processor/mab_sentinel_processor.py +0 -884
- palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
- palimpzest/utils/index_helpers.py +0 -6
- palimpzest-0.6.3.dist-info/RECORD +0 -87
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
- {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
|
@@ -2,31 +2,73 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import time
|
|
5
|
+
from typing import Callable
|
|
5
6
|
|
|
6
|
-
from
|
|
7
|
+
from chromadb.api.models.Collection import Collection
|
|
8
|
+
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
|
|
9
|
+
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
|
|
10
|
+
from openai import OpenAI
|
|
11
|
+
from ragatouille.RAGPretrainedModel import RAGPretrainedModel
|
|
12
|
+
from sentence_transformers import SentenceTransformer
|
|
13
|
+
|
|
14
|
+
from palimpzest.constants import MODEL_CARDS, Model
|
|
15
|
+
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates, RecordOpStats
|
|
7
16
|
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
17
|
+
from palimpzest.core.lib.schemas import Schema
|
|
8
18
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
9
19
|
|
|
10
20
|
|
|
11
21
|
class RetrieveOp(PhysicalOperator):
|
|
12
|
-
def __init__(
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
index: Collection | RAGPretrainedModel,
|
|
25
|
+
search_attr: str,
|
|
26
|
+
output_attrs: list[dict] | type[Schema],
|
|
27
|
+
search_func: Callable | None,
|
|
28
|
+
k: int,
|
|
29
|
+
*args,
|
|
30
|
+
**kwargs,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""
|
|
33
|
+
Initialize the RetrieveOp object.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
index (Collection | RAGPretrainedModel): The PZ index to use for retrieval.
|
|
37
|
+
search_attr (str): The attribute to search on.
|
|
38
|
+
output_attrs (list[dict]): The output fields containing the results of the search.
|
|
39
|
+
search_func (Callable | None): The function to use for searching the index. If None, the default search function will be used.
|
|
40
|
+
k (int): The number of top results to retrieve.
|
|
41
|
+
"""
|
|
13
42
|
super().__init__(*args, **kwargs)
|
|
43
|
+
|
|
44
|
+
# extract the field names from the output_attrs
|
|
45
|
+
if isinstance(output_attrs, Schema):
|
|
46
|
+
self.output_field_names = output_attrs.field_names()
|
|
47
|
+
elif isinstance(output_attrs, list):
|
|
48
|
+
self.output_field_names = [attr["name"] for attr in output_attrs]
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError("`output_attrs` must be a list of dicts or a Schema object.")
|
|
51
|
+
|
|
52
|
+
if len(self.output_field_names) != 1 and search_func is None:
|
|
53
|
+
raise ValueError("If `search_func` is None, `output_attrs` must have a single field.")
|
|
54
|
+
|
|
14
55
|
self.index = index
|
|
15
|
-
self.search_func = search_func
|
|
16
56
|
self.search_attr = search_attr
|
|
17
|
-
self.
|
|
57
|
+
self.output_attrs = output_attrs
|
|
58
|
+
self.search_func = search_func if search_func is not None else self.default_search_func
|
|
18
59
|
self.k = k
|
|
19
60
|
|
|
20
61
|
def __str__(self):
|
|
21
62
|
op = super().__str__()
|
|
22
|
-
op += f" Retrieve: {
|
|
63
|
+
op += f" Retrieve: {self.index.__class__.__name__} with top {self.k}\n"
|
|
23
64
|
return op
|
|
24
65
|
|
|
25
66
|
def get_id_params(self):
|
|
26
67
|
id_params = super().get_id_params()
|
|
27
68
|
id_params = {
|
|
69
|
+
"index": self.index.__class__.__name__,
|
|
28
70
|
"search_attr": self.search_attr,
|
|
29
|
-
"
|
|
71
|
+
"output_attrs": self.output_attrs,
|
|
30
72
|
"k": self.k,
|
|
31
73
|
**id_params,
|
|
32
74
|
}
|
|
@@ -39,7 +81,7 @@ class RetrieveOp(PhysicalOperator):
|
|
|
39
81
|
"index": self.index,
|
|
40
82
|
"search_func": self.search_func,
|
|
41
83
|
"search_attr": self.search_attr,
|
|
42
|
-
"
|
|
84
|
+
"output_attrs": self.output_attrs,
|
|
43
85
|
"k": self.k,
|
|
44
86
|
**op_params,
|
|
45
87
|
}
|
|
@@ -53,37 +95,86 @@ class RetrieveOp(PhysicalOperator):
|
|
|
53
95
|
"""
|
|
54
96
|
return OperatorCostEstimates(
|
|
55
97
|
cardinality=source_op_cost_estimates.cardinality,
|
|
56
|
-
time_per_record=0.
|
|
57
|
-
cost_per_record=0.
|
|
98
|
+
time_per_record=0.01 * self.k, # estimate 10 ms execution lookup per output
|
|
99
|
+
cost_per_record=0.001 * self.k, # estimate small marginal cost of lookups
|
|
58
100
|
quality=1.0,
|
|
59
101
|
)
|
|
60
102
|
|
|
61
|
-
def
|
|
62
|
-
|
|
103
|
+
def default_search_func(self, index: Collection | RAGPretrainedModel, query: list[str] | list[list[float]], k: int) -> list[str] | list[list[str]]:
|
|
104
|
+
"""
|
|
105
|
+
Default search function for the Retrieve operation. This function uses the index to
|
|
106
|
+
retrieve the top-k results for the given query. The query will be a (possibly singleton)
|
|
107
|
+
list of strings or a list of lists of floats (i.e., embeddings). The function will return
|
|
108
|
+
the top-k results per-query in (descending) sorted order. If the input is a singleton list,
|
|
109
|
+
then the output will be a list of strings. If the input is a list of lists, then the output
|
|
110
|
+
will be a list of lists of strings.
|
|
63
111
|
|
|
64
|
-
|
|
112
|
+
Args:
|
|
113
|
+
index (PZIndex): The index to use for retrieval.
|
|
114
|
+
query (list[str] | list[list[float]]): The query (or queries) to search for.
|
|
115
|
+
k (int): The maximum number of results the retrieve operator will return.
|
|
65
116
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
117
|
+
Returns:
|
|
118
|
+
list[str] | list[list[str]]: The top results in (descending) sorted order per query.
|
|
119
|
+
"""
|
|
120
|
+
# check if the input is a singleton list or a list of lists
|
|
121
|
+
is_singleton_list = len(query) == 1
|
|
122
|
+
|
|
123
|
+
if isinstance(index, Collection):
|
|
124
|
+
# if the index is a chromadb collection, use the query method
|
|
125
|
+
results = index.query(query, n_results=k)
|
|
126
|
+
|
|
127
|
+
# the results["documents"] will be a list[list[str]]; if the input is a singleton list,
|
|
128
|
+
# then we output the list of strings (i.e., the first element of the list), otherwise
|
|
129
|
+
# we output the list of lists
|
|
130
|
+
final_results = results["documents"][0] if is_singleton_list else results["documents"]
|
|
131
|
+
|
|
132
|
+
# NOTE: self.output_field_names must be a singleton for default_search_func to be used
|
|
133
|
+
return {self.output_field_names[0]: final_results}
|
|
74
134
|
|
|
75
|
-
|
|
76
|
-
|
|
135
|
+
elif isinstance(index, RAGPretrainedModel):
|
|
136
|
+
# if the index is a rag model, use the rag model to get the top k results
|
|
137
|
+
results = index.search(query, k=k)
|
|
138
|
+
|
|
139
|
+
# the results will be a list[dict]; if the input is a singleton list, however
|
|
140
|
+
# it will be a list[list[dict]]; if the input is a list of lists
|
|
141
|
+
final_results = []
|
|
142
|
+
if is_singleton_list:
|
|
143
|
+
final_results = [result["content"] for result in results]
|
|
144
|
+
else:
|
|
145
|
+
for query_results in results:
|
|
146
|
+
final_results.append([result["content"] for result in query_results])
|
|
147
|
+
|
|
148
|
+
# NOTE: self.output_field_names must be a singleton for default_search_func to be used
|
|
149
|
+
return {self.output_field_names[0]: final_results}
|
|
150
|
+
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError("Unsupported index type. Must be either a Collection or RAGPretrainedModel.")
|
|
153
|
+
|
|
154
|
+
def _create_record_set(
|
|
155
|
+
self,
|
|
156
|
+
candidate: DataRecord,
|
|
157
|
+
top_k_results: dict[str, list[str] | list[list[str]]] | None,
|
|
158
|
+
generation_stats: GenerationStats,
|
|
159
|
+
total_time: float,
|
|
160
|
+
) -> DataRecordSet:
|
|
161
|
+
"""
|
|
162
|
+
Given an input DataRecord and the top_k_results, construct the resulting RecordSet.
|
|
163
|
+
"""
|
|
164
|
+
# create output DataRecord an set the output attribute
|
|
165
|
+
output_dr, answer = DataRecord.from_parent(self.output_schema, parent_record=candidate), {}
|
|
166
|
+
for output_field_name in self.output_field_names:
|
|
167
|
+
top_k_attr_results = None if top_k_results is None else top_k_results[output_field_name]
|
|
168
|
+
setattr(output_dr, output_field_name, top_k_attr_results)
|
|
169
|
+
answer[output_field_name] = top_k_attr_results
|
|
77
170
|
|
|
78
|
-
|
|
79
|
-
answer = {self.output_attr: top_k_results}
|
|
171
|
+
# get the record_state and generated fields
|
|
80
172
|
record_state = output_dr.to_dict(include_bytes=False)
|
|
81
173
|
|
|
82
|
-
# NOTE:
|
|
83
|
-
|
|
84
|
-
# return the full field name (as opposed to the short field name))
|
|
85
|
-
generated_fields = self.get_fields_to_generate(candidate)
|
|
174
|
+
# NOTE: this should be equivalent to self.get_fields_to_generate()
|
|
175
|
+
generated_fields = self.output_field_names
|
|
86
176
|
|
|
177
|
+
# construct the RecordOpStats object
|
|
87
178
|
record_op_stats = RecordOpStats(
|
|
88
179
|
record_id=output_dr.id,
|
|
89
180
|
record_parent_id=output_dr.parent_id,
|
|
@@ -92,19 +183,124 @@ class RetrieveOp(PhysicalOperator):
|
|
|
92
183
|
op_id=self.get_op_id(),
|
|
93
184
|
logical_op_id=self.logical_op_id,
|
|
94
185
|
op_name=self.op_name(),
|
|
95
|
-
time_per_record=
|
|
96
|
-
cost_per_record=
|
|
186
|
+
time_per_record=total_time,
|
|
187
|
+
cost_per_record=generation_stats.cost_per_record,
|
|
97
188
|
answer=answer,
|
|
98
189
|
input_fields=self.input_schema.field_names(),
|
|
99
190
|
generated_fields=generated_fields,
|
|
100
|
-
fn_call_duration_secs=
|
|
191
|
+
fn_call_duration_secs=total_time - generation_stats.llm_call_duration_secs,
|
|
192
|
+
llm_call_duration_secs=generation_stats.llm_call_duration_secs,
|
|
193
|
+
total_llm_calls=generation_stats.total_llm_calls,
|
|
194
|
+
total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
|
|
101
195
|
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
102
196
|
)
|
|
103
197
|
|
|
104
198
|
drs = [output_dr]
|
|
105
199
|
record_op_stats_lst = [record_op_stats]
|
|
106
200
|
|
|
107
|
-
# construct record set
|
|
108
|
-
|
|
201
|
+
# construct and return the record set
|
|
202
|
+
return DataRecordSet(drs, record_op_stats_lst)
|
|
109
203
|
|
|
110
|
-
|
|
204
|
+
|
|
205
|
+
def __call__(self, candidate: DataRecord) -> DataRecordSet:
|
|
206
|
+
start_time = time.time()
|
|
207
|
+
|
|
208
|
+
# check that query is a string or list of strings, otherwise return output with self.output_field_names set to None
|
|
209
|
+
query = getattr(candidate, self.search_attr)
|
|
210
|
+
query_is_str = isinstance(query, str)
|
|
211
|
+
query_is_list_of_str = isinstance(query, list) and all(isinstance(q, str) for q in query)
|
|
212
|
+
if not query_is_str and not query_is_list_of_str:
|
|
213
|
+
return self._create_record_set(
|
|
214
|
+
candidate=candidate,
|
|
215
|
+
top_k_results=None,
|
|
216
|
+
generation_stats=GenerationStats(),
|
|
217
|
+
total_time=time.time() - start_time,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# if query is a string, convert it to a list of strings
|
|
221
|
+
if query_is_str:
|
|
222
|
+
query = [query]
|
|
223
|
+
|
|
224
|
+
# compute input/query embedding(s) if the index is a chromadb collection
|
|
225
|
+
inputs, gen_stats = None, GenerationStats()
|
|
226
|
+
if isinstance(self.index, Collection):
|
|
227
|
+
uses_openai_embedding_fcn = isinstance(self.index._embedding_function, OpenAIEmbeddingFunction)
|
|
228
|
+
uses_sentence_transformer_embedding_fcn = isinstance(self.index._embedding_function, SentenceTransformerEmbeddingFunction)
|
|
229
|
+
error_msg = "ChromaDB index must use OpenAI or SentenceTransformer embedding function; see: https://docs.trychroma.com/integrations/embedding-models/openai"
|
|
230
|
+
assert uses_openai_embedding_fcn or uses_sentence_transformer_embedding_fcn, error_msg
|
|
231
|
+
|
|
232
|
+
model_name = self.index._embedding_function._model_name if uses_openai_embedding_fcn else "clip-ViT-B-32"
|
|
233
|
+
err_msg = f"For Chromadb, we currently only support `text-embedding-3-small` and `clip-ViT-B-32`; your index uses: {model_name}"
|
|
234
|
+
assert model_name in [Model.TEXT_EMBEDDING_3_SMALL.value, Model.CLIP_VIT_B_32.value], err_msg
|
|
235
|
+
|
|
236
|
+
# compute embeddings
|
|
237
|
+
try:
|
|
238
|
+
embed_start_time = time.time()
|
|
239
|
+
total_input_tokens = 0.0
|
|
240
|
+
if uses_openai_embedding_fcn:
|
|
241
|
+
client = OpenAI()
|
|
242
|
+
response = client.embeddings.create(input=query, model=model_name)
|
|
243
|
+
total_input_tokens = response.usage.total_tokens
|
|
244
|
+
inputs = [item.embedding for item in response.data]
|
|
245
|
+
|
|
246
|
+
elif uses_sentence_transformer_embedding_fcn:
|
|
247
|
+
model = SentenceTransformer(model_name)
|
|
248
|
+
inputs = model.encode(query)
|
|
249
|
+
|
|
250
|
+
embed_total_time = time.time() - embed_start_time
|
|
251
|
+
|
|
252
|
+
# compute cost of embedding(s)
|
|
253
|
+
model_card = MODEL_CARDS[model_name]
|
|
254
|
+
total_input_cost = model_card["usd_per_input_token"] * total_input_tokens
|
|
255
|
+
gen_stats = GenerationStats(
|
|
256
|
+
model_name=model_name,
|
|
257
|
+
total_input_tokens=total_input_tokens,
|
|
258
|
+
total_output_tokens=0.0,
|
|
259
|
+
total_input_cost=total_input_cost,
|
|
260
|
+
total_output_cost=0.0,
|
|
261
|
+
cost_per_record=total_input_cost,
|
|
262
|
+
llm_call_duration_secs=embed_total_time,
|
|
263
|
+
total_llm_calls=1,
|
|
264
|
+
total_embedding_llm_calls=len(query),
|
|
265
|
+
)
|
|
266
|
+
except Exception:
|
|
267
|
+
query = None
|
|
268
|
+
|
|
269
|
+
# in the default case, pass string inputs rather than embeddings
|
|
270
|
+
if inputs is None:
|
|
271
|
+
inputs = query
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
assert inputs is not None, "Error: inputs is None (likely because embedding generation failed)"
|
|
275
|
+
top_results = self.search_func(self.index, inputs, self.k)
|
|
276
|
+
|
|
277
|
+
except Exception:
|
|
278
|
+
top_results = ["error-in-retrieve"]
|
|
279
|
+
os.makedirs("retrieve-errors", exist_ok=True)
|
|
280
|
+
ts = time.time()
|
|
281
|
+
with open(f"retrieve-errors/error-{ts}.txt", "w") as f:
|
|
282
|
+
f.write(str(query))
|
|
283
|
+
|
|
284
|
+
# TODO: the user is always right! let's drop this post-processing in the future
|
|
285
|
+
# filter top_results for the top_k_results
|
|
286
|
+
top_k_results = {output_field_name: [] for output_field_name in self.output_field_names}
|
|
287
|
+
for output_field_name in self.output_field_names:
|
|
288
|
+
if output_field_name in top_results:
|
|
289
|
+
if all([isinstance(result, list) for result in top_results[output_field_name]]):
|
|
290
|
+
for result in top_results[output_field_name]:
|
|
291
|
+
top_k_results[output_field_name].append(result[:self.k])
|
|
292
|
+
else:
|
|
293
|
+
top_k_results[output_field_name] = top_results[output_field_name][:self.k]
|
|
294
|
+
else:
|
|
295
|
+
top_k_results[output_field_name] = []
|
|
296
|
+
|
|
297
|
+
if self.verbose:
|
|
298
|
+
print(f"Top {self.k} results: {top_k_results}")
|
|
299
|
+
|
|
300
|
+
# construct and return the record set
|
|
301
|
+
return self._create_record_set(
|
|
302
|
+
candidate=candidate,
|
|
303
|
+
top_k_results=top_k_results,
|
|
304
|
+
generation_stats=gen_stats,
|
|
305
|
+
total_time=time.time() - start_time,
|
|
306
|
+
)
|
|
@@ -69,14 +69,17 @@ class ScanPhysicalOp(PhysicalOperator, ABC):
|
|
|
69
69
|
item = self.datareader[idx]
|
|
70
70
|
end_time = time.time()
|
|
71
71
|
|
|
72
|
+
# TODO: remove once validation data is refactored
|
|
73
|
+
item_field_dict = item.get("fields", item)
|
|
74
|
+
|
|
72
75
|
# check that item covers fields in output schema
|
|
73
76
|
output_field_names = self.output_schema.field_names()
|
|
74
|
-
assert all([field in
|
|
77
|
+
assert all([field in item_field_dict for field in output_field_names]), f"Some fields in DataReader schema not present in item!\n - DataReader fields: {output_field_names}\n - Item fields: {list(item.keys())}"
|
|
75
78
|
|
|
76
79
|
# construct a DataRecord from the item
|
|
77
80
|
dr = DataRecord(self.output_schema, source_idx=idx)
|
|
78
81
|
for field in output_field_names:
|
|
79
|
-
setattr(dr, field,
|
|
82
|
+
setattr(dr, field, item_field_dict[field])
|
|
80
83
|
|
|
81
84
|
# create RecordOpStats objects
|
|
82
85
|
record_op_stats = RecordOpStats(
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from palimpzest.constants import (
|
|
6
|
+
MODEL_CARDS,
|
|
7
|
+
NAIVE_EST_NUM_INPUT_TOKENS,
|
|
8
|
+
NAIVE_EST_NUM_OUTPUT_TOKENS,
|
|
9
|
+
PromptStrategy,
|
|
10
|
+
)
|
|
11
|
+
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
|
|
12
|
+
from palimpzest.core.elements.records import DataRecord
|
|
13
|
+
from palimpzest.core.lib.fields import Field, StringField
|
|
14
|
+
from palimpzest.query.generators.generators import generator_factory
|
|
15
|
+
from palimpzest.query.operators.convert import LLMConvert
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SplitConvert(LLMConvert):
|
|
19
|
+
def __init__(self, num_chunks: int = 2, min_size_to_chunk: int = 1000, *args, **kwargs):
|
|
20
|
+
super().__init__(*args, **kwargs)
|
|
21
|
+
self.num_chunks = num_chunks
|
|
22
|
+
self.min_size_to_chunk = min_size_to_chunk
|
|
23
|
+
self.split_generator = generator_factory(self.model, PromptStrategy.SPLIT_PROPOSER, self.cardinality, self.verbose)
|
|
24
|
+
self.split_merge_generator = generator_factory(self.model, PromptStrategy.SPLIT_MERGER, self.cardinality, self.verbose)
|
|
25
|
+
|
|
26
|
+
# crude adjustment factor for naive estimation in no-sentinel setting
|
|
27
|
+
self.naive_quality_adjustment = 0.6
|
|
28
|
+
|
|
29
|
+
def __str__(self):
|
|
30
|
+
op = super().__str__()
|
|
31
|
+
op += f" Chunk Size: {str(self.num_chunks)}\n"
|
|
32
|
+
op += f" Min Size to Chunk: {str(self.min_size_to_chunk)}\n"
|
|
33
|
+
return op
|
|
34
|
+
|
|
35
|
+
def get_id_params(self):
|
|
36
|
+
id_params = super().get_id_params()
|
|
37
|
+
id_params = {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **id_params}
|
|
38
|
+
|
|
39
|
+
return id_params
|
|
40
|
+
|
|
41
|
+
def get_op_params(self):
|
|
42
|
+
op_params = super().get_op_params()
|
|
43
|
+
return {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **op_params}
|
|
44
|
+
|
|
45
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
46
|
+
"""
|
|
47
|
+
Update the cost per record and quality estimates produced by LLMConvert's naive estimates.
|
|
48
|
+
We adjust the cost per record to account for the reduced number of input tokens following
|
|
49
|
+
the retrieval of relevant chunks, and we make a crude estimate of the quality degradation
|
|
50
|
+
that results from using a downsized input (although this may in fact improve quality in
|
|
51
|
+
some cases).
|
|
52
|
+
"""
|
|
53
|
+
# get naive cost estimates from LLMConvert
|
|
54
|
+
naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
|
|
55
|
+
|
|
56
|
+
# re-compute cost per record assuming we use fewer input tokens; naively assume a single input field
|
|
57
|
+
est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS
|
|
58
|
+
est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
|
|
59
|
+
model_conversion_usd_per_record = (
|
|
60
|
+
MODEL_CARDS[self.model.value]["usd_per_input_token"] * est_num_input_tokens
|
|
61
|
+
+ MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# set refined estimate of cost per record and, for now,
|
|
65
|
+
# assume quality multiplier is proportional to sqrt(sqrt(token_budget))
|
|
66
|
+
naive_op_cost_estimates.cost_per_record = model_conversion_usd_per_record
|
|
67
|
+
naive_op_cost_estimates.cost_per_record_lower_bound = naive_op_cost_estimates.cost_per_record
|
|
68
|
+
naive_op_cost_estimates.cost_per_record_upper_bound = naive_op_cost_estimates.cost_per_record
|
|
69
|
+
naive_op_cost_estimates.quality = (naive_op_cost_estimates.quality) * self.naive_quality_adjustment
|
|
70
|
+
naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
|
|
71
|
+
naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
|
|
72
|
+
|
|
73
|
+
return naive_op_cost_estimates
|
|
74
|
+
|
|
75
|
+
def is_image_conversion(self) -> bool:
|
|
76
|
+
"""SplitConvert is currently disallowed on image conversions, so this must be False."""
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
def get_text_chunks(self, text: str, num_chunks: int) -> list[str]:
|
|
80
|
+
"""
|
|
81
|
+
Given a text string, chunk it into num_chunks substrings of roughly equal size.
|
|
82
|
+
"""
|
|
83
|
+
chunks = []
|
|
84
|
+
|
|
85
|
+
idx, chunk_size = 0, math.ceil(len(text) / num_chunks)
|
|
86
|
+
while idx + chunk_size < len(text):
|
|
87
|
+
chunks.append(text[idx : idx + chunk_size])
|
|
88
|
+
idx += chunk_size
|
|
89
|
+
|
|
90
|
+
if idx < len(text):
|
|
91
|
+
chunks.append(text[idx:])
|
|
92
|
+
|
|
93
|
+
return chunks
|
|
94
|
+
|
|
95
|
+
def get_chunked_candidate(self, candidate: DataRecord, input_fields: list[str]) -> list[DataRecord]:
|
|
96
|
+
"""
|
|
97
|
+
For each text field, chunk the content. If a field is smaller than the chunk size,
|
|
98
|
+
simply include the full field.
|
|
99
|
+
"""
|
|
100
|
+
# compute mapping from each field to its chunked content
|
|
101
|
+
field_name_to_chunked_content = {}
|
|
102
|
+
for field_name in input_fields:
|
|
103
|
+
field = candidate.get_field_type(field_name)
|
|
104
|
+
content = candidate[field_name]
|
|
105
|
+
|
|
106
|
+
# do not chunk this field if it is not a string or a list of strings
|
|
107
|
+
is_string_field = isinstance(field, StringField)
|
|
108
|
+
is_list_string_field = hasattr(field, "element_type") and isinstance(field.element_type, StringField)
|
|
109
|
+
if not (is_string_field or is_list_string_field):
|
|
110
|
+
field_name_to_chunked_content[field_name] = [content]
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
# if this is a list of strings, join the strings
|
|
114
|
+
if is_list_string_field:
|
|
115
|
+
content = "[" + ", ".join(content) + "]"
|
|
116
|
+
|
|
117
|
+
# skip this field if its length is less than the min size to chunk
|
|
118
|
+
if len(content) < self.min_size_to_chunk:
|
|
119
|
+
field_name_to_chunked_content[field_name] = [content]
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
# chunk the content
|
|
123
|
+
field_name_to_chunked_content[field_name] = self.get_text_chunks(content, self.num_chunks)
|
|
124
|
+
|
|
125
|
+
# compute the true number of chunks (may be 1 if all fields are not chunked)
|
|
126
|
+
num_chunks = max(len(chunks) for chunks in field_name_to_chunked_content.values())
|
|
127
|
+
|
|
128
|
+
# create the chunked canidates
|
|
129
|
+
candidates = []
|
|
130
|
+
for chunk_idx in range(num_chunks):
|
|
131
|
+
candidate_copy = candidate.copy()
|
|
132
|
+
for field_name in input_fields:
|
|
133
|
+
field_chunks = field_name_to_chunked_content[field_name]
|
|
134
|
+
candidate_copy[field_name] = field_chunks[chunk_idx] if len(field_chunks) > 1 else field_chunks[0]
|
|
135
|
+
|
|
136
|
+
candidates.append(candidate_copy)
|
|
137
|
+
|
|
138
|
+
return candidates
|
|
139
|
+
|
|
140
|
+
def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
|
|
141
|
+
# get the set of input fields to use for the convert operation
|
|
142
|
+
input_fields = self.get_input_fields()
|
|
143
|
+
|
|
144
|
+
# lookup most relevant chunks for each field using embedding search
|
|
145
|
+
candidate_copy = candidate.copy()
|
|
146
|
+
chunked_candidates = self.get_chunked_candidate(candidate_copy, input_fields)
|
|
147
|
+
|
|
148
|
+
# construct kwargs for generation
|
|
149
|
+
gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema}
|
|
150
|
+
|
|
151
|
+
# generate outputs for each chunk separately
|
|
152
|
+
chunk_outputs, chunk_generation_stats_lst = [], []
|
|
153
|
+
for candidate in chunked_candidates:
|
|
154
|
+
_, reasoning, chunk_generation_stats, _ = self.split_generator(candidate, fields, json_output=False, **gen_kwargs)
|
|
155
|
+
chunk_outputs.append(reasoning)
|
|
156
|
+
chunk_generation_stats_lst.append(chunk_generation_stats)
|
|
157
|
+
|
|
158
|
+
# call the merger
|
|
159
|
+
gen_kwargs = {
|
|
160
|
+
"project_cols": input_fields,
|
|
161
|
+
"output_schema": self.output_schema,
|
|
162
|
+
"chunk_outputs": chunk_outputs,
|
|
163
|
+
}
|
|
164
|
+
field_answers, _, merger_gen_stats, _ = self.split_merge_generator(candidate, fields, **gen_kwargs)
|
|
165
|
+
|
|
166
|
+
# compute the total generation stats
|
|
167
|
+
generation_stats = sum(chunk_generation_stats_lst) + merger_gen_stats
|
|
168
|
+
|
|
169
|
+
return field_answers, generation_stats
|
|
@@ -9,7 +9,7 @@ from palimpzest.constants import (
|
|
|
9
9
|
NAIVE_EST_NUM_OUTPUT_TOKENS,
|
|
10
10
|
)
|
|
11
11
|
from palimpzest.core.data.dataclasses import OperatorCostEstimates
|
|
12
|
-
from palimpzest.query.operators.convert import
|
|
12
|
+
from palimpzest.query.operators.convert import LLMConvertBonded
|
|
13
13
|
from palimpzest.utils.token_reduction_helpers import best_substring_match, find_best_range
|
|
14
14
|
|
|
15
15
|
|
|
@@ -32,8 +32,8 @@ from palimpzest.utils.token_reduction_helpers import best_substring_match, find_
|
|
|
32
32
|
# - this also creates difficulties in properly performing cost-estimation for this operator; e.g. if we use
|
|
33
33
|
# n <= MAX_HEATMAP_UPDATES samples to cost this operator, then we will never actually measure its performance
|
|
34
34
|
# in the token reduction phase -- which could have a serious degradation in quality that our optimizer doesn't see
|
|
35
|
-
class
|
|
36
|
-
# NOTE: moving these closer to the
|
|
35
|
+
class TokenReducedConvertBonded(LLMConvertBonded):
|
|
36
|
+
# NOTE: moving these closer to the TokenReducedConvertBonded class for now (in part to make
|
|
37
37
|
# them easier to mock); we can make these parameterized as well
|
|
38
38
|
MAX_HEATMAP_UPDATES: int = 5
|
|
39
39
|
TOKEN_REDUCTION_SAMPLE: int = 0
|
|
@@ -90,9 +90,9 @@ class TokenReducedConvert(LLMConvert):
|
|
|
90
90
|
naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
|
|
91
91
|
|
|
92
92
|
return naive_op_cost_estimates
|
|
93
|
-
|
|
93
|
+
|
|
94
94
|
def is_image_conversion(self) -> bool:
|
|
95
|
-
"""
|
|
95
|
+
"""TokenReducedConvertBonded is currently disallowed on image conversions, so this must be False."""
|
|
96
96
|
return False
|
|
97
97
|
|
|
98
98
|
def reduce_context(self, full_context: str) -> str:
|
|
@@ -119,7 +119,9 @@ class TokenReducedConvert(LLMConvert):
|
|
|
119
119
|
return sample
|
|
120
120
|
|
|
121
121
|
def _dspy_generate_fields(self, prompt: str, content: str | list[str]) -> tuple[list[dict[str, list]] | Any]:
|
|
122
|
-
raise Exception(
|
|
122
|
+
raise Exception(
|
|
123
|
+
"TokenReducedConvertBonded is executing despite being deprecated until implementation changes can be made."
|
|
124
|
+
)
|
|
123
125
|
answer, query_stats = None, None
|
|
124
126
|
if self.first_execution or self.count < self.MAX_HEATMAP_UPDATES:
|
|
125
127
|
if self.verbose:
|
|
@@ -165,11 +167,3 @@ class TokenReducedConvert(LLMConvert):
|
|
|
165
167
|
self.heatmap[norm_si:norm_ei] = map(lambda x: x + 1, self.heatmap[norm_si:norm_ei])
|
|
166
168
|
|
|
167
169
|
return answer, query_stats
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
class TokenReducedConvertConventional(TokenReducedConvert, LLMConvertConventional):
|
|
171
|
-
pass
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
class TokenReducedConvertBonded(TokenReducedConvert, LLMConvertBonded):
|
|
175
|
-
pass
|
|
@@ -19,12 +19,6 @@ from palimpzest.query.optimizer.rules import (
|
|
|
19
19
|
from palimpzest.query.optimizer.rules import (
|
|
20
20
|
LLMConvertBondedRule as _LLMConvertBondedRule,
|
|
21
21
|
)
|
|
22
|
-
from palimpzest.query.optimizer.rules import (
|
|
23
|
-
LLMConvertConventionalRule as _LLMConvertConventionalRule,
|
|
24
|
-
)
|
|
25
|
-
from palimpzest.query.optimizer.rules import (
|
|
26
|
-
LLMConvertRule as _LLMConvertRule,
|
|
27
|
-
)
|
|
28
22
|
from palimpzest.query.optimizer.rules import (
|
|
29
23
|
LLMFilterRule as _LLMFilterRule,
|
|
30
24
|
)
|
|
@@ -50,13 +44,10 @@ from palimpzest.query.optimizer.rules import (
|
|
|
50
44
|
Rule as _Rule,
|
|
51
45
|
)
|
|
52
46
|
from palimpzest.query.optimizer.rules import (
|
|
53
|
-
|
|
54
|
-
)
|
|
55
|
-
from palimpzest.query.optimizer.rules import (
|
|
56
|
-
TokenReducedConvertConventionalRule as _TokenReducedConvertConventionalRule,
|
|
47
|
+
SplitConvertRule as _SplitConvertRule,
|
|
57
48
|
)
|
|
58
49
|
from palimpzest.query.optimizer.rules import (
|
|
59
|
-
|
|
50
|
+
TokenReducedConvertBondedRule as _TokenReducedConvertBondedRule,
|
|
60
51
|
)
|
|
61
52
|
from palimpzest.query.optimizer.rules import (
|
|
62
53
|
TransformationRule as _TransformationRule,
|
|
@@ -70,8 +61,6 @@ ALL_RULES = [
|
|
|
70
61
|
_CriticAndRefineConvertRule,
|
|
71
62
|
_ImplementationRule,
|
|
72
63
|
_LLMConvertBondedRule,
|
|
73
|
-
_LLMConvertConventionalRule,
|
|
74
|
-
_LLMConvertRule,
|
|
75
64
|
_LLMFilterRule,
|
|
76
65
|
_MixtureOfAgentsConvertRule,
|
|
77
66
|
_NonLLMConvertRule,
|
|
@@ -80,9 +69,8 @@ ALL_RULES = [
|
|
|
80
69
|
_RAGConvertRule,
|
|
81
70
|
_RetrieveRule,
|
|
82
71
|
_Rule,
|
|
72
|
+
_SplitConvertRule,
|
|
83
73
|
_TokenReducedConvertBondedRule,
|
|
84
|
-
_TokenReducedConvertConventionalRule,
|
|
85
|
-
_TokenReducedConvertRule,
|
|
86
74
|
_TransformationRule,
|
|
87
75
|
]
|
|
88
76
|
|
|
@@ -90,7 +78,7 @@ IMPLEMENTATION_RULES = [
|
|
|
90
78
|
rule
|
|
91
79
|
for rule in ALL_RULES
|
|
92
80
|
if issubclass(rule, _ImplementationRule)
|
|
93
|
-
and rule not in [_CodeSynthesisConvertRule, _ImplementationRule
|
|
81
|
+
and rule not in [_CodeSynthesisConvertRule, _ImplementationRule]
|
|
94
82
|
]
|
|
95
83
|
|
|
96
84
|
TRANSFORMATION_RULES = [
|