palimpzest 0.7.7__py3-none-any.whl → 0.7.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. palimpzest/constants.py +113 -75
  2. palimpzest/core/data/dataclasses.py +55 -38
  3. palimpzest/core/elements/index.py +5 -15
  4. palimpzest/core/elements/records.py +1 -1
  5. palimpzest/prompts/prompt_factory.py +1 -1
  6. palimpzest/query/execution/all_sample_execution_strategy.py +216 -0
  7. palimpzest/query/execution/execution_strategy.py +4 -4
  8. palimpzest/query/execution/execution_strategy_type.py +7 -1
  9. palimpzest/query/execution/mab_execution_strategy.py +184 -72
  10. palimpzest/query/execution/parallel_execution_strategy.py +182 -15
  11. palimpzest/query/execution/single_threaded_execution_strategy.py +21 -21
  12. palimpzest/query/generators/api_client_factory.py +6 -7
  13. palimpzest/query/generators/generators.py +5 -8
  14. palimpzest/query/operators/aggregate.py +4 -3
  15. palimpzest/query/operators/convert.py +1 -1
  16. palimpzest/query/operators/filter.py +1 -1
  17. palimpzest/query/operators/limit.py +1 -1
  18. palimpzest/query/operators/map.py +1 -1
  19. palimpzest/query/operators/physical.py +8 -4
  20. palimpzest/query/operators/project.py +1 -1
  21. palimpzest/query/operators/retrieve.py +7 -23
  22. palimpzest/query/operators/scan.py +1 -1
  23. palimpzest/query/optimizer/cost_model.py +54 -62
  24. palimpzest/query/optimizer/optimizer.py +2 -6
  25. palimpzest/query/optimizer/plan.py +4 -4
  26. palimpzest/query/optimizer/primitives.py +1 -1
  27. palimpzest/query/optimizer/rules.py +8 -26
  28. palimpzest/query/optimizer/tasks.py +3 -3
  29. palimpzest/query/processor/processing_strategy_type.py +2 -2
  30. palimpzest/query/processor/sentinel_processor.py +0 -2
  31. palimpzest/sets.py +2 -3
  32. palimpzest/utils/generation_helpers.py +1 -1
  33. palimpzest/utils/model_helpers.py +27 -9
  34. palimpzest/utils/progress.py +81 -72
  35. {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/METADATA +4 -2
  36. {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/RECORD +39 -38
  37. {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/WHEEL +1 -1
  38. {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/licenses/LICENSE +0 -0
  39. {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/top_level.txt +0 -0
@@ -45,8 +45,8 @@ class ParallelExecutionStrategy(ExecutionStrategy):
45
45
  def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list], future_queues: dict[str, list]) -> bool:
46
46
  """Helper function to check if all upstream operators have finished processing their inputs."""
47
47
  for upstream_op_idx in range(op_idx):
48
- upstream_op_id = plan.operators[upstream_op_idx].get_op_id()
49
- if len(input_queues[upstream_op_id]) > 0 or len(future_queues[upstream_op_id]) > 0:
48
+ upstream_full_op_id = plan.operators[upstream_op_idx].get_full_op_id()
49
+ if len(input_queues[upstream_full_op_id]) > 0 or len(future_queues[upstream_full_op_id]) > 0:
50
50
  return False
51
51
 
52
52
  return True
@@ -57,14 +57,14 @@ class ParallelExecutionStrategy(ExecutionStrategy):
57
57
  the updates to plan stats and progress manager before returning the results from the finished futures.
58
58
  """
59
59
  # get the op_id for the operator
60
- op_id = operator.get_op_id()
60
+ full_op_id = operator.get_full_op_id()
61
61
 
62
62
  # this function is called when the future queue is not empty
63
63
  # and the executor is not busy processing other futures
64
- done_futures, not_done_futures = wait(future_queues[op_id], timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
64
+ done_futures, not_done_futures = wait(future_queues[full_op_id], timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
65
65
 
66
66
  # add the unfinished futures back to the previous op's future queue
67
- future_queues[op_id] = list(not_done_futures)
67
+ future_queues[full_op_id] = list(not_done_futures)
68
68
 
69
69
  # add the finished futures to the input queue for this operator
70
70
  output_records = []
@@ -75,7 +75,7 @@ class ParallelExecutionStrategy(ExecutionStrategy):
75
75
  num_outputs = sum(record.passed_operator for record in records)
76
76
 
77
77
  # update the progress manager
78
- self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
78
+ self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
79
79
 
80
80
  # update plan stats
81
81
  plan_stats.add_record_op_stats(record_op_stats)
@@ -106,35 +106,35 @@ class ParallelExecutionStrategy(ExecutionStrategy):
106
106
  final_op = plan.operators[-1]
107
107
  while self._any_queue_not_empty(input_queues) or self._any_queue_not_empty(future_queues):
108
108
  for op_idx, operator in enumerate(plan.operators):
109
- op_id = operator.get_op_id()
109
+ full_op_id = operator.get_full_op_id()
110
110
 
111
111
  # get any finished futures from the previous operator and add them to the input queue for this operator
112
112
  if not isinstance(operator, ScanPhysicalOp):
113
113
  prev_operator = plan.operators[op_idx - 1]
114
114
  records = self._process_future_results(prev_operator, future_queues, plan_stats)
115
- input_queues[op_id].extend(records)
115
+ input_queues[full_op_id].extend(records)
116
116
 
117
117
  # for the final operator, add any finished futures to the output_records
118
- if operator.get_op_id() == final_op.get_op_id():
118
+ if full_op_id == final_op.get_full_op_id():
119
119
  records = self._process_future_results(operator, future_queues, plan_stats)
120
120
  output_records.extend(records)
121
121
 
122
122
  # if this operator does not have enough inputs to execute, then skip it
123
- num_inputs = len(input_queues[op_id])
123
+ num_inputs = len(input_queues[full_op_id])
124
124
  agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, op_idx, input_queues, future_queues)
125
125
  if num_inputs == 0 or agg_op_not_ready:
126
126
  continue
127
127
 
128
128
  # if this operator is an aggregate, process all the records in the input queue
129
129
  if isinstance(operator, AggregateOp):
130
- input_records = [input_queues[op_id].pop(0) for _ in range(num_inputs)]
130
+ input_records = [input_queues[full_op_id].pop(0) for _ in range(num_inputs)]
131
131
  future = executor.submit(operator, input_records)
132
- future_queues[op_id].append(future)
132
+ future_queues[full_op_id].append(future)
133
133
 
134
134
  else:
135
- input_record = input_queues[op_id].pop(0)
135
+ input_record = input_queues[full_op_id].pop(0)
136
136
  future = executor.submit(operator, input_record)
137
- future_queues[op_id].append(future)
137
+ future_queues[full_op_id].append(future)
138
138
 
139
139
  # break out of loop if the final operator is a LimitScanOp and we've reached its limit
140
140
  if isinstance(final_op, LimitScanOp) and len(output_records) == final_op.limit:
@@ -162,7 +162,174 @@ class ParallelExecutionStrategy(ExecutionStrategy):
162
162
 
163
163
  # initialize input queues and future queues for each operation
164
164
  input_queues = self._create_input_queues(plan)
165
- future_queues = {op.get_op_id(): [] for op in plan.operators}
165
+ future_queues = {op.get_full_op_id(): [] for op in plan.operators}
166
+
167
+ # initialize and start the progress manager
168
+ self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
169
+ self.progress_manager.start()
170
+
171
+ # NOTE: we must handle progress manager outside of _exeecute_plan to ensure that it is shut down correctly;
172
+ # if we don't have the `finally:` branch, then program crashes can cause future program runs to fail
173
+ # because the progress manager cannot get a handle to the console
174
+ try:
175
+ # execute plan
176
+ output_records, plan_stats = self._execute_plan(plan, input_queues, future_queues, plan_stats)
177
+
178
+ finally:
179
+ # finish progress tracking
180
+ self.progress_manager.finish()
181
+
182
+ logger.info(f"Done executing plan: {plan.plan_id}")
183
+ logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
184
+
185
+ return output_records, plan_stats
186
+
187
+
188
+ class SequentialParallelExecutionStrategy(ExecutionStrategy):
189
+ """
190
+ A parallel execution strategy that processes operators sequentially.
191
+ """
192
+ def __init__(self, *args, **kwargs):
193
+ super().__init__(*args, **kwargs)
194
+ self.max_workers = (
195
+ self._get_parallel_max_workers()
196
+ if self.max_workers is None
197
+ else self.max_workers
198
+ )
199
+
200
+ def _get_parallel_max_workers(self):
201
+ # for now, return the number of system CPUs;
202
+ # in the future, we may want to consider the models the user has access to
203
+ # and whether or not they will encounter rate-limits. If they will, we should
204
+ # set the max workers in a manner that is designed to avoid hitting them.
205
+ # Doing this "right" may require considering their logical, physical plan,
206
+ # and tier status with LLM providers. It may also be worth dynamically
207
+ # changing the max_workers in response to 429 errors.
208
+ return max(int(0.8 * multiprocessing.cpu_count()), 1)
209
+
210
+ def _any_queue_not_empty(self, queues: dict[str, list]) -> bool:
211
+ """Helper function to check if any queue is not empty."""
212
+ return any(len(queue) > 0 for queue in queues.values())
213
+
214
+ def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list], future_queues: dict[str, list]) -> bool:
215
+ """Helper function to check if all upstream operators have finished processing their inputs."""
216
+ for upstream_op_idx in range(op_idx):
217
+ upstream_full_op_id = plan.operators[upstream_op_idx].get_full_op_id()
218
+ if len(input_queues[upstream_full_op_id]) > 0 or len(future_queues[upstream_full_op_id]) > 0:
219
+ return False
220
+
221
+ return True
222
+
223
+ def _process_future_results(self, operator: PhysicalOperator, future_queues: dict[str, list], plan_stats: PlanStats) -> list[DataRecord]:
224
+ """
225
+ Helper function which takes an operator, the future queues, and plan stats, and performs
226
+ the updates to plan stats and progress manager before returning the results from the finished futures.
227
+ """
228
+ # get the op_id for the operator
229
+ full_op_id = operator.get_full_op_id()
230
+
231
+ # this function is called when the future queue is not empty
232
+ # and the executor is not busy processing other futures
233
+ done_futures, not_done_futures = wait(future_queues[full_op_id], timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
234
+
235
+ # add the unfinished futures back to the previous op's future queue
236
+ future_queues[full_op_id] = list(not_done_futures)
237
+
238
+ # add the finished futures to the input queue for this operator
239
+ output_records = []
240
+ for future in done_futures:
241
+ record_set: DataRecordSet = future.result()
242
+ records = record_set.data_records
243
+ record_op_stats = record_set.record_op_stats
244
+ num_outputs = sum(record.passed_operator for record in records)
245
+
246
+ # update the progress manager
247
+ self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
248
+
249
+ # update plan stats
250
+ plan_stats.add_record_op_stats(record_op_stats)
251
+
252
+ # add records to the cache
253
+ self._add_records_to_cache(operator.target_cache_id, records)
254
+
255
+ # add records which aren't filtered to the output records
256
+ output_records.extend([record for record in records if record.passed_operator])
257
+
258
+ return output_records
259
+
260
+ def _execute_plan(
261
+ self,
262
+ plan: PhysicalPlan,
263
+ input_queues: dict[str, list],
264
+ future_queues: dict[str, list],
265
+ plan_stats: PlanStats,
266
+ ) -> tuple[list[DataRecord], PlanStats]:
267
+ # process all of the input records using a thread pool
268
+ output_records = []
269
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
270
+ logger.debug(f"Created thread pool with {self.max_workers} workers")
271
+
272
+ # execute the plan until either:
273
+ # 1. all records have been processed, or
274
+ # 2. the final limit operation has completed (we break out of the loop if this happens)
275
+ final_op = plan.operators[-1]
276
+ for op_idx, operator in enumerate(plan.operators):
277
+ full_op_id = operator.get_full_op_id()
278
+ input_queue = input_queues[full_op_id]
279
+
280
+ # if this operator is an aggregate, process all the records in the input queue
281
+ if isinstance(operator, AggregateOp):
282
+ num_inputs = len(input_queue)
283
+ input_records = [input_queue.pop(0) for _ in range(num_inputs)]
284
+ future = executor.submit(operator, input_records)
285
+ future_queues[full_op_id].append(future)
286
+
287
+ else:
288
+ while len(input_queue) > 0:
289
+ input_record = input_queue.pop(0)
290
+ future = executor.submit(operator, input_record)
291
+ future_queues[full_op_id].append(future)
292
+
293
+ # block until all futures for this operator have completed; and add finished futures to next operator's input
294
+ while len(future_queues[full_op_id]) > 0:
295
+ records = self._process_future_results(operator, future_queues, plan_stats)
296
+
297
+ # get any finished futures from the previous operator and add them to the input queue for this operator
298
+ if full_op_id != final_op.get_full_op_id():
299
+ next_op_id = plan.operators[op_idx + 1].get_full_op_id()
300
+ input_queues[next_op_id].extend(records)
301
+
302
+ # for the final operator, add any finished futures to the output_records
303
+ else:
304
+ output_records.extend(records)
305
+
306
+ # break out of loop if the final operator is a LimitScanOp and we've reached its limit
307
+ if isinstance(final_op, LimitScanOp) and len(output_records) == final_op.limit:
308
+ break
309
+
310
+ # close the cache
311
+ self._close_cache([op.target_cache_id for op in plan.operators])
312
+
313
+ # finalize plan stats
314
+ plan_stats.finish()
315
+
316
+ return output_records, plan_stats
317
+
318
+
319
+ def execute_plan(self, plan: PhysicalPlan):
320
+ """Initialize the stats and execute the plan."""
321
+ # for now, assert that the first operator in the plan is a ScanPhysicalOp
322
+ assert isinstance(plan.operators[0], ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
323
+ logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
324
+ logger.info(f"Plan Details: {plan}")
325
+
326
+ # initialize plan stats
327
+ plan_stats = PlanStats.from_plan(plan)
328
+ plan_stats.start()
329
+
330
+ # initialize input queues and future queues for each operation
331
+ input_queues = self._create_input_queues(plan)
332
+ future_queues = {op.get_full_op_id(): [] for op in plan.operators}
166
333
 
167
334
  # initialize and start the progress manager
168
335
  self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
@@ -30,35 +30,35 @@ class SequentialSingleThreadExecutionStrategy(ExecutionStrategy):
30
30
  output_records = []
31
31
  for op_idx, operator in enumerate(plan.operators):
32
32
  # if we've filtered out all records, terminate early
33
- op_id = operator.get_op_id()
34
- num_inputs = len(input_queues[op_id])
33
+ full_op_id = operator.get_full_op_id()
34
+ num_inputs = len(input_queues[full_op_id])
35
35
  if num_inputs == 0:
36
36
  break
37
37
 
38
38
  # begin to process this operator
39
39
  records, record_op_stats = [], []
40
- logger.info(f"Processing operator {operator.op_name()} ({op_id})")
40
+ logger.info(f"Processing operator {operator.op_name()} ({full_op_id})")
41
41
 
42
42
  # if this operator is an aggregate, process all the records in the input_queue
43
43
  if isinstance(operator, AggregateOp):
44
- record_set = operator(candidates=input_queues[op_id])
44
+ record_set = operator(candidates=input_queues[full_op_id])
45
45
  records = record_set.data_records
46
46
  record_op_stats = record_set.record_op_stats
47
47
  num_outputs = sum(record.passed_operator for record in records)
48
48
 
49
49
  # update the progress manager
50
- self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
50
+ self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
51
51
 
52
52
  # otherwise, process the records in the input queue for this operator one at a time
53
53
  else:
54
- for input_record in input_queues[op_id]:
54
+ for input_record in input_queues[full_op_id]:
55
55
  record_set = operator(input_record)
56
56
  records.extend(record_set.data_records)
57
57
  record_op_stats.extend(record_set.record_op_stats)
58
58
  num_outputs = sum(record.passed_operator for record in record_set.data_records)
59
59
 
60
60
  # update the progress manager
61
- self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
61
+ self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
62
62
 
63
63
  # finish early if this is a limit
64
64
  if isinstance(operator, LimitScanOp) and len(records) == operator.limit:
@@ -73,10 +73,10 @@ class SequentialSingleThreadExecutionStrategy(ExecutionStrategy):
73
73
  # update next input_queue (if it exists)
74
74
  output_records = [record for record in records if record.passed_operator]
75
75
  if op_idx + 1 < len(plan.operators):
76
- next_op_id = plan.operators[op_idx + 1].get_op_id()
77
- input_queues[next_op_id] = output_records
76
+ next_full_op_id = plan.operators[op_idx + 1].get_full_op_id()
77
+ input_queues[next_full_op_id] = output_records
78
78
 
79
- logger.info(f"Finished processing operator {operator.op_name()} ({operator.get_op_id()}), and generated {len(records)} records")
79
+ logger.info(f"Finished processing operator {operator.op_name()} ({operator.get_full_op_id()}), and generated {len(records)} records")
80
80
 
81
81
  # close the cache
82
82
  self._close_cache([op.target_cache_id for op in plan.operators])
@@ -146,8 +146,8 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
146
146
  def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list]) -> bool:
147
147
  """Helper function to check if all upstream operators have finished processing their inputs."""
148
148
  for upstream_op_idx in range(op_idx):
149
- upstream_op_id = plan.operators[upstream_op_idx].get_op_id()
150
- if len(input_queues[upstream_op_id]) > 0:
149
+ upstream_full_op_id = plan.operators[upstream_op_idx].get_full_op_id()
150
+ if len(input_queues[upstream_full_op_id]) > 0:
151
151
  return False
152
152
 
153
153
  return True
@@ -160,8 +160,8 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
160
160
  while self._any_queue_not_empty(input_queues):
161
161
  for op_idx, operator in enumerate(plan.operators):
162
162
  # if this operator does not have enough inputs to execute, then skip it
163
- op_id = operator.get_op_id()
164
- num_inputs = len(input_queues[op_id])
163
+ full_op_id = operator.get_full_op_id()
164
+ num_inputs = len(input_queues[full_op_id])
165
165
  agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, op_idx, input_queues)
166
166
  if num_inputs == 0 or agg_op_not_ready:
167
167
  continue
@@ -171,25 +171,25 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
171
171
 
172
172
  # if the next operator is an aggregate, process all the records in the input_queue
173
173
  if isinstance(operator, AggregateOp):
174
- input_records = [input_queues[op_id].pop(0) for _ in range(num_inputs)]
174
+ input_records = [input_queues[full_op_id].pop(0) for _ in range(num_inputs)]
175
175
  record_set = operator(candidates=input_records)
176
176
  records = record_set.data_records
177
177
  record_op_stats = record_set.record_op_stats
178
178
  num_outputs = sum(record.passed_operator for record in records)
179
179
 
180
180
  # update the progress manager
181
- self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
181
+ self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
182
182
 
183
183
  # otherwise, process the next record in the input queue for this operator
184
184
  else:
185
- input_record = input_queues[op_id].pop(0)
185
+ input_record = input_queues[full_op_id].pop(0)
186
186
  record_set = operator(input_record)
187
187
  records = record_set.data_records
188
188
  record_op_stats = record_set.record_op_stats
189
189
  num_outputs = sum(record.passed_operator for record in records)
190
190
 
191
191
  # update the progress manager
192
- self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
192
+ self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
193
193
 
194
194
  # update plan stats
195
195
  plan_stats.add_record_op_stats(record_op_stats)
@@ -200,12 +200,12 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
200
200
  # update next input_queue or final_output_records
201
201
  output_records = [record for record in records if record.passed_operator]
202
202
  if op_idx + 1 < len(plan.operators):
203
- next_op_id = plan.operators[op_idx + 1].get_op_id()
204
- input_queues[next_op_id].extend(output_records)
203
+ next_full_op_id = plan.operators[op_idx + 1].get_full_op_id()
204
+ input_queues[next_full_op_id].extend(output_records)
205
205
  else:
206
206
  final_output_records.extend(output_records)
207
207
 
208
- logger.info(f"Finished processing operator {operator.op_name()} ({operator.get_op_id()}) on {num_inputs} records")
208
+ logger.info(f"Finished processing operator {operator.op_name()} ({operator.get_full_op_id()}) on {num_inputs} records")
209
209
 
210
210
  # break out of loop if the final operator is a LimitScanOp and we've reached its limit
211
211
  if isinstance(plan.operators[-1], LimitScanOp) and len(final_output_records) == plan.operators[-1].limit:
@@ -22,10 +22,9 @@ class APIClientFactory:
22
22
  @staticmethod
23
23
  def _create_client(api_client: APIClient, api_key: str):
24
24
  """Create a new client instance based on the api_client name."""
25
- match api_client:
26
- case APIClient.OPENAI:
27
- return OpenAI(api_key=api_key)
28
- case APIClient.TOGETHER:
29
- return Together(api_key=api_key)
30
- case _:
31
- raise ValueError(f"Unknown api_client: {api_client}")
25
+ if api_client == APIClient.OPENAI:
26
+ return OpenAI(api_key=api_key)
27
+ elif api_client == APIClient.TOGETHER:
28
+ return Together(api_key=api_key)
29
+ else:
30
+ raise ValueError(f"Unknown api_client: {api_client}")
@@ -49,10 +49,10 @@ def generator_factory(
49
49
  """
50
50
  Factory function to return the correct generator based on the model, strategy, and cardinality.
51
51
  """
52
- if model in [Model.GPT_4o, Model.GPT_4o_MINI, Model.GPT_4o_V, Model.GPT_4o_MINI_V]:
52
+ if model.is_openai_model():
53
53
  return OpenAIGenerator(model, prompt_strategy, cardinality, verbose)
54
54
 
55
- elif model in [Model.MIXTRAL, Model.LLAMA3, Model.LLAMA3_V, Model.DEEPSEEK]:
55
+ elif model.is_together_model():
56
56
  return TogetherGenerator(model, prompt_strategy, cardinality, verbose)
57
57
 
58
58
  raise Exception(f"Unsupported model: {model}")
@@ -61,8 +61,6 @@ def generator_factory(
61
61
  def get_api_key(key: str) -> str:
62
62
  # get API key from environment or throw an exception if it's not set
63
63
  if key not in os.environ:
64
- print(f"KEY: {key}")
65
- print(f"{os.environ.keys()}")
66
64
  raise ValueError("key not found in environment variables")
67
65
 
68
66
  return os.environ[key]
@@ -464,7 +462,7 @@ class OpenAIGenerator(BaseGenerator[str | list[str], str]):
464
462
  verbose: bool = False,
465
463
  ):
466
464
  # assert that model is an OpenAI model
467
- assert model in [Model.GPT_4o, Model.GPT_4o_MINI, Model.GPT_4o_V, Model.GPT_4o_MINI_V]
465
+ assert model.is_openai_model()
468
466
  super().__init__(model, prompt_strategy, cardinality, verbose, "developer")
469
467
 
470
468
  def _get_client_or_model(self, **kwargs) -> OpenAI:
@@ -508,7 +506,7 @@ class TogetherGenerator(BaseGenerator[str | list[str], str]):
508
506
  verbose: bool = False,
509
507
  ):
510
508
  # assert that model is a model offered by Together
511
- assert model in [Model.MIXTRAL, Model.LLAMA3, Model.LLAMA3_V, Model.DEEPSEEK]
509
+ assert model.is_together_model()
512
510
  super().__init__(model, prompt_strategy, cardinality, verbose, "system")
513
511
 
514
512
  def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
@@ -525,7 +523,7 @@ class TogetherGenerator(BaseGenerator[str | list[str], str]):
525
523
  For LLAMA3, the payload needs to be in a {"role": <role>, "content": <content>} format.
526
524
  """
527
525
  # for other models, use our standard payload generation
528
- if self.model != Model.LLAMA3:
526
+ if not self.model.is_llama_model():
529
527
  return super()._generate_payload(messages, **kwargs)
530
528
 
531
529
  # get basic parameters
@@ -593,7 +591,6 @@ def code_ensemble_execution(
593
591
  preds.append(pred)
594
592
 
595
593
  preds = [pred for pred in preds if pred is not None]
596
- print(preds)
597
594
 
598
595
  if len(preds) == 1:
599
596
  majority_response = preds[0]
@@ -138,7 +138,7 @@ class ApplyGroupByOp(AggregateOp):
138
138
  record_parent_id=dr.parent_id,
139
139
  record_source_idx=dr.source_idx,
140
140
  record_state=dr.to_dict(include_bytes=False),
141
- op_id=self.get_op_id(),
141
+ full_op_id=self.get_full_op_id(),
142
142
  logical_op_id=self.logical_op_id,
143
143
  op_name=self.op_name(),
144
144
  time_per_record=total_time / len(drs),
@@ -198,7 +198,8 @@ class AverageAggregateOp(AggregateOp):
198
198
  record_parent_id=dr.parent_id,
199
199
  record_source_idx=dr.source_idx,
200
200
  record_state=dr.to_dict(include_bytes=False),
201
- op_id=self.get_op_id(),
201
+ full_op_id=self.get_full_op_id(),
202
+ logical_op_id=self.logical_op_id,
202
203
  op_name=self.op_name(),
203
204
  time_per_record=time.time() - start_time,
204
205
  cost_per_record=0.0,
@@ -251,7 +252,7 @@ class CountAggregateOp(AggregateOp):
251
252
  record_parent_id=dr.parent_id,
252
253
  record_source_idx=dr.source_idx,
253
254
  record_state=dr.to_dict(include_bytes=False),
254
- op_id=self.get_op_id(),
255
+ full_op_id=self.get_full_op_id(),
255
256
  logical_op_id=self.logical_op_id,
256
257
  op_name=self.op_name(),
257
258
  time_per_record=time.time() - start_time,
@@ -115,7 +115,7 @@ class ConvertOp(PhysicalOperator, ABC):
115
115
  record_parent_id=dr.parent_id,
116
116
  record_source_idx=dr.source_idx,
117
117
  record_state=dr.to_dict(include_bytes=False),
118
- op_id=self.get_op_id(),
118
+ full_op_id=self.get_full_op_id(),
119
119
  logical_op_id=self.logical_op_id,
120
120
  op_name=self.op_name(),
121
121
  time_per_record=time_per_record,
@@ -84,7 +84,7 @@ class FilterOp(PhysicalOperator, ABC):
84
84
  record_parent_id=dr.parent_id,
85
85
  record_source_idx=dr.source_idx,
86
86
  record_state=dr.to_dict(include_bytes=False),
87
- op_id=self.get_op_id(),
87
+ full_op_id=self.get_full_op_id(),
88
88
  logical_op_id=self.logical_op_id,
89
89
  op_name=self.op_name(),
90
90
  time_per_record=total_time,
@@ -44,7 +44,7 @@ class LimitScanOp(PhysicalOperator):
44
44
  record_parent_id=dr.parent_id,
45
45
  record_source_idx=dr.source_idx,
46
46
  record_state=dr.to_dict(include_bytes=False),
47
- op_id=self.get_op_id(),
47
+ full_op_id=self.get_full_op_id(),
48
48
  logical_op_id=self.logical_op_id,
49
49
  op_name=self.op_name(),
50
50
  time_per_record=0.0,
@@ -47,7 +47,7 @@ class MapOp(PhysicalOperator):
47
47
  record_parent_id=record.parent_id,
48
48
  record_source_idx=record.source_idx,
49
49
  record_state=record.to_dict(include_bytes=False),
50
- op_id=self.get_op_id(),
50
+ full_op_id=self.get_full_op_id(),
51
51
  logical_op_id=self.logical_op_id,
52
52
  op_name=self.op_name(),
53
53
  time_per_record=total_time,
@@ -58,8 +58,8 @@ class PhysicalOperator:
58
58
  return op
59
59
 
60
60
  def __eq__(self, other) -> bool:
61
- all_id_params_match = all(value == getattr(other, key) for key, value in self.get_id_params().items())
62
- return isinstance(other, self.__class__) and all_id_params_match
61
+ all_op_params_match = all(value == getattr(other, key) for key, value in self.get_op_params().items())
62
+ return isinstance(other, self.__class__) and all_op_params_match
63
63
 
64
64
  def copy(self) -> PhysicalOperator:
65
65
  return self.__class__(**self.get_op_params())
@@ -79,7 +79,8 @@ class PhysicalOperator:
79
79
  This is particularly true for convert operations, where the output schema
80
80
  is now the union of the input and output schemas of the logical operator.
81
81
  """
82
- return {"generated_fields": self.generated_fields}
82
+ # return {"generated_fields": self.generated_fields}
83
+ return {}
83
84
 
84
85
  def get_op_params(self) -> dict:
85
86
  """
@@ -129,8 +130,11 @@ class PhysicalOperator:
129
130
  def get_logical_op_id(self) -> str | None:
130
131
  return self.logical_op_id
131
132
 
133
+ def get_full_op_id(self):
134
+ return f"{self.get_logical_op_id()}-{self.get_op_id()}"
135
+
132
136
  def __hash__(self):
133
- return int(self.op_id, 16)
137
+ return int(self.op_id, 16) # NOTE: should we use self.get_full_op_id() instead?
134
138
 
135
139
  def get_model_name(self) -> str | None:
136
140
  """Returns the name of the model used by the physical operator (if it sets self.model). Otherwise, it returns None."""
@@ -42,7 +42,7 @@ class ProjectOp(PhysicalOperator):
42
42
  record_parent_id=dr.parent_id,
43
43
  record_source_idx=dr.source_idx,
44
44
  record_state=dr.to_dict(include_bytes=False),
45
- op_id=self.get_op_id(),
45
+ full_op_id=self.get_full_op_id(),
46
46
  logical_op_id=self.logical_op_id,
47
47
  op_name=self.op_name(),
48
48
  time_per_record=0.0,
@@ -8,7 +8,6 @@ from chromadb.api.models.Collection import Collection
8
8
  from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
9
9
  from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
10
10
  from openai import OpenAI
11
- from ragatouille.RAGPretrainedModel import RAGPretrainedModel
12
11
  from sentence_transformers import SentenceTransformer
13
12
 
14
13
  from palimpzest.constants import MODEL_CARDS, Model
@@ -21,7 +20,7 @@ from palimpzest.query.operators.physical import PhysicalOperator
21
20
  class RetrieveOp(PhysicalOperator):
22
21
  def __init__(
23
22
  self,
24
- index: Collection | RAGPretrainedModel,
23
+ index: Collection,
25
24
  search_attr: str,
26
25
  output_attrs: list[dict] | type[Schema],
27
26
  search_func: Callable | None,
@@ -33,7 +32,7 @@ class RetrieveOp(PhysicalOperator):
33
32
  Initialize the RetrieveOp object.
34
33
 
35
34
  Args:
36
- index (Collection | RAGPretrainedModel): The PZ index to use for retrieval.
35
+ index (Collection): The PZ index to use for retrieval.
37
36
  search_attr (str): The attribute to search on.
38
37
  output_attrs (list[dict]): The output fields containing the results of the search.
39
38
  search_func (Callable | None): The function to use for searching the index. If None, the default search function will be used.
@@ -100,7 +99,7 @@ class RetrieveOp(PhysicalOperator):
100
99
  quality=1.0,
101
100
  )
102
101
 
103
- def default_search_func(self, index: Collection | RAGPretrainedModel, query: list[str] | list[list[float]], k: int) -> list[str] | list[list[str]]:
102
+ def default_search_func(self, index: Collection, query: list[str] | list[list[float]], k: int) -> list[str] | list[list[str]]:
104
103
  """
105
104
  Default search function for the Retrieve operation. This function uses the index to
106
105
  retrieve the top-k results for the given query. The query will be a (possibly singleton)
@@ -132,24 +131,8 @@ class RetrieveOp(PhysicalOperator):
132
131
  # NOTE: self.output_field_names must be a singleton for default_search_func to be used
133
132
  return {self.output_field_names[0]: final_results}
134
133
 
135
- elif isinstance(index, RAGPretrainedModel):
136
- # if the index is a rag model, use the rag model to get the top k results
137
- results = index.search(query, k=k)
138
-
139
- # the results will be a list[dict]; if the input is a singleton list, however
140
- # it will be a list[list[dict]]; if the input is a list of lists
141
- final_results = []
142
- if is_singleton_list:
143
- final_results = [result["content"] for result in results]
144
- else:
145
- for query_results in results:
146
- final_results.append([result["content"] for result in query_results])
147
-
148
- # NOTE: self.output_field_names must be a singleton for default_search_func to be used
149
- return {self.output_field_names[0]: final_results}
150
-
151
134
  else:
152
- raise ValueError("Unsupported index type. Must be either a Collection or RAGPretrainedModel.")
135
+ raise ValueError("Unsupported index type. Must be either a Collection.")
153
136
 
154
137
  def _create_record_set(
155
138
  self,
@@ -180,7 +163,7 @@ class RetrieveOp(PhysicalOperator):
180
163
  record_parent_id=output_dr.parent_id,
181
164
  record_source_idx=output_dr.source_idx,
182
165
  record_state=record_state,
183
- op_id=self.get_op_id(),
166
+ full_op_id=self.get_full_op_id(),
184
167
  logical_op_id=self.logical_op_id,
185
168
  op_name=self.op_name(),
186
169
  time_per_record=total_time,
@@ -231,7 +214,8 @@ class RetrieveOp(PhysicalOperator):
231
214
 
232
215
  model_name = self.index._embedding_function._model_name if uses_openai_embedding_fcn else "clip-ViT-B-32"
233
216
  err_msg = f"For Chromadb, we currently only support `text-embedding-3-small` and `clip-ViT-B-32`; your index uses: {model_name}"
234
- assert model_name in [Model.TEXT_EMBEDDING_3_SMALL.value, Model.CLIP_VIT_B_32.value], err_msg
217
+ embedding_model_names = [model.value for model in Model if model.is_embedding_model()]
218
+ assert model_name in embedding_model_names, err_msg
235
219
 
236
220
  # compute embeddings
237
221
  try: