palimpzest 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. palimpzest/__init__.py +5 -0
  2. palimpzest/constants.py +110 -43
  3. palimpzest/core/__init__.py +0 -78
  4. palimpzest/core/data/dataclasses.py +382 -44
  5. palimpzest/core/elements/filters.py +7 -3
  6. palimpzest/core/elements/index.py +70 -0
  7. palimpzest/core/elements/records.py +33 -11
  8. palimpzest/core/lib/fields.py +1 -0
  9. palimpzest/core/lib/schemas.py +4 -3
  10. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
  11. palimpzest/prompts/prompt_factory.py +44 -7
  12. palimpzest/prompts/split_merge_prompts.py +56 -0
  13. palimpzest/prompts/split_proposer_prompts.py +55 -0
  14. palimpzest/query/execution/execution_strategy.py +435 -53
  15. palimpzest/query/execution/execution_strategy_type.py +20 -0
  16. palimpzest/query/execution/mab_execution_strategy.py +532 -0
  17. palimpzest/query/execution/parallel_execution_strategy.py +143 -172
  18. palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
  19. palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
  20. palimpzest/query/generators/api_client_factory.py +31 -0
  21. palimpzest/query/generators/generators.py +256 -76
  22. palimpzest/query/operators/__init__.py +1 -2
  23. palimpzest/query/operators/code_synthesis_convert.py +33 -18
  24. palimpzest/query/operators/convert.py +30 -97
  25. palimpzest/query/operators/critique_and_refine_convert.py +5 -6
  26. palimpzest/query/operators/filter.py +7 -10
  27. palimpzest/query/operators/logical.py +54 -10
  28. palimpzest/query/operators/map.py +130 -0
  29. palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
  30. palimpzest/query/operators/physical.py +3 -12
  31. palimpzest/query/operators/rag_convert.py +66 -18
  32. palimpzest/query/operators/retrieve.py +230 -34
  33. palimpzest/query/operators/scan.py +5 -2
  34. palimpzest/query/operators/split_convert.py +169 -0
  35. palimpzest/query/operators/token_reduction_convert.py +8 -14
  36. palimpzest/query/optimizer/__init__.py +4 -16
  37. palimpzest/query/optimizer/cost_model.py +73 -266
  38. palimpzest/query/optimizer/optimizer.py +87 -58
  39. palimpzest/query/optimizer/optimizer_strategy.py +18 -97
  40. palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
  41. palimpzest/query/optimizer/plan.py +2 -3
  42. palimpzest/query/optimizer/primitives.py +5 -3
  43. palimpzest/query/optimizer/rules.py +336 -172
  44. palimpzest/query/optimizer/tasks.py +30 -100
  45. palimpzest/query/processor/config.py +38 -22
  46. palimpzest/query/processor/nosentinel_processor.py +16 -520
  47. palimpzest/query/processor/processing_strategy_type.py +28 -0
  48. palimpzest/query/processor/query_processor.py +38 -206
  49. palimpzest/query/processor/query_processor_factory.py +117 -130
  50. palimpzest/query/processor/sentinel_processor.py +90 -0
  51. palimpzest/query/processor/streaming_processor.py +25 -32
  52. palimpzest/sets.py +88 -41
  53. palimpzest/utils/model_helpers.py +8 -7
  54. palimpzest/utils/progress.py +368 -152
  55. palimpzest/utils/token_reduction_helpers.py +1 -3
  56. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/METADATA +28 -24
  57. palimpzest-0.7.0.dist-info/RECORD +96 -0
  58. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
  59. palimpzest/query/processor/mab_sentinel_processor.py +0 -884
  60. palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
  61. palimpzest/utils/index_helpers.py +0 -6
  62. palimpzest-0.6.3.dist-info/RECORD +0 -87
  63. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
  64. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,15 @@
1
- import time
1
+ import logging
2
2
 
3
- from palimpzest.core.data.dataclasses import OperatorStats, PlanStats
3
+ from palimpzest.core.data.dataclasses import PlanStats
4
+ from palimpzest.core.elements.records import DataRecord
4
5
  from palimpzest.query.execution.execution_strategy import ExecutionStrategy
5
6
  from palimpzest.query.operators.aggregate import AggregateOp
