palimpzest 0.9.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. palimpzest/constants.py +1 -0
  2. palimpzest/core/data/dataset.py +33 -5
  3. palimpzest/core/elements/groupbysig.py +10 -1
  4. palimpzest/core/elements/records.py +16 -7
  5. palimpzest/core/lib/schemas.py +20 -3
  6. palimpzest/core/models.py +10 -4
  7. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  8. palimpzest/query/execution/execution_strategy.py +13 -11
  9. palimpzest/query/execution/mab_execution_strategy.py +40 -14
  10. palimpzest/query/execution/parallel_execution_strategy.py +31 -7
  11. palimpzest/query/execution/single_threaded_execution_strategy.py +23 -6
  12. palimpzest/query/generators/generators.py +1 -1
  13. palimpzest/query/operators/__init__.py +7 -6
  14. palimpzest/query/operators/aggregate.py +110 -5
  15. palimpzest/query/operators/convert.py +1 -1
  16. palimpzest/query/operators/join.py +279 -23
  17. palimpzest/query/operators/logical.py +20 -8
  18. palimpzest/query/operators/mixture_of_agents.py +3 -1
  19. palimpzest/query/operators/physical.py +5 -2
  20. palimpzest/query/operators/rag.py +5 -4
  21. palimpzest/query/operators/{retrieve.py → topk.py} +10 -10
  22. palimpzest/query/optimizer/__init__.py +7 -3
  23. palimpzest/query/optimizer/cost_model.py +5 -5
  24. palimpzest/query/optimizer/optimizer.py +3 -2
  25. palimpzest/query/optimizer/plan.py +2 -3
  26. palimpzest/query/optimizer/rules.py +31 -11
  27. palimpzest/query/optimizer/tasks.py +4 -4
  28. palimpzest/query/processor/config.py +1 -0
  29. palimpzest/utils/progress.py +51 -23
  30. palimpzest/validator/validator.py +7 -7
  31. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/METADATA +26 -66
  32. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/RECORD +35 -35
  33. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/WHEEL +0 -0
  34. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/licenses/LICENSE +0 -0
  35. {palimpzest-0.9.0.dist-info → palimpzest-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import threading
3
4
  import time
4
5
  from abc import ABC, abstractmethod
5
6
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -37,10 +38,9 @@ class JoinOp(PhysicalOperator, ABC):
37
38
  def __init__(
38
39
  self,
39
40
  condition: str,
40
- model: Model,
41
- prompt_strategy: PromptStrategy = PromptStrategy.JOIN,
41
+ how: str = "inner",
42
+ on: list[str] | None = None,
42
43
  join_parallelism: int = 64,
43
- reasoning_effort: str | None = None,
44
44
  retain_inputs: bool = True,
45
45
  desc: str | None = None,
46
46
  *args,
@@ -49,33 +49,37 @@ class JoinOp(PhysicalOperator, ABC):
49
49
  super().__init__(*args, **kwargs)
50
50
  assert self.input_schema == self.output_schema, "Input and output schemas must match for JoinOp"
51
51
  self.condition = condition
52
- self.model = model
53
- self.prompt_strategy = prompt_strategy
52
+ self.how = how
53
+ self.on = on
54
54
  self.join_parallelism = join_parallelism
55
- self.reasoning_effort = reasoning_effort
56
55
  self.retain_inputs = retain_inputs
57
56
  self.desc = desc
58
- self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
59
57
  self.join_idx = 0
58
+ self.finished = False
60
59
 
61
60
  # maintain list(s) of input records for the join
62
61
  self._left_input_records: list[DataRecord] = []
63
62
  self._right_input_records: list[DataRecord] = []
64
63
 
64
+ # maintain set of left/right record ids that have been joined (for left/right/outer joins)
65
+ self._left_joined_record_ids: set[str] = set()
66
+ self._right_joined_record_ids: set[str] = set()
67
+
65
68
  def __str__(self):
66
69
  op = super().__str__()
67
70
  op += f" Condition: {self.condition}\n"
71
+ op += f" How: {self.how}\n"
72
+ op += f" On: {self.on}\n"
68
73
  return op
69
74
 
70
75
  def get_id_params(self):
71
76
  id_params = super().get_id_params()
72
77
  id_params = {
73
78
  "condition": self.condition,
74
- "model": self.model.value,
75
- "prompt_strategy": self.prompt_strategy.value,
76
79
  "join_parallelism": self.join_parallelism,
77
- "reasoning_effort": self.reasoning_effort,
78
80
  "desc": self.desc,
81
+ "how": self.how,
82
+ "on": self.on,
79
83
  **id_params,
80
84
  }
81
85
  return id_params
@@ -84,23 +88,232 @@ class JoinOp(PhysicalOperator, ABC):
84
88
  op_params = super().get_op_params()
85
89
  op_params = {
86
90
  "condition": self.condition,
87
- "model": self.model,
88
- "prompt_strategy": self.prompt_strategy,
89
91
  "join_parallelism": self.join_parallelism,
90
- "reasoning_effort": self.reasoning_effort,
91
92
  "retain_inputs": self.retain_inputs,
92
93
  "desc": self.desc,
94
+ "how": self.how,
95
+ "on": self.on,
93
96
  **op_params,
94
97
  }
95
98
  return op_params
96
99
 
97
- def get_model_name(self):
98
- return self.model.value
100
+ def _compute_unmatched_records(self) -> DataRecordSet:
101
+ """Helper function to compute unmatched records for left/right/outer joins."""
102
+ def join_unmatched_records(input_records: list[DataRecord] | list[tuple[DataRecord, list[float]]], joined_record_ids: set[str], left: bool = True):
103
+ records, record_op_stats_lst = [], []
104
+ for record in input_records:
105
+ start_time = time.time()
106
+ record = record[0] if isinstance(record, tuple) else record
107
+ if record._id not in joined_record_ids:
108
+ unmatched_dr = (
109
+ DataRecord.from_join_parents(self.output_schema, record, None)
110
+ if left
111
+ else DataRecord.from_join_parents(self.output_schema, None, record)
112
+ )
113
+ unmatched_dr._passed_operator = True
114
+
115
+ # compute record stats and add to output_record_op_stats
116
+ time_per_record = time.time() - start_time
117
+ record_op_stats = RecordOpStats(
118
+ record_id=unmatched_dr._id,
119
+ record_parent_ids=unmatched_dr._parent_ids,
120
+ record_source_indices=unmatched_dr._source_indices,
121
+ record_state=unmatched_dr.to_dict(include_bytes=False),
122
+ full_op_id=self.get_full_op_id(),
123
+ logical_op_id=self.logical_op_id,
124
+ op_name=self.op_name(),
125
+ time_per_record=time_per_record,
126
+ cost_per_record=0.0,
127
+ model_name=self.get_model_name(),
128
+ join_condition=str(self.on),
129
+ fn_call_duration_secs=time_per_record,
130
+ answer={"passed_operator": True},
131
+ passed_operator=True,
132
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
133
+ )
134
+ records.append(unmatched_dr)
135
+ record_op_stats_lst.append(record_op_stats)
136
+ return records, record_op_stats_lst
137
+
138
+ records, record_op_stats = [], []
139
+ if self.how == "left":
140
+ records, record_op_stats = join_unmatched_records(self._left_input_records, self._left_joined_record_ids, left=True)
141
+
142
+ elif self.how == "right":
143
+ records, record_op_stats = join_unmatched_records(self._right_input_records, self._right_joined_record_ids, left=False)
144
+
145
+ elif self.how == "outer":
146
+ records, record_op_stats = join_unmatched_records(self._left_input_records, self._left_joined_record_ids, left=True)
147
+ right_records, right_record_op_stats = join_unmatched_records(self._right_input_records, self._right_joined_record_ids, left=False)
148
+ records.extend(right_records)
149
+ record_op_stats.extend(right_record_op_stats)
150
+
151
+ return DataRecordSet(records, record_op_stats)
99
152
 
100
153
  @abstractmethod
101
154
  def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
102
155
  pass
103
156
 
157
+ def set_finished(self):
158
+ """Mark the operator as finished after computing left/right/outer join logic."""
159
+ self.finished = True
160
+
161
+ class RelationalJoin(JoinOp):
162
+
163
+ def get_model_name(self):
164
+ return None
165
+
166
+ def _process_join_candidate_pair(self, left_candidate, right_candidate) -> tuple[DataRecord, RecordOpStats]:
167
+ start_time = time.time()
168
+
169
+ # determine whether or not the join was satisfied
170
+ passed_operator = all(
171
+ left_candidate[field] == right_candidate[field]
172
+ for field in self.on
173
+ )
174
+
175
+ # handle different join types
176
+ if self.how == "left" and passed_operator:
177
+ self._left_joined_record_ids.add(left_candidate._id)
178
+ elif self.how == "right" and passed_operator:
179
+ self._right_joined_record_ids.add(right_candidate._id)
180
+ elif self.how == "outer" and passed_operator:
181
+ self._left_joined_record_ids.add(left_candidate._id)
182
+ self._right_joined_record_ids.add(right_candidate._id)
183
+
184
+ # compute output record and add to output_records
185
+ join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
186
+ join_dr._passed_operator = passed_operator
187
+
188
+ # compute record stats and add to output_record_op_stats
189
+ time_per_record = time.time() - start_time
190
+ record_op_stats = RecordOpStats(
191
+ record_id=join_dr._id,
192
+ record_parent_ids=join_dr._parent_ids,
193
+ record_source_indices=join_dr._source_indices,
194
+ record_state=join_dr.to_dict(include_bytes=False),
195
+ full_op_id=self.get_full_op_id(),
196
+ logical_op_id=self.logical_op_id,
197
+ op_name=self.op_name(),
198
+ time_per_record=time_per_record,
199
+ cost_per_record=0.0,
200
+ model_name=self.get_model_name(),
201
+ join_condition=str(self.on),
202
+ fn_call_duration_secs=time_per_record,
203
+ answer={"passed_operator": passed_operator},
204
+ passed_operator=passed_operator,
205
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
206
+ )
207
+
208
+ return join_dr, record_op_stats
209
+
210
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
211
+ # estimate output cardinality using a constant assumption of the filter selectivity
212
+ selectivity = NAIVE_EST_JOIN_SELECTIVITY
213
+ cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
214
+
215
+ # estimate 1 ms execution time per input record pair
216
+ time_per_record = 0.001 * (left_source_op_cost_estimates.cardinality + right_source_op_cost_estimates.cardinality)
217
+
218
+ return OperatorCostEstimates(
219
+ cardinality=cardinality,
220
+ time_per_record=time_per_record,
221
+ cost_per_record=0.0,
222
+ quality=1.0,
223
+ )
224
+
225
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
226
+ # create the set of candidates to join
227
+ join_candidates = []
228
+ for candidate in left_candidates:
229
+ for right_candidate in right_candidates:
230
+ join_candidates.append((candidate, right_candidate))
231
+ for right_candidate in self._right_input_records:
232
+ join_candidates.append((candidate, right_candidate))
233
+ for candidate in self._left_input_records:
234
+ for right_candidate in right_candidates:
235
+ join_candidates.append((candidate, right_candidate))
236
+
237
+ # apply the join logic to each pair of candidates
238
+ output_records, output_record_op_stats = [], []
239
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
240
+ futures = [
241
+ executor.submit(self._process_join_candidate_pair, candidate, right_candidate)
242
+ for candidate, right_candidate in join_candidates
243
+ ]
244
+
245
+ # collect results as they complete
246
+ for future in as_completed(futures):
247
+ self.join_idx += 1
248
+ join_output_record, join_output_record_op_stats = future.result()
249
+ output_records.append(join_output_record)
250
+ output_record_op_stats.append(join_output_record_op_stats)
251
+
252
+ # compute the number of inputs processed
253
+ num_inputs_processed = len(join_candidates)
254
+
255
+ # store input records to join with new records added later
256
+ if self.retain_inputs:
257
+ self._left_input_records.extend(left_candidates)
258
+ self._right_input_records.extend(right_candidates)
259
+
260
+ # if this is the final call, then add in any left/right/outer join records that did not match
261
+ if final:
262
+ return self._compute_unmatched_records(), 0
263
+
264
+ # return empty DataRecordSet if no output records were produced
265
+ if len(output_records) == 0:
266
+ return DataRecordSet([], []), num_inputs_processed
267
+
268
+ return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
269
+
270
+
271
+
272
+ class LLMJoin(JoinOp):
273
+ def __init__(
274
+ self,
275
+ model: Model,
276
+ prompt_strategy: PromptStrategy = PromptStrategy.JOIN,
277
+ reasoning_effort: str | None = None,
278
+ *args,
279
+ **kwargs,
280
+ ):
281
+ super().__init__(*args, **kwargs)
282
+ self.model = model
283
+ self.prompt_strategy = prompt_strategy
284
+ self.reasoning_effort = reasoning_effort
285
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
286
+
287
+ def __str__(self):
288
+ op = super().__str__()
289
+ op += f" Model: {self.model.value}\n"
290
+ op += f" Reasoning Effort: {self.reasoning_effort}\n"
291
+ op += f" Prompt Strategy: {self.prompt_strategy.value}\n"
292
+ return op
293
+
294
+ def get_id_params(self):
295
+ id_params = super().get_id_params()
296
+ id_params = {
297
+ "model": self.model.value,
298
+ "prompt_strategy": self.prompt_strategy.value,
299
+ "reasoning_effort": self.reasoning_effort,
300
+ **id_params,
301
+ }
302
+ return id_params
303
+
304
+ def get_op_params(self):
305
+ op_params = super().get_op_params()
306
+ op_params = {
307
+ "model": self.model,
308
+ "prompt_strategy": self.prompt_strategy,
309
+ "reasoning_effort": self.reasoning_effort,
310
+ **op_params,
311
+ }
312
+ return op_params
313
+
314
+ def get_model_name(self):
315
+ return self.model.value
316
+
104
317
  def _process_join_candidate_pair(
105
318
  self,
106
319
  left_candidate: DataRecord,
@@ -116,6 +329,15 @@ class JoinOp(PhysicalOperator, ABC):
116
329
  # determine whether or not the join was satisfied
117
330
  passed_operator = field_answers["passed_operator"]
118
331
 
332
+ # handle different join types
333
+ if self.how == "left" and passed_operator:
334
+ self._left_joined_record_ids.add(left_candidate._id)
335
+ elif self.how == "right" and passed_operator:
336
+ self._right_joined_record_ids.add(right_candidate._id)
337
+ elif self.how == "outer" and passed_operator:
338
+ self._left_joined_record_ids.add(left_candidate._id)
339
+ self._right_joined_record_ids.add(right_candidate._id)
340
+
119
341
  # compute output record and add to output_records
120
342
  join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
121
343
  join_dr._passed_operator = passed_operator
@@ -149,7 +371,7 @@ class JoinOp(PhysicalOperator, ABC):
149
371
  return join_dr, record_op_stats
150
372
 
151
373
 
152
- class NestedLoopsJoin(JoinOp):
374
+ class NestedLoopsJoin(LLMJoin):
153
375
 
154
376
  def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
155
377
  # estimate number of input tokens from source
@@ -192,7 +414,7 @@ class NestedLoopsJoin(JoinOp):
192
414
  quality=quality,
193
415
  )
194
416
 
195
- def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
417
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
196
418
  # get the set of input fields from both records in the join
197
419
  input_fields = self.get_input_fields()
198
420
 
@@ -234,6 +456,10 @@ class NestedLoopsJoin(JoinOp):
234
456
  self._left_input_records.extend(left_candidates)
235
457
  self._right_input_records.extend(right_candidates)
236
458
 
459
+ # if this is the final call, then add in any left/right/outer join records that did not match
460
+ if final:
461
+ return self._compute_unmatched_records(), 0
462
+
237
463
  # return empty DataRecordSet if no output records were produced
238
464
  if len(output_records) == 0:
239
465
  return DataRecordSet([], []), num_inputs_processed
@@ -241,7 +467,7 @@ class NestedLoopsJoin(JoinOp):
241
467
  return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
242
468
 
243
469
 
244
- class EmbeddingJoin(JoinOp):
470
+ class EmbeddingJoin(LLMJoin):
245
471
  # NOTE: we currently do not support audio joins as embedding models for audio seem to have
246
472
  # specialized use cases (e.g., speech-to-text) with strict requirements on things like e.g. sample rate
247
473
  def __init__(
@@ -261,6 +487,8 @@ class EmbeddingJoin(JoinOp):
261
487
  if field_name.split(".")[-1] in self.get_input_fields()
262
488
  ])
263
489
  self.embedding_model = Model.TEXT_EMBEDDING_3_SMALL if self.text_only else Model.CLIP_VIT_B_32
490
+ self.clip_model = None
491
+ self._lock = threading.Lock()
264
492
 
265
493
  # keep track of embedding costs that could not be amortized if no output records were produced
266
494
  self.residual_embedding_cost = 0.0
@@ -276,6 +504,11 @@ class EmbeddingJoin(JoinOp):
276
504
  self.min_matching_sim = float("inf")
277
505
  self.max_non_matching_sim = float("-inf")
278
506
 
507
+ def __str__(self):
508
+ op = super().__str__()
509
+ op += f" Num Samples: {self.num_samples}\n"
510
+ return op
511
+
279
512
  def get_id_params(self):
280
513
  id_params = super().get_id_params()
281
514
  id_params = {
@@ -327,6 +560,12 @@ class EmbeddingJoin(JoinOp):
327
560
  quality=quality,
328
561
  )
329
562
 
563
+ def _get_clip_model(self):
564
+ with self._lock:
565
+ if self.clip_model is None:
566
+ self.clip_model = SentenceTransformer(self.embedding_model.value)
567
+ return self.clip_model
568
+
330
569
  def _compute_embeddings(self, candidates: list[DataRecord], input_fields: list[str]) -> tuple[np.ndarray, GenerationStats]:
331
570
  # return empty array and empty stats if no candidates
332
571
  if len(candidates) == 0:
@@ -342,7 +581,7 @@ class EmbeddingJoin(JoinOp):
342
581
  total_input_tokens = response.usage.total_tokens
343
582
  embeddings = np.array([item.embedding for item in response.data])
344
583
  else:
345
- model = SentenceTransformer(self.embedding_model.value)
584
+ model = self._get_clip_model()
346
585
  embeddings = np.zeros((len(candidates), 512)) # CLIP embeddings are 512-dimensional
347
586
  num_input_fields_present = 0
348
587
  for field in input_fields:
@@ -389,6 +628,15 @@ class EmbeddingJoin(JoinOp):
389
628
  join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
390
629
  join_dr._passed_operator = passed_operator
391
630
 
631
+ # handle different join types
632
+ if self.how == "left" and passed_operator:
633
+ self._left_joined_record_ids.add(left_candidate._id)
634
+ elif self.how == "right" and passed_operator:
635
+ self._right_joined_record_ids.add(right_candidate._id)
636
+ elif self.how == "outer" and passed_operator:
637
+ self._left_joined_record_ids.add(left_candidate._id)
638
+ self._right_joined_record_ids.add(right_candidate._id)
639
+
392
640
  # NOTE: embedding costs are amortized over all records and added at the end of __call__
393
641
  # compute record stats and add to output_record_op_stats
394
642
  record_op_stats = RecordOpStats(
@@ -410,7 +658,7 @@ class EmbeddingJoin(JoinOp):
410
658
 
411
659
  return join_dr, record_op_stats
412
660
 
413
- def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
661
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
414
662
  # get the set of input fields from both records in the join
415
663
  input_fields = self.get_input_fields()
416
664
 
@@ -468,18 +716,22 @@ class EmbeddingJoin(JoinOp):
468
716
  self.max_non_matching_sim = embedding_sim
469
717
  if records_joined and embedding_sim < self.min_matching_sim:
470
718
  self.min_matching_sim = embedding_sim
471
-
719
+
472
720
  # update samples drawn and num_inputs_processed
473
721
  self.samples_drawn += samples_to_draw
474
722
  num_inputs_processed += samples_to_draw
475
723
 
476
724
  # process remaining candidates based on embedding similarity
477
725
  if len(join_candidates) > 0:
478
- assert self.samples_drawn == self.num_samples, "All samples should have been drawn before processing remaining candidates"
726
+ assert self.samples_drawn >= self.num_samples, "All samples should have been drawn before processing remaining candidates"
479
727
  with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
480
728
  futures = []
481
729
  for left_candidate, right_candidate, embedding_sim in join_candidates:
482
- llm_call_needed = self.min_matching_sim <= embedding_sim <= self.max_non_matching_sim
730
+ llm_call_needed = (
731
+ self.min_matching_sim == float("inf")
732
+ or self.max_non_matching_sim == float("-inf")
733
+ or self.min_matching_sim <= embedding_sim <= self.max_non_matching_sim
734
+ )
483
735
 
484
736
  if llm_call_needed:
485
737
  futures.append(executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim))
@@ -526,6 +778,10 @@ class EmbeddingJoin(JoinOp):
526
778
  self._left_input_records.extend(zip(left_candidates, left_embeddings))
527
779
  self._right_input_records.extend(zip(right_candidates, right_embeddings))
528
780
 
781
+ # if this is the final call, then add in any left/right/outer join records that did not match
782
+ if final:
783
+ return self._compute_unmatched_records(), 0
784
+
529
785
  # return empty DataRecordSet if no output records were produced
530
786
  if len(output_records) == 0:
531
787
  self.residual_embedding_cost = total_embedding_cost
@@ -9,7 +9,7 @@ from palimpzest.constants import AggFunc, Cardinality
9
9
  from palimpzest.core.data import context, dataset
10
10
  from palimpzest.core.elements.filters import Filter
11
11
  from palimpzest.core.elements.groupbysig import GroupBySig
12
- from palimpzest.core.lib.schemas import Average, Count, Max, Min
12
+ from palimpzest.core.lib.schemas import Average, Count, Max, Min, Sum
13
13
  from palimpzest.utils.hash_helpers import hash_for_id
14
14
 
15
15
 
@@ -25,7 +25,7 @@ class LogicalOperator:
25
25
  - LimitScan (scans up to N records from a Set)
26
26
  - GroupByAggregate (applies a group by on the Set)
27
27
  - Aggregate (applies an aggregation on the Set)
28
- - RetrieveScan (fetches documents from a provided input for a given query)
28
+ - TopKScan (fetches documents from a provided input for a given query)
29
29
  - Map (applies a function to each record in the Set without adding any new columns)
30
30
  - ComputeOperator (executes a computation described in natural language)
31
31
  - SearchOperator (executes a search query on the input Context)
@@ -160,6 +160,8 @@ class Aggregate(LogicalOperator):
160
160
  kwargs["output_schema"] = Count
161
161
  elif agg_func == AggFunc.AVERAGE:
162
162
  kwargs["output_schema"] = Average
163
+ elif agg_func == AggFunc.SUM:
164
+ kwargs["output_schema"] = Sum
163
165
  elif agg_func == AggFunc.MIN:
164
166
  kwargs["output_schema"] = Min
165
167
  elif agg_func == AggFunc.MAX:
@@ -411,17 +413,25 @@ class GroupByAggregate(LogicalOperator):
411
413
 
412
414
 
413
415
  class JoinOp(LogicalOperator):
414
- def __init__(self, condition: str, desc: str | None = None, *args, **kwargs):
416
+ def __init__(self, condition: str, on: list[str] | None = None, how: str = "inner", desc: str | None = None, *args, **kwargs):
415
417
  super().__init__(*args, **kwargs)
416
418
  self.condition = condition
419
+ self.on = on
420
+ self.how = how
417
421
  self.desc = desc
418
422
 
419
423
  def __str__(self):
420
- return f"Join(condition={self.condition})"
424
+ return f"Join(condition={self.condition})" if self.on is None else f"Join(on={self.on}, how={self.how})"
421
425
 
422
426
  def get_logical_id_params(self) -> dict:
423
427
  logical_id_params = super().get_logical_id_params()
424
- logical_id_params = {"condition": self.condition, "desc": self.desc, **logical_id_params}
428
+ logical_id_params = {
429
+ "condition": self.condition,
430
+ "on": self.on,
431
+ "how": self.how,
432
+ "desc": self.desc,
433
+ **logical_id_params,
434
+ }
425
435
 
426
436
  return logical_id_params
427
437
 
@@ -429,6 +439,8 @@ class JoinOp(LogicalOperator):
429
439
  logical_op_params = super().get_logical_op_params()
430
440
  logical_op_params = {
431
441
  "condition": self.condition,
442
+ "on": self.on,
443
+ "how": self.how,
432
444
  "desc": self.desc,
433
445
  **logical_op_params,
434
446
  }
@@ -484,8 +496,8 @@ class Project(LogicalOperator):
484
496
  return logical_op_params
485
497
 
486
498
 
487
- class RetrieveScan(LogicalOperator):
488
- """A RetrieveScan is a logical operator that represents a scan of a particular input Dataset, with a convert-like retrieve applied."""
499
+ class TopKScan(LogicalOperator):
500
+ """A TopKScan is a logical operator that represents a scan of a particular input Dataset, with a top-k operation applied."""
489
501
 
490
502
  def __init__(
491
503
  self,
@@ -505,7 +517,7 @@ class RetrieveScan(LogicalOperator):
505
517
  self.k = k
506
518
 
507
519
  def __str__(self):
508
- return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)})"
520
+ return f"TopKScan({self.input_schema} -> {str(self.output_schema)})"
509
521
 
