palimpzest 0.8.7__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. palimpzest/constants.py +13 -4
  2. palimpzest/core/data/dataset.py +75 -5
  3. palimpzest/core/elements/groupbysig.py +5 -1
  4. palimpzest/core/elements/records.py +16 -7
  5. palimpzest/core/lib/schemas.py +26 -3
  6. palimpzest/core/models.py +4 -4
  7. palimpzest/prompts/aggregate_prompts.py +99 -0
  8. palimpzest/prompts/prompt_factory.py +162 -75
  9. palimpzest/prompts/utils.py +38 -1
  10. palimpzest/prompts/validator.py +24 -24
  11. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  12. palimpzest/query/execution/execution_strategy.py +8 -8
  13. palimpzest/query/execution/mab_execution_strategy.py +30 -11
  14. palimpzest/query/execution/parallel_execution_strategy.py +31 -7
  15. palimpzest/query/execution/single_threaded_execution_strategy.py +23 -6
  16. palimpzest/query/generators/generators.py +9 -7
  17. palimpzest/query/operators/__init__.py +10 -6
  18. palimpzest/query/operators/aggregate.py +394 -10
  19. palimpzest/query/operators/convert.py +1 -1
  20. palimpzest/query/operators/join.py +279 -23
  21. palimpzest/query/operators/logical.py +36 -11
  22. palimpzest/query/operators/mixture_of_agents.py +3 -1
  23. palimpzest/query/operators/physical.py +5 -2
  24. palimpzest/query/operators/{retrieve.py → topk.py} +10 -10
  25. palimpzest/query/optimizer/__init__.py +11 -3
  26. palimpzest/query/optimizer/cost_model.py +5 -5
  27. palimpzest/query/optimizer/optimizer.py +3 -2
  28. palimpzest/query/optimizer/plan.py +2 -3
  29. palimpzest/query/optimizer/rules.py +73 -13
  30. palimpzest/query/optimizer/tasks.py +4 -4
  31. palimpzest/utils/progress.py +19 -17
  32. palimpzest/validator/validator.py +7 -7
  33. {palimpzest-0.8.7.dist-info → palimpzest-1.0.0.dist-info}/METADATA +26 -66
  34. {palimpzest-0.8.7.dist-info → palimpzest-1.0.0.dist-info}/RECORD +37 -36
  35. {palimpzest-0.8.7.dist-info → palimpzest-1.0.0.dist-info}/WHEEL +0 -0
  36. {palimpzest-0.8.7.dist-info → palimpzest-1.0.0.dist-info}/licenses/LICENSE +0 -0
  37. {palimpzest-0.8.7.dist-info → palimpzest-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import threading
3
4
  import time
4
5
  from abc import ABC, abstractmethod
5
6
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -37,10 +38,9 @@ class JoinOp(PhysicalOperator, ABC):
37
38
  def __init__(
38
39
  self,
39
40
  condition: str,
40
- model: Model,
41
- prompt_strategy: PromptStrategy = PromptStrategy.JOIN,
41
+ how: str = "inner",
42
+ on: list[str] | None = None,
42
43
  join_parallelism: int = 64,
43
- reasoning_effort: str | None = None,
44
44
  retain_inputs: bool = True,
45
45
  desc: str | None = None,
46
46
  *args,
@@ -49,33 +49,37 @@ class JoinOp(PhysicalOperator, ABC):
49
49
  super().__init__(*args, **kwargs)
50
50
  assert self.input_schema == self.output_schema, "Input and output schemas must match for JoinOp"
51
51
  self.condition = condition
52
- self.model = model
53
- self.prompt_strategy = prompt_strategy
52
+ self.how = how
53
+ self.on = on
54
54
  self.join_parallelism = join_parallelism
55
- self.reasoning_effort = reasoning_effort
56
55
  self.retain_inputs = retain_inputs
57
56
  self.desc = desc
58
- self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
59
57
  self.join_idx = 0
58
+ self.finished = False
60
59
 
61
60
  # maintain list(s) of input records for the join
62
61
  self._left_input_records: list[DataRecord] = []
63
62
  self._right_input_records: list[DataRecord] = []
64
63
 
64
+ # maintain set of left/right record ids that have been joined (for left/right/outer joins)
65
+ self._left_joined_record_ids: set[str] = set()
66
+ self._right_joined_record_ids: set[str] = set()
67
+
65
68
  def __str__(self):
66
69
  op = super().__str__()
67
70
  op += f" Condition: {self.condition}\n"
71
+ op += f" How: {self.how}\n"
72
+ op += f" On: {self.on}\n"
68
73
  return op
69
74
 
70
75
  def get_id_params(self):
71
76
  id_params = super().get_id_params()
72
77
  id_params = {
73
78
  "condition": self.condition,
74
- "model": self.model.value,
75
- "prompt_strategy": self.prompt_strategy.value,
76
79
  "join_parallelism": self.join_parallelism,
77
- "reasoning_effort": self.reasoning_effort,
78
80
  "desc": self.desc,
81
+ "how": self.how,
82
+ "on": self.on,
79
83
  **id_params,
80
84
  }
81
85
  return id_params
@@ -84,23 +88,232 @@ class JoinOp(PhysicalOperator, ABC):
84
88
  op_params = super().get_op_params()
85
89
  op_params = {
86
90
  "condition": self.condition,
87
- "model": self.model,
88
- "prompt_strategy": self.prompt_strategy,
89
91
  "join_parallelism": self.join_parallelism,
90
- "reasoning_effort": self.reasoning_effort,
91
92
  "retain_inputs": self.retain_inputs,
92
93
  "desc": self.desc,
94
+ "how": self.how,
95
+ "on": self.on,
93
96
  **op_params,
94
97
  }
95
98
  return op_params
96
99
 
97
- def get_model_name(self):
98
- return self.model.value
100
+ def _compute_unmatched_records(self) -> DataRecordSet:
101
+ """Helper function to compute unmatched records for left/right/outer joins."""
102
+ def join_unmatched_records(input_records: list[DataRecord] | list[tuple[DataRecord, list[float]]], joined_record_ids: set[str], left: bool = True):
103
+ records, record_op_stats_lst = [], []
104
+ for record in input_records:
105
+ start_time = time.time()
106
+ record = record[0] if isinstance(record, tuple) else record
107
+ if record._id not in joined_record_ids:
108
+ unmatched_dr = (
109
+ DataRecord.from_join_parents(self.output_schema, record, None)
110
+ if left
111
+ else DataRecord.from_join_parents(self.output_schema, None, record)
112
+ )
113
+ unmatched_dr._passed_operator = True
114
+
115
+ # compute record stats and add to output_record_op_stats
116
+ time_per_record = time.time() - start_time
117
+ record_op_stats = RecordOpStats(
118
+ record_id=unmatched_dr._id,
119
+ record_parent_ids=unmatched_dr._parent_ids,
120
+ record_source_indices=unmatched_dr._source_indices,
121
+ record_state=unmatched_dr.to_dict(include_bytes=False),
122
+ full_op_id=self.get_full_op_id(),
123
+ logical_op_id=self.logical_op_id,
124
+ op_name=self.op_name(),
125
+ time_per_record=time_per_record,
126
+ cost_per_record=0.0,
127
+ model_name=self.get_model_name(),
128
+ join_condition=str(self.on),
129
+ fn_call_duration_secs=time_per_record,
130
+ answer={"passed_operator": True},
131
+ passed_operator=True,
132
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
133
+ )
134
+ records.append(unmatched_dr)
135
+ record_op_stats_lst.append(record_op_stats)
136
+ return records, record_op_stats_lst
137
+
138
+ records, record_op_stats = [], []
139
+ if self.how == "left":
140
+ records, record_op_stats = join_unmatched_records(self._left_input_records, self._left_joined_record_ids, left=True)
141
+
142
+ elif self.how == "right":
143
+ records, record_op_stats = join_unmatched_records(self._right_input_records, self._right_joined_record_ids, left=False)
144
+
145
+ elif self.how == "outer":
146
+ records, record_op_stats = join_unmatched_records(self._left_input_records, self._left_joined_record_ids, left=True)
147
+ right_records, right_record_op_stats = join_unmatched_records(self._right_input_records, self._right_joined_record_ids, left=False)
148
+ records.extend(right_records)
149
+ record_op_stats.extend(right_record_op_stats)
150
+
151
+ return DataRecordSet(records, record_op_stats)
99
152
 
100
153
  @abstractmethod
101
154
  def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
102
155
  pass
103
156
 
157
+ def set_finished(self):
158
+ """Mark the operator as finished after computing left/right/outer join logic."""
159
+ self.finished = True
160
+
161
+ class RelationalJoin(JoinOp):
162
+
163
+ def get_model_name(self):
164
+ return None
165
+
166
+ def _process_join_candidate_pair(self, left_candidate, right_candidate) -> tuple[DataRecord, RecordOpStats]:
167
+ start_time = time.time()
168
+
169
+ # determine whether or not the join was satisfied
170
+ passed_operator = all(
171
+ left_candidate[field] == right_candidate[field]
172
+ for field in self.on
173
+ )
174
+
175
+ # handle different join types
176
+ if self.how == "left" and passed_operator:
177
+ self._left_joined_record_ids.add(left_candidate._id)
178
+ elif self.how == "right" and passed_operator:
179
+ self._right_joined_record_ids.add(right_candidate._id)
180
+ elif self.how == "outer" and passed_operator:
181
+ self._left_joined_record_ids.add(left_candidate._id)
182
+ self._right_joined_record_ids.add(right_candidate._id)
183
+
184
+ # compute output record and add to output_records
185
+ join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
186
+ join_dr._passed_operator = passed_operator
187
+
188
+ # compute record stats and add to output_record_op_stats
189
+ time_per_record = time.time() - start_time
190
+ record_op_stats = RecordOpStats(
191
+ record_id=join_dr._id,
192
+ record_parent_ids=join_dr._parent_ids,
193
+ record_source_indices=join_dr._source_indices,
194
+ record_state=join_dr.to_dict(include_bytes=False),
195
+ full_op_id=self.get_full_op_id(),
196
+ logical_op_id=self.logical_op_id,
197
+ op_name=self.op_name(),
198
+ time_per_record=time_per_record,
199
+ cost_per_record=0.0,
200
+ model_name=self.get_model_name(),
201
+ join_condition=str(self.on),
202
+ fn_call_duration_secs=time_per_record,
203
+ answer={"passed_operator": passed_operator},
204
+ passed_operator=passed_operator,
205
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
206
+ )
207
+
208
+ return join_dr, record_op_stats
209
+
210
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
211
+ # estimate output cardinality using a constant assumption of the filter selectivity
212
+ selectivity = NAIVE_EST_JOIN_SELECTIVITY
213
+ cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
214
+
215
+ # estimate 1 ms execution time per input record pair
216
+ time_per_record = 0.001 * (left_source_op_cost_estimates.cardinality + right_source_op_cost_estimates.cardinality)
217
+
218
+ return OperatorCostEstimates(
219
+ cardinality=cardinality,
220
+ time_per_record=time_per_record,
221
+ cost_per_record=0.0,
222
+ quality=1.0,
223
+ )
224
+
225
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
226
+ # create the set of candidates to join
227
+ join_candidates = []
228
+ for candidate in left_candidates:
229
+ for right_candidate in right_candidates:
230
+ join_candidates.append((candidate, right_candidate))
231
+ for right_candidate in self._right_input_records:
232
+ join_candidates.append((candidate, right_candidate))
233
+ for candidate in self._left_input_records:
234
+ for right_candidate in right_candidates:
235
+ join_candidates.append((candidate, right_candidate))
236
+
237
+ # apply the join logic to each pair of candidates
238
+ output_records, output_record_op_stats = [], []
239
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
240
+ futures = [
241
+ executor.submit(self._process_join_candidate_pair, candidate, right_candidate)
242
+ for candidate, right_candidate in join_candidates
243
+ ]
244
+
245
+ # collect results as they complete
246
+ for future in as_completed(futures):
247
+ self.join_idx += 1
248
+ join_output_record, join_output_record_op_stats = future.result()
249
+ output_records.append(join_output_record)
250
+ output_record_op_stats.append(join_output_record_op_stats)
251
+
252
+ # compute the number of inputs processed
253
+ num_inputs_processed = len(join_candidates)
254
+
255
+ # store input records to join with new records added later
256
+ if self.retain_inputs:
257
+ self._left_input_records.extend(left_candidates)
258
+ self._right_input_records.extend(right_candidates)
259
+
260
+ # if this is the final call, then add in any left/right/outer join records that did not match
261
+ if final:
262
+ return self._compute_unmatched_records(), 0
263
+
264
+ # return empty DataRecordSet if no output records were produced
265
+ if len(output_records) == 0:
266
+ return DataRecordSet([], []), num_inputs_processed
267
+
268
+ return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
269
+
270
+
271
+
272
+ class LLMJoin(JoinOp):
273
+ def __init__(
274
+ self,
275
+ model: Model,
276
+ prompt_strategy: PromptStrategy = PromptStrategy.JOIN,
277
+ reasoning_effort: str | None = None,
278
+ *args,
279
+ **kwargs,
280
+ ):
281
+ super().__init__(*args, **kwargs)
282
+ self.model = model
283
+ self.prompt_strategy = prompt_strategy
284
+ self.reasoning_effort = reasoning_effort
285
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
286
+
287
+ def __str__(self):
288
+ op = super().__str__()
289
+ op += f" Model: {self.model.value}\n"
290
+ op += f" Reasoning Effort: {self.reasoning_effort}\n"
291
+ op += f" Prompt Strategy: {self.prompt_strategy.value}\n"
292
+ return op
293
+
294
+ def get_id_params(self):
295
+ id_params = super().get_id_params()
296
+ id_params = {
297
+ "model": self.model.value,
298
+ "prompt_strategy": self.prompt_strategy.value,
299
+ "reasoning_effort": self.reasoning_effort,
300
+ **id_params,
301
+ }
302
+ return id_params
303
+
304
+ def get_op_params(self):
305
+ op_params = super().get_op_params()
306
+ op_params = {
307
+ "model": self.model,
308
+ "prompt_strategy": self.prompt_strategy,
309
+ "reasoning_effort": self.reasoning_effort,
310
+ **op_params,
311
+ }
312
+ return op_params
313
+
314
+ def get_model_name(self):
315
+ return self.model.value
316
+
104
317
  def _process_join_candidate_pair(
105
318
  self,
106
319
  left_candidate: DataRecord,
@@ -116,6 +329,15 @@ class JoinOp(PhysicalOperator, ABC):
116
329
  # determine whether or not the join was satisfied
117
330
  passed_operator = field_answers["passed_operator"]
118
331
 
332
+ # handle different join types
333
+ if self.how == "left" and passed_operator:
334
+ self._left_joined_record_ids.add(left_candidate._id)
335
+ elif self.how == "right" and passed_operator:
336
+ self._right_joined_record_ids.add(right_candidate._id)
337
+ elif self.how == "outer" and passed_operator:
338
+ self._left_joined_record_ids.add(left_candidate._id)
339
+ self._right_joined_record_ids.add(right_candidate._id)
340
+
119
341
  # compute output record and add to output_records
120
342
  join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
121
343
  join_dr._passed_operator = passed_operator
@@ -149,7 +371,7 @@ class JoinOp(PhysicalOperator, ABC):
149
371
  return join_dr, record_op_stats
150
372
 
151
373
 
152
- class NestedLoopsJoin(JoinOp):
374
+ class NestedLoopsJoin(LLMJoin):
153
375
 
154
376
  def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
155
377
  # estimate number of input tokens from source
@@ -192,7 +414,7 @@ class NestedLoopsJoin(JoinOp):
192
414
  quality=quality,
193
415
  )
194
416
 
195
- def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
417
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
196
418
  # get the set of input fields from both records in the join
197
419
  input_fields = self.get_input_fields()
198
420
 
@@ -234,6 +456,10 @@ class NestedLoopsJoin(JoinOp):
234
456
  self._left_input_records.extend(left_candidates)
235
457
  self._right_input_records.extend(right_candidates)
236
458
 
459
+ # if this is the final call, then add in any left/right/outer join records that did not match
460
+ if final:
461
+ return self._compute_unmatched_records(), 0
462
+
237
463
  # return empty DataRecordSet if no output records were produced
238
464
  if len(output_records) == 0:
239
465
  return DataRecordSet([], []), num_inputs_processed
@@ -241,7 +467,7 @@ class NestedLoopsJoin(JoinOp):
241
467
  return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
242
468
 
243
469
 
244
- class EmbeddingJoin(JoinOp):
470
+ class EmbeddingJoin(LLMJoin):
245
471
  # NOTE: we currently do not support audio joins as embedding models for audio seem to have
246
472
  # specialized use cases (e.g., speech-to-text) with strict requirements on things like e.g. sample rate
247
473
  def __init__(
@@ -261,6 +487,8 @@ class EmbeddingJoin(JoinOp):
261
487
  if field_name.split(".")[-1] in self.get_input_fields()
262
488
  ])
263
489
  self.embedding_model = Model.TEXT_EMBEDDING_3_SMALL if self.text_only else Model.CLIP_VIT_B_32
490
+ self.clip_model = None
491
+ self._lock = threading.Lock()
264
492
 
265
493
  # keep track of embedding costs that could not be amortized if no output records were produced
266
494
  self.residual_embedding_cost = 0.0
@@ -276,6 +504,11 @@ class EmbeddingJoin(JoinOp):
276
504
  self.min_matching_sim = float("inf")
277
505
  self.max_non_matching_sim = float("-inf")
278
506
 
507
+ def __str__(self):
508
+ op = super().__str__()
509
+ op += f" Num Samples: {self.num_samples}\n"
510
+ return op
511
+
279
512
  def get_id_params(self):
280
513
  id_params = super().get_id_params()
281
514
  id_params = {
@@ -327,6 +560,12 @@ class EmbeddingJoin(JoinOp):
327
560
  quality=quality,
328
561
  )
329
562
 
563
+ def _get_clip_model(self):
564
+ with self._lock:
565
+ if self.clip_model is None:
566
+ self.clip_model = SentenceTransformer(self.embedding_model.value)
567
+ return self.clip_model
568
+
330
569
  def _compute_embeddings(self, candidates: list[DataRecord], input_fields: list[str]) -> tuple[np.ndarray, GenerationStats]:
331
570
  # return empty array and empty stats if no candidates
332
571
  if len(candidates) == 0:
@@ -342,7 +581,7 @@ class EmbeddingJoin(JoinOp):
342
581
  total_input_tokens = response.usage.total_tokens
343
582
  embeddings = np.array([item.embedding for item in response.data])
344
583
  else:
345
- model = SentenceTransformer(self.embedding_model.value)
584
+ model = self._get_clip_model()
346
585
  embeddings = np.zeros((len(candidates), 512)) # CLIP embeddings are 512-dimensional
347
586
  num_input_fields_present = 0
348
587
  for field in input_fields:
@@ -389,6 +628,15 @@ class EmbeddingJoin(JoinOp):
389
628
  join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
390
629
  join_dr._passed_operator = passed_operator
391
630
 
631
+ # handle different join types
632
+ if self.how == "left" and passed_operator:
633
+ self._left_joined_record_ids.add(left_candidate._id)
634
+ elif self.how == "right" and passed_operator:
635
+ self._right_joined_record_ids.add(right_candidate._id)
636
+ elif self.how == "outer" and passed_operator:
637
+ self._left_joined_record_ids.add(left_candidate._id)
638
+ self._right_joined_record_ids.add(right_candidate._id)
639
+
392
640
  # NOTE: embedding costs are amortized over all records and added at the end of __call__
393
641
  # compute record stats and add to output_record_op_stats
394
642
  record_op_stats = RecordOpStats(
@@ -410,7 +658,7 @@ class EmbeddingJoin(JoinOp):
410
658
 
411
659
  return join_dr, record_op_stats
412
660
 
413
- def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
661
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
414
662
  # get the set of input fields from both records in the join
415
663
  input_fields = self.get_input_fields()
416
664
 
@@ -468,18 +716,22 @@ class EmbeddingJoin(JoinOp):
468
716
  self.max_non_matching_sim = embedding_sim
469
717
  if records_joined and embedding_sim < self.min_matching_sim:
470
718
  self.min_matching_sim = embedding_sim
471
-
719
+
472
720
  # update samples drawn and num_inputs_processed
473
721
  self.samples_drawn += samples_to_draw
474
722
  num_inputs_processed += samples_to_draw
475
723
 
476
724
  # process remaining candidates based on embedding similarity
477
725
  if len(join_candidates) > 0:
478
- assert self.samples_drawn == self.num_samples, "All samples should have been drawn before processing remaining candidates"
726
+ assert self.samples_drawn >= self.num_samples, "All samples should have been drawn before processing remaining candidates"
479
727
  with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
480
728
  futures = []
481
729
  for left_candidate, right_candidate, embedding_sim in join_candidates:
482
- llm_call_needed = self.min_matching_sim <= embedding_sim <= self.max_non_matching_sim
730
+ llm_call_needed = (
731
+ self.min_matching_sim == float("inf")
732
+ or self.max_non_matching_sim == float("-inf")
733
+ or self.min_matching_sim <= embedding_sim <= self.max_non_matching_sim
734
+ )
483
735
 
484
736
  if llm_call_needed:
485
737
  futures.append(executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim))
@@ -526,6 +778,10 @@ class EmbeddingJoin(JoinOp):
526
778
  self._left_input_records.extend(zip(left_candidates, left_embeddings))
527
779
  self._right_input_records.extend(zip(right_candidates, right_embeddings))
528
780
 
781
+ # if this is the final call, then add in any left/right/outer join records that did not match
782
+ if final:
783
+ return self._compute_unmatched_records(), 0
784
+
529
785
  # return empty DataRecordSet if no output records were produced
530
786
  if len(output_records) == 0:
531
787
  self.residual_embedding_cost = total_embedding_cost
@@ -9,7 +9,7 @@ from palimpzest.constants import AggFunc, Cardinality
9
9
  from palimpzest.core.data import context, dataset
10
10
  from palimpzest.core.elements.filters import Filter
11
11
  from palimpzest.core.elements.groupbysig import GroupBySig
12
- from palimpzest.core.lib.schemas import Average, Count
12
+ from palimpzest.core.lib.schemas import Average, Count, Max, Min, Sum
13
13
  from palimpzest.utils.hash_helpers import hash_for_id
14
14
 
15
15
 
@@ -25,7 +25,7 @@ class LogicalOperator:
25
25
  - LimitScan (scans up to N records from a Set)
26
26
  - GroupByAggregate (applies a group by on the Set)
27
27
  - Aggregate (applies an aggregation on the Set)
28
- - RetrieveScan (fetches documents from a provided input for a given query)
28
+ - TopKScan (fetches documents from a provided input for a given query)
29
29
  - Map (applies a function to each record in the Set without adding any new columns)
30
30
  - ComputeOperator (executes a computation described in natural language)
31
31
  - SearchOperator (executes a search query on the input Context)
@@ -149,27 +149,41 @@ class Aggregate(LogicalOperator):
149
149
 
150
150
  def __init__(
151
151
  self,
152
- agg_func: AggFunc,
152
+ agg_func: AggFunc | None = None,
153
+ agg_str: str | None = None,
153
154
  *args,
154
155
  **kwargs,
155
156
  ):
157
+ assert agg_func is not None or agg_str is not None, "Either agg_func or agg_str must be provided"
156
158
  if kwargs.get("output_schema") is None:
157
159
  if agg_func == AggFunc.COUNT:
158
160
  kwargs["output_schema"] = Count
159
161
  elif agg_func == AggFunc.AVERAGE:
160
162
  kwargs["output_schema"] = Average
163
+ elif agg_func == AggFunc.SUM:
164
+ kwargs["output_schema"] = Sum
165
+ elif agg_func == AggFunc.MIN:
166
+ kwargs["output_schema"] = Min
167
+ elif agg_func == AggFunc.MAX:
168
+ kwargs["output_schema"] = Max
161
169
  else:
162
170
  raise ValueError(f"Unsupported aggregation function: {agg_func}")
163
171
 
164
172
  super().__init__(*args, **kwargs)
165
173
  self.agg_func = agg_func
174
+ self.agg_str = agg_str
166
175
 
167
176
  def __str__(self):
168
- return f"{self.__class__.__name__}(function: {str(self.agg_func.value)})"
177
+ desc = f"function: {str(self.agg_func.value)}" if self.agg_func else f"agg: {self.agg_str}"
178
+ return f"{self.__class__.__name__}({desc})"
169
179
 
170
180
  def get_logical_id_params(self) -> dict:
171
181
  logical_id_params = super().get_logical_id_params()
172
- logical_id_params = {"agg_func": self.agg_func, **logical_id_params}
182
+ logical_id_params = {
183
+ "agg_func": self.agg_func,
184
+ "agg_str": self.agg_str,
185
+ **logical_id_params,
186
+ }
173
187
 
174
188
  return logical_id_params
175
189
 
@@ -177,6 +191,7 @@ class Aggregate(LogicalOperator):
177
191
  logical_op_params = super().get_logical_op_params()
178
192
  logical_op_params = {
179
193
  "agg_func": self.agg_func,
194
+ "agg_str": self.agg_str,
180
195
  **logical_op_params,
181
196
  }
182
197
 
@@ -398,17 +413,25 @@ class GroupByAggregate(LogicalOperator):
398
413
 
399
414
 
400
415
  class JoinOp(LogicalOperator):
401
- def __init__(self, condition: str, desc: str | None = None, *args, **kwargs):
416
+ def __init__(self, condition: str, on: list[str] | None = None, how: str = "inner", desc: str | None = None, *args, **kwargs):
402
417
  super().__init__(*args, **kwargs)
403
418
  self.condition = condition
419
+ self.on = on
420
+ self.how = how
404
421
  self.desc = desc
405
422
 
406
423
  def __str__(self):
407
- return f"Join(condition={self.condition})"
424
+ return f"Join(condition={self.condition})" if self.on is None else f"Join(on={self.on}, how={self.how})"
408
425
 
409
426
  def get_logical_id_params(self) -> dict:
410
427
  logical_id_params = super().get_logical_id_params()
411
- logical_id_params = {"condition": self.condition, "desc": self.desc, **logical_id_params}
428
+ logical_id_params = {
429
+ "condition": self.condition,
430
+ "on": self.on,
431
+ "how": self.how,
432
+ "desc": self.desc,
433
+ **logical_id_params,
434
+ }
412
435
 
413
436
  return logical_id_params
414
437
 
@@ -416,6 +439,8 @@ class JoinOp(LogicalOperator):
416
439
  logical_op_params = super().get_logical_op_params()
417
440
  logical_op_params = {
418
441
  "condition": self.condition,
442
+ "on": self.on,
443
+ "how": self.how,
419
444
  "desc": self.desc,
420
445
  **logical_op_params,
421
446
  }
@@ -471,8 +496,8 @@ class Project(LogicalOperator):
471
496
  return logical_op_params
472
497
 
473
498
 
474
- class RetrieveScan(LogicalOperator):
475
- """A RetrieveScan is a logical operator that represents a scan of a particular input Dataset, with a convert-like retrieve applied."""
499
+ class TopKScan(LogicalOperator):
500
+ """A TopKScan is a logical operator that represents a scan of a particular input Dataset, with a top-k operation applied."""
476
501
 
477
502
  def __init__(
478
503
  self,
@@ -492,7 +517,7 @@ class RetrieveScan(LogicalOperator):
492
517
  self.k = k
493
518
 
494
519
  def __str__(self):
495
- return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)})"
520
+ return f"TopKScan({self.input_schema} -> {str(self.output_schema)})"
496
521
 
