palimpzest 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. palimpzest/constants.py +38 -62
  2. palimpzest/core/data/iter_dataset.py +5 -5
  3. palimpzest/core/elements/groupbysig.py +1 -1
  4. palimpzest/core/elements/records.py +91 -109
  5. palimpzest/core/lib/schemas.py +23 -0
  6. palimpzest/core/models.py +3 -3
  7. palimpzest/prompts/__init__.py +2 -6
  8. palimpzest/prompts/convert_prompts.py +10 -66
  9. palimpzest/prompts/critique_and_refine_prompts.py +66 -0
  10. palimpzest/prompts/filter_prompts.py +8 -46
  11. palimpzest/prompts/join_prompts.py +12 -75
  12. palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
  13. palimpzest/prompts/moa_proposer_prompts.py +87 -0
  14. palimpzest/prompts/prompt_factory.py +351 -479
  15. palimpzest/prompts/split_merge_prompts.py +51 -2
  16. palimpzest/prompts/split_proposer_prompts.py +48 -16
  17. palimpzest/prompts/utils.py +109 -0
  18. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  19. palimpzest/query/execution/execution_strategy.py +4 -4
  20. palimpzest/query/execution/mab_execution_strategy.py +1 -2
  21. palimpzest/query/execution/parallel_execution_strategy.py +3 -3
  22. palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
  23. palimpzest/query/generators/generators.py +31 -17
  24. palimpzest/query/operators/__init__.py +15 -2
  25. palimpzest/query/operators/aggregate.py +21 -19
  26. palimpzest/query/operators/compute.py +6 -8
  27. palimpzest/query/operators/convert.py +12 -37
  28. palimpzest/query/operators/critique_and_refine.py +194 -0
  29. palimpzest/query/operators/distinct.py +7 -7
  30. palimpzest/query/operators/filter.py +13 -25
  31. palimpzest/query/operators/join.py +321 -192
  32. palimpzest/query/operators/limit.py +4 -4
  33. palimpzest/query/operators/mixture_of_agents.py +246 -0
  34. palimpzest/query/operators/physical.py +25 -2
  35. palimpzest/query/operators/project.py +4 -4
  36. palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
  37. palimpzest/query/operators/retrieve.py +10 -9
  38. palimpzest/query/operators/scan.py +9 -10
  39. palimpzest/query/operators/search.py +18 -24
  40. palimpzest/query/operators/split.py +321 -0
  41. palimpzest/query/optimizer/__init__.py +12 -8
  42. palimpzest/query/optimizer/optimizer.py +12 -10
  43. palimpzest/query/optimizer/rules.py +201 -108
  44. palimpzest/query/optimizer/tasks.py +18 -6
  45. palimpzest/validator/validator.py +7 -9
  46. {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/METADATA +3 -8
  47. palimpzest-0.8.4.dist-info/RECORD +95 -0
  48. palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
  49. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
  50. palimpzest/prompts/util_phrases.py +0 -19
  51. palimpzest/query/operators/critique_and_refine_convert.py +0 -113
  52. palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
  53. palimpzest/query/operators/split_convert.py +0 -170
  54. palimpzest-0.8.2.dist-info/RECORD +0 -95
  55. {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/WHEEL +0 -0
  56. {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/licenses/LICENSE +0 -0
  57. {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/top_level.txt +0 -0
@@ -113,18 +113,20 @@ class ApplyGroupByOp(AggregateOp):
113
113
  group_by_fields = self.group_by_sig.group_by_fields
114
114
  agg_fields = self.group_by_sig.get_agg_field_names()
115
115
  for g in agg_state:
116
- dr = DataRecord.from_agg_parents(
117
- schema=self.group_by_sig.output_schema(),
118
- parent_records=candidates,
119
- )
116
+ # build up data item
117
+ data_item = {}
120
118
  for i in range(0, len(g)):
121
119
  k = g[i]
122
- setattr(dr, group_by_fields[i], k)
120
+ data_item[group_by_fields[i]] = k
123
121
  vals = agg_state[g]
124
122
  for i in range(0, len(vals)):
125
123
  v = ApplyGroupByOp.agg_final(self.group_by_sig.agg_funcs[i], vals[i])
126
- setattr(dr, agg_fields[i], v)
124
+ data_item[agg_fields[i]] = v
127
125
 
126
+ # create new DataRecord
127
+ schema = self.group_by_sig.output_schema()
128
+ data_item = schema(**data_item)
129
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
128
130
  drs.append(dr)
129
131
 
130
132
  # create RecordOpStats objects
@@ -132,9 +134,9 @@ class ApplyGroupByOp(AggregateOp):
132
134
  record_op_stats_lst = []
133
135
  for dr in drs:
134
136
  record_op_stats = RecordOpStats(
135
- record_id=dr.id,
136
- record_parent_ids=dr.parent_ids,
137
- record_source_indices=dr.source_indices,
137
+ record_id=dr._id,
138
+ record_parent_ids=dr._parent_ids,
139
+ record_source_indices=dr._source_indices,
138
140
  record_state=dr.to_dict(include_bytes=False),
139
141
  full_op_id=self.get_full_op_id(),
140
142
  logical_op_id=self.logical_op_id,
@@ -197,7 +199,6 @@ class AverageAggregateOp(AggregateOp):
197
199
  # NOTE: right now we perform a check in the constructor which enforces that the input_schema
198
200
  # has a single field which is numeric in nature; in the future we may want to have a
199
201
  # cleaner way of computing the value (rather than `float(list(candidate...))` below)
200
- dr = DataRecord.from_agg_parents(schema=Average, parent_records=candidates)
201
202
  summation, total = 0, 0
202
203
  for candidate in candidates:
203
204
  try:
@@ -205,13 +206,14 @@ class AverageAggregateOp(AggregateOp):
205
206
  total += 1
206
207
  except Exception:
207
208
  pass
208
- dr.average = summation / total
209
+ data_item = Average(average=summation / total)
210
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
209
211
 
210
212
  # create RecordOpStats object
211
213
  record_op_stats = RecordOpStats(
212
- record_id=dr.id,
213
- record_parent_ids=dr.parent_ids,
214
- record_source_indices=dr.source_indices,
214
+ record_id=dr._id,
215
+ record_parent_ids=dr._parent_ids,
216
+ record_source_indices=dr._source_indices,
215
217
  record_state=dr.to_dict(include_bytes=False),
216
218
  full_op_id=self.get_full_op_id(),
217
219
  logical_op_id=self.logical_op_id,
@@ -260,14 +262,14 @@ class CountAggregateOp(AggregateOp):
260
262
  start_time = time.time()
261
263
 
262
264
  # create new DataRecord
263
- dr = DataRecord.from_agg_parents(schema=Count, parent_records=candidates)
264
- dr.count = len(candidates)
265
+ data_item = Count(count=len(candidates))
266
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
265
267
 
266
268
  # create RecordOpStats object
267
269
  record_op_stats = RecordOpStats(
268
- record_id=dr.id,
269
- record_parent_ids=dr.parent_ids,
270
- record_source_indices=dr.source_indices,
270
+ record_id=dr._id,
271
+ record_parent_ids=dr._parent_ids,
272
+ record_source_indices=dr._source_indices,
271
273
  record_state=dr.to_dict(include_bytes=False),
272
274
  full_op_id=self.get_full_op_id(),
273
275
  logical_op_id=self.logical_op_id,
@@ -93,17 +93,15 @@ class SmolAgentsCompute(PhysicalOperator):
93
93
  Given an input DataRecord and a determination of whether it passed the filter or not,
94
94
  construct the resulting RecordSet.
95
95
  """
96
- # create new DataRecord and set passed_operator attribute
97
- dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
98
- for field in self.output_schema.model_fields:
99
- if field in answer:
100
- dr[field] = answer[field]
96
+ # create new DataRecord
97
+ data_item = {field: answer[field] for field in self.output_schema.model_fields if field in answer}
98
+ dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
101
99
 
102
100
  # create RecordOpStats object
103
101
  record_op_stats = RecordOpStats(
104
- record_id=dr.id,
105
- record_parent_ids=dr.parent_ids,
106
- record_source_indices=dr.source_indices,
102
+ record_id=dr._id,
103
+ record_parent_ids=dr._parent_ids,
104
+ record_source_indices=dr._source_indices,
107
105
  record_state=dr.to_dict(include_bytes=False),
108
106
  full_op_id=self.get_full_op_id(),
109
107
  logical_op_id=self.logical_op_id,
@@ -74,25 +74,14 @@ class ConvertOp(PhysicalOperator, ABC):
74
74
 
75
75
  drs = []
76
76
  for idx in range(max(n_records, 1)):
77
- # initialize record with the correct output schema, parent record, and cardinality idx
78
- dr = DataRecord.from_parent(self.output_schema, parent_record=candidate, cardinality_idx=idx)
79
-
80
- # copy all fields from the input record
81
- # NOTE: this means that records processed by PZ converts will inherit all pre-computed fields
82
- # in an incremental fashion; this is a design choice which may be revisited in the future
83
- for field in candidate.get_field_names():
84
- setattr(dr, field, getattr(candidate, field))
85
-
86
- # get input field names and output field names
87
- input_fields = list(self.input_schema.model_fields)
88
- output_fields = list(self.output_schema.model_fields)
89
-
90
77
  # parse newly generated fields from the field_answers dictionary for this field; if the list
91
78
  # of generated values is shorter than the number of records, we fill in with None
92
- for field in output_fields:
93
- if field not in input_fields:
94
- value = field_answers[field][idx] if idx < len(field_answers[field]) else None
95
- setattr(dr, field, value)
79
+ data_item = {}
80
+ for field in self.generated_fields:
81
+ data_item[field] = field_answers[field][idx] if idx < len(field_answers[field]) else None
82
+
83
+ # initialize record with the correct output schema, data_item, parent record, and cardinality idx
84
+ dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate, cardinality_idx=idx)
96
85
 
97
86
  # append data record to list of output data records
98
87
  drs.append(dr)
@@ -117,9 +106,9 @@ class ConvertOp(PhysicalOperator, ABC):
117
106
  # create the RecordOpStats objects for each output record
118
107
  record_op_stats_lst = [
119
108
  RecordOpStats(
120
- record_id=dr.id,
121
- record_parent_ids=dr.parent_ids,
122
- record_source_indices=dr.source_indices,
109
+ record_id=dr._id,
110
+ record_parent_ids=dr._parent_ids,
111
+ record_source_indices=dr._source_indices,
123
112
  record_state=dr.to_dict(include_bytes=False),
124
113
  full_op_id=self.get_full_op_id(),
125
114
  logical_op_id=self.logical_op_id,
@@ -127,7 +116,7 @@ class ConvertOp(PhysicalOperator, ABC):
127
116
  time_per_record=time_per_record,
128
117
  cost_per_record=per_record_stats.cost_per_record,
129
118
  model_name=self.get_model_name(),
130
- answer={field_name: getattr(dr, field_name) for field_name in field_names},
119
+ answer={field_name: getattr(dr, field_name, None) for field_name in field_names},
131
120
  input_fields=list(self.input_schema.model_fields),
132
121
  generated_fields=field_names,
133
122
  total_input_tokens=per_record_stats.total_input_tokens,
@@ -139,7 +128,6 @@ class ConvertOp(PhysicalOperator, ABC):
139
128
  total_llm_calls=per_record_stats.total_llm_calls,
140
129
  total_embedding_llm_calls=per_record_stats.total_embedding_llm_calls,
141
130
  failed_convert=(not successful_convert),
142
- image_operation=self.is_image_conversion(),
143
131
  op_details={k: str(v) for k, v in self.get_id_params().items()},
144
132
  )
145
133
  for dr in records
@@ -148,11 +136,6 @@ class ConvertOp(PhysicalOperator, ABC):
148
136
  # create and return the DataRecordSet
149
137
  return DataRecordSet(records, record_op_stats_lst)
150
138
 
151
- @abstractmethod
152
- def is_image_conversion(self) -> bool:
153
- """Return True if the convert operation processes an image, False otherwise."""
154
- pass
155
-
156
139
  @abstractmethod
157
140
  def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
158
141
  """
@@ -216,11 +199,6 @@ class NonLLMConvert(ConvertOp):
216
199
  op += f" UDF: {self.udf.__name__}\n"
217
200
  return op
218
201
 
219
- def is_image_conversion(self) -> bool:
220
- # NOTE: even if the UDF is processing an image, we do not consider this an image conversion
221
- # (the output of this function will be used by the CostModel in a way which does not apply to UDFs)
222
- return False
223
-
224
202
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
225
203
  """
226
204
  Compute naive cost estimates for the NonLLMConvert operation. These estimates assume
@@ -287,7 +265,7 @@ class LLMConvert(ConvertOp):
287
265
  def __init__(
288
266
  self,
289
267
  model: Model,
290
- prompt_strategy: PromptStrategy = PromptStrategy.COT_QA,
268
+ prompt_strategy: PromptStrategy = PromptStrategy.MAP,
291
269
  reasoning_effort: str | None = None,
292
270
  *args,
293
271
  **kwargs,
@@ -330,9 +308,6 @@ class LLMConvert(ConvertOp):
330
308
  def get_model_name(self):
331
309
  return None if self.model is None else self.model.value
332
310
 
333
- def is_image_conversion(self) -> bool:
334
- return self.prompt_strategy.is_image_prompt()
335
-
336
311
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
337
312
  """
338
313
  Compute naive cost estimates for the LLMConvert operation. Implicitly, these estimates
@@ -350,7 +325,7 @@ class LLMConvert(ConvertOp):
350
325
 
351
326
  # get est. of conversion cost (in USD) per record from model card
352
327
  usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
353
- if getattr(self, "prompt_strategy", None) is not None and self.prompt_strategy.is_audio_prompt():
328
+ if getattr(self, "prompt_strategy", None) is not None and self.is_audio_op():
354
329
  usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
355
330
 
356
331
  model_conversion_usd_per_record = (
@@ -0,0 +1,194 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from pydantic.fields import FieldInfo
6
+
7
+ from palimpzest.constants import MODEL_CARDS, Cardinality, Model, PromptStrategy
8
+ from palimpzest.core.elements.records import DataRecord
9
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates
10
+ from palimpzest.query.generators.generators import Generator
11
+ from palimpzest.query.operators.convert import LLMConvert
12
+ from palimpzest.query.operators.filter import LLMFilter
13
+
14
+ # TYPE DEFINITIONS
15
+ FieldName = str
16
+
17
+
18
+ class CritiqueAndRefineConvert(LLMConvert):
19
+
20
+ def __init__(
21
+ self,
22
+ critic_model: Model,
23
+ refine_model: Model,
24
+ *args,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(*args, **kwargs)
28
+ self.critic_model = critic_model
29
+ self.refine_model = refine_model
30
+
31
+ # create generators
32
+ self.critic_generator = Generator(self.critic_model, PromptStrategy.MAP_CRITIC, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
33
+ self.refine_generator = Generator(self.refine_model, PromptStrategy.MAP_REFINE, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
34
+
35
+ def __str__(self):
36
+ op = super().__str__()
37
+ op += f" Critic Model: {self.critic_model}\n"
38
+ op += f" Refine Model: {self.refine_model}\n"
39
+ return op
40
+
41
+ def get_id_params(self):
42
+ id_params = super().get_id_params()
43
+ id_params = {
44
+ "critic_model": self.critic_model.value,
45
+ "refine_model": self.refine_model.value,
46
+ **id_params,
47
+ }
48
+
49
+ return id_params
50
+
51
+ def get_op_params(self):
52
+ op_params = super().get_op_params()
53
+ op_params = {
54
+ "critic_model": self.critic_model,
55
+ "refine_model": self.refine_model,
56
+ **op_params,
57
+ }
58
+
59
+ return op_params
60
+
61
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
62
+ """
63
+ Currently, we are invoking `self.model`, then critiquing its output with `self.critic_model`, and
64
+ finally refining the output with `self.refine_model`. Thus, we roughly expect to incur the cost
65
+ and time of three LLMConverts. In practice, this naive quality estimate will be overwritten by the
66
+ CostModel's estimate once it executes a few instances of the operator.
67
+ """
68
+ # get naive cost estimates for first LLM call and multiply by 3 for now;
69
+ # of course we should sum individual estimates for each model, but this is a rough estimate
70
+ # and in practice we will need to revamp our naive cost estimates in the near future
71
+ naive_op_cost_estimates = 3 * super().naive_cost_estimates(source_op_cost_estimates)
72
+
73
+ # for naive setting, estimate quality as quality of refine model
74
+ model_quality = MODEL_CARDS[self.refine_model.value]["overall"] / 100.0
75
+ naive_op_cost_estimates.quality = model_quality
76
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
77
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
78
+
79
+ return naive_op_cost_estimates
80
+
81
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
82
+ # get input fields
83
+ input_fields = self.get_input_fields()
84
+
85
+ # NOTE: when I merge in the `abacus` branch, I will want to update this to reflect the changes I made to reasoning extraction
86
+ # execute the initial model
87
+ original_gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema}
88
+ field_answers, reasoning, original_gen_stats, original_messages = self.generator(candidate, fields, **original_gen_kwargs)
89
+ original_output = f"REASONING: {reasoning}\nANSWER: {field_answers}\n"
90
+
91
+ # execute the critic model
92
+ critic_gen_kwargs = {"original_output": original_output, "original_messages": original_messages, **original_gen_kwargs}
93
+ _, reasoning, critic_gen_stats, _ = self.critic_generator(candidate, fields, json_output=False, **critic_gen_kwargs)
94
+ critique_output = f"CRITIQUE: {reasoning}\n"
95
+
96
+ # execute the refinement model
97
+ refine_gen_kwargs = {"critique_output": critique_output, **critic_gen_kwargs}
98
+ field_answers, reasoning, refine_gen_stats, _ = self.refine_generator(candidate, fields, **refine_gen_kwargs)
99
+
100
+ # compute the total generation stats
101
+ generation_stats = original_gen_stats + critic_gen_stats + refine_gen_stats
102
+
103
+ return field_answers, generation_stats
104
+
105
+
106
+ class CritiqueAndRefineFilter(LLMFilter):
107
+
108
+ def __init__(
109
+ self,
110
+ critic_model: Model,
111
+ refine_model: Model,
112
+ *args,
113
+ **kwargs,
114
+ ):
115
+ super().__init__(*args, **kwargs)
116
+ self.critic_model = critic_model
117
+ self.refine_model = refine_model
118
+
119
+ # create generators
120
+ self.critic_generator = Generator(self.critic_model, PromptStrategy.FILTER_CRITIC, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
121
+ self.refine_generator = Generator(self.refine_model, PromptStrategy.FILTER_REFINE, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
122
+
123
+ def __str__(self):
124
+ op = super().__str__()
125
+ op += f" Critic Model: {self.critic_model}\n"
126
+ op += f" Refine Model: {self.refine_model}\n"
127
+ return op
128
+
129
+ def get_id_params(self):
130
+ id_params = super().get_id_params()
131
+ id_params = {
132
+ "critic_model": self.critic_model.value,
133
+ "refine_model": self.refine_model.value,
134
+ **id_params,
135
+ }
136
+
137
+ return id_params
138
+
139
+ def get_op_params(self):
140
+ op_params = super().get_op_params()
141
+ op_params = {
142
+ "critic_model": self.critic_model,
143
+ "refine_model": self.refine_model,
144
+ **op_params,
145
+ }
146
+
147
+ return op_params
148
+
149
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
150
+ """
151
+ Currently, we are invoking `self.model`, then critiquing its output with `self.critic_model`, and
152
+ finally refining the output with `self.refine_model`. Thus, we roughly expect to incur the cost
153
+ and time of three LLMFilters. In practice, this naive quality estimate will be overwritten by the
154
+ CostModel's estimate once it executes a few instances of the operator.
155
+ """
156
+ # get naive cost estimates for first LLM call and multiply by 3 for now;
157
+ # of course we should sum individual estimates for each model, but this is a rough estimate
158
+ # and in practice we will need to revamp our naive cost estimates in the near future
159
+ naive_op_cost_estimates = 3 * super().naive_cost_estimates(source_op_cost_estimates)
160
+
161
+ # for naive setting, estimate quality as quality of refine model
162
+ model_quality = MODEL_CARDS[self.refine_model.value]["overall"] / 100.0
163
+ naive_op_cost_estimates.quality = model_quality
164
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
165
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
166
+
167
+ return naive_op_cost_estimates
168
+
169
+ def filter(self, candidate: DataRecord) -> tuple[dict[str, bool], GenerationStats]:
170
+ # get input fields
171
+ input_fields = self.get_input_fields()
172
+
173
+ # construct output fields
174
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the record passed the filter operation")}
175
+
176
+ # NOTE: when I merge in the `abacus` branch, I will want to update this to reflect the changes I made to reasoning extraction
177
+ # execute the initial model
178
+ original_gen_kwargs = {"project_cols": input_fields, "filter_condition": self.filter_obj.filter_condition}
179
+ field_answers, reasoning, original_gen_stats, original_messages = self.generator(candidate, fields, **original_gen_kwargs)
180
+ original_output = f"REASONING: {reasoning}\nANSWER: {str(field_answers['passed_operator']).upper()}\n"
181
+
182
+ # execute the critic model
183
+ critic_gen_kwargs = {"original_output": original_output, "original_messages": original_messages, **original_gen_kwargs}
184
+ _, reasoning, critic_gen_stats, _ = self.critic_generator(candidate, fields, json_output=False, **critic_gen_kwargs)
185
+ critique_output = f"CRITIQUE: {reasoning}\n"
186
+
187
+ # execute the refinement model
188
+ refine_gen_kwargs = {"critique_output": critique_output, **critic_gen_kwargs}
189
+ field_answers, reasoning, refine_gen_stats, _ = self.refine_generator(candidate, fields, **refine_gen_kwargs)
190
+
191
+ # compute the total generation stats
192
+ generation_stats = original_gen_stats + critic_gen_stats + refine_gen_stats
193
+
194
+ return field_answers, generation_stats
@@ -35,27 +35,27 @@ class DistinctOp(PhysicalOperator):
35
35
 
36
36
  def __call__(self, candidate: DataRecord) -> DataRecordSet:
37
37
  # create new DataRecord
38
- dr = DataRecord.from_parent(schema=candidate.schema, parent_record=candidate)
38
+ dr = DataRecord.from_parent(schema=candidate.schema, data_item={}, parent_record=candidate)
39
39
 
40
40
  # output record only if it has not been seen before
41
41
  record_str = dr.to_json_str(project_cols=self.distinct_cols, bytes_to_str=True, sorted=True)
42
42
  record_hash = f"{hash(record_str)}"
43
- dr.passed_operator = record_hash not in self._distinct_seen
44
- if dr.passed_operator:
43
+ dr._passed_operator = record_hash not in self._distinct_seen
44
+ if dr._passed_operator:
45
45
  self._distinct_seen.add(record_hash)
46
46
 
47
47
  # create RecordOpStats object
48
48
  record_op_stats = RecordOpStats(
49
- record_id=dr.id,
50
- record_parent_ids=dr.parent_ids,
51
- record_source_indices=dr.source_indices,
49
+ record_id=dr._id,
50
+ record_parent_ids=dr._parent_ids,
51
+ record_source_indices=dr._source_indices,
52
52
  record_state=dr.to_dict(include_bytes=False),
53
53
  full_op_id=self.get_full_op_id(),
54
54
  logical_op_id=self.logical_op_id,
55
55
  op_name=self.op_name(),
56
56
  time_per_record=0.0,
57
57
  cost_per_record=0.0,
58
- passed_operator=dr.passed_operator,
58
+ passed_operator=dr._passed_operator,
59
59
  op_details={k: str(v) for k, v in self.get_id_params().items()},
60
60
  )
61
61
 
@@ -41,11 +41,6 @@ class FilterOp(PhysicalOperator, ABC):
41
41
  op_params = super().get_op_params()
42
42
  return {"filter": self.filter_obj, "desc": self.desc, **op_params}
43
43
 
44
- @abstractmethod
45
- def is_image_filter(self) -> bool:
46
- """Return True if the filter operation processes an image, False otherwise."""
47
- pass
48
-
49
44
  @abstractmethod
50
45
  def filter(self, candidate: DataRecord) -> tuple[dict[str, bool], GenerationStats]:
51
46
  """
@@ -76,14 +71,14 @@ class FilterOp(PhysicalOperator, ABC):
76
71
  construct the resulting RecordSet.
77
72
  """
78
73
  # create new DataRecord and set passed_operator attribute
79
- dr = DataRecord.from_parent(candidate.schema, parent_record=candidate)
80
- dr.passed_operator = passed_operator
74
+ dr = DataRecord.from_parent(schema=candidate.schema, data_item={}, parent_record=candidate)
75
+ dr._passed_operator = passed_operator
81
76
 
82
77
  # create RecordOpStats object
83
78
  record_op_stats = RecordOpStats(
84
- record_id=dr.id,
85
- record_parent_ids=dr.parent_ids,
86
- record_source_indices=dr.source_indices,
79
+ record_id=dr._id,
80
+ record_parent_ids=dr._parent_ids,
81
+ record_source_indices=dr._source_indices,
87
82
  record_state=dr.to_dict(include_bytes=False),
88
83
  full_op_id=self.get_full_op_id(),
89
84
  logical_op_id=self.logical_op_id,
@@ -102,7 +97,6 @@ class FilterOp(PhysicalOperator, ABC):
102
97
  total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
103
98
  answer=answer,
104
99
  passed_operator=passed_operator,
105
- image_operation=self.is_image_filter(),
106
100
  op_details={k: str(v) for k, v in self.get_id_params().items()},
107
101
  )
108
102
 
@@ -127,10 +121,6 @@ class FilterOp(PhysicalOperator, ABC):
127
121
 
128
122
 
129
123
  class NonLLMFilter(FilterOp):
130
- def is_image_filter(self) -> bool:
131
- # NOTE: even if the UDF is processing an image, we do not consider this an image filter
132
- # (the output of this function will be used by the CostModel in a way which does not apply to UDFs)
133
- return False
134
124
 
135
125
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates):
136
126
  # estimate output cardinality using a constant assumption of the filter selectivity
@@ -174,7 +164,7 @@ class LLMFilter(FilterOp):
174
164
  def __init__(
175
165
  self,
176
166
  model: Model,
177
- prompt_strategy: PromptStrategy = PromptStrategy.COT_BOOL,
167
+ prompt_strategy: PromptStrategy = PromptStrategy.FILTER,
178
168
  reasoning_effort: str | None = None,
179
169
  *args,
180
170
  **kwargs,
@@ -183,13 +173,14 @@ class LLMFilter(FilterOp):
183
173
  self.model = model
184
174
  self.prompt_strategy = prompt_strategy
185
175
  self.reasoning_effort = reasoning_effort
186
- self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
176
+ if model is not None:
177
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
187
178
 
188
179
  def get_id_params(self):
189
180
  id_params = super().get_id_params()
190
181
  id_params = {
191
- "model": self.model.value,
192
- "prompt_strategy": self.prompt_strategy.value,
182
+ "model": None if self.model is None else self.model.value,
183
+ "prompt_strategy": None if self.prompt_strategy is None else self.prompt_strategy.value,
193
184
  "reasoning_effort": self.reasoning_effort,
194
185
  **id_params,
195
186
  }
@@ -208,15 +199,12 @@ class LLMFilter(FilterOp):
208
199
  return op_params
209
200
 
210
201
  def get_model_name(self):
211
- return self.model.value
212
-
213
- def is_image_filter(self) -> bool:
214
- return self.prompt_strategy is PromptStrategy.COT_BOOL_IMAGE
202
+ return None if self.model is None else self.model.value
215
203
 
216
204
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates):
217
205
  # estimate number of input tokens from source
218
206
  est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS
219
- if self.is_image_filter():
207
+ if self.is_image_op():
220
208
  est_num_input_tokens = 765 / 10 # 1024x1024 image is 765 tokens
221
209
 
222
210
  # NOTE: the output often generates an entire reasoning sentence, thus the true value may be higher
@@ -232,7 +220,7 @@ class LLMFilter(FilterOp):
232
220
  # get est. of conversion cost (in USD) per record from model card
233
221
  usd_per_input_token = (
234
222
  MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
235
- if self.prompt_strategy.is_audio_prompt()
223
+ if self.is_audio_op()
236
224
  else MODEL_CARDS[self.model.value]["usd_per_input_token"]
237
225
  )
238
226
  model_conversion_usd_per_record = (