palimpzest 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. palimpzest/__init__.py +5 -0
  2. palimpzest/constants.py +110 -43
  3. palimpzest/core/__init__.py +0 -78
  4. palimpzest/core/data/dataclasses.py +382 -44
  5. palimpzest/core/elements/filters.py +7 -3
  6. palimpzest/core/elements/index.py +70 -0
  7. palimpzest/core/elements/records.py +33 -11
  8. palimpzest/core/lib/fields.py +1 -0
  9. palimpzest/core/lib/schemas.py +4 -3
  10. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
  11. palimpzest/prompts/prompt_factory.py +44 -7
  12. palimpzest/prompts/split_merge_prompts.py +56 -0
  13. palimpzest/prompts/split_proposer_prompts.py +55 -0
  14. palimpzest/query/execution/execution_strategy.py +435 -53
  15. palimpzest/query/execution/execution_strategy_type.py +20 -0
  16. palimpzest/query/execution/mab_execution_strategy.py +532 -0
  17. palimpzest/query/execution/parallel_execution_strategy.py +143 -172
  18. palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
  19. palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
  20. palimpzest/query/generators/api_client_factory.py +31 -0
  21. palimpzest/query/generators/generators.py +256 -76
  22. palimpzest/query/operators/__init__.py +1 -2
  23. palimpzest/query/operators/code_synthesis_convert.py +33 -18
  24. palimpzest/query/operators/convert.py +30 -97
  25. palimpzest/query/operators/critique_and_refine_convert.py +5 -6
  26. palimpzest/query/operators/filter.py +7 -10
  27. palimpzest/query/operators/logical.py +54 -10
  28. palimpzest/query/operators/map.py +130 -0
  29. palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
  30. palimpzest/query/operators/physical.py +3 -12
  31. palimpzest/query/operators/rag_convert.py +66 -18
  32. palimpzest/query/operators/retrieve.py +230 -34
  33. palimpzest/query/operators/scan.py +5 -2
  34. palimpzest/query/operators/split_convert.py +169 -0
  35. palimpzest/query/operators/token_reduction_convert.py +8 -14
  36. palimpzest/query/optimizer/__init__.py +4 -16
  37. palimpzest/query/optimizer/cost_model.py +73 -266
  38. palimpzest/query/optimizer/optimizer.py +87 -58
  39. palimpzest/query/optimizer/optimizer_strategy.py +18 -97
  40. palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
  41. palimpzest/query/optimizer/plan.py +2 -3
  42. palimpzest/query/optimizer/primitives.py +5 -3
  43. palimpzest/query/optimizer/rules.py +336 -172
  44. palimpzest/query/optimizer/tasks.py +30 -100
  45. palimpzest/query/processor/config.py +38 -22
  46. palimpzest/query/processor/nosentinel_processor.py +16 -520
  47. palimpzest/query/processor/processing_strategy_type.py +28 -0
  48. palimpzest/query/processor/query_processor.py +38 -206
  49. palimpzest/query/processor/query_processor_factory.py +117 -130
  50. palimpzest/query/processor/sentinel_processor.py +90 -0
  51. palimpzest/query/processor/streaming_processor.py +25 -32
  52. palimpzest/sets.py +88 -41
  53. palimpzest/utils/model_helpers.py +8 -7
  54. palimpzest/utils/progress.py +368 -152
  55. palimpzest/utils/token_reduction_helpers.py +1 -3
  56. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/METADATA +28 -24
  57. palimpzest-0.7.0.dist-info/RECORD +96 -0
  58. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
  59. palimpzest/query/processor/mab_sentinel_processor.py +0 -884
  60. palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
  61. palimpzest/utils/index_helpers.py +0 -6
  62. palimpzest-0.6.3.dist-info/RECORD +0 -87
  63. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
  64. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
@@ -2,31 +2,73 @@ from __future__ import annotations
2
2
 
3
3
  import os
4
4
  import time
5
+ from typing import Callable
5
6
 
