palimpzest 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +38 -62
- palimpzest/core/data/dataset.py +1 -1
- palimpzest/core/data/iter_dataset.py +5 -5
- palimpzest/core/elements/groupbysig.py +1 -1
- palimpzest/core/elements/records.py +91 -109
- palimpzest/core/lib/schemas.py +23 -0
- palimpzest/core/models.py +3 -3
- palimpzest/prompts/__init__.py +2 -6
- palimpzest/prompts/convert_prompts.py +10 -66
- palimpzest/prompts/critique_and_refine_prompts.py +66 -0
- palimpzest/prompts/filter_prompts.py +8 -46
- palimpzest/prompts/join_prompts.py +12 -75
- palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
- palimpzest/prompts/moa_proposer_prompts.py +87 -0
- palimpzest/prompts/prompt_factory.py +351 -479
- palimpzest/prompts/split_merge_prompts.py +51 -2
- palimpzest/prompts/split_proposer_prompts.py +48 -16
- palimpzest/prompts/utils.py +109 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
- palimpzest/query/execution/execution_strategy.py +4 -4
- palimpzest/query/execution/mab_execution_strategy.py +47 -23
- palimpzest/query/execution/parallel_execution_strategy.py +3 -3
- palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
- palimpzest/query/generators/generators.py +31 -17
- palimpzest/query/operators/__init__.py +15 -2
- palimpzest/query/operators/aggregate.py +21 -19
- palimpzest/query/operators/compute.py +6 -8
- palimpzest/query/operators/convert.py +12 -37
- palimpzest/query/operators/critique_and_refine.py +194 -0
- palimpzest/query/operators/distinct.py +7 -7
- palimpzest/query/operators/filter.py +13 -25
- palimpzest/query/operators/join.py +321 -192
- palimpzest/query/operators/limit.py +4 -4
- palimpzest/query/operators/mixture_of_agents.py +246 -0
- palimpzest/query/operators/physical.py +25 -2
- palimpzest/query/operators/project.py +4 -4
- palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
- palimpzest/query/operators/retrieve.py +10 -9
- palimpzest/query/operators/scan.py +9 -10
- palimpzest/query/operators/search.py +18 -24
- palimpzest/query/operators/split.py +321 -0
- palimpzest/query/optimizer/__init__.py +12 -8
- palimpzest/query/optimizer/optimizer.py +12 -10
- palimpzest/query/optimizer/rules.py +201 -108
- palimpzest/query/optimizer/tasks.py +18 -6
- palimpzest/query/processor/config.py +2 -2
- palimpzest/query/processor/query_processor.py +2 -2
- palimpzest/query/processor/query_processor_factory.py +9 -5
- palimpzest/validator/validator.py +7 -9
- {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/METADATA +3 -8
- palimpzest-0.8.3.dist-info/RECORD +95 -0
- palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
- palimpzest/prompts/util_phrases.py +0 -19
- palimpzest/query/operators/critique_and_refine_convert.py +0 -113
- palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
- palimpzest/query/operators/split_convert.py +0 -170
- palimpzest-0.8.1.dist-info/RECORD +0 -95
- {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/WHEEL +0 -0
- {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -101,7 +101,7 @@ def get_json_from_answer(answer: str, model: Model, cardinality: Cardinality) ->
|
|
|
101
101
|
# TODO: make sure answer parsing works with custom prompts / parsers (can defer this)
|
|
102
102
|
class Generator(Generic[ContextType, InputType]):
|
|
103
103
|
"""
|
|
104
|
-
|
|
104
|
+
Class for generating new fields for a record using an LLM.
|
|
105
105
|
"""
|
|
106
106
|
|
|
107
107
|
def __init__(
|
|
@@ -181,11 +181,11 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
181
181
|
|
|
182
182
|
return None
|
|
183
183
|
|
|
184
|
-
def _check_bool_answer_text(self, answer_text: str) -> dict | None:
|
|
184
|
+
def _check_bool_answer_text(self, answer_text: str, throw_exception: bool=False) -> dict | None:
|
|
185
185
|
"""
|
|
186
186
|
Return {"passed_operator": True} if and only if "true" is in the answer text.
|
|
187
187
|
Return {"passed_operator": False} if and only if "false" is in the answer text.
|
|
188
|
-
Otherwise,
|
|
188
|
+
Otherwise, raise an exception.
|
|
189
189
|
"""
|
|
190
190
|
# NOTE: we may be able to eliminate this condition by specifying this JSON output in the prompt;
|
|
191
191
|
# however, that would also need to coincide with a change to allow the parse_answer_fn to set "passed_operator"
|
|
@@ -194,6 +194,9 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
194
194
|
elif "false" in answer_text.lower():
|
|
195
195
|
return {"passed_operator": False}
|
|
196
196
|
|
|
197
|
+
if throw_exception:
|
|
198
|
+
raise Exception(f"Could not parse answer from completion text: {answer_text}")
|
|
199
|
+
|
|
197
200
|
return None
|
|
198
201
|
|
|
199
202
|
def _parse_convert_answer(self, completion_text: str, fields: dict[str, FieldInfo], json_output: bool) -> dict[str, list]:
|
|
@@ -235,7 +238,7 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
235
238
|
|
|
236
239
|
return self._check_convert_answer_text(completion_text, fields, throw_exception=True)
|
|
237
240
|
|
|
238
|
-
def _parse_bool_answer(self, completion_text: str) -> dict[str, list]:
|
|
241
|
+
def _parse_bool_answer(self, completion_text: str, json_output: bool) -> dict[str, list]:
|
|
239
242
|
"""Extract the answer from the completion object for filter and join operations."""
|
|
240
243
|
# if the model followed the default instructions, the completion text will place
|
|
241
244
|
# its answer between "ANSWER:" and "---"
|
|
@@ -243,6 +246,12 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
243
246
|
matches = regex.findall(completion_text)
|
|
244
247
|
if len(matches) > 0:
|
|
245
248
|
answer_text = matches[0].strip()
|
|
249
|
+
|
|
250
|
+
# if we don't expect a JSON output, return the answer text as is
|
|
251
|
+
if not json_output:
|
|
252
|
+
return answer_text
|
|
253
|
+
|
|
254
|
+
# otherwise, try to parse the answer text into a JSON object
|
|
246
255
|
field_answers = self._check_bool_answer_text(answer_text)
|
|
247
256
|
if field_answers is not None:
|
|
248
257
|
return field_answers
|
|
@@ -252,16 +261,21 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
252
261
|
matches = regex.findall(completion_text)
|
|
253
262
|
if len(matches) > 0:
|
|
254
263
|
answer_text = matches[0].strip()
|
|
264
|
+
|
|
265
|
+
# if we don't expect a JSON output, return the answer text as is
|
|
266
|
+
if not json_output:
|
|
267
|
+
return answer_text
|
|
268
|
+
|
|
269
|
+
# otherwise, try to parse the answer text into a JSON object
|
|
255
270
|
field_answers = self._check_bool_answer_text(answer_text)
|
|
256
271
|
if field_answers is not None:
|
|
257
272
|
return field_answers
|
|
258
273
|
|
|
259
|
-
# finally, try taking all of the text; throw an exception if
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
raise Exception(f"Could not parse answer from completion text: {completion_text}")
|
|
274
|
+
# finally, try taking all of the text; for JSON output, throw an exception if parsing fails
|
|
275
|
+
if not json_output:
|
|
276
|
+
return completion_text
|
|
263
277
|
|
|
264
|
-
return
|
|
278
|
+
return self._check_bool_answer_text(completion_text, throw_exception=True)
|
|
265
279
|
|
|
266
280
|
def _parse_answer(self, completion_text: str, fields: dict[str, FieldInfo] | None, json_output: bool, **kwargs) -> dict[str, list]:
|
|
267
281
|
"""Extract the answer from the completion object."""
|
|
@@ -275,8 +289,8 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
275
289
|
|
|
276
290
|
# extract the per-field answers from the completion text
|
|
277
291
|
field_answers = (
|
|
278
|
-
self._parse_bool_answer(completion_text)
|
|
279
|
-
if self.prompt_strategy.
|
|
292
|
+
self._parse_bool_answer(completion_text, json_output)
|
|
293
|
+
if self.prompt_strategy.is_filter_prompt() or self.prompt_strategy.is_join_prompt()
|
|
280
294
|
else self._parse_convert_answer(completion_text, fields, json_output)
|
|
281
295
|
)
|
|
282
296
|
|
|
@@ -299,6 +313,7 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
299
313
|
|
|
300
314
|
# generate a list of messages which can be used to construct a payload
|
|
301
315
|
messages = self.prompt_factory.create_messages(candidate, fields, right_candidate, **kwargs)
|
|
316
|
+
is_audio_op = any(msg.get("type") == "input_audio" for msg in messages)
|
|
302
317
|
|
|
303
318
|
# generate the text completion
|
|
304
319
|
start_time = time.time()
|
|
@@ -307,7 +322,7 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
307
322
|
completion_kwargs = {}
|
|
308
323
|
if not self.model.is_o_model() and not self.model.is_gpt_5_model():
|
|
309
324
|
completion_kwargs = {"temperature": kwargs.get("temperature", 0.0), **completion_kwargs}
|
|
310
|
-
if
|
|
325
|
+
if is_audio_op:
|
|
311
326
|
completion_kwargs = {"modalities": ["text"], **completion_kwargs}
|
|
312
327
|
if self.model.is_reasoning_model():
|
|
313
328
|
if self.model.is_vertex_model():
|
|
@@ -330,11 +345,10 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
330
345
|
# if there's an error generating the completion, we have to return an empty answer
|
|
331
346
|
# and can only account for the time spent performing the failed generation
|
|
332
347
|
except Exception as e:
|
|
333
|
-
print(f"Error generating completion: {e}")
|
|
334
348
|
logger.error(f"Error generating completion: {e}")
|
|
335
349
|
field_answers = (
|
|
336
350
|
{"passed_operator": False}
|
|
337
|
-
if self.prompt_strategy.
|
|
351
|
+
if self.prompt_strategy.is_filter_prompt() or self.prompt_strategy.is_join_prompt()
|
|
338
352
|
else {field_name: None for field_name in fields}
|
|
339
353
|
)
|
|
340
354
|
reasoning = None
|
|
@@ -360,7 +374,7 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
360
374
|
# for now, we only use tokens from prompt_token_details if it's an audio prompt
|
|
361
375
|
# get output tokens (all text) and input tokens by modality
|
|
362
376
|
output_tokens = usage["completion_tokens"]
|
|
363
|
-
if
|
|
377
|
+
if is_audio_op:
|
|
364
378
|
input_audio_tokens = usage["prompt_tokens_details"].get("audio_tokens", 0)
|
|
365
379
|
input_text_tokens = usage["prompt_tokens_details"].get("text_tokens", 0)
|
|
366
380
|
input_image_tokens = 0
|
|
@@ -413,9 +427,9 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
413
427
|
|
|
414
428
|
# parse field answers
|
|
415
429
|
field_answers = None
|
|
416
|
-
if fields is not None and (self.prompt_strategy.
|
|
430
|
+
if fields is not None and (self.prompt_strategy.is_filter_prompt() or self.prompt_strategy.is_join_prompt()):
|
|
417
431
|
field_answers = {"passed_operator": False}
|
|
418
|
-
elif fields is not None and not (self.prompt_strategy.
|
|
432
|
+
elif fields is not None and not (self.prompt_strategy.is_filter_prompt() or self.prompt_strategy.is_join_prompt()):
|
|
419
433
|
field_answers = {field_name: None for field_name in fields}
|
|
420
434
|
try:
|
|
421
435
|
field_answers = self._parse_answer(completion_text, fields, json_output, **kwargs)
|
|
@@ -6,6 +6,8 @@ from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
|
|
|
6
6
|
from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
|
|
7
7
|
from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
|
|
8
8
|
from palimpzest.query.operators.convert import NonLLMConvert as _NonLLMConvert
|
|
9
|
+
from palimpzest.query.operators.critique_and_refine import CritiqueAndRefineConvert as _CritiqueAndRefineConvert
|
|
10
|
+
from palimpzest.query.operators.critique_and_refine import CritiqueAndRefineFilter as _CritiqueAndRefineFilter
|
|
9
11
|
from palimpzest.query.operators.distinct import DistinctOp as _DistinctOp
|
|
10
12
|
from palimpzest.query.operators.filter import FilterOp as _FilterOp
|
|
11
13
|
from palimpzest.query.operators.filter import LLMFilter as _LLMFilter
|
|
@@ -46,12 +48,17 @@ from palimpzest.query.operators.logical import (
|
|
|
46
48
|
from palimpzest.query.operators.logical import (
|
|
47
49
|
RetrieveScan as _RetrieveScan,
|
|
48
50
|
)
|
|
49
|
-
from palimpzest.query.operators.
|
|
51
|
+
from palimpzest.query.operators.mixture_of_agents import MixtureOfAgentsConvert as _MixtureOfAgentsConvert
|
|
52
|
+
from palimpzest.query.operators.mixture_of_agents import MixtureOfAgentsFilter as _MixtureOfAgentsFilter
|
|
50
53
|
from palimpzest.query.operators.physical import PhysicalOperator as _PhysicalOperator
|
|
51
54
|
from palimpzest.query.operators.project import ProjectOp as _ProjectOp
|
|
55
|
+
from palimpzest.query.operators.rag import RAGConvert as _RAGConvert
|
|
56
|
+
from palimpzest.query.operators.rag import RAGFilter as _RAGFilter
|
|
52
57
|
from palimpzest.query.operators.retrieve import RetrieveOp as _RetrieveOp
|
|
53
58
|
from palimpzest.query.operators.scan import MarshalAndScanDataOp as _MarshalAndScanDataOp
|
|
54
59
|
from palimpzest.query.operators.scan import ScanPhysicalOp as _ScanPhysicalOp
|
|
60
|
+
from palimpzest.query.operators.split import SplitConvert as _SplitConvert
|
|
61
|
+
from palimpzest.query.operators.split import SplitFilter as _SplitFilter
|
|
55
62
|
|
|
56
63
|
LOGICAL_OPERATORS = [
|
|
57
64
|
_LogicalOperator,
|
|
@@ -72,6 +79,8 @@ PHYSICAL_OPERATORS = (
|
|
|
72
79
|
[_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp]
|
|
73
80
|
# convert
|
|
74
81
|
+ [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
|
|
82
|
+
# critique and refine
|
|
83
|
+
+ [_CritiqueAndRefineConvert, _CritiqueAndRefineFilter]
|
|
75
84
|
# distinct
|
|
76
85
|
+ [_DistinctOp]
|
|
77
86
|
# scan
|
|
@@ -83,13 +92,17 @@ PHYSICAL_OPERATORS = (
|
|
|
83
92
|
# limit
|
|
84
93
|
+ [_LimitScanOp]
|
|
85
94
|
# mixture-of-agents
|
|
86
|
-
+ [_MixtureOfAgentsConvert]
|
|
95
|
+
+ [_MixtureOfAgentsConvert, _MixtureOfAgentsFilter]
|
|
87
96
|
# physical
|
|
88
97
|
+ [_PhysicalOperator]
|
|
89
98
|
# project
|
|
90
99
|
+ [_ProjectOp]
|
|
100
|
+
# rag
|
|
101
|
+
+ [_RAGConvert, _RAGFilter]
|
|
91
102
|
# retrieve
|
|
92
103
|
+ [_RetrieveOp]
|
|
104
|
+
# split
|
|
105
|
+
+ [_SplitConvert, _SplitFilter]
|
|
93
106
|
)
|
|
94
107
|
|
|
95
108
|
__all__ = [
|
|
@@ -113,18 +113,20 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
113
113
|
group_by_fields = self.group_by_sig.group_by_fields
|
|
114
114
|
agg_fields = self.group_by_sig.get_agg_field_names()
|
|
115
115
|
for g in agg_state:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
parent_records=candidates,
|
|
119
|
-
)
|
|
116
|
+
# build up data item
|
|
117
|
+
data_item = {}
|
|
120
118
|
for i in range(0, len(g)):
|
|
121
119
|
k = g[i]
|
|
122
|
-
|
|
120
|
+
data_item[group_by_fields[i]] = k
|
|
123
121
|
vals = agg_state[g]
|
|
124
122
|
for i in range(0, len(vals)):
|
|
125
123
|
v = ApplyGroupByOp.agg_final(self.group_by_sig.agg_funcs[i], vals[i])
|
|
126
|
-
|
|
124
|
+
data_item[agg_fields[i]] = v
|
|
127
125
|
|
|
126
|
+
# create new DataRecord
|
|
127
|
+
schema = self.group_by_sig.output_schema()
|
|
128
|
+
data_item = schema(**data_item)
|
|
129
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
128
130
|
drs.append(dr)
|
|
129
131
|
|
|
130
132
|
# create RecordOpStats objects
|
|
@@ -132,9 +134,9 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
132
134
|
record_op_stats_lst = []
|
|
133
135
|
for dr in drs:
|
|
134
136
|
record_op_stats = RecordOpStats(
|
|
135
|
-
record_id=dr.
|
|
136
|
-
record_parent_ids=dr.
|
|
137
|
-
record_source_indices=dr.
|
|
137
|
+
record_id=dr._id,
|
|
138
|
+
record_parent_ids=dr._parent_ids,
|
|
139
|
+
record_source_indices=dr._source_indices,
|
|
138
140
|
record_state=dr.to_dict(include_bytes=False),
|
|
139
141
|
full_op_id=self.get_full_op_id(),
|
|
140
142
|
logical_op_id=self.logical_op_id,
|
|
@@ -197,7 +199,6 @@ class AverageAggregateOp(AggregateOp):
|
|
|
197
199
|
# NOTE: right now we perform a check in the constructor which enforces that the input_schema
|
|
198
200
|
# has a single field which is numeric in nature; in the future we may want to have a
|
|
199
201
|
# cleaner way of computing the value (rather than `float(list(candidate...))` below)
|
|
200
|
-
dr = DataRecord.from_agg_parents(schema=Average, parent_records=candidates)
|
|
201
202
|
summation, total = 0, 0
|
|
202
203
|
for candidate in candidates:
|
|
203
204
|
try:
|
|
@@ -205,13 +206,14 @@ class AverageAggregateOp(AggregateOp):
|
|
|
205
206
|
total += 1
|
|
206
207
|
except Exception:
|
|
207
208
|
pass
|
|
208
|
-
|
|
209
|
+
data_item = Average(average=summation / total)
|
|
210
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
209
211
|
|
|
210
212
|
# create RecordOpStats object
|
|
211
213
|
record_op_stats = RecordOpStats(
|
|
212
|
-
record_id=dr.
|
|
213
|
-
record_parent_ids=dr.
|
|
214
|
-
record_source_indices=dr.
|
|
214
|
+
record_id=dr._id,
|
|
215
|
+
record_parent_ids=dr._parent_ids,
|
|
216
|
+
record_source_indices=dr._source_indices,
|
|
215
217
|
record_state=dr.to_dict(include_bytes=False),
|
|
216
218
|
full_op_id=self.get_full_op_id(),
|
|
217
219
|
logical_op_id=self.logical_op_id,
|
|
@@ -260,14 +262,14 @@ class CountAggregateOp(AggregateOp):
|
|
|
260
262
|
start_time = time.time()
|
|
261
263
|
|
|
262
264
|
# create new DataRecord
|
|
263
|
-
|
|
264
|
-
dr
|
|
265
|
+
data_item = Count(count=len(candidates))
|
|
266
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
265
267
|
|
|
266
268
|
# create RecordOpStats object
|
|
267
269
|
record_op_stats = RecordOpStats(
|
|
268
|
-
record_id=dr.
|
|
269
|
-
record_parent_ids=dr.
|
|
270
|
-
record_source_indices=dr.
|
|
270
|
+
record_id=dr._id,
|
|
271
|
+
record_parent_ids=dr._parent_ids,
|
|
272
|
+
record_source_indices=dr._source_indices,
|
|
271
273
|
record_state=dr.to_dict(include_bytes=False),
|
|
272
274
|
full_op_id=self.get_full_op_id(),
|
|
273
275
|
logical_op_id=self.logical_op_id,
|
|
@@ -93,17 +93,15 @@ class SmolAgentsCompute(PhysicalOperator):
|
|
|
93
93
|
Given an input DataRecord and a determination of whether it passed the filter or not,
|
|
94
94
|
construct the resulting RecordSet.
|
|
95
95
|
"""
|
|
96
|
-
# create new DataRecord
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
if field in answer:
|
|
100
|
-
dr[field] = answer[field]
|
|
96
|
+
# create new DataRecord
|
|
97
|
+
data_item = {field: answer[field] for field in self.output_schema.model_fields if field in answer}
|
|
98
|
+
dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
|
|
101
99
|
|
|
102
100
|
# create RecordOpStats object
|
|
103
101
|
record_op_stats = RecordOpStats(
|
|
104
|
-
record_id=dr.
|
|
105
|
-
record_parent_ids=dr.
|
|
106
|
-
record_source_indices=dr.
|
|
102
|
+
record_id=dr._id,
|
|
103
|
+
record_parent_ids=dr._parent_ids,
|
|
104
|
+
record_source_indices=dr._source_indices,
|
|
107
105
|
record_state=dr.to_dict(include_bytes=False),
|
|
108
106
|
full_op_id=self.get_full_op_id(),
|
|
109
107
|
logical_op_id=self.logical_op_id,
|
|
@@ -74,25 +74,14 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
74
74
|
|
|
75
75
|
drs = []
|
|
76
76
|
for idx in range(max(n_records, 1)):
|
|
77
|
-
# initialize record with the correct output schema, parent record, and cardinality idx
|
|
78
|
-
dr = DataRecord.from_parent(self.output_schema, parent_record=candidate, cardinality_idx=idx)
|
|
79
|
-
|
|
80
|
-
# copy all fields from the input record
|
|
81
|
-
# NOTE: this means that records processed by PZ converts will inherit all pre-computed fields
|
|
82
|
-
# in an incremental fashion; this is a design choice which may be revisited in the future
|
|
83
|
-
for field in candidate.get_field_names():
|
|
84
|
-
setattr(dr, field, getattr(candidate, field))
|
|
85
|
-
|
|
86
|
-
# get input field names and output field names
|
|
87
|
-
input_fields = list(self.input_schema.model_fields)
|
|
88
|
-
output_fields = list(self.output_schema.model_fields)
|
|
89
|
-
|
|
90
77
|
# parse newly generated fields from the field_answers dictionary for this field; if the list
|
|
91
78
|
# of generated values is shorter than the number of records, we fill in with None
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
79
|
+
data_item = {}
|
|
80
|
+
for field in self.generated_fields:
|
|
81
|
+
data_item[field] = field_answers[field][idx] if idx < len(field_answers[field]) else None
|
|
82
|
+
|
|
83
|
+
# initialize record with the correct output schema, data_item, parent record, and cardinality idx
|
|
84
|
+
dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate, cardinality_idx=idx)
|
|
96
85
|
|
|
97
86
|
# append data record to list of output data records
|
|
98
87
|
drs.append(dr)
|
|
@@ -117,9 +106,9 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
117
106
|
# create the RecordOpStats objects for each output record
|
|
118
107
|
record_op_stats_lst = [
|
|
119
108
|
RecordOpStats(
|
|
120
|
-
record_id=dr.
|
|
121
|
-
record_parent_ids=dr.
|
|
122
|
-
record_source_indices=dr.
|
|
109
|
+
record_id=dr._id,
|
|
110
|
+
record_parent_ids=dr._parent_ids,
|
|
111
|
+
record_source_indices=dr._source_indices,
|
|
123
112
|
record_state=dr.to_dict(include_bytes=False),
|
|
124
113
|
full_op_id=self.get_full_op_id(),
|
|
125
114
|
logical_op_id=self.logical_op_id,
|
|
@@ -127,7 +116,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
127
116
|
time_per_record=time_per_record,
|
|
128
117
|
cost_per_record=per_record_stats.cost_per_record,
|
|
129
118
|
model_name=self.get_model_name(),
|
|
130
|
-
answer={field_name: getattr(dr, field_name) for field_name in field_names},
|
|
119
|
+
answer={field_name: getattr(dr, field_name, None) for field_name in field_names},
|
|
131
120
|
input_fields=list(self.input_schema.model_fields),
|
|
132
121
|
generated_fields=field_names,
|
|
133
122
|
total_input_tokens=per_record_stats.total_input_tokens,
|
|
@@ -139,7 +128,6 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
139
128
|
total_llm_calls=per_record_stats.total_llm_calls,
|
|
140
129
|
total_embedding_llm_calls=per_record_stats.total_embedding_llm_calls,
|
|
141
130
|
failed_convert=(not successful_convert),
|
|
142
|
-
image_operation=self.is_image_conversion(),
|
|
143
131
|
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
144
132
|
)
|
|
145
133
|
for dr in records
|
|
@@ -148,11 +136,6 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
148
136
|
# create and return the DataRecordSet
|
|
149
137
|
return DataRecordSet(records, record_op_stats_lst)
|
|
150
138
|
|
|
151
|
-
@abstractmethod
|
|
152
|
-
def is_image_conversion(self) -> bool:
|
|
153
|
-
"""Return True if the convert operation processes an image, False otherwise."""
|
|
154
|
-
pass
|
|
155
|
-
|
|
156
139
|
@abstractmethod
|
|
157
140
|
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
158
141
|
"""
|
|
@@ -216,11 +199,6 @@ class NonLLMConvert(ConvertOp):
|
|
|
216
199
|
op += f" UDF: {self.udf.__name__}\n"
|
|
217
200
|
return op
|
|
218
201
|
|
|
219
|
-
def is_image_conversion(self) -> bool:
|
|
220
|
-
# NOTE: even if the UDF is processing an image, we do not consider this an image conversion
|
|
221
|
-
# (the output of this function will be used by the CostModel in a way which does not apply to UDFs)
|
|
222
|
-
return False
|
|
223
|
-
|
|
224
202
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
225
203
|
"""
|
|
226
204
|
Compute naive cost estimates for the NonLLMConvert operation. These estimates assume
|
|
@@ -287,7 +265,7 @@ class LLMConvert(ConvertOp):
|
|
|
287
265
|
def __init__(
|
|
288
266
|
self,
|
|
289
267
|
model: Model,
|
|
290
|
-
prompt_strategy: PromptStrategy = PromptStrategy.
|
|
268
|
+
prompt_strategy: PromptStrategy = PromptStrategy.MAP,
|
|
291
269
|
reasoning_effort: str | None = None,
|
|
292
270
|
*args,
|
|
293
271
|
**kwargs,
|
|
@@ -330,9 +308,6 @@ class LLMConvert(ConvertOp):
|
|
|
330
308
|
def get_model_name(self):
|
|
331
309
|
return None if self.model is None else self.model.value
|
|
332
310
|
|
|
333
|
-
def is_image_conversion(self) -> bool:
|
|
334
|
-
return self.prompt_strategy.is_image_prompt()
|
|
335
|
-
|
|
336
311
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
337
312
|
"""
|
|
338
313
|
Compute naive cost estimates for the LLMConvert operation. Implicitly, these estimates
|
|
@@ -350,7 +325,7 @@ class LLMConvert(ConvertOp):
|
|
|
350
325
|
|
|
351
326
|
# get est. of conversion cost (in USD) per record from model card
|
|
352
327
|
usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
|
|
353
|
-
if getattr(self, "prompt_strategy", None) is not None and self.
|
|
328
|
+
if getattr(self, "prompt_strategy", None) is not None and self.is_audio_op():
|
|
354
329
|
usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
|
|
355
330
|
|
|
356
331
|
model_conversion_usd_per_record = (
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic.fields import FieldInfo
|
|
6
|
+
|
|
7
|
+
from palimpzest.constants import MODEL_CARDS, Cardinality, Model, PromptStrategy
|
|
8
|
+
from palimpzest.core.elements.records import DataRecord
|
|
9
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates
|
|
10
|
+
from palimpzest.query.generators.generators import Generator
|
|
11
|
+
from palimpzest.query.operators.convert import LLMConvert
|
|
12
|
+
from palimpzest.query.operators.filter import LLMFilter
|
|
13
|
+
|
|
14
|
+
# TYPE DEFINITIONS
|
|
15
|
+
FieldName = str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CritiqueAndRefineConvert(LLMConvert):
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
critic_model: Model,
|
|
23
|
+
refine_model: Model,
|
|
24
|
+
*args,
|
|
25
|
+
**kwargs,
|
|
26
|
+
):
|
|
27
|
+
super().__init__(*args, **kwargs)
|
|
28
|
+
self.critic_model = critic_model
|
|
29
|
+
self.refine_model = refine_model
|
|
30
|
+
|
|
31
|
+
# create generators
|
|
32
|
+
self.critic_generator = Generator(self.critic_model, PromptStrategy.MAP_CRITIC, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
33
|
+
self.refine_generator = Generator(self.refine_model, PromptStrategy.MAP_REFINE, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
34
|
+
|
|
35
|
+
def __str__(self):
|
|
36
|
+
op = super().__str__()
|
|
37
|
+
op += f" Critic Model: {self.critic_model}\n"
|
|
38
|
+
op += f" Refine Model: {self.refine_model}\n"
|
|
39
|
+
return op
|
|
40
|
+
|
|
41
|
+
def get_id_params(self):
|
|
42
|
+
id_params = super().get_id_params()
|
|
43
|
+
id_params = {
|
|
44
|
+
"critic_model": self.critic_model.value,
|
|
45
|
+
"refine_model": self.refine_model.value,
|
|
46
|
+
**id_params,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
return id_params
|
|
50
|
+
|
|
51
|
+
def get_op_params(self):
|
|
52
|
+
op_params = super().get_op_params()
|
|
53
|
+
op_params = {
|
|
54
|
+
"critic_model": self.critic_model,
|
|
55
|
+
"refine_model": self.refine_model,
|
|
56
|
+
**op_params,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
return op_params
|
|
60
|
+
|
|
61
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
62
|
+
"""
|
|
63
|
+
Currently, we are invoking `self.model`, then critiquing its output with `self.critic_model`, and
|
|
64
|
+
finally refining the output with `self.refine_model`. Thus, we roughly expect to incur the cost
|
|
65
|
+
and time of three LLMConverts. In practice, this naive quality estimate will be overwritten by the
|
|
66
|
+
CostModel's estimate once it executes a few instances of the operator.
|
|
67
|
+
"""
|
|
68
|
+
# get naive cost estimates for first LLM call and multiply by 3 for now;
|
|
69
|
+
# of course we should sum individual estimates for each model, but this is a rough estimate
|
|
70
|
+
# and in practice we will need to revamp our naive cost estimates in the near future
|
|
71
|
+
naive_op_cost_estimates = 3 * super().naive_cost_estimates(source_op_cost_estimates)
|
|
72
|
+
|
|
73
|
+
# for naive setting, estimate quality as quality of refine model
|
|
74
|
+
model_quality = MODEL_CARDS[self.refine_model.value]["overall"] / 100.0
|
|
75
|
+
naive_op_cost_estimates.quality = model_quality
|
|
76
|
+
naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
|
|
77
|
+
naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
|
|
78
|
+
|
|
79
|
+
return naive_op_cost_estimates
|
|
80
|
+
|
|
81
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
|
|
82
|
+
# get input fields
|
|
83
|
+
input_fields = self.get_input_fields()
|
|
84
|
+
|
|
85
|
+
# NOTE: when I merge in the `abacus` branch, I will want to update this to reflect the changes I made to reasoning extraction
|
|
86
|
+
# execute the initial model
|
|
87
|
+
original_gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema}
|
|
88
|
+
field_answers, reasoning, original_gen_stats, original_messages = self.generator(candidate, fields, **original_gen_kwargs)
|
|
89
|
+
original_output = f"REASONING: {reasoning}\nANSWER: {field_answers}\n"
|
|
90
|
+
|
|
91
|
+
# execute the critic model
|
|
92
|
+
critic_gen_kwargs = {"original_output": original_output, "original_messages": original_messages, **original_gen_kwargs}
|
|
93
|
+
_, reasoning, critic_gen_stats, _ = self.critic_generator(candidate, fields, json_output=False, **critic_gen_kwargs)
|
|
94
|
+
critique_output = f"CRITIQUE: {reasoning}\n"
|
|
95
|
+
|
|
96
|
+
# execute the refinement model
|
|
97
|
+
refine_gen_kwargs = {"critique_output": critique_output, **critic_gen_kwargs}
|
|
98
|
+
field_answers, reasoning, refine_gen_stats, _ = self.refine_generator(candidate, fields, **refine_gen_kwargs)
|
|
99
|
+
|
|
100
|
+
# compute the total generation stats
|
|
101
|
+
generation_stats = original_gen_stats + critic_gen_stats + refine_gen_stats
|
|
102
|
+
|
|
103
|
+
return field_answers, generation_stats
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class CritiqueAndRefineFilter(LLMFilter):
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
critic_model: Model,
|
|
111
|
+
refine_model: Model,
|
|
112
|
+
*args,
|
|
113
|
+
**kwargs,
|
|
114
|
+
):
|
|
115
|
+
super().__init__(*args, **kwargs)
|
|
116
|
+
self.critic_model = critic_model
|
|
117
|
+
self.refine_model = refine_model
|
|
118
|
+
|
|
119
|
+
# create generators
|
|
120
|
+
self.critic_generator = Generator(self.critic_model, PromptStrategy.FILTER_CRITIC, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
|
|
121
|
+
self.refine_generator = Generator(self.refine_model, PromptStrategy.FILTER_REFINE, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
|
|
122
|
+
|
|
123
|
+
def __str__(self):
|
|
124
|
+
op = super().__str__()
|
|
125
|
+
op += f" Critic Model: {self.critic_model}\n"
|
|
126
|
+
op += f" Refine Model: {self.refine_model}\n"
|
|
127
|
+
return op
|
|
128
|
+
|
|
129
|
+
def get_id_params(self):
|
|
130
|
+
id_params = super().get_id_params()
|
|
131
|
+
id_params = {
|
|
132
|
+
"critic_model": self.critic_model.value,
|
|
133
|
+
"refine_model": self.refine_model.value,
|
|
134
|
+
**id_params,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
return id_params
|
|
138
|
+
|
|
139
|
+
def get_op_params(self):
|
|
140
|
+
op_params = super().get_op_params()
|
|
141
|
+
op_params = {
|
|
142
|
+
"critic_model": self.critic_model,
|
|
143
|
+
"refine_model": self.refine_model,
|
|
144
|
+
**op_params,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
return op_params
|
|
148
|
+
|
|
149
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
150
|
+
"""
|
|
151
|
+
Currently, we are invoking `self.model`, then critiquing its output with `self.critic_model`, and
|
|
152
|
+
finally refining the output with `self.refine_model`. Thus, we roughly expect to incur the cost
|
|
153
|
+
and time of three LLMFilters. In practice, this naive quality estimate will be overwritten by the
|
|
154
|
+
CostModel's estimate once it executes a few instances of the operator.
|
|
155
|
+
"""
|
|
156
|
+
# get naive cost estimates for first LLM call and multiply by 3 for now;
|
|
157
|
+
# of course we should sum individual estimates for each model, but this is a rough estimate
|
|
158
|
+
# and in practice we will need to revamp our naive cost estimates in the near future
|
|
159
|
+
naive_op_cost_estimates = 3 * super().naive_cost_estimates(source_op_cost_estimates)
|
|
160
|
+
|
|
161
|
+
# for naive setting, estimate quality as quality of refine model
|
|
162
|
+
model_quality = MODEL_CARDS[self.refine_model.value]["overall"] / 100.0
|
|
163
|
+
naive_op_cost_estimates.quality = model_quality
|
|
164
|
+
naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
|
|
165
|
+
naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
|
|
166
|
+
|
|
167
|
+
return naive_op_cost_estimates
|
|
168
|
+
|
|
169
|
+
def filter(self, candidate: DataRecord) -> tuple[dict[str, bool], GenerationStats]:
|
|
170
|
+
# get input fields
|
|
171
|
+
input_fields = self.get_input_fields()
|
|
172
|
+
|
|
173
|
+
# construct output fields
|
|
174
|
+
fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the record passed the filter operation")}
|
|
175
|
+
|
|
176
|
+
# NOTE: when I merge in the `abacus` branch, I will want to update this to reflect the changes I made to reasoning extraction
|
|
177
|
+
# execute the initial model
|
|
178
|
+
original_gen_kwargs = {"project_cols": input_fields, "filter_condition": self.filter_obj.filter_condition}
|
|
179
|
+
field_answers, reasoning, original_gen_stats, original_messages = self.generator(candidate, fields, **original_gen_kwargs)
|
|
180
|
+
original_output = f"REASONING: {reasoning}\nANSWER: {str(field_answers['passed_operator']).upper()}\n"
|
|
181
|
+
|
|
182
|
+
# execute the critic model
|
|
183
|
+
critic_gen_kwargs = {"original_output": original_output, "original_messages": original_messages, **original_gen_kwargs}
|
|
184
|
+
_, reasoning, critic_gen_stats, _ = self.critic_generator(candidate, fields, json_output=False, **critic_gen_kwargs)
|
|
185
|
+
critique_output = f"CRITIQUE: {reasoning}\n"
|
|
186
|
+
|
|
187
|
+
# execute the refinement model
|
|
188
|
+
refine_gen_kwargs = {"critique_output": critique_output, **critic_gen_kwargs}
|
|
189
|
+
field_answers, reasoning, refine_gen_stats, _ = self.refine_generator(candidate, fields, **refine_gen_kwargs)
|
|
190
|
+
|
|
191
|
+
# compute the total generation stats
|
|
192
|
+
generation_stats = original_gen_stats + critic_gen_stats + refine_gen_stats
|
|
193
|
+
|
|
194
|
+
return field_answers, generation_stats
|
|
@@ -35,27 +35,27 @@ class DistinctOp(PhysicalOperator):
|
|
|
35
35
|
|
|
36
36
|
def __call__(self, candidate: DataRecord) -> DataRecordSet:
|
|
37
37
|
# create new DataRecord
|
|
38
|
-
dr = DataRecord.from_parent(schema=candidate.schema, parent_record=candidate)
|
|
38
|
+
dr = DataRecord.from_parent(schema=candidate.schema, data_item={}, parent_record=candidate)
|
|
39
39
|
|
|
40
40
|
# output record only if it has not been seen before
|
|
41
41
|
record_str = dr.to_json_str(project_cols=self.distinct_cols, bytes_to_str=True, sorted=True)
|
|
42
42
|
record_hash = f"{hash(record_str)}"
|
|
43
|
-
dr.
|
|
44
|
-
if dr.
|
|
43
|
+
dr._passed_operator = record_hash not in self._distinct_seen
|
|
44
|
+
if dr._passed_operator:
|
|
45
45
|
self._distinct_seen.add(record_hash)
|
|
46
46
|
|
|
47
47
|
# create RecordOpStats object
|
|
48
48
|
record_op_stats = RecordOpStats(
|
|
49
|
-
record_id=dr.
|
|
50
|
-
record_parent_ids=dr.
|
|
51
|
-
record_source_indices=dr.
|
|
49
|
+
record_id=dr._id,
|
|
50
|
+
record_parent_ids=dr._parent_ids,
|
|
51
|
+
record_source_indices=dr._source_indices,
|
|
52
52
|
record_state=dr.to_dict(include_bytes=False),
|
|
53
53
|
full_op_id=self.get_full_op_id(),
|
|
54
54
|
logical_op_id=self.logical_op_id,
|
|
55
55
|
op_name=self.op_name(),
|
|
56
56
|
time_per_record=0.0,
|
|
57
57
|
cost_per_record=0.0,
|
|
58
|
-
passed_operator=dr.
|
|
58
|
+
passed_operator=dr._passed_operator,
|
|
59
59
|
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
60
60
|
)
|
|
61
61
|
|