palimpzest 0.8.6__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +12 -4
- palimpzest/core/data/dataset.py +42 -0
- palimpzest/core/elements/records.py +5 -1
- palimpzest/core/lib/schemas.py +13 -0
- palimpzest/prompts/aggregate_prompts.py +99 -0
- palimpzest/prompts/prompt_factory.py +163 -75
- palimpzest/prompts/utils.py +38 -1
- palimpzest/prompts/validator.py +24 -24
- palimpzest/query/generators/generators.py +9 -7
- palimpzest/query/operators/__init__.py +4 -1
- palimpzest/query/operators/aggregate.py +285 -6
- palimpzest/query/operators/logical.py +17 -4
- palimpzest/query/optimizer/__init__.py +4 -0
- palimpzest/query/optimizer/rules.py +42 -2
- palimpzest/validator/validator.py +7 -7
- {palimpzest-0.8.6.dist-info → palimpzest-0.9.0.dist-info}/METADATA +1 -1
- {palimpzest-0.8.6.dist-info → palimpzest-0.9.0.dist-info}/RECORD +20 -19
- {palimpzest-0.8.6.dist-info → palimpzest-0.9.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.8.6.dist-info → palimpzest-0.9.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.8.6.dist-info → palimpzest-0.9.0.dist-info}/top_level.txt +0 -0
palimpzest/prompts/validator.py
CHANGED
|
@@ -22,17 +22,17 @@ OUTPUT FIELDS:
|
|
|
22
22
|
- birth_year: the year the scientist was born
|
|
23
23
|
|
|
24
24
|
CONTEXT:
|
|
25
|
-
{
|
|
25
|
+
{
|
|
26
26
|
"text": "Augusta Ada King, Countess of Lovelace, also known as Ada Lovelace, was an English mathematician and writer chiefly known for her work on Charles Babbage's proposed mechanical general-purpose computer, the Analytical Engine. She was the first to recognise that the machine had applications beyond pure calculation.",
|
|
27
27
|
"birthday": "December 10, 1815"
|
|
28
|
-
}
|
|
28
|
+
}
|
|
29
29
|
|
|
30
30
|
OUTPUT:
|
|
31
31
|
--------
|
|
32
|
-
{
|
|
32
|
+
{
|
|
33
33
|
"name": "Charles Babbage",
|
|
34
34
|
"birth_year": 1815
|
|
35
|
-
}
|
|
35
|
+
}
|
|
36
36
|
|
|
37
37
|
EVALUATION: {"name": 0.0, "birth_year": 1.0}
|
|
38
38
|
|
|
@@ -66,18 +66,18 @@ OUTPUT FIELDS:
|
|
|
66
66
|
- person_in_image: true if a person is in the image and false otherwise
|
|
67
67
|
|
|
68
68
|
CONTEXT:
|
|
69
|
-
{
|
|
69
|
+
{
|
|
70
70
|
"image": <bytes>,
|
|
71
71
|
"photographer": "CameraEnthusiast1"
|
|
72
|
-
}
|
|
72
|
+
}
|
|
73
73
|
<image content provided here; assume in this example the image shows a dog and a cat playing>
|
|
74
74
|
|
|
75
75
|
OUTPUT:
|
|
76
76
|
--------
|
|
77
|
-
{
|
|
77
|
+
{
|
|
78
78
|
"dog_in_image": true,
|
|
79
79
|
"person_in_image": true
|
|
80
|
-
}
|
|
80
|
+
}
|
|
81
81
|
|
|
82
82
|
EVALUATION: {"dog_in_image": 1.0, "person_in_image": 0.0}
|
|
83
83
|
|
|
@@ -113,22 +113,22 @@ OUTPUT FIELDS:
|
|
|
113
113
|
- birth_year: the year the scientist was born
|
|
114
114
|
|
|
115
115
|
CONTEXT:
|
|
116
|
-
{
|
|
116
|
+
{
|
|
117
117
|
"text": "Augusta Ada King, Countess of Lovelace, also known as Ada Lovelace, was an English mathematician and writer chiefly known for her work on Charles Babbage's proposed mechanical general-purpose computer, the Analytical Engine. She was the first to recognise that the machine had applications beyond pure calculation.",
|
|
118
118
|
"birthdays": "...Lovelace was born on December 10, 1815, almost exactly 24 years after Babbage's birth on 26 December 1791..."
|
|
119
|
-
}
|
|
119
|
+
}
|
|
120
120
|
|
|
121
121
|
OUTPUTS:
|
|
122
122
|
--------
|
|
123
123
|
[
|
|
124
|
-
{
|
|
124
|
+
{
|
|
125
125
|
"name": "Ada Lovelace",
|
|
126
126
|
"birth_year": 1815
|
|
127
|
-
}
|
|
128
|
-
{
|
|
127
|
+
},
|
|
128
|
+
{
|
|
129
129
|
"name": "Charles Babbage",
|
|
130
130
|
"birth_year": 1790
|
|
131
|
-
}
|
|
131
|
+
}
|
|
132
132
|
]
|
|
133
133
|
|
|
134
134
|
EVALUATION: [{"name": 1.0, "birth_year": 1.0}, {"name": 1.0, "birth_year": 0.0}]
|
|
@@ -163,23 +163,23 @@ OUTPUT FIELDS:
|
|
|
163
163
|
- animal_is_canine: true if the animal is a canine and false otherwise
|
|
164
164
|
|
|
165
165
|
CONTEXT:
|
|
166
|
-
{
|
|
166
|
+
{
|
|
167
167
|
"image": <bytes>,
|
|
168
168
|
"photographer": "CameraEnthusiast1"
|
|
169
|
-
}
|
|
169
|
+
}
|
|
170
170
|
<image content provided here; assume in this example the image shows a dog and a cat playing>
|
|
171
171
|
|
|
172
172
|
OUTPUT:
|
|
173
173
|
--------
|
|
174
174
|
[
|
|
175
|
-
{
|
|
175
|
+
{
|
|
176
176
|
"animal": "dog",
|
|
177
177
|
"animal_is_canine": true
|
|
178
|
-
}
|
|
179
|
-
{
|
|
178
|
+
},
|
|
179
|
+
{
|
|
180
180
|
"animal": "cat",
|
|
181
181
|
"animal_is_canine": true
|
|
182
|
-
}
|
|
182
|
+
}
|
|
183
183
|
]
|
|
184
184
|
|
|
185
185
|
EVALUATION: [{"animal": 1.0, "animal_is_canine": 1.0}, {"animal": 1.0, "animal_is_canine": 0.0}]
|
|
@@ -214,20 +214,20 @@ OUTPUT FIELDS:
|
|
|
214
214
|
- related_scientists: list of scientists who perform similar work as the scientist described in the text
|
|
215
215
|
|
|
216
216
|
CONTEXT:
|
|
217
|
-
{
|
|
217
|
+
{
|
|
218
218
|
"text": "Augusta Ada King, Countess of Lovelace, also known as Ada Lovelace, was an English mathematician and writer chiefly known for her work on Charles Babbage's proposed mechanical general-purpose computer, the Analytical Engine. She was the first to recognise that the machine had applications beyond pure calculation.",
|
|
219
|
-
}
|
|
219
|
+
}
|
|
220
220
|
|
|
221
221
|
OUTPUT:
|
|
222
222
|
--------
|
|
223
|
-
{
|
|
223
|
+
{
|
|
224
224
|
"related_scientists": [
|
|
225
225
|
"Charles Babbage",
|
|
226
226
|
"Alan Turing",
|
|
227
227
|
"Charles Darwin",
|
|
228
228
|
"John von Neumann",
|
|
229
229
|
]
|
|
230
|
-
}
|
|
230
|
+
}
|
|
231
231
|
|
|
232
232
|
EVALUATION: {"related_scientists": 0.75}
|
|
233
233
|
|
|
@@ -296,9 +296,9 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
296
296
|
|
|
297
297
|
return field_answers
|
|
298
298
|
|
|
299
|
-
def __call__(self, candidate: DataRecord, fields: dict[str, FieldInfo] | None, right_candidate: DataRecord | None = None, json_output: bool=True, **kwargs) -> GenerationOutput:
|
|
300
|
-
"""Take the input record (`candidate`), generate the output `fields`, and return the generated output."""
|
|
301
|
-
logger.debug(f"Generating for candidate {candidate} with fields {fields}")
|
|
299
|
+
def __call__(self, candidate: DataRecord | list[DataRecord], fields: dict[str, FieldInfo] | None, right_candidate: DataRecord | None = None, json_output: bool=True, **kwargs) -> GenerationOutput:
|
|
300
|
+
"""Take the input record(s) (`candidate`), generate the output `fields`, and return the generated output."""
|
|
301
|
+
logger.debug(f"Generating for candidate(s) {candidate} with fields {fields}")
|
|
302
302
|
|
|
303
303
|
# fields can only be None if the user provides an answer parser
|
|
304
304
|
fields_check = fields is not None or "parse_answer" in kwargs
|
|
@@ -338,7 +338,7 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
338
338
|
reasoning_effort = "minimal" if self.reasoning_effort is None else self.reasoning_effort
|
|
339
339
|
completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
|
|
340
340
|
if self.model.is_vllm_model():
|
|
341
|
-
completion_kwargs = {"api_base": self.api_base, **completion_kwargs}
|
|
341
|
+
completion_kwargs = {"api_base": self.api_base, "api_key": os.environ.get("VLLM_API_KEY", "fake-api-key") **completion_kwargs}
|
|
342
342
|
completion = litellm.completion(model=self.model_name, messages=messages, **completion_kwargs)
|
|
343
343
|
end_time = time.time()
|
|
344
344
|
logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds")
|
|
@@ -405,15 +405,17 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
405
405
|
|
|
406
406
|
# pretty print prompt + full completion output for debugging
|
|
407
407
|
completion_text = completion.choices[0].message.content
|
|
408
|
-
prompt = ""
|
|
408
|
+
prompt, system_prompt = "", ""
|
|
409
409
|
for message in messages:
|
|
410
|
+
if message["role"] == "system":
|
|
411
|
+
system_prompt += message["content"] + "\n"
|
|
410
412
|
if message["role"] == "user":
|
|
411
413
|
if message["type"] == "text":
|
|
412
414
|
prompt += message["content"] + "\n"
|
|
413
415
|
elif message["type"] == "image":
|
|
414
|
-
prompt += "<image>\n"
|
|
416
|
+
prompt += "<image>\n" * len(message["content"])
|
|
415
417
|
elif message["type"] == "input_audio":
|
|
416
|
-
prompt += "<audio>\n"
|
|
418
|
+
prompt += "<audio>\n" * len(message["content"])
|
|
417
419
|
logger.debug(f"PROMPT:\n{prompt}")
|
|
418
420
|
logger.debug(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
|
|
419
421
|
|
|
@@ -2,6 +2,9 @@ from palimpzest.query.operators.aggregate import AggregateOp as _AggregateOp
|
|
|
2
2
|
from palimpzest.query.operators.aggregate import ApplyGroupByOp as _ApplyGroupByOp
|
|
3
3
|
from palimpzest.query.operators.aggregate import AverageAggregateOp as _AverageAggregateOp
|
|
4
4
|
from palimpzest.query.operators.aggregate import CountAggregateOp as _CountAggregateOp
|
|
5
|
+
from palimpzest.query.operators.aggregate import MaxAggregateOp as _MaxAggregateOp
|
|
6
|
+
from palimpzest.query.operators.aggregate import MinAggregateOp as _MinAggregateOp
|
|
7
|
+
from palimpzest.query.operators.aggregate import SemanticAggregate as _SemanticAggregate
|
|
5
8
|
from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
|
|
6
9
|
from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
|
|
7
10
|
from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
|
|
@@ -77,7 +80,7 @@ LOGICAL_OPERATORS = [
|
|
|
77
80
|
|
|
78
81
|
PHYSICAL_OPERATORS = (
|
|
79
82
|
# aggregate
|
|
80
|
-
[_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp]
|
|
83
|
+
[_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp, _MaxAggregateOp, _MinAggregateOp, _SemanticAggregate]
|
|
81
84
|
# convert
|
|
82
85
|
+ [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
|
|
83
86
|
# critique and refine
|
|
@@ -1,12 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import time
|
|
4
|
-
|
|
5
|
-
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from palimpzest.constants import (
|
|
7
|
+
MODEL_CARDS,
|
|
8
|
+
NAIVE_EST_NUM_GROUPS,
|
|
9
|
+
NAIVE_EST_NUM_INPUT_TOKENS,
|
|
10
|
+
NAIVE_EST_NUM_OUTPUT_TOKENS,
|
|
11
|
+
AggFunc,
|
|
12
|
+
Model,
|
|
13
|
+
PromptStrategy,
|
|
14
|
+
)
|
|
6
15
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
7
16
|
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
8
|
-
from palimpzest.core.lib.schemas import Average, Count
|
|
17
|
+
from palimpzest.core.lib.schemas import Average, Count, Max, Min
|
|
9
18
|
from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
|
|
19
|
+
from palimpzest.query.generators.generators import Generator
|
|
10
20
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
11
21
|
|
|
12
22
|
|
|
@@ -156,12 +166,17 @@ class AverageAggregateOp(AggregateOp):
|
|
|
156
166
|
|
|
157
167
|
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
158
168
|
# enforce that output schema is correct
|
|
159
|
-
assert kwargs["output_schema"] == Average, "AverageAggregateOp requires output_schema to be Average"
|
|
169
|
+
assert kwargs["output_schema"].model_fields.keys() == Average.model_fields.keys(), "AverageAggregateOp requires output_schema to be Average"
|
|
160
170
|
|
|
161
171
|
# enforce that input schema is a single numeric field
|
|
162
172
|
input_field_types = list(kwargs["input_schema"].model_fields.values())
|
|
163
173
|
assert len(input_field_types) == 1, "AverageAggregateOp requires input_schema to have exactly one field"
|
|
164
|
-
numeric_field_types = [
|
|
174
|
+
numeric_field_types = [
|
|
175
|
+
bool, int, float, int | float,
|
|
176
|
+
bool | None, int | None, float | None, int | float | None,
|
|
177
|
+
bool | Any, int | Any, float | Any, int | float | Any,
|
|
178
|
+
bool | None | Any, int | None | Any, float | None | Any, int | float | None | Any,
|
|
179
|
+
]
|
|
165
180
|
is_numeric = input_field_types[0].annotation in numeric_field_types
|
|
166
181
|
assert is_numeric, f"AverageAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
|
|
167
182
|
|
|
@@ -230,7 +245,7 @@ class CountAggregateOp(AggregateOp):
|
|
|
230
245
|
|
|
231
246
|
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
232
247
|
# enforce that output schema is correct
|
|
233
|
-
assert kwargs["output_schema"] == Count, "CountAggregateOp requires output_schema to be Count"
|
|
248
|
+
assert kwargs["output_schema"].model_fields.keys() == Count.model_fields.keys(), "CountAggregateOp requires output_schema to be Count"
|
|
234
249
|
|
|
235
250
|
# call parent constructor
|
|
236
251
|
super().__init__(*args, **kwargs)
|
|
@@ -280,3 +295,267 @@ class CountAggregateOp(AggregateOp):
|
|
|
280
295
|
)
|
|
281
296
|
|
|
282
297
|
return DataRecordSet([dr], [record_op_stats])
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class MinAggregateOp(AggregateOp):
|
|
301
|
+
# NOTE: we don't actually need / use agg_func here (yet)
|
|
302
|
+
|
|
303
|
+
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
304
|
+
# enforce that output schema is correct
|
|
305
|
+
assert kwargs["output_schema"].model_fields.keys() == Min.model_fields.keys(), "MinAggregateOp requires output_schema to be Min"
|
|
306
|
+
|
|
307
|
+
# call parent constructor
|
|
308
|
+
super().__init__(*args, **kwargs)
|
|
309
|
+
self.agg_func = agg_func
|
|
310
|
+
|
|
311
|
+
def __str__(self):
|
|
312
|
+
op = super().__str__()
|
|
313
|
+
op += f" Function: {str(self.agg_func)}\n"
|
|
314
|
+
return op
|
|
315
|
+
|
|
316
|
+
def get_id_params(self):
|
|
317
|
+
id_params = super().get_id_params()
|
|
318
|
+
return {"agg_func": str(self.agg_func), **id_params}
|
|
319
|
+
|
|
320
|
+
def get_op_params(self):
|
|
321
|
+
op_params = super().get_op_params()
|
|
322
|
+
return {"agg_func": self.agg_func, **op_params}
|
|
323
|
+
|
|
324
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
325
|
+
# for now, assume applying the aggregation takes negligible additional time (and no cost in USD)
|
|
326
|
+
return OperatorCostEstimates(
|
|
327
|
+
cardinality=1,
|
|
328
|
+
time_per_record=0,
|
|
329
|
+
cost_per_record=0,
|
|
330
|
+
quality=1.0,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
334
|
+
start_time = time.time()
|
|
335
|
+
|
|
336
|
+
# create new DataRecord
|
|
337
|
+
min = float("inf")
|
|
338
|
+
for candidate in candidates:
|
|
339
|
+
try: # noqa: SIM105
|
|
340
|
+
min = min(float(list(candidate.to_dict().values())[0]), min)
|
|
341
|
+
except Exception:
|
|
342
|
+
pass
|
|
343
|
+
data_item = Min(min=min if min != float("inf") else None)
|
|
344
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
345
|
+
|
|
346
|
+
# create RecordOpStats object
|
|
347
|
+
record_op_stats = RecordOpStats(
|
|
348
|
+
record_id=dr.id,
|
|
349
|
+
record_parent_ids=dr.parent_ids,
|
|
350
|
+
record_source_indices=dr.source_indices,
|
|
351
|
+
record_state=dr.to_dict(include_bytes=False),
|
|
352
|
+
full_op_id=self.get_full_op_id(),
|
|
353
|
+
logical_op_id=self.logical_op_id,
|
|
354
|
+
op_name=self.op_name(),
|
|
355
|
+
time_per_record=time.time() - start_time,
|
|
356
|
+
cost_per_record=0.0,
|
|
357
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
return DataRecordSet([dr], [record_op_stats])
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class MaxAggregateOp(AggregateOp):
|
|
364
|
+
# NOTE: we don't actually need / use agg_func here (yet)
|
|
365
|
+
|
|
366
|
+
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
367
|
+
# enforce that output schema is correct
|
|
368
|
+
assert kwargs["output_schema"].model_fields.keys() == Max.model_fields.keys(), "MaxAggregateOp requires output_schema to be Max"
|
|
369
|
+
|
|
370
|
+
# call parent constructor
|
|
371
|
+
super().__init__(*args, **kwargs)
|
|
372
|
+
self.agg_func = agg_func
|
|
373
|
+
|
|
374
|
+
def __str__(self):
|
|
375
|
+
op = super().__str__()
|
|
376
|
+
op += f" Function: {str(self.agg_func)}\n"
|
|
377
|
+
return op
|
|
378
|
+
|
|
379
|
+
def get_id_params(self):
|
|
380
|
+
id_params = super().get_id_params()
|
|
381
|
+
return {"agg_func": str(self.agg_func), **id_params}
|
|
382
|
+
|
|
383
|
+
def get_op_params(self):
|
|
384
|
+
op_params = super().get_op_params()
|
|
385
|
+
return {"agg_func": self.agg_func, **op_params}
|
|
386
|
+
|
|
387
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
388
|
+
# for now, assume applying the aggregation takes negligible additional time (and no cost in USD)
|
|
389
|
+
return OperatorCostEstimates(
|
|
390
|
+
cardinality=1,
|
|
391
|
+
time_per_record=0,
|
|
392
|
+
cost_per_record=0,
|
|
393
|
+
quality=1.0,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
397
|
+
start_time = time.time()
|
|
398
|
+
|
|
399
|
+
# create new DataRecord
|
|
400
|
+
|
|
401
|
+
max = float("-inf")
|
|
402
|
+
for candidate in candidates:
|
|
403
|
+
try: # noqa: SIM105
|
|
404
|
+
max = max(float(list(candidate.to_dict().values())[0]), max)
|
|
405
|
+
except Exception:
|
|
406
|
+
pass
|
|
407
|
+
data_item = Max(max=max if max != float("-inf") else None)
|
|
408
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
409
|
+
|
|
410
|
+
# create RecordOpStats object
|
|
411
|
+
record_op_stats = RecordOpStats(
|
|
412
|
+
record_id=dr.id,
|
|
413
|
+
record_parent_ids=dr.parent_ids,
|
|
414
|
+
record_source_indices=dr.source_indices,
|
|
415
|
+
record_state=dr.to_dict(include_bytes=False),
|
|
416
|
+
full_op_id=self.get_full_op_id(),
|
|
417
|
+
logical_op_id=self.logical_op_id,
|
|
418
|
+
op_name=self.op_name(),
|
|
419
|
+
time_per_record=time.time() - start_time,
|
|
420
|
+
cost_per_record=0.0,
|
|
421
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
return DataRecordSet([dr], [record_op_stats])
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
class SemanticAggregate(AggregateOp):
|
|
428
|
+
|
|
429
|
+
def __init__(self, agg_str: str, model: Model, prompt_strategy: PromptStrategy = PromptStrategy.AGG, reasoning_effort: str | None = None, *args, **kwargs):
|
|
430
|
+
# call parent constructor
|
|
431
|
+
super().__init__(*args, **kwargs)
|
|
432
|
+
self.agg_str = agg_str
|
|
433
|
+
self.model = model
|
|
434
|
+
self.prompt_strategy = prompt_strategy
|
|
435
|
+
self.reasoning_effort = reasoning_effort
|
|
436
|
+
if model is not None:
|
|
437
|
+
self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base)
|
|
438
|
+
|
|
439
|
+
def __str__(self):
|
|
440
|
+
op = super().__str__()
|
|
441
|
+
op += f" Prompt Strategy: {self.prompt_strategy}\n"
|
|
442
|
+
op += f" Reasoning Effort: {self.reasoning_effort}\n"
|
|
443
|
+
op += f" Agg: {str(self.agg_str)}\n"
|
|
444
|
+
return op
|
|
445
|
+
|
|
446
|
+
def get_id_params(self):
|
|
447
|
+
id_params = super().get_id_params()
|
|
448
|
+
id_params = {
|
|
449
|
+
"agg_str": self.agg_str,
|
|
450
|
+
"model": None if self.model is None else self.model.value,
|
|
451
|
+
"prompt_strategy": None if self.prompt_strategy is None else self.prompt_strategy.value,
|
|
452
|
+
"reasoning_effort": self.reasoning_effort,
|
|
453
|
+
**id_params,
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
return id_params
|
|
457
|
+
|
|
458
|
+
def get_op_params(self):
|
|
459
|
+
op_params = super().get_op_params()
|
|
460
|
+
op_params = {
|
|
461
|
+
"agg_str": self.agg_str,
|
|
462
|
+
"model": self.model,
|
|
463
|
+
"prompt_strategy": self.prompt_strategy,
|
|
464
|
+
"reasoning_effort": self.reasoning_effort,
|
|
465
|
+
**op_params,
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
return op_params
|
|
469
|
+
|
|
470
|
+
def get_model_name(self) -> str:
|
|
471
|
+
return self.model.value
|
|
472
|
+
|
|
473
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
474
|
+
"""
|
|
475
|
+
Compute naive cost estimates for the LLMConvert operation. Implicitly, these estimates
|
|
476
|
+
assume the use of a single LLM call for each input record. Child classes of LLMConvert
|
|
477
|
+
may call this function through super() and adjust these estimates as needed (or they can
|
|
478
|
+
completely override this function).
|
|
479
|
+
"""
|
|
480
|
+
# estimate number of input and output tokens from source
|
|
481
|
+
est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS * source_op_cost_estimates.cardinality
|
|
482
|
+
est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
|
|
483
|
+
|
|
484
|
+
# get est. of conversion time per record from model card;
|
|
485
|
+
model_name = self.model.value
|
|
486
|
+
model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
|
|
487
|
+
|
|
488
|
+
# get est. of conversion cost (in USD) per record from model card
|
|
489
|
+
usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
|
|
490
|
+
if getattr(self, "prompt_strategy", None) is not None and self.prompt_strategy.is_audio_prompt():
|
|
491
|
+
usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
|
|
492
|
+
|
|
493
|
+
model_conversion_usd_per_record = (
|
|
494
|
+
usd_per_input_token * est_num_input_tokens
|
|
495
|
+
+ MODEL_CARDS[model_name]["usd_per_output_token"] * est_num_output_tokens
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
# estimate quality of output based on the strength of the model being used
|
|
499
|
+
quality = (MODEL_CARDS[model_name]["overall"] / 100.0)
|
|
500
|
+
|
|
501
|
+
return OperatorCostEstimates(
|
|
502
|
+
cardinality=1.0,
|
|
503
|
+
time_per_record=model_conversion_time_per_record,
|
|
504
|
+
cost_per_record=model_conversion_usd_per_record,
|
|
505
|
+
quality=quality,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
509
|
+
start_time = time.time()
|
|
510
|
+
|
|
511
|
+
# TODO: if candidates is an empty list, return an empty DataRecordSet
|
|
512
|
+
if len(candidates) == 0:
|
|
513
|
+
return DataRecordSet([], [])
|
|
514
|
+
|
|
515
|
+
# get the set of input fields to use for the operation
|
|
516
|
+
input_fields = self.get_input_fields()
|
|
517
|
+
|
|
518
|
+
# get the set of output fields to use for the operation
|
|
519
|
+
fields_to_generate = self.get_fields_to_generate(candidates[0])
|
|
520
|
+
fields = {field: field_type for field, field_type in self.output_schema.model_fields.items() if field in fields_to_generate}
|
|
521
|
+
|
|
522
|
+
# construct kwargs for generation
|
|
523
|
+
gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema, "agg_instruction": self.agg_str}
|
|
524
|
+
|
|
525
|
+
# generate outputs for all fields in a single query
|
|
526
|
+
field_answers, _, generation_stats, _ = self.generator(candidates, fields, **gen_kwargs)
|
|
527
|
+
assert all([field in field_answers for field in fields]), "Not all fields were generated!"
|
|
528
|
+
|
|
529
|
+
# construct data record for the output
|
|
530
|
+
field, value = fields_to_generate[0], field_answers[fields_to_generate[0]][0]
|
|
531
|
+
data_item = self.output_schema(**{field: value})
|
|
532
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
533
|
+
|
|
534
|
+
# create RecordOpStats object
|
|
535
|
+
record_op_stats = RecordOpStats(
|
|
536
|
+
record_id=dr._id,
|
|
537
|
+
record_parent_ids=dr._parent_ids,
|
|
538
|
+
record_source_indices=dr._source_indices,
|
|
539
|
+
record_state=dr.to_dict(include_bytes=False),
|
|
540
|
+
full_op_id=self.get_full_op_id(),
|
|
541
|
+
logical_op_id=self.logical_op_id,
|
|
542
|
+
op_name=self.op_name(),
|
|
543
|
+
time_per_record=time.time() - start_time,
|
|
544
|
+
cost_per_record=generation_stats.cost_per_record,
|
|
545
|
+
model_name=self.get_model_name(),
|
|
546
|
+
answer={field: value},
|
|
547
|
+
input_fields=input_fields,
|
|
548
|
+
generated_fields=fields_to_generate,
|
|
549
|
+
total_input_tokens=generation_stats.total_input_tokens,
|
|
550
|
+
total_output_tokens=generation_stats.total_output_tokens,
|
|
551
|
+
total_input_cost=generation_stats.total_input_cost,
|
|
552
|
+
total_output_cost=generation_stats.total_output_cost,
|
|
553
|
+
llm_call_duration_secs=generation_stats.llm_call_duration_secs,
|
|
554
|
+
fn_call_duration_secs=generation_stats.fn_call_duration_secs,
|
|
555
|
+
total_llm_calls=generation_stats.total_llm_calls,
|
|
556
|
+
total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
|
|
557
|
+
image_operation=self.is_image_op(),
|
|
558
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
return DataRecordSet([dr], [record_op_stats])
|
|
@@ -9,7 +9,7 @@ from palimpzest.constants import AggFunc, Cardinality
|
|
|
9
9
|
from palimpzest.core.data import context, dataset
|
|
10
10
|
from palimpzest.core.elements.filters import Filter
|
|
11
11
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
12
|
-
from palimpzest.core.lib.schemas import Average, Count
|
|
12
|
+
from palimpzest.core.lib.schemas import Average, Count, Max, Min
|
|
13
13
|
from palimpzest.utils.hash_helpers import hash_for_id
|
|
14
14
|
|
|
15
15
|
|
|
@@ -149,27 +149,39 @@ class Aggregate(LogicalOperator):
|
|
|
149
149
|
|
|
150
150
|
def __init__(
|
|
151
151
|
self,
|
|
152
|
-
agg_func: AggFunc,
|
|
152
|
+
agg_func: AggFunc | None = None,
|
|
153
|
+
agg_str: str | None = None,
|
|
153
154
|
*args,
|
|
154
155
|
**kwargs,
|
|
155
156
|
):
|
|
157
|
+
assert agg_func is not None or agg_str is not None, "Either agg_func or agg_str must be provided"
|
|
156
158
|
if kwargs.get("output_schema") is None:
|
|
157
159
|
if agg_func == AggFunc.COUNT:
|
|
158
160
|
kwargs["output_schema"] = Count
|
|
159
161
|
elif agg_func == AggFunc.AVERAGE:
|
|
160
162
|
kwargs["output_schema"] = Average
|
|
163
|
+
elif agg_func == AggFunc.MIN:
|
|
164
|
+
kwargs["output_schema"] = Min
|
|
165
|
+
elif agg_func == AggFunc.MAX:
|
|
166
|
+
kwargs["output_schema"] = Max
|
|
161
167
|
else:
|
|
162
168
|
raise ValueError(f"Unsupported aggregation function: {agg_func}")
|
|
163
169
|
|
|
164
170
|
super().__init__(*args, **kwargs)
|
|
165
171
|
self.agg_func = agg_func
|
|
172
|
+
self.agg_str = agg_str
|
|
166
173
|
|
|
167
174
|
def __str__(self):
|
|
168
|
-
|
|
175
|
+
desc = f"function: {str(self.agg_func.value)}" if self.agg_func else f"agg: {self.agg_str}"
|
|
176
|
+
return f"{self.__class__.__name__}({desc})"
|
|
169
177
|
|
|
170
178
|
def get_logical_id_params(self) -> dict:
|
|
171
179
|
logical_id_params = super().get_logical_id_params()
|
|
172
|
-
logical_id_params = {
|
|
180
|
+
logical_id_params = {
|
|
181
|
+
"agg_func": self.agg_func,
|
|
182
|
+
"agg_str": self.agg_str,
|
|
183
|
+
**logical_id_params,
|
|
184
|
+
}
|
|
173
185
|
|
|
174
186
|
return logical_id_params
|
|
175
187
|
|
|
@@ -177,6 +189,7 @@ class Aggregate(LogicalOperator):
|
|
|
177
189
|
logical_op_params = super().get_logical_op_params()
|
|
178
190
|
logical_op_params = {
|
|
179
191
|
"agg_func": self.agg_func,
|
|
192
|
+
"agg_str": self.agg_str,
|
|
180
193
|
**logical_op_params,
|
|
181
194
|
}
|
|
182
195
|
|
|
@@ -47,6 +47,9 @@ from palimpzest.query.optimizer.rules import (
|
|
|
47
47
|
from palimpzest.query.optimizer.rules import (
|
|
48
48
|
Rule as _Rule,
|
|
49
49
|
)
|
|
50
|
+
from palimpzest.query.optimizer.rules import (
|
|
51
|
+
SemanticAggregateRule as _SemanticAggregateRule,
|
|
52
|
+
)
|
|
50
53
|
from palimpzest.query.optimizer.rules import (
|
|
51
54
|
SplitRule as _SplitRule,
|
|
52
55
|
)
|
|
@@ -72,6 +75,7 @@ ALL_RULES = [
|
|
|
72
75
|
_ReorderConverts,
|
|
73
76
|
_RetrieveRule,
|
|
74
77
|
_Rule,
|
|
78
|
+
_SemanticAggregateRule,
|
|
75
79
|
_SplitRule,
|
|
76
80
|
_TransformationRule,
|
|
77
81
|
]
|
|
@@ -12,7 +12,14 @@ from palimpzest.core.lib.schemas import (
|
|
|
12
12
|
IMAGE_LIST_FIELD_TYPES,
|
|
13
13
|
)
|
|
14
14
|
from palimpzest.prompts import CONTEXT_SEARCH_PROMPT
|
|
15
|
-
from palimpzest.query.operators.aggregate import
|
|
15
|
+
from palimpzest.query.operators.aggregate import (
|
|
16
|
+
ApplyGroupByOp,
|
|
17
|
+
AverageAggregateOp,
|
|
18
|
+
CountAggregateOp,
|
|
19
|
+
MaxAggregateOp,
|
|
20
|
+
MinAggregateOp,
|
|
21
|
+
SemanticAggregate,
|
|
22
|
+
)
|
|
16
23
|
from palimpzest.query.operators.compute import SmolAgentsCompute
|
|
17
24
|
from palimpzest.query.operators.convert import LLMConvertBonded, NonLLMConvert
|
|
18
25
|
from palimpzest.query.operators.critique_and_refine import CritiqueAndRefineConvert, CritiqueAndRefineFilter
|
|
@@ -924,6 +931,35 @@ class EmbeddingJoinRule(ImplementationRule):
|
|
|
924
931
|
|
|
925
932
|
return cls._perform_substitution(logical_expression, EmbeddingJoin, runtime_kwargs, variable_op_kwargs)
|
|
926
933
|
|
|
934
|
+
class SemanticAggregateRule(ImplementationRule):
|
|
935
|
+
"""
|
|
936
|
+
Substitute a logical expression for a SemanticAggregate with an llm physical implementation.
|
|
937
|
+
"""
|
|
938
|
+
|
|
939
|
+
@classmethod
|
|
940
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
941
|
+
is_match = isinstance(logical_expression.operator, Aggregate) and logical_expression.operator.agg_str is not None
|
|
942
|
+
logger.debug(f"SemanticAggregateRule matches_pattern: {is_match} for {logical_expression}")
|
|
943
|
+
return is_match
|
|
944
|
+
|
|
945
|
+
@classmethod
|
|
946
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
947
|
+
logger.debug(f"Substituting SemanticAggregateRule for {logical_expression}")
|
|
948
|
+
|
|
949
|
+
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
950
|
+
models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression) and not model.is_llama_model()]
|
|
951
|
+
no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
|
|
952
|
+
variable_op_kwargs = [
|
|
953
|
+
{
|
|
954
|
+
"model": model,
|
|
955
|
+
"prompt_strategy": PromptStrategy.AGG_NO_REASONING if model.is_reasoning_model() and no_reasoning else PromptStrategy.AGG,
|
|
956
|
+
"reasoning_effort": runtime_kwargs["reasoning_effort"]
|
|
957
|
+
}
|
|
958
|
+
for model in models
|
|
959
|
+
]
|
|
960
|
+
|
|
961
|
+
return cls._perform_substitution(logical_expression, SemanticAggregate, runtime_kwargs, variable_op_kwargs)
|
|
962
|
+
|
|
927
963
|
|
|
928
964
|
class AggregateRule(ImplementationRule):
|
|
929
965
|
"""
|
|
@@ -932,7 +968,7 @@ class AggregateRule(ImplementationRule):
|
|
|
932
968
|
|
|
933
969
|
@classmethod
|
|
934
970
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
935
|
-
is_match = isinstance(logical_expression.operator, Aggregate)
|
|
971
|
+
is_match = isinstance(logical_expression.operator, Aggregate) and logical_expression.operator.agg_func is not None
|
|
936
972
|
logger.debug(f"AggregateRule matches_pattern: {is_match} for {logical_expression}")
|
|
937
973
|
return is_match
|
|
938
974
|
|
|
@@ -946,6 +982,10 @@ class AggregateRule(ImplementationRule):
|
|
|
946
982
|
physical_op_class = CountAggregateOp
|
|
947
983
|
elif logical_expression.operator.agg_func == AggFunc.AVERAGE:
|
|
948
984
|
physical_op_class = AverageAggregateOp
|
|
985
|
+
elif logical_expression.operator.agg_func == AggFunc.MIN:
|
|
986
|
+
physical_op_class = MinAggregateOp
|
|
987
|
+
elif logical_expression.operator.agg_func == AggFunc.MAX:
|
|
988
|
+
physical_op_class = MaxAggregateOp
|
|
949
989
|
else:
|
|
950
990
|
raise Exception(f"Cannot support aggregate function: {logical_expression.operator.agg_func}")
|
|
951
991
|
|