palimpzest 0.8.6__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,17 +22,17 @@ OUTPUT FIELDS:
22
22
  - birth_year: the year the scientist was born
23
23
 
24
24
  CONTEXT:
25
- {{
25
+ {
26
26
  "text": "Augusta Ada King, Countess of Lovelace, also known as Ada Lovelace, was an English mathematician and writer chiefly known for her work on Charles Babbage's proposed mechanical general-purpose computer, the Analytical Engine. She was the first to recognise that the machine had applications beyond pure calculation.",
27
27
  "birthday": "December 10, 1815"
28
- }}
28
+ }
29
29
 
30
30
  OUTPUT:
31
31
  --------
32
- {{
32
+ {
33
33
  "name": "Charles Babbage",
34
34
  "birth_year": 1815
35
- }}
35
+ }
36
36
 
37
37
  EVALUATION: {"name": 0.0, "birth_year": 1.0}
38
38
 
@@ -66,18 +66,18 @@ OUTPUT FIELDS:
66
66
  - person_in_image: true if a person is in the image and false otherwise
67
67
 
68
68
  CONTEXT:
69
- {{
69
+ {
70
70
  "image": <bytes>,
71
71
  "photographer": "CameraEnthusiast1"
72
- }}
72
+ }
73
73
  <image content provided here; assume in this example the image shows a dog and a cat playing>
74
74
 
75
75
  OUTPUT:
76
76
  --------
77
- {{
77
+ {
78
78
  "dog_in_image": true,
79
79
  "person_in_image": true
80
- }}
80
+ }
81
81
 
82
82
  EVALUATION: {"dog_in_image": 1.0, "person_in_image": 0.0}
83
83
 
@@ -113,22 +113,22 @@ OUTPUT FIELDS:
113
113
  - birth_year: the year the scientist was born
114
114
 
115
115
  CONTEXT:
116
- {{
116
+ {
117
117
  "text": "Augusta Ada King, Countess of Lovelace, also known as Ada Lovelace, was an English mathematician and writer chiefly known for her work on Charles Babbage's proposed mechanical general-purpose computer, the Analytical Engine. She was the first to recognise that the machine had applications beyond pure calculation.",
118
118
  "birthdays": "...Lovelace was born on December 10, 1815, almost exactly 24 years after Babbage's birth on 26 December 1791..."
119
- }}
119
+ }
120
120
 
121
121
  OUTPUTS:
122
122
  --------
123
123
  [
124
- {{
124
+ {
125
125
  "name": "Ada Lovelace",
126
126
  "birth_year": 1815
127
- }},
128
- {{
127
+ },
128
+ {
129
129
  "name": "Charles Babbage",
130
130
  "birth_year": 1790
131
- }}
131
+ }
132
132
  ]
133
133
 
134
134
  EVALUATION: [{"name": 1.0, "birth_year": 1.0}, {"name": 1.0, "birth_year": 0.0}]
@@ -163,23 +163,23 @@ OUTPUT FIELDS:
163
163
  - animal_is_canine: true if the animal is a canine and false otherwise
164
164
 
165
165
  CONTEXT:
166
- {{
166
+ {
167
167
  "image": <bytes>,
168
168
  "photographer": "CameraEnthusiast1"
169
- }}
169
+ }
170
170
  <image content provided here; assume in this example the image shows a dog and a cat playing>
171
171
 
172
172
  OUTPUT:
173
173
  --------
174
174
  [
175
- {{
175
+ {
176
176
  "animal": "dog",
177
177
  "animal_is_canine": true
178
- }},
179
- {{
178
+ },
179
+ {
180
180
  "animal": "cat",
181
181
  "animal_is_canine": true
182
- }}
182
+ }
183
183
  ]
184
184
 
185
185
  EVALUATION: [{"animal": 1.0, "animal_is_canine": 1.0}, {"animal": 1.0, "animal_is_canine": 0.0}]
@@ -214,20 +214,20 @@ OUTPUT FIELDS:
214
214
  - related_scientists: list of scientists who perform similar work as the scientist described in the text
215
215
 
216
216
  CONTEXT:
217
- {{
217
+ {
218
218
  "text": "Augusta Ada King, Countess of Lovelace, also known as Ada Lovelace, was an English mathematician and writer chiefly known for her work on Charles Babbage's proposed mechanical general-purpose computer, the Analytical Engine. She was the first to recognise that the machine had applications beyond pure calculation.",
219
- }}
219
+ }
220
220
 
221
221
  OUTPUT:
222
222
  --------
223
- {{
223
+ {
224
224
  "related_scientists": [
225
225
  "Charles Babbage",
226
226
  "Alan Turing",
227
227
  "Charles Darwin",
228
228
  "John von Neumann",
229
229
  ]
230
- }}
230
+ }
231
231
 
232
232
  EVALUATION: {"related_scientists": 0.75}
233
233
 
@@ -296,9 +296,9 @@ class Generator(Generic[ContextType, InputType]):
296
296
 
297
297
  return field_answers
298
298
 
299
- def __call__(self, candidate: DataRecord, fields: dict[str, FieldInfo] | None, right_candidate: DataRecord | None = None, json_output: bool=True, **kwargs) -> GenerationOutput:
300
- """Take the input record (`candidate`), generate the output `fields`, and return the generated output."""
301
- logger.debug(f"Generating for candidate {candidate} with fields {fields}")
299
+ def __call__(self, candidate: DataRecord | list[DataRecord], fields: dict[str, FieldInfo] | None, right_candidate: DataRecord | None = None, json_output: bool=True, **kwargs) -> GenerationOutput:
300
+ """Take the input record(s) (`candidate`), generate the output `fields`, and return the generated output."""
301
+ logger.debug(f"Generating for candidate(s) {candidate} with fields {fields}")
302
302
 
303
303
  # fields can only be None if the user provides an answer parser
304
304
  fields_check = fields is not None or "parse_answer" in kwargs
@@ -338,7 +338,7 @@ class Generator(Generic[ContextType, InputType]):
338
338
  reasoning_effort = "minimal" if self.reasoning_effort is None else self.reasoning_effort
339
339
  completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
340
340
  if self.model.is_vllm_model():
341
- completion_kwargs = {"api_base": self.api_base, **completion_kwargs}
341
+ completion_kwargs = {"api_base": self.api_base, "api_key": os.environ.get("VLLM_API_KEY", "fake-api-key") **completion_kwargs}
342
342
  completion = litellm.completion(model=self.model_name, messages=messages, **completion_kwargs)
343
343
  end_time = time.time()
344
344
  logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds")
@@ -405,15 +405,17 @@ class Generator(Generic[ContextType, InputType]):
405
405
 
406
406
  # pretty print prompt + full completion output for debugging
407
407
  completion_text = completion.choices[0].message.content
408
- prompt = ""
408
+ prompt, system_prompt = "", ""
409
409
  for message in messages:
410
+ if message["role"] == "system":
411
+ system_prompt += message["content"] + "\n"
410
412
  if message["role"] == "user":
411
413
  if message["type"] == "text":
412
414
  prompt += message["content"] + "\n"
413
415
  elif message["type"] == "image":
414
- prompt += "<image>\n"
416
+ prompt += "<image>\n" * len(message["content"])
415
417
  elif message["type"] == "input_audio":
416
- prompt += "<audio>\n"
418
+ prompt += "<audio>\n" * len(message["content"])
417
419
  logger.debug(f"PROMPT:\n{prompt}")
