palimpzest 0.5.4__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. palimpzest/__init__.py +7 -9
  2. palimpzest/constants.py +47 -7
  3. palimpzest/core/__init__.py +20 -26
  4. palimpzest/core/data/dataclasses.py +9 -2
  5. palimpzest/core/data/datareaders.py +497 -0
  6. palimpzest/core/elements/records.py +29 -37
  7. palimpzest/core/lib/fields.py +14 -12
  8. palimpzest/core/lib/schemas.py +80 -94
  9. palimpzest/policy.py +58 -0
  10. palimpzest/prompts/__init__.py +22 -0
  11. palimpzest/prompts/code_synthesis_prompts.py +28 -0
  12. palimpzest/prompts/convert_prompts.py +87 -0
  13. palimpzest/prompts/critique_and_refine_convert_prompts.py +216 -0
  14. palimpzest/prompts/filter_prompts.py +69 -0
  15. palimpzest/prompts/moa_aggregator_convert_prompts.py +57 -0
  16. palimpzest/prompts/moa_proposer_convert_prompts.py +79 -0
  17. palimpzest/prompts/prompt_factory.py +732 -0
  18. palimpzest/prompts/util_phrases.py +14 -0
  19. palimpzest/query/execution/execution_strategy.py +0 -3
  20. palimpzest/query/execution/parallel_execution_strategy.py +12 -25
  21. palimpzest/query/execution/single_threaded_execution_strategy.py +31 -45
  22. palimpzest/query/generators/generators.py +71 -347
  23. palimpzest/query/operators/__init__.py +5 -5
  24. palimpzest/query/operators/aggregate.py +10 -5
  25. palimpzest/query/operators/code_synthesis_convert.py +4 -48
  26. palimpzest/query/operators/convert.py +5 -2
  27. palimpzest/query/operators/critique_and_refine_convert.py +112 -0
  28. palimpzest/query/operators/filter.py +1 -1
  29. palimpzest/query/operators/limit.py +1 -1
  30. palimpzest/query/operators/logical.py +28 -27
  31. palimpzest/query/operators/mixture_of_agents_convert.py +4 -1
  32. palimpzest/query/operators/physical.py +32 -20
  33. palimpzest/query/operators/project.py +1 -1
  34. palimpzest/query/operators/rag_convert.py +6 -3
  35. palimpzest/query/operators/retrieve.py +13 -31
  36. palimpzest/query/operators/scan.py +150 -0
  37. palimpzest/query/optimizer/__init__.py +5 -1
  38. palimpzest/query/optimizer/cost_model.py +18 -34
  39. palimpzest/query/optimizer/optimizer.py +40 -25
  40. palimpzest/query/optimizer/optimizer_strategy.py +26 -0
  41. palimpzest/query/optimizer/plan.py +2 -2
  42. palimpzest/query/optimizer/rules.py +118 -27
  43. palimpzest/query/processor/config.py +12 -1
  44. palimpzest/query/processor/mab_sentinel_processor.py +125 -112
  45. palimpzest/query/processor/nosentinel_processor.py +46 -62
  46. palimpzest/query/processor/query_processor.py +10 -20
  47. palimpzest/query/processor/query_processor_factory.py +12 -5
  48. palimpzest/query/processor/random_sampling_sentinel_processor.py +112 -91
  49. palimpzest/query/processor/streaming_processor.py +11 -17
  50. palimpzest/sets.py +170 -94
  51. palimpzest/tools/pdfparser.py +5 -64
  52. palimpzest/utils/datareader_helpers.py +61 -0
  53. palimpzest/utils/field_helpers.py +69 -0
  54. palimpzest/utils/hash_helpers.py +3 -2
  55. palimpzest/utils/udfs.py +0 -28
  56. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/METADATA +49 -49
  57. palimpzest-0.6.1.dist-info/RECORD +87 -0
  58. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/top_level.txt +0 -1
  59. cli/README.md +0 -156
  60. cli/__init__.py +0 -0
  61. cli/cli_main.py +0 -390
  62. palimpzest/config.py +0 -89
  63. palimpzest/core/data/datasources.py +0 -369
  64. palimpzest/datamanager/__init__.py +0 -0
  65. palimpzest/datamanager/datamanager.py +0 -300
  66. palimpzest/prompts.py +0 -397
  67. palimpzest/query/operators/datasource.py +0 -202
  68. palimpzest-0.5.4.dist-info/RECORD +0 -83
  69. palimpzest-0.5.4.dist-info/entry_points.txt +0 -2
  70. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/LICENSE +0 -0
  71. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/WHEEL +0 -0
@@ -1,15 +1,18 @@
1
1
  import time
2
2
  from concurrent.futures import ThreadPoolExecutor, wait
3
- from functools import partial
4
- from typing import Callable
3
+ from copy import deepcopy
5
4
 
6
5
  import numpy as np
7
6
 
8
7
  from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
9
- from palimpzest.core.data.dataclasses import ExecutionStats, OperatorStats, PlanStats, RecordOpStats
10
- from palimpzest.core.data.datasources import ValidationDataSource
11
- from palimpzest.core.elements.records import DataRecord, DataRecordCollection, DataRecordSet
12
- from palimpzest.core.lib.schemas import SourceRecord
8
+ from palimpzest.core.data.dataclasses import (
9
+ ExecutionStats,
10
+ OperatorCostEstimates,
11
+ OperatorStats,
12
+ PlanStats,
13
+ RecordOpStats,
14
+ )
15
+ from palimpzest.core.elements.records import DataRecordCollection, DataRecordSet
13
16
  from palimpzest.policy import Policy
14
17
  from palimpzest.query.execution.parallel_execution_strategy import PipelinedParallelExecutionStrategy
15
18
  from palimpzest.query.execution.single_threaded_execution_strategy import (
@@ -17,10 +20,10 @@ from palimpzest.query.execution.single_threaded_execution_strategy import (
17
20
  SequentialSingleThreadExecutionStrategy,
18
21
  )
19
22
  from palimpzest.query.operators.convert import ConvertOp, LLMConvert
20
- from palimpzest.query.operators.datasource import CacheScanDataOp, MarshalAndScanDataOp
21
23
  from palimpzest.query.operators.filter import FilterOp, LLMFilter
22
24
  from palimpzest.query.operators.physical import PhysicalOperator
23
25
  from palimpzest.query.operators.retrieve import RetrieveOp
26
+ from palimpzest.query.operators.scan import CacheScanDataOp, MarshalAndScanDataOp
24
27
  from palimpzest.query.optimizer.cost_model import SampleBasedCostModel
25
28
  from palimpzest.query.optimizer.optimizer_strategy import OptimizationStrategyType
26
29
  from palimpzest.query.optimizer.plan import SentinelPlan
@@ -50,8 +53,6 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
50
53
  # self.max_workers = self.get_parallel_max_workers()
51
54
  # TODO: undo
52
55
  # self.max_workers = 1
53
- assert isinstance(self.datasource, ValidationDataSource), "DataSource must be ValidationDataSource for sentinel execution"
54
-
55
56
  self.k = k
56
57
  self.sample_budget = sample_budget
57
58
  self.j = int(sample_budget / k)
@@ -68,20 +69,19 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
68
69
  def compute_quality(
69
70
  self,
70
71
  record_set: DataRecordSet,
71
- expected_record_set: DataRecordSet | None = None,
72
+ expected_output: dict | None = None,
72
73
  champion_record_set: DataRecordSet | None = None,
73
74
  is_filter_op: bool = False,
74
75
  is_convert_op: bool = False,
75
- field_to_metric_fn: dict[str, str | Callable] | None = None,
76
76
  ) -> DataRecordSet:
77
77
  """
78
- Compute the quality for the given `record_set` by comparing it to the `expected_record_set`.
78
+ Compute the quality for the given `record_set` by comparing it to the `expected_output`.
79
79
 
80
80
  Update the record_set by assigning the quality to each entry in its record_op_stats and
81
81
  returning the updated record_set.
82
82
  """
83
83
  # compute whether we can only use the champion
84
- only_using_champion = expected_record_set is None
84
+ only_using_champion = expected_output is None
85
85
 
86
86
  # if this operation is a failed convert
87
87
  if is_convert_op and len(record_set) == 0:
@@ -103,16 +103,17 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
103
103
  champion_record = champion_record_set[0]
104
104
  record_op_stats.quality = int(record_op_stats.passed_operator == champion_record.passed_operator)
105
105
 
106
- # - if we are using validation data, we may have multiple expected records in the expected_record_set for this source_id,
106
+ # - if we are using validation data, we may have multiple expected records in the expected_output for this source_idx,
107
107
  # thus, if we can identify an exact match, we can use that to evaluate the filter's quality
108
108
  # - if we are using validation data but we *cannot* find an exact match, then we will once again use the champion record set
109
109
  else:
110
110
  # compute number of matches between this record's computed fields and this expected record's outputs
111
111
  found_match_in_output = False
112
- for expected_record in expected_record_set:
112
+ labels_dict_lst = expected_output["labels"] if isinstance(expected_output["labels"], list) else [expected_output["labels"]]
113
+ for labels_dict in labels_dict_lst:
113
114
  all_correct = True
114
115
  for field, value in record_op_stats.record_state.items():
115
- if value != getattr(expected_record, field):
116
+ if value != labels_dict[field]:
116
117
  all_correct = False
117
118
  break
118
119
 
@@ -121,7 +122,7 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
121
122
  break
122
123
 
123
124
  if found_match_in_output:
124
- record_op_stats.quality = int(record_op_stats.passed_operator == expected_record.passed_operator)
125
+ record_op_stats.quality = int(record_op_stats.passed_operator)
125
126
  else:
126
127
  champion_record = champion_record_set[0]
127
128
  record_op_stats.quality = int(record_op_stats.passed_operator == champion_record.passed_operator)
@@ -134,13 +135,23 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
134
135
  # validation dataset and use the champion model (as opposed to the validation
135
136
  # output) for scoring fields which have their values projected out
136
137
 
137
- # set the expected_record_set to be the champion_record_set if we do not have validation data
138
- expected_record_set = champion_record_set if only_using_champion else expected_record_set
138
+ # create list of dictionaries of labels for each expected / champion output
139
+ labels_dict_lst = []
140
+ if only_using_champion:
141
+ for champion_record in champion_record_set:
142
+ labels_dict_lst.append(champion_record.to_dict())
143
+ else:
144
+ labels_dict_lst = (
145
+ expected_output["labels"]
146
+ if isinstance(expected_output["labels"], list)
147
+ else [expected_output["labels"]]
148
+ )
139
149
 
140
150
  # GREEDY ALGORITHM
141
151
  # for each record in the expected output, we look for the computed record which maximizes the quality metric;
142
152
  # once we've identified that computed record we remove it from consideration for the next expected output
143
- for expected_record in expected_record_set:
153
+ field_to_score_fn = {} if only_using_champion else expected_output["score_fn"]
154
+ for labels_dict in labels_dict_lst:
144
155
  best_quality, best_record_op_stats = 0.0, None
145
156
  for record_op_stats in record_set.record_op_stats:
146
157
  # if we already assigned this record a quality, skip it
@@ -151,26 +162,22 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
151
162
  total_quality = 0
152
163
  for field in record_op_stats.generated_fields:
153
164
  computed_value = record_op_stats.record_state.get(field, None)
154
- expected_value = getattr(expected_record, field)
165
+ expected_value = labels_dict[field]
155
166
 
156
167
  # get the metric function for this field
157
- metric_fn = (
158
- field_to_metric_fn[field]
159
- if field_to_metric_fn is not None and field in field_to_metric_fn
160
- else "exact"
161
- )
168
+ score_fn = field_to_score_fn.get(field, "exact")
162
169
 
163
170
  # compute exact match
164
- if metric_fn == "exact":
171
+ if score_fn == "exact":
165
172
  total_quality += int(computed_value == expected_value)
166
173
 
167
174
  # compute UDF metric
168
- elif callable(metric_fn):
169
- total_quality += metric_fn(computed_value, expected_value)
175
+ elif callable(score_fn):
176
+ total_quality += score_fn(computed_value, expected_value)
170
177
 
171
178
  # otherwise, throw an exception
172
179
  else:
173
- raise Exception(f"Unrecognized metric_fn: {metric_fn}")
180
+ raise Exception(f"Unrecognized score_fn: {score_fn}")
174
181
 
175
182
  # compute recall and update best seen so far
176
183
  quality = total_quality / len(record_op_stats.generated_fields)
@@ -195,8 +202,7 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
195
202
  operator_sets: list[list[PhysicalOperator]],
196
203
  execution_data: dict[str, dict[str, list[DataRecordSet]]],
197
204
  champion_outputs: dict[str, dict[str, DataRecordSet]],
198
- expected_outputs: dict[str, DataRecordSet] | None = None,
199
- field_to_metric_fn: dict[str, str | Callable] | None = None,
205
+ expected_outputs: dict[str, dict],
200
206
  ) -> list[RecordOpStats]:
201
207
  """
202
208
  NOTE: This approach to cost modeling does not work directly for aggregation queries;
@@ -242,9 +248,9 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
242
248
  this_op_execution_data = execution_data[logical_op_id]
243
249
 
244
250
  # compute quality of each output computed by this operator
245
- for source_id, record_sets in this_op_execution_data.items():
251
+ for source_idx, record_sets in this_op_execution_data.items():
246
252
  # NOTE
247
- # source_id is a particular input, for which we may have computed multiple output record_sets;
253
+ # source_idx is a particular input, for which we may have computed multiple output record_sets;
248
254
  # each of these record_sets may contain more than one record (b/c one-to-many) and we have one
249
255
  # record_set per operator in the op_set
250
256
 
@@ -255,30 +261,45 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
255
261
  record_op_stats.quality = 1.0
256
262
  continue
257
263
 
258
- # get the expected output for this source_id if we have one
259
- expected_record_set = (
260
- expected_outputs[source_id]
261
- if expected_outputs is not None and source_id in expected_outputs
264
+ # get the expected output for this source_idx if we have one
265
+ expected_output = (
266
+ expected_outputs[source_idx]
267
+ if expected_outputs is not None and source_idx in expected_outputs
262
268
  else None
263
269
  )
264
270
 
265
271
  # extract champion output for this record set
266
- champion_record_set = champion_outputs[logical_op_id][source_id]
272
+ champion_record_set = champion_outputs[logical_op_id][source_idx]
267
273
 
268
274
  # for each record_set produced by an operation, compute its quality
269
275
  for record_set in record_sets:
270
- record_set = self.compute_quality(record_set, expected_record_set, champion_record_set, is_filter_op, is_convert_op, field_to_metric_fn)
276
+ record_set = self.compute_quality(record_set, expected_output, champion_record_set, is_filter_op, is_convert_op)
271
277
 
272
278
  # if this operator is a source op (i.e. has no input logical operator), return the execution data
273
279
  if is_source_op:
274
280
  return execution_data
275
281
 
276
282
  # recursively call the function on the next logical operator until you reach a scan
277
- execution_data = self.score_quality(operator_sets[:-1], execution_data, champion_outputs, expected_outputs, field_to_metric_fn)
283
+ execution_data = self.score_quality(operator_sets[:-1], execution_data, champion_outputs, expected_outputs)
278
284
 
279
285
  # return the quality annotated record op stats
280
286
  return execution_data
281
287
 
288
+ def pick_champion_output(self, op_set_record_sets: list[tuple[DataRecordSet, PhysicalOperator]]) -> DataRecordSet:
289
+ # if there's only one operator in the set, we return its record_set
290
+ if len(op_set_record_sets) == 1:
291
+ record_set, _ = op_set_record_sets[0]
292
+ return record_set
293
+
294
+ # find the operator with the highest average quality and return its record_set
295
+ base_op_cost_est = OperatorCostEstimates(cardinality=1.0, cost_per_record=0.0, time_per_record=0.0, quality=1.0)
296
+ champion_record_set, champion_quality = None, -1.0
297
+ for record_set, op in op_set_record_sets:
298
+ op_cost_estimates = op.naive_cost_estimates(base_op_cost_est)
299
+ if op_cost_estimates.quality > champion_quality:
300
+ champion_record_set, champion_quality = record_set, op_cost_estimates.quality
301
+
302
+ return champion_record_set
282
303
 
283
304
  def pick_ensemble_output(self, op_set_record_sets: list[tuple[DataRecordSet, PhysicalOperator]]) -> DataRecordSet:
284
305
  # if there's only one operator in the set, we return its record_set
@@ -341,7 +362,7 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
341
362
  # update list of futures
342
363
  futures = not_done_futures
343
364
 
344
- # compute mapping from source_id to record sets for all operators and for champion operator
365
+ # compute mapping from source_idx to record sets for all operators and for champion operator
345
366
  all_record_sets, champion_record_sets = {}, {}
346
367
  for candidate in candidates:
347
368
  candidate_output_record_sets = []
@@ -352,19 +373,19 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
352
373
  # select the champion (i.e. best) record_set from all the record sets computed for this operator
353
374
  champion_record_set = self.pick_output_fn(candidate_output_record_sets)
354
375
 
355
- # get the source_id associated with this input record
356
- source_id = candidate.source_id
376
+ # get the source_idx associated with this input record
377
+ source_idx = candidate.source_idx
357
378
 
358
- # add champion record_set to mapping from source_id --> champion record_set
359
- champion_record_sets[source_id] = champion_record_set
379
+ # add champion record_set to mapping from source_idx --> champion record_set
380
+ champion_record_sets[source_idx] = champion_record_set
360
381
 
361
- # add all record_sets computed for this source_id to mapping from source_id --> record_sets
362
- all_record_sets[source_id] = [tup[0] for tup in candidate_output_record_sets]
382
+ # add all record_sets computed for this source_idx to mapping from source_idx --> record_sets
383
+ all_record_sets[source_idx] = [tup[0] for tup in candidate_output_record_sets]
363
384
 
364
385
  return all_record_sets, champion_record_sets
365
386
 
366
387
 
367
- def execute_sentinel_plan(self, plan: SentinelPlan, expected_outputs: dict[str, DataRecordSet], policy: Policy):
388
+ def execute_sentinel_plan(self, plan: SentinelPlan, expected_outputs: dict[str, dict], policy: Policy):
368
389
  """
369
390
  """
370
391
  if self.verbose:
@@ -389,26 +410,23 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
389
410
  )
390
411
 
391
412
  # sample validation records
392
- total_num_samples = self.datasource.get_val_length()
393
- sample_indices = np.arange(total_num_samples)
413
+ total_num_samples = len(self.val_datasource)
414
+ source_indices = np.arange(total_num_samples)
394
415
  if self.sample_start_idx is not None:
395
416
  assert self.sample_end_idx is not None
396
- sample_indices = sample_indices[self.sample_start_idx:self.sample_end_idx]
417
+ source_indices = source_indices[self.sample_start_idx:self.sample_end_idx]
397
418
  elif not self.sample_all_records:
398
- self.rng.shuffle(sample_indices)
399
- j = min(self.j, len(sample_indices))
400
- sample_indices = sample_indices[:j]
419
+ self.rng.shuffle(source_indices)
420
+ j = min(self.j, len(source_indices))
421
+ source_indices = source_indices[:j]
401
422
 
402
423
  # initialize output variables
403
424
  all_outputs, champion_outputs = {}, {}
404
425
 
405
426
  # create initial set of candidates for source scan operator
406
427
  candidates = []
407
- for sample_idx in sample_indices:
408
- candidate = DataRecord(schema=SourceRecord, source_id=sample_idx)
409
- candidate.idx = sample_idx
410
- candidate.get_item_fn = partial(self.datasource.get_item, val=True)
411
- candidates.append(candidate)
428
+ for source_idx in source_indices:
429
+ candidates.append(source_idx)
412
430
 
413
431
  # NOTE: because we need to dynamically create sample matrices for each operator,
414
432
  # sentinel execution must be executed one operator at a time (i.e. sequentially)
@@ -422,24 +440,24 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
422
440
  sampled_ops = self.rng.choice(op_set, size=k, replace=False)
423
441
 
424
442
  # run sampled operators on sampled candidates
425
- source_id_to_record_sets, source_id_to_champion_record_set = self.execute_op_set(candidates, sampled_ops)
443
+ source_idx_to_record_sets, source_idx_to_champion_record_set = self.execute_op_set(candidates, sampled_ops)
426
444
 
427
445
  # update all_outputs and champion_outputs dictionary
428
446
  if logical_op_id not in all_outputs:
429
- all_outputs[logical_op_id] = source_id_to_record_sets
430
- champion_outputs[logical_op_id] = source_id_to_champion_record_set
447
+ all_outputs[logical_op_id] = source_idx_to_record_sets
448
+ champion_outputs[logical_op_id] = source_idx_to_champion_record_set
431
449
  else:
432
- for source_id, record_sets in source_id_to_record_sets.items():
433
- if source_id not in all_outputs[logical_op_id]:
434
- all_outputs[logical_op_id][source_id] = record_sets
435
- champion_outputs[logical_op_id][source_id] = source_id_to_champion_record_set[source_id]
450
+ for source_idx, record_sets in source_idx_to_record_sets.items():
451
+ if source_idx not in all_outputs[logical_op_id]:
452
+ all_outputs[logical_op_id][source_idx] = record_sets
453
+ champion_outputs[logical_op_id][source_idx] = source_idx_to_champion_record_set[source_idx]
436
454
  else:
437
- all_outputs[logical_op_id][source_id].extend(record_sets)
438
- champion_outputs[logical_op_id][source_id].extend(source_id_to_champion_record_set[source_id])
455
+ all_outputs[logical_op_id][source_idx].extend(record_sets)
456
+ champion_outputs[logical_op_id][source_idx].extend(source_idx_to_champion_record_set[source_idx])
439
457
 
440
458
  # flatten lists of records and record_op_stats
441
459
  all_records, all_record_op_stats = [], []
442
- for _, record_sets in source_id_to_record_sets.items():
460
+ for _, record_sets in source_idx_to_record_sets.items():
443
461
  for record_set in record_sets:
444
462
  all_records.extend(record_set.data_records)
445
463
  all_record_op_stats.extend(record_set.record_op_stats)
@@ -455,12 +473,13 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
455
473
  if not self.nocache:
456
474
  for record in all_records:
457
475
  if getattr(record, "passed_operator", True):
458
- self.datadir.append_cache(logical_op_id, record)
476
+ # self.datadir.append_cache(logical_op_id, record)
477
+ pass
459
478
 
460
479
  # update candidates for next operator; we use champion outputs as input
461
480
  candidates = []
462
481
  if next_logical_op_id is not None:
463
- for _, record_set in source_id_to_champion_record_set.items():
482
+ for _, record_set in source_idx_to_champion_record_set.items():
464
483
  for record in record_set:
465
484
  if isinstance(op_set[0], FilterOp) and not record.passed_operator:
466
485
  continue
@@ -471,13 +490,13 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
471
490
  break
472
491
 
473
492
  # compute quality for each operator
474
- field_to_metric_fn = self.datasource.get_field_to_metric_fn()
475
- all_outputs = self.score_quality(plan.operator_sets, all_outputs, champion_outputs, expected_outputs, field_to_metric_fn)
493
+ all_outputs = self.score_quality(plan.operator_sets, all_outputs, champion_outputs, expected_outputs)
476
494
 
477
495
  # if caching was allowed, close the cache
478
496
  if not self.nocache:
479
- for logical_op_id, _, _ in plan:
480
- self.datadir.close_cache(logical_op_id)
497
+ for _, _, _ in plan:
498
+ # self.datadir.close_cache(logical_op_id)
499
+ pass
481
500
 
482
501
  # finalize plan stats
483
502
  total_plan_time = time.time() - plan_start_time
@@ -497,28 +516,32 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
497
516
  """