6
- from palimpzest.core.data.dataclasses import OperatorCostEstimates, RecordOpStats
7
+ from chromadb.api.models.Collection import Collection
8
+ from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
9
+ from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
10
+ from openai import OpenAI
11
+ from ragatouille.RAGPretrainedModel import RAGPretrainedModel
12
+ from sentence_transformers import SentenceTransformer
13
+
14
+ from palimpzest.constants import MODEL_CARDS, Model
15
+ from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates, RecordOpStats
7
16
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
17
+ from palimpzest.core.lib.schemas import Schema
8
18
  from palimpzest.query.operators.physical import PhysicalOperator
9
19
 
10
20
 
11
21
  class RetrieveOp(PhysicalOperator):
12
- def __init__(self, index, search_func, search_attr, output_attr, k, *args, **kwargs):
22
+ def __init__(
23
+ self,
24
+ index: Collection | RAGPretrainedModel,
25
+ search_attr: str,
26
+ output_attrs: list[dict] | type[Schema],
27
+ search_func: Callable | None,
28
+ k: int,
29
+ *args,
30
+ **kwargs,
31
+ ) -> None:
32
+ """
33
+ Initialize the RetrieveOp object.
34
+
35
+ Args:
36
+ index (Collection | RAGPretrainedModel): The PZ index to use for retrieval.
37
+ search_attr (str): The attribute to search on.
38
+ output_attrs (list[dict]): The output fields containing the results of the search.
39
+ search_func (Callable | None): The function to use for searching the index. If None, the default search function will be used.
40
+ k (int): The number of top results to retrieve.
41
+ """
13
42
  super().__init__(*args, **kwargs)
43
+
44
+ # extract the field names from the output_attrs
45
+ if isinstance(output_attrs, Schema):
46
+ self.output_field_names = output_attrs.field_names()
47
+ elif isinstance(output_attrs, list):
48
+ self.output_field_names = [attr["name"] for attr in output_attrs]
49
+ else:
50
+ raise ValueError("`output_attrs` must be a list of dicts or a Schema object.")
51
+
52
+ if len(self.output_field_names) != 1 and search_func is None:
53
+ raise ValueError("If `search_func` is None, `output_attrs` must have a single field.")
54
+
14
55
  self.index = index
15
- self.search_func = search_func
16
56
  self.search_attr = search_attr
17
- self.output_attr = output_attr
57
+ self.output_attrs = output_attrs
58
+ self.search_func = search_func if search_func is not None else self.default_search_func
18
59
  self.k = k
19
60
 
20
61
  def __str__(self):
21
62
  op = super().__str__()
22
- op += f" Retrieve: {str(self.index)} with top {self.k}\n"
63
+ op += f" Retrieve: {self.index.__class__.__name__} with top {self.k}\n"
23
64
  return op
24
65
 
25
66
  def get_id_params(self):
26
67
  id_params = super().get_id_params()
27
68
  id_params = {
69
+ "index": self.index.__class__.__name__,
28
70
  "search_attr": self.search_attr,
29
- "output_attr": self.output_attr,
71
+ "output_attrs": self.output_attrs,
30
72
  "k": self.k,
31
73
  **id_params,
32
74
  }
@@ -39,7 +81,7 @@ class RetrieveOp(PhysicalOperator):
39
81
  "index": self.index,
40
82
  "search_func": self.search_func,
41
83
  "search_attr": self.search_attr,
42
- "output_attr": self.output_attr,
84
+ "output_attrs": self.output_attrs,
43
85
  "k": self.k,
44
86
  **op_params,
45
87
  }
@@ -53,37 +95,86 @@ class RetrieveOp(PhysicalOperator):
53
95
  """
54
96
  return OperatorCostEstimates(
55
97
  cardinality=source_op_cost_estimates.cardinality,
56
- time_per_record=0.001, # estimate 1 ms single-threaded execution for index lookup
57
- cost_per_record=0.0,
98
+ time_per_record=0.01 * self.k, # estimate 10 ms execution lookup per output
99
+ cost_per_record=0.001 * self.k, # estimate small marginal cost of lookups
58
100
  quality=1.0,
59
101
  )
60
102
 
61
- def __call__(self, candidate: DataRecord) -> DataRecordSet:
62
- start_time = time.time()
103
+ def default_search_func(self, index: Collection | RAGPretrainedModel, query: list[str] | list[list[float]], k: int) -> list[str] | list[list[str]]:
104
+ """
105
+ Default search function for the Retrieve operation. This function uses the index to
106
+ retrieve the top-k results for the given query. The query will be a (possibly singleton)
107
+ list of strings or a list of lists of floats (i.e., embeddings). The function will return
108
+ the top-k results per-query in (descending) sorted order. If the input is a singleton list,
109
+ then the output will be a list of strings. If the input is a list of lists, then the output
110
+ will be a list of lists of strings.
63
111
 