510
522
  def get_logical_id_params(self) -> dict:
511
523
  # NOTE: if we allow optimization over index, then we will need to include it in the id params
@@ -75,8 +75,9 @@ class MixtureOfAgentsConvert(LLMConvert):
75
75
  In practice, this naive quality estimate will be overwritten by the CostModel's estimate
76
76
  once it executes a few instances of the operator.
77
77
  """
78
- # temporarily set self.model so that super().naive_cost_estimates(...) can compute an estimate
78
+ # temporarily set self.model and self.prompt_strategy so that super().naive_cost_estimates(...) can compute an estimate
79
79
  self.model = self.proposer_models[0]
80
+ self.prompt_strategy = PromptStrategy.MAP_MOA_PROPOSER
80
81
 
81
82
  # get naive cost estimates for single LLM call and scale it by number of LLMs used in MoA
82
83
  naive_op_cost_estimates = super().naive_cost_estimates(source_op_cost_estimates)
@@ -98,6 +99,7 @@ class MixtureOfAgentsConvert(LLMConvert):
98
99
 
99
100
  # reset self.model to be None
100
101
  self.model = None
102
+ self.prompt_strategy = None
101
103
 
102
104
  return naive_op_cost_estimates
103
105
 
@@ -42,10 +42,13 @@ class PhysicalOperator:
42
42
  self.op_id = None
43
43
 
44
44
  # compute the input modalities (if any) for this physical operator
45
+ depends_on_short_field_names = [field.split(".")[-1] for field in self.depends_on] if self.depends_on is not None else None
45
46
  self.input_modalities = None
46
47
  if self.input_schema is not None:
47
48
  self.input_modalities = set()
48
- for field in self.input_schema.model_fields.values():
49
+ for field_name, field in self.input_schema.model_fields.items():
50
+ if self.depends_on is not None and field_name not in depends_on_short_field_names:
51
+ continue
49
52
  field_type = field.annotation
50
53
  if field_type in IMAGE_FIELD_TYPES:
51
54
  self.input_modalities.add(Modality.IMAGE)
@@ -191,7 +194,7 @@ class PhysicalOperator:
191
194
  in the candidate. This is important for operators with retry logic, where we may only need to
192
195
  recompute a subset of self.generated_fields.
193
196
 
194
- Right now this is only used by convert and retrieve operators.
197
+ Right now this is only used by convert and top-k operators.
195
198
  """