498
517
  # if we're using validation data, get the set of expected output records
499
518
  expected_outputs = {}
500
- for idx in range(self.datasource.get_val_length()):
501
- data_records = self.datasource.get_item(idx, val=True, include_label=True)
502
- if not isinstance(data_records, list):
503
- data_records = [data_records]
504
- record_set = DataRecordSet(data_records, None)
505
- expected_outputs[record_set.source_id] = record_set
519
+ for source_idx in range(len(self.val_datasource)):
520
+ # TODO: make sure execute_op_set uses self.val_datasource
521
+ expected_output = self.val_datasource[source_idx]
522
+ expected_outputs[source_idx] = expected_output
506
523
 
507
524
  # run sentinel plan
508
525
  execution_data, plan_stats = self.execute_sentinel_plan(sentinel_plan, expected_outputs, policy)
509
526
 
510
527
  return execution_data, plan_stats
511
528
 
512
-
529
+
513
530
  def create_sentinel_plan(self, dataset: Set, policy: Policy) -> SentinelPlan:
514
531
  """
515
532
  Generates and returns a SentinelPlan for the given dataset.
516
533
  """
517
534
  # TODO: explicitly pull up filters; for SIGMOD we can explicitly write plans w/filters pulled up
518
- # initialize the optimizer
519
- # TODO: Do we need to re-initialize the optimizer here?
535
+
536
+ # create a new optimizer and update its strategy to SENTINEL
520
537
  optimizer = self.optimizer.deepcopy_clean()