64
- query = getattr(candidate, self.search_attr)
112
+ Args:
113
+ index (PZIndex): The index to use for retrieval.
114
+ query (list[str] | list[list[float]]): The query (or queries) to search for.
115
+ k (int): The maximum number of results the retrieve operator will return.
65
116
 
66
- try:
67
- top_k_results = self.search_func(self.index, query, self.k)
68
- except Exception:
69
- top_k_results = ["error-in-retrieve"]
70
- os.makedirs("retrieve-errors", exist_ok=True)
71
- ts = time.time()
72
- with open(f"retrieve-errors/error-{ts}.txt", "w") as f:
73
- f.write(str(query))
117
+ Returns:
118
+ list[str] | list[list[str]]: The top results in (descending) sorted order per query.
119
+ """
120
+ # check if the input is a singleton list or a list of lists
121
+ is_singleton_list = len(query) == 1
122
+
123
+ if isinstance(index, Collection):
124
+ # if the index is a chromadb collection, use the query method
125
+ results = index.query(query, n_results=k)
126
+
127
+ # the results["documents"] will be a list[list[str]]; if the input is a singleton list,
128
+ # then we output the list of strings (i.e., the first element of the list), otherwise
129
+ # we output the list of lists
130
+ final_results = results["documents"][0] if is_singleton_list else results["documents"]
131
+
132
+ # NOTE: self.output_field_names must be a singleton for default_search_func to be used
133
+ return {self.output_field_names[0]: final_results}
74
134
 
75
- output_dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
76
- setattr(output_dr, self.output_attr, top_k_results)
135
+ elif isinstance(index, RAGPretrainedModel):
136
+ # if the index is a rag model, use the rag model to get the top k results
137
+ results = index.search(query, k=k)
138
+
139
+ # the results will be a list[dict]; if the input is a singleton list, however
140
+ # it will be a list[list[dict]]; if the input is a list of lists
141
+ final_results = []
142
+ if is_singleton_list:
143
+ final_results = [result["content"] for result in results]
144
+ else:
145
+ for query_results in results:
146
+ final_results.append([result["content"] for result in query_results])
147
+
148
+ # NOTE: self.output_field_names must be a singleton for default_search_func to be used
149
+ return {self.output_field_names[0]: final_results}
150
+
151
+ else:
152
+ raise ValueError("Unsupported index type. Must be either a Collection or RAGPretrainedModel.")
153
+
154
+ def _create_record_set(
155
+ self,
156
+ candidate: DataRecord,
157
+ top_k_results: dict[str, list[str] | list[list[str]]] | None,
158
+ generation_stats: GenerationStats,
159
+ total_time: float,
160
+ ) -> DataRecordSet:
161
+ """
162
+ Given an input DataRecord and the top_k_results, construct the resulting RecordSet.
163
+ """
164
+ # create output DataRecord an set the output attribute
165
+ output_dr, answer = DataRecord.from_parent(self.output_schema, parent_record=candidate), {}
166
+ for output_field_name in self.output_field_names:
167
+ top_k_attr_results = None if top_k_results is None else top_k_results[output_field_name]
168
+ setattr(output_dr, output_field_name, top_k_attr_results)
169
+ answer[output_field_name] = top_k_attr_results
77
170
 
78
- duration_secs = time.time() - start_time
79
- answer = {self.output_attr: top_k_results}
171
+ # get the record_state and generated fields
80
172
  record_state = output_dr.to_dict(include_bytes=False)
81
173
 
82
- # NOTE: right now this should be equivalent to [self.output_attr], but in the future we may
83
- # want to support the RetrieveOp generating multiple fields. (Also, the function will
84
- # return the full field name (as opposed to the short field name))
85
- generated_fields = self.get_fields_to_generate(candidate)
174
+ # NOTE: this should be equivalent to self.get_fields_to_generate()
175
+ generated_fields = self.output_field_names
86
176
 
177
+ # construct the RecordOpStats object
87
178
  record_op_stats = RecordOpStats(
88
179
  record_id=output_dr.id,
89
180
  record_parent_id=output_dr.parent_id,
@@ -92,19 +183,124 @@ class RetrieveOp(PhysicalOperator):
92
183
  op_id=self.get_op_id(),
93
184
  logical_op_id=self.logical_op_id,
94
185
  op_name=self.op_name(),
95
- time_per_record=duration_secs,
96
- cost_per_record=0.0,
186
+ time_per_record=total_time,
187
+ cost_per_record=generation_stats.cost_per_record,
97
188
  answer=answer,
98
189
  input_fields=self.input_schema.field_names(),
99
190
  generated_fields=generated_fields,
100
- fn_call_duration_secs=duration_secs,
191
+ fn_call_duration_secs=total_time - generation_stats.llm_call_duration_secs,
192
+ llm_call_duration_secs=generation_stats.llm_call_duration_secs,
193
+ total_llm_calls=generation_stats.total_llm_calls,
194
+ total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
101
195
  op_details={k: str(v) for k, v in self.get_id_params().items()},
102
196
  )
103
197
 
104
198
  drs = [output_dr]
105
199
  record_op_stats_lst = [record_op_stats]
106
200
 
107
- # construct record set
108
- record_set = DataRecordSet(drs, record_op_stats_lst)
201
+ # construct and return the record set
202
+ return DataRecordSet(drs, record_op_stats_lst)
109
203
 
110
- return record_set
204
+
205
+ def __call__(self, candidate: DataRecord) -> DataRecordSet:
206
+ start_time = time.time()
207
+
208
+ # check that query is a string or list of strings, otherwise return output with self.output_field_names set to None
209
+ query = getattr(candidate, self.search_attr)
210
+ query_is_str = isinstance(query, str)
211
+ query_is_list_of_str = isinstance(query, list) and all(isinstance(q, str) for q in query)
212
+ if not query_is_str and not query_is_list_of_str:
213
+ return self._create_record_set(
214
+ candidate=candidate,
215
+ top_k_results=None,
216
+ generation_stats=GenerationStats(),
217
+ total_time=time.time() - start_time,
218
+ )
219
+
220
+ # if query is a string, convert it to a list of strings
221
+ if query_is_str:
222
+ query = [query]
223
+
224
+ # compute input/query embedding(s) if the index is a chromadb collection
225
+ inputs, gen_stats = None, GenerationStats()
226
+ if isinstance(self.index, Collection):
227
+ uses_openai_embedding_fcn = isinstance(self.index._embedding_function, OpenAIEmbeddingFunction)
228
+ uses_sentence_transformer_embedding_fcn = isinstance(self.index._embedding_function, SentenceTransformerEmbeddingFunction)
229
+ error_msg = "ChromaDB index must use OpenAI or SentenceTransformer embedding function; see: https://docs.trychroma.com/integrations/embedding-models/openai"
230
+ assert uses_openai_embedding_fcn or uses_sentence_transformer_embedding_fcn, error_msg
231
+
232
+ model_name = self.index._embedding_function._model_name if uses_openai_embedding_fcn else "clip-ViT-B-32"
233
+ err_msg = f"For Chromadb, we currently only support `text-embedding-3-small` and `clip-ViT-B-32`; your index uses: {model_name}"
234
+ assert model_name in [Model.TEXT_EMBEDDING_3_SMALL.value, Model.CLIP_VIT_B_32.value], err_msg
235
+
236
+ # compute embeddings
237
+ try:
238
+ embed_start_time = time.time()
239
+ total_input_tokens = 0.0
240
+ if uses_openai_embedding_fcn:
241
+ client = OpenAI()
242
+ response = client.embeddings.create(input=query, model=model_name)
243
+ total_input_tokens = response.usage.total_tokens
244
+ inputs = [item.embedding for item in response.data]
245
+
246
+ elif uses_sentence_transformer_embedding_fcn:
247
+ model = SentenceTransformer(model_name)
248
+ inputs = model.encode(query)
249
+
250
+ embed_total_time = time.time() - embed_start_time
251
+
252
+ # compute cost of embedding(s)
253
+ model_card = MODEL_CARDS[model_name]
254
+ total_input_cost = model_card["usd_per_input_token"] * total_input_tokens
255
+ gen_stats = GenerationStats(
256
+ model_name=model_name,
257
+ total_input_tokens=total_input_tokens,
258
+ total_output_tokens=0.0,
259
+ total_input_cost=total_input_cost,
260
+ total_output_cost=0.0,
261
+ cost_per_record=total_input_cost,
262
+ llm_call_duration_secs=embed_total_time,
263
+ total_llm_calls=1,
264
+ total_embedding_llm_calls=len(query),
265
+ )
266
+ except Exception:
267
+ query = None
268
+
269
+ # in the default case, pass string inputs rather than embeddings
270
+ if inputs is None:
271
+ inputs = query
272
+
273
+ try:
274
+ assert inputs is not None, "Error: inputs is None (likely because embedding generation failed)"
275
+ top_results = self.search_func(self.index, inputs, self.k)
276
+
277
+ except Exception:
278
+ top_results = ["error-in-retrieve"]
279
+ os.makedirs("retrieve-errors", exist_ok=True)
280
+ ts = time.time()
281
+ with open(f"retrieve-errors/error-{ts}.txt", "w") as f:
282
+ f.write(str(query))
283
+
284
+ # TODO: the user is always right! let's drop this post-processing in the future
285
+ # filter top_results for the top_k_results
286
+ top_k_results = {output_field_name: [] for output_field_name in self.output_field_names}
287
+ for output_field_name in self.output_field_names:
288
+ if output_field_name in top_results:
289
+ if all([isinstance(result, list) for result in top_results[output_field_name]]):
290
+ for result in top_results[output_field_name]:
291
+ top_k_results[output_field_name].append(result[:self.k])
292
+ else:
293
+ top_k_results[output_field_name] = top_results[output_field_name][:self.k]
294
+ else:
295
+ top_k_results[output_field_name] = []
296
+
297
+ if self.verbose:
298
+ print(f"Top {self.k} results: {top_k_results}")
299
+
300
+ # construct and return the record set
301
+ return self._create_record_set(
302
+ candidate=candidate,
303
+ top_k_results=top_k_results,
304
+ generation_stats=gen_stats,
305
+ total_time=time.time() - start_time,
306
+ )
@@ -69,14 +69,17 @@ class ScanPhysicalOp(PhysicalOperator, ABC):
69
69
  item = self.datareader[idx]
70
70
  end_time = time.time()
71
71
 
72
+ # TODO: remove once validation data is refactored
73
+ item_field_dict = item.get("fields", item)
74
+
72
75
  # check that item covers fields in output schema
73
76
  output_field_names = self.output_schema.field_names()
74
- assert all([field in item for field in output_field_names]), f"Some fields in DataReader schema not present in item!\n - DataReader fields: {output_field_names}\n - Item fields: {list(item.keys())}"
77
+ assert all([field in item_field_dict for field in output_field_names]), f"Some fields in DataReader schema not present in item!\n - DataReader fields: {output_field_names}\n - Item fields: {list(item.keys())}"
75
78
 
76
79
  # construct a DataRecord from the item
77
80
  dr = DataRecord(self.output_schema, source_idx=idx)
78
81
  for field in output_field_names:
79
- setattr(dr, field, item[field])
82
+ setattr(dr, field, item_field_dict[field])
80
83
 
81
84
  # create RecordOpStats objects
82
85
  record_op_stats = RecordOpStats(
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ from palimpzest.constants import (
6
+ MODEL_CARDS,
7
+ NAIVE_EST_NUM_INPUT_TOKENS,
8
+ NAIVE_EST_NUM_OUTPUT_TOKENS,
9
+ PromptStrategy,
10
+ )
11
+ from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
12
+ from palimpzest.core.elements.records import DataRecord
13
+ from palimpzest.core.lib.fields import Field, StringField
14
+ from palimpzest.query.generators.generators import generator_factory
15
+ from palimpzest.query.operators.convert import LLMConvert
16
+
17
+
18
+ class SplitConvert(LLMConvert):
19
+ def __init__(self, num_chunks: int = 2, min_size_to_chunk: int = 1000, *args, **kwargs):
20
+ super().__init__(*args, **kwargs)
21
+ self.num_chunks = num_chunks
22
+ self.min_size_to_chunk = min_size_to_chunk
23
+ self.split_generator = generator_factory(self.model, PromptStrategy.SPLIT_PROPOSER, self.cardinality, self.verbose)
24
+ self.split_merge_generator = generator_factory(self.model, PromptStrategy.SPLIT_MERGER, self.cardinality, self.verbose)
25
+
26
+ # crude adjustment factor for naive estimation in no-sentinel setting
27
+ self.naive_quality_adjustment = 0.6
28
+
29
+ def __str__(self):
30
+ op = super().__str__()
31
+ op += f" Chunk Size: {str(self.num_chunks)}\n"
32
+ op += f" Min Size to Chunk: {str(self.min_size_to_chunk)}\n"
33
+ return op
34
+
35
+ def get_id_params(self):
36
+ id_params = super().get_id_params()
37
+ id_params = {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **id_params}
38
+
39
+ return id_params
40
+
41
+ def get_op_params(self):
42
+ op_params = super().get_op_params()
43
+ return {"num_chunks": self.num_chunks, "min_size_to_chunk": self.min_size_to_chunk, **op_params}
44
+
45
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
46
+ """
47
+ Update the cost per record and quality estimates produced by LLMConvert's naive estimates.
48
+ We adjust the cost per record to account for the reduced number of input tokens following
49
+ the retrieval of relevant chunks, and we make a crude estimate of the quality degradation
50
+ that results from using a downsized input (although this may in fact improve quality in
51
+ some cases).
52
+ """
53
+ # get naive cost estimates from LLMConvert
54
+ naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
55
+
56
+ # re-compute cost per record assuming we use fewer input tokens; naively assume a single input field
57
+ est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS
58
+ est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
59
+ model_conversion_usd_per_record = (
60
+ MODEL_CARDS[self.model.value]["usd_per_input_token"] * est_num_input_tokens
61
+ + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
62
+ )
63
+
64
+ # set refined estimate of cost per record and, for now,
65
+ # assume quality multiplier is proportional to sqrt(sqrt(token_budget))
66
+ naive_op_cost_estimates.cost_per_record = model_conversion_usd_per_record
67
+ naive_op_cost_estimates.cost_per_record_lower_bound = naive_op_cost_estimates.cost_per_record
68
+ naive_op_cost_estimates.cost_per_record_upper_bound = naive_op_cost_estimates.cost_per_record
69
+ naive_op_cost_estimates.quality = (naive_op_cost_estimates.quality) * self.naive_quality_adjustment
70
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
71
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
72
+
73
+ return naive_op_cost_estimates
74
+
75
+ def is_image_conversion(self) -> bool:
76
+ """SplitConvert is currently disallowed on image conversions, so this must be False."""
77
+ return False
78
+
79
+ def get_text_chunks(self, text: str, num_chunks: int) -> list[str]:
80
+ """
81
+ Given a text string, chunk it into num_chunks substrings of roughly equal size.
82
+ """
83
+ chunks = []
84
+
85
+ idx, chunk_size = 0, math.ceil(len(text) / num_chunks)
86
+ while idx + chunk_size < len(text):
87
+ chunks.append(text[idx : idx + chunk_size])
88
+ idx += chunk_size
89
+
90
+ if idx < len(text):
91
+ chunks.append(text[idx:])
92
+
93
+ return chunks
94
+
95
+ def get_chunked_candidate(self, candidate: DataRecord, input_fields: list[str]) -> list[DataRecord]:
96
+ """
97
+ For each text field, chunk the content. If a field is smaller than the chunk size,
98
+ simply include the full field.
99
+ """
100
+ # compute mapping from each field to its chunked content
101
+ field_name_to_chunked_content = {}
102
+ for field_name in input_fields:
103
+ field = candidate.get_field_type(field_name)
104
+ content = candidate[field_name]
105
+
106
+ # do not chunk this field if it is not a string or a list of strings
107
+ is_string_field = isinstance(field, StringField)
108
+ is_list_string_field = hasattr(field, "element_type") and isinstance(field.element_type, StringField)
109
+ if not (is_string_field or is_list_string_field):
110
+ field_name_to_chunked_content[field_name] = [content]
111
+ continue
112
+
113
+ # if this is a list of strings, join the strings
114
+ if is_list_string_field:
115
+ content = "[" + ", ".join(content) + "]"
116
+
117
+ # skip this field if its length is less than the min size to chunk
118
+ if len(content) < self.min_size_to_chunk:
119
+ field_name_to_chunked_content[field_name] = [content]
120
+ continue
121
+
122
+ # chunk the content
123
+ field_name_to_chunked_content[field_name] = self.get_text_chunks(content, self.num_chunks)
124
+
125
+ # compute the true number of chunks (may be 1 if all fields are not chunked)
126
+ num_chunks = max(len(chunks) for chunks in field_name_to_chunked_content.values())
127
+
128
+ # create the chunked canidates
129
+ candidates = []
130
+ for chunk_idx in range(num_chunks):
131
+ candidate_copy = candidate.copy()
132
+ for field_name in input_fields:
133
+ field_chunks = field_name_to_chunked_content[field_name]
134
+ candidate_copy[field_name] = field_chunks[chunk_idx] if len(field_chunks) > 1 else field_chunks[0]
135
+
136
+ candidates.append(candidate_copy)
137
+
138
+ return candidates
139
+
140
+ def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
141
+ # get the set of input fields to use for the convert operation
142
+ input_fields = self.get_input_fields()
143
+
144
+ # lookup most relevant chunks for each field using embedding search
145
+ candidate_copy = candidate.copy()
146
+ chunked_candidates = self.get_chunked_candidate(candidate_copy, input_fields)
147
+
148
+ # construct kwargs for generation
149
+ gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema}
150
+
151
+ # generate outputs for each chunk separately
152
+ chunk_outputs, chunk_generation_stats_lst = [], []
153
+ for candidate in chunked_candidates:
154
+ _, reasoning, chunk_generation_stats, _ = self.split_generator(candidate, fields, json_output=False, **gen_kwargs)
155
+ chunk_outputs.append(reasoning)
156
+ chunk_generation_stats_lst.append(chunk_generation_stats)
157
+
158
+ # call the merger
159
+ gen_kwargs = {
160
+ "project_cols": input_fields,
161
+ "output_schema": self.output_schema,
162
+ "chunk_outputs": chunk_outputs,
163
+ }
164
+ field_answers, _, merger_gen_stats, _ = self.split_merge_generator(candidate, fields, **gen_kwargs)
165
+
166
+ # compute the total generation stats
167
+ generation_stats = sum(chunk_generation_stats_lst) + merger_gen_stats
168
+
169
+ return field_answers, generation_stats
@@ -9,7 +9,7 @@ from palimpzest.constants import (
9
9
  NAIVE_EST_NUM_OUTPUT_TOKENS,
10
10
  )
11
11
  from palimpzest.core.data.dataclasses import OperatorCostEstimates
12
- from palimpzest.query.operators.convert import LLMConvert, LLMConvertBonded, LLMConvertConventional
12
+ from palimpzest.query.operators.convert import LLMConvertBonded
13
13
  from palimpzest.utils.token_reduction_helpers import best_substring_match, find_best_range
14
14
 
15
15
 
@@ -32,8 +32,8 @@ from palimpzest.utils.token_reduction_helpers import best_substring_match, find_
32
32
  # - this also creates difficulties in properly performing cost-estimation for this operator; e.g. if we use
33
33
  # n <= MAX_HEATMAP_UPDATES samples to cost this operator, then we will never actually measure its performance
34
34
  # in the token reduction phase -- which could have a serious degradation in quality that our optimizer doesn't see
35
- class TokenReducedConvert(LLMConvert):
36
- # NOTE: moving these closer to the TokenReducedConvert class for now (in part to make
35
+ class TokenReducedConvertBonded(LLMConvertBonded):
36
+ # NOTE: moving these closer to the TokenReducedConvertBonded class for now (in part to make
37
37
  # them easier to mock); we can make these parameterized as well
38
38
  MAX_HEATMAP_UPDATES: int = 5
39
39
  TOKEN_REDUCTION_SAMPLE: int = 0
@@ -90,9 +90,9 @@ class TokenReducedConvert(LLMConvert):
90
90
  naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
91
91
 
92
92
  return naive_op_cost_estimates
93
-
93
+
94
94
  def is_image_conversion(self) -> bool:
95
- """TokenReducedConvert is currently disallowed on image conversions, so this must be False."""
95
+ """TokenReducedConvertBonded is currently disallowed on image conversions, so this must be False."""
96
96
  return False
97
97
 
98
98
  def reduce_context(self, full_context: str) -> str:
@@ -119,7 +119,9 @@ class TokenReducedConvert(LLMConvert):
119
119
  return sample
120
120
 
121
121
  def _dspy_generate_fields(self, prompt: str, content: str | list[str]) -> tuple[list[dict[str, list]] | Any]:
122
- raise Exception("TokenReducedConvert is executing despite being deprecated until implementation changes can be made.")
122
+ raise Exception(
123
+ "TokenReducedConvertBonded is executing despite being deprecated until implementation changes can be made."
124
+ )
123
125
  answer, query_stats = None, None
124
126
  if self.first_execution or self.count < self.MAX_HEATMAP_UPDATES:
125
127
  if self.verbose:
@@ -165,11 +167,3 @@ class TokenReducedConvert(LLMConvert):
165
167
  self.heatmap[norm_si:norm_ei] = map(lambda x: x + 1, self.heatmap[norm_si:norm_ei])
166
168
 
167
169
  return answer, query_stats
168
-
169
-
170
- class TokenReducedConvertConventional(TokenReducedConvert, LLMConvertConventional):
171
- pass
172
-
173
-
174
- class TokenReducedConvertBonded(TokenReducedConvert, LLMConvertBonded):
175
- pass
@@ -19,12 +19,6 @@ from palimpzest.query.optimizer.rules import (
19
19
  from palimpzest.query.optimizer.rules import (
20
20
  LLMConvertBondedRule as _LLMConvertBondedRule,
21
21
  )
22
- from palimpzest.query.optimizer.rules import (
23
- LLMConvertConventionalRule as _LLMConvertConventionalRule,
24
- )
25
- from palimpzest.query.optimizer.rules import (
26
- LLMConvertRule as _LLMConvertRule,
27
- )
28
22
  from palimpzest.query.optimizer.rules import (
29
23
  LLMFilterRule as _LLMFilterRule,
30
24
  )
@@ -50,13 +44,10 @@ from palimpzest.query.optimizer.rules import (
50
44
  Rule as _Rule,
51
45
  )
52
46
  from palimpzest.query.optimizer.rules import (
53
- TokenReducedConvertBondedRule as _TokenReducedConvertBondedRule,
54
- )
55
- from palimpzest.query.optimizer.rules import (
56
- TokenReducedConvertConventionalRule as _TokenReducedConvertConventionalRule,
47
+ SplitConvertRule as _SplitConvertRule,
57
48
  )
58
49
  from palimpzest.query.optimizer.rules import (
59
- TokenReducedConvertRule as _TokenReducedConvertRule,
50
+ TokenReducedConvertBondedRule as _TokenReducedConvertBondedRule,
60
51
  )
61
52
  from palimpzest.query.optimizer.rules import (
62
53
  TransformationRule as _TransformationRule,
@@ -70,8 +61,6 @@ ALL_RULES = [
70
61
  _CriticAndRefineConvertRule,
71
62
  _ImplementationRule,
72
63
  _LLMConvertBondedRule,
73
- _LLMConvertConventionalRule,
74
- _LLMConvertRule,
75
64
  _LLMFilterRule,
76
65
  _MixtureOfAgentsConvertRule,
77
66
  _NonLLMConvertRule,
@@ -80,9 +69,8 @@ ALL_RULES = [
80
69
  _RAGConvertRule,
81
70
  _RetrieveRule,
82
71
  _Rule,
72
+ _SplitConvertRule,
83
73
  _TokenReducedConvertBondedRule,
84
- _TokenReducedConvertConventionalRule,
85
- _TokenReducedConvertRule,
86
74
  _TransformationRule,
87
75
  ]
88
76
 
@@ -90,7 +78,7 @@ IMPLEMENTATION_RULES = [
90
78
  rule
91
79
  for rule in ALL_RULES
92
80
  if issubclass(rule, _ImplementationRule)
93
- and rule not in [_CodeSynthesisConvertRule, _ImplementationRule, _LLMConvertRule, _TokenReducedConvertRule]
81
+ and rule not in [_CodeSynthesisConvertRule, _ImplementationRule]
94
82
  ]
95
83
 
96
84
  TRANSFORMATION_RULES = [