palimpzest 0.6.4__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +5 -0
- palimpzest/constants.py +110 -43
- palimpzest/core/__init__.py +0 -78
- palimpzest/core/data/dataclasses.py +382 -44
- palimpzest/core/elements/filters.py +7 -3
- palimpzest/core/elements/index.py +70 -0
- palimpzest/core/elements/records.py +33 -11
- palimpzest/core/lib/fields.py +1 -0
- palimpzest/core/lib/schemas.py +4 -3
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
- palimpzest/prompts/prompt_factory.py +44 -7
- palimpzest/prompts/split_merge_prompts.py +56 -0
- palimpzest/prompts/split_proposer_prompts.py +55 -0
- palimpzest/query/execution/execution_strategy.py +435 -53
- palimpzest/query/execution/execution_strategy_type.py +20 -0
- palimpzest/query/execution/mab_execution_strategy.py +532 -0
- palimpzest/query/execution/parallel_execution_strategy.py +143 -172
- palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
- palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
- palimpzest/query/generators/api_client_factory.py +31 -0
- palimpzest/query/generators/generators.py +256 -76
- palimpzest/query/operators/__init__.py +1 -2
- palimpzest/query/operators/code_synthesis_convert.py +33 -18
- palimpzest/query/operators/convert.py +30 -97
- palimpzest/query/operators/critique_and_refine_convert.py +5 -6
- palimpzest/query/operators/filter.py +7 -10
- palimpzest/query/operators/logical.py +54 -10
- palimpzest/query/operators/map.py +130 -0
- palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
- palimpzest/query/operators/physical.py +3 -12
- palimpzest/query/operators/rag_convert.py +66 -18
- palimpzest/query/operators/retrieve.py +230 -34
- palimpzest/query/operators/scan.py +5 -2
- palimpzest/query/operators/split_convert.py +169 -0
- palimpzest/query/operators/token_reduction_convert.py +8 -14
- palimpzest/query/optimizer/__init__.py +4 -16
- palimpzest/query/optimizer/cost_model.py +73 -266
- palimpzest/query/optimizer/optimizer.py +87 -58
- palimpzest/query/optimizer/optimizer_strategy.py +18 -97
- palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/primitives.py +5 -3
- palimpzest/query/optimizer/rules.py +336 -172
- palimpzest/query/optimizer/tasks.py +30 -100
- palimpzest/query/processor/config.py +38 -22
- palimpzest/query/processor/nosentinel_processor.py +16 -520
- palimpzest/query/processor/processing_strategy_type.py +28 -0
- palimpzest/query/processor/query_processor.py +38 -206
- palimpzest/query/processor/query_processor_factory.py +117 -130
- palimpzest/query/processor/sentinel_processor.py +90 -0
- palimpzest/query/processor/streaming_processor.py +25 -32
- palimpzest/sets.py +88 -41
- palimpzest/utils/model_helpers.py +8 -7
- palimpzest/utils/progress.py +368 -152
- palimpzest/utils/token_reduction_helpers.py +1 -3
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/METADATA +19 -9
- palimpzest-0.7.1.dist-info/RECORD +96 -0
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/WHEEL +1 -1
- palimpzest/query/processor/mab_sentinel_processor.py +0 -884
- palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
- palimpzest/utils/index_helpers.py +0 -6
- palimpzest-0.6.4.dist-info/RECORD +0 -87
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info/licenses}/LICENSE +0 -0
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,15 @@
|
|
|
1
|
-
import
|
|
1
|
+
import logging
|
|
2
2
|
|
|
3
|
-
from palimpzest.core.data.dataclasses import
|
|
3
|
+
from palimpzest.core.data.dataclasses import PlanStats
|
|
4
|
+
from palimpzest.core.elements.records import DataRecord
|
|
4
5
|
from palimpzest.query.execution.execution_strategy import ExecutionStrategy
|
|
5
6
|
from palimpzest.query.operators.aggregate import AggregateOp
|
|
6
|
-
from palimpzest.query.operators.filter import FilterOp
|
|
7
7
|
from palimpzest.query.operators.limit import LimitScanOp
|
|
8
8
|
from palimpzest.query.operators.scan import ScanPhysicalOp
|
|
9
9
|
from palimpzest.query.optimizer.plan import PhysicalPlan
|
|
10
|
+
from palimpzest.utils.progress import create_progress_manager
|
|
10
11
|
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
11
13
|
|
|
12
14
|
class SequentialSingleThreadExecutionStrategy(ExecutionStrategy):
|
|
13
15
|
"""
|
|
@@ -21,113 +23,100 @@ class SequentialSingleThreadExecutionStrategy(ExecutionStrategy):
|
|
|
21
23
|
"""
|
|
22
24
|
def __init__(self, *args, **kwargs):
|
|
23
25
|
super().__init__(*args, **kwargs)
|
|
26
|
+
self.max_workers = 1
|
|
24
27
|
|
|
25
|
-
def
|
|
26
|
-
"""Initialize the stats and the execute the plan."""
|
|
27
|
-
if self.verbose:
|
|
28
|
-
print("----------------------")
|
|
29
|
-
print(f"PLAN[{plan.plan_id}] (n={num_samples}):")
|
|
30
|
-
print(plan)
|
|
31
|
-
print("---")
|
|
32
|
-
|
|
33
|
-
plan_start_time = time.time()
|
|
34
|
-
|
|
35
|
-
# initialize plan stats and operator stats
|
|
36
|
-
plan_stats = PlanStats(plan_id=plan.plan_id, plan_str=str(plan))
|
|
37
|
-
for op in plan.operators:
|
|
38
|
-
op_id = op.get_op_id()
|
|
39
|
-
op_name = op.op_name()
|
|
40
|
-
op_details = {k: str(v) for k, v in op.get_id_params().items()}
|
|
41
|
-
plan_stats.operator_stats[op_id] = OperatorStats(op_id=op_id, op_name=op_name, op_details=op_details)
|
|
42
|
-
|
|
43
|
-
# initialize list of output records and intermediate variables
|
|
44
|
-
output_records = []
|
|
45
|
-
current_scan_idx = self.scan_start_idx
|
|
46
|
-
|
|
47
|
-
# get handle to scan operator and pre-compute its size
|
|
48
|
-
source_operator = plan.operators[0]
|
|
49
|
-
assert isinstance(source_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
|
|
50
|
-
datareader_len = len(source_operator.datareader)
|
|
51
|
-
|
|
52
|
-
# initialize processing queues for each operation
|
|
53
|
-
processing_queues = {op.get_op_id(): [] for op in plan.operators if not isinstance(op, ScanPhysicalOp)}
|
|
54
|
-
|
|
28
|
+
def _execute_plan(self, plan: PhysicalPlan, input_queues: dict[str, list], plan_stats: PlanStats) -> tuple[list[DataRecord], PlanStats]:
|
|
55
29
|
# execute the plan one operator at a time
|
|
30
|
+
output_records = []
|
|
56
31
|
for op_idx, operator in enumerate(plan.operators):
|
|
32
|
+
# if we've filtered out all records, terminate early
|
|
57
33
|
op_id = operator.get_op_id()
|
|
58
|
-
|
|
59
|
-
|
|
34
|
+
num_inputs = len(input_queues[op_id])
|
|
35
|
+
if num_inputs == 0:
|
|
36
|
+
break
|
|
60
37
|
|
|
61
|
-
#
|
|
38
|
+
# begin to process this operator
|
|
62
39
|
records, record_op_stats = [], []
|
|
40
|
+
logger.info(f"Processing operator {operator.op_name()} ({op_id})")
|
|
63
41
|
|
|
64
|
-
#
|
|
65
|
-
if isinstance(operator,
|
|
66
|
-
|
|
67
|
-
while keep_scanning_source_records:
|
|
68
|
-
# run ScanPhysicalOp on current scan index
|
|
69
|
-
record_set = operator(current_scan_idx)
|
|
70
|
-
records.extend(record_set.data_records)
|
|
71
|
-
record_op_stats.extend(record_set.record_op_stats)
|
|
72
|
-
|
|
73
|
-
# update the current scan index
|
|
74
|
-
current_scan_idx += 1
|
|
75
|
-
|
|
76
|
-
# update whether to keep scanning source records
|
|
77
|
-
keep_scanning_source_records = current_scan_idx < datareader_len and len(records) < num_samples
|
|
78
|
-
|
|
79
|
-
# aggregate operators accept all input records at once
|
|
80
|
-
elif isinstance(operator, AggregateOp):
|
|
81
|
-
record_set = operator(candidates=processing_queues[op_id])
|
|
42
|
+
# if this operator is an aggregate, process all the records in the input_queue
|
|
43
|
+
if isinstance(operator, AggregateOp):
|
|
44
|
+
record_set = operator(candidates=input_queues[op_id])
|
|
82
45
|
records = record_set.data_records
|
|
83
46
|
record_op_stats = record_set.record_op_stats
|
|
47
|
+
num_outputs = sum(record.passed_operator for record in records)
|
|
48
|
+
|
|
49
|
+
# update the progress manager
|
|
50
|
+
self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
84
51
|
|
|
85
|
-
# otherwise, process the records in the
|
|
86
|
-
|
|
87
|
-
for input_record in
|
|
52
|
+
# otherwise, process the records in the input queue for this operator one at a time
|
|
53
|
+
else:
|
|
54
|
+
for input_record in input_queues[op_id]:
|
|
88
55
|
record_set = operator(input_record)
|
|
89
56
|
records.extend(record_set.data_records)
|
|
90
57
|
record_op_stats.extend(record_set.record_op_stats)
|
|
58
|
+
num_outputs = sum(record.passed_operator for record in record_set.data_records)
|
|
59
|
+
|
|
60
|
+
# update the progress manager
|
|
61
|
+
self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
91
62
|
|
|
63
|
+
# finish early if this is a limit
|
|
92
64
|
if isinstance(operator, LimitScanOp) and len(records) == operator.limit:
|
|
93
65
|
break
|
|
94
66
|
|
|
95
67
|
# update plan stats
|
|
96
|
-
plan_stats.
|
|
97
|
-
record_op_stats,
|
|
98
|
-
source_op_id=prev_op_id,
|
|
99
|
-
plan_id=plan.plan_id,
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
# add records (which are not filtered) to the cache, if allowed
|
|
103
|
-
if not self.nocache:
|
|
104
|
-
for record in records:
|
|
105
|
-
if getattr(record, "passed_operator", True):
|
|
106
|
-
# self.datadir.append_cache(operator.target_cache_id, record)
|
|
107
|
-
pass
|
|
108
|
-
|
|
109
|
-
# update processing_queues or output_records
|
|
110
|
-
for record in records:
|
|
111
|
-
if isinstance(operator, FilterOp) and not record.passed_operator:
|
|
112
|
-
continue
|
|
113
|
-
if next_op_id is not None:
|
|
114
|
-
processing_queues[next_op_id].append(record)
|
|
115
|
-
else:
|
|
116
|
-
output_records.append(record)
|
|
68
|
+
plan_stats.add_record_op_stats(record_op_stats)
|
|
117
69
|
|
|
118
|
-
#
|
|
119
|
-
|
|
120
|
-
break
|
|
70
|
+
# add records to the cache
|
|
71
|
+
self._add_records_to_cache(operator.target_cache_id, records)
|
|
121
72
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
73
|
+
# update next input_queue (if it exists)
|
|
74
|
+
output_records = [record for record in records if record.passed_operator]
|
|
75
|
+
if op_idx + 1 < len(plan.operators):
|
|
76
|
+
next_op_id = plan.operators[op_idx + 1].get_op_id()
|
|
77
|
+
input_queues[next_op_id] = output_records
|
|
78
|
+
|
|
79
|
+
logger.info(f"Finished processing operator {operator.op_name()} ({operator.get_op_id()}), and generated {len(records)} records")
|
|
80
|
+
|
|
81
|
+
# close the cache
|
|
82
|
+
self._close_cache([op.target_cache_id for op in plan.operators])
|
|
127
83
|
|
|
128
84
|
# finalize plan stats
|
|
129
|
-
|
|
130
|
-
|
|
85
|
+
plan_stats.finish()
|
|
86
|
+
|
|
87
|
+
return output_records, plan_stats
|
|
88
|
+
|
|
89
|
+
def execute_plan(self, plan: PhysicalPlan) -> tuple[list[DataRecord], PlanStats]:
|
|
90
|
+
"""Initialize the stats and execute the plan."""
|
|
91
|
+
# for now, assert that the first operator in the plan is a ScanPhysicalOp
|
|
92
|
+
assert isinstance(plan.operators[0], ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
|
|
93
|
+
logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
|
|
94
|
+
logger.info(f"Plan Details: {plan}")
|
|
95
|
+
|
|
96
|
+
# initialize plan stats
|
|
97
|
+
plan_stats = PlanStats.from_plan(plan)
|
|
98
|
+
plan_stats.start()
|
|
99
|
+
|
|
100
|
+
# initialize input queues for each operation
|
|
101
|
+
input_queues = self._create_input_queues(plan)
|
|
102
|
+
|
|
103
|
+
# initialize and start the progress manager
|
|
104
|
+
self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
|
|
105
|
+
self.progress_manager.start()
|
|
106
|
+
|
|
107
|
+
# NOTE: we must handle progress manager outside of _exeecute_plan to ensure that it is shut down correctly;
|
|
108
|
+
# if we don't have the `finally:` branch, then program crashes can cause future program runs to fail
|
|
109
|
+
# because the progress manager cannot get a handle to the console
|
|
110
|
+
try:
|
|
111
|
+
# execute plan
|
|
112
|
+
output_records, plan_stats = self._execute_plan(plan, input_queues, plan_stats)
|
|
113
|
+
|
|
114
|
+
finally:
|
|
115
|
+
# finish progress tracking
|
|
116
|
+
self.progress_manager.finish()
|
|
117
|
+
|
|
118
|
+
logger.info(f"Done executing plan: {plan.plan_id}")
|
|
119
|
+
logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
|
|
131
120
|
|
|
132
121
|
return output_records, plan_stats
|
|
133
122
|
|
|
@@ -148,137 +137,118 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
|
|
|
148
137
|
|
|
149
138
|
def __init__(self, *args, **kwargs):
|
|
150
139
|
super().__init__(*args, **kwargs)
|
|
151
|
-
self.max_workers = 1
|
|
152
|
-
|
|
153
|
-
def
|
|
154
|
-
"""
|
|
155
|
-
|
|
156
|
-
print("----------------------")
|
|
157
|
-
print(f"PLAN[{plan.plan_id}] (n={num_samples}):")
|
|
158
|
-
print(plan)
|
|
159
|
-
print("---")
|
|
160
|
-
|
|
161
|
-
plan_start_time = time.time()
|
|
162
|
-
|
|
163
|
-
# initialize plan stats and operator stats
|
|
164
|
-
plan_stats = PlanStats(plan_id=plan.plan_id, plan_str=str(plan))
|
|
165
|
-
for op in plan.operators:
|
|
166
|
-
op_id = op.get_op_id()
|
|
167
|
-
op_name = op.op_name()
|
|
168
|
-
op_details = {k: str(v) for k, v in op.get_id_params().items()}
|
|
169
|
-
plan_stats.operator_stats[op_id] = OperatorStats(op_id=op_id, op_name=op_name, op_details=op_details)
|
|
170
|
-
|
|
171
|
-
# initialize list of output records and intermediate variables
|
|
172
|
-
output_records = []
|
|
173
|
-
source_records_scanned = 0
|
|
174
|
-
current_scan_idx = self.scan_start_idx
|
|
140
|
+
self.max_workers = 1
|
|
141
|
+
|
|
142
|
+
def _any_queue_not_empty(self, queues: dict[str, list]) -> bool:
|
|
143
|
+
"""Helper function to check if any queue is not empty."""
|
|
144
|
+
return any(len(queue) > 0 for queue in queues.values())
|
|
175
145
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
146
|
+
def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list]) -> bool:
|
|
147
|
+
"""Helper function to check if all upstream operators have finished processing their inputs."""
|
|
148
|
+
for upstream_op_idx in range(op_idx):
|
|
149
|
+
upstream_op_id = plan.operators[upstream_op_idx].get_op_id()
|
|
150
|
+
if len(input_queues[upstream_op_id]) > 0:
|
|
151
|
+
return False
|
|
180
152
|
|
|
181
|
-
|
|
182
|
-
processing_queues = {op.get_op_id(): [] for op in plan.operators if not isinstance(op, ScanPhysicalOp)}
|
|
153
|
+
return True
|
|
183
154
|
|
|
155
|
+
def _execute_plan(self, plan: PhysicalPlan, input_queues: dict[str, list], plan_stats: PlanStats) -> tuple[list[DataRecord], PlanStats]:
|
|
184
156
|
# execute the plan until either:
|
|
185
157
|
# 1. all records have been processed, or
|
|
186
|
-
# 2. the final limit operation has completed
|
|
187
|
-
|
|
188
|
-
while
|
|
158
|
+
# 2. the final limit operation has completed (we break out of the loop if this happens)
|
|
159
|
+
final_output_records = []
|
|
160
|
+
while self._any_queue_not_empty(input_queues):
|
|
189
161
|
for op_idx, operator in enumerate(plan.operators):
|
|
162
|
+
# if this operator does not have enough inputs to execute, then skip it
|
|
190
163
|
op_id = operator.get_op_id()
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
164
|
+
num_inputs = len(input_queues[op_id])
|
|
165
|
+
agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, op_idx, input_queues)
|
|
166
|
+
if num_inputs == 0 or agg_op_not_ready:
|
|
167
|
+
continue
|
|
194
168
|
|
|
195
169
|
# create empty lists for records and execution stats generated by executing this operator on its next input(s)
|
|
196
170
|
records, record_op_stats = [], []
|
|
197
171
|
|
|
198
|
-
#
|
|
199
|
-
if isinstance(operator,
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
# only invoke aggregate operator(s) once there are no more source records and all
|
|
213
|
-
# upstream operators' processing queues are empty
|
|
214
|
-
elif isinstance(operator, AggregateOp):
|
|
215
|
-
upstream_ops_are_finished = True
|
|
216
|
-
for upstream_op_idx in range(op_idx):
|
|
217
|
-
# scan operators do not have processing queues
|
|
218
|
-
if isinstance(plan.operators[upstream_op_idx], ScanPhysicalOp):
|
|
219
|
-
continue
|
|
220
|
-
|
|
221
|
-
# check upstream ops which do have a processing queue
|
|
222
|
-
upstream_op_id = plan.operators[upstream_op_idx].get_op_id()
|
|
223
|
-
upstream_ops_are_finished = (
|
|
224
|
-
upstream_ops_are_finished and len(processing_queues[upstream_op_id]) == 0
|
|
225
|
-
)
|
|
226
|
-
|
|
227
|
-
if not keep_scanning_source_records and upstream_ops_are_finished:
|
|
228
|
-
record_set = operator(candidates=processing_queues[op_id])
|
|
229
|
-
records = record_set.data_records
|
|
230
|
-
record_op_stats = record_set.record_op_stats
|
|
231
|
-
processing_queues[op_id] = []
|
|
232
|
-
|
|
233
|
-
# otherwise, process the next record in the processing queue for this operator
|
|
234
|
-
elif len(processing_queues[op_id]) > 0:
|
|
235
|
-
input_record = processing_queues[op_id].pop(0)
|
|
172
|
+
# if the next operator is an aggregate, process all the records in the input_queue
|
|
173
|
+
if isinstance(operator, AggregateOp):
|
|
174
|
+
input_records = [input_queues[op_id].pop(0) for _ in range(num_inputs)]
|
|
175
|
+
record_set = operator(candidates=input_records)
|
|
176
|
+
records = record_set.data_records
|
|
177
|
+
record_op_stats = record_set.record_op_stats
|
|
178
|
+
num_outputs = sum(record.passed_operator for record in records)
|
|
179
|
+
|
|
180
|
+
# update the progress manager
|
|
181
|
+
self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
182
|
+
|
|
183
|
+
# otherwise, process the next record in the input queue for this operator
|
|
184
|
+
else:
|
|
185
|
+
input_record = input_queues[op_id].pop(0)
|
|
236
186
|
record_set = operator(input_record)
|
|
237
187
|
records = record_set.data_records
|
|
238
188
|
record_op_stats = record_set.record_op_stats
|
|
189
|
+
num_outputs = sum(record.passed_operator for record in records)
|
|
239
190
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
# update finished_executing based on whether all records have been processed
|
|
266
|
-
still_processing = any([len(queue) > 0 for queue in processing_queues.values()])
|
|
267
|
-
keep_scanning_source_records = current_scan_idx < datareader_len and source_records_scanned < num_samples
|
|
268
|
-
finished_executing = not keep_scanning_source_records and not still_processing
|
|
269
|
-
|
|
270
|
-
# update finished_executing based on limit
|
|
271
|
-
if isinstance(operator, LimitScanOp):
|
|
272
|
-
finished_executing = len(output_records) == operator.limit
|
|
273
|
-
|
|
274
|
-
# if caching was allowed, close the cache
|
|
275
|
-
if not self.nocache:
|
|
276
|
-
for _ in plan.operators:
|
|
277
|
-
# self.datadir.close_cache(operator.target_cache_id)
|
|
278
|
-
pass
|
|
191
|
+
# update the progress manager
|
|
192
|
+
self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
|
|
193
|
+
|
|
194
|
+
# update plan stats
|
|
195
|
+
plan_stats.add_record_op_stats(record_op_stats)
|
|
196
|
+
|
|
197
|
+
# add records to the cache
|
|
198
|
+
self._add_records_to_cache(operator.target_cache_id, records)
|
|
199
|
+
|
|
200
|
+
# update next input_queue or final_output_records
|
|
201
|
+
output_records = [record for record in records if record.passed_operator]
|
|
202
|
+
if op_idx + 1 < len(plan.operators):
|
|
203
|
+
next_op_id = plan.operators[op_idx + 1].get_op_id()
|
|
204
|
+
input_queues[next_op_id].extend(output_records)
|
|
205
|
+
else:
|
|
206
|
+
final_output_records.extend(output_records)
|
|
207
|
+
|
|
208
|
+
logger.info(f"Finished processing operator {operator.op_name()} ({operator.get_op_id()}) on {num_inputs} records")
|
|
209
|
+
|
|
210
|
+
# break out of loop if the final operator is a LimitScanOp and we've reached its limit
|
|
211
|
+
if isinstance(plan.operators[-1], LimitScanOp) and len(final_output_records) == plan.operators[-1].limit:
|
|
212
|
+
break
|
|
213
|
+
|
|
214
|
+
# close the cache
|
|
215
|
+
self._close_cache([op.target_cache_id for op in plan.operators])
|
|
279
216
|
|
|
280
217
|
# finalize plan stats
|
|
281
|
-
|
|
282
|
-
|
|
218
|
+
plan_stats.finish()
|
|
219
|
+
|
|
220
|
+
return final_output_records, plan_stats
|
|
221
|
+
|
|
222
|
+
def execute_plan(self, plan: PhysicalPlan):
|
|
223
|
+
"""Initialize the stats and execute the plan."""
|
|
224
|
+
# for now, assert that the first operator in the plan is a ScanPhysicalOp
|
|
225
|
+
assert isinstance(plan.operators[0], ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
|
|
226
|
+
logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
|
|
227
|
+
logger.info(f"Plan Details: {plan}")
|
|
228
|
+
|
|
229
|
+
# initialize plan stats
|
|
230
|
+
plan_stats = PlanStats.from_plan(plan)
|
|
231
|
+
plan_stats.start()
|
|
232
|
+
|
|
233
|
+
# initialize input queues for each operation
|
|
234
|
+
input_queues = self._create_input_queues(plan)
|
|
235
|
+
|
|
236
|
+
# initialize and start the progress manager
|
|
237
|
+
self.progress_manager = create_progress_manager(plan, self.num_samples, self.progress)
|
|
238
|
+
self.progress_manager.start()
|
|
239
|
+
|
|
240
|
+
# NOTE: we must handle progress manager outside of _exeecute_plan to ensure that it is shut down correctly;
|
|
241
|
+
# if we don't have the `finally:` branch, then program crashes can cause future program runs to fail
|
|
242
|
+
# because the progress manager cannot get a handle to the console
|
|
243
|
+
try:
|
|
244
|
+
# execute plan
|
|
245
|
+
output_records, plan_stats = self._execute_plan(plan, input_queues, plan_stats)
|
|
246
|
+
|
|
247
|
+
finally:
|
|
248
|
+
# finish progress tracking
|
|
249
|
+
self.progress_manager.finish()
|
|
250
|
+
|
|
251
|
+
logger.info(f"Done executing plan: {plan.plan_id}")
|
|
252
|
+
logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
|
|
283
253
|
|
|
284
254
|
return output_records, plan_stats
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
|
|
3
|
+
from openai import OpenAI
|
|
4
|
+
from together import Together
|
|
5
|
+
|
|
6
|
+
from palimpzest.constants import APIClient
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class APIClientFactory:
|
|
10
|
+
_instances = {}
|
|
11
|
+
_lock = Lock()
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def get_client(cls, api_client: APIClient, api_key: str):
|
|
15
|
+
"""Get a singleton instance of the requested API client."""
|
|
16
|
+
if api_client not in cls._instances:
|
|
17
|
+
with cls._lock: # Ensure thread safety
|
|
18
|
+
if api_client not in cls._instances: # Double-check inside the lock
|
|
19
|
+
cls._instances[api_client] = cls._create_client(api_client, api_key)
|
|
20
|
+
return cls._instances[api_client]
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def _create_client(api_client: APIClient, api_key: str):
|
|
24
|
+
"""Create a new client instance based on the api_client name."""
|
|
25
|
+
match api_client:
|
|
26
|
+
case APIClient.OPENAI:
|
|
27
|
+
return OpenAI(api_key=api_key)
|
|
28
|
+
case APIClient.TOGETHER:
|
|
29
|
+
return Together(api_key=api_key)
|
|
30
|
+
case _:
|
|
31
|
+
raise ValueError(f"Unknown api_client: {api_client}")
|