196
199
  fields_to_generate = [
197
200
  field_name
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import time
4
+ from typing import Any
4
5
 
5
6
  from numpy import dot
6
7
  from numpy.linalg import norm
@@ -153,8 +154,8 @@ class RAGConvert(LLMConvert):
153
154
  field = candidate.get_field_type(field_name)
154
155
 
155
156
  # skip this field if it is not a string or a list of strings
156
- is_string_field = field.annotation in [str, str | None]
157
- is_list_string_field = field.annotation in [list[str], list[str] | None]
157
+ is_string_field = field.annotation in [str, str | None, str | Any]
158
+ is_list_string_field = field.annotation in [list[str], list[str] | None, list[str] | Any]
158
159
  if not (is_string_field or is_list_string_field):
159
160
  continue
160
161
 
@@ -358,8 +359,8 @@ class RAGFilter(LLMFilter):
358
359
  field = candidate.get_field_type(field_name)
359
360
 
360
361
  # skip this field if it is not a string or a list of strings
361
- is_string_field = field.annotation in [str, str | None]
362
- is_list_string_field = field.annotation in [list[str], list[str] | None]
362
+ is_string_field = field.annotation in [str, str | None, str | Any]
363
+ is_list_string_field = field.annotation in [list[str], list[str] | None, list[str] | Any]
363
364
  if not (is_string_field or is_list_string_field):
364
365
  continue
365
366