palimpzest 0.5.4__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. palimpzest/__init__.py +7 -9
  2. palimpzest/constants.py +47 -7
  3. palimpzest/core/__init__.py +20 -26
  4. palimpzest/core/data/dataclasses.py +9 -2
  5. palimpzest/core/data/datareaders.py +497 -0
  6. palimpzest/core/elements/records.py +29 -37
  7. palimpzest/core/lib/fields.py +14 -12
  8. palimpzest/core/lib/schemas.py +80 -94
  9. palimpzest/policy.py +58 -0
  10. palimpzest/prompts/__init__.py +22 -0
  11. palimpzest/prompts/code_synthesis_prompts.py +28 -0
  12. palimpzest/prompts/convert_prompts.py +87 -0
  13. palimpzest/prompts/critique_and_refine_convert_prompts.py +216 -0
  14. palimpzest/prompts/filter_prompts.py +69 -0
  15. palimpzest/prompts/moa_aggregator_convert_prompts.py +57 -0
  16. palimpzest/prompts/moa_proposer_convert_prompts.py +79 -0
  17. palimpzest/prompts/prompt_factory.py +732 -0
  18. palimpzest/prompts/util_phrases.py +14 -0
  19. palimpzest/query/execution/execution_strategy.py +0 -3
  20. palimpzest/query/execution/parallel_execution_strategy.py +12 -25
  21. palimpzest/query/execution/single_threaded_execution_strategy.py +31 -45
  22. palimpzest/query/generators/generators.py +71 -347
  23. palimpzest/query/operators/__init__.py +5 -5
  24. palimpzest/query/operators/aggregate.py +10 -5
  25. palimpzest/query/operators/code_synthesis_convert.py +4 -48
  26. palimpzest/query/operators/convert.py +5 -2
  27. palimpzest/query/operators/critique_and_refine_convert.py +112 -0
  28. palimpzest/query/operators/filter.py +1 -1
  29. palimpzest/query/operators/limit.py +1 -1
  30. palimpzest/query/operators/logical.py +28 -27
  31. palimpzest/query/operators/mixture_of_agents_convert.py +4 -1
  32. palimpzest/query/operators/physical.py +32 -20
  33. palimpzest/query/operators/project.py +1 -1
  34. palimpzest/query/operators/rag_convert.py +6 -3
  35. palimpzest/query/operators/retrieve.py +13 -31
  36. palimpzest/query/operators/scan.py +150 -0
  37. palimpzest/query/optimizer/__init__.py +5 -1
  38. palimpzest/query/optimizer/cost_model.py +18 -34
  39. palimpzest/query/optimizer/optimizer.py +40 -25
  40. palimpzest/query/optimizer/optimizer_strategy.py +26 -0
  41. palimpzest/query/optimizer/plan.py +2 -2
  42. palimpzest/query/optimizer/rules.py +118 -27
  43. palimpzest/query/processor/config.py +12 -1
  44. palimpzest/query/processor/mab_sentinel_processor.py +125 -112
  45. palimpzest/query/processor/nosentinel_processor.py +46 -62
  46. palimpzest/query/processor/query_processor.py +10 -20
  47. palimpzest/query/processor/query_processor_factory.py +12 -5
  48. palimpzest/query/processor/random_sampling_sentinel_processor.py +112 -91
  49. palimpzest/query/processor/streaming_processor.py +11 -17
  50. palimpzest/sets.py +170 -94
  51. palimpzest/tools/pdfparser.py +5 -64
  52. palimpzest/utils/datareader_helpers.py +61 -0
  53. palimpzest/utils/field_helpers.py +69 -0
  54. palimpzest/utils/hash_helpers.py +3 -2
  55. palimpzest/utils/udfs.py +0 -28
  56. {palimpzest-0.5.4.dist-info → palimpzest-0.6.0.dist-info}/METADATA +49 -49
  57. palimpzest-0.6.0.dist-info/RECORD +87 -0
  58. {palimpzest-0.5.4.dist-info → palimpzest-0.6.0.dist-info}/top_level.txt +0 -1
  59. cli/README.md +0 -156
  60. cli/__init__.py +0 -0
  61. cli/cli_main.py +0 -390
  62. palimpzest/config.py +0 -89
  63. palimpzest/core/data/datasources.py +0 -369
  64. palimpzest/datamanager/__init__.py +0 -0
  65. palimpzest/datamanager/datamanager.py +0 -300
  66. palimpzest/prompts.py +0 -397
  67. palimpzest/query/operators/datasource.py +0 -202
  68. palimpzest-0.5.4.dist-info/RECORD +0 -83
  69. palimpzest-0.5.4.dist-info/entry_points.txt +0 -2
  70. {palimpzest-0.5.4.dist-info → palimpzest-0.6.0.dist-info}/LICENSE +0 -0
  71. {palimpzest-0.5.4.dist-info → palimpzest-0.6.0.dist-info}/WHEEL +0 -0
@@ -1,22 +1,26 @@
1
1
  import time
2
2
  from concurrent.futures import ThreadPoolExecutor, wait
3
- from functools import partial
4
- from typing import Callable
3
+ from copy import deepcopy
5
4
 
6
5
  import numpy as np
7
6
 
8
7
  from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
9
- from palimpzest.core.data.dataclasses import ExecutionStats, OperatorStats, PlanStats, RecordOpStats
10
- from palimpzest.core.elements.records import DataRecord, DataRecordCollection, DataRecordSet
11
- from palimpzest.core.lib.schemas import SourceRecord
8
+ from palimpzest.core.data.dataclasses import (
9
+ ExecutionStats,
10
+ OperatorCostEstimates,
11
+ OperatorStats,
12
+ PlanStats,
13
+ RecordOpStats,
14
+ )
15
+ from palimpzest.core.elements.records import DataRecordCollection, DataRecordSet
12
16
  from palimpzest.policy import Policy
13
17
  from palimpzest.query.execution.parallel_execution_strategy import PipelinedParallelExecutionStrategy
14
18
  from palimpzest.query.execution.single_threaded_execution_strategy import SequentialSingleThreadExecutionStrategy
15
19
  from palimpzest.query.operators.convert import ConvertOp, LLMConvert
16
- from palimpzest.query.operators.datasource import CacheScanDataOp, MarshalAndScanDataOp
17
20
  from palimpzest.query.operators.filter import FilterOp, LLMFilter
18
21
  from palimpzest.query.operators.physical import PhysicalOperator
19
22
  from palimpzest.query.operators.retrieve import RetrieveOp
23
+ from palimpzest.query.operators.scan import CacheScanDataOp, MarshalAndScanDataOp
20
24
  from palimpzest.query.optimizer.cost_model import SampleBasedCostModel
21
25
  from palimpzest.query.optimizer.optimizer_strategy import OptimizationStrategyType
22
26
  from palimpzest.query.optimizer.plan import SentinelPlan
@@ -49,7 +53,7 @@ class MABSentinelQueryProcessor(QueryProcessor):
49
53
  self.sample_budget = sample_budget
50
54
  self.early_stop_iters = early_stop_iters
51
55
  self.use_final_op_quality = use_final_op_quality
52
- self.pick_output_fn = self.pick_ensemble_output
56
+ self.pick_output_fn = self.pick_champion_output
53
57
  self.rng = np.random.default_rng(seed=seed)
54
58
 
55
59
 
@@ -72,10 +76,10 @@ class MABSentinelQueryProcessor(QueryProcessor):
72
76
  """
73
77
  # compute metrics for each operator in all_outputs
74
78
  logical_op_id_to_op_metrics = {}
75
- for logical_op_id, source_id_to_record_sets in all_outputs.items():
79
+ for logical_op_id, source_idx_to_record_sets in all_outputs.items():
76
80
  # compute selectivity for each physical operator
77
81
  phys_op_to_num_inputs, phys_op_to_num_outputs = {}, {}
78
- for _, record_sets in source_id_to_record_sets.items():
82
+ for _, record_sets in source_idx_to_record_sets.items():
79
83
  for record_set in record_sets:
80
84
  op_id = record_set.record_op_stats[0].op_id
81
85
  num_outputs = sum([record_op_stats.passed_operator for record_op_stats in record_set.record_op_stats])
@@ -93,7 +97,7 @@ class MABSentinelQueryProcessor(QueryProcessor):
93
97
 
94
98
  # compute average cost, time, and quality
95
99
  phys_op_to_costs, phys_op_to_times, phys_op_to_qualities = {}, {}, {}
96
- for _, record_sets in source_id_to_record_sets.items():
100
+ for _, record_sets in source_idx_to_record_sets.items():
97
101
  for record_set in record_sets:
98
102
  for record_op_stats in record_set.record_op_stats:
99
103
  op_id = record_op_stats.op_id
@@ -228,20 +232,19 @@ class MABSentinelQueryProcessor(QueryProcessor):
228
232
  def compute_quality(
229
233
  self,
230
234
  record_set: DataRecordSet,
231
- expected_record_set: DataRecordSet | None = None,
235
+ expected_output: dict | None = None,
232
236
  champion_record_set: DataRecordSet | None = None,
233
237
  is_filter_op: bool = False,
234
238
  is_convert_op: bool = False,
235
- field_to_metric_fn: dict[str, str | Callable] | None = None,
236
239
  ) -> DataRecordSet:
237
240
  """
238
- Compute the quality for the given `record_set` by comparing it to the `expected_record_set`.
241
+ Compute the quality for the given `record_set` by comparing it to the `expected_output`.
239
242
 
240
243
  Update the record_set by assigning the quality to each entry in its record_op_stats and
241
244
  returning the updated record_set.
