palimpzest 0.7.7__py3-none-any.whl → 0.7.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +113 -75
- palimpzest/core/data/dataclasses.py +55 -38
- palimpzest/core/elements/index.py +5 -15
- palimpzest/core/elements/records.py +1 -1
- palimpzest/prompts/prompt_factory.py +1 -1
- palimpzest/query/execution/all_sample_execution_strategy.py +216 -0
- palimpzest/query/execution/execution_strategy.py +4 -4
- palimpzest/query/execution/execution_strategy_type.py +7 -1
- palimpzest/query/execution/mab_execution_strategy.py +184 -72
- palimpzest/query/execution/parallel_execution_strategy.py +182 -15
- palimpzest/query/execution/single_threaded_execution_strategy.py +21 -21
- palimpzest/query/generators/api_client_factory.py +6 -7
- palimpzest/query/generators/generators.py +5 -8
- palimpzest/query/operators/aggregate.py +4 -3
- palimpzest/query/operators/convert.py +1 -1
- palimpzest/query/operators/filter.py +1 -1
- palimpzest/query/operators/limit.py +1 -1
- palimpzest/query/operators/map.py +1 -1
- palimpzest/query/operators/physical.py +8 -4
- palimpzest/query/operators/project.py +1 -1
- palimpzest/query/operators/retrieve.py +7 -23
- palimpzest/query/operators/scan.py +1 -1
- palimpzest/query/optimizer/cost_model.py +54 -62
- palimpzest/query/optimizer/optimizer.py +2 -6
- palimpzest/query/optimizer/plan.py +4 -4
- palimpzest/query/optimizer/primitives.py +1 -1
- palimpzest/query/optimizer/rules.py +8 -26
- palimpzest/query/optimizer/tasks.py +3 -3
- palimpzest/query/processor/processing_strategy_type.py +2 -2
- palimpzest/query/processor/sentinel_processor.py +0 -2
- palimpzest/sets.py +2 -3
- palimpzest/utils/generation_helpers.py +1 -1
- palimpzest/utils/model_helpers.py +27 -9
- palimpzest/utils/progress.py +81 -72
- {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/METADATA +4 -2
- {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/RECORD +39 -38
- {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/WHEEL +1 -1
- {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.7.dist-info → palimpzest-0.7.9.dist-info}/top_level.txt +0 -0
|
@@ -45,8 +45,8 @@ class ParallelExecutionStrategy(ExecutionStrategy):
|
|
|
45
45
|
def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list], future_queues: dict[str, list]) -> bool:
|
|
46
46
|
"""Helper function to check if all upstream operators have finished processing their inputs."""
|
|
47
47
|
for upstream_op_idx in range(op_idx):
|
|
48
|
-
|
|
49
|
-
if len(input_queues[
|
|
48
|
+
upstream_full_op_id = plan.operators[upstream_op_idx].get_full_op_id()
|
|
49
|
+
if len(input_queues[upstream_full_op_id]) > 0 or len(future_queues[upstream_full_op_id]) > 0:
|
|
50
50
|
return False
|
|
51
51
|
|
|
52
52
|
return True
|
|
@@ -57,14 +57,14 @@ class ParallelExecutionStrategy(ExecutionStrategy):
|
|
|
57
57
|
the updates to plan stats and progress manager before returning the results from the finished futures.
|
|
58
58
|
"""
|
|
59
59
|
# get the op_id for the operator
|
|
60
|
-
|
|
60
|
+
full_op_id = operator.get_full_op_id()
|
|
61
61
|
|
|
62
62
|
# this function is called when the future queue is not empty
|
|
63
63
|
# and the executor is not busy processing other futures
|
|
64
|
-
done_futures, not_done_futures = wait(future_queues[
|
|
64
|
+
done_futures, not_done_futures = wait(future_queues[full_op_id], timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
|
|
65
65
|
|
|
66
66
|
# add the unfinished futures back to the previous op's future queue
|
|
67
|
-
future_queues[
|
|
67
|
+
future_queues[full_op_id] = list(not_done_futures)
|
|
68
68
|
|
|
69
69
|
# add the finished futures to the input queue for this operator
|
|
70
70
|
output_records = []
|
|
@@ -75,7 +75,7 @@ class ParallelExecutionStrategy(ExecutionStrategy):
|
|
|
75
75
|
num_outputs = sum(record.passed_operator for record in records)
|
|
76
76
|
|
|
77
77
|
# update the progress manager
|
|
78
|
-
self.progress_manager.incr(
|
|
78
|
+
self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
79
79
|
|
|
80
80
|
# update plan stats
|
|
81
81
|
plan_stats.add_record_op_stats(record_op_stats)
|
|
@@ -106,35 +106,35 @@ class ParallelExecutionStrategy(ExecutionStrategy):
|
|
|
106
106
|
final_op = plan.operators[-1]
|
|
107
107
|
while self._any_queue_not_empty(input_queues) or self._any_queue_not_empty(future_queues):
|
|
108
108
|
for op_idx, operator in enumerate(plan.operators):
|
|
109
|
-
|
|
109
|
+
full_op_id = operator.get_full_op_id()
|
|
110
110
|
|
|
111
111
|
# get any finished futures from the previous operator and add them to the input queue for this operator
|
|
112
112
|
if not isinstance(operator, ScanPhysicalOp):
|
|
113
113
|
prev_operator = plan.operators[op_idx - 1]
|
|
114
114
|
records = self._process_future_results(prev_operator, future_queues, plan_stats)
|
|
115
|
-
input_queues[
|
|
115
|
+
input_queues[full_op_id].extend(records)
|
|
116
116
|
|
|
117
117
|
# for the final operator, add any finished futures to the output_records
|
|
118
|
-
if
|
|
118
|
+
if full_op_id == final_op.get_full_op_id():
|
|
119
119
|
records = self._process_future_results(operator, future_queues, plan_stats)
|
|
120
120
|
output_records.extend(records)
|
|
121
121
|
|
|
122
122
|
# if this operator does not have enough inputs to execute, then skip it
|
|
123
|
-
num_inputs = len(input_queues[
|
|
123
|
+
num_inputs = len(input_queues[full_op_id])
|
|
124
124
|
agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, op_idx, input_queues, future_queues)
|
|
125
125
|
if num_inputs == 0 or agg_op_not_ready:
|
|
126
126
|
continue
|
|
127
127
|
|
|
128
128
|
# if this operator is an aggregate, process all the records in the input queue
|
|
129
129
|
if isinstance(operator, AggregateOp):
|
|
130
|
-
input_records = [input_queues[
|
|
130
|
+
input_records = [input_queues[full_op_id].pop(0) for _ in range(num_inputs)]
|
|
131
131
|
future = executor.submit(operator, input_records)
|
|
132
|
-
future_queues[
|
|
132
|
+
future_queues[full_op_id].append(future)
|
|
133
133
|
|
|
134
134
|
else:
|
|
135
|
-
input_record = input_queues[
|
|
135
|
+
input_record = input_queues[full_op_id].pop(0)
|
|
136
136
|
future = executor.submit(operator, input_record)
|
|
137
|
-
future_queues[
|
|
137
|
+
future_queues[full_op_id].append(future)
|
|
138
138
|
|
|
139
139
|
# break out of loop if the final operator is a LimitScanOp and we've reached its limit
|
|
140
140
|
if isinstance(final_op, LimitScanOp) and len(output_records) == final_op.limit:
|
|
@@ -162,7 +162,174 @@ class ParallelExecutionStrategy(ExecutionStrategy):
|
|
|
162
162
|
|
|
163
163
|
# initialize input queues and future queues for each operation
|
|
164
164
|
input_queues = self._create_input_queues(plan)
|
|
165
|
-
future_queues = {op.
|
|
165
|
+
future_queues = {op.get_full_op_id(): [] for op in plan.operators}
|
|
166
|
+
|
|
167
|
+
# initialize and start the progress manager
|
|
168
|
+
self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
|
|
169
|
+
self.progress_manager.start()
|
|
170
|
+
|
|
171
|
+
# NOTE: we must handle progress manager outside of _exeecute_plan to ensure that it is shut down correctly;
|
|
172
|
+
# if we don't have the `finally:` branch, then program crashes can cause future program runs to fail
|
|
173
|
+
# because the progress manager cannot get a handle to the console
|
|
174
|
+
try:
|
|
175
|
+
# execute plan
|
|
176
|
+
output_records, plan_stats = self._execute_plan(plan, input_queues, future_queues, plan_stats)
|
|
177
|
+
|
|
178
|
+
finally:
|
|
179
|
+
# finish progress tracking
|
|
180
|
+
self.progress_manager.finish()
|
|
181
|
+
|
|
182
|
+
logger.info(f"Done executing plan: {plan.plan_id}")
|
|
183
|
+
logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
|
|
184
|
+
|
|
185
|
+
return output_records, plan_stats
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class SequentialParallelExecutionStrategy(ExecutionStrategy):
|
|
189
|
+
"""
|
|
190
|
+
A parallel execution strategy that processes operators sequentially.
|
|
191
|
+
"""
|
|
192
|
+
def __init__(self, *args, **kwargs):
|
|
193
|
+
super().__init__(*args, **kwargs)
|
|
194
|
+
self.max_workers = (
|
|
195
|
+
self._get_parallel_max_workers()
|
|
196
|
+
if self.max_workers is None
|
|
197
|
+
else self.max_workers
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
def _get_parallel_max_workers(self):
|
|
201
|
+
# for now, return the number of system CPUs;
|
|
202
|
+
# in the future, we may want to consider the models the user has access to
|
|
203
|
+
# and whether or not they will encounter rate-limits. If they will, we should
|
|
204
|
+
# set the max workers in a manner that is designed to avoid hitting them.
|
|
205
|
+
# Doing this "right" may require considering their logical, physical plan,
|
|
206
|
+
# and tier status with LLM providers. It may also be worth dynamically
|
|
207
|
+
# changing the max_workers in response to 429 errors.
|
|
208
|
+
return max(int(0.8 * multiprocessing.cpu_count()), 1)
|
|
209
|
+
|
|
210
|
+
def _any_queue_not_empty(self, queues: dict[str, list]) -> bool:
|
|
211
|
+
"""Helper function to check if any queue is not empty."""
|
|
212
|
+
return any(len(queue) > 0 for queue in queues.values())
|
|
213
|
+
|
|
214
|
+
def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list], future_queues: dict[str, list]) -> bool:
|
|
215
|
+
"""Helper function to check if all upstream operators have finished processing their inputs."""
|
|
216
|
+
for upstream_op_idx in range(op_idx):
|
|
217
|
+
upstream_full_op_id = plan.operators[upstream_op_idx].get_full_op_id()
|
|
218
|
+
if len(input_queues[upstream_full_op_id]) > 0 or len(future_queues[upstream_full_op_id]) > 0:
|
|
219
|
+
return False
|
|
220
|
+
|
|
221
|
+
return True
|
|
222
|
+
|
|
223
|
+
def _process_future_results(self, operator: PhysicalOperator, future_queues: dict[str, list], plan_stats: PlanStats) -> list[DataRecord]:
|
|
224
|
+
"""
|
|
225
|
+
Helper function which takes an operator, the future queues, and plan stats, and performs
|
|
226
|
+
the updates to plan stats and progress manager before returning the results from the finished futures.
|
|
227
|
+
"""
|
|
228
|
+
# get the op_id for the operator
|
|
229
|
+
full_op_id = operator.get_full_op_id()
|
|
230
|
+
|
|
231
|
+
# this function is called when the future queue is not empty
|
|
232
|
+
# and the executor is not busy processing other futures
|
|
233
|
+
done_futures, not_done_futures = wait(future_queues[full_op_id], timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
|
|
234
|
+
|
|
235
|
+
# add the unfinished futures back to the previous op's future queue
|
|
236
|
+
future_queues[full_op_id] = list(not_done_futures)
|
|
237
|
+
|
|
238
|
+
# add the finished futures to the input queue for this operator
|
|
239
|
+
output_records = []
|
|
240
|
+
for future in done_futures:
|
|
241
|
+
record_set: DataRecordSet = future.result()
|
|
242
|
+
records = record_set.data_records
|
|
243
|
+
record_op_stats = record_set.record_op_stats
|
|
244
|
+
num_outputs = sum(record.passed_operator for record in records)
|
|
245
|
+
|
|
246
|
+
# update the progress manager
|
|
247
|
+
self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
248
|
+
|
|
249
|
+
# update plan stats
|
|
250
|
+
plan_stats.add_record_op_stats(record_op_stats)
|
|
251
|
+
|
|
252
|
+
# add records to the cache
|
|
253
|
+
self._add_records_to_cache(operator.target_cache_id, records)
|
|
254
|
+
|
|
255
|
+
# add records which aren't filtered to the output records
|
|
256
|
+
output_records.extend([record for record in records if record.passed_operator])
|
|
257
|
+
|
|
258
|
+
return output_records
|
|
259
|
+
|
|
260
|
+
def _execute_plan(
|
|
261
|
+
self,
|
|
262
|
+
plan: PhysicalPlan,
|
|
263
|
+
input_queues: dict[str, list],
|
|
264
|
+
future_queues: dict[str, list],
|
|
265
|
+
plan_stats: PlanStats,
|
|
266
|
+
) -> tuple[list[DataRecord], PlanStats]:
|
|
267
|
+
# process all of the input records using a thread pool
|
|
268
|
+
output_records = []
|
|
269
|
+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
270
|
+
logger.debug(f"Created thread pool with {self.max_workers} workers")
|
|
271
|
+
|
|
272
|
+
# execute the plan until either:
|
|
273
|
+
# 1. all records have been processed, or
|
|
274
|
+
# 2. the final limit operation has completed (we break out of the loop if this happens)
|
|
275
|
+
final_op = plan.operators[-1]
|
|
276
|
+
for op_idx, operator in enumerate(plan.operators):
|
|
277
|
+
full_op_id = operator.get_full_op_id()
|
|
278
|
+
input_queue = input_queues[full_op_id]
|
|
279
|
+
|
|
280
|
+
# if this operator is an aggregate, process all the records in the input queue
|
|
281
|
+
if isinstance(operator, AggregateOp):
|
|
282
|
+
num_inputs = len(input_queue)
|
|
283
|
+
input_records = [input_queue.pop(0) for _ in range(num_inputs)]
|
|
284
|
+
future = executor.submit(operator, input_records)
|
|
285
|
+
future_queues[full_op_id].append(future)
|
|
286
|
+
|
|
287
|
+
else:
|
|
288
|
+
while len(input_queue) > 0:
|
|
289
|
+
input_record = input_queue.pop(0)
|
|
290
|
+
future = executor.submit(operator, input_record)
|
|
291
|
+
future_queues[full_op_id].append(future)
|
|
292
|
+
|
|
293
|
+
# block until all futures for this operator have completed; and add finished futures to next operator's input
|
|
294
|
+
while len(future_queues[full_op_id]) > 0:
|
|
295
|
+
records = self._process_future_results(operator, future_queues, plan_stats)
|
|
296
|
+
|
|
297
|
+
# get any finished futures from the previous operator and add them to the input queue for this operator
|
|
298
|
+
if full_op_id != final_op.get_full_op_id():
|
|
299
|
+
next_op_id = plan.operators[op_idx + 1].get_full_op_id()
|
|
300
|
+
input_queues[next_op_id].extend(records)
|
|
301
|
+
|
|
302
|
+
# for the final operator, add any finished futures to the output_records
|
|
303
|
+
else:
|
|
304
|
+
output_records.extend(records)
|
|
305
|
+
|
|
306
|
+
# break out of loop if the final operator is a LimitScanOp and we've reached its limit
|
|
307
|
+
if isinstance(final_op, LimitScanOp) and len(output_records) == final_op.limit:
|
|
308
|
+
break
|
|
309
|
+
|
|
310
|
+
# close the cache
|
|
311
|
+
self._close_cache([op.target_cache_id for op in plan.operators])
|
|
312
|
+
|
|
313
|
+
# finalize plan stats
|
|
314
|
+
plan_stats.finish()
|
|
315
|
+
|
|
316
|
+
return output_records, plan_stats
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def execute_plan(self, plan: PhysicalPlan):
|
|
320
|
+
"""Initialize the stats and execute the plan."""
|
|
321
|
+
# for now, assert that the first operator in the plan is a ScanPhysicalOp
|
|
322
|
+
assert isinstance(plan.operators[0], ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
|
|
323
|
+
logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
|
|
324
|
+
logger.info(f"Plan Details: {plan}")
|
|
325
|
+
|
|
326
|
+
# initialize plan stats
|
|
327
|
+
plan_stats = PlanStats.from_plan(plan)
|
|
328
|
+
plan_stats.start()
|
|
329
|
+
|
|
330
|
+
# initialize input queues and future queues for each operation
|
|
331
|
+
input_queues = self._create_input_queues(plan)
|
|
332
|
+
future_queues = {op.get_full_op_id(): [] for op in plan.operators}
|
|
166
333
|
|
|
167
334
|
# initialize and start the progress manager
|
|
168
335
|
self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
|
|
@@ -30,35 +30,35 @@ class SequentialSingleThreadExecutionStrategy(ExecutionStrategy):
|
|
|
30
30
|
output_records = []
|
|
31
31
|
for op_idx, operator in enumerate(plan.operators):
|
|
32
32
|
# if we've filtered out all records, terminate early
|
|
33
|
-
|
|
34
|
-
num_inputs = len(input_queues[
|
|
33
|
+
full_op_id = operator.get_full_op_id()
|
|
34
|
+
num_inputs = len(input_queues[full_op_id])
|
|
35
35
|
if num_inputs == 0:
|
|
36
36
|
break
|
|
37
37
|
|
|
38
38
|
# begin to process this operator
|
|
39
39
|
records, record_op_stats = [], []
|
|
40
|
-
logger.info(f"Processing operator {operator.op_name()} ({
|
|
40
|
+
logger.info(f"Processing operator {operator.op_name()} ({full_op_id})")
|
|
41
41
|
|
|
42
42
|
# if this operator is an aggregate, process all the records in the input_queue
|
|
43
43
|
if isinstance(operator, AggregateOp):
|
|
44
|
-
record_set = operator(candidates=input_queues[
|
|
44
|
+
record_set = operator(candidates=input_queues[full_op_id])
|
|
45
45
|
records = record_set.data_records
|
|
46
46
|
record_op_stats = record_set.record_op_stats
|
|
47
47
|
num_outputs = sum(record.passed_operator for record in records)
|
|
48
48
|
|
|
49
49
|
# update the progress manager
|
|
50
|
-
self.progress_manager.incr(
|
|
50
|
+
self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
51
51
|
|
|
52
52
|
# otherwise, process the records in the input queue for this operator one at a time
|
|
53
53
|
else:
|
|
54
|
-
for input_record in input_queues[
|
|
54
|
+
for input_record in input_queues[full_op_id]:
|
|
55
55
|
record_set = operator(input_record)
|
|
56
56
|
records.extend(record_set.data_records)
|
|
57
57
|
record_op_stats.extend(record_set.record_op_stats)
|
|
58
58
|
num_outputs = sum(record.passed_operator for record in record_set.data_records)
|
|
59
59
|
|
|
60
60
|
# update the progress manager
|
|
61
|
-
self.progress_manager.incr(
|
|
61
|
+
self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
62
62
|
|
|
63
63
|
# finish early if this is a limit
|
|
64
64
|
if isinstance(operator, LimitScanOp) and len(records) == operator.limit:
|
|
@@ -73,10 +73,10 @@ class SequentialSingleThreadExecutionStrategy(ExecutionStrategy):
|
|
|
73
73
|
# update next input_queue (if it exists)
|
|
74
74
|
output_records = [record for record in records if record.passed_operator]
|
|
75
75
|
if op_idx + 1 < len(plan.operators):
|
|
76
|
-
|
|
77
|
-
input_queues[
|
|
76
|
+
next_full_op_id = plan.operators[op_idx + 1].get_full_op_id()
|
|
77
|
+
input_queues[next_full_op_id] = output_records
|
|
78
78
|
|
|
79
|
-
logger.info(f"Finished processing operator {operator.op_name()} ({operator.
|
|
79
|
+
logger.info(f"Finished processing operator {operator.op_name()} ({operator.get_full_op_id()}), and generated {len(records)} records")
|
|
80
80
|
|
|
81
81
|
# close the cache
|
|
82
82
|
self._close_cache([op.target_cache_id for op in plan.operators])
|
|
@@ -146,8 +146,8 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
|
|
|
146
146
|
def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list]) -> bool:
|
|
147
147
|
"""Helper function to check if all upstream operators have finished processing their inputs."""
|
|
148
148
|
for upstream_op_idx in range(op_idx):
|
|
149
|
-
|
|
150
|
-
if len(input_queues[
|
|
149
|
+
upstream_full_op_id = plan.operators[upstream_op_idx].get_full_op_id()
|
|
150
|
+
if len(input_queues[upstream_full_op_id]) > 0:
|
|
151
151
|
return False
|
|
152
152
|
|
|
153
153
|
return True
|
|
@@ -160,8 +160,8 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
|
|
|
160
160
|
while self._any_queue_not_empty(input_queues):
|
|
161
161
|
for op_idx, operator in enumerate(plan.operators):
|
|
162
162
|
# if this operator does not have enough inputs to execute, then skip it
|
|
163
|
-
|
|
164
|
-
num_inputs = len(input_queues[
|
|
163
|
+
full_op_id = operator.get_full_op_id()
|
|
164
|
+
num_inputs = len(input_queues[full_op_id])
|
|
165
165
|
agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, op_idx, input_queues)
|
|
166
166
|
if num_inputs == 0 or agg_op_not_ready:
|
|
167
167
|
continue
|
|
@@ -171,25 +171,25 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
|
|
|
171
171
|
|
|
172
172
|
# if the next operator is an aggregate, process all the records in the input_queue
|
|
173
173
|
if isinstance(operator, AggregateOp):
|
|
174
|
-
input_records = [input_queues[
|
|
174
|
+
input_records = [input_queues[full_op_id].pop(0) for _ in range(num_inputs)]
|
|
175
175
|
record_set = operator(candidates=input_records)
|
|
176
176
|
records = record_set.data_records
|
|
177
177
|
record_op_stats = record_set.record_op_stats
|
|
178
178
|
num_outputs = sum(record.passed_operator for record in records)
|
|
179
179
|
|
|
180
180
|
# update the progress manager
|
|
181
|
-
self.progress_manager.incr(
|
|
181
|
+
self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
182
182
|
|
|
183
183
|
# otherwise, process the next record in the input queue for this operator
|
|
184
184
|
else:
|
|
185
|
-
input_record = input_queues[
|
|
185
|
+
input_record = input_queues[full_op_id].pop(0)
|
|
186
186
|
record_set = operator(input_record)
|
|
187
187
|
records = record_set.data_records
|
|
188
188
|
record_op_stats = record_set.record_op_stats
|
|
189
189
|
num_outputs = sum(record.passed_operator for record in records)
|
|
190
190
|
|
|
191
191
|
# update the progress manager
|
|
192
|
-
self.progress_manager.incr(
|
|
192
|
+
self.progress_manager.incr(full_op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
193
193
|
|
|
194
194
|
# update plan stats
|
|
195
195
|
plan_stats.add_record_op_stats(record_op_stats)
|
|
@@ -200,12 +200,12 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
|
|
|
200
200
|
# update next input_queue or final_output_records
|
|
201
201
|
output_records = [record for record in records if record.passed_operator]
|
|
202
202
|
if op_idx + 1 < len(plan.operators):
|
|
203
|
-
|
|
204
|
-
input_queues[
|
|
203
|
+
next_full_op_id = plan.operators[op_idx + 1].get_full_op_id()
|
|
204
|
+
input_queues[next_full_op_id].extend(output_records)
|
|
205
205
|
else:
|
|
206
206
|
final_output_records.extend(output_records)
|
|
207
207
|
|
|
208
|
-
logger.info(f"Finished processing operator {operator.op_name()} ({operator.
|
|
208
|
+
logger.info(f"Finished processing operator {operator.op_name()} ({operator.get_full_op_id()}) on {num_inputs} records")
|
|
209
209
|
|
|
210
210
|
# break out of loop if the final operator is a LimitScanOp and we've reached its limit
|
|
211
211
|
if isinstance(plan.operators[-1], LimitScanOp) and len(final_output_records) == plan.operators[-1].limit:
|
|
@@ -22,10 +22,9 @@ class APIClientFactory:
|
|
|
22
22
|
@staticmethod
|
|
23
23
|
def _create_client(api_client: APIClient, api_key: str):
|
|
24
24
|
"""Create a new client instance based on the api_client name."""
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
raise ValueError(f"Unknown api_client: {api_client}")
|
|
25
|
+
if api_client == APIClient.OPENAI:
|
|
26
|
+
return OpenAI(api_key=api_key)
|
|
27
|
+
elif api_client == APIClient.TOGETHER:
|
|
28
|
+
return Together(api_key=api_key)
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError(f"Unknown api_client: {api_client}")
|
|
@@ -49,10 +49,10 @@ def generator_factory(
|
|
|
49
49
|
"""
|
|
50
50
|
Factory function to return the correct generator based on the model, strategy, and cardinality.
|
|
51
51
|
"""
|
|
52
|
-
if model
|
|
52
|
+
if model.is_openai_model():
|
|
53
53
|
return OpenAIGenerator(model, prompt_strategy, cardinality, verbose)
|
|
54
54
|
|
|
55
|
-
elif model
|
|
55
|
+
elif model.is_together_model():
|
|
56
56
|
return TogetherGenerator(model, prompt_strategy, cardinality, verbose)
|
|
57
57
|
|
|
58
58
|
raise Exception(f"Unsupported model: {model}")
|
|
@@ -61,8 +61,6 @@ def generator_factory(
|
|
|
61
61
|
def get_api_key(key: str) -> str:
|
|
62
62
|
# get API key from environment or throw an exception if it's not set
|
|
63
63
|
if key not in os.environ:
|
|
64
|
-
print(f"KEY: {key}")
|
|
65
|
-
print(f"{os.environ.keys()}")
|
|
66
64
|
raise ValueError("key not found in environment variables")
|
|
67
65
|
|
|
68
66
|
return os.environ[key]
|
|
@@ -464,7 +462,7 @@ class OpenAIGenerator(BaseGenerator[str | list[str], str]):
|
|
|
464
462
|
verbose: bool = False,
|
|
465
463
|
):
|
|
466
464
|
# assert that model is an OpenAI model
|
|
467
|
-
assert model
|
|
465
|
+
assert model.is_openai_model()
|
|
468
466
|
super().__init__(model, prompt_strategy, cardinality, verbose, "developer")
|
|
469
467
|
|
|
470
468
|
def _get_client_or_model(self, **kwargs) -> OpenAI:
|
|
@@ -508,7 +506,7 @@ class TogetherGenerator(BaseGenerator[str | list[str], str]):
|
|
|
508
506
|
verbose: bool = False,
|
|
509
507
|
):
|
|
510
508
|
# assert that model is a model offered by Together
|
|
511
|
-
assert model
|
|
509
|
+
assert model.is_together_model()
|
|
512
510
|
super().__init__(model, prompt_strategy, cardinality, verbose, "system")
|
|
513
511
|
|
|
514
512
|
def _generate_payload(self, messages: list[dict], **kwargs) -> dict:
|
|
@@ -525,7 +523,7 @@ class TogetherGenerator(BaseGenerator[str | list[str], str]):
|
|
|
525
523
|
For LLAMA3, the payload needs to be in a {"role": <role>, "content": <content>} format.
|
|
526
524
|
"""
|
|
527
525
|
# for other models, use our standard payload generation
|
|
528
|
-
if self.model
|
|
526
|
+
if not self.model.is_llama_model():
|
|
529
527
|
return super()._generate_payload(messages, **kwargs)
|
|
530
528
|
|
|
531
529
|
# get basic parameters
|
|
@@ -593,7 +591,6 @@ def code_ensemble_execution(
|
|
|
593
591
|
preds.append(pred)
|
|
594
592
|
|
|
595
593
|
preds = [pred for pred in preds if pred is not None]
|
|
596
|
-
print(preds)
|
|
597
594
|
|
|
598
595
|
if len(preds) == 1:
|
|
599
596
|
majority_response = preds[0]
|
|
@@ -138,7 +138,7 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
138
138
|
record_parent_id=dr.parent_id,
|
|
139
139
|
record_source_idx=dr.source_idx,
|
|
140
140
|
record_state=dr.to_dict(include_bytes=False),
|
|
141
|
-
|
|
141
|
+
full_op_id=self.get_full_op_id(),
|
|
142
142
|
logical_op_id=self.logical_op_id,
|
|
143
143
|
op_name=self.op_name(),
|
|
144
144
|
time_per_record=total_time / len(drs),
|
|
@@ -198,7 +198,8 @@ class AverageAggregateOp(AggregateOp):
|
|
|
198
198
|
record_parent_id=dr.parent_id,
|
|
199
199
|
record_source_idx=dr.source_idx,
|
|
200
200
|
record_state=dr.to_dict(include_bytes=False),
|
|
201
|
-
|
|
201
|
+
full_op_id=self.get_full_op_id(),
|
|
202
|
+
logical_op_id=self.logical_op_id,
|
|
202
203
|
op_name=self.op_name(),
|
|
203
204
|
time_per_record=time.time() - start_time,
|
|
204
205
|
cost_per_record=0.0,
|
|
@@ -251,7 +252,7 @@ class CountAggregateOp(AggregateOp):
|
|
|
251
252
|
record_parent_id=dr.parent_id,
|
|
252
253
|
record_source_idx=dr.source_idx,
|
|
253
254
|
record_state=dr.to_dict(include_bytes=False),
|
|
254
|
-
|
|
255
|
+
full_op_id=self.get_full_op_id(),
|
|
255
256
|
logical_op_id=self.logical_op_id,
|
|
256
257
|
op_name=self.op_name(),
|
|
257
258
|
time_per_record=time.time() - start_time,
|
|
@@ -115,7 +115,7 @@ class ConvertOp(PhysicalOperator, ABC):
|
|
|
115
115
|
record_parent_id=dr.parent_id,
|
|
116
116
|
record_source_idx=dr.source_idx,
|
|
117
117
|
record_state=dr.to_dict(include_bytes=False),
|
|
118
|
-
|
|
118
|
+
full_op_id=self.get_full_op_id(),
|
|
119
119
|
logical_op_id=self.logical_op_id,
|
|
120
120
|
op_name=self.op_name(),
|
|
121
121
|
time_per_record=time_per_record,
|
|
@@ -84,7 +84,7 @@ class FilterOp(PhysicalOperator, ABC):
|
|
|
84
84
|
record_parent_id=dr.parent_id,
|
|
85
85
|
record_source_idx=dr.source_idx,
|
|
86
86
|
record_state=dr.to_dict(include_bytes=False),
|
|
87
|
-
|
|
87
|
+
full_op_id=self.get_full_op_id(),
|
|
88
88
|
logical_op_id=self.logical_op_id,
|
|
89
89
|
op_name=self.op_name(),
|
|
90
90
|
time_per_record=total_time,
|
|
@@ -44,7 +44,7 @@ class LimitScanOp(PhysicalOperator):
|
|
|
44
44
|
record_parent_id=dr.parent_id,
|
|
45
45
|
record_source_idx=dr.source_idx,
|
|
46
46
|
record_state=dr.to_dict(include_bytes=False),
|
|
47
|
-
|
|
47
|
+
full_op_id=self.get_full_op_id(),
|
|
48
48
|
logical_op_id=self.logical_op_id,
|
|
49
49
|
op_name=self.op_name(),
|
|
50
50
|
time_per_record=0.0,
|
|
@@ -47,7 +47,7 @@ class MapOp(PhysicalOperator):
|
|
|
47
47
|
record_parent_id=record.parent_id,
|
|
48
48
|
record_source_idx=record.source_idx,
|
|
49
49
|
record_state=record.to_dict(include_bytes=False),
|
|
50
|
-
|
|
50
|
+
full_op_id=self.get_full_op_id(),
|
|
51
51
|
logical_op_id=self.logical_op_id,
|
|
52
52
|
op_name=self.op_name(),
|
|
53
53
|
time_per_record=total_time,
|
|
@@ -58,8 +58,8 @@ class PhysicalOperator:
|
|
|
58
58
|
return op
|
|
59
59
|
|
|
60
60
|
def __eq__(self, other) -> bool:
|
|
61
|
-
|
|
62
|
-
return isinstance(other, self.__class__) and
|
|
61
|
+
all_op_params_match = all(value == getattr(other, key) for key, value in self.get_op_params().items())
|
|
62
|
+
return isinstance(other, self.__class__) and all_op_params_match
|
|
63
63
|
|
|
64
64
|
def copy(self) -> PhysicalOperator:
|
|
65
65
|
return self.__class__(**self.get_op_params())
|
|
@@ -79,7 +79,8 @@ class PhysicalOperator:
|
|
|
79
79
|
This is particularly true for convert operations, where the output schema
|
|
80
80
|
is now the union of the input and output schemas of the logical operator.
|
|
81
81
|
"""
|
|
82
|
-
return {"generated_fields": self.generated_fields}
|
|
82
|
+
# return {"generated_fields": self.generated_fields}
|
|
83
|
+
return {}
|
|
83
84
|
|
|
84
85
|
def get_op_params(self) -> dict:
|
|
85
86
|
"""
|
|
@@ -129,8 +130,11 @@ class PhysicalOperator:
|
|
|
129
130
|
def get_logical_op_id(self) -> str | None:
|
|
130
131
|
return self.logical_op_id
|
|
131
132
|
|
|
133
|
+
def get_full_op_id(self):
|
|
134
|
+
return f"{self.get_logical_op_id()}-{self.get_op_id()}"
|
|
135
|
+
|
|
132
136
|
def __hash__(self):
|
|
133
|
-
return int(self.op_id, 16)
|
|
137
|
+
return int(self.op_id, 16) # NOTE: should we use self.get_full_op_id() instead?
|
|
134
138
|
|
|
135
139
|
def get_model_name(self) -> str | None:
|
|
136
140
|
"""Returns the name of the model used by the physical operator (if it sets self.model). Otherwise, it returns None."""
|
|
@@ -42,7 +42,7 @@ class ProjectOp(PhysicalOperator):
|
|
|
42
42
|
record_parent_id=dr.parent_id,
|
|
43
43
|
record_source_idx=dr.source_idx,
|
|
44
44
|
record_state=dr.to_dict(include_bytes=False),
|
|
45
|
-
|
|
45
|
+
full_op_id=self.get_full_op_id(),
|
|
46
46
|
logical_op_id=self.logical_op_id,
|
|
47
47
|
op_name=self.op_name(),
|
|
48
48
|
time_per_record=0.0,
|
|
@@ -8,7 +8,6 @@ from chromadb.api.models.Collection import Collection
|
|
|
8
8
|
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
|
|
9
9
|
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
|
|
10
10
|
from openai import OpenAI
|
|
11
|
-
from ragatouille.RAGPretrainedModel import RAGPretrainedModel
|
|
12
11
|
from sentence_transformers import SentenceTransformer
|
|
13
12
|
|
|
14
13
|
from palimpzest.constants import MODEL_CARDS, Model
|
|
@@ -21,7 +20,7 @@ from palimpzest.query.operators.physical import PhysicalOperator
|
|
|
21
20
|
class RetrieveOp(PhysicalOperator):
|
|
22
21
|
def __init__(
|
|
23
22
|
self,
|
|
24
|
-
index: Collection
|
|
23
|
+
index: Collection,
|
|
25
24
|
search_attr: str,
|
|
26
25
|
output_attrs: list[dict] | type[Schema],
|
|
27
26
|
search_func: Callable | None,
|
|
@@ -33,7 +32,7 @@ class RetrieveOp(PhysicalOperator):
|
|
|
33
32
|
Initialize the RetrieveOp object.
|
|
34
33
|
|
|
35
34
|
Args:
|
|
36
|
-
index (Collection
|
|
35
|
+
index (Collection): The PZ index to use for retrieval.
|
|
37
36
|
search_attr (str): The attribute to search on.
|
|
38
37
|
output_attrs (list[dict]): The output fields containing the results of the search.
|
|
39
38
|
search_func (Callable | None): The function to use for searching the index. If None, the default search function will be used.
|
|
@@ -100,7 +99,7 @@ class RetrieveOp(PhysicalOperator):
|
|
|
100
99
|
quality=1.0,
|
|
101
100
|
)
|
|
102
101
|
|
|
103
|
-
def default_search_func(self, index: Collection
|
|
102
|
+
def default_search_func(self, index: Collection, query: list[str] | list[list[float]], k: int) -> list[str] | list[list[str]]:
|
|
104
103
|
"""
|
|
105
104
|
Default search function for the Retrieve operation. This function uses the index to
|
|
106
105
|
retrieve the top-k results for the given query. The query will be a (possibly singleton)
|
|
@@ -132,24 +131,8 @@ class RetrieveOp(PhysicalOperator):
|
|
|
132
131
|
# NOTE: self.output_field_names must be a singleton for default_search_func to be used
|
|
133
132
|
return {self.output_field_names[0]: final_results}
|
|
134
133
|
|
|
135
|
-
elif isinstance(index, RAGPretrainedModel):
|
|
136
|
-
# if the index is a rag model, use the rag model to get the top k results
|
|
137
|
-
results = index.search(query, k=k)
|
|
138
|
-
|
|
139
|
-
# the results will be a list[dict]; if the input is a singleton list, however
|
|
140
|
-
# it will be a list[list[dict]]; if the input is a list of lists
|
|
141
|
-
final_results = []
|
|
142
|
-
if is_singleton_list:
|
|
143
|
-
final_results = [result["content"] for result in results]
|
|
144
|
-
else:
|
|
145
|
-
for query_results in results:
|
|
146
|
-
final_results.append([result["content"] for result in query_results])
|
|
147
|
-
|
|
148
|
-
# NOTE: self.output_field_names must be a singleton for default_search_func to be used
|
|
149
|
-
return {self.output_field_names[0]: final_results}
|
|
150
|
-
|
|
151
134
|
else:
|
|
152
|
-
raise ValueError("Unsupported index type. Must be either a Collection
|
|
135
|
+
raise ValueError("Unsupported index type. Must be either a Collection.")
|
|
153
136
|
|
|
154
137
|
def _create_record_set(
|
|
155
138
|
self,
|
|
@@ -180,7 +163,7 @@ class RetrieveOp(PhysicalOperator):
|
|
|
180
163
|
record_parent_id=output_dr.parent_id,
|
|
181
164
|
record_source_idx=output_dr.source_idx,
|
|
182
165
|
record_state=record_state,
|
|
183
|
-
|
|
166
|
+
full_op_id=self.get_full_op_id(),
|
|
184
167
|
logical_op_id=self.logical_op_id,
|
|
185
168
|
op_name=self.op_name(),
|
|
186
169
|
time_per_record=total_time,
|
|
@@ -231,7 +214,8 @@ class RetrieveOp(PhysicalOperator):
|
|
|
231
214
|
|
|
232
215
|
model_name = self.index._embedding_function._model_name if uses_openai_embedding_fcn else "clip-ViT-B-32"
|
|
233
216
|
err_msg = f"For Chromadb, we currently only support `text-embedding-3-small` and `clip-ViT-B-32`; your index uses: {model_name}"
|
|
234
|
-
|
|
217
|
+
embedding_model_names = [model.value for model in Model if model.is_embedding_model()]
|
|
218
|
+
assert model_name in embedding_model_names, err_msg
|
|
235
219
|
|
|
236
220
|
# compute embeddings
|
|
237
221
|
try:
|