497
522
  def get_logical_id_params(self) -> dict:
498
523
  # NOTE: if we allow optimization over index, then we will need to include it in the id params
@@ -75,8 +75,9 @@ class MixtureOfAgentsConvert(LLMConvert):
75
75
  In practice, this naive quality estimate will be overwritten by the CostModel's estimate
76
76
  once it executes a few instances of the operator.
77
77
  """
78
- # temporarily set self.model so that super().naive_cost_estimates(...) can compute an estimate
78
+ # temporarily set self.model and self.prompt_strategy so that super().naive_cost_estimates(...) can compute an estimate
79
79
  self.model = self.proposer_models[0]
80
+ self.prompt_strategy = PromptStrategy.MAP_MOA_PROPOSER
80
81
 
81
82
  # get naive cost estimates for single LLM call and scale it by number of LLMs used in MoA
82
83
  naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
@@ -98,6 +99,7 @@ class MixtureOfAgentsConvert(LLMConvert):
98
99
 
99
100
  # reset self.model to be None
100
101
  self.model = None
102
+ self.prompt_strategy = None
101
103
 
102
104
  return naive_op_cost_estimates
103
105
 
@@ -42,10 +42,13 @@ class PhysicalOperator:
42
42
  self.op_id = None
43
43
 
44
44
  # compute the input modalities (if any) for this physical operator
45
+ depends_on_short_field_names = [field.split(".")[-1] for field in self.depends_on] if self.depends_on is not None else None
45
46
  self.input_modalities = None
46
47
  if self.input_schema is not None:
47
48
  self.input_modalities = set()
48
- for field in self.input_schema.model_fields.values():
49
+ for field_name, field in self.input_schema.model_fields.items():
50
+ if self.depends_on is not None and field_name not in depends_on_short_field_names:
51
+ continue
49
52
  field_type = field.annotation
50
53
  if field_type in IMAGE_FIELD_TYPES:
51
54
  self.input_modalities.add(Modality.IMAGE)
@@ -191,7 +194,7 @@ class PhysicalOperator:
191
194
  in the candidate. This is important for operators with retry logic, where we may only need to
192
195
  recompute a subset of self.generated_fields.
193
196
 
194
- Right now this is only used by convert and retrieve operators.
197
+ Right now this is only used by convert and top-k operators.
195
198
  """
196
199
  fields_to_generate = [
197
200
  field_name