242
245
  """
243
246
  # compute whether we can only use the champion
244
- only_using_champion = expected_record_set is None
247
+ only_using_champion = expected_output is None
245
248
 
246
249
  # if this operation is a failed convert
247
250
  if is_convert_op and len(record_set) == 0:
@@ -263,16 +266,17 @@ class MABSentinelQueryProcessor(QueryProcessor):
263
266
  champion_record = champion_record_set[0]
264
267
  record_op_stats.quality = int(record_op_stats.passed_operator == champion_record.passed_operator)
265
268
 
266
- # - if we are using validation data, we may have multiple expected records in the expected_record_set for this source_id,
269
+ # - if we are using validation data, we may have multiple expected records in the expected_output for this source_idx,
267
270
  # thus, if we can identify an exact match, we can use that to evaluate the filter's quality
268
271
  # - if we are using validation data but we *cannot* find an exact match, then we will once again use the champion record set
269
272
  else:
270
273
  # compute number of matches between this record's computed fields and this expected record's outputs
271
274
  found_match_in_output = False
272
- for expected_record in expected_record_set:
275
+ labels_dict_lst = expected_output["labels"] if isinstance(expected_output["labels"], list) else [expected_output["labels"]]
276
+ for labels_dict in labels_dict_lst:
273
277
  all_correct = True
274
278
  for field, value in record_op_stats.record_state.items():
275
- if value != getattr(expected_record, field):
279
+ if value != labels_dict[field]:
276
280
  all_correct = False
277
281
  break
278
282
 
@@ -281,7 +285,7 @@ class MABSentinelQueryProcessor(QueryProcessor):
281
285
  break
282
286
 
283
287
  if found_match_in_output:
284
- record_op_stats.quality = int(record_op_stats.passed_operator == expected_record.passed_operator)
288
+ record_op_stats.quality = int(record_op_stats.passed_operator)
285
289
  else:
286
290
  champion_record = champion_record_set[0]
287
291
  record_op_stats.quality = int(record_op_stats.passed_operator == champion_record.passed_operator)
@@ -294,13 +298,23 @@ class MABSentinelQueryProcessor(QueryProcessor):
294
298
  # validation dataset and use the champion model (as opposed to the validation
295
299
  # output) for scoring fields which have their values projected out
296
300
 
297
- # set the expected_record_set to be the champion_record_set if we do not have validation data
298
- expected_record_set = champion_record_set if only_using_champion else expected_record_set
301
+ # create list of dictionaries of labels for each expected / champion output
302
+ labels_dict_lst = []
303
+ if only_using_champion:
304
+ for champion_record in champion_record_set:
305
+ labels_dict_lst.append(champion_record.to_dict())
306
+ else:
307
+ labels_dict_lst = (
308
+ expected_output["labels"]
309
+ if isinstance(expected_output["labels"], list)
310
+ else [expected_output["labels"]]
311
+ )
299
312
 
300
313
  # GREEDY ALGORITHM
301
314
  # for each record in the expected output, we look for the computed record which maximizes the quality metric;
302
315
  # once we've identified that computed record we remove it from consideration for the next expected output
303
- for expected_record in expected_record_set:
316
+ field_to_score_fn = {} if only_using_champion else expected_output["score_fn"]
317
+ for labels_dict in labels_dict_lst:
304
318
  best_quality, best_record_op_stats = 0.0, None
305
319
  for record_op_stats in record_set.record_op_stats:
306
320
  # if we already assigned this record a quality, skip it
@@ -311,26 +325,22 @@ class MABSentinelQueryProcessor(QueryProcessor):
311
325
  total_quality = 0
312
326
  for field in record_op_stats.generated_fields:
313
327
  computed_value = record_op_stats.record_state.get(field, None)
314
- expected_value = getattr(expected_record, field)
328
+ expected_value = labels_dict[field]
315
329
 
316
330
  # get the metric function for this field
317
- metric_fn = (
318
- field_to_metric_fn[field]
319
- if field_to_metric_fn is not None and field in field_to_metric_fn
320
- else "exact"
321
- )
331
+ score_fn = field_to_score_fn.get(field, "exact")
322
332
 
323
333
  # compute exact match
324
- if metric_fn == "exact":
334
+ if score_fn == "exact":
325
335
  total_quality += int(computed_value == expected_value)
326
336
 
327
337
  # compute UDF metric
328
- elif callable(metric_fn):
329
- total_quality += metric_fn(computed_value, expected_value)
338
+ elif callable(score_fn):
339
+ total_quality += score_fn(computed_value, expected_value)
330
340
 
331
341
  # otherwise, throw an exception
332
342
  else:
333
- raise Exception(f"Unrecognized metric_fn: {metric_fn}")
343
+ raise Exception(f"Unrecognized score_fn: {score_fn}")
334
344
 
335
345
  # compute recall and update best seen so far
336
346
  quality = total_quality / len(record_op_stats.generated_fields)
@@ -351,14 +361,13 @@ class MABSentinelQueryProcessor(QueryProcessor):
351
361
 
352
362
 
353
363
  def score_quality(
354
- self,
355
- op_set: list[PhysicalOperator],
356
- logical_op_id: str,
357
- execution_data: dict[str, dict[str, list[DataRecordSet]]],
358
- champion_outputs: dict[str, dict[str, DataRecordSet]],
359
- expected_outputs: dict[str, DataRecordSet] | None = None,
360
- field_to_metric_fn: dict[str, str | Callable] | None = None,
361
- ) -> list[RecordOpStats]:
364
+ self,
365
+ op_set: list[PhysicalOperator],
366
+ logical_op_id: str,
367
+ execution_data: dict[str, dict[str, list[DataRecordSet]]],
368
+ champion_outputs: dict[str, dict[str, DataRecordSet]],
369
+ expected_outputs: dict[str, dict],
370
+ ) -> list[RecordOpStats]:
362
371
  """
363
372
  NOTE: This approach to cost modeling does not work directly for aggregation queries;
364
373
  for these queries, we would ask the user to provide validation data for the step immediately
@@ -396,9 +405,9 @@ class MABSentinelQueryProcessor(QueryProcessor):
396
405
  this_op_execution_data = execution_data[logical_op_id]
397
406
 
398
407
  # compute quality of each output computed by this operator
399
- for source_id, record_sets in this_op_execution_data.items():
408
+ for source_idx, record_sets in this_op_execution_data.items():
400
409
  # NOTE
401
- # source_id is a particular input, for which we may have computed multiple output record_sets;
410
+ # source_idx is a particular input, for which we may have computed multiple output record_sets;
402
411
  # each of these record_sets may contain more than one record (b/c one-to-many) and we have one
403
412
  # record_set per operator in the op_set
404
413
 
@@ -409,23 +418,38 @@ class MABSentinelQueryProcessor(QueryProcessor):
409
418
  record_op_stats.quality = 1.0
410
419
  continue
411
420
 
412
- # get the expected output for this source_id if we have one
413
- expected_record_set = (
414
- expected_outputs[source_id]
415
- if expected_outputs is not None and source_id in expected_outputs
421
+ # get the expected output for this source_idx if we have one
422
+ expected_output = (
423
+ expected_outputs[source_idx]
424
+ if expected_outputs is not None and source_idx in expected_outputs
416
425
  else None
417
426
  )
418
427
 
419
428
  # extract champion output for this record set
420
- champion_record_set = champion_outputs[logical_op_id][source_id]
429
+ champion_record_set = champion_outputs[logical_op_id][source_idx]
421
430
 
422
431
  # for each record_set produced by an operation, compute its quality
423
432
  for record_set in record_sets:
424
- record_set = self.compute_quality(record_set, expected_record_set, champion_record_set, is_filter_op, is_convert_op, field_to_metric_fn)
433
+ record_set = self.compute_quality(record_set, expected_output, champion_record_set, is_filter_op, is_convert_op)
425
434
 
426
435
  # return the quality annotated record op stats
427
436
  return execution_data
428
437
 
438
+ def pick_champion_output(self, op_set_record_sets: list[tuple[DataRecordSet, PhysicalOperator]]) -> DataRecordSet:
439
+ # if there's only one operator in the set, we return its record_set
440
+ if len(op_set_record_sets) == 1:
441
+ record_set, _ = op_set_record_sets[0]
442
+ return record_set
443
+
444
+ # find the operator with the highest average quality and return its record_set
445
+ base_op_cost_est = OperatorCostEstimates(cardinality=1.0, cost_per_record=0.0, time_per_record=0.0, quality=1.0)
446
+ champion_record_set, champion_quality = None, -1.0
447
+ for record_set, op in op_set_record_sets:
448
+ op_cost_estimates = op.naive_cost_estimates(base_op_cost_est)
449
+ if op_cost_estimates.quality > champion_quality:
450
+ champion_record_set, champion_quality = record_set, op_cost_estimates.quality
451
+
452
+ return champion_record_set
429
453
 
430
454
  def pick_ensemble_output(self, op_set_record_sets: list[tuple[DataRecordSet, PhysicalOperator]]) -> DataRecordSet:
431
455
  # if there's only one operator in the set, we return its record_set
@@ -463,6 +487,10 @@ class MABSentinelQueryProcessor(QueryProcessor):
463
487
  record_set, _ = op_set_record_sets[0]
464
488
  return record_set
465
489
 
490
+ # NOTE: I don't like that this assumes the models are consistent in
491
+ # how they order their record outputs for one-to-many converts;
492
+ # eventually we can try out more robust schemes to account for
493
+ # differences in ordering
466
494
  # aggregate records at each index in the response
467
495
  idx_to_records = {}
468
496
  for record_set, _ in op_set_record_sets:
@@ -523,37 +551,30 @@ class MABSentinelQueryProcessor(QueryProcessor):
523
551
  # update list of futures
524
552
  futures = not_done_futures
525
553
 
526
- # compute mapping from source_id to record sets for all operators and for champion operator
554
+ # compute mapping from source_idx to record sets for all operators and for champion operator
527
555
  all_record_sets, champion_record_sets = {}, {}
528
- for op, candidate in op_candidate_pairs:
529
- candidate_output_record_sets, source_id = [], None
556
+ for _, candidate in op_candidate_pairs:
557
+ candidate_output_record_sets, source_idx = [], None
530
558
  for record_set, operator, candidate_ in output_record_sets:
531
559
  if candidate == candidate_:
532
560
  candidate_output_record_sets.append((record_set, operator))
533
561
 
534
- # NOTE: we should resolve this issue in a more thoughtful way, but currently the source_id for
535
- # scan candidate records is the sample_idx, when we want it to be the source_id which is
536
- # set by the scan operator when it returns the record
537
- # get the source_id associated with this input record
538
- source_id = (
539
- candidate.source_id
540
- if not (isinstance(op, (MarshalAndScanDataOp, CacheScanDataOp)))
541
- else record_set[0].source_id
542
- )
562
+ # get the source_idx associated with this input record
563
+ source_idx = candidate.source_idx
543
564
 
544
565
  # select the champion (i.e. best) record_set from all the record sets computed for this candidate
545
566
  champion_record_set = self.pick_output_fn(candidate_output_record_sets)
546
567
 
547
- # add champion record_set to mapping from source_id --> champion record_set
548
- champion_record_sets[source_id] = champion_record_set
568
+ # add champion record_set to mapping from source_idx --> champion record_set
569
+ champion_record_sets[source_idx] = champion_record_set
549
570
 
550
- # add all record_sets computed for this source_id to mapping from source_id --> record_sets
551
- all_record_sets[source_id] = [tup[0] for tup in candidate_output_record_sets]
571
+ # add all record_sets computed for this source_idx to mapping from source_idx --> record_sets
572
+ all_record_sets[source_idx] = [tup[0] for tup in candidate_output_record_sets]
552
573
 
553
574
  return all_record_sets, champion_record_sets
554
575
 
555
576
 
556
- def execute_sentinel_plan(self, plan: SentinelPlan, expected_outputs: dict[str, DataRecordSet], policy: Policy):
577
+ def execute_sentinel_plan(self, plan: SentinelPlan, expected_outputs: dict[str, dict], policy: Policy):
557
578
  """
558
579
  """
559
580
  if self.verbose:
@@ -578,9 +599,9 @@ class MABSentinelQueryProcessor(QueryProcessor):
578
599
  )
