palimpzest 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +38 -62
- palimpzest/core/data/iter_dataset.py +5 -5
- palimpzest/core/elements/groupbysig.py +1 -1
- palimpzest/core/elements/records.py +91 -109
- palimpzest/core/lib/schemas.py +23 -0
- palimpzest/core/models.py +3 -3
- palimpzest/prompts/__init__.py +2 -6
- palimpzest/prompts/convert_prompts.py +10 -66
- palimpzest/prompts/critique_and_refine_prompts.py +66 -0
- palimpzest/prompts/filter_prompts.py +8 -46
- palimpzest/prompts/join_prompts.py +12 -75
- palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
- palimpzest/prompts/moa_proposer_prompts.py +87 -0
- palimpzest/prompts/prompt_factory.py +351 -479
- palimpzest/prompts/split_merge_prompts.py +51 -2
- palimpzest/prompts/split_proposer_prompts.py +48 -16
- palimpzest/prompts/utils.py +109 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
- palimpzest/query/execution/execution_strategy.py +4 -4
- palimpzest/query/execution/mab_execution_strategy.py +1 -2
- palimpzest/query/execution/parallel_execution_strategy.py +3 -3
- palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
- palimpzest/query/generators/generators.py +31 -17
- palimpzest/query/operators/__init__.py +15 -2
- palimpzest/query/operators/aggregate.py +21 -19
- palimpzest/query/operators/compute.py +6 -8
- palimpzest/query/operators/convert.py +12 -37
- palimpzest/query/operators/critique_and_refine.py +194 -0
- palimpzest/query/operators/distinct.py +7 -7
- palimpzest/query/operators/filter.py +13 -25
- palimpzest/query/operators/join.py +321 -192
- palimpzest/query/operators/limit.py +4 -4
- palimpzest/query/operators/mixture_of_agents.py +246 -0
- palimpzest/query/operators/physical.py +25 -2
- palimpzest/query/operators/project.py +4 -4
- palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
- palimpzest/query/operators/retrieve.py +10 -9
- palimpzest/query/operators/scan.py +9 -10
- palimpzest/query/operators/search.py +18 -24
- palimpzest/query/operators/split.py +321 -0
- palimpzest/query/optimizer/__init__.py +12 -8
- palimpzest/query/optimizer/optimizer.py +12 -10
- palimpzest/query/optimizer/rules.py +201 -108
- palimpzest/query/optimizer/tasks.py +18 -6
- palimpzest/validator/validator.py +7 -9
- {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/METADATA +3 -8
- palimpzest-0.8.4.dist-info/RECORD +95 -0
- palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
- palimpzest/prompts/util_phrases.py +0 -19
- palimpzest/query/operators/critique_and_refine_convert.py +0 -113
- palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
- palimpzest/query/operators/split_convert.py +0 -170
- palimpzest-0.8.2.dist-info/RECORD +0 -95
- {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/WHEEL +0 -0
- {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -113,18 +113,20 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
113
113
|
group_by_fields = self.group_by_sig.group_by_fields
|
|
114
114
|
agg_fields = self.group_by_sig.get_agg_field_names()
|
|
115
115
|
for g in agg_state:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
parent_records=candidates,
|
|
119
|
-
)
|
|
116
|
+
# build up data item
|
|
117
|
+
data_item = {}
|
|
120
118
|
for i in range(0, len(g)):
|
|
121
119
|
k = g[i]
|
|
122
|
-
|
|
120
|
+
data_item[group_by_fields[i]] = k
|
|
123
121
|
vals = agg_state[g]
|
|
124
122
|
for i in range(0, len(vals)):
|
|
125
123
|
v = ApplyGroupByOp.agg_final(self.group_by_sig.agg_funcs[i], vals[i])
|
|
126
|
-
|
|
124
|
+
data_item[agg_fields[i]] = v
|
|
127
125
|
|
|
126
|
+
# create new DataRecord
|
|
127
|
+
schema = self.group_by_sig.output_schema()
|
|
128
|
+
data_item = schema(**data_item)
|
|
129
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
128
130
|
drs.append(dr)
|
|
129
131
|
|
|
130
132
|
# create RecordOpStats objects
|
|
@@ -132,9 +134,9 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
132
134
|
record_op_stats_lst = []
|
|
133
135
|
for dr in drs:
|
|
134
136
|
record_op_stats = RecordOpStats(
|
|
135
|
-
record_id=dr.
|
|
136
|
-
record_parent_ids=dr.
|
|
137
|
-
record_source_indices=dr.
|
|
137
|
+
record_id=dr._id,
|
|
138
|
+
record_parent_ids=dr._parent_ids,
|
|
139
|
+
record_source_indices=dr._source_indices,
|
|
138
140
|
record_state=dr.to_dict(include_bytes=False),
|
|
139
141
|
full_op_id=self.get_full_op_id(),
|
|
140
142
|
logical_op_id=self.logical_op_id,
|
|
@@ -197,7 +199,6 @@ class AverageAggregateOp(AggregateOp):
|
|
|
197
199
|
# NOTE: right now we perform a check in the constructor which enforces that the input_schema
|
|
198
200
|
# has a single field which is numeric in nature; in the future we may want to have a
|
|
199
201
|
# cleaner way of computing the value (rather than `float(list(candidate...))` below)
|
|
200
|
-
dr = DataRecord.from_agg_parents(schema=Average, parent_records=candidates)
|
|
201
202
|
summation, total = 0, 0
|
|
202
203
|
for candidate in candidates:
|
|
203
204
|
try:
|
|
@@ -205,13 +206,14 @@ class AverageAggregateOp(AggregateOp):
|
|
|
205
206
|
total += 1
|
|
206
207
|
except Exception:
|
|
207
208
|
pass
|
|
208
|
-
|
|
209
|
+
data_item = Average(average=summation / total)
|
|
210
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
209
211
|
|
|
210
212
|
# create RecordOpStats object
|
|
211
213
|
record_op_stats = RecordOpStats(
|
|
212
|
-
record_id=dr.
|
|
213
|
-
record_parent_ids=dr.
|
|
214
|
-
record_source_indices=dr.
|
|
214
|
+
record_id=dr._id,
|
|
215
|
+
record_parent_ids=dr._parent_ids,
|
|
216
|
+
record_source_indices=dr._source_indices,
|
|
215
217
|
record_state=dr.to_dict(include_bytes=False),
|
|
216
218
|
full_op_id=self.get_full_op_id(),
|
|
217
219
|
logical_op_id=self.logical_op_id,
|
|
@@ -260,14 +262,14 @@ class CountAggregateOp(AggregateOp):
|
|
|
260
262
|
start_time = time.time()
|
|
261
263
|
|
|
262
264
|
# create new DataRecord
|
|
263
|
-
|
|
264
|
-
dr
|
|
265
|
+
data_item = Count(count=len(candidates))
|
|
266
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
265
267
|
|
|
266
268
|
# create RecordOpStats object
|
|
267
269
|
record_op_stats = RecordOpStats(
|
|
268
|
-
record_id=dr.
|
|
269
|
-
record_parent_ids=dr.
|
|
270
|
-
record_source_indices=dr.
|
|
270
|
+
record_id=dr._id,
|
|
271
|
+
record_parent_ids=dr._parent_ids,
|
|
272
|
+
record_source_indices=dr._source_indices,
|
|
271
273
|
record_state=dr.to_dict(include_bytes=False),
|
|
272
274
|
full_op_id=self.get_full_op_id(),
|
|
273
275
|
logical_op_id=self.logical_op_id,
|
|
@@ -93,17 +93,15 @@ class SmolAgentsCompute(PhysicalOperator):
|
|
|
93
93
|
Given an input DataRecord and a determination of whether it passed the filter or not,
|
|
94
94
|
construct the resulting RecordSet.
|
|
95
95
|
"""
|
|
96
|
-
# create new DataRecord
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
if field in answer:
|
|
100
|
-
dr[field] = answer[field]
|
|
96
|
+
# create new DataRecord
|
|
97
|
+
data_item = {field: answer[field] for field in self.output_schema.model_fields if field in answer}
|
|
98
|
+
dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate)
|
|
101
99
|
|
|
102
100
|
# create RecordOpStats object
|
|
103
101
|
record_op_stats = RecordOpStats(
|
|
104
|
-
record_id=dr.
|
|
105
|
-
record_parent_ids=dr.
|
|
106
|
-
record_source_indices=dr.
|
|
102
|
+
record_id=dr._id,
|
|
103
|
+
record_parent_ids=dr._parent_ids,
|
|
104
|
+
record_source_indices=dr._source_indices,
|
|
107
105
|
record_state=dr.to_dict(include_bytes=False),
|
|
108
106
|
full_op_id=self.get_full_op_id(),
|
|
109
107
|
logical_op_id=self.logical_op_id,
|
|
@@ -74,25 +74,14 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
74
74
|
|
|
75
75
|
drs = []
|
|
76
76
|
for idx in range(max(n_records, 1)):
|
|
77
|
-
# initialize record with the correct output schema, parent record, and cardinality idx
|
|
78
|
-
dr = DataRecord.from_parent(self.output_schema, parent_record=candidate, cardinality_idx=idx)
|
|
79
|
-
|
|
80
|
-
# copy all fields from the input record
|
|
81
|
-
# NOTE: this means that records processed by PZ converts will inherit all pre-computed fields
|
|
82
|
-
# in an incremental fashion; this is a design choice which may be revisited in the future
|
|
83
|
-
for field in candidate.get_field_names():
|
|
84
|
-
setattr(dr, field, getattr(candidate, field))
|
|
85
|
-
|
|
86
|
-
# get input field names and output field names
|
|
87
|
-
input_fields = list(self.input_schema.model_fields)
|
|
88
|
-
output_fields = list(self.output_schema.model_fields)
|
|
89
|
-
|
|
90
77
|
# parse newly generated fields from the field_answers dictionary for this field; if the list
|
|
91
78
|
# of generated values is shorter than the number of records, we fill in with None
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
79
|
+
data_item = {}
|
|
80
|
+
for field in self.generated_fields:
|
|
81
|
+
data_item[field] = field_answers[field][idx] if idx < len(field_answers[field]) else None
|
|
82
|
+
|
|
83
|
+
# initialize record with the correct output schema, data_item, parent record, and cardinality idx
|
|
84
|
+
dr = DataRecord.from_parent(self.output_schema, data_item, parent_record=candidate, cardinality_idx=idx)
|
|
96
85
|
|
|
97
86
|
# append data record to list of output data records
|
|
98
87
|
drs.append(dr)
|
|
@@ -117,9 +106,9 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
117
106
|
# create the RecordOpStats objects for each output record
|
|
118
107
|
record_op_stats_lst = [
|
|
119
108
|
RecordOpStats(
|
|
120
|
-
record_id=dr.
|
|
121
|
-
record_parent_ids=dr.
|
|
122
|
-
record_source_indices=dr.
|
|
109
|
+
record_id=dr._id,
|
|
110
|
+
record_parent_ids=dr._parent_ids,
|
|
111
|
+
record_source_indices=dr._source_indices,
|
|
123
112
|
record_state=dr.to_dict(include_bytes=False),
|
|
124
113
|
full_op_id=self.get_full_op_id(),
|
|
125
114
|
logical_op_id=self.logical_op_id,
|
|
@@ -127,7 +116,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
127
116
|
time_per_record=time_per_record,
|
|
128
117
|
cost_per_record=per_record_stats.cost_per_record,
|
|
129
118
|
model_name=self.get_model_name(),
|
|
130
|
-
answer={field_name: getattr(dr, field_name) for field_name in field_names},
|
|
119
|
+
answer={field_name: getattr(dr, field_name, None) for field_name in field_names},
|
|
131
120
|
input_fields=list(self.input_schema.model_fields),
|
|
132
121
|
generated_fields=field_names,
|
|
133
122
|
total_input_tokens=per_record_stats.total_input_tokens,
|
|
@@ -139,7 +128,6 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
139
128
|
total_llm_calls=per_record_stats.total_llm_calls,
|
|
140
129
|
total_embedding_llm_calls=per_record_stats.total_embedding_llm_calls,
|
|
141
130
|
failed_convert=(not successful_convert),
|
|
142
|
-
image_operation=self.is_image_conversion(),
|
|
143
131
|
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
144
132
|
)
|
|
145
133
|
for dr in records
|
|
@@ -148,11 +136,6 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
148
136
|
# create and return the DataRecordSet
|
|
149
137
|
return DataRecordSet(records, record_op_stats_lst)
|
|
150
138
|
|
|
151
|
-
@abstractmethod
|
|
152
|
-
def is_image_conversion(self) -> bool:
|
|
153
|
-
"""Return True if the convert operation processes an image, False otherwise."""
|
|
154
|
-
pass
|
|
155
|
-
|
|
156
139
|
@abstractmethod
|
|
157
140
|
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
158
141
|
"""
|
|
@@ -216,11 +199,6 @@ class NonLLMConvert(ConvertOp):
|
|
|
216
199
|
op += f" UDF: {self.udf.__name__}\n"
|
|
217
200
|
return op
|
|
218
201
|
|
|
219
|
-
def is_image_conversion(self) -> bool:
|
|
220
|
-
# NOTE: even if the UDF is processing an image, we do not consider this an image conversion
|
|
221
|
-
# (the output of this function will be used by the CostModel in a way which does not apply to UDFs)
|
|
222
|
-
return False
|
|
223
|
-
|
|
224
202
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
225
203
|
"""
|
|
226
204
|
Compute naive cost estimates for the NonLLMConvert operation. These estimates assume
|
|
@@ -287,7 +265,7 @@ class LLMConvert(ConvertOp):
|
|
|
287
265
|
def __init__(
|
|
288
266
|
self,
|
|
289
267
|
model: Model,
|
|
290
|
-
prompt_strategy: PromptStrategy = PromptStrategy.
|
|
268
|
+
prompt_strategy: PromptStrategy = PromptStrategy.MAP,
|
|
291
269
|
reasoning_effort: str | None = None,
|
|
292
270
|
*args,
|
|
293
271
|
**kwargs,
|
|
@@ -330,9 +308,6 @@ class LLMConvert(ConvertOp):
|
|
|
330
308
|
def get_model_name(self):
|
|
331
309
|
return None if self.model is None else self.model.value
|
|
332
310
|
|
|
333
|
-
def is_image_conversion(self) -> bool:
|
|
334
|
-
return self.prompt_strategy.is_image_prompt()
|
|
335
|
-
|
|
336
311
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
337
312
|
"""
|
|
338
313
|
Compute naive cost estimates for the LLMConvert operation. Implicitly, these estimates
|
|
@@ -350,7 +325,7 @@ class LLMConvert(ConvertOp):
|
|
|
350
325
|
|
|
351
326
|
# get est. of conversion cost (in USD) per record from model card
|
|
352
327
|
usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
|
|
353
|
-
if getattr(self, "prompt_strategy", None) is not None and self.
|
|
328
|
+
if getattr(self, "prompt_strategy", None) is not None and self.is_audio_op():
|
|
354
329
|
usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
|
|
355
330
|
|
|
356
331
|
model_conversion_usd_per_record = (
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic.fields import FieldInfo
|
|
6
|
+
|
|
7
|
+
from palimpzest.constants import MODEL_CARDS, Cardinality, Model, PromptStrategy
|
|
8
|
+
from palimpzest.core.elements.records import DataRecord
|
|
9
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates
|
|
10
|
+
from palimpzest.query.generators.generators import Generator
|
|
11
|
+
from palimpzest.query.operators.convert import LLMConvert
|
|
12
|
+
from palimpzest.query.operators.filter import LLMFilter
|
|
13
|
+
|
|
14
|
+
# TYPE DEFINITIONS
|
|
15
|
+
FieldName = str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CritiqueAndRefineConvert(LLMConvert):
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
critic_model: Model,
|
|
23
|
+
refine_model: Model,
|
|
24
|
+
*args,
|
|
25
|
+
**kwargs,
|
|
26
|
+
):
|
|
27
|
+
super().__init__(*args, **kwargs)
|
|
28
|
+
self.critic_model = critic_model
|
|
29
|
+
self.refine_model = refine_model
|
|
30
|
+
|
|
31
|
+
# create generators
|
|
32
|
+
self.critic_generator = Generator(self.critic_model, PromptStrategy.MAP_CRITIC, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
33
|
+
self.refine_generator = Generator(self.refine_model, PromptStrategy.MAP_REFINE, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
34
|
+
|
|
35
|
+
def __str__(self):
|
|
36
|
+
op = super().__str__()
|
|
37
|
+
op += f" Critic Model: {self.critic_model}\n"
|
|
38
|
+
op += f" Refine Model: {self.refine_model}\n"
|
|
39
|
+
return op
|
|
40
|
+
|
|
41
|
+
def get_id_params(self):
|
|
42
|
+
id_params = super().get_id_params()
|
|
43
|
+
id_params = {
|
|
44
|
+
"critic_model": self.critic_model.value,
|
|
45
|
+
"refine_model": self.refine_model.value,
|
|
46
|
+
**id_params,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
return id_params
|
|
50
|
+
|
|
51
|
+
def get_op_params(self):
|
|
52
|
+
op_params = super().get_op_params()
|
|
53
|
+
op_params = {
|
|
54
|
+
"critic_model": self.critic_model,
|
|
55
|
+
"refine_model": self.refine_model,
|
|
56
|
+
**op_params,
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
return op_params
|
|
60
|
+
|
|
61
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
62
|
+
"""
|
|
63
|
+
Currently, we are invoking `self.model`, then critiquing its output with `self.critic_model`, and
|
|
64
|
+
finally refining the output with `self.refine_model`. Thus, we roughly expect to incur the cost
|
|
65
|
+
and time of three LLMConverts. In practice, this naive quality estimate will be overwritten by the
|
|
66
|
+
CostModel's estimate once it executes a few instances of the operator.
|
|
67
|
+
"""
|
|
68
|
+
# get naive cost estimates for first LLM call and multiply by 3 for now;
|
|
69
|
+
# of course we should sum individual estimates for each model, but this is a rough estimate
|
|
70
|
+
# and in practice we will need to revamp our naive cost estimates in the near future
|
|
71
|
+
naive_op_cost_estimates = 3 * super().naive_cost_estimates(source_op_cost_estimates)
|
|
72
|
+
|
|
73
|
+
# for naive setting, estimate quality as quality of refine model
|
|
74
|
+
model_quality = MODEL_CARDS[self.refine_model.value]["overall"] / 100.0
|
|
75
|
+
naive_op_cost_estimates.quality = model_quality
|
|
76
|
+
naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
|
|
77
|
+
naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
|
|
78
|
+
|
|
79
|
+
return naive_op_cost_estimates
|
|
80
|
+
|
|
81
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
|
|
82
|
+
# get input fields
|
|
83
|
+
input_fields = self.get_input_fields()
|
|
84
|
+
|
|
85
|
+
# NOTE: when I merge in the `abacus` branch, I will want to update this to reflect the changes I made to reasoning extraction
|
|
86
|
+
# execute the initial model
|
|
87
|
+
original_gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema}
|
|
88
|
+
field_answers, reasoning, original_gen_stats, original_messages = self.generator(candidate, fields, **original_gen_kwargs)
|
|
89
|
+
original_output = f"REASONING: {reasoning}\nANSWER: {field_answers}\n"
|
|
90
|
+
|
|
91
|
+
# execute the critic model
|
|
92
|
+
critic_gen_kwargs = {"original_output": original_output, "original_messages": original_messages, **original_gen_kwargs}
|
|
93
|
+
_, reasoning, critic_gen_stats, _ = self.critic_generator(candidate, fields, json_output=False, **critic_gen_kwargs)
|
|
94
|
+
critique_output = f"CRITIQUE: {reasoning}\n"
|
|
95
|
+
|
|
96
|
+
# execute the refinement model
|
|
97
|
+
refine_gen_kwargs = {"critique_output": critique_output, **critic_gen_kwargs}
|
|
98
|
+
field_answers, reasoning, refine_gen_stats, _ = self.refine_generator(candidate, fields, **refine_gen_kwargs)
|
|
99
|
+
|
|
100
|
+
# compute the total generation stats
|
|
101
|
+
generation_stats = original_gen_stats + critic_gen_stats + refine_gen_stats
|
|
102
|
+
|
|
103
|
+
return field_answers, generation_stats
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class CritiqueAndRefineFilter(LLMFilter):
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
critic_model: Model,
|
|
111
|
+
refine_model: Model,
|
|
112
|
+
*args,
|
|
113
|
+
**kwargs,
|
|
114
|
+
):
|
|
115
|
+
super().__init__(*args, **kwargs)
|
|
116
|
+
self.critic_model = critic_model
|
|
117
|
+
self.refine_model = refine_model
|
|
118
|
+
|
|
119
|
+
# create generators
|
|
120
|
+
self.critic_generator = Generator(self.critic_model, PromptStrategy.FILTER_CRITIC, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
|
|
121
|
+
self.refine_generator = Generator(self.refine_model, PromptStrategy.FILTER_REFINE, self.reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
|
|
122
|
+
|
|
123
|
+
def __str__(self):
|
|
124
|
+
op = super().__str__()
|
|
125
|
+
op += f" Critic Model: {self.critic_model}\n"
|
|
126
|
+
op += f" Refine Model: {self.refine_model}\n"
|
|
127
|
+
return op
|
|
128
|
+
|
|
129
|
+
def get_id_params(self):
|
|
130
|
+
id_params = super().get_id_params()
|
|
131
|
+
id_params = {
|
|
132
|
+
"critic_model": self.critic_model.value,
|
|
133
|
+
"refine_model": self.refine_model.value,
|
|
134
|
+
**id_params,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
return id_params
|
|
138
|
+
|
|
139
|
+
def get_op_params(self):
|
|
140
|
+
op_params = super().get_op_params()
|
|
141
|
+
op_params = {
|
|
142
|
+
"critic_model": self.critic_model,
|
|
143
|
+
"refine_model": self.refine_model,
|
|
144
|
+
**op_params,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
return op_params
|
|
148
|
+
|
|
149
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
150
|
+
"""
|
|
151
|
+
Currently, we are invoking `self.model`, then critiquing its output with `self.critic_model`, and
|
|
152
|
+
finally refining the output with `self.refine_model`. Thus, we roughly expect to incur the cost
|
|
153
|
+
and time of three LLMFilters. In practice, this naive quality estimate will be overwritten by the
|
|
154
|
+
CostModel's estimate once it executes a few instances of the operator.
|
|
155
|
+
"""
|
|
156
|
+
# get naive cost estimates for first LLM call and multiply by 3 for now;
|
|
157
|
+
# of course we should sum individual estimates for each model, but this is a rough estimate
|
|
158
|
+
# and in practice we will need to revamp our naive cost estimates in the near future
|
|
159
|
+
naive_op_cost_estimates = 3 * super().naive_cost_estimates(source_op_cost_estimates)
|
|
160
|
+
|
|
161
|
+
# for naive setting, estimate quality as quality of refine model
|
|
162
|
+
model_quality = MODEL_CARDS[self.refine_model.value]["overall"] / 100.0
|
|
163
|
+
naive_op_cost_estimates.quality = model_quality
|
|
164
|
+
naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
|
|
165
|
+
naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
|
|
166
|
+
|
|
167
|
+
return naive_op_cost_estimates
|
|
168
|
+
|
|
169
|
+
def filter(self, candidate: DataRecord) -> tuple[dict[str, bool], GenerationStats]:
|
|
170
|
+
# get input fields
|
|
171
|
+
input_fields = self.get_input_fields()
|
|
172
|
+
|
|
173
|
+
# construct output fields
|
|
174
|
+
fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the record passed the filter operation")}
|
|
175
|
+
|
|
176
|
+
# NOTE: when I merge in the `abacus` branch, I will want to update this to reflect the changes I made to reasoning extraction
|
|
177
|
+
# execute the initial model
|
|
178
|
+
original_gen_kwargs = {"project_cols": input_fields, "filter_condition": self.filter_obj.filter_condition}
|
|
179
|
+
field_answers, reasoning, original_gen_stats, original_messages = self.generator(candidate, fields, **original_gen_kwargs)
|
|
180
|
+
original_output = f"REASONING: {reasoning}\nANSWER: {str(field_answers['passed_operator']).upper()}\n"
|
|
181
|
+
|
|
182
|
+
# execute the critic model
|
|
183
|
+
critic_gen_kwargs = {"original_output": original_output, "original_messages": original_messages, **original_gen_kwargs}
|
|
184
|
+
_, reasoning, critic_gen_stats, _ = self.critic_generator(candidate, fields, json_output=False, **critic_gen_kwargs)
|
|
185
|
+
critique_output = f"CRITIQUE: {reasoning}\n"
|
|
186
|
+
|
|
187
|
+
# execute the refinement model
|
|
188
|
+
refine_gen_kwargs = {"critique_output": critique_output, **critic_gen_kwargs}
|
|
189
|
+
field_answers, reasoning, refine_gen_stats, _ = self.refine_generator(candidate, fields, **refine_gen_kwargs)
|
|
190
|
+
|
|
191
|
+
# compute the total generation stats
|
|
192
|
+
generation_stats = original_gen_stats + critic_gen_stats + refine_gen_stats
|
|
193
|
+
|
|
194
|
+
return field_answers, generation_stats
|
|
@@ -35,27 +35,27 @@ class DistinctOp(PhysicalOperator):
|
|
|
35
35
|
|
|
36
36
|
def __call__(self, candidate: DataRecord) -> DataRecordSet:
|
|
37
37
|
# create new DataRecord
|
|
38
|
-
dr = DataRecord.from_parent(schema=candidate.schema, parent_record=candidate)
|
|
38
|
+
dr = DataRecord.from_parent(schema=candidate.schema, data_item={}, parent_record=candidate)
|
|
39
39
|
|
|
40
40
|
# output record only if it has not been seen before
|
|
41
41
|
record_str = dr.to_json_str(project_cols=self.distinct_cols, bytes_to_str=True, sorted=True)
|
|
42
42
|
record_hash = f"{hash(record_str)}"
|
|
43
|
-
dr.
|
|
44
|
-
if dr.
|
|
43
|
+
dr._passed_operator = record_hash not in self._distinct_seen
|
|
44
|
+
if dr._passed_operator:
|
|
45
45
|
self._distinct_seen.add(record_hash)
|
|
46
46
|
|
|
47
47
|
# create RecordOpStats object
|
|
48
48
|
record_op_stats = RecordOpStats(
|
|
49
|
-
record_id=dr.
|
|
50
|
-
record_parent_ids=dr.
|
|
51
|
-
record_source_indices=dr.
|
|
49
|
+
record_id=dr._id,
|
|
50
|
+
record_parent_ids=dr._parent_ids,
|
|
51
|
+
record_source_indices=dr._source_indices,
|
|
52
52
|
record_state=dr.to_dict(include_bytes=False),
|
|
53
53
|
full_op_id=self.get_full_op_id(),
|
|
54
54
|
logical_op_id=self.logical_op_id,
|
|
55
55
|
op_name=self.op_name(),
|
|
56
56
|
time_per_record=0.0,
|
|
57
57
|
cost_per_record=0.0,
|
|
58
|
-
passed_operator=dr.
|
|
58
|
+
passed_operator=dr._passed_operator,
|
|
59
59
|
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
60
60
|
)
|
|
61
61
|
|
|
@@ -41,11 +41,6 @@ class FilterOp(PhysicalOperator, ABC):
|
|
|
41
41
|
op_params = super().get_op_params()
|
|
42
42
|
return {"filter": self.filter_obj, "desc": self.desc, **op_params}
|
|
43
43
|
|
|
44
|
-
@abstractmethod
|
|
45
|
-
def is_image_filter(self) -> bool:
|
|
46
|
-
"""Return True if the filter operation processes an image, False otherwise."""
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
44
|
@abstractmethod
|
|
50
45
|
def filter(self, candidate: DataRecord) -> tuple[dict[str, bool], GenerationStats]:
|
|
51
46
|
"""
|
|
@@ -76,14 +71,14 @@ class FilterOp(PhysicalOperator, ABC):
|
|
|
76
71
|
construct the resulting RecordSet.
|
|
77
72
|
"""
|
|
78
73
|
# create new DataRecord and set passed_operator attribute
|
|
79
|
-
dr = DataRecord.from_parent(candidate.schema, parent_record=candidate)
|
|
80
|
-
dr.
|
|
74
|
+
dr = DataRecord.from_parent(schema=candidate.schema, data_item={}, parent_record=candidate)
|
|
75
|
+
dr._passed_operator = passed_operator
|
|
81
76
|
|
|
82
77
|
# create RecordOpStats object
|
|
83
78
|
record_op_stats = RecordOpStats(
|
|
84
|
-
record_id=dr.
|
|
85
|
-
record_parent_ids=dr.
|
|
86
|
-
record_source_indices=dr.
|
|
79
|
+
record_id=dr._id,
|
|
80
|
+
record_parent_ids=dr._parent_ids,
|
|
81
|
+
record_source_indices=dr._source_indices,
|
|
87
82
|
record_state=dr.to_dict(include_bytes=False),
|
|
88
83
|
full_op_id=self.get_full_op_id(),
|
|
89
84
|
logical_op_id=self.logical_op_id,
|
|
@@ -102,7 +97,6 @@ class FilterOp(PhysicalOperator, ABC):
|
|
|
102
97
|
total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
|
|
103
98
|
answer=answer,
|
|
104
99
|
passed_operator=passed_operator,
|
|
105
|
-
image_operation=self.is_image_filter(),
|
|
106
100
|
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
107
101
|
)
|
|
108
102
|
|
|
@@ -127,10 +121,6 @@ class FilterOp(PhysicalOperator, ABC):
|
|
|
127
121
|
|
|
128
122
|
|
|
129
123
|
class NonLLMFilter(FilterOp):
|
|
130
|
-
def is_image_filter(self) -> bool:
|
|
131
|
-
# NOTE: even if the UDF is processing an image, we do not consider this an image filter
|
|
132
|
-
# (the output of this function will be used by the CostModel in a way which does not apply to UDFs)
|
|
133
|
-
return False
|
|
134
124
|
|
|
135
125
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates):
|
|
136
126
|
# estimate output cardinality using a constant assumption of the filter selectivity
|
|
@@ -174,7 +164,7 @@ class LLMFilter(FilterOp):
|
|
|
174
164
|
def __init__(
|
|
175
165
|
self,
|
|
176
166
|
model: Model,
|
|
177
|
-
prompt_strategy: PromptStrategy = PromptStrategy.
|
|
167
|
+
prompt_strategy: PromptStrategy = PromptStrategy.FILTER,
|
|
178
168
|
reasoning_effort: str | None = None,
|
|
179
169
|
*args,
|
|
180
170
|
**kwargs,
|
|
@@ -183,13 +173,14 @@ class LLMFilter(FilterOp):
|
|
|
183
173
|
self.model = model
|
|
184
174
|
self.prompt_strategy = prompt_strategy
|
|
185
175
|
self.reasoning_effort = reasoning_effort
|
|
186
|
-
|
|
176
|
+
if model is not None:
|
|
177
|
+
self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
|
|
187
178
|
|
|
188
179
|
def get_id_params(self):
|
|
189
180
|
id_params = super().get_id_params()
|
|
190
181
|
id_params = {
|
|
191
|
-
"model": self.model.value,
|
|
192
|
-
"prompt_strategy": self.prompt_strategy.value,
|
|
182
|
+
"model": None if self.model is None else self.model.value,
|
|
183
|
+
"prompt_strategy": None if self.prompt_strategy is None else self.prompt_strategy.value,
|
|
193
184
|
"reasoning_effort": self.reasoning_effort,
|
|
194
185
|
**id_params,
|
|
195
186
|
}
|
|
@@ -208,15 +199,12 @@ class LLMFilter(FilterOp):
|
|
|
208
199
|
return op_params
|
|
209
200
|
|
|
210
201
|
def get_model_name(self):
|
|
211
|
-
return self.model.value
|
|
212
|
-
|
|
213
|
-
def is_image_filter(self) -> bool:
|
|
214
|
-
return self.prompt_strategy is PromptStrategy.COT_BOOL_IMAGE
|
|
202
|
+
return None if self.model is None else self.model.value
|
|
215
203
|
|
|
216
204
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates):
|
|
217
205
|
# estimate number of input tokens from source
|
|
218
206
|
est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS
|
|
219
|
-
if self.
|
|
207
|
+
if self.is_image_op():
|
|
220
208
|
est_num_input_tokens = 765 / 10 # 1024x1024 image is 765 tokens
|
|
221
209
|
|
|
222
210
|
# NOTE: the output often generates an entire reasoning sentence, thus the true value may be higher
|
|
@@ -232,7 +220,7 @@ class LLMFilter(FilterOp):
|
|
|
232
220
|
# get est. of conversion cost (in USD) per record from model card
|
|
233
221
|
usd_per_input_token = (
|
|
234
222
|
MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
|
|
235
|
-
if self.
|
|
223
|
+
if self.is_audio_op()
|
|
236
224
|
else MODEL_CARDS[self.model.value]["usd_per_input_token"]
|
|
237
225
|
)
|
|
238
226
|
model_conversion_usd_per_record = (
|