palimpzest 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. palimpzest/__init__.py +5 -0
  2. palimpzest/constants.py +110 -43
  3. palimpzest/core/__init__.py +0 -78
  4. palimpzest/core/data/dataclasses.py +382 -44
  5. palimpzest/core/elements/filters.py +7 -3
  6. palimpzest/core/elements/index.py +70 -0
  7. palimpzest/core/elements/records.py +33 -11
  8. palimpzest/core/lib/fields.py +1 -0
  9. palimpzest/core/lib/schemas.py +4 -3
  10. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
  11. palimpzest/prompts/prompt_factory.py +44 -7
  12. palimpzest/prompts/split_merge_prompts.py +56 -0
  13. palimpzest/prompts/split_proposer_prompts.py +55 -0
  14. palimpzest/query/execution/execution_strategy.py +435 -53
  15. palimpzest/query/execution/execution_strategy_type.py +20 -0
  16. palimpzest/query/execution/mab_execution_strategy.py +532 -0
  17. palimpzest/query/execution/parallel_execution_strategy.py +143 -172
  18. palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
  19. palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
  20. palimpzest/query/generators/api_client_factory.py +31 -0
  21. palimpzest/query/generators/generators.py +256 -76
  22. palimpzest/query/operators/__init__.py +1 -2
  23. palimpzest/query/operators/code_synthesis_convert.py +33 -18
  24. palimpzest/query/operators/convert.py +30 -97
  25. palimpzest/query/operators/critique_and_refine_convert.py +5 -6
  26. palimpzest/query/operators/filter.py +7 -10
  27. palimpzest/query/operators/logical.py +54 -10
  28. palimpzest/query/operators/map.py +130 -0
  29. palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
  30. palimpzest/query/operators/physical.py +3 -12
  31. palimpzest/query/operators/rag_convert.py +66 -18
  32. palimpzest/query/operators/retrieve.py +230 -34
  33. palimpzest/query/operators/scan.py +5 -2
  34. palimpzest/query/operators/split_convert.py +169 -0
  35. palimpzest/query/operators/token_reduction_convert.py +8 -14
  36. palimpzest/query/optimizer/__init__.py +4 -16
  37. palimpzest/query/optimizer/cost_model.py +73 -266
  38. palimpzest/query/optimizer/optimizer.py +87 -58
  39. palimpzest/query/optimizer/optimizer_strategy.py +18 -97
  40. palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
  41. palimpzest/query/optimizer/plan.py +2 -3
  42. palimpzest/query/optimizer/primitives.py +5 -3
  43. palimpzest/query/optimizer/rules.py +336 -172
  44. palimpzest/query/optimizer/tasks.py +30 -100
  45. palimpzest/query/processor/config.py +38 -22
  46. palimpzest/query/processor/nosentinel_processor.py +16 -520
  47. palimpzest/query/processor/processing_strategy_type.py +28 -0
  48. palimpzest/query/processor/query_processor.py +38 -206
  49. palimpzest/query/processor/query_processor_factory.py +117 -130
  50. palimpzest/query/processor/sentinel_processor.py +90 -0
  51. palimpzest/query/processor/streaming_processor.py +25 -32
  52. palimpzest/sets.py +88 -41
  53. palimpzest/utils/model_helpers.py +8 -7
  54. palimpzest/utils/progress.py +368 -152
  55. palimpzest/utils/token_reduction_helpers.py +1 -3
  56. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/METADATA +28 -24
  57. palimpzest-0.7.0.dist-info/RECORD +96 -0
  58. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/WHEEL +1 -1
  59. palimpzest/query/processor/mab_sentinel_processor.py +0 -884
  60. palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
  61. palimpzest/utils/index_helpers.py +0 -6
  62. palimpzest-0.6.3.dist-info/RECORD +0 -87
  63. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info/licenses}/LICENSE +0 -0
  64. {palimpzest-0.6.3.dist-info → palimpzest-0.7.0.dist-info}/top_level.txt +0 -0
@@ -1,71 +1,453 @@
1
- import time
1
+ import logging
2
2
  from abc import ABC, abstractmethod
