palimpzest 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. palimpzest/constants.py +38 -62
  2. palimpzest/core/data/iter_dataset.py +5 -5
  3. palimpzest/core/elements/groupbysig.py +1 -1
  4. palimpzest/core/elements/records.py +91 -109
  5. palimpzest/core/lib/schemas.py +23 -0
  6. palimpzest/core/models.py +3 -3
  7. palimpzest/prompts/__init__.py +2 -6
  8. palimpzest/prompts/convert_prompts.py +10 -66
  9. palimpzest/prompts/critique_and_refine_prompts.py +66 -0
  10. palimpzest/prompts/filter_prompts.py +8 -46
  11. palimpzest/prompts/join_prompts.py +12 -75
  12. palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
  13. palimpzest/prompts/moa_proposer_prompts.py +87 -0
  14. palimpzest/prompts/prompt_factory.py +351 -479
  15. palimpzest/prompts/split_merge_prompts.py +51 -2
  16. palimpzest/prompts/split_proposer_prompts.py +48 -16
  17. palimpzest/prompts/utils.py +109 -0
  18. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  19. palimpzest/query/execution/execution_strategy.py +4 -4
  20. palimpzest/query/execution/mab_execution_strategy.py +1 -2
  21. palimpzest/query/execution/parallel_execution_strategy.py +3 -3
  22. palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
  23. palimpzest/query/generators/generators.py +31 -17
  24. palimpzest/query/operators/__init__.py +15 -2
  25. palimpzest/query/operators/aggregate.py +21 -19
  26. palimpzest/query/operators/compute.py +6 -8
  27. palimpzest/query/operators/convert.py +12 -37
  28. palimpzest/query/operators/critique_and_refine.py +194 -0
  29. palimpzest/query/operators/distinct.py +7 -7
  30. palimpzest/query/operators/filter.py +13 -25
  31. palimpzest/query/operators/join.py +321 -192
  32. palimpzest/query/operators/limit.py +4 -4
  33. palimpzest/query/operators/mixture_of_agents.py +246 -0
  34. palimpzest/query/operators/physical.py +25 -2
  35. palimpzest/query/operators/project.py +4 -4
  36. palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
  37. palimpzest/query/operators/retrieve.py +10 -9
  38. palimpzest/query/operators/scan.py +9 -10
  39. palimpzest/query/operators/search.py +18 -24
  40. palimpzest/query/operators/split.py +321 -0
  41. palimpzest/query/optimizer/__init__.py +12 -8
  42. palimpzest/query/optimizer/optimizer.py +12 -10
  43. palimpzest/query/optimizer/rules.py +201 -108
  44. palimpzest/query/optimizer/tasks.py +18 -6
  45. palimpzest/validator/validator.py +7 -9
  46. {palimpzest-0.8.2.dist-info → palimpzest-0.8.3.dist-info}/METADATA +3 -8
  47. palimpzest-0.8.3.dist-info/RECORD +95 -0
  48. palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
  49. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
  50. palimpzest/prompts/util_phrases.py +0 -19
  51. palimpzest/query/operators/critique_and_refine_convert.py +0 -113
  52. palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
  53. palimpzest/query/operators/split_convert.py +0 -170
  54. palimpzest-0.8.2.dist-info/RECORD +0 -95
  55. {palimpzest-0.8.2.dist-info → palimpzest-0.8.3.dist-info}/WHEEL +0 -0
  56. {palimpzest-0.8.2.dist-info → palimpzest-0.8.3.dist-info}/licenses/LICENSE +0 -0
  57. {palimpzest-0.8.2.dist-info → palimpzest-0.8.3.dist-info}/top_level.txt +0 -0
@@ -36,13 +36,13 @@ class LimitScanOp(PhysicalOperator):
36
36
  # NOTE: execution layer ensures that no more than self.limit
37
37
  # records are returned to the user by this operator.
38
38
  # create new DataRecord
39
- dr = DataRecord.from_parent(schema=candidate.schema, parent_record=candidate)
39
+ dr = DataRecord.from_parent(schema=candidate.schema, data_item={}, parent_record=candidate)
40
40
 
41
41
  # create RecordOpStats object
