palimpzest 0.5.4__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. palimpzest/__init__.py +7 -9
  2. palimpzest/constants.py +47 -7
  3. palimpzest/core/__init__.py +20 -26
  4. palimpzest/core/data/dataclasses.py +9 -2
  5. palimpzest/core/data/datareaders.py +497 -0
  6. palimpzest/core/elements/records.py +29 -37
  7. palimpzest/core/lib/fields.py +14 -12
  8. palimpzest/core/lib/schemas.py +80 -94
  9. palimpzest/policy.py +58 -0
  10. palimpzest/prompts/__init__.py +22 -0
  11. palimpzest/prompts/code_synthesis_prompts.py +28 -0
  12. palimpzest/prompts/convert_prompts.py +87 -0
  13. palimpzest/prompts/critique_and_refine_convert_prompts.py +216 -0
  14. palimpzest/prompts/filter_prompts.py +69 -0
  15. palimpzest/prompts/moa_aggregator_convert_prompts.py +57 -0
  16. palimpzest/prompts/moa_proposer_convert_prompts.py +79 -0
  17. palimpzest/prompts/prompt_factory.py +732 -0
  18. palimpzest/prompts/util_phrases.py +14 -0
  19. palimpzest/query/execution/execution_strategy.py +0 -3
  20. palimpzest/query/execution/parallel_execution_strategy.py +12 -25
  21. palimpzest/query/execution/single_threaded_execution_strategy.py +31 -45
  22. palimpzest/query/generators/generators.py +71 -347
  23. palimpzest/query/operators/__init__.py +5 -5
  24. palimpzest/query/operators/aggregate.py +10 -5
  25. palimpzest/query/operators/code_synthesis_convert.py +4 -48
  26. palimpzest/query/operators/convert.py +5 -2
  27. palimpzest/query/operators/critique_and_refine_convert.py +112 -0
  28. palimpzest/query/operators/filter.py +1 -1
  29. palimpzest/query/operators/limit.py +1 -1
  30. palimpzest/query/operators/logical.py +28 -27
  31. palimpzest/query/operators/mixture_of_agents_convert.py +4 -1
  32. palimpzest/query/operators/physical.py +32 -20
  33. palimpzest/query/operators/project.py +1 -1
  34. palimpzest/query/operators/rag_convert.py +6 -3
  35. palimpzest/query/operators/retrieve.py +13 -31
  36. palimpzest/query/operators/scan.py +150 -0
  37. palimpzest/query/optimizer/__init__.py +5 -1
  38. palimpzest/query/optimizer/cost_model.py +18 -34
  39. palimpzest/query/optimizer/optimizer.py +40 -25
  40. palimpzest/query/optimizer/optimizer_strategy.py +26 -0
  41. palimpzest/query/optimizer/plan.py +2 -2
  42. palimpzest/query/optimizer/rules.py +118 -27
  43. palimpzest/query/processor/config.py +12 -1
  44. palimpzest/query/processor/mab_sentinel_processor.py +125 -112
  45. palimpzest/query/processor/nosentinel_processor.py +46 -62
  46. palimpzest/query/processor/query_processor.py +10 -20
  47. palimpzest/query/processor/query_processor_factory.py +12 -5
  48. palimpzest/query/processor/random_sampling_sentinel_processor.py +112 -91
  49. palimpzest/query/processor/streaming_processor.py +11 -17
  50. palimpzest/sets.py +170 -94
  51. palimpzest/tools/pdfparser.py +5 -64
  52. palimpzest/utils/datareader_helpers.py +61 -0
  53. palimpzest/utils/field_helpers.py +69 -0
  54. palimpzest/utils/hash_helpers.py +3 -2
  55. palimpzest/utils/udfs.py +0 -28
  56. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/METADATA +49 -49
  57. palimpzest-0.6.1.dist-info/RECORD +87 -0
  58. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/top_level.txt +0 -1
  59. cli/README.md +0 -156
  60. cli/__init__.py +0 -0
  61. cli/cli_main.py +0 -390
  62. palimpzest/config.py +0 -89
  63. palimpzest/core/data/datasources.py +0 -369
  64. palimpzest/datamanager/__init__.py +0 -0
  65. palimpzest/datamanager/datamanager.py +0 -300
  66. palimpzest/prompts.py +0 -397
  67. palimpzest/query/operators/datasource.py +0 -202
  68. palimpzest-0.5.4.dist-info/RECORD +0 -83
  69. palimpzest-0.5.4.dist-info/entry_points.txt +0 -2
  70. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/LICENSE +0 -0
  71. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/WHEEL +0 -0
@@ -1,17 +1,16 @@
1
1
  import time
2
2
 
3
3
  from palimpzest.core.data.dataclasses import ExecutionStats, OperatorStats, PlanStats
4
- from palimpzest.core.elements.records import DataRecord, DataRecordCollection
5
- from palimpzest.core.lib.schemas import SourceRecord
4
+ from palimpzest.core.elements.records import DataRecordCollection
6
5
  from palimpzest.query.execution.parallel_execution_strategy import PipelinedParallelExecutionStrategy
7
6
  from palimpzest.query.execution.single_threaded_execution_strategy import (
8
7
  PipelinedSingleThreadExecutionStrategy,
9
8
  SequentialSingleThreadExecutionStrategy,
10
9
  )
11
10
  from palimpzest.query.operators.aggregate import AggregateOp
12
- from palimpzest.query.operators.datasource import DataSourcePhysicalOp
13
11
  from palimpzest.query.operators.filter import FilterOp
14
12
  from palimpzest.query.operators.limit import LimitScanOp
13
+ from palimpzest.query.operators.scan import ScanPhysicalOp
15
14
  from palimpzest.query.optimizer.plan import PhysicalPlan
16
15
  from palimpzest.query.processor.query_processor import QueryProcessor
17
16
  from palimpzest.utils.progress import create_progress_manager
@@ -29,7 +28,8 @@ class NoSentinelQueryProcessor(QueryProcessor):
29
28
 
30
29
  # if nocache is True, make sure we do not re-use codegen examples
31
30
  if self.nocache:
32
- self.clear_cached_examples()
31
+ # self.clear_cached_examples()
32
+ pass
33
33
 
34
34
  # execute plan(s) according to the optimization strategy
35
35
  records, plan_stats = self._execute_with_strategy(self.dataset, self.policy, self.optimizer)