3
- from enum import Enum
3
+ from concurrent.futures import ThreadPoolExecutor, wait
4
4
 
5
- from palimpzest.core.data.dataclasses import ExecutionStats, PlanStats
6
- from palimpzest.core.elements.records import DataRecord
7
- from palimpzest.query.optimizer.plan import PhysicalPlan
5
+ import numpy as np
6
+ from chromadb.api.models.Collection import Collection
8
7
 
8
+ from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
9
+ from palimpzest.core.data.dataclasses import OperatorCostEstimates, PlanStats, RecordOpStats
10
+ from palimpzest.core.data.datareaders import DataReader
11
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
12
+ from palimpzest.policy import Policy
13
+ from palimpzest.query.operators.convert import LLMConvert
14
+ from palimpzest.query.operators.filter import FilterOp, LLMFilter
15
+ from palimpzest.query.operators.physical import PhysicalOperator
16
+ from palimpzest.query.operators.retrieve import RetrieveOp
17
+ from palimpzest.query.operators.scan import ScanPhysicalOp
18
+ from palimpzest.query.optimizer.plan import PhysicalPlan, SentinelPlan
19
+ from palimpzest.utils.progress import PZSentinelProgressManager
9
20
 
10
- class ExecutionStrategyType(str, Enum):
11
- """Available execution strategy types"""
12
- SEQUENTIAL = "sequential"
13
- PIPELINED_SINGLE_THREAD = "pipelined"
14
- PIPELINED_PARALLEL = "pipelined_parallel"
15
- AUTO = "auto"
21
+ logger = logging.getLogger(__name__)
16
22
 
17
-
18
- class ExecutionStrategy(ABC):
19
- """
20
- Base strategy for executing query plans.
21
- Defines how to execute a single plan.
22
- """
23
- def __init__(self,
23
+ class BaseExecutionStrategy:
24
+ def __init__(self,
24
25
  scan_start_idx: int = 0,
25
26
  max_workers: int | None = None,
26
- nocache: bool = True,
27
- verbose: bool = False):
27
+ num_samples: int | None = None,
28
+ cache: bool = False,
29
+ verbose: bool = False,
30
+ progress: bool = True,
31
+ *args,
32
+ **kwargs):
28
33
  self.scan_start_idx = scan_start_idx
29
- self.nocache = nocache
30
- self.verbose = verbose
31
34
  self.max_workers = max_workers
32
- self.execution_stats = []
35
+ self.num_samples = num_samples
36
+ self.cache = cache
37
+ self.verbose = verbose
38
+ self.progress = progress
39
+
33
40
 
41
+ def _add_records_to_cache(self, target_cache_id: str, records: list[DataRecord]) -> None:
42
+ """Add each record (which isn't filtered) to the cache for the given target_cache_id."""
43
+ if self.cache:
44
+ for record in records:
45
+ if getattr(record, "passed_operator", True):
46
+ # self.datadir.append_cache(target_cache_id, record)
47
+ pass
48
+
49
+ def _close_cache(self, target_cache_ids: list[str]) -> None:
50
+ """Close the cache for each of the given target_cache_ids"""
51
+ if self.cache:
52
+ for target_cache_id in target_cache_ids: # noqa: B007
53
+ # self.datadir.close_cache(target_cache_id)
54
+ pass
55
+
56
+ class ExecutionStrategy(BaseExecutionStrategy, ABC):
57
+ """Base strategy for executing query plans. Defines how to execute a PhysicalPlan.
58
+ """
59
+ def __init__(self, *args, **kwargs):
60
+ super().__init__(*args, **kwargs)
61
+ logger.info(f"Initialized ExecutionStrategy {self.__class__.__name__}")
62
+ logger.debug(f"ExecutionStrategy initialized with config: {self.__dict__}")
34
63
 
35
64
  @abstractmethod
36
- def execute_plan(
37
- self,
38
- plan: PhysicalPlan,
39
- num_samples: int | float = float("inf"),
40
- workers: int = 1
41
- ) -> tuple[list[DataRecord], PlanStats]:
65
+ def execute_plan(self, plan: PhysicalPlan) -> tuple[list[DataRecord], PlanStats]:
42
66
  """Execute a single plan according to strategy"""