42
42
  record_op_stats = RecordOpStats(
43
- record_id=dr.id,
44
- record_parent_ids=dr.parent_ids,
45
- record_source_indices=dr.source_indices,
43
+ record_id=dr._id,
44
+ record_parent_ids=dr._parent_ids,
45
+ record_source_indices=dr._source_indices,
46
46
  record_state=dr.to_dict(include_bytes=False),
47
47
  full_op_id=self.get_full_op_id(),
48
48
  logical_op_id=self.logical_op_id,
@@ -0,0 +1,246 @@
1
+ from __future__ import annotations
2
+
3
+ from pydantic.fields import FieldInfo
4
+
5
+ from palimpzest.constants import MODEL_CARDS, Cardinality, Model, PromptStrategy
6
+ from palimpzest.core.elements.records import DataRecord
7
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates
8
+ from palimpzest.query.generators.generators import Generator
9
+ from palimpzest.query.operators.convert import LLMConvert
10
+ from palimpzest.query.operators.filter import LLMFilter
11
+
12
+ # TYPE DEFINITIONS
13
+ FieldName = str
14
+
15
+
16
+ class MixtureOfAgentsConvert(LLMConvert):
17
+
18
+ def __init__(
19
+ self,
20
+ proposer_models: list[Model],
21
+ temperatures: list[float],
22
+ aggregator_model: Model,
23
+ *args,
24
+ **kwargs,
25
+ ):
26
+ kwargs["model"] = None
27
+ kwargs["prompt_strategy"] = None
28
+ super().__init__(*args, **kwargs)
29
+ sorted_proposers, sorted_temps = zip(*[(m, t) for m, t in sorted(zip(proposer_models, temperatures), key=lambda pair: pair[0])])
30
+ self.proposer_models = list(sorted_proposers)
31
+ self.temperatures = list(sorted_temps)
32
+ self.aggregator_model = aggregator_model
33
+
34
+ # create generators
35
+ self.proposer_generators = [
36
+ Generator(model, PromptStrategy.MAP_MOA_PROPOSER, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
37
+ for model in proposer_models
38
+ ]
39
+ self.aggregator_generator = Generator(aggregator_model, PromptStrategy.MAP_MOA_AGG, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
40
+
41
+ def __str__(self):
42
+ op = super().__str__()
43
+ op += f" Proposer Models: {self.proposer_models}\n"
44
+ op += f" Temperatures: {self.temperatures}\n"
45
+ op += f" Aggregator Model: {self.aggregator_model}\n"
46
+ return op
47
+
48
+ def get_id_params(self):
49
+ id_params = super().get_id_params()
50
+ id_params = {
51
+ "proposer_models": [model.value for model in self.proposer_models],
52
+ "temperatures": self.temperatures,
53
+ "aggregator_model": self.aggregator_model.value,
54
+ **id_params,
55
+ }
56
+
57
+ return id_params
58
+
59
+ def get_op_params(self):
60
+ op_params = super().get_op_params()
61
+ op_params = {
62
+ "proposer_models": self.proposer_models,
63
+ "temperatures": self.temperatures,
64
+ "aggregator_model": self.aggregator_model,
65
+ **op_params,
66
+ }
67
+
68
+ return op_params
69
+
70
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
71
+ """
72
+ Currently, we are using multiple proposer models with different temperatures to synthesize
73
+ answers, which are then aggregated and summarized by a single aggregator model. Thus, we
74
+ roughly expect to incur the cost and time of an LLMConvert * (len(proposer_models) + 1).
75
+ In practice, this naive quality estimate will be overwritten by the CostModel's estimate
76
+ once it executes a few instances of the operator.
77
+ """
78
+ # temporarily set self.model so that super().naive_cost_estimates(...) can compute an estimate
79
+ self.model = self.proposer_models[0]
80
+
81
+ # get naive cost estimates for single LLM call and scale it by number of LLMs used in MoA
82
+ naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
83
+ naive_op_cost_estimates.time_per_record *= (len(self.proposer_models) + 1)
84
+ naive_op_cost_estimates.time_per_record_lower_bound = naive_op_cost_estimates.time_per_record
85
+ naive_op_cost_estimates.time_per_record_upper_bound = naive_op_cost_estimates.time_per_record
86
+ naive_op_cost_estimates.cost_per_record *= (len(self.proposer_models) + 1)
87
+ naive_op_cost_estimates.cost_per_record_lower_bound = naive_op_cost_estimates.cost_per_record
88
+ naive_op_cost_estimates.cost_per_record_upper_bound = naive_op_cost_estimates.cost_per_record
89
+
90
+ # for naive setting, estimate quality as mean of all model qualities
91
+ model_qualities = [
92
+ MODEL_CARDS[model.value]["overall"] / 100.0
93
+ for model in self.proposer_models + [self.aggregator_model]
94
+ ]
95
+ naive_op_cost_estimates.quality = sum(model_qualities)/(len(self.proposer_models) + 1)
96
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
97
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
98
+
99
+ # reset self.model to be None
100
+ self.model = None
101
+
102
+ return naive_op_cost_estimates
103
+
104
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
105
+ # get input fields
106
+ input_fields = self.get_input_fields()
107
+
108
+ # execute generator models in sequence
109
+ proposer_model_final_answers, proposer_model_generation_stats = [], []
110
+ for proposer_generator, temperature in zip(self.proposer_generators, self.temperatures):
111
+ gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema, "temperature": temperature}
112
+ _, reasoning, generation_stats, _ = proposer_generator(candidate, fields, json_output=False, **gen_kwargs)
113
+ proposer_text = f"REASONING: {reasoning}\n"
114
+ proposer_model_final_answers.append(proposer_text)
115
+ proposer_model_generation_stats.append(generation_stats)
116
+
117
+ # call the aggregator
118
+ gen_kwargs = {
119
+ "project_cols": input_fields,
120
+ "output_schema": self.output_schema,
121
+ "model_responses": proposer_model_final_answers,
122
+ }
123
+ field_answers, _, aggregator_gen_stats, _ = self.aggregator_generator(candidate, fields, **gen_kwargs)
124
+
125
+ # compute the total generation stats
126
+ generation_stats = sum(proposer_model_generation_stats) + aggregator_gen_stats
127
+
128
+ return field_answers, generation_stats
129
+
130
+
131
+ class MixtureOfAgentsFilter(LLMFilter):
132
+
133
+ def __init__(
134
+ self,
135
+ proposer_models: list[Model],
136
+ temperatures: list[float],
137
+ aggregator_model: Model,
138
+ *args,
139
+ **kwargs,
140
+ ):
141
+ kwargs["model"] = None
142
+ kwargs["prompt_strategy"] = None
143
+ super().__init__(*args, **kwargs)
144
+ sorted_proposers, sorted_temps = zip(*[(m, t) for m, t in sorted(zip(proposer_models, temperatures), key=lambda pair: pair[0])])
145
+ self.proposer_models = list(sorted_proposers)
146
+ self.temperatures = list(sorted_temps)
147
+ self.aggregator_model = aggregator_model
148
+
149
+ # create generators
150
+ self.proposer_generators = [
151
+ Generator(model, PromptStrategy.FILTER_MOA_PROPOSER, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
152
+ for model in proposer_models
153
+ ]
154
+ self.aggregator_generator = Generator(aggregator_model, PromptStrategy.FILTER_MOA_AGG, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
155
+
156
+ def __str__(self):
157
+ op = super().__str__()
158
+ op += f" Proposer Models: {self.proposer_models}\n"
159
+ op += f" Temperatures: {self.temperatures}\n"
160
+ op += f" Aggregator Model: {self.aggregator_model}\n"
161
+ return op
162
+
163
+ def get_id_params(self):
164
+ id_params = super().get_id_params()
165
+ id_params = {
166
+ "proposer_models": [model.value for model in self.proposer_models],
167
+ "temperatures": self.temperatures,
168
+ "aggregator_model": self.aggregator_model.value,
169
+ **id_params,
170
+ }
171
+
172
+ return id_params
173
+
174
+ def get_op_params(self):
175
+ op_params = super().get_op_params()
176
+ op_params = {
177
+ "proposer_models": self.proposer_models,
178
+ "temperatures": self.temperatures,
179
+ "aggregator_model": self.aggregator_model,
180
+ **op_params,
181
+ }
182
+
183
+ return op_params
184
+
185
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
186
+ """
187
+ Currently, we are using multiple proposer models with different temperatures to synthesize
188
+ answers, which are then aggregated and summarized by a single aggregator model. Thus, we
189
+ roughly expect to incur the cost and time of an LLMFilter * (len(proposer_models) + 1).
190
+ In practice, this naive quality estimate will be overwritten by the CostModel's estimate
191
+ once it executes a few instances of the operator.
192
+ """
193
+ # temporarily set self.model so that super().naive_cost_estimates(...) can compute an estimate
194
+ self.model = self.proposer_models[0]
195
+
196
+ # get naive cost estimates for single LLM call and scale it by number of LLMs used in MoA
197
+ naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
198
+ naive_op_cost_estimates.time_per_record *= (len(self.proposer_models) + 1)
199
+ naive_op_cost_estimates.time_per_record_lower_bound = naive_op_cost_estimates.time_per_record
200
+ naive_op_cost_estimates.time_per_record_upper_bound = naive_op_cost_estimates.time_per_record
201
+ naive_op_cost_estimates.cost_per_record *= (len(self.proposer_models) + 1)
202
+ naive_op_cost_estimates.cost_per_record_lower_bound = naive_op_cost_estimates.cost_per_record
203
+ naive_op_cost_estimates.cost_per_record_upper_bound = naive_op_cost_estimates.cost_per_record
204
+
205
+ # for naive setting, estimate quality as mean of all model qualities
206
+ model_qualities = [
207
+ MODEL_CARDS[model.value]["overall"] / 100.0
208
+ for model in self.proposer_models + [self.aggregator_model]
209
+ ]
210
+ naive_op_cost_estimates.quality = sum(model_qualities)/(len(self.proposer_models) + 1)
211
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
212
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
213
+
214
+ # reset self.model to be None
215
+ self.model = None
216
+
217
+ return naive_op_cost_estimates
218
+
219
+ def filter(self, candidate: DataRecord) -> tuple[dict[str, bool], GenerationStats]:
220
+ # get input fields
221
+ input_fields = self.get_input_fields()
222
+
223
+ # construct output fields
224
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the record passed the filter operation")}
225
+
226
+ # execute generator models in sequence
227
+ proposer_model_final_answers, proposer_model_generation_stats = [], []
228
+ for proposer_generator, temperature in zip(self.proposer_generators, self.temperatures):
229
+ gen_kwargs = {"project_cols": input_fields, "filter_condition": self.filter_obj.filter_condition, "temperature": temperature}
230
+ _, reasoning, generation_stats, _ = proposer_generator(candidate, fields, json_output=False, **gen_kwargs)
231
+ proposer_text = f"REASONING: {reasoning}\n"
232
+ proposer_model_final_answers.append(proposer_text)
233
+ proposer_model_generation_stats.append(generation_stats)
234
+
235
+ # call the aggregator
236
+ gen_kwargs = {
237
+ "project_cols": input_fields,
238
+ "filter_condition": self.filter_obj.filter_condition,
239
+ "model_responses": proposer_model_final_answers,
240
+ }
241
+ field_answers, _, aggregator_gen_stats, _ = self.aggregator_generator(candidate, fields, **gen_kwargs)
242
+
243
+ # compute the total generation stats
244
+ generation_stats = sum(proposer_model_generation_stats) + aggregator_gen_stats
245
+
246
+ return field_answers, generation_stats
@@ -4,7 +4,9 @@ import json
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
7
+ from palimpzest.constants import Modality
7
8
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
9
+ from palimpzest.core.lib.schemas import AUDIO_FIELD_TYPES, IMAGE_FIELD_TYPES
8
10
  from palimpzest.core.models import OperatorCostEstimates
9
11
  from palimpzest.utils.hash_helpers import hash_for_id
10
12
 
@@ -18,8 +20,8 @@ class PhysicalOperator:
18
20
 
19
21
  def __init__(
20
22
  self,
21
- output_schema: BaseModel,
22
- input_schema: BaseModel | None = None,
23
+ output_schema: type[BaseModel],
24
+ input_schema: type[BaseModel] | None = None,
23
25
  depends_on: list[str] | None = None,
24
26
  logical_op_id: str | None = None,
25
27
  unique_logical_op_id: str | None = None,
@@ -39,6 +41,19 @@ class PhysicalOperator:
39
41
  self.verbose = verbose
40
42
  self.op_id = None
41
43
 
44
+ # compute the input modalities (if any) for this physical operator
45
+ self.input_modalities = None
46
+ if self.input_schema is not None:
47
+ self.input_modalities = set()
48
+ for field in self.input_schema.model_fields.values():
49
+ field_type = field.annotation
50
+ if field_type in IMAGE_FIELD_TYPES:
51
+ self.input_modalities.add(Modality.IMAGE)
52
+ elif field_type in AUDIO_FIELD_TYPES:
53
+ self.input_modalities.add(Modality.AUDIO)
54
+ else:
55
+ self.input_modalities.add(Modality.TEXT)
56
+
42
57
  # compute the fields generated by this physical operator
43
58
  input_field_names = list(self.input_schema.model_fields) if self.input_schema is not None else []
44
59
  self.generated_fields = sorted([
@@ -139,6 +154,14 @@ class PhysicalOperator:
139
154
  def get_full_op_id(self):
140
155
  return f"{self.get_logical_op_id()}-{self.get_op_id()}"
141
156
 
157
+ def is_image_op(self) -> bool:
158
+ """Returns True if this physical operator is designed to handle image data."""
159
+ return self.input_modalities is not None and Modality.IMAGE in self.input_modalities
160
+
161
+ def is_audio_op(self) -> bool:
162
+ """Returns True if this physical operator is designed to handle audio data."""
163
+ return self.input_modalities is not None and Modality.AUDIO in self.input_modalities
164
+
142
165
  def __hash__(self):
143
166
  return int(self.op_id, 16) # NOTE: should we use self.get_full_op_id() instead?
144
167
 
@@ -34,13 +34,13 @@ class ProjectOp(PhysicalOperator):
34
34
 
35
35
  def __call__(self, candidate: DataRecord) -> DataRecordSet:
36
36
  # create new DataRecord with projection applied
37
- dr = DataRecord.from_parent(schema=candidate.schema, parent_record=candidate, project_cols=self.project_cols)
37
+ dr = DataRecord.from_parent(schema=candidate.schema, data_item={}, parent_record=candidate, project_cols=self.project_cols)
38
38
 
39
39
  # create RecordOpStats object
40
40
  record_op_stats = RecordOpStats(
41
- record_id=dr.id,
42
- record_parent_ids=dr.parent_ids,
43
- record_source_indices=dr.source_indices,
41
+ record_id=dr._id,
42
+ record_parent_ids=dr._parent_ids,
43
+ record_source_indices=dr._source_indices,
44
44
  record_state=dr.to_dict(include_bytes=False),
45
45
  full_op_id=self.get_full_op_id(),
46
46
  logical_op_id=self.logical_op_id,
@@ -15,6 +15,7 @@ from palimpzest.constants import (
15
15
  from palimpzest.core.elements.records import DataRecord
16
16
  from palimpzest.core.models import GenerationStats, OperatorCostEstimates
17
17
  from palimpzest.query.operators.convert import LLMConvert
18
+ from palimpzest.query.operators.filter import LLMFilter
18
19
 
19
20
 
20
21
  class RAGConvert(LLMConvert):
@@ -26,7 +27,7 @@ class RAGConvert(LLMConvert):
26
27
  self.num_chunks_per_field = num_chunks_per_field
27
28
  self.chunk_size = chunk_size
28
29
 
29
- # crude adjustment factor for naive estimation in no-sentinel setting
30
+ # crude adjustment factor for naive estimation in unoptimized setting
30
31
  self.naive_quality_adjustment = 0.6
31
32
 
32
33
  def __str__(self):
@@ -74,10 +75,6 @@ class RAGConvert(LLMConvert):
74
75
 
75
76
  return naive_op_cost_estimates
76
77
 
77
- def is_image_conversion(self) -> bool:
78
- """RAGConvert is currently disallowed on image conversions, so this must be False."""
79
- return False
80
-
81
78
  def chunk_text(self, text: str, chunk_size: int) -> list[str]:
82
79
  """
83
80
  Given a text string, chunk it into substrings of length chunk_size.
@@ -228,3 +225,203 @@ class RAGConvert(LLMConvert):
228
225
  generation_stats += single_field_stats
229
226
 
230
227
  return field_answers, generation_stats
228
+
229
+
230
+ class RAGFilter(LLMFilter):
231
+ def __init__(self, num_chunks_per_field: int, chunk_size: int = 1000, *args, **kwargs):
232
+ super().__init__(*args, **kwargs)
233
+ # NOTE: in the future, we should abstract the embedding model to allow for different models
234
+ self.client = None
235
+ self.embedding_model = Model.TEXT_EMBEDDING_3_SMALL
236
+ self.num_chunks_per_field = num_chunks_per_field
237
+ self.chunk_size = chunk_size
238
+
239
+ # crude adjustment factor for naive estimation in no-sentinel setting
240
+ self.naive_quality_adjustment = 0.6
241
+
242
+ def __str__(self):
243
+ op = super().__str__()
244
+ op += f" Number of Chunks: {str(self.num_chunks_per_field)}\n"
245
+ op += f" Chunk Size: {str(self.chunk_size)}\n"
246
+ return op
247
+
248
+ def get_id_params(self):
249
+ id_params = super().get_id_params()
250
+ id_params = {"num_chunks_per_field": self.num_chunks_per_field, "chunk_size": self.chunk_size, **id_params}
251
+
252
+ return id_params
253
+
254
+ def get_op_params(self):
255
+ op_params = super().get_op_params()
256
+ return {"num_chunks_per_field": self.num_chunks_per_field, "chunk_size": self.chunk_size, **op_params}
257
+
258
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
259
+ """
260
+ Update the cost per record and quality estimates produced by LLMFilter's naive estimates.
261
+ We adjust the cost per record to account for the reduced number of input tokens following
262
+ the retrieval of relevant chunks, and we make a crude estimate of the quality degradation
263
+ that results from using a downsized input (although this may in fact improve quality in
264
+ some cases).
265
+ """
266
+ # get naive cost estimates from LLMFilter
267
+ naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
268
+
269
+ # re-compute cost per record assuming we use fewer input tokens; naively assume a single input field
270
+ est_num_input_tokens = self.num_chunks_per_field * self.chunk_size
271
+ est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
272
+ model_conversion_usd_per_record = (
273
+ MODEL_CARDS[self.model.value]["usd_per_input_token"] * est_num_input_tokens
274
+ + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
275
+ )
276
+
277
+ # set refined estimate of cost per record
278
+ naive_op_cost_estimates.cost_per_record = model_conversion_usd_per_record
279
+ naive_op_cost_estimates.cost_per_record_lower_bound = naive_op_cost_estimates.cost_per_record
280
+ naive_op_cost_estimates.cost_per_record_upper_bound = naive_op_cost_estimates.cost_per_record
281
+ naive_op_cost_estimates.quality = (naive_op_cost_estimates.quality) * self.naive_quality_adjustment
282
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
283
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
284
+
285
+ return naive_op_cost_estimates
286
+
287
+ def chunk_text(self, text: str, chunk_size: int) -> list[str]:
288
+ """
289
+ Given a text string, chunk it into substrings of length chunk_size.
290
+ """
291
+ chunks = []
292
+ idx = 0
293
+ while idx + chunk_size < len(text):
294
+ chunks.append(text[idx : idx + chunk_size])
295
+ idx += chunk_size
296
+
297
+ if idx < len(text):
298
+ chunks.append(text[idx:])
299
+
300
+ return chunks
301
+
302
+ def compute_embedding(self, text: str) -> tuple[list[float], GenerationStats]:
303
+ """
304
+ Compute the embedding for a text string. Return the embedding and the GenerationStats object
305
+ that captures the cost of the operation.
306
+ """
307
+ # get the embedding model name
308
+ model_name = self.embedding_model.value
309
+
310
+ # compute the embedding
311
+ start_time = time.time()
312
+ response = self.client.embeddings.create(input=text, model=model_name)
313
+ total_time = time.time() - start_time
314
+
315
+ # extract the embedding
316
+ embedding = response.data[0].embedding
317
+
318
+ # compute the generation stats object
319
+ model_card = MODEL_CARDS[model_name]
320
+ total_input_tokens = response.usage.total_tokens
321
+ total_input_cost = model_card["usd_per_input_token"] * total_input_tokens
322
+ embed_stats = GenerationStats(
323
+ model_name=model_name, # NOTE: this should be overwritten by generation model in filter()
324
+ total_input_tokens=total_input_tokens,
325
+ total_output_tokens=0.0,
326
+ total_input_cost=total_input_cost,
327
+ total_output_cost=0.0,
328
+ cost_per_record=total_input_cost,
329
+ llm_call_duration_secs=total_time,
330
+ total_llm_calls=1,
331
+ total_embedding_llm_calls=1,
332
+ )
333
+
334
+ return embedding, embed_stats
335
+
336
+ def compute_similarity(self, query_embedding: list[float], chunk_embedding: list[float]) -> float:
337
+ """
338
+ Compute the similarity between the query and chunk embeddings.
339
+ """
340
+ return dot(query_embedding, chunk_embedding) / (norm(query_embedding) * norm(chunk_embedding))
341
+
342
+ def get_chunked_candidate(self, candidate: DataRecord, input_fields: list[str]) -> tuple[DataRecord, GenerationStats]:
343
+ """
344
+ For each text field, chunk the content and compute the chunk embeddings. Then select the top-k chunks
345
+ for each field. If a field is smaller than the chunk size, simply include the full field.
346
+ """
347
+ # initialize stats for embedding costs
348
+ embed_stats = GenerationStats()
349
+
350
+ # compute embedding for filter condition
351
+ query_embedding, query_embed_stats = self.compute_embedding(self.filter_obj.filter_condition)
352
+
353
+ # add cost of embedding the query to embed_stats
354
+ embed_stats += query_embed_stats
355
+
356
+ # for each input field, chunk its content and compute the (per-chunk) embeddings
357
+ for field_name in input_fields:
358
+ field = candidate.get_field_type(field_name)
359
+
360
+ # skip this field if it is not a string or a list of strings
361
+ is_string_field = field.annotation in [str, str | None]
362
+ is_list_string_field = field.annotation in [list[str], list[str] | None]
363
+ if not (is_string_field or is_list_string_field):
364
+ continue
365
+
366
+ # if this is a list of strings, join the strings
367
+ if is_list_string_field:
368
+ candidate[field_name] = "[" + ", ".join(candidate[field_name]) + "]"
369
+
370
+ # skip this field if it is a string field and its length is less than the chunk size
371
+ if len(candidate[field_name]) < self.chunk_size:
372
+ continue
373
+
374
+ # chunk the content
375
+ chunks = self.chunk_text(candidate[field_name], self.chunk_size)
376
+
377
+ # compute embeddings for each chunk
378
+ chunk_embeddings, chunk_embed_stats_lst = zip(*[self.compute_embedding(chunk) for chunk in chunks])
379
+
380
+ # add cost of embedding each chunk to embed_stats
381
+ for chunk_embed_stats in chunk_embed_stats_lst:
382
+ embed_stats += chunk_embed_stats
383
+
384
+ # select the top-k chunks
385
+ sorted_chunks = sorted(
386
+ zip(range(len(chunks)), chunks, chunk_embeddings),
387
+ key=lambda tup: self.compute_similarity(query_embedding, tup[2]),
388
+ reverse=True,
389
+ )
390
+ top_k_chunks = [(chunk_idx, chunk) for chunk_idx, chunk, _ in sorted_chunks[:self.num_chunks_per_field]]
391
+
392
+ # sort the top-k chunks by their original index in the content, and join them with ellipses
393
+ top_k_chunks = [chunk for _, chunk in sorted(top_k_chunks, key=lambda tup: tup[0])]
394
+ candidate[field_name] = "...".join(top_k_chunks)
395
+
396
+ return candidate, embed_stats
397
+
398
+ def filter(self, candidate: DataRecord) -> tuple[dict[str, bool], GenerationStats]:
399
+ # set client
400
+ self.client = OpenAI() if self.client is None else self.client
401
+
402
+ # get the set of input fields to use for the filter operation
403
+ input_fields = self.get_input_fields()
404
+
405
+ # construct output fields
406
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the record passed the filter operation")}
407
+
408
+ # lookup most relevant chunks for each field using embedding search
409
+ candidate_copy = candidate.copy()
410
+ candidate_copy, embed_stats = self.get_chunked_candidate(candidate_copy, input_fields)
411
+
412
+ # construct kwargs for generation
413
+ gen_kwargs = {"project_cols": input_fields, "filter_condition": self.filter_obj.filter_condition}
414
+
415
+ # generate outputs for all fields in a single query
416
+ field_answers, _, generation_stats, _ = self.generator(candidate_copy, fields, **gen_kwargs)
417
+
418
+ # NOTE: summing embedding stats with generation stats is messy because it will lead to misleading
419
+ # measurements of total_input_tokens and total_output_tokens. We should fix this in the future.
420
+ # The good news: as long as we compute the cost_per_record of each GenerationStats object correctly,
421
+ # then the total cost of the operation will be correct (which will roll-up to correctly computing
422
+ # the total cost of the operator, plan, and execution).
423
+ #
424
+ # combine stats from embedding with stats for generation
425
+ generation_stats += embed_stats
426
+
427
+ return field_answers, generation_stats
@@ -145,11 +145,11 @@ class RetrieveOp(PhysicalOperator):
145
145
  Given an input DataRecord and the top_k_results, construct the resulting RecordSet.
146
146
  """
147
147
  # create output DataRecord an set the output attribute
148
- output_dr, answer = DataRecord.from_parent(self.output_schema, parent_record=candidate), {}
149
- for output_field_name in self.output_field_names:
150
- top_k_attr_results = None if top_k_results is None else top_k_results[output_field_name]
151
- setattr(output_dr, output_field_name, top_k_attr_results)
152
- answer[output_field_name] = top_k_attr_results
148
+ data_item = {
149
+ output_field_name: None if top_k_results is None else top_k_results[output_field_name]
150
+ for output_field_name in self.output_field_names
151
+ }
152
+ output_dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
153
153
 
154
154
  # get the record_state and generated fields
155
155
  record_state = output_dr.to_dict(include_bytes=False)
@@ -159,16 +159,17 @@ class RetrieveOp(PhysicalOperator):
159
159
 
160
160
  # construct the RecordOpStats object
161
161
  record_op_stats = RecordOpStats(
162
- record_id=output_dr.id,
163
- record_parent_ids=output_dr.parent_ids,
164
- record_source_indices=output_dr.source_indices,
162
+ record_id=output_dr._id,
163
+ record_parent_ids=output_dr._parent_ids,
164
+ record_source_indices=output_dr._source_indices,
165
165
  record_state=record_state,
166
166
  full_op_id=self.get_full_op_id(),
167
167
  logical_op_id=self.logical_op_id,
168
168
  op_name=self.op_name(),
169
169
  time_per_record=total_time,
170
170
  cost_per_record=generation_stats.cost_per_record,
171
- answer=answer,
171
+ total_embedding_cost=generation_stats.cost_per_record,
172
+ answer=data_item,
172
173
  input_fields=list(self.input_schema.model_fields),
173
174
  generated_fields=generated_fields,
174
175
  fn_call_duration_secs=total_time - generation_stats.llm_call_duration_secs,
@@ -71,15 +71,14 @@ class ScanPhysicalOp(PhysicalOperator, ABC):
71
71
  assert all([field in item for field in output_field_names]), f"Some fields in Dataset schema not present in item!\n - Dataset fields: {output_field_names}\n - Item fields: {list(item.keys())}"
72
72
 
73
73
  # construct a DataRecord from the item
74
- dr = DataRecord(self.output_schema, source_indices=[f"{self.datasource.id}-{idx}"])
75
- for field in output_field_names:
76
- setattr(dr, field, item[field])
74
+ data_item = self.output_schema(**{field: item[field] for field in output_field_names})
75
+ dr = DataRecord(data_item, source_indices=[f"{self.datasource.id}-{idx}"])
77
76
 
78
77
  # create RecordOpStats objects
79
78
  record_op_stats = RecordOpStats(
80
- record_id=dr.id,
81
- record_parent_ids=dr.parent_ids,
82
- record_source_indices=dr.source_indices,
79
+ record_id=dr._id,
80
+ record_parent_ids=dr._parent_ids,
81
+ record_source_indices=dr._source_indices,
83
82
  record_state=dr.to_dict(include_bytes=False),
84
83
  full_op_id=self.get_full_op_id(),
85
84
  logical_op_id=self.logical_op_id,
@@ -170,15 +169,15 @@ class ContextScanOp(PhysicalOperator):
170
169
  """
171
170
  # construct a DataRecord from the context
172
171
  start_time = time.time()
173
- dr = DataRecord(self.output_schema, source_indices=[f"{self.context.id}-{0}"])
172
+ dr = DataRecord(self.output_schema(), source_indices=[f"{self.context.id}-{0}"])
174
173
  dr.context = self.context
175
174
  end_time = time.time()
176
175
 
177
176
  # create RecordOpStats objects
178
177
  record_op_stats = RecordOpStats(
179
- record_id=dr.id,
180
- record_parent_ids=dr.parent_ids,
181
- record_source_indices=dr.source_indices,
178
+ record_id=dr._id,
179
+ record_parent_ids=dr._parent_ids,
180
+ record_source_indices=dr._source_indices,
182
181
  record_state=dr.to_dict(include_bytes=False),
183
182
  full_op_id=self.get_full_op_id(),
184
183
  logical_op_id=self.logical_op_id,