6
- from palimpzest.query.operators.filter import FilterOp
7
7
  from palimpzest.query.operators.limit import LimitScanOp
8
8
  from palimpzest.query.operators.scan import ScanPhysicalOp
9
9
  from palimpzest.query.optimizer.plan import PhysicalPlan
10
+ from palimpzest.utils.progress import create_progress_manager
10
11
 
12
+ logger = logging.getLogger(__name__)
11
13
 
12
14
  class SequentialSingleThreadExecutionStrategy(ExecutionStrategy):
13
15
  """
@@ -21,113 +23,100 @@ class SequentialSingleThreadExecutionStrategy(ExecutionStrategy):
21
23
  """
22
24
  def __init__(self, *args, **kwargs):
23
25
  super().__init__(*args, **kwargs)
26
+ self.max_workers = 1
24
27
 
25
- def execute_plan(self, plan: PhysicalPlan, num_samples: int | float = float("inf"), plan_workers: int = 1):
26
- """Initialize the stats and the execute the plan."""
27
- if self.verbose:
28
- print("----------------------")
29
- print(f"PLAN[{plan.plan_id}] (n={num_samples}):")
30
- print(plan)
31
- print("---")
32
-
33
- plan_start_time = time.time()
34
-
35
- # initialize plan stats and operator stats
36
- plan_stats = PlanStats(plan_id=plan.plan_id, plan_str=str(plan))
37
- for op in plan.operators:
38
- op_id = op.get_op_id()
39
- op_name = op.op_name()
40
- op_details = {k: str(v) for k, v in op.get_id_params().items()}
41
- plan_stats.operator_stats[op_id] = OperatorStats(op_id=op_id, op_name=op_name, op_details=op_details)
42
-
43
- # initialize list of output records and intermediate variables
44
- output_records = []
45
- current_scan_idx = self.scan_start_idx
46
-
47
- # get handle to scan operator and pre-compute its size
48
- source_operator = plan.operators[0]
49
- assert isinstance(source_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
50
- datareader_len = len(source_operator.datareader)
51
-
52
- # initialize processing queues for each operation
53
- processing_queues = {op.get_op_id(): [] for op in plan.operators if not isinstance(op, ScanPhysicalOp)}
54
-
28
+ def _execute_plan(self, plan: PhysicalPlan, input_queues: dict[str, list], plan_stats: PlanStats) -> tuple[list[DataRecord], PlanStats]:
55
29
  # execute the plan one operator at a time
30
+ output_records = []
56
31
  for op_idx, operator in enumerate(plan.operators):
32
+ # if we've filtered out all records, terminate early
57
33
  op_id = operator.get_op_id()
58
- prev_op_id = plan.operators[op_idx - 1].get_op_id() if op_idx > 1 else None
59
- next_op_id = plan.operators[op_idx + 1].get_op_id() if op_idx + 1 < len(plan.operators) else None
34
+ num_inputs = len(input_queues[op_id])
35
+ if num_inputs == 0:
36
+ break
60
37
 
61
- # initialize output records and record_op_stats for this operator
38
+ # begin to process this operator
62
39
  records, record_op_stats = [], []
40
+ logger.info(f"Processing operator {operator.op_name()} ({op_id})")
63
41
 
64
- # invoke scan operator(s) until we run out of source records or hit the num_samples limit
65
- if isinstance(operator, ScanPhysicalOp):
66
- keep_scanning_source_records = True
67
- while keep_scanning_source_records:
68
- # run ScanPhysicalOp on current scan index
69
- record_set = operator(current_scan_idx)
70
- records.extend(record_set.data_records)
71
- record_op_stats.extend(record_set.record_op_stats)
72
-
73
- # update the current scan index
74
- current_scan_idx += 1
75
-
76
- # update whether to keep scanning source records
77
- keep_scanning_source_records = current_scan_idx < datareader_len and len(records) < num_samples
78
-
79
- # aggregate operators accept all input records at once
80
- elif isinstance(operator, AggregateOp):
81
- record_set = operator(candidates=processing_queues[op_id])
42
+ # if this operator is an aggregate, process all the records in the input_queue
43
+ if isinstance(operator, AggregateOp):
44
+ record_set = operator(candidates=input_queues[op_id])
82
45
  records = record_set.data_records
83
46
  record_op_stats = record_set.record_op_stats
47
+ num_outputs = sum(record.passed_operator for record in records)
48
+
49
+ # update the progress manager
50
+ self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
84
51
 
85
- # otherwise, process the records in the processing queue for this operator one at a time
86
- elif len(processing_queues[op_id]) > 0:
87
- for input_record in processing_queues[op_id]:
52
+ # otherwise, process the records in the input queue for this operator one at a time
53
+ else:
54
+ for input_record in input_queues[op_id]:
88
55
  record_set = operator(input_record)
89
56
  records.extend(record_set.data_records)
90
57
  record_op_stats.extend(record_set.record_op_stats)
58
+ num_outputs = sum(record.passed_operator for record in record_set.data_records)
59
+
60
+ # update the progress manager
61
+ self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
91
62
 
63
+ # finish early if this is a limit
92
64
  if isinstance(operator, LimitScanOp) and len(records) == operator.limit:
93
65
  break
94
66
 
95
67
  # update plan stats
96
- plan_stats.operator_stats[op_id].add_record_op_stats(
97
- record_op_stats,
98
- source_op_id=prev_op_id,
99
- plan_id=plan.plan_id,
100
- )
101
-
102
- # add records (which are not filtered) to the cache, if allowed
103
- if not self.nocache:
104
- for record in records:
105
- if getattr(record, "passed_operator", True):
106
- # self.datadir.append_cache(operator.target_cache_id, record)
107
- pass
108
-
109
- # update processing_queues or output_records
110
- for record in records:
111
- if isinstance(operator, FilterOp) and not record.passed_operator:
112
- continue
113
- if next_op_id is not None:
114
- processing_queues[next_op_id].append(record)
115
- else:
116
- output_records.append(record)
68
+ plan_stats.add_record_op_stats(record_op_stats)
117
69
 
118
- # if we've filtered out all records, terminate early
119
- if next_op_id is not None and processing_queues[next_op_id] == []:
120
- break
70
+ # add records to the cache
71
+ self._add_records_to_cache(operator.target_cache_id, records)
121
72
 
122
- # if caching was allowed, close the cache
123
- if not self.nocache:
124
- for _ in plan.operators:
125
- # self.datadir.close_cache(operator.target_cache_id)
126
- pass
73
+ # update next input_queue (if it exists)
74
+ output_records = [record for record in records if record.passed_operator]
75
+ if op_idx + 1 < len(plan.operators):
76
+ next_op_id = plan.operators[op_idx + 1].get_op_id()
77
+ input_queues[next_op_id] = output_records
78
+
79
+ logger.info(f"Finished processing operator {operator.op_name()} ({operator.get_op_id()}), and generated {len(records)} records")
80
+
81
+ # close the cache
82
+ self._close_cache([op.target_cache_id for op in plan.operators])
127
83
 
128
84
  # finalize plan stats
129
- total_plan_time = time.time() - plan_start_time
130
- plan_stats.finalize(total_plan_time)
85
+ plan_stats.finish()
86
+
87
+ return output_records, plan_stats
88
+
89
+ def execute_plan(self, plan: PhysicalPlan) -> tuple[list[DataRecord], PlanStats]:
90
+ """Initialize the stats and execute the plan."""
91
+ # for now, assert that the first operator in the plan is a ScanPhysicalOp
92
+ assert isinstance(plan.operators[0], ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
93
+ logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
94
+ logger.info(f"Plan Details: {plan}")
95
+
96
+ # initialize plan stats
97
+ plan_stats = PlanStats.from_plan(plan)
98
+ plan_stats.start()
99
+
100
+ # initialize input queues for each operation
101
+ input_queues = self._create_input_queues(plan)
102
+
103
+ # initialize and start the progress manager
104
+ self.progress_manager = create_progress_manager(plan, num_samples=self.num_samples, progress=self.progress)
105
+ self.progress_manager.start()
106
+
107
+ # NOTE: we must handle progress manager outside of _exeecute_plan to ensure that it is shut down correctly;
108
+ # if we don't have the `finally:` branch, then program crashes can cause future program runs to fail
109
+ # because the progress manager cannot get a handle to the console
110
+ try:
111
+ # execute plan
112
+ output_records, plan_stats = self._execute_plan(plan, input_queues, plan_stats)
113
+
114
+ finally:
115
+ # finish progress tracking
116
+ self.progress_manager.finish()
117
+
118
+ logger.info(f"Done executing plan: {plan.plan_id}")
119
+ logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
131
120
 
132
121
  return output_records, plan_stats
133
122
 
@@ -148,137 +137,118 @@ class PipelinedSingleThreadExecutionStrategy(ExecutionStrategy):
148
137
 
149
138
  def __init__(self, *args, **kwargs):
150
139
  super().__init__(*args, **kwargs)
151
- self.max_workers = 1 if self.max_workers is None else self.max_workers
152
-
153
- def execute_plan(self, plan: PhysicalPlan, num_samples: int | float = float("inf"), plan_workers: int = 1):
154
- """Initialize the stats and the execute the plan."""
155
- if self.verbose:
156
- print("----------------------")
157
- print(f"PLAN[{plan.plan_id}] (n={num_samples}):")
158
- print(plan)
159
- print("---")
160
-
161
- plan_start_time = time.time()
162
-
163
- # initialize plan stats and operator stats
164
- plan_stats = PlanStats(plan_id=plan.plan_id, plan_str=str(plan))
165
- for op in plan.operators:
166
- op_id = op.get_op_id()
167
- op_name = op.op_name()
168
- op_details = {k: str(v) for k, v in op.get_id_params().items()}
169
- plan_stats.operator_stats[op_id] = OperatorStats(op_id=op_id, op_name=op_name, op_details=op_details)
170
-
171
- # initialize list of output records and intermediate variables
172
- output_records = []
173
- source_records_scanned = 0
174
- current_scan_idx = self.scan_start_idx
140
+ self.max_workers = 1
141
+
142
+ def _any_queue_not_empty(self, queues: dict[str, list]) -> bool:
143
+ """Helper function to check if any queue is not empty."""
144
+ return any(len(queue) > 0 for queue in queues.values())
175
145
 
176
- # get handle to scan operator and pre-compute its size
177
- source_operator = plan.operators[0]
178
- assert isinstance(source_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
179
- datareader_len = len(source_operator.datareader)
146
+ def _upstream_ops_finished(self, plan: PhysicalPlan, op_idx: int, input_queues: dict[str, list]) -> bool:
147
+ """Helper function to check if all upstream operators have finished processing their inputs."""
148
+ for upstream_op_idx in range(op_idx):
149
+ upstream_op_id = plan.operators[upstream_op_idx].get_op_id()
150
+ if len(input_queues[upstream_op_id]) > 0:
151
+ return False
180
152
 
181
- # initialize processing queues for each operation
182
- processing_queues = {op.get_op_id(): [] for op in plan.operators if not isinstance(op, ScanPhysicalOp)}
153
+ return True
183
154
 
155
+ def _execute_plan(self, plan: PhysicalPlan, input_queues: dict[str, list], plan_stats: PlanStats) -> tuple[list[DataRecord], PlanStats]:
184
156
  # execute the plan until either:
185
157
  # 1. all records have been processed, or
186
- # 2. the final limit operation has completed
187
- finished_executing, keep_scanning_source_records = False, True
188
- while not finished_executing:
158
+ # 2. the final limit operation has completed (we break out of the loop if this happens)
159
+ final_output_records = []
160
+ while self._any_queue_not_empty(input_queues):
189
161
  for op_idx, operator in enumerate(plan.operators):
162
+ # if this operator does not have enough inputs to execute, then skip it
190
163
  op_id = operator.get_op_id()
191
-
192
- prev_op_id = plan.operators[op_idx - 1].get_op_id() if op_idx > 1 else None
193
- next_op_id = plan.operators[op_idx + 1].get_op_id() if op_idx + 1 < len(plan.operators) else None
164
+ num_inputs = len(input_queues[op_id])
165
+ agg_op_not_ready = isinstance(operator, AggregateOp) and not self._upstream_ops_finished(plan, op_idx, input_queues)
166
+ if num_inputs == 0 or agg_op_not_ready:
167
+ continue
194
168
 
195
169
  # create empty lists for records and execution stats generated by executing this operator on its next input(s)
196
170
  records, record_op_stats = [], []
197
171
 
198
- # invoke scan operator(s) until we run out of source records or hit the num_samples limit
199
- if isinstance(operator, ScanPhysicalOp):
200
- if keep_scanning_source_records:
201
- # run ScanPhysicalOp on current scan index
202
- record_set = operator(current_scan_idx)
203
- records = record_set.data_records
204
- record_op_stats = record_set.record_op_stats
205
-
206
- # update number of source records scanned and the current index
207
- source_records_scanned += len(records)
208
- current_scan_idx += 1
209
- else:
210
- continue
211
-
212
- # only invoke aggregate operator(s) once there are no more source records and all
213
- # upstream operators' processing queues are empty
214
- elif isinstance(operator, AggregateOp):
215
- upstream_ops_are_finished = True
216
- for upstream_op_idx in range(op_idx):
217
- # scan operators do not have processing queues
218
- if isinstance(plan.operators[upstream_op_idx], ScanPhysicalOp):
219
- continue
220
-
221
- # check upstream ops which do have a processing queue
222
- upstream_op_id = plan.operators[upstream_op_idx].get_op_id()
223
- upstream_ops_are_finished = (
224
- upstream_ops_are_finished and len(processing_queues[upstream_op_id]) == 0
225
- )
226
-
227
- if not keep_scanning_source_records and upstream_ops_are_finished:
228
- record_set = operator(candidates=processing_queues[op_id])
229
- records = record_set.data_records
230
- record_op_stats = record_set.record_op_stats
231
- processing_queues[op_id] = []
232
-
233
- # otherwise, process the next record in the processing queue for this operator
234
- elif len(processing_queues[op_id]) > 0:
235
- input_record = processing_queues[op_id].pop(0)
172
+ # if the next operator is an aggregate, process all the records in the input_queue
173
+ if isinstance(operator, AggregateOp):
174
+ input_records = [input_queues[op_id].pop(0) for _ in range(num_inputs)]
175
+ record_set = operator(candidates=input_records)
176
+ records = record_set.data_records
177
+ record_op_stats = record_set.record_op_stats
178
+ num_outputs = sum(record.passed_operator for record in records)
179
+
180
+ # update the progress manager
181
+ self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
182
+
183
+ # otherwise, process the next record in the input queue for this operator
184
+ else:
185
+ input_record = input_queues[op_id].pop(0)
236
186
  record_set = operator(input_record)
237
187
  records = record_set.data_records
238
188
  record_op_stats = record_set.record_op_stats
189
+ num_outputs = sum(record.passed_operator for record in records)
239
190
 
240
- # if records were generated by this operator, process them
241
- if len(records) > 0:
242
- # update plan stats
243
- plan_stats.operator_stats[op_id].add_record_op_stats(
244
- record_op_stats,
245
- source_op_id=prev_op_id,
246
- plan_id=plan.plan_id,
247
- )
248
-
249
- # add records (which are not filtered) to the cache, if allowed
250
- if not self.nocache:
251
- for record in records:
252
- if getattr(record, "passed_operator", True):
253
- # self.datadir.append_cache(operator.target_cache_id, record)
254
- pass
255
-
256
- # update processing_queues or output_records
257
- for record in records:
258
- if isinstance(operator, FilterOp) and not record.passed_operator:
259
- continue
260
- if next_op_id is not None:
261
- processing_queues[next_op_id].append(record)
262
- else:
263
- output_records.append(record)
264
-
265
- # update finished_executing based on whether all records have been processed
266
- still_processing = any([len(queue) > 0 for queue in processing_queues.values()])
267
- keep_scanning_source_records = current_scan_idx < datareader_len and source_records_scanned < num_samples
268
- finished_executing = not keep_scanning_source_records and not still_processing
269
-
270
- # update finished_executing based on limit
271
- if isinstance(operator, LimitScanOp):
272
- finished_executing = len(output_records) == operator.limit
273
-
274
- # if caching was allowed, close the cache
275
- if not self.nocache:
276
- for _ in plan.operators:
277
- # self.datadir.close_cache(operator.target_cache_id)
278
- pass
191
+ # update the progress manager
192
+ self.progress_manager.incr(op_id, num_outputs=num_outputs, total_cost=record_set.get_total_cost())
193
+
194
+ # update plan stats
195
+ plan_stats.add_record_op_stats(record_op_stats)
196
+
197
+ # add records to the cache
198
+ self._add_records_to_cache(operator.target_cache_id, records)
199
+
200
+ # update next input_queue or final_output_records
201
+ output_records = [record for record in records if record.passed_operator]
202
+ if op_idx + 1 < len(plan.operators):
203
+ next_op_id = plan.operators[op_idx + 1].get_op_id()
204
+ input_queues[next_op_id].extend(output_records)
205
+ else:
206
+ final_output_records.extend(output_records)
207
+
208
+ logger.info(f"Finished processing operator {operator.op_name()} ({operator.get_op_id()}) on {num_inputs} records")
209
+
210
+ # break out of loop if the final operator is a LimitScanOp and we've reached its limit
211
+ if isinstance(plan.operators[-1], LimitScanOp) and len(final_output_records) == plan.operators[-1].limit:
212
+ break
213
+
214
+ # close the cache
215
+ self._close_cache([op.target_cache_id for op in plan.operators])
279
216
 
280
217
  # finalize plan stats
281
- total_plan_time = time.time() - plan_start_time
282
- plan_stats.finalize(total_plan_time)
218
+ plan_stats.finish()
219
+
220
+ return final_output_records, plan_stats
221
+
222
+ def execute_plan(self, plan: PhysicalPlan):
223
+ """Initialize the stats and execute the plan."""
224
+ # for now, assert that the first operator in the plan is a ScanPhysicalOp
225
+ assert isinstance(plan.operators[0], ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
226
+ logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
227
+ logger.info(f"Plan Details: {plan}")
228
+
229
+ # initialize plan stats
230
+ plan_stats = PlanStats.from_plan(plan)
231
+ plan_stats.start()
232
+
233
+ # initialize input queues for each operation
234
+ input_queues = self._create_input_queues(plan)
235
+
236
+ # initialize and start the progress manager
237
+ self.progress_manager = create_progress_manager(plan, self.num_samples, self.progress)
238
+ self.progress_manager.start()
239
+
240
+ # NOTE: we must handle progress manager outside of _exeecute_plan to ensure that it is shut down correctly;
241
+ # if we don't have the `finally:` branch, then program crashes can cause future program runs to fail
242
+ # because the progress manager cannot get a handle to the console
243
+ try:
244
+ # execute plan
245
+ output_records, plan_stats = self._execute_plan(plan, input_queues, plan_stats)
246
+
247
+ finally:
248
+ # finish progress tracking
249
+ self.progress_manager.finish()
250
+
251
+ logger.info(f"Done executing plan: {plan.plan_id}")
252
+ logger.debug(f"Plan stats: (plan_cost={plan_stats.total_plan_cost}, plan_time={plan_stats.total_plan_time})")
283
253
 
284
254
  return output_records, plan_stats
@@ -0,0 +1,31 @@
1
+ from threading import Lock
2
+
3
+ from openai import OpenAI
4
+ from together import Together
5
+
6
+ from palimpzest.constants import APIClient
7
+
8
+
9
+ class APIClientFactory:
10
+ _instances = {}
11
+ _lock = Lock()
12
+
13
+ @classmethod
14
+ def get_client(cls, api_client: APIClient, api_key: str):
15
+ """Get a singleton instance of the requested API client."""
16
+ if api_client not in cls._instances:
17
+ with cls._lock: # Ensure thread safety
18
+ if api_client not in cls._instances: # Double-check inside the lock
19
+ cls._instances[api_client] = cls._create_client(api_client, api_key)
20
+ return cls._instances[api_client]
21
+
22
+ @staticmethod
23
+ def _create_client(api_client: APIClient, api_key: str):
24
+ """Create a new client instance based on the api_client name."""
25
+ match api_client:
26
+ case APIClient.OPENAI:
27
+ return OpenAI(api_key=api_key)
28
+ case APIClient.TOGETHER:
29
+ return Together(api_key=api_key)
30
+ case _:
31
+ raise ValueError(f"Unknown api_client: {api_client}")