43
67
  pass
44
68
 
69
+ def _create_input_queues(self, plan: PhysicalPlan) -> dict[str, list]:
70
+ """Initialize input queues for each operator in the plan."""
71
+ input_queues = {}
72
+ for op in plan.operators:
73
+ inputs = []
74
+ if isinstance(op, ScanPhysicalOp):
75
+ scan_end_idx = (
76
+ len(op.datareader)
77
+ if self.num_samples is None
78
+ else min(self.scan_start_idx + self.num_samples, len(op.datareader))
79
+ )
80
+ inputs = [idx for idx in range(self.scan_start_idx, scan_end_idx)]
81
+ input_queues[op.get_op_id()] = inputs
82
+
83
+ return input_queues
45
84
 
46
- # TODO(chjun): use _create_execution_stats for execution stats setup.
47
- ## aggregate plan stats
48
- # aggregate_plan_stats = self.aggregate_plan_stats(plan_stats)
49
-
50
- # # add sentinel records and plan stats (if captured) to plan execution data
51
- # execution_stats = ExecutionStats(
52
- # execution_id=self.execution_id(),
53
- # plan_stats=aggregate_plan_stats,
54
- # total_execution_time=time.time() - execution_start_time,
55
- # total_execution_cost=sum(
56
- # list(map(lambda plan_stats: plan_stats.total_plan_cost, aggregate_plan_stats.values()))
57
- # ),
58
- # plan_strs={plan_id: plan_stats.plan_str for plan_id, plan_stats in aggregate_plan_stats.items()},
59
- # )
60
- def _create_execution_stats(
85
+ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
86
+ """Base strategy for executing sentinel query plans. Defines how to execute a SentinelPlan."""
87
+ """
88
+ Specialized query processor that implements MAB sentinel strategy
89
+ for coordinating optimization and execution.
90
+ """
91
+ def __init__(
61
92
  self,
62
- plan_stats: list[PlanStats],
63
- start_time: float
64
- ) -> ExecutionStats:
65
- """Create execution statistics"""
66
- return ExecutionStats(
67
- execution_id=f"exec_{int(start_time)}",
68
- plan_stats={ps.plan_id: ps for ps in plan_stats},
69
- total_execution_time=time.time() - start_time,
70
- total_execution_cost=sum(ps.total_cost for ps in plan_stats)
93
+ val_datasource: DataReader,
94
+ k: int,
95
+ j: int,
96
+ sample_budget: int,
97
+ policy: Policy,
98
+ use_final_op_quality: bool = False,
99
+ seed: int = 42,
100
+ exp_name: str | None = None,
101
+ *args,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(*args, **kwargs)
105
+ self.val_datasource = val_datasource
106
+ self.k = k
107
+ self.j = j
108
+ self.sample_budget = sample_budget
109
+ self.policy = policy
110
+ self.use_final_op_quality = use_final_op_quality
111
+ self.seed = seed
112
+ self.rng = np.random.default_rng(seed=seed)
113
+ self.exp_name = exp_name
114
+
115
+ # special cache which is used for tracking the target record sets for each (source_idx, logical_op_id)
116
+ self.champion_output_cache: dict[int, dict[str, tuple[DataRecordSet, float]]] = {}
117
+
118
+ # general cache which maps hash(logical_op_id, phys_op_id, hash(input)) --> record_set
119
+ self.cache: dict[int, DataRecordSet] = {}
120
+
121
+ # progress manager used to track progress of the execution
122
+ self.progress_manager: PZSentinelProgressManager | None = None
123
+
124
+ def _compute_quality(
125
+ self,
126
+ physical_op_cls: type[PhysicalOperator],
127
+ record_set: DataRecordSet,
128
+ target_record_set: DataRecordSet,
129
+ ) -> DataRecordSet:
130
+ """
131
+ Compute the quality for the given `record_set` by comparing it to the `target_record_set`.
132
+
133
+ Update the record_set by assigning the quality to each entry in its record_op_stats and
134
+ returning the updated record_set.
135
+ """
136
+ # if this operation failed
137
+ if len(record_set) == 0:
138
+ record_set.record_op_stats[0].quality = 0.0
139
+
140
+ # if this operation is a filter:
141
+ # - return 1.0 if there's a match in the expected output which this operator does not filter out and 0.0 otherwise
142
+ elif issubclass(physical_op_cls, FilterOp):
143
+ # NOTE: we know that record_set.data_records will contain a single entry for a filter op
144
+ record = record_set.data_records[0]
145
+
146
+ # search for a record in the target with the same set of fields
147
+ found_match_in_target = False
148
+ for target_record in target_record_set:
149
+ all_correct = True
150
+ for field, value in record.field_values.items():
151
+ if value != target_record[field]:
152
+ all_correct = False
153
+ break
154
+
155
+ if all_correct:
156
+ found_match_in_target = target_record.passed_operator
157
+ break
158
+
159
+ # set quality based on whether we found a match in the target and return
160
+ record_set.record_op_stats[0].quality = int(record.passed_operator == found_match_in_target)
161
+
162
+ return record_set
163
+
164
+ # if this is a successful convert operation
165
+ else:
166
+ # NOTE: the following computation assumes we do not project out computed values
167
+ # (and that the validation examples provide all computed fields); even if
168
+ # a user program does add projection, we can ignore the projection on the
169
+ # validation dataset and use the champion model (as opposed to the validation
170
+ # output) for scoring fields which have their values projected out
171
+
172
+ # GREEDY ALGORITHM
173
+ # for each record in the expected output, we look for the computed record which maximizes the quality metric;
174
+ # once we've identified that computed record we remove it from consideration for the next expected output
175
+ field_to_score_fn = target_record_set.get_field_to_score_fn()
176
+ for target_record in target_record_set:
177
+ best_quality, best_record_op_stats = 0.0, None
178
+ for record_op_stats in record_set.record_op_stats:
179
+ # if we already assigned this record a quality, skip it
180
+ if record_op_stats.quality is not None:
181
+ continue
182
+
183
+ # compute number of matches between this record's computed fields and this expected record's outputs
184
+ total_quality = 0
185
+ for field in record_op_stats.generated_fields:
186
+ computed_value = record_op_stats.record_state.get(field, None)
187
+ expected_value = target_record[field]
188
+
189
+ # get the metric function for this field
190
+ score_fn = field_to_score_fn.get(field, "exact")
191
+
192
+ # compute exact match
193
+ if score_fn == "exact":
194
+ total_quality += int(computed_value == expected_value)
195
+
196
+ # compute UDF metric
197
+ elif callable(score_fn):
198
+ total_quality += score_fn(computed_value, expected_value)
199
+
200
+ # otherwise, throw an exception
201
+ else:
202
+ raise Exception(f"Unrecognized score_fn: {score_fn}")
203
+
204
+ # compute recall and update best seen so far
205
+ quality = total_quality / len(record_op_stats.generated_fields)
206
+ if quality > best_quality:
207
+ best_quality = quality
208
+ best_record_op_stats = record_op_stats
209
+
210
+ # set best_quality as quality for the best_record_op_stats
211
+ if best_record_op_stats is not None:
212
+ best_record_op_stats.quality = best_quality
213
+
214
+ # for any records which did not receive a quality, set it to 0.0 as these are unexpected extras
215
+ for record_op_stats in record_set.record_op_stats:
216
+ if record_op_stats.quality is None:
217
+ record_op_stats.quality = 0.0
218
+
219
+ return record_set
220
+
221
+ def _score_quality(
222
+ self,
223
+ physical_op_cls: type[PhysicalOperator],
224
+ source_idx_to_record_sets: dict[int, list[DataRecordSet]],
225
+ source_idx_to_target_record_set: dict[int, DataRecordSet],
226
+ ) -> dict[int, list[DataRecordSet]]:
227
+ """
228
+ NOTE: This approach to cost modeling does not work directly for aggregation queries;
229
+ for these queries, we would ask the user to provide validation data for the step immediately
230
+ before a final aggregation
231
+
232
+ NOTE: This function currently assumes that one-to-many converts do NOT create duplicate outputs.
233
+ This assumption would break if, for example, we extracted the breed of every dog in an image.
234
+ If there were two golden retrievers and a bernoodle in an image and we extracted:
235
+
236
+ {"image": "file1.png", "breed": "Golden Retriever"}
237
+ {"image": "file1.png", "breed": "Golden Retriever"}
238
+ {"image": "file1.png", "breed": "Bernedoodle"}
239
+
240
+ This function would currently give perfect accuracy to the following output:
241
+
242
+ {"image": "file1.png", "breed": "Golden Retriever"}
243
+ {"image": "file1.png", "breed": "Bernedoodle"}
244
+
245
+ Even though it is missing one of the golden retrievers.
246
+ """
247
+ # extract information about the logical operation performed at this stage of the sentinel plan;
248
+ # NOTE: we can infer these fields from context clues, but in the long-term we should have a more
249
+ # principled way of getting these directly from attributes either stored in the sentinel_plan
250
+ # or in the PhysicalOperator
251
+ is_perfect_quality_op = (
252
+ not issubclass(physical_op_cls, LLMConvert)
253
+ and not issubclass(physical_op_cls, LLMFilter)
254
+ and not issubclass(physical_op_cls, RetrieveOp)
71
255
  )
256
+
257
+ # compute quality of each output computed by this operator
258
+ for source_idx, record_sets in source_idx_to_record_sets.items():
259
+ # if this operation does not involve an LLM, every record_op_stats object gets perfect quality
260
+ if is_perfect_quality_op:
261
+ for record_set in record_sets:
262
+ for record_op_stats in record_set.record_op_stats:
263
+ record_op_stats.quality = 1.0
264
+ continue
265
+
266
+ # extract target output for this record set
267
+ target_record_set = source_idx_to_target_record_set[source_idx]
268
+
269
+ # for each record_set produced by an operation, compute its quality
270
+ for record_set in record_sets:
271
+ record_set = self._compute_quality(physical_op_cls, record_set, target_record_set)
272
+
273
+ # return the quality annotated record sets
274
+ return source_idx_to_record_sets
275
+
276
+ def _get_target_record_sets(
277
+ self,
278
+ logical_op_id: str,
279
+ source_idx_to_record_set_tuples: dict[int, list[tuple[DataRecordSet, PhysicalOperator, bool]]],
280
+ expected_outputs: dict[int, dict] | None,
281
+ ) -> dict[int, DataRecordSet]:
282
+ # initialize mapping from source index to target record sets
283
+ source_idx_to_target_record_set = {}
284
+
285
+ for source_idx, record_set_tuples in source_idx_to_record_set_tuples.items():
286
+ # get the first generated output for this source_idx
287
+ base_target_record = None
288
+ for record_set, _, _ in record_set_tuples:
289
+ if len(record_set) > 0:
290
+ base_target_record = record_set[0]
291
+ break
292
+
293
+ # compute availability of data
294
+ base_target_present = base_target_record is not None
295
+ labels_present = expected_outputs is not None
296
+ labels_for_source_present = False
297
+ if labels_present and source_idx in expected_outputs:
298
+ labels = expected_outputs[source_idx].get("labels", [])
299
+ labels_dict_lst = labels if isinstance(labels, list) else [labels]
300
+ labels_for_source_present = labels_dict_lst != [] and labels_dict_lst != [None]
301
+
302
+ # if we have a base target record and label info, use the label info to construct the target record set
303
+ if base_target_present and labels_for_source_present:
304
+ # get the field_to_score_fn
305
+ field_to_score_fn = expected_outputs[source_idx].get("score_fn", {})
306
+
307
+ # construct the target record set; we force passed_operator to be True for all target records
308
+ target_records = []
309
+ for labels_dict in labels_dict_lst:
310
+ target_record = base_target_record.copy()
311
+ for field, value in labels_dict.items():
312
+ target_record[field] = value
313
+ target_record.passed_operator = True
314
+ target_records.append(target_record)
315
+
316
+ source_idx_to_target_record_set[source_idx] = DataRecordSet(target_records, None, field_to_score_fn)
317
+ continue
318
+
319
+ # get the best computed output for this (source_idx, logical_op_id) so far (if one exists)
320
+ champion_record_set, champion_op_quality = None, None
321
+ if source_idx in self.champion_output_cache and logical_op_id in self.champion_output_cache[source_idx]:
322
+ champion_record_set, champion_op_quality = self.champion_output_cache[source_idx][logical_op_id]
323
+
324
+ # get the highest quality output that we just computed
325
+ max_quality_record_set, max_op_quality = self._pick_champion_output(record_set_tuples)
326
+
327
+ # if this new output is of higher quality than our previous champion (or if we didn't have
328
+ # a previous champion) then we update our champion record set
329
+ if champion_op_quality is None or (max_op_quality is not None and max_op_quality > champion_op_quality):
330
+ champion_record_set, champion_op_quality = max_quality_record_set, max_op_quality
331
+
332
+ # update the cache with the new champion record set and quality
333
+ if source_idx not in self.champion_output_cache:
334
+ self.champion_output_cache[source_idx] = {}
335
+ self.champion_output_cache[source_idx][logical_op_id] = (champion_record_set, champion_op_quality)
336
+
337
+ # set the target
338
+ source_idx_to_target_record_set[source_idx] = champion_record_set
339
+
340
+ return source_idx_to_target_record_set
341
+
342
+ def _pick_champion_output(self, record_set_tuples: list[tuple[DataRecordSet, PhysicalOperator, bool]]) -> tuple[DataRecordSet, float | None]:
343
+ # find the operator with the highest estimated quality and return its record_set
344
+ base_op_cost_est = OperatorCostEstimates(cardinality=1.0, cost_per_record=0.0, time_per_record=0.0, quality=1.0)
345
+ champion_record_set, champion_quality = None, None
346
+ for record_set, op, _ in record_set_tuples:
347
+ # skip failed operations
348
+ if len(record_set) == 0:
349
+ continue
350
+
351
+ # get the estimated quality of this operator
352
+ est_quality = op.naive_cost_estimates(base_op_cost_est).quality if self._is_llm_op(op) else 1.0
353
+ if champion_quality is None or est_quality > champion_quality:
354
+ champion_record_set, champion_quality = record_set, est_quality
355
+
356
+ return champion_record_set, champion_quality
357
+
358
+ def _flatten_record_sets(self, source_idx_to_record_sets: dict[int, list[DataRecordSet]]) -> tuple[list[DataRecord], list[RecordOpStats]]:
359
+ """
360
+ Flatten the list of record sets and record op stats for each source_idx.
361
+ """
362
+ all_records, all_record_op_stats = [], []
363
+ for _, record_sets in source_idx_to_record_sets.items():
364
+ for record_set in record_sets:
365
+ all_records.extend(record_set.data_records)
366
+ all_record_op_stats.extend(record_set.record_op_stats)
367
+
368
+ return all_records, all_record_op_stats
369
+
370
+ def _execute_op_set(self, op_input_pairs: list[tuple[PhysicalOperator, DataRecord | int]]) -> tuple[dict[int, list[tuple[DataRecordSet, PhysicalOperator, bool]]], dict[str, int]]:
371
+ def execute_op_wrapper(operator, input) -> tuple[DataRecordSet, PhysicalOperator, DataRecord | int]:
372
+ record_set = operator(input)
373
+ return record_set, operator, input
374
+
375
+ # TODO: modify unit tests to always have record_op_stats so we can use record_op_stats for source_idx
376
+ # for scan operators, `input` will be the source_idx
377
+ def get_source_idx(input):
378
+ return input.source_idx if isinstance(input, DataRecord) else input
379
+
380
+ def get_hash(operator, input):
381
+ logical_op_id = operator.get_logical_op_id()
382
+ phys_op_id = operator.get_op_id()
383
+ return hash(f"{logical_op_id}{phys_op_id}{hash(input)}")
384
+
385
+ # initialize mapping from source indices to output record sets
386
+ source_idx_to_record_sets_and_ops = {get_source_idx(input): [] for _, input in op_input_pairs}
387
+
388
+ # if any operations were previously executed, read the results from the cache
389
+ final_op_input_pairs = []
390
+ for operator, input in op_input_pairs:
391
+ # compute hash
392
+ op_input_hash = get_hash(operator, input)
393
+
394
+ # get result from cache
395
+ if op_input_hash in self.cache:
396
+ source_idx = get_source_idx(input)
397
+ record_set, operator = self.cache[op_input_hash]
398
+ source_idx_to_record_sets_and_ops[source_idx].append((record_set, operator, False))
399
+
400
+ # otherwise, add to final_op_input_pairs
401
+ else:
402
+ final_op_input_pairs.append((operator, input))
403
+
404
+ # keep track of the number of llm operations
405
+ num_llm_ops = 0
406
+
407
+ # create thread pool w/max workers and run futures over worker pool
408
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
409
+ # create futures
410
+ futures = [
411
+ executor.submit(execute_op_wrapper, operator, input)
412
+ for operator, input in final_op_input_pairs
413
+ ]
414
+ output_record_sets = []
415
+ while len(futures) > 0:
416
+ done_futures, not_done_futures = wait(futures, timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
417
+ for future in done_futures:
418
+ # update output record sets
419
+ record_set, operator, input = future.result()
420
+ output_record_sets.append((record_set, operator, input))
421
+
422
+ # update cache
423
+ op_input_hash = get_hash(operator, input)
424
+ self.cache[op_input_hash] = (record_set, operator)
425
+
426
+ # update progress manager
427
+ if self._is_llm_op(operator):
428
+ num_llm_ops += 1
429
+ self.progress_manager.incr(operator.get_logical_op_id(), num_samples=1, total_cost=record_set.get_total_cost())
430
+
431
+ # update futures
432
+ futures = list(not_done_futures)
433
+
434
+ # update mapping from source_idx to record sets and operators
435
+ for record_set, operator, input in output_record_sets:
436
+ # get the source_idx associated with this input record;
437
+ source_idx = get_source_idx(input)
438
+
439
+ # add record_set to mapping from source_idx --> record_sets
440
+ source_idx_to_record_sets_and_ops[source_idx].append((record_set, operator, True))
441
+
442
+ return source_idx_to_record_sets_and_ops, num_llm_ops
443
+
444
+ def _is_llm_op(self, physical_op: PhysicalOperator) -> bool:
445
+ is_llm_convert = isinstance(physical_op, LLMConvert)
446
+ is_llm_filter = isinstance(physical_op, LLMFilter)
447
+ is_llm_retrieve = isinstance(physical_op, RetrieveOp) and isinstance(physical_op.index, Collection)
448
+ return is_llm_convert or is_llm_filter or is_llm_retrieve
449
+
450
+ @abstractmethod
451
+ def execute_sentinel_plan(self, sentinel_plan: SentinelPlan, expected_outputs: dict[str, dict]):
452
+ """Execute a SentinelPlan according to strategy"""
453
+ pass
@@ -0,0 +1,20 @@
1
+ from enum import Enum
2
+
3
+ from palimpzest.query.execution.mab_execution_strategy import MABExecutionStrategy
4
+ from palimpzest.query.execution.parallel_execution_strategy import ParallelExecutionStrategy
5
+ from palimpzest.query.execution.random_sampling_execution_strategy import RandomSamplingExecutionStrategy
6
+ from palimpzest.query.execution.single_threaded_execution_strategy import (
7
+ PipelinedSingleThreadExecutionStrategy,
8
+ SequentialSingleThreadExecutionStrategy,
9
+ )
10
+
11
+
12
+ class ExecutionStrategyType(Enum):
13
+ """Available execution strategy types"""
14
+ SEQUENTIAL = SequentialSingleThreadExecutionStrategy
15
+ PIPELINED = PipelinedSingleThreadExecutionStrategy
16
+ PARALLEL = ParallelExecutionStrategy
17
+
18
+ class SentinelExecutionStrategyType(Enum):
19
+ MAB = MABExecutionStrategy
20
+ RANDOM = RandomSamplingExecutionStrategy