521
538
  optimizer.update_strategy(OptimizationStrategyType.SENTINEL)
539
+
540
+ # create copy of dataset, but change its data source to the validation data source
541
+ dataset = deepcopy(dataset)
542
+ dataset._set_data_source(self.val_datasource)
543
+
544
+ # get the sentinel plan for the given dataset
522
545
  sentinel_plans = optimizer.optimize(dataset, policy)
523
546
  sentinel_plan = sentinel_plans[0]
524
547
 
@@ -529,12 +552,13 @@ class RandomSamplingSentinelQueryProcessor(QueryProcessor):
529
552
  execution_start_time = time.time()
530
553
 
531
554
  # for now, enforce that we are using validation data; we can relax this after paper submission
532
- if not self.using_validation_data:
533
- raise Exception("Make sure you are using ValidationDataSource with MABSentinelExecutionEngine")
555
+ if self.val_datasource is None:
556
+ raise Exception("Make sure you are using validation data with MABSentinelExecutionEngine")
534
557
 
535
558
  # if nocache is True, make sure we do not re-use codegen examples
536
559
  if self.nocache:
537
- self.clear_cached_examples()
560
+ # self.clear_cached_examples()
561
+ pass
538
562
 
539
563
  # create sentinel plan
540
564
  sentinel_plan = self.create_sentinel_plan(self.dataset, self.policy)
