palimpzest 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. palimpzest/constants.py +38 -62
  2. palimpzest/core/data/dataset.py +1 -1
  3. palimpzest/core/data/iter_dataset.py +5 -5
  4. palimpzest/core/elements/groupbysig.py +1 -1
  5. palimpzest/core/elements/records.py +91 -109
  6. palimpzest/core/lib/schemas.py +23 -0
  7. palimpzest/core/models.py +3 -3
  8. palimpzest/prompts/__init__.py +2 -6
  9. palimpzest/prompts/convert_prompts.py +10 -66
  10. palimpzest/prompts/critique_and_refine_prompts.py +66 -0
  11. palimpzest/prompts/filter_prompts.py +8 -46
  12. palimpzest/prompts/join_prompts.py +12 -75
  13. palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
  14. palimpzest/prompts/moa_proposer_prompts.py +87 -0
  15. palimpzest/prompts/prompt_factory.py +351 -479
  16. palimpzest/prompts/split_merge_prompts.py +51 -2
  17. palimpzest/prompts/split_proposer_prompts.py +48 -16
  18. palimpzest/prompts/utils.py +109 -0
  19. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  20. palimpzest/query/execution/execution_strategy.py +4 -4
  21. palimpzest/query/execution/mab_execution_strategy.py +47 -23
  22. palimpzest/query/execution/parallel_execution_strategy.py +3 -3
  23. palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
  24. palimpzest/query/generators/generators.py +31 -17
  25. palimpzest/query/operators/__init__.py +15 -2
  26. palimpzest/query/operators/aggregate.py +21 -19
  27. palimpzest/query/operators/compute.py +6 -8
  28. palimpzest/query/operators/convert.py +12 -37
  29. palimpzest/query/operators/critique_and_refine.py +194 -0
  30. palimpzest/query/operators/distinct.py +7 -7
  31. palimpzest/query/operators/filter.py +13 -25
  32. palimpzest/query/operators/join.py +321 -192
  33. palimpzest/query/operators/limit.py +4 -4
  34. palimpzest/query/operators/mixture_of_agents.py +246 -0
  35. palimpzest/query/operators/physical.py +25 -2
  36. palimpzest/query/operators/project.py +4 -4
  37. palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
  38. palimpzest/query/operators/retrieve.py +10 -9
  39. palimpzest/query/operators/scan.py +9 -10
  40. palimpzest/query/operators/search.py +18 -24
  41. palimpzest/query/operators/split.py +321 -0
  42. palimpzest/query/optimizer/__init__.py +12 -8
  43. palimpzest/query/optimizer/optimizer.py +12 -10
  44. palimpzest/query/optimizer/rules.py +201 -108
  45. palimpzest/query/optimizer/tasks.py +18 -6
  46. palimpzest/query/processor/config.py +2 -2
  47. palimpzest/query/processor/query_processor.py +2 -2
  48. palimpzest/query/processor/query_processor_factory.py +9 -5
  49. palimpzest/validator/validator.py +7 -9
  50. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/METADATA +3 -8
  51. palimpzest-0.8.3.dist-info/RECORD +95 -0
  52. palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
  53. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
  54. palimpzest/prompts/util_phrases.py +0 -19
  55. palimpzest/query/operators/critique_and_refine_convert.py +0 -113
  56. palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
  57. palimpzest/query/operators/split_convert.py +0 -170
  58. palimpzest-0.8.1.dist-info/RECORD +0 -95
  59. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/WHEEL +0 -0
  60. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/licenses/LICENSE +0 -0
  61. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/top_level.txt +0 -0
@@ -101,7 +101,7 @@ def get_json_from_answer(answer: str, model: Model, cardinality: Cardinality) ->
101
101
  # TODO: make sure answer parsing works with custom prompts / parsers (can defer this)
102
102
  class Generator(Generic[ContextType, InputType]):
103
103
  """
104
- Abstract base class for Generators.
104
+ Class for generating new fields for a record using an LLM.
105
105
  """
106
106
 
107
107
  def __init__(
@@ -181,11 +181,11 @@ class Generator(Generic[ContextType, InputType]):
181
181
 
182
182
  return None
183
183
 
184
- def _check_bool_answer_text(self, answer_text: str) -> dict | None:
184
+ def _check_bool_answer_text(self, answer_text: str, throw_exception: bool=False) -> dict | None:
185
185
  """
186
186
  Return {"passed_operator": True} if and only if "true" is in the answer text.
187
187
  Return {"passed_operator": False} if and only if "false" is in the answer text.
188
- Otherwise, return None.
188
+ Otherwise, raise an exception.
189
189
  """
190
190
  # NOTE: we may be able to eliminate this condition by specifying this JSON output in the prompt;
191
191
  # however, that would also need to coincide with a change to allow the parse_answer_fn to set "passed_operator"
@@ -194,6 +194,9 @@ class Generator(Generic[ContextType, InputType]):
194
194
  elif "false" in answer_text.lower():
195
195
  return {"passed_operator": False}
196
196
 
197
+ if throw_exception:
198
+ raise Exception(f"Could not parse answer from completion text: {answer_text}")
199
+
197
200
  return None
198
201
 
199
202
  def _parse_convert_answer(self, completion_text: str, fields: dict[str, FieldInfo], json_output: bool) -> dict[str, list]:
@@ -235,7 +238,7 @@ class Generator(Generic[ContextType, InputType]):
235
238
 
236
239
  return self._check_convert_answer_text(completion_text, fields, throw_exception=True)
237
240
 
238
- def _parse_bool_answer(self, completion_text: str) -> dict[str, list]:
241
+ def _parse_bool_answer(self, completion_text: str, json_output: bool) -> dict[str, list]:
239
242
  """Extract the answer from the completion object for filter and join operations."""
240
243
  # if the model followed the default instructions, the completion text will place
241
244
  # its answer between "ANSWER:" and "---"
@@ -243,6 +246,12 @@ class Generator(Generic[ContextType, InputType]):
243
246
  matches = regex.findall(completion_text)
244
247
  if len(matches) > 0:
245
248
  answer_text = matches[0].strip()
249
+
250
+ # if we don't expect a JSON output, return the answer text as is
251
+ if not json_output:
252
+ return answer_text
253
+
254
+ # otherwise, try to parse the answer text into a JSON object
246
255
  field_answers = self._check_bool_answer_text(answer_text)
247
256
  if field_answers is not None:
248
257
  return field_answers
@@ -252,16 +261,21 @@ class Generator(Generic[ContextType, InputType]):
252
261
  matches = regex.findall(completion_text)
253
262
  if len(matches) > 0:
254
263
  answer_text = matches[0].strip()
264
+
265
+ # if we don't expect a JSON output, return the answer text as is
266
+ if not json_output:
267
+ return answer_text
268
+
269
+ # otherwise, try to parse the answer text into a JSON object
255
270
  field_answers = self._check_bool_answer_text(answer_text)
256
271
  if field_answers is not None:
257
272
  return field_answers
258
273
 
259
- # finally, try taking all of the text; throw an exception if this doesn't work
260
- field_answers = self._check_bool_answer_text(completion_text)
261
- if field_answers is None:
262
- raise Exception(f"Could not parse answer from completion text: {completion_text}")
274
+ # finally, try taking all of the text; for JSON output, throw an exception if parsing fails
275
+ if not json_output:
276
+ return completion_text
263
277
 
264
- return field_answers
278
+ return self._check_bool_answer_text(completion_text, throw_exception=True)
265
279
 
266
280
  def _parse_answer(self, completion_text: str, fields: dict[str, FieldInfo] | None, json_output: bool, **kwargs) -> dict[str, list]:
267
281
  """Extract the answer from the completion object."""
@@ -275,8 +289,8 @@ class Generator(Generic[ContextType, InputType]):
275
289
 
276
290
  # extract the per-field answers from the completion text
277
291
  field_answers = (
278
- self._parse_bool_answer(completion_text)
279
- if self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()
292
+ self._parse_bool_answer(completion_text, json_output)
293
+ if self.prompt_strategy.is_filter_prompt() or self.prompt_strategy.is_join_prompt()
280
294
  else self._parse_convert_answer(completion_text, fields, json_output)
281
295
  )
282
296
 
@@ -299,6 +313,7 @@ class Generator(Generic[ContextType, InputType]):
299
313
 
300
314
  # generate a list of messages which can be used to construct a payload
301
315
  messages = self.prompt_factory.create_messages(candidate, fields, right_candidate, **kwargs)
316
+ is_audio_op = any(msg.get("type") == "input_audio" for msg in messages)
302
317
 
303
318
  # generate the text completion
304
319
  start_time = time.time()
@@ -307,7 +322,7 @@ class Generator(Generic[ContextType, InputType]):
307
322
  completion_kwargs = {}
308
323
  if not self.model.is_o_model() and not self.model.is_gpt_5_model():
309
324
  completion_kwargs = {"temperature": kwargs.get("temperature", 0.0), **completion_kwargs}
310
- if self.prompt_strategy.is_audio_prompt():
325
+ if is_audio_op:
311
326
  completion_kwargs = {"modalities": ["text"], **completion_kwargs}
312
327
  if self.model.is_reasoning_model():
313
328
  if self.model.is_vertex_model():
@@ -330,11 +345,10 @@ class Generator(Generic[ContextType, InputType]):
330
345
  # if there's an error generating the completion, we have to return an empty answer
331
346
  # and can only account for the time spent performing the failed generation
332
347
  except Exception as e:
333
- print(f"Error generating completion: {e}")
334
348
  logger.error(f"Error generating completion: {e}")
335
349
  field_answers = (
336
350
  {"passed_operator": False}
337
- if self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()
351
+ if self.prompt_strategy.is_filter_prompt() or self.prompt_strategy.is_join_prompt()
338
352
  else {field_name: None for field_name in fields}
339
353
  )
340
354
  reasoning = None
@@ -360,7 +374,7 @@ class Generator(Generic[ContextType, InputType]):
360
374
  # for now, we only use tokens from prompt_token_details if it's an audio prompt
361
375
  # get output tokens (all text) and input tokens by modality
362
376
  output_tokens = usage["completion_tokens"]
363
- if self.prompt_strategy.is_audio_prompt():
377
+ if is_audio_op:
364
378
  input_audio_tokens = usage["prompt_tokens_details"].get("audio_tokens", 0)
365
379
  input_text_tokens = usage["prompt_tokens_details"].get("text_tokens", 0)
366
380
  input_image_tokens = 0
@@ -413,9 +427,9 @@ class Generator(Generic[ContextType, InputType]):
413
427
 
414
428
  # parse field answers
415
429
  field_answers = None
416
- if fields is not None and (self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()):
430
+ if fields is not None and (self.prompt_strategy.is_filter_prompt() or self.prompt_strategy.is_join_prompt()):
417
431
  field_answers = {"passed_operator": False}
418
- elif fields is not None and not (self.prompt_strategy.is_bool_prompt() or self.prompt_strategy.is_join_prompt()):
432
+ elif fields is not None and not (self.prompt_strategy.is_filter_prompt() or self.prompt_strategy.is_join_prompt()):
419
433
  field_answers = {field_name: None for field_name in fields}
420
434
  try:
421
435
  field_answers = self._parse_answer(completion_text, fields, json_output, **kwargs)
@@ -6,6 +6,8 @@ from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
6
6
  from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
7
7
  from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
8
8
  from palimpzest.query.operators.convert import NonLLMConvert as _NonLLMConvert
9
+ from palimpzest.query.operators.critique_and_refine import CritiqueAndRefineConvert as _CritiqueAndRefineConvert
10
+ from palimpzest.query.operators.critique_and_refine import CritiqueAndRefineFilter as _CritiqueAndRefineFilter
9
11
  from palimpzest.query.operators.distinct import DistinctOp as _DistinctOp
10
12
  from palimpzest.query.operators.filter import FilterOp as _FilterOp
11
13
  from palimpzest.query.operators.filter import LLMFilter as _LLMFilter
@@ -46,12 +48,17 @@ from palimpzest.query.operators.logical import (
46
48
  from palimpzest.query.operators.logical import (
47
49
  RetrieveScan as _RetrieveScan,
48
50
  )
49
- from palimpzest.query.operators.mixture_of_agents_convert import MixtureOfAgentsConvert as _MixtureOfAgentsConvert
51
+ from palimpzest.query.operators.mixture_of_agents import MixtureOfAgentsConvert as _MixtureOfAgentsConvert
52
+ from palimpzest.query.operators.mixture_of_agents import MixtureOfAgentsFilter as _MixtureOfAgentsFilter
50
53
  from palimpzest.query.operators.physical import PhysicalOperator as _PhysicalOperator
51
54
  from palimpzest.query.operators.project import ProjectOp as _ProjectOp
55
+ from palimpzest.query.operators.rag import RAGConvert as _RAGConvert
56
+ from palimpzest.query.operators.rag import RAGFilter as _RAGFilter
52
57
  from palimpzest.query.operators.retrieve import RetrieveOp as _RetrieveOp
53
58
  from palimpzest.query.operators.scan import MarshalAndScanDataOp as _MarshalAndScanDataOp
54
59
  from palimpzest.query.operators.scan import ScanPhysicalOp as _ScanPhysicalOp
60
+ from palimpzest.query.operators.split import SplitConvert as _SplitConvert
61
+ from palimpzest.query.operators.split import SplitFilter as _SplitFilter
55
62
 
56
63
  LOGICAL_OPERATORS = [
57
64
  _LogicalOperator,
@@ -72,6 +79,8 @@ PHYSICAL_OPERATORS = (
72
79
  [_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp]
73
80
  # convert
74
81
  + [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
82
+ # critique and refine
83
+ + [_CritiqueAndRefineConvert, _CritiqueAndRefineFilter]
75
84
  # distinct
76
85
  + [_DistinctOp]
77
86
  # scan
@@ -83,13 +92,17 @@ PHYSICAL_OPERATORS = (
83
92
  # limit
84
93
  + [_LimitScanOp]
85
94
  # mixture-of-agents
86
- + [_MixtureOfAgentsConvert]
95
+ + [_MixtureOfAgentsConvert, _MixtureOfAgentsFilter]
87
96
  # physical
88
97
  + [_PhysicalOperator]
89
98
  # project
90
99
  + [_ProjectOp]
100
+ # rag
101
+ + [_RAGConvert, _RAGFilter]
91
102
  # retrieve
92
103
  + [_RetrieveOp]
104
+ # split
105
+ + [_SplitConvert, _SplitFilter]
93
106
  )
94
107
 
95
108
  __all__ = [
@@ -113,18 +113,20 @@ class ApplyGroupByOp(AggregateOp):
113
113
  group_by_fields = self.group_by_sig.group_by_fields
114
114
  agg_fields = self.group_by_sig.get_agg_field_names()
115
115
  for g in agg_state:
116
- dr = DataRecord.from_agg_parents(
117
- schema=self.group_by_sig.output_schema(),
118
- parent_records=candidates,
119
- )
116
+ # build up data item
117
+ data_item = {}
120
118
  for i in range(0, len(g)):
121
119
  k = g[i]
122
- setattr(dr, group_by_fields[i], k)
120
+ data_item[group_by_fields[i]] = k
123
121
  vals = agg_state[g]
124
122
  for i in range(0, len(vals)):
125
123
  v = ApplyGroupByOp.agg_final(self.group_by_sig.agg_funcs[i], vals[i])
126
- setattr(dr, agg_fields[i], v)
124
+ data_item[agg_fields[i]] = v
127
125
 
126
+ # create new DataRecord
127
+ schema = self.group_by_sig.output_schema()
128
+ data_item = schema(**data_item)
129
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
128
130
  drs.append(dr)
129
131
 
130
132
  # create RecordOpStats objects
@@ -132,9 +134,9 @@ class ApplyGroupByOp(AggregateOp):
132
134
  record_op_stats_lst = []
133
135
  for dr in drs:
134
136
  record_op_stats = RecordOpStats(
135
- record_id=dr.id,
136
- record_parent_ids=dr.parent_ids,
137
- record_source_indices=dr.source_indices,
137
+ record_id=dr._id,
138
+ record_parent_ids=dr._parent_ids,
139
+ record_source_indices=dr._source_indices,
138
140
  record_state=dr.to_dict(include_bytes=False),
139
141
  full_op_id=self.get_full_op_id(),
140
142
  logical_op_id=self.logical_op_id,
@@ -197,7 +199,6 @@ class AverageAggregateOp(AggregateOp):
197
199
  # NOTE: right now we perform a check in the constructor which enforces that the input_schema
198
200
  # has a single field which is numeric in nature; in the future we may want to have a
199
201
  # cleaner way of computing the value (rather than `float(list(candidate...))` below)
200
- dr = DataRecord.from_agg_parents(schema=Average, parent_records=candidates)
201
202
  summation, total = 0, 0
202
203
  for candidate in candidates:
203
204
  try:
@@ -205,13 +206,14 @@ class AverageAggregateOp(AggregateOp):
205
206
  total += 1
206
207
  except Exception:
207
208
  pass
208
- dr.average = summation / total
209
+ data_item = Average(average=summation / total)
210
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
209
211
 
210
212
  # create RecordOpStats object
211
213
  record_op_stats = RecordOpStats(
212
- record_id=dr.id,
213
- record_parent_ids=dr.parent_ids,
214
- record_source_indices=dr.source_indices,
214
+ record_id=dr._id,
215
+ record_parent_ids=dr._parent_ids,
216
+ record_source_indices=dr._source_indices,
215
217
  record_state=dr.to_dict(include_bytes=False),
216
218
  full_op_id=self.get_full_op_id(),
217
219
  logical_op_id=self.logical_op_id,
@@ -260,14 +262,14 @@ class CountAggregateOp(AggregateOp):
260
262
  start_time = time.time()
261
263
 
262
264
  # create new DataRecord
263
- dr = DataRecord.from_agg_parents(schema=Count, parent_records=candidates)
264
- dr.count = len(candidates)
265
+ data_item = Count(count=len(candidates))
266
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
265
267
 
266
268
  # create RecordOpStats object
267
269
  record_op_stats = RecordOpStats(
268
- record_id=dr.id,
269
- record_parent_ids=dr.parent_ids,
270
- record_source_indices=dr.source_indices,
270
+ record_id=dr._id,
271
+ record_parent_ids=dr._parent_ids,
272
+ record_source_indices=dr._source_indices,
271
273
  record_state=dr.to_dict(include_bytes=False),
272
274
  full_op_id=self.get_full_op_id(),
273
275
  logical_op_id=self.logical_op_id,
@@ -93,17 +93,15 @@ class SmolAgentsCompute(PhysicalOperator):
93
93
  Given an input DataRecord and a determination of whether it passed the filter or not,
94
94
  construct the resulting RecordSet.
95
95
  """
96
- # create new DataRecord and set passed_operator attribute
97
- dr = DataRecord.from_parent(self.output_schema, parent_record=candidate)
98
- for field in self.output_schema.model_fields:
99
- if field in answer:
100
- dr[field] = answer[field]
96
+ # create new DataRecord
97
+ data_item = {field: answer[field] for field in self.output_schema.model_fields if field in answer}
98
+ dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
101
99
 
102
100
  # create RecordOpStats object
103
101
  record_op_stats = RecordOpStats(
104
- record_id=dr.id,
105
- record_parent_ids=dr.parent_ids,
106
- record_source_indices=dr.source_indices,
102
+ record_id=dr._id,
103
+ record_parent_ids=dr._parent_ids,
104
+ record_source_indices=dr._source_indices,
107
105
  record_state=dr.to_dict(include_bytes=False),
108
106
  full_op_id=self.get_full_op_id(),
109
107
  logical_op_id=self.logical_op_id,
@@ -74,25 +74,14 @@ class ConvertOp(PhysicalOperator, ABC):
74
74
 
75
75
  drs = []
76
76
  for idx in range(max(n_records, 1)):
77
- # initialize record with the correct output schema, parent record, and cardinality idx
78
- dr = DataRecord.from_parent(self.output_schema, parent_record=candidate, cardinality_idx=idx)
79
-
80
- # copy all fields from the input record
81
- # NOTE: this means that records processed by PZ converts will inherit all pre-computed fields
82
- # in an incremental fashion; this is a design choice which may be revisited in the future
83
- for field in candidate.get_field_names():
84
- setattr(dr, field, getattr(candidate, field))
85
-
86
- # get input field names and output field names
87
- input_fields = list(self.input_schema.model_fields)
88
- output_fields = list(self.output_schema.model_fields)
89
-
90
77
  # parse newly generated fields from the field_answers dictionary for this field; if the list
91
78
  # of generated values is shorter than the number of records, we fill in with None
92
- for field in output_fields:
93
- if field not in input_fields:
94
- value = field_answers[field][idx] if idx < len(field_answers[field]) else None
95
- setattr(dr, field, value)
79
+ data_item = {}
80
+ for field in self.generated_fields:
81
+ data_item[field] = field_answers[field][idx] if idx < len(field_answers[field]) else None
82
+
83
+ # initialize record with the correct output schema, data_item, parent record, and cardinality idx
84
+ dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate, cardinality_idx=idx)
96
85
 
97
86
  # append data record to list of output data records
98
87
  drs.append(dr)
@@ -117,9 +106,9 @@ class ConvertOp(PhysicalOperator, ABC):
117
106
  # create the RecordOpStats objects for each output record
118
107
  record_op_stats_lst = [
119
108
  RecordOpStats(
120
- record_id=dr.id,
121
- record_parent_ids=dr.parent_ids,
122
- record_source_indices=dr.source_indices,
109
+ record_id=dr._id,
110
+ record_parent_ids=dr._parent_ids,
111
+ record_source_indices=dr._source_indices,
123
112
  record_state=dr.to_dict(include_bytes=False),
124
113
  full_op_id=self.get_full_op_id(),
125
114
  logical_op_id=self.logical_op_id,
@@ -127,7 +116,7 @@ class ConvertOp(PhysicalOperator, ABC):
127
116
  time_per_record=time_per_record,
128
117
  cost_per_record=per_record_stats.cost_per_record,
129
118
  model_name=self.get_model_name(),
130
- answer={field_name: getattr(dr, field_name) for field_name in field_names},
119
+ answer={field_name: getattr(dr, field_name, None) for field_name in field_names},
131
120
  input_fields=list(self.input_schema.model_fields),
132
121
  generated_fields=field_names,
133
122
  total_input_tokens=per_record_stats.total_input_tokens,
@@ -139,7 +128,6 @@ class ConvertOp(PhysicalOperator, ABC):
139
128
  total_llm_calls=per_record_stats.total_llm_calls,
140
129
  total_embedding_llm_calls=per_record_stats.total_embedding_llm_calls,
141
130
  failed_convert=(not successful_convert),
142
- image_operation=self.is_image_conversion(),
143
131
  op_details={k: str(v) for k, v in self.get_id_params().items()},
144
132
  )
145
133
  for dr in records
@@ -148,11 +136,6 @@ class ConvertOp(PhysicalOperator, ABC):
148
136
  # create and return the DataRecordSet
149
137
  return DataRecordSet(records, record_op_stats_lst)
150
138
 
151
- @abstractmethod
152
- def is_image_conversion(self) -> bool:
153
- """Return True if the convert operation processes an image, False otherwise."""
154
- pass
155
-
156
139
  @abstractmethod
157
140
  def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
158
141
  """
@@ -216,11 +199,6 @@ class NonLLMConvert(ConvertOp):
216
199
  op += f" UDF: {self.udf.__name__}\n"
217
200
  return op
218
201
 
219
- def is_image_conversion(self) -> bool:
220
- # NOTE: even if the UDF is processing an image, we do not consider this an image conversion
221
- # (the output of this function will be used by the CostModel in a way which does not apply to UDFs)
222
- return False
223
-
224
202
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
225
203
  """
226
204
  Compute naive cost estimates for the NonLLMConvert operation. These estimates assume
@@ -287,7 +265,7 @@ class LLMConvert(ConvertOp):
287
265
  def __init__(
288
266
  self,
289
267
  model: Model,
290
- prompt_strategy: PromptStrategy = PromptStrategy.COT_QA,
268
+ prompt_strategy: PromptStrategy = PromptStrategy.MAP,
291
269
  reasoning_effort: str | None = None,
292
270
  *args,
293
271
  **kwargs,
@@ -330,9 +308,6 @@ class LLMConvert(ConvertOp):
330
308
  def get_model_name(self):
331
309
  return None if self.model is None else self.model.value
332
310
 
333
- def is_image_conversion(self) -> bool:
334
- return self.prompt_strategy.is_image_prompt()
335
-
336
311
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
337
312
  """
338
313
  Compute naive cost estimates for the LLMConvert operation. Implicitly, these estimates
@@ -350,7 +325,7 @@ class LLMConvert(ConvertOp):
350
325
 
351
326
  # get est. of conversion cost (in USD) per record from model card
352
327
  usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
353
- if getattr(self, "prompt_strategy", None) is not None and self.prompt_strategy.is_audio_prompt():
328
+ if getattr(self, "prompt_strategy", None) is not None and self.is_audio_op():
354
329
  usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
355
330
 
356
331
  model_conversion_usd_per_record = (
@@ -0,0 +1,194 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from pydantic.fields import FieldInfo
6
+
7
+ from palimpzest.constants import MODEL_CARDS, Cardinality, Model, PromptStrategy
8
+ from palimpzest.core.elements.records import DataRecord
9
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates
10
+ from palimpzest.query.generators.generators import Generator
11
+ from palimpzest.query.operators.convert import LLMConvert
12
+ from palimpzest.query.operators.filter import LLMFilter
13
+
14
+ # TYPE DEFINITIONS
15
+ FieldName = str
16
+
17
+
18
+ class CritiqueAndRefineConvert(LLMConvert):
19
+
20
+ def __init__(
21
+ self,
22
+ critic_model: Model,
23
+ refine_model: Model,
24
+ *args,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(*args, **kwargs)
28
+ self.critic_model = critic_model
29
+ self.refine_model = refine_model
30
+
31
+ # create generators
32
+ self.critic_generator = Generator(self.critic_model, PromptStrategy.MAP_CRITIC, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
33
+ self.refine_generator = Generator(self.refine_model, PromptStrategy.MAP_REFINE, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
34
+
35
+ def __str__(self):
36
+ op = super().__str__()
37
+ op += f" Critic Model: {self.critic_model}\n"
38
+ op += f" Refine Model: {self.refine_model}\n"
39
+ return op
40
+
41
+ def get_id_params(self):
42
+ id_params = super().get_id_params()
43
+ id_params = {
44
+ "critic_model": self.critic_model.value,
45
+ "refine_model": self.refine_model.value,
46
+ **id_params,
47
+ }
48
+
49
+ return id_params
50
+
51
+ def get_op_params(self):
52
+ op_params = super().get_op_params()
53
+ op_params = {
54
+ "critic_model": self.critic_model,
55
+ "refine_model": self.refine_model,
56
+ **op_params,
57
+ }
58
+
59
+ return op_params
60
+
61
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
62
+ """
63
+ Currently, we are invoking `self.model`, then critiquing its output with `self.critic_model`, and
64
+ finally refining the output with `self.refine_model`. Thus, we roughly expect to incur the cost
65
+ and time of three LLMConverts. In practice, this naive quality estimate will be overwritten by the
66
+ CostModel's estimate once it executes a few instances of the operator.
67
+ """
68
+ # get naive cost estimates for first LLM call and multiply by 3 for now;
69
+ # of course we should sum individual estimates for each model, but this is a rough estimate
70
+ # and in practice we will need to revamp our naive cost estimates in the near future
71
+ naive_op_cost_estimates = 3 * super().naive_cost_estimates(source_op_cost_estimates)
72
+
73
+ # for naive setting, estimate quality as quality of refine model
74
+ model_quality = MODEL_CARDS[self.refine_model.value]["overall"] / 100.0
75
+ naive_op_cost_estimates.quality = model_quality
76
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
77
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
78
+
79
+ return naive_op_cost_estimates
80
+
81
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
82
+ # get input fields
83
+ input_fields = self.get_input_fields()
84
+
85
+ # NOTE: when I merge in the `abacus` branch, I will want to update this to reflect the changes I made to reasoning extraction
86
+ # execute the initial model
87
+ original_gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema}
88
+ field_answers, reasoning, original_gen_stats, original_messages = self.generator(candidate, fields, **original_gen_kwargs)
89
+ original_output = f"REASONING: {reasoning}\nANSWER: {field_answers}\n"
90
+
91
+ # execute the critic model
92
+ critic_gen_kwargs = {"original_output": original_output, "original_messages": original_messages, **original_gen_kwargs}
93
+ _, reasoning, critic_gen_stats, _ = self.critic_generator(candidate, fields, json_output=False, **critic_gen_kwargs)
94
+ critique_output = f"CRITIQUE: {reasoning}\n"
95
+
96
+ # execute the refinement model
97
+ refine_gen_kwargs = {"critique_output": critique_output, **critic_gen_kwargs}
98
+ field_answers, reasoning, refine_gen_stats, _ = self.refine_generator(candidate, fields, **refine_gen_kwargs)
99
+
100
+ # compute the total generation stats
101
+ generation_stats = original_gen_stats + critic_gen_stats + refine_gen_stats
102
+
103
+ return field_answers, generation_stats
104
+
105
+
106
+ class CritiqueAndRefineFilter(LLMFilter):
107
+
108
+ def __init__(
109
+ self,
110
+ critic_model: Model,
111
+ refine_model: Model,
112
+ *args,
113
+ **kwargs,
114
+ ):
115
+ super().__init__(*args, **kwargs)
116
+ self.critic_model = critic_model
117
+ self.refine_model = refine_model
118
+
119
+ # create generators
120
+ self.critic_generator = Generator(self.critic_model, PromptStrategy.FILTER_CRITIC, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
121
+ self.refine_generator = Generator(self.refine_model, PromptStrategy.FILTER_REFINE, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
122
+
123
+ def __str__(self):
124
+ op = super().__str__()
125
+ op += f" Critic Model: {self.critic_model}\n"
126
+ op += f" Refine Model: {self.refine_model}\n"
127
+ return op
128
+
129
+ def get_id_params(self):
130
+ id_params = super().get_id_params()
131
+ id_params = {
132
+ "critic_model": self.critic_model.value,
133
+ "refine_model": self.refine_model.value,
134
+ **id_params,
135
+ }
136
+
137
+ return id_params
138
+
139
+ def get_op_params(self):
140
+ op_params = super().get_op_params()
141
+ op_params = {
142
+ "critic_model": self.critic_model,
143
+ "refine_model": self.refine_model,
144
+ **op_params,
145
+ }
146
+
147
+ return op_params
148
+
149
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
150
+ """
151
+ Currently, we are invoking `self.model`, then critiquing its output with `self.critic_model`, and
152
+ finally refining the output with `self.refine_model`. Thus, we roughly expect to incur the cost
153
+ and time of three LLMFilters. In practice, this naive quality estimate will be overwritten by the
154
+ CostModel's estimate once it executes a few instances of the operator.
155
+ """
156
+ # get naive cost estimates for first LLM call and multiply by 3 for now;
157
+ # of course we should sum individual estimates for each model, but this is a rough estimate
158
+ # and in practice we will need to revamp our naive cost estimates in the near future
159
+ naive_op_cost_estimates = 3 * super().naive_cost_estimates(source_op_cost_estimates)
160
+
161
+ # for naive setting, estimate quality as quality of refine model
162
+ model_quality = MODEL_CARDS[self.refine_model.value]["overall"] / 100.0
163
+ naive_op_cost_estimates.quality = model_quality
164
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
165
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
166
+
167
+ return naive_op_cost_estimates
168
+
169
+ def filter(self, candidate: DataRecord) -> tuple[dict[str, bool], GenerationStats]:
170
+ # get input fields
171
+ input_fields = self.get_input_fields()
172
+
173
+ # construct output fields
174
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the record passed the filter operation")}
175
+
176
+ # NOTE: when I merge in the `abacus` branch, I will want to update this to reflect the changes I made to reasoning extraction
177
+ # execute the initial model
178
+ original_gen_kwargs = {"project_cols": input_fields, "filter_condition": self.filter_obj.filter_condition}
179
+ field_answers, reasoning, original_gen_stats, original_messages = self.generator(candidate, fields, **original_gen_kwargs)
180
+ original_output = f"REASONING: {reasoning}\nANSWER: {str(field_answers['passed_operator']).upper()}\n"
181
+
182
+ # execute the critic model
183
+ critic_gen_kwargs = {"original_output": original_output, "original_messages": original_messages, **original_gen_kwargs}
184
+ _, reasoning, critic_gen_stats, _ = self.critic_generator(candidate, fields, json_output=False, **critic_gen_kwargs)
185
+ critique_output = f"CRITIQUE: {reasoning}\n"
186
+
187
+ # execute the refinement model
188
+ refine_gen_kwargs = {"critique_output": critique_output, **critic_gen_kwargs}
189
+ field_answers, reasoning, refine_gen_stats, _ = self.refine_generator(candidate, fields, **refine_gen_kwargs)
190
+
191
+ # compute the total generation stats
192
+ generation_stats = original_gen_stats + critic_gen_stats + refine_gen_stats
193
+
194
+ return field_answers, generation_stats
@@ -35,27 +35,27 @@ class DistinctOp(PhysicalOperator):
35
35
 
36
36
  def __call__(self, candidate: DataRecord) -> DataRecordSet:
37
37
  # create new DataRecord
38
- dr = DataRecord.from_parent(schema=candidate.schema, parent_record=candidate)
38
+ dr = DataRecord.from_parent(schema=candidate.schema, data_item={}, parent_record=candidate)
39
39
 
40
40
  # output record only if it has not been seen before
41
41
  record_str = dr.to_json_str(project_cols=self.distinct_cols, bytes_to_str=True, sorted=True)
42
42
  record_hash = f"{hash(record_str)}"
43
- dr.passed_operator = record_hash not in self._distinct_seen
44
- if dr.passed_operator:
43
+ dr._passed_operator = record_hash not in self._distinct_seen
44
+ if dr._passed_operator:
45
45
  self._distinct_seen.add(record_hash)
46
46
 
47
47
  # create RecordOpStats object
48
48
  record_op_stats = RecordOpStats(
49
- record_id=dr.id,
50
- record_parent_ids=dr.parent_ids,
51
- record_source_indices=dr.source_indices,
49
+ record_id=dr._id,
50
+ record_parent_ids=dr._parent_ids,
51
+ record_source_indices=dr._source_indices,
52
52
  record_state=dr.to_dict(include_bytes=False),
53
53
  full_op_id=self.get_full_op_id(),
54
54
  logical_op_id=self.logical_op_id,
55
55
  op_name=self.op_name(),
56
56
  time_per_record=0.0,
57
57
  cost_per_record=0.0,
58
- passed_operator=dr.passed_operator,
58
+ passed_operator=dr._passed_operator,
59
59
  op_details={k: str(v) for k, v in self.get_id_params().items()},
60
60
  )
61
61