418
420
  logger.debug(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
419
421
 
@@ -2,6 +2,9 @@ from palimpzest.query.operators.aggregate import AggregateOp as _AggregateOp
2
2
  from palimpzest.query.operators.aggregate import ApplyGroupByOp as _ApplyGroupByOp
3
3
  from palimpzest.query.operators.aggregate import AverageAggregateOp as _AverageAggregateOp
4
4
  from palimpzest.query.operators.aggregate import CountAggregateOp as _CountAggregateOp
5
+ from palimpzest.query.operators.aggregate import MaxAggregateOp as _MaxAggregateOp
6
+ from palimpzest.query.operators.aggregate import MinAggregateOp as _MinAggregateOp
7
+ from palimpzest.query.operators.aggregate import SemanticAggregate as _SemanticAggregate
5
8
  from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
6
9
  from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
7
10
  from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
@@ -77,7 +80,7 @@ LOGICAL_OPERATORS = [
77
80
 
78
81
  PHYSICAL_OPERATORS = (
79
82
  # aggregate
80
- [_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp]
83
+ [_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp, _MaxAggregateOp, _MinAggregateOp, _SemanticAggregate]
81
84
  # convert
82
85
  + [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
83
86
  # critique and refine
@@ -1,12 +1,22 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import time
4
-
5
- from palimpzest.constants import NAIVE_EST_NUM_GROUPS, AggFunc
4
+ from typing import Any
5
+
6
+ from palimpzest.constants import (
7
+ MODEL_CARDS,
8
+ NAIVE_EST_NUM_GROUPS,
9
+ NAIVE_EST_NUM_INPUT_TOKENS,
10
+ NAIVE_EST_NUM_OUTPUT_TOKENS,
11
+ AggFunc,
12
+ Model,
13
+ PromptStrategy,
14
+ )
6
15
  from palimpzest.core.elements.groupbysig import GroupBySig
7
16
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
8
- from palimpzest.core.lib.schemas import Average, Count
17
+ from palimpzest.core.lib.schemas import Average, Count, Max, Min
9
18
  from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
19
+ from palimpzest.query.generators.generators import Generator
10
20
  from palimpzest.query.operators.physical import PhysicalOperator
11
21
 
12
22
 
@@ -156,12 +166,17 @@ class AverageAggregateOp(AggregateOp):
156
166
 
157
167
  def __init__(self, agg_func: AggFunc, *args, **kwargs):
158
168
  # enforce that output schema is correct
159
- assert kwargs["output_schema"] == Average, "AverageAggregateOp requires output_schema to be Average"
169
+ assert kwargs["output_schema"].model_fields.keys() == Average.model_fields.keys(), "AverageAggregateOp requires output_schema to be Average"
160
170
 
161
171
  # enforce that input schema is a single numeric field
162
172
  input_field_types = list(kwargs["input_schema"].model_fields.values())
163
173
  assert len(input_field_types) == 1, "AverageAggregateOp requires input_schema to have exactly one field"
164
- numeric_field_types = [bool, int, float, bool | None, int | None, float | None, int | float, int | float | None]
174
+ numeric_field_types = [
175
+ bool, int, float, int | float,
176
+ bool | None, int | None, float | None, int | float | None,
177
+ bool | Any, int | Any, float | Any, int | float | Any,
178
+ bool | None | Any, int | None | Any, float | None | Any, int | float | None | Any,
179
+ ]
165
180
  is_numeric = input_field_types[0].annotation in numeric_field_types
166
181
  assert is_numeric, f"AverageAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
167
182
 
@@ -230,7 +245,7 @@ class CountAggregateOp(AggregateOp):
230
245
 
231
246
  def __init__(self, agg_func: AggFunc, *args, **kwargs):
232
247
  # enforce that output schema is correct
233
- assert kwargs["output_schema"] == Count, "CountAggregateOp requires output_schema to be Count"
248
+ assert kwargs["output_schema"].model_fields.keys() == Count.model_fields.keys(), "CountAggregateOp requires output_schema to be Count"
234
249
 
235
250
  # call parent constructor
236
251
  super().__init__(*args, **kwargs)
@@ -280,3 +295,267 @@ class CountAggregateOp(AggregateOp):
280
295
  )
281
296
 
282
297
  return DataRecordSet([dr], [record_op_stats])
298
+
299
+
300
+ class MinAggregateOp(AggregateOp):
301
+ # NOTE: we don't actually need / use agg_func here (yet)
302
+
303
+ def __init__(self, agg_func: AggFunc, *args, **kwargs):
304
+ # enforce that output schema is correct
305
+ assert kwargs["output_schema"].model_fields.keys() == Min.model_fields.keys(), "MinAggregateOp requires output_schema to be Min"
306
+
307
+ # call parent constructor
308
+ super().__init__(*args, **kwargs)
309
+ self.agg_func = agg_func
310
+
311
+ def __str__(self):
312
+ op = super().__str__()
313
+ op += f" Function: {str(self.agg_func)}\n"
314
+ return op
315
+
316
+ def get_id_params(self):
317
+ id_params = super().get_id_params()
318
+ return {"agg_func": str(self.agg_func), **id_params}
319
+
320
+ def get_op_params(self):
321
+ op_params = super().get_op_params()
322
+ return {"agg_func": self.agg_func, **op_params}
323
+
324
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
325
+ # for now, assume applying the aggregation takes negligible additional time (and no cost in USD)
326
+ return OperatorCostEstimates(
327
+ cardinality=1,
328
+ time_per_record=0,
329
+ cost_per_record=0,
330
+ quality=1.0,
331
+ )
332
+
333
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
334
+ start_time = time.time()
335
+
336
+ # create new DataRecord
337
+ min = float("inf")
338
+ for candidate in candidates:
339
+ try: # noqa: SIM105
340
+ min = min(float(list(candidate.to_dict().values())[0]), min)
341
+ except Exception:
342
+ pass
343
+ data_item = Min(min=min if min != float("inf") else None)
344
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
345
+
346
+ # create RecordOpStats object
347
+ record_op_stats = RecordOpStats(
348
+ record_id=dr.id,
349
+ record_parent_ids=dr.parent_ids,
350
+ record_source_indices=dr.source_indices,
351
+ record_state=dr.to_dict(include_bytes=False),
352
+ full_op_id=self.get_full_op_id(),
353
+ logical_op_id=self.logical_op_id,
354
+ op_name=self.op_name(),
355
+ time_per_record=time.time() - start_time,
356
+ cost_per_record=0.0,
357
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
358
+ )
359
+
360
+ return DataRecordSet([dr], [record_op_stats])
361
+
362
+
363
+ class MaxAggregateOp(AggregateOp):
364
+ # NOTE: we don't actually need / use agg_func here (yet)
365
+
366
+ def __init__(self, agg_func: AggFunc, *args, **kwargs):
367
+ # enforce that output schema is correct
368
+ assert kwargs["output_schema"].model_fields.keys() == Max.model_fields.keys(), "MaxAggregateOp requires output_schema to be Max"
369
+
370
+ # call parent constructor
371
+ super().__init__(*args, **kwargs)
372
+ self.agg_func = agg_func
373
+
374
+ def __str__(self):
375
+ op = super().__str__()
376
+ op += f" Function: {str(self.agg_func)}\n"
377
+ return op
378
+
379
+ def get_id_params(self):
380
+ id_params = super().get_id_params()
381
+ return {"agg_func": str(self.agg_func), **id_params}
382
+
383
+ def get_op_params(self):
384
+ op_params = super().get_op_params()
385
+ return {"agg_func": self.agg_func, **op_params}
386
+
387
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
388
+ # for now, assume applying the aggregation takes negligible additional time (and no cost in USD)
389
+ return OperatorCostEstimates(
390
+ cardinality=1,
391
+ time_per_record=0,
392
+ cost_per_record=0,
393
+ quality=1.0,
394
+ )
395
+
396
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
397
+ start_time = time.time()
398
+
399
+ # create new DataRecord
400
+
401
+ max = float("-inf")
402
+ for candidate in candidates:
403
+ try: # noqa: SIM105
404
+ max = max(float(list(candidate.to_dict().values())[0]), max)
405
+ except Exception:
406
+ pass
407
+ data_item = Max(max=max if max != float("-inf") else None)
408
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
409
+
410
+ # create RecordOpStats object
411
+ record_op_stats = RecordOpStats(
412
+ record_id=dr.id,
413
+ record_parent_ids=dr.parent_ids,
414
+ record_source_indices=dr.source_indices,
415
+ record_state=dr.to_dict(include_bytes=False),
416
+ full_op_id=self.get_full_op_id(),
417
+ logical_op_id=self.logical_op_id,
418
+ op_name=self.op_name(),
419
+ time_per_record=time.time() - start_time,
420
+ cost_per_record=0.0,
421
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
422
+ )
423
+
424
+ return DataRecordSet([dr], [record_op_stats])
425
+
426
+
427
+ class SemanticAggregate(AggregateOp):
428
+
429
+ def __init__(self, agg_str: str, model: Model, prompt_strategy: PromptStrategy = PromptStrategy.AGG, reasoning_effort: str | None = None, *args, **kwargs):
430
+ # call parent constructor
431
+ super().__init__(*args, **kwargs)
432
+ self.agg_str = agg_str
433
+ self.model = model
434
+ self.prompt_strategy = prompt_strategy
435
+ self.reasoning_effort = reasoning_effort
436
+ if model is not None:
437
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base)
438
+
439
+ def __str__(self):
440
+ op = super().__str__()
441
+ op += f" Prompt Strategy: {self.prompt_strategy}\n"
442
+ op += f" Reasoning Effort: {self.reasoning_effort}\n"
443
+ op += f" Agg: {str(self.agg_str)}\n"
444
+ return op
445
+
446
+ def get_id_params(self):
447
+ id_params = super().get_id_params()
448
+ id_params = {
449
+ "agg_str": self.agg_str,
450
+ "model": None if self.model is None else self.model.value,
451
+ "prompt_strategy": None if self.prompt_strategy is None else self.prompt_strategy.value,
452
+ "reasoning_effort": self.reasoning_effort,
453
+ **id_params,
454
+ }
455
+
456
+ return id_params
457
+
458
+ def get_op_params(self):
459
+ op_params = super().get_op_params()
460
+ op_params = {
461
+ "agg_str": self.agg_str,
462
+ "model": self.model,
463
+ "prompt_strategy": self.prompt_strategy,
464
+ "reasoning_effort": self.reasoning_effort,
465
+ **op_params,
466
+ }
467
+
468
+ return op_params
469
+
470
+ def get_model_name(self) -> str:
471
+ return self.model.value
472
+
473
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
474
+ """
475
+ Compute naive cost estimates for the LLMConvert operation. Implicitly, these estimates
476
+ assume the use of a single LLM call for each input record. Child classes of LLMConvert
477
+ may call this function through super() and adjust these estimates as needed (or they can
478
+ completely override this function).
479
+ """
480
+ # estimate number of input and output tokens from source
481
+ est_num_input_tokens = NAIVE_EST_NUM_INPUT_TOKENS * source_op_cost_estimates.cardinality
482
+ est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
483
+
484
+ # get est. of conversion time per record from model card;
485
+ model_name = self.model.value
486
+ model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
487
+
488
+ # get est. of conversion cost (in USD) per record from model card
489
+ usd_per_input_token = MODEL_CARDS[model_name].get("usd_per_input_token")
490
+ if getattr(self, "prompt_strategy", None) is not None and self.prompt_strategy.is_audio_prompt():
491
+ usd_per_input_token = MODEL_CARDS[model_name]["usd_per_audio_input_token"]
492
+
493
+ model_conversion_usd_per_record = (
494
+ usd_per_input_token * est_num_input_tokens
495
+ + MODEL_CARDS[model_name]["usd_per_output_token"] * est_num_output_tokens
496
+ )
497
+
498
+ # estimate quality of output based on the strength of the model being used
499
+ quality = (MODEL_CARDS[model_name]["overall"] / 100.0)
500
+
501
+ return OperatorCostEstimates(
502
+ cardinality=1.0,
503
+ time_per_record=model_conversion_time_per_record,
504
+ cost_per_record=model_conversion_usd_per_record,
505
+ quality=quality,
506
+ )
507
+
508
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
509
+ start_time = time.time()
510
+
511
+ # TODO: if candidates is an empty list, return an empty DataRecordSet
512
+ if len(candidates) == 0:
513
+ return DataRecordSet([], [])
514
+
515
+ # get the set of input fields to use for the operation
516
+ input_fields = self.get_input_fields()
517
+
518
+ # get the set of output fields to use for the operation
519
+ fields_to_generate = self.get_fields_to_generate(candidates[0])
520
+ fields = {field: field_type for field, field_type in self.output_schema.model_fields.items() if field in fields_to_generate}
521
+
522
+ # construct kwargs for generation
523
+ gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema, "agg_instruction": self.agg_str}
524
+
525
+ # generate outputs for all fields in a single query
526
+ field_answers, _, generation_stats, _ = self.generator(candidates, fields, **gen_kwargs)
527
+ assert all([field in field_answers for field in fields]), "Not all fields were generated!"
528
+
529
+ # construct data record for the output
530
+ field, value = fields_to_generate[0], field_answers[fields_to_generate[0]][0]
531
+ data_item = self.output_schema(**{field: value})
532
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
533
+
534
+ # create RecordOpStats object
535
+ record_op_stats = RecordOpStats(
536
+ record_id=dr._id,
537
+ record_parent_ids=dr._parent_ids,
538
+ record_source_indices=dr._source_indices,
539
+ record_state=dr.to_dict(include_bytes=False),
540
+ full_op_id=self.get_full_op_id(),
541
+ logical_op_id=self.logical_op_id,
542
+ op_name=self.op_name(),
543
+ time_per_record=time.time() - start_time,
544
+ cost_per_record=generation_stats.cost_per_record,
545
+ model_name=self.get_model_name(),
546
+ answer={field: value},
547
+ input_fields=input_fields,
548
+ generated_fields=fields_to_generate,
549
+ total_input_tokens=generation_stats.total_input_tokens,
550
+ total_output_tokens=generation_stats.total_output_tokens,
551
+ total_input_cost=generation_stats.total_input_cost,
552
+ total_output_cost=generation_stats.total_output_cost,
553
+ llm_call_duration_secs=generation_stats.llm_call_duration_secs,
554
+ fn_call_duration_secs=generation_stats.fn_call_duration_secs,
555
+ total_llm_calls=generation_stats.total_llm_calls,
556
+ total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
557
+ image_operation=self.is_image_op(),
558
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
559
+ )
560
+
561
+ return DataRecordSet([dr], [record_op_stats])
@@ -9,7 +9,7 @@ from palimpzest.constants import AggFunc, Cardinality
9
9
  from palimpzest.core.data import context, dataset
10
10
  from palimpzest.core.elements.filters import Filter
11
11
  from palimpzest.core.elements.groupbysig import GroupBySig
12
- from palimpzest.core.lib.schemas import Average, Count
12
+ from palimpzest.core.lib.schemas import Average, Count, Max, Min
13
13
  from palimpzest.utils.hash_helpers import hash_for_id
14
14
 
15
15
 
@@ -149,27 +149,39 @@ class Aggregate(LogicalOperator):
149
149
 
150
150
  def __init__(
151
151
  self,
152
- agg_func: AggFunc,
152
+ agg_func: AggFunc | None = None,
153
+ agg_str: str | None = None,
153
154
  *args,
154
155
  **kwargs,
155
156
  ):
157
+ assert agg_func is not None or agg_str is not None, "Either agg_func or agg_str must be provided"
156
158
  if kwargs.get("output_schema") is None:
157
159
  if agg_func == AggFunc.COUNT:
158
160
  kwargs["output_schema"] = Count
159
161
  elif agg_func == AggFunc.AVERAGE:
160
162
  kwargs["output_schema"] = Average
163
+ elif agg_func == AggFunc.MIN:
164
+ kwargs["output_schema"] = Min
165
+ elif agg_func == AggFunc.MAX:
166
+ kwargs["output_schema"] = Max
161
167
  else:
162
168
  raise ValueError(f"Unsupported aggregation function: {agg_func}")
163
169
 
164
170
  super().__init__(*args, **kwargs)
165
171
  self.agg_func = agg_func
172
+ self.agg_str = agg_str
166
173
 
167
174
  def __str__(self):
168
- return f"{self.__class__.__name__}(function: {str(self.agg_func.value)})"
175
+ desc = f"function: {str(self.agg_func.value)}" if self.agg_func else f"agg: {self.agg_str}"
176
+ return f"{self.__class__.__name__}({desc})"
169
177
 
170
178
  def get_logical_id_params(self) -> dict:
171
179
  logical_id_params = super().get_logical_id_params()
172
- logical_id_params = {"agg_func": self.agg_func, **logical_id_params}
180
+ logical_id_params = {
181
+ "agg_func": self.agg_func,
182
+ "agg_str": self.agg_str,
183
+ **logical_id_params,
184
+ }
173
185
 
174
186
  return logical_id_params
175
187
 
@@ -177,6 +189,7 @@ class Aggregate(LogicalOperator):
177
189
  logical_op_params = super().get_logical_op_params()
178
190
  logical_op_params = {
179
191
  "agg_func": self.agg_func,
192
+ "agg_str": self.agg_str,
180
193
  **logical_op_params,
181
194
  }
182
195
 
@@ -47,6 +47,9 @@ from palimpzest.query.optimizer.rules import (
47
47
  from palimpzest.query.optimizer.rules import (
48
48
  Rule as _Rule,
49
49
  )
50
+ from palimpzest.query.optimizer.rules import (
51
+ SemanticAggregateRule as _SemanticAggregateRule,
52
+ )
50
53
  from palimpzest.query.optimizer.rules import (
51
54
  SplitRule as _SplitRule,
52
55
  )
@@ -72,6 +75,7 @@ ALL_RULES = [
72
75
  _ReorderConverts,
73
76
  _RetrieveRule,
74
77
  _Rule,
78
+ _SemanticAggregateRule,
75
79
  _SplitRule,
76
80
  _TransformationRule,
77
81
  ]
@@ -12,7 +12,14 @@ from palimpzest.core.lib.schemas import (
12
12
  IMAGE_LIST_FIELD_TYPES,
13
13
  )
14
14
  from palimpzest.prompts import CONTEXT_SEARCH_PROMPT
15
- from palimpzest.query.operators.aggregate import ApplyGroupByOp, AverageAggregateOp, CountAggregateOp
15
+ from palimpzest.query.operators.aggregate import (
16
+ ApplyGroupByOp,
17
+ AverageAggregateOp,
18
+ CountAggregateOp,
19
+ MaxAggregateOp,
20
+ MinAggregateOp,
21
+ SemanticAggregate,
22
+ )
16
23
  from palimpzest.query.operators.compute import SmolAgentsCompute
17
24
  from palimpzest.query.operators.convert import LLMConvertBonded, NonLLMConvert
18
25
  from palimpzest.query.operators.critique_and_refine import CritiqueAndRefineConvert, CritiqueAndRefineFilter
@@ -924,6 +931,35 @@ class EmbeddingJoinRule(ImplementationRule):
924
931
 
925
932
  return cls._perform_substitution(logical_expression, EmbeddingJoin, runtime_kwargs, variable_op_kwargs)
926
933
 
934
+ class SemanticAggregateRule(ImplementationRule):
935
+ """
936
+ Substitute a logical expression for a SemanticAggregate with an llm physical implementation.
937
+ """
938
+
939
+ @classmethod
940
+ def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
941
+ is_match = isinstance(logical_expression.operator, Aggregate) and logical_expression.operator.agg_str is not None
942
+ logger.debug(f"SemanticAggregateRule matches_pattern: {is_match} for {logical_expression}")
943
+ return is_match
944
+
945
+ @classmethod
946
+ def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
947
+ logger.debug(f"Substituting SemanticAggregateRule for {logical_expression}")
948
+
949
+ # create variable physical operator kwargs for each model which can implement this logical_expression
950
+ models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression) and not model.is_llama_model()]
951
+ no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
952
+ variable_op_kwargs = [
953
+ {
954
+ "model": model,
955
+ "prompt_strategy": PromptStrategy.AGG_NO_REASONING if model.is_reasoning_model() and no_reasoning else PromptStrategy.AGG,
956
+ "reasoning_effort": runtime_kwargs["reasoning_effort"]
957
+ }
958
+ for model in models
959
+ ]
960
+
961
+ return cls._perform_substitution(logical_expression, SemanticAggregate, runtime_kwargs, variable_op_kwargs)
962
+
927
963
 
928
964
  class AggregateRule(ImplementationRule):
929
965
  """
@@ -932,7 +968,7 @@ class AggregateRule(ImplementationRule):
932
968
 
933
969
  @classmethod
934
970
  def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
935
- is_match = isinstance(logical_expression.operator, Aggregate)
971
+ is_match = isinstance(logical_expression.operator, Aggregate) and logical_expression.operator.agg_func is not None
936
972
  logger.debug(f"AggregateRule matches_pattern: {is_match} for {logical_expression}")
937
973
  return is_match
938
974
 
@@ -946,6 +982,10 @@ class AggregateRule(ImplementationRule):
946
982
  physical_op_class = CountAggregateOp
947
983
  elif logical_expression.operator.agg_func == AggFunc.AVERAGE:
948
984
  physical_op_class = AverageAggregateOp
985
+ elif logical_expression.operator.agg_func == AggFunc.MIN:
986
+ physical_op_class = MinAggregateOp
987
+ elif logical_expression.operator.agg_func == AggFunc.MAX:
988
+ physical_op_class = MaxAggregateOp
949
989
  else:
950
990
  raise Exception(f"Cannot support aggregate function: {logical_expression.operator.agg_func}")
951
991