@@ -60,7 +60,6 @@ class NoSentinelSequentialSingleThreadProcessor(NoSentinelQueryProcessor, Sequen
60
60
  SequentialSingleThreadExecutionStrategy.__init__(
61
61
  self,
62
62
  scan_start_idx=self.scan_start_idx,
63
- datadir=self.datadir,
64
63
  max_workers=self.max_workers,
65
64
  nocache=self.nocache,
66
65
  verbose=self.verbose
@@ -92,21 +91,20 @@ class NoSentinelSequentialSingleThreadProcessor(NoSentinelQueryProcessor, Sequen
92
91
  output_records = []
93
92
  current_scan_idx = self.scan_start_idx
94
93
 
95
- # get handle to DataSource and pre-compute its size
94
+ # get handle to scan operator and pre-compute its size
96
95
  source_operator = plan.operators[0]
97
- assert isinstance(source_operator, DataSourcePhysicalOp), "First operator in physical plan must be a DataSourcePhysicalOp"
98
- datasource = source_operator.get_datasource()
99
- datasource_len = len(datasource)
96
+ assert isinstance(source_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
97
+ datareader_len = len(source_operator.datareader)
100
98
 
101
99
  # Calculate total work units - each record needs to go through each operator
102
100
  total_ops = len(plan.operators)
103
- total_items = min(num_samples, datasource_len) if num_samples != float("inf") else datasource_len
101
+ total_items = min(num_samples, datareader_len) if num_samples != float("inf") else datareader_len
104
102
  total_work_units = total_items * total_ops
105
103
  self.progress_manager.start(total_work_units)
106
104
  work_units_completed = 0
107
105
 
108
106
  # initialize processing queues for each operation
109
- processing_queues = {op.get_op_id(): [] for op in plan.operators if not isinstance(op, DataSourcePhysicalOp)}
107
+ processing_queues = {op.get_op_id(): [] for op in plan.operators if not isinstance(op, ScanPhysicalOp)}
110
108
 
111
109
  try:
112
110
  # execute the plan one operator at a time
@@ -122,19 +120,12 @@ class NoSentinelSequentialSingleThreadProcessor(NoSentinelQueryProcessor, Sequen
122
120
  # initialize output records and record_op_stats for this operator
123
121
  records, record_op_stats = [], []
124
122
 
125
- # invoke datasource operator(s) until we run out of source records or hit the num_samples limit
126
- if isinstance(operator, DataSourcePhysicalOp):
123
+ # invoke scan operator(s) until we run out of source records or hit the num_samples limit
124
+ if isinstance(operator, ScanPhysicalOp):
127
125
  keep_scanning_source_records = True
128
126
  while keep_scanning_source_records:
129
- # construct input DataRecord for DataSourcePhysicalOp
130
- # NOTE: this DataRecord will be discarded and replaced by the scan_operator;
131
- # it is simply a vessel to inform the scan_operator which record to fetch
132
- candidate = DataRecord(schema=SourceRecord, source_id=current_scan_idx)
133
- candidate.idx = current_scan_idx
134
- candidate.get_item_fn = datasource.get_item
135
-
136
- # run DataSourcePhysicalOp on record
137
- record_set = operator(candidate)
127
+ # run ScanPhysicalOp on current scan index
128
+ record_set = operator(current_scan_idx)
138
129
  records.extend(record_set.data_records)
139
130
  record_op_stats.extend(record_set.record_op_stats)
140
131
 
@@ -149,7 +140,7 @@ class NoSentinelSequentialSingleThreadProcessor(NoSentinelQueryProcessor, Sequen
149
140
  current_scan_idx += 1
150
141
 
151
142
  # update whether to keep scanning source records
152
- keep_scanning_source_records = current_scan_idx < datasource_len and len(records) < num_samples
143
+ keep_scanning_source_records = current_scan_idx < datareader_len and len(records) < num_samples
153
144
 
154
145
  # aggregate operators accept all input records at once
155
146
  elif isinstance(operator, AggregateOp):
@@ -193,7 +184,8 @@ class NoSentinelSequentialSingleThreadProcessor(NoSentinelQueryProcessor, Sequen
193
184
  if not self.nocache:
194
185
  for record in records:
195
186
  if getattr(record, "passed_operator", True):
196
- self.datadir.append_cache(operator.target_cache_id, record)
187
+ # self.datadir.append_cache(operator.target_cache_id, record)
188
+ pass
197
189
 
198
190
  # update processing_queues or output_records
199
191
  for record in records:
@@ -210,8 +202,9 @@ class NoSentinelSequentialSingleThreadProcessor(NoSentinelQueryProcessor, Sequen
210
202
 
211
203
  # if caching was allowed, close the cache
212
204
  if not self.nocache:
213
- for operator in plan.operators:
214
- self.datadir.close_cache(operator.target_cache_id)
205
+ for _ in plan.operators:
206
+ # self.datadir.close_cache(operator.target_cache_id)
207
+ pass
215
208
 
216
209
  # finalize plan stats
217
210
  total_plan_time = time.time() - plan_start_time
@@ -234,7 +227,6 @@ class NoSentinelPipelinedSingleThreadProcessor(NoSentinelQueryProcessor, Pipelin
234
227
  PipelinedSingleThreadExecutionStrategy.__init__(
235
228
  self,
236
229
  scan_start_idx=self.scan_start_idx,
237
- datadir=self.datadir,
238
230
  max_workers=self.max_workers,
239
231
  nocache=self.nocache,
240
232
  verbose=self.verbose
@@ -267,22 +259,21 @@ class NoSentinelPipelinedSingleThreadProcessor(NoSentinelQueryProcessor, Pipelin
267
259
  source_records_scanned = 0
268
260
  current_scan_idx = self.scan_start_idx
269
261
 
270
- # get handle to DataSource and pre-compute its size
262
+ # get handle to scan operator and pre-compute its size
271
263
  source_operator = plan.operators[0]
272
- assert isinstance(source_operator, DataSourcePhysicalOp), "First operator in physical plan must be a DataSourcePhysicalOp"
273
- datasource = source_operator.get_datasource()
274
- datasource_len = len(datasource)
264
+ assert isinstance(source_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
265
+ datareader_len = len(source_operator.datareader)
275
266
 
276
267
  # Calculate total work units - each record needs to go through each operator
277
268
  total_ops = len(plan.operators)
278
- total_items = min(num_samples, datasource_len) if num_samples != float("inf") else datasource_len
269
+ total_items = min(num_samples, datareader_len) if num_samples != float("inf") else datareader_len
279
270
  total_work_units = total_items * total_ops
280
271
  self.progress_manager.start(total_work_units)
281
272
  work_units_completed = 0
282
273
 
283
274
  try:
284
275
  # initialize processing queues for each operation
285
- processing_queues = {op.get_op_id(): [] for op in plan.operators if not isinstance(op, DataSourcePhysicalOp)}
276
+ processing_queues = {op.get_op_id(): [] for op in plan.operators if not isinstance(op, ScanPhysicalOp)}
286
277
 
287
278
  # execute the plan until either:
288
279
  # 1. all records have been processed, or
@@ -301,16 +292,11 @@ class NoSentinelPipelinedSingleThreadProcessor(NoSentinelQueryProcessor, Pipelin
301
292
  # create empty lists for records and execution stats generated by executing this operator on its next input(s)
302
293
  records, record_op_stats = [], []
303
294
 
304
- # invoke datasource operator(s) until we run out of source records or hit the num_samples limit
305
- if isinstance(operator, DataSourcePhysicalOp):
295
+ # invoke scan operator(s) until we run out of source records or hit the num_samples limit
296
+ if isinstance(operator, ScanPhysicalOp):
306
297
  if keep_scanning_source_records:
307
- # construct input DataRecord for DataSourcePhysicalOp
308
- candidate = DataRecord(schema=SourceRecord, source_id=current_scan_idx)
309
- candidate.idx = current_scan_idx
310
- candidate.get_item_fn = datasource.get_item
311
-
312
- # run DataSourcePhysicalOp on record
313
- record_set = operator(candidate)
298
+ # run ScanPhysicalOp on current scan index
299
+ record_set = operator(current_scan_idx)
314
300
  records = record_set.data_records
315
301
  record_op_stats = record_set.record_op_stats
316
302
 
@@ -326,15 +312,15 @@ class NoSentinelPipelinedSingleThreadProcessor(NoSentinelQueryProcessor, Pipelin
326
312
  current_scan_idx += 1
327
313
 
328
314
  # update whether to keep scanning source records
329
- keep_scanning_source_records = current_scan_idx < datasource_len and source_records_scanned < num_samples
315
+ keep_scanning_source_records = current_scan_idx < datareader_len and source_records_scanned < num_samples
330
316
 
331
317
  # only invoke aggregate operator(s) once there are no more source records and all
332
318
  # upstream operators' processing queues are empty
333
319
  elif isinstance(operator, AggregateOp):
334
320
  upstream_ops_are_finished = True
335
321
  for upstream_op_idx in range(op_idx):
336
- # datasources do not have processing queues
337
- if isinstance(plan.operators[upstream_op_idx], DataSourcePhysicalOp):
322
+ # scan operators do not have processing queues
323
+ if isinstance(plan.operators[upstream_op_idx], ScanPhysicalOp):
338
324
  continue
339
325
 
340
326
  # check upstream ops which do have a processing queue
@@ -383,7 +369,8 @@ class NoSentinelPipelinedSingleThreadProcessor(NoSentinelQueryProcessor, Pipelin
383
369
  if not self.nocache:
384
370
  for record in records:
385
371
  if getattr(record, "passed_operator", True):
386
- self.datadir.append_cache(operator.target_cache_id, record)
372
+ # self.datadir.append_cache(operator.target_cache_id, record)
373
+ pass
387
374
 
388
375
  # update processing_queues or output_records
389
376
  for record in records:
@@ -404,8 +391,9 @@ class NoSentinelPipelinedSingleThreadProcessor(NoSentinelQueryProcessor, Pipelin
404
391
 
405
392
  # if caching was allowed, close the cache
406
393
  if not self.nocache:
407
- for operator in plan.operators:
408
- self.datadir.close_cache(operator.target_cache_id)
394
+ for _ in plan.operators:
395
+ # self.datadir.close_cache(operator.target_cache_id)
396
+ pass
409
397
 
410
398
  # finalize plan stats
411
399
  total_plan_time = time.time() - plan_start_time
@@ -428,7 +416,6 @@ class NoSentinelPipelinedParallelProcessor(NoSentinelQueryProcessor, PipelinedPa
428
416
  PipelinedParallelExecutionStrategy.__init__(
429
417
  self,
430
418
  scan_start_idx=self.scan_start_idx,
431
- datadir=self.datadir,
432
419
  max_workers=self.max_workers,
433
420
  nocache=self.nocache,
434
421
  verbose=self.verbose
@@ -461,15 +448,14 @@ class NoSentinelPipelinedParallelProcessor(NoSentinelQueryProcessor, PipelinedPa
461
448
  # source_records_scanned = 0
462
449
  # current_scan_idx = self.scan_start_idx
463
450
 
464
- # # get handle to DataSource and pre-compute its size
451
+ # # get handle to scan operator and pre-compute its size
465
452
  # source_operator = plan.operators[0]
466
- # assert isinstance(source_operator, DataSourcePhysicalOp), "First operator in physical plan must be a DataSourcePhysicalOp"
467
- # datasource = source_operator.get_datasource()
468
- # datasource_len = len(datasource)
453
+ # assert isinstance(source_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
454
+ # datareader_len = len(source_operator.datareader)
469
455
 
470
456
  # # Calculate total work units - each record needs to go through each operator
471
457
  # total_ops = len(plan.operators)
472
- # total_items = min(num_samples, datasource_len) if num_samples != float("inf") else datasource_len
458
+ # total_items = min(num_samples, datareader_len) if num_samples != float("inf") else datareader_len
473
459
  # total_work_units = total_items * total_ops
474
460
  # self.progress_manager.start(total_work_units)
475
461
  # work_units_completed = 0
@@ -518,14 +504,11 @@ class NoSentinelPipelinedParallelProcessor(NoSentinelQueryProcessor, PipelinedPa
518
504
  # for _, operator in enumerate(plan.operators):
519
505
  # op_id = operator.get_op_id()
520
506
 
521
- # if isinstance(operator, DataSourcePhysicalOp) and keep_scanning_source_records:
507
+ # if isinstance(operator, ScanPhysicalOp) and keep_scanning_source_records:
522
508
  # # Submit source operator task
523
- # candidate = DataRecord(schema=SourceRecord, source_id=current_scan_idx)
524
- # candidate.idx = current_scan_idx
525
- # candidate.get_item_fn = datasource.get_item
526
- # futures.append(executor.submit(PhysicalOperator.execute_op_wrapper, operator, candidate))
509
+ # futures.append(executor.submit(PhysicalOperator.execute_op_wrapper, operator, current_scan_idx))
527
510
  # current_scan_idx += 1
528
- # keep_scanning_source_records = current_scan_idx < datasource_len and source_records_scanned < num_samples
511
+ # keep_scanning_source_records = current_scan_idx < datareader_len and source_records_scanned < num_samples
529
512
 
530
513
  # elif len(processing_queues[op_id]) > 0:
531
514
  # # Submit task for next record in queue
@@ -538,8 +521,9 @@ class NoSentinelPipelinedParallelProcessor(NoSentinelQueryProcessor, PipelinedPa
538
521
 
539
522
  # # if caching was allowed, close the cache
540
523
  # if not self.nocache:
541
- # for operator in plan.operators:
542
- # self.datadir.close_cache(operator.target_cache_id)
524
+ # for _ in plan.operators:
525
+ # # self.datadir.close_cache(operator.target_cache_id)
526
+ # pass
543
527
 
544
528
  # # finalize plan stats
545
529
  # total_plan_time = time.time() - plan_start_time
@@ -2,9 +2,8 @@ from abc import abstractmethod
2
2
  from concurrent.futures import ThreadPoolExecutor
3
3
 
4
4
  from palimpzest.core.data.dataclasses import PlanStats, RecordOpStats
5
- from palimpzest.core.data.datasources import DataSource, ValidationDataSource
5
+ from palimpzest.core.data.datareaders import DataReader
6
6
  from palimpzest.core.elements.records import DataRecord, DataRecordCollection
7
- from palimpzest.datamanager.datamanager import DataDirectory
8
7
  from palimpzest.policy import Policy
9
8
  from palimpzest.query.optimizer.cost_model import CostModel
10
9
  from palimpzest.query.optimizer.optimizer import Optimizer
@@ -44,16 +43,15 @@ class QueryProcessor:
44
43
 
45
44
  self.config = config or QueryProcessorConfig()
46
45
  self.dataset = dataset
47
- self.datasource = self._get_datasource(self.dataset)
46
+ self.datareader = self._get_datareader(self.dataset)
48
47
  self.num_samples = self.config.num_samples
49
- self.using_validation_data = isinstance(self.datasource, ValidationDataSource)
48
+ self.val_datasource = self.config.val_datasource
50
49
  self.scan_start_idx = self.config.scan_start_idx
51
50
  self.nocache = self.config.nocache
52
51
  self.verbose = self.config.verbose
53
52
  self.max_workers = self.config.max_workers
54
53
  self.num_workers_per_plan = self.config.num_workers_per_plan
55
54
  self.min_plans = self.config.min_plans
56
- self.datadir = DataDirectory()
57
55
 
58
56
  self.policy = self.config.policy
59
57
 
@@ -70,16 +68,15 @@ class QueryProcessor:
70
68
  assert optimizer is not None, "Optimizer is required. Please use QueryProcessorFactory.create_processor() to initialize a QueryProcessor."
71
69
  self.optimizer = optimizer
72
70
 
73
- def _get_datasource(self, dataset: Set | DataSource) -> str:
71
+ def _get_datareader(self, dataset: Set | DataReader) -> DataReader:
74
72
  """
75
- Gets the DataSource for the given dataset.
73
+ Gets the DataReader for the given dataset.
76
74
  """
77
- # iterate until we reach DataSource
75
+ # iterate until we reach DataReader
78
76
  while isinstance(dataset, Set):
79
77
  dataset = dataset._source
80
78
 
81
- # this will throw an exception if datasource is not registered with PZ
82
- return DataDirectory().get_registered_dataset(dataset.dataset_id)
79
+ return dataset
83
80
 
84
81
  def execution_id(self) -> str:
85
82
  """
@@ -92,13 +89,6 @@ class QueryProcessor:
92
89
 
93
90
  return hash_for_id(id_str)
94
91
 
95
- def clear_cached_examples(self):
96
- """
97
- Clear cached codegen samples.
98
- """
99
- cache = self.datadir.get_cache_service()
100
- cache.rm_cache()
101
-
102
92
  def get_max_quality_plan_id(self, plans: list[PhysicalPlan]) -> str:
103
93
  """
104
94
  Return the plan_id for the plan with the highest quality in the list of plans.
@@ -233,7 +223,7 @@ class QueryProcessor:
233
223
 
234
224
  # get the initial set of optimal plans according to the optimizer
235
225
  plans = optimizer.optimize(dataset, policy)
236
- while len(plans) > 1 and self.scan_start_idx < len(self.datasource):
226
+ while len(plans) > 1 and self.scan_start_idx < len(self.datareader):
237
227
  # identify the plan with the highest quality in the set
238
228
  max_quality_plan_id = self.get_max_quality_plan_id(plans)
239
229
 
@@ -244,7 +234,7 @@ class QueryProcessor:
244
234
  records.extend(new_records)
245
235
  plan_stats.extend(new_plan_stats)
246
236
 
247
- if self.scan_start_idx + self.num_samples < len(self.datasource):
237
+ if self.scan_start_idx + self.num_samples < len(self.datareader):
248
238
  # update cost model and optimizer
249
239
  execution_data.extend(new_execution_data)
250
240
  cost_model = CostModel(sample_execution_data=execution_data)
@@ -256,7 +246,7 @@ class QueryProcessor:
256
246
  # update scan start idx
257
247
  self.scan_start_idx += self.num_samples
258
248
 
259
- if self.scan_start_idx < len(self.datasource):
249
+ if self.scan_start_idx < len(self.datareader):
260
250
  # execute final plan until end
261
251
  final_plan = plans[0]
262
252
  new_records, new_plan_stats = self.execute_plan(
@@ -77,18 +77,22 @@ class QueryProcessorFactory:
77
77
 
78
78
  Args:
79
79
  dataset: The dataset to process
80
- config: Additional configuration parameters:
80
+ config: The user-provided QueryProcessorConfig; if it is None, the default config will be used
81
+ kwargs: Additional keyword arguments to pass to the QueryProcessorConfig
81
82
  """
82
83
  if config is None:
83
84
  config = QueryProcessorConfig()
84
85
 
86
+ # apply any additional keyword arguments to the config
87
+ config.update(**kwargs)
88
+
85
89
  config = cls._config_validation_and_normalization(config)
86
90
  processing_strategy, execution_strategy, optimizer_strategy = cls._normalize_strategies(config)
87
91
  optimizer = cls._create_optimizer(optimizer_strategy, config)
88
92
 
89
93
  processor_key = (processing_strategy, execution_strategy)
90
94
  processor_cls = cls.PROCESSOR_MAPPING.get(processor_key)
91
-
95
+
92
96
  if processor_cls is None:
93
97
  raise ValueError(f"Unsupported combination of processing strategy {processing_strategy} "
94
98
  f"and execution strategy {execution_strategy}")
@@ -96,7 +100,7 @@ class QueryProcessorFactory:
96
100
  return processor_cls(dataset=dataset, optimizer=optimizer, config=config, **kwargs)
97
101
 
98
102
  @classmethod
99
- def create_and_run_processor(cls, dataset: Dataset, config: QueryProcessorConfig, **kwargs) -> DataRecordCollection:
103
+ def create_and_run_processor(cls, dataset: Dataset, config: QueryProcessorConfig | None = None, **kwargs) -> DataRecordCollection:
100
104
  # TODO(Jun): Consider to use cache here.
101
105
  processor = cls.create_processor(dataset=dataset, config=config, **kwargs)
102
106
  return processor.execute()
@@ -153,8 +157,11 @@ class QueryProcessorFactory:
153
157
  raise ValueError("Policy is required for optimizer")
154
158
 
155
159
  if not config.nocache:
156
- raise ValueError("nocache=False is not supported yet!!")
157
-
160
+ raise ValueError("nocache=False is not supported yet")
161
+
162
+ if config.val_datasource is None and config.processing_strategy in [ProcessingStrategyType.MAB_SENTINEL, ProcessingStrategyType.RANDOM_SAMPLING]:
163
+ raise ValueError("val_datasource is required for MAB_SENTINEL and RANDOM_SAMPLING processing strategies")
164
+
158
165
  available_models = getattr(config, 'available_models', [])
159
166
  if available_models is None or len(available_models) == 0:
160
167
  available_models = get_models(include_vision=True)