579
600
 
580
601
  # shuffle the indices of records to sample
581
- total_num_samples = self.datasource.get_val_length()
582
- shuffled_sample_indices = [int(idx) for idx in np.arange(total_num_samples)]
583
- self.rng.shuffle(shuffled_sample_indices)
602
+ total_num_samples = len(self.val_datasource)
603
+ shuffled_source_indices = [int(idx) for idx in np.arange(total_num_samples)]
604
+ self.rng.shuffle(shuffled_source_indices)
584
605
 
585
606
  # sample k initial operators for each operator set; for each operator maintain a tuple of:
586
607
  # (operator, next_shuffled_sample_idx, new_operator); new_operator is True when an operator
@@ -601,12 +622,6 @@ class MABSentinelQueryProcessor(QueryProcessor):
601
622
  for logical_op_id, _, op_set in plan
602
623
  }
603
624
 
604
- # TODO: long-term, we should do something which does not rely on scanning validation source to build this mapping
605
- sample_idx_to_source_id = {
606
- sample_idx: self.datasource.get_item(sample_idx, val=True).source_id
607
- for sample_idx in range(total_num_samples)
608
- }
609
-
610
625
  # NOTE: to maintain parity with our count of samples drawn in the random sampling execution,
611
626
  # for each logical_op_id, we count the number of (record, op) executions as the number of samples within that op_set;
612
627
  # the samples drawn is equal to the max of that number across all operator sets
@@ -623,26 +638,22 @@ class MABSentinelQueryProcessor(QueryProcessor):
623
638
  updated_frontier_ops_lst = []
624
639
  for op, next_shuffled_sample_idx, new_operator, fully_sampled in frontier_ops[logical_op_id]:
625
640
  # execute new operators on first j candidates, and previously sampled operators on one additional candidate
626
- j = min(self.j, len(shuffled_sample_indices)) if new_operator else 1
641
+ j = min(self.j, len(shuffled_source_indices)) if new_operator else 1
627
642
  for j_idx in range(j):
628
643
  candidates = []
629
644
  if isinstance(op, (MarshalAndScanDataOp, CacheScanDataOp)):
630
- sample_idx = shuffled_sample_indices[(next_shuffled_sample_idx + j_idx) % len(shuffled_sample_indices)]
631
- candidate = DataRecord(schema=SourceRecord, source_id=sample_idx)
632
- candidate.idx = sample_idx
633
- candidate.get_item_fn = partial(self.datasource.get_item, val=True)
634
- candidates = [candidate]
645
+ source_idx = shuffled_source_indices[(next_shuffled_sample_idx + j_idx) % len(shuffled_source_indices)]
646
+ candidates = [source_idx]
635
647
  logical_op_id_to_num_samples[logical_op_id] += 1
636
648
  phys_op_id_to_num_samples[op.get_op_id()] += 1
637
649
  else:
638
- if next_shuffled_sample_idx + j_idx == len(shuffled_sample_indices):
650
+ if next_shuffled_sample_idx + j_idx == len(shuffled_source_indices):
639
651
  fully_sampled = True
640
652
  break
641
653
 
642
654
  # pick best output from all_outputs from previous logical operator
643
- sample_idx = shuffled_sample_indices[next_shuffled_sample_idx + j_idx]
644
- source_id = sample_idx_to_source_id[sample_idx]
645
- record_sets = all_outputs[prev_logical_op_id][source_id]
655
+ source_idx = shuffled_source_indices[next_shuffled_sample_idx + j_idx]
656
+ record_sets = all_outputs[prev_logical_op_id][source_idx]
646
657
  all_source_record_sets = [(record_set, None) for record_set in record_sets]
647
658
  max_quality_record_set = self.pick_highest_quality_output(all_source_record_sets)
648
659
  if (
@@ -672,26 +683,26 @@ class MABSentinelQueryProcessor(QueryProcessor):
672
683
  continue
673
684
 
674
685
  # run sampled operators on sampled candidates
675
- source_id_to_record_sets, source_id_to_champion_record_set = self.execute_op_set(op_candidate_pairs)
686
+ source_idx_to_record_sets, source_idx_to_champion_record_set = self.execute_op_set(op_candidate_pairs)
676
687
 
677
688
  # update all_outputs and champion_outputs dictionary
678
689
  if logical_op_id not in all_outputs:
679
- all_outputs[logical_op_id] = source_id_to_record_sets
680
- champion_outputs[logical_op_id] = source_id_to_champion_record_set
690
+ all_outputs[logical_op_id] = source_idx_to_record_sets
691
+ champion_outputs[logical_op_id] = source_idx_to_champion_record_set
681
692
  else:
682
- for source_id, record_sets in source_id_to_record_sets.items():
683
- if source_id not in all_outputs[logical_op_id]:
684
- all_outputs[logical_op_id][source_id] = record_sets
685
- champion_outputs[logical_op_id][source_id] = source_id_to_champion_record_set[source_id]
693
+ for source_idx, record_sets in source_idx_to_record_sets.items():
694
+ if source_idx not in all_outputs[logical_op_id]:
695
+ all_outputs[logical_op_id][source_idx] = record_sets
696
+ champion_outputs[logical_op_id][source_idx] = source_idx_to_champion_record_set[source_idx]
686
697
  else:
687
- all_outputs[logical_op_id][source_id].extend(record_sets)
698
+ all_outputs[logical_op_id][source_idx].extend(record_sets)
688
699
  # NOTE: short-term solution; in practice we can get multiple champion records from different
689
700
  # sets of operators, so we should try to find a way to only take one
690
- champion_outputs[logical_op_id][source_id] = source_id_to_champion_record_set[source_id]
701
+ champion_outputs[logical_op_id][source_idx] = source_idx_to_champion_record_set[source_idx]
691
702
 
692
703
  # flatten lists of records and record_op_stats
693
704
  all_records, all_record_op_stats = [], []
694
- for _, record_sets in source_id_to_record_sets.items():
705
+ for _, record_sets in source_idx_to_record_sets.items():
695
706
  for record_set in record_sets:
696
707
  all_records.extend(record_set.data_records)
697
708
  all_record_op_stats.extend(record_set.record_op_stats)
@@ -707,17 +718,16 @@ class MABSentinelQueryProcessor(QueryProcessor):
707
718
  if not self.nocache:
708
719
  for record in all_records:
709
720
  if getattr(record, "passed_operator", True):
710
- self.datadir.append_cache(logical_op_id, record)
721
+ # self.datadir.append_cache(logical_op_id, record)
722
+ pass
711
723
 
712
724
  # compute quality for each operator
713
- field_to_metric_fn = self.datasource.get_field_to_metric_fn()
714
725
  all_outputs = self.score_quality(
715
726
  op_set,
716
727
  logical_op_id,
717
728
  all_outputs,
718
729
  champion_outputs,
719
730
  expected_outputs,
720
- field_to_metric_fn,
721
731
  )
722
732
 
723
733
  # update the (pareto) frontier for each set of operators
@@ -736,8 +746,9 @@ class MABSentinelQueryProcessor(QueryProcessor):
736
746
 
737
747
  # if caching was allowed, close the cache
738
748
  if not self.nocache:
739
- for logical_op_id, _, _ in plan:
740
- self.datadir.close_cache(logical_op_id)
749
+ for _, _, _ in plan:
750
+ # self.datadir.close_cache(logical_op_id)
751
+ pass
741
752
 
742
753
  # finalize plan stats
743
754
  total_plan_time = time.time() - plan_start_time
@@ -757,12 +768,10 @@ class MABSentinelQueryProcessor(QueryProcessor):
757
768
  """
758
769
  # if we're using validation data, get the set of expected output records
759
770
  expected_outputs = {}
760
- for idx in range(self.datasource.get_val_length()):
761
- data_records = self.datasource.get_item(idx, val=True, include_label=True)
762
- if not isinstance(data_records, list):
763
- data_records = [data_records]
764
- record_set = DataRecordSet(data_records, None)
765
- expected_outputs[record_set.source_id] = record_set
771
+ for source_idx in range(len(self.val_datasource)):
772
+ # TODO: make sure execute_op_set uses self.val_datasource
773
+ expected_output = self.val_datasource[source_idx]
774
+ expected_outputs[source_idx] = expected_output
766
775
 
767
776
  # run sentinel plan
768
777
  execution_data, plan_stats = self.execute_sentinel_plan(sentinel_plan, expected_outputs, policy)
@@ -775,11 +784,16 @@ class MABSentinelQueryProcessor(QueryProcessor):
775
784
  Generates and returns a SentinelPlan for the given dataset.
776
785
  """
777
786
  # TODO: explicitly pull up filters; for SIGMOD we can explicitly write plans w/filters pulled up
778
- # initialize the optimizer
779
- # use optimizer to generate sentinel plans
780
- # TODO: Do we need to re-initialize the optimizer here?
787
+
788
+ # create a new optimizer and update its strategy to SENTINEL
781
789
  optimizer = self.optimizer.deepcopy_clean()
782
790
  optimizer.update_strategy(OptimizationStrategyType.SENTINEL)
791
+
792
+ # create copy of dataset, but change its data source to the validation data source
793
+ dataset = deepcopy(dataset)
794
+ dataset._set_data_source(self.val_datasource)
795
+
796
+ # get the sentinel plan for the given dataset
783
797
  sentinel_plans = optimizer.optimize(dataset, policy)
784
798
  sentinel_plan = sentinel_plans[0]
785
799
 
@@ -790,12 +804,13 @@ class MABSentinelQueryProcessor(QueryProcessor):
790
804
  execution_start_time = time.time()
791
805
 
792
806
  # for now, enforce that we are using validation data; we can relax this after paper submission
793
- if not self.using_validation_data:
794
- raise Exception("Make sure you are using ValidationDataSource with MABSentinelExecutionEngine")
807
+ if self.val_datasource is None:
808
+ raise Exception("Make sure you are using validation data with MABSentinelExecutionEngine")
795
809
 
796
810
  # if nocache is True, make sure we do not re-use codegen examples
797
811
  if self.nocache:
798
- self.clear_cached_examples()
812
+ # self.clear_cached_examples()
813
+ pass
799
814
 
800
815
  # create sentinel plan
801
816
  sentinel_plan = self.create_sentinel_plan(self.dataset, self.policy)
@@ -846,7 +861,6 @@ class MABSentinelSequentialSingleThreadProcessor(MABSentinelQueryProcessor, Sequ
846
861
  SequentialSingleThreadExecutionStrategy.__init__(
847
862
  self,
848
863
  scan_start_idx=self.scan_start_idx,
849
- datadir=self.datadir,
850
864
  max_workers=self.max_workers,
851
865
  nocache=self.nocache,
852
866
  verbose=self.verbose
@@ -863,7 +877,6 @@ class MABSentinelPipelinedParallelProcessor(MABSentinelQueryProcessor, Pipelined
863
877
  PipelinedParallelExecutionStrategy.__init__(
864
878
  self,
865
879
  scan_start_idx=self.scan_start_idx,
866
- datadir=self.datadir,
867
880
  max_workers=self.max_workers,
868
881
  nocache=self.nocache,
869
882
  verbose=self.verbose