palimpzest 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. palimpzest/constants.py +1 -0
  2. palimpzest/core/data/dataset.py +33 -5
  3. palimpzest/core/elements/groupbysig.py +10 -1
  4. palimpzest/core/elements/records.py +16 -7
  5. palimpzest/core/lib/schemas.py +20 -3
  6. palimpzest/core/models.py +10 -4
  7. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  8. palimpzest/query/execution/execution_strategy.py +13 -11
  9. palimpzest/query/execution/mab_execution_strategy.py +40 -14
  10. palimpzest/query/execution/parallel_execution_strategy.py +31 -7
  11. palimpzest/query/execution/single_threaded_execution_strategy.py +23 -6
  12. palimpzest/query/generators/generators.py +1 -1
  13. palimpzest/query/operators/__init__.py +7 -6
  14. palimpzest/query/operators/aggregate.py +110 -5
  15. palimpzest/query/operators/convert.py +1 -1
  16. palimpzest/query/operators/join.py +279 -23
  17. palimpzest/query/operators/logical.py +20 -8
  18. palimpzest/query/operators/mixture_of_agents.py +3 -1
  19. palimpzest/query/operators/physical.py +5 -2
  20. palimpzest/query/operators/rag.py +5 -4
  21. palimpzest/query/operators/{retrieve.py → topk.py} +10 -10
  22. palimpzest/query/optimizer/__init__.py +7 -3
  23. palimpzest/query/optimizer/cost_model.py +5 -5
  24. palimpzest/query/optimizer/optimizer.py +3 -2
  25. palimpzest/query/optimizer/plan.py +2 -3
  26. palimpzest/query/optimizer/rules.py +31 -11
  27. palimpzest/query/optimizer/tasks.py +4 -4
  28. palimpzest/query/processor/config.py +1 -0
  29. palimpzest/utils/progress.py +51 -23
  30. palimpzest/validator/validator.py +7 -7
  31. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/METADATA +26 -66
  32. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/RECORD +35 -35
  33. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/WHEEL +0 -0
  34. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/licenses/LICENSE +0 -0
  35. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/top_level.txt +0 -0
@@ -338,7 +338,7 @@ class Generator(Generic[ContextType, InputType]):
338
338
  reasoning_effort = "minimal" if self.reasoning_effort is None else self.reasoning_effort
339
339
  completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
340
340
  if self.model.is_vllm_model():
341
- completion_kwargs = {"api_base": self.api_base, "api_key": os.environ.get("VLLM_API_KEY", "fake-api-key") **completion_kwargs}
341
+ completion_kwargs = {"api_base": self.api_base, "api_key": os.environ.get("VLLM_API_KEY", "fake-api-key"), **completion_kwargs}
342
342
  completion = litellm.completion(model=self.model_name, messages=messages, **completion_kwargs)
343
343
  end_time = time.time()
344
344
  logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds")
@@ -5,6 +5,7 @@ from palimpzest.query.operators.aggregate import CountAggregateOp as _CountAggre
5
5
  from palimpzest.query.operators.aggregate import MaxAggregateOp as _MaxAggregateOp
6
6
  from palimpzest.query.operators.aggregate import MinAggregateOp as _MinAggregateOp
7
7
  from palimpzest.query.operators.aggregate import SemanticAggregate as _SemanticAggregate
8
+ from palimpzest.query.operators.aggregate import SumAggregateOp as _SumAggregateOp
8
9
  from palimpzest.query.operators.convert import ConvertOp as _ConvertOp
9
10
  from palimpzest.query.operators.convert import LLMConvert as _LLMConvert
10
11
  from palimpzest.query.operators.convert import LLMConvertBonded as _LLMConvertBonded
@@ -50,7 +51,7 @@ from palimpzest.query.operators.logical import (
50
51
  Project as _Project,
51
52
  )
52
53
  from palimpzest.query.operators.logical import (
53
- RetrieveScan as _RetrieveScan,
54
+ TopKScan as _TopKScan,
54
55
  )
55
56
  from palimpzest.query.operators.mixture_of_agents import MixtureOfAgentsConvert as _MixtureOfAgentsConvert
56
57
  from palimpzest.query.operators.mixture_of_agents import MixtureOfAgentsFilter as _MixtureOfAgentsFilter
@@ -58,11 +59,11 @@ from palimpzest.query.operators.physical import PhysicalOperator as _PhysicalOpe
58
59
  from palimpzest.query.operators.project import ProjectOp as _ProjectOp
59
60
  from palimpzest.query.operators.rag import RAGConvert as _RAGConvert
60
61
  from palimpzest.query.operators.rag import RAGFilter as _RAGFilter
61
- from palimpzest.query.operators.retrieve import RetrieveOp as _RetrieveOp
62
62
  from palimpzest.query.operators.scan import MarshalAndScanDataOp as _MarshalAndScanDataOp
63
63
  from palimpzest.query.operators.scan import ScanPhysicalOp as _ScanPhysicalOp
64
64
  from palimpzest.query.operators.split import SplitConvert as _SplitConvert
65
65
  from palimpzest.query.operators.split import SplitFilter as _SplitFilter
66
+ from palimpzest.query.operators.topk import TopKOp as _TopKOp
66
67
 
67
68
  LOGICAL_OPERATORS = [
68
69
  _LogicalOperator,
@@ -75,12 +76,12 @@ LOGICAL_OPERATORS = [
75
76
  _LogicalJoinOp,
76
77
  _LimitScan,
77
78
  _Project,
78
- _RetrieveScan,
79
+ _TopKScan,
79
80
  ]
80
81
 
81
82
  PHYSICAL_OPERATORS = (
82
83
  # aggregate
83
- [_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp, _MaxAggregateOp, _MinAggregateOp, _SemanticAggregate]
84
+ [_AggregateOp, _ApplyGroupByOp, _AverageAggregateOp, _CountAggregateOp, _MaxAggregateOp, _MinAggregateOp, _SemanticAggregate, _SumAggregateOp]
84
85
  # convert
85
86
  + [_ConvertOp, _NonLLMConvert, _LLMConvert, _LLMConvertBonded]
86
87
  # critique and refine
@@ -103,8 +104,8 @@ PHYSICAL_OPERATORS = (
103
104
  + [_ProjectOp]
104
105
  # rag
105
106
  + [_RAGConvert, _RAGFilter]
106
- # retrieve
107
- + [_RetrieveOp]
107
+ # top-k
108
+ + [_TopKOp]
108
109
  # split
109
110
  + [_SplitConvert, _SplitFilter]
110
111
  )
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import contextlib
3
4
  import time
4
5
  from typing import Any
5
6
 
@@ -14,7 +15,7 @@ from palimpzest.constants import (
14
15
  )
15
16
  from palimpzest.core.elements.groupbysig import GroupBySig
16
17
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
17
- from palimpzest.core.lib.schemas import Average, Count, Max, Min
18
+ from palimpzest.core.lib.schemas import Average, Count, Max, Min, Sum
18
19
  from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
19
20
  from palimpzest.query.generators.generators import Generator
20
21
  from palimpzest.query.operators.physical import PhysicalOperator
@@ -68,6 +69,16 @@ class ApplyGroupByOp(AggregateOp):
68
69
  return 0
69
70
  elif func.lower() == "average":
70
71
  return (0, 0)
72
+ elif func.lower() == "sum":
73
+ return 0
74
+ elif func.lower() == "min":
75
+ return float("inf")
76
+ elif func.lower() == "max":
77
+ return float("-inf")
78
+ elif func.lower() == "list":
79
+ return []
80
+ elif func.lower() == "set":
81
+ return set()
71
82
  else:
72
83
  raise Exception("Unknown agg function " + func)
73
84
 
@@ -76,16 +87,34 @@ class ApplyGroupByOp(AggregateOp):
76
87
  if func.lower() == "count":
77
88
  return state + 1
78
89
  elif func.lower() == "average":
79
- sum, cnt = state
90
+ sum_, cnt = state
91
+ if val is None:
92
+ return (sum_, cnt)
93
+ return (sum_ + val, cnt + 1)
94
+ elif func.lower() == "sum":
95
+ if val is None:
96
+ return state
97
+ return state + sum(val) if isinstance(val, list) else state + val
98
+ elif func.lower() == "min":
99
+ if val is None:
100
+ return state
101
+ return min(state, min(val) if isinstance(val, list) else val)
102
+ elif func.lower() == "max":
80
103
  if val is None:
81
- return (sum, cnt)
82
- return (sum + val, cnt + 1)
104
+ return state
105
+ return max(state, max(val) if isinstance(val, list) else val)
106
+ elif func.lower() == "list":
107
+ state.append(val)
108
+ return state
109
+ elif func.lower() == "set":
110
+ state.add(val)
111
+ return state
83
112
  else:
84
113
  raise Exception("Unknown agg function " + func)
85
114
 
86
115
  @staticmethod
87
116
  def agg_final(func, state):
88
- if func.lower() == "count":
117
+ if func.lower() in ["count", "sum", "min", "max", "list", "set"]:
89
118
  return state
90
119
  elif func.lower() == "average":
91
120
  sum, cnt = state
@@ -240,6 +269,82 @@ class AverageAggregateOp(AggregateOp):
240
269
  return DataRecordSet([dr], [record_op_stats])
241
270
 
242
271
 
272
+ class SumAggregateOp(AggregateOp):
273
+ # NOTE: we don't actually need / use agg_func here (yet)
274
+
275
+ def __init__(self, agg_func: AggFunc, *args, **kwargs):
276
+ # enforce that output schema is correct
277
+ assert kwargs["output_schema"].model_fields.keys() == Sum.model_fields.keys(), "SumAggregateOp requires output_schema to be Sum"
278
+
279
+ # enforce that input schema is a single numeric field
280
+ input_field_types = list(kwargs["input_schema"].model_fields.values())
281
+ assert len(input_field_types) == 1, "SumAggregateOp requires input_schema to have exactly one field"
282
+ numeric_field_types = [
283
+ bool, int, float, int | float,
284
+ bool | None, int | None, float | None, int | float | None,
285
+ bool | Any, int | Any, float | Any, int | float | Any,
286
+ bool | None | Any, int | None | Any, float | None | Any, int | float | None | Any,
287
+ ]
288
+ is_numeric = input_field_types[0].annotation in numeric_field_types
289
+ assert is_numeric, f"SumAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
290
+
291
+ # call parent constructor
292
+ super().__init__(*args, **kwargs)
293
+ self.agg_func = agg_func
294
+
295
+ def __str__(self):
296
+ op = super().__str__()
297
+ op += f" Function: {str(self.agg_func)}\n"
298
+ return op
299
+
300
+ def get_id_params(self):
301
+ id_params = super().get_id_params()
302
+ return {"agg_func": str(self.agg_func), **id_params}
303
+
304
+ def get_op_params(self):
305
+ op_params = super().get_op_params()
306
+ return {"agg_func": self.agg_func, **op_params}
307
+
308
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
309
+ # for now, assume applying the aggregation takes negligible additional time (and no cost in USD)
310
+ return OperatorCostEstimates(
311
+ cardinality=1,
312
+ time_per_record=0,
313
+ cost_per_record=0,
314
+ quality=1.0,
315
+ )
316
+
317
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
318
+ start_time = time.time()
319
+
320
+ # NOTE: we currently do not guarantee that input values conform to their specified type;
321
+ # as a result, we simply omit any values which do not parse to a float from the average
322
+ # NOTE: right now we perform a check in the constructor which enforces that the input_schema
323
+ # has a single field which is numeric in nature; in the future we may want to have a
324
+ # cleaner way of computing the value (rather than `float(list(candidate...))` below)
325
+ summation = 0
326
+ for candidate in candidates:
327
+ with contextlib.suppress(Exception):
328
+ summation += float(list(candidate.to_dict().values())[0])
329
+ data_item = Sum(sum=summation)
330
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
331
+
332
+ # create RecordOpStats object
333
+ record_op_stats = RecordOpStats(
334
+ record_id=dr._id,
335
+ record_parent_ids=dr._parent_ids,
336
+ record_source_indices=dr._source_indices,
337
+ record_state=dr.to_dict(include_bytes=False),
338
+ full_op_id=self.get_full_op_id(),
339
+ logical_op_id=self.logical_op_id,
340
+ op_name=self.op_name(),
341
+ time_per_record=time.time() - start_time,
342
+ cost_per_record=0.0,
343
+ )
344
+
345
+ return DataRecordSet([dr], [record_op_stats])
346
+
347
+
243
348
  class CountAggregateOp(AggregateOp):
244
349
  # NOTE: we don't actually need / use agg_func here (yet)
245
350
 
@@ -320,7 +320,7 @@ class LLMConvert(ConvertOp):
320
320
  est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
321
321
 
322
322
  # get est. of conversion time per record from model card;
323
- model_name = self.model.value if getattr(self, "model", None) is not None else Model.GPT_4o_MINI.value
323
+ model_name = self.model.value
324
324
  model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
325
325
 
326
326
  # get est. of conversion cost (in USD) per record from model card