@@ -582,7 +606,6 @@ class RandomSamplingSentinelSequentialSingleThreadProcessor(RandomSamplingSentin
582
606
  SequentialSingleThreadExecutionStrategy.__init__(
583
607
  self,
584
608
  scan_start_idx=self.scan_start_idx,
585
- datadir=self.datadir,
586
609
  max_workers=self.max_workers,
587
610
  verbose=self.verbose
588
611
  )
@@ -597,7 +620,6 @@ class RandomSamplingSentinelPipelinedParallelProcessor(RandomSamplingSentinelQue
597
620
  PipelinedParallelExecutionStrategy.__init__(
598
621
  self,
599
622
  scan_start_idx=self.scan_start_idx,
600
- datadir=self.datadir,
601
623
  max_workers=self.max_workers,
602
624
  verbose=self.verbose
603
625
  )
@@ -612,7 +634,6 @@ class RandomSamplingSentinelPipelinedSingleThreadProcessor(RandomSamplingSentine
612
634
  PipelinedSingleThreadExecutionStrategy.__init__(
613
635
  self,
614
636
  scan_start_idx=self.scan_start_idx,
615
- datadir=self.datadir,
616
637
  max_workers=self.max_workers,
617
638
  verbose=self.verbose
618
639
  )
@@ -1,13 +1,12 @@
1
1
  import time
2
2
 
3
3
  from palimpzest.core.data.dataclasses import OperatorStats, PlanStats
4
- from palimpzest.core.elements.records import DataRecord, DataRecordCollection
5
- from palimpzest.core.lib.schemas import SourceRecord
4
+ from palimpzest.core.elements.records import DataRecordCollection
6
5
  from palimpzest.policy import Policy
7
6
  from palimpzest.query.operators.aggregate import AggregateOp
8
- from palimpzest.query.operators.datasource import DataSourcePhysicalOp
9
7
  from palimpzest.query.operators.filter import FilterOp
10
8
  from palimpzest.query.operators.limit import LimitScanOp
9
+ from palimpzest.query.operators.scan import ScanPhysicalOp
11
10
  from palimpzest.query.optimizer.plan import PhysicalPlan
12
11
  from palimpzest.query.processor.query_processor import QueryProcessor
13
12
  from palimpzest.sets import Dataset
@@ -47,7 +46,7 @@ class StreamingQueryProcessor(QueryProcessor):
47
46
  self._plan_stats = plan_stats
48
47
 
49
48
  def generate_plan(self, dataset: Dataset, policy: Policy):
50
- self.clear_cached_examples()
49
+ # self.clear_cached_examples()
51
50
  start_time = time.time()
52
51
 
53
52
  # TODO: Do we need to re-initialize the optimizer here?
@@ -90,21 +89,16 @@ class StreamingQueryProcessor(QueryProcessor):
90
89
 
91
90
  def get_input_records(self):
92
91
  scan_operator = self.plan.operators[0]
93
- assert isinstance(scan_operator, DataSourcePhysicalOp), "First operator in physical plan must be a DataSourcePhysicalOp"
94
- datasource = scan_operator.get_datasource()
95
- if not datasource:
96
- raise Exception("Data source not found")
97
- datasource_len = len(datasource)
92
+ assert isinstance(scan_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
93
+ datareader = scan_operator.datareader
94
+ if not datareader:
95
+ raise Exception("DataReader not found")
96
+ datareader_len = len(datareader)
98
97
 
99
98
  input_records = []
100
99
  record_op_stats = []
101
- for idx in range(datasource_len):
102
- # NOTE: this DataRecord will be discarded and replaced by the scan_operator;
103
- # it is simply a vessel to inform the scan_operator which record to fetch
104
- candidate = DataRecord(schema=SourceRecord, source_id=idx)
105
- candidate.idx = idx
106
- candidate.get_item_fn = datasource.get_item
107
- record_set = scan_operator(candidate)
100
+ for source_idx in range(datareader_len):
101
+ record_set = scan_operator(source_idx)
108
102
  input_records += record_set.data_records
109
103
  record_op_stats += record_set.record_op_stats
110
104
 
@@ -130,7 +124,7 @@ class StreamingQueryProcessor(QueryProcessor):
130
124
  op_id = operator.get_op_id()
131
125
  prev_op_id = plan.operators[op_idx - 1].get_op_id() if op_idx > 1 else None
132
126
 
133
- if isinstance(operator, DataSourcePhysicalOp):
127
+ if isinstance(operator, ScanPhysicalOp):
134
128
  continue
135
129
  # only invoke aggregate operator(s) once there are no more source records and all
136
130
  # upstream operators' processing queues are empty