palimpzest 0.7.21__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.21.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,24 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
- from concurrent.futures import ThreadPoolExecutor, wait
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
4
 
5
5
  import numpy as np
6
6
  from chromadb.api.models.Collection import Collection
7
7
 
8
- from palimpzest.constants import PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS
9
- from palimpzest.core.data.dataclasses import OperatorCostEstimates, PlanStats, RecordOpStats
10
- from palimpzest.core.data.datareaders import DataReader
8
+ from palimpzest.constants import Cardinality
9
+ from palimpzest.core.data.dataset import Dataset
11
10
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
11
+ from palimpzest.core.models import GenerationStats, PlanStats, SentinelPlanStats
12
12
  from palimpzest.policy import Policy
13
13
  from palimpzest.query.operators.convert import LLMConvert
14
- from palimpzest.query.operators.filter import FilterOp, LLMFilter
14
+ from palimpzest.query.operators.filter import LLMFilter
15
+ from palimpzest.query.operators.join import JoinOp
15
16
  from palimpzest.query.operators.physical import PhysicalOperator
16
17
  from palimpzest.query.operators.retrieve import RetrieveOp
17
- from palimpzest.query.operators.scan import ScanPhysicalOp
18
+ from palimpzest.query.operators.scan import ContextScanOp, ScanPhysicalOp
18
19
  from palimpzest.query.optimizer.plan import PhysicalPlan, SentinelPlan
19
20
  from palimpzest.utils.progress import PZSentinelProgressManager
21
+ from palimpzest.validator.validator import Validator
20
22
 
21
23
  logger = logging.getLogger(__name__)
22
24
 
@@ -24,35 +26,20 @@ class BaseExecutionStrategy:
24
26
  def __init__(self,
25
27
  scan_start_idx: int = 0,
26
28
  max_workers: int | None = None,
29
+ batch_size: int | None = None,
27
30
  num_samples: int | None = None,
28
- cache: bool = False,
29
31
  verbose: bool = False,
30
32
  progress: bool = True,
31
33
  *args,
32
34
  **kwargs):
33
35
  self.scan_start_idx = scan_start_idx
34
36
  self.max_workers = max_workers
37
+ self.batch_size = batch_size
35
38
  self.num_samples = num_samples
36
- self.cache = cache
37
39
  self.verbose = verbose
38
40
  self.progress = progress
39
41
 
40
42
 
41
- def _add_records_to_cache(self, target_cache_id: str, records: list[DataRecord]) -> None:
42
- """Add each record (which isn't filtered) to the cache for the given target_cache_id."""
43
- if self.cache:
44
- for record in records:
45
- if getattr(record, "passed_operator", True):
46
- # self.datadir.append_cache(target_cache_id, record)
47
- pass
48
-
49
- def _close_cache(self, target_cache_ids: list[str]) -> None:
50
- """Close the cache for each of the given target_cache_ids"""
51
- if self.cache:
52
- for target_cache_id in target_cache_ids: # noqa: B007
53
- # self.datadir.close_cache(target_cache_id)
54
- pass
55
-
56
43
  class ExecutionStrategy(BaseExecutionStrategy, ABC):
57
44
  """Base strategy for executing query plans. Defines how to execute a PhysicalPlan.
58
45
  """
@@ -66,19 +53,24 @@ class ExecutionStrategy(BaseExecutionStrategy, ABC):
66
53
  """Execute a single plan according to strategy"""
67
54
  pass
68
55
 
69
- def _create_input_queues(self, plan: PhysicalPlan) -> dict[str, list]:
56
+ def _create_input_queues(self, plan: PhysicalPlan) -> dict[str, dict[str, list]]:
70
57
  """Initialize input queues for each operator in the plan."""
71
- input_queues = {}
72
- for op in plan.operators:
73
- inputs = []
58
+ input_queues = {f"{topo_idx}-{op.get_full_op_id()}": {} for topo_idx, op in enumerate(plan)}
59
+ for topo_idx, op in enumerate(plan):
60
+ full_op_id = op.get_full_op_id()
61
+ unique_op_id = f"{topo_idx}-{full_op_id}"
74
62
  if isinstance(op, ScanPhysicalOp):
75
63
  scan_end_idx = (
76
- len(op.datareader)
64
+ len(op.datasource)
77
65
  if self.num_samples is None
78
- else min(self.scan_start_idx + self.num_samples, len(op.datareader))
66
+ else min(self.scan_start_idx + self.num_samples, len(op.datasource))
79
67
  )
80
- inputs = [idx for idx in range(self.scan_start_idx, scan_end_idx)]
81
- input_queues[op.get_full_op_id()] = inputs
68
+ input_queues[unique_op_id][f"source_{full_op_id}"] = [idx for idx in range(self.scan_start_idx, scan_end_idx)]
69
+ elif isinstance(op, ContextScanOp):
70
+ input_queues[unique_op_id][f"source_{full_op_id}"] = [None]
71
+ else:
72
+ for source_unique_full_op_id in plan.get_source_unique_full_op_ids(topo_idx, op):
73
+ input_queues[unique_op_id][source_unique_full_op_id] = []
82
74
 
83
75
  return input_queues
84
76
 
@@ -90,7 +82,6 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
90
82
  """
91
83
  def __init__(
92
84
  self,
93
- val_datasource: DataReader,
94
85
  k: int,
95
86
  j: int,
96
87
  sample_budget: int,
@@ -103,7 +94,6 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
103
94
  **kwargs,
104
95
  ):
105
96
  super().__init__(*args, **kwargs)
106
- self.val_datasource = val_datasource
107
97
  self.k = k
108
98
  self.j = j
109
99
  self.sample_budget = sample_budget
@@ -114,292 +104,200 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
114
104
  self.rng = np.random.default_rng(seed=seed)
115
105
  self.exp_name = exp_name
116
106
 
117
- # special cache which is used for tracking the target record sets for each (source_idx, logical_op_id)
118
- self.champion_output_cache: dict[int, dict[str, tuple[DataRecordSet, float]]] = {}
119
-
120
107
  # general cache which maps hash(logical_op_id, phys_op_id, hash(input)) --> record_set
121
108
  self.cache: dict[int, DataRecordSet] = {}
122
109
 
123
110
  # progress manager used to track progress of the execution
124
111
  self.progress_manager: PZSentinelProgressManager | None = None
125
112
 
126
- def _compute_quality(
127
- self,
128
- physical_op_cls: type[PhysicalOperator],
129
- record_set: DataRecordSet,
130
- target_record_set: DataRecordSet,
131
- ) -> DataRecordSet:
132
- """
133
- Compute the quality for the given `record_set` by comparing it to the `target_record_set`.
134
-
135
- Update the record_set by assigning the quality to each entry in its record_op_stats and
136
- returning the updated record_set.
137
- """
138
- # if this operation failed
139
- if len(record_set) == 0:
140
- record_set.record_op_stats[0].quality = 0.0
141
-
142
- # if this operation is a filter:
143
- # - return 1.0 if there's a match in the expected output which this operator does not filter out and 0.0 otherwise
144
- elif issubclass(physical_op_cls, FilterOp):
145
- # NOTE: we know that record_set.data_records will contain a single entry for a filter op
146
- record = record_set.data_records[0]
147
-
148
- # search for a record in the target with the same set of fields
149
- found_match_in_target = False
150
- for target_record in target_record_set:
151
- all_correct = True
152
- for field, value in record.field_values.items():
153
- if value != target_record[field]:
154
- all_correct = False
155
- break
156
-
157
- if all_correct:
158
- found_match_in_target = target_record.passed_operator
159
- break
160
-
161
- # set quality based on whether we found a match in the target and return
162
- record_set.record_op_stats[0].quality = int(record.passed_operator == found_match_in_target)
163
-
164
- return record_set
165
-
166
- # if this is a successful convert operation
167
- else:
168
- # NOTE: the following computation assumes we do not project out computed values
169
- # (and that the validation examples provide all computed fields); even if
170
- # a user program does add projection, we can ignore the projection on the
171
- # validation dataset and use the champion model (as opposed to the validation
172
- # output) for scoring fields which have their values projected out
173
-
174
- # GREEDY ALGORITHM
175
- # for each record in the expected output, we look for the computed record which maximizes the quality metric;
176
- # once we've identified that computed record we remove it from consideration for the next expected output
177
- field_to_score_fn = target_record_set.get_field_to_score_fn()
178
- for target_record in target_record_set:
179
- best_quality, best_record_op_stats = 0.0, None
180
- for record_op_stats in record_set.record_op_stats:
181
- # if we already assigned this record a quality, skip it
182
- if record_op_stats.quality is not None:
183
- continue
184
-
185
- # compute number of matches between this record's computed fields and this expected record's outputs
186
- total_quality = 0
187
- for field in record_op_stats.generated_fields:
188
- computed_value = record_op_stats.record_state.get(field, None)
189
- expected_value = target_record[field]
190
-
191
- # get the metric function for this field
192
- score_fn = field_to_score_fn.get(field, "exact")
193
-
194
- # compute exact match
195
- if score_fn == "exact":
196
- total_quality += int(computed_value == expected_value)
197
-
198
- # compute UDF metric
199
- elif callable(score_fn):
200
- total_quality += score_fn(computed_value, expected_value)
201
-
202
- # otherwise, throw an exception
203
- else:
204
- raise Exception(f"Unrecognized score_fn: {score_fn}")
205
-
206
- # compute recall and update best seen so far
207
- quality = total_quality / len(record_op_stats.generated_fields)
208
- if quality > best_quality:
209
- best_quality = quality
210
- best_record_op_stats = record_op_stats
211
-
212
- # set best_quality as quality for the best_record_op_stats
213
- if best_record_op_stats is not None:
214
- best_record_op_stats.quality = best_quality
215
-
216
- # for any records which did not receive a quality, set it to 0.0 as these are unexpected extras
217
- for record_op_stats in record_set.record_op_stats:
218
- if record_op_stats.quality is None:
219
- record_op_stats.quality = 0.0
220
-
221
- return record_set
222
-
223
113
  def _score_quality(
224
114
  self,
225
- physical_op_cls: type[PhysicalOperator],
226
- source_idx_to_record_sets: dict[int, list[DataRecordSet]],
227
- source_idx_to_target_record_set: dict[int, DataRecordSet],
228
- ) -> dict[int, list[DataRecordSet]]:
229
- """
230
- NOTE: This approach to cost modeling does not work directly for aggregation queries;
231
- for these queries, we would ask the user to provide validation data for the step immediately
232
- before a final aggregation
233
-
234
- NOTE: This function currently assumes that one-to-many converts do NOT create duplicate outputs.
235
- This assumption would break if, for example, we extracted the breed of every dog in an image.
236
- If there were two golden retrievers and a bernoodle in an image and we extracted:
237
-
238
- {"image": "file1.png", "breed": "Golden Retriever"}
239
- {"image": "file1.png", "breed": "Golden Retriever"}
240
- {"image": "file1.png", "breed": "Bernedoodle"}
241
-
242
- This function would currently give perfect accuracy to the following output:
243
-
244
- {"image": "file1.png", "breed": "Golden Retriever"}
245
- {"image": "file1.png", "breed": "Bernedoodle"}
246
-
247
- Even though it is missing one of the golden retrievers.
248
- """
115
+ validator: Validator,
116
+ source_indices_to_record_sets: dict[tuple[str], list[tuple[DataRecordSet, PhysicalOperator]]],
117
+ ) -> tuple[dict[int, list[DataRecordSet]], GenerationStats]:
249
118
  # extract information about the logical operation performed at this stage of the sentinel plan;
250
119
  # NOTE: we can infer these fields from context clues, but in the long-term we should have a more
251
120
  # principled way of getting these directly from attributes either stored in the sentinel_plan
252
121
  # or in the PhysicalOperator
253
- is_perfect_quality_op = (
254
- not issubclass(physical_op_cls, LLMConvert)
255
- and not issubclass(physical_op_cls, LLMFilter)
256
- and not issubclass(physical_op_cls, RetrieveOp)
257
- )
122
+ def is_perfect_quality_op(op: PhysicalOperator):
123
+ return (
124
+ not isinstance(op, LLMConvert)
125
+ and not isinstance(op, LLMFilter)
126
+ and not isinstance(op, RetrieveOp)
127
+ and not isinstance(op, JoinOp)
128
+ )
129
+
130
+ # create minimal set of futures necessary to compute quality of each output record
131
+ futures, full_hashes, full_hash_to_bool_output = [], set(), {}
132
+ with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
133
+ for _, record_set_tuples in source_indices_to_record_sets.items():
134
+ for record_set, op in record_set_tuples:
135
+ # if this operation does not involve an LLM, every record_op_stats object gets perfect quality
136
+ if is_perfect_quality_op(op):
137
+ for record_op_stats in record_set.record_op_stats:
138
+ record_op_stats.quality = 1.0
139
+ continue
258
140
 
259
- # compute quality of each output computed by this operator
260
- for source_idx, record_sets in source_idx_to_record_sets.items():
261
- # if this operation does not involve an LLM, every record_op_stats object gets perfect quality
262
- if is_perfect_quality_op:
263
- for record_set in record_sets:
264
- for record_op_stats in record_set.record_op_stats:
265
- record_op_stats.quality = 1.0
266
- continue
141
+ # if the operation failed, assign 0.0 quality
142
+ if len(record_set) == 0:
143
+ record_set.record_op_stats[0].quality = 0.0
144
+ continue
267
145
 
268
- # extract target output for this record set
269
- target_record_set = source_idx_to_target_record_set[source_idx]
146
+ # create future for map
147
+ if isinstance(op, LLMConvert) and op.cardinality is Cardinality.ONE_TO_ONE:
148
+ fields = op.generated_fields
149
+ input_record: DataRecord = record_set.input
150
+ output = record_set.data_records[0].to_dict(project_cols=fields)
151
+ output_str = record_set.data_records[0].to_json_str(project_cols=fields, bytes_to_str=True, sorted=True)
152
+ full_hash = f"{hash(input_record)}{hash(output_str)}"
153
+ if full_hash not in full_hashes:
154
+ full_hashes.add(full_hash)
155
+ futures.append(executor.submit(validator._score_map, op, fields, input_record, output, full_hash))
156
+
157
+ # create future for flat map
158
+ elif isinstance(op, LLMConvert) and op.cardinality is Cardinality.ONE_TO_MANY:
159
+ fields = op.generated_fields
160
+ input_record: DataRecord = record_set.input
161
+ output, output_strs = [], []
162
+ for data_record in record_set.data_records:
163
+ output.append(data_record.to_dict(project_cols=fields))
164
+ output_strs.append(data_record.to_json_str(project_cols=fields, bytes_to_str=True, sorted=True))
165
+ full_hash = f"{hash(input_record)}{hash(tuple(sorted(output_strs)))}"
166
+ if full_hash not in full_hashes:
167
+ full_hashes.add(full_hash)
168
+ futures.append(executor.submit(validator._score_flat_map, op, fields, input_record, output, full_hash))
169
+
170
+ # create future for retrieve
171
+ elif isinstance(op, RetrieveOp):
172
+ fields = op.generated_fields
173
+ input_record: DataRecord = record_set.input
174
+ output = record_set.data_records[0].to_dict(project_cols=fields)
175
+ output_str = record_set.data_records[0].to_json_str(project_cols=fields, bytes_to_str=True, sorted=True)
176
+ full_hash = f"{hash(input_record)}{hash(output_str)}"
177
+ if full_hash not in full_hashes:
178
+ full_hashes.add(full_hash)
179
+ futures.append(executor.submit(validator._score_retrieve, op, fields, input_record, output, full_hash))
180
+
181
+ # create future for filter
182
+ elif isinstance(op, LLMFilter):
183
+ filter_str = op.filter_obj.filter_condition
184
+ input_record: DataRecord = record_set.input
185
+ output = record_set.data_records[0].passed_operator
186
+ full_hash = f"{filter_str}{hash(input_record)}"
187
+ if full_hash not in full_hashes:
188
+ full_hash_to_bool_output[full_hash] = output
189
+ full_hashes.add(full_hash)
190
+ futures.append(executor.submit(validator._score_filter, op, filter_str, input_record, output, full_hash))
191
+
192
+ # create future for join
193
+ elif isinstance(op, JoinOp):
194
+ condition = op.condition
195
+ for left_idx, left_input_record in enumerate(record_set.input[0]):
196
+ for right_idx, right_input_record in enumerate(record_set.input[1]):
197
+ record_idx = left_idx * len(record_set.input[1]) + right_idx
198
+ output = record_set.data_records[record_idx].passed_operator
199
+ full_hash = f"{condition}{hash(left_input_record)}{hash(right_input_record)}"
200
+ if full_hash not in full_hashes:
201
+ full_hash_to_bool_output[full_hash] = output
202
+ full_hashes.add(full_hash)
203
+ futures.append(executor.submit(validator._score_join, op, condition, left_input_record, right_input_record, output, full_hash))
204
+
205
+ # collect results from futures
206
+ full_hash_to_score, validation_gen_stats = {}, GenerationStats()
207
+ for future in as_completed(futures):
208
+ score, gen_stats, full_hash = future.result()
209
+ full_hash_to_score[full_hash] = score
210
+ validation_gen_stats += gen_stats
270
211
 
271
- # for each record_set produced by an operation, compute its quality
272
- for record_set in record_sets:
273
- record_set = self._compute_quality(physical_op_cls, record_set, target_record_set)
212
+ # compute quality of each output computed by this operator
213
+ for _, record_set_tuples in source_indices_to_record_sets.items():
214
+ for record_set, op in record_set_tuples:
215
+ if is_perfect_quality_op(op) or len(record_set) == 0:
216
+ continue
217
+
218
+ if isinstance(op, LLMConvert) and op.cardinality is Cardinality.ONE_TO_ONE:
219
+ fields = op.generated_fields
220
+ input_record: DataRecord = record_set.input
221
+ output_str = record_set.data_records[0].to_json_str(project_cols=fields, bytes_to_str=True, sorted=True)
222
+ full_hash = f"{hash(input_record)}{hash(output_str)}"
223
+ record_set.record_op_stats[0].quality = full_hash_to_score[full_hash]
224
+
225
+ elif isinstance(op, LLMConvert) and op.cardinality is Cardinality.ONE_TO_MANY:
226
+ fields = op.generated_fields
227
+ input_record: DataRecord = record_set.input
228
+ output_strs = []
229
+ for data_record in record_set.data_records:
230
+ output_strs.append(data_record.to_json_str(project_cols=fields, bytes_to_str=True, sorted=True))
231
+ full_hash = f"{hash(input_record)}{hash(tuple(sorted(output_strs)))}"
232
+ score = full_hash_to_score[full_hash]
233
+ for record_op_stats in record_set.record_op_stats:
234
+ record_op_stats.quality = score
235
+
236
+ # TODO: this scoring function will (likely) bias towards small values of k since it
237
+ # measures precision and not recall / F1; will need to revisit this in the future
238
+ elif isinstance(op, RetrieveOp):
239
+ fields = op.generated_fields
240
+ input_record: DataRecord = record_set.input
241
+ output_str = record_set.data_records[0].to_json_str(project_cols=fields, bytes_to_str=True, sorted=True)
242
+ full_hash = f"{hash(input_record)}{hash(output_str)}"
243
+ score = full_hash_to_score[full_hash]
244
+ record_set.record_op_stats[0].quality = score
245
+
246
+ elif isinstance(op, LLMFilter):
247
+ filter_str = op.filter_obj.filter_condition
248
+ input_record: DataRecord = record_set.input
249
+ output = record_set.data_records[0].passed_operator
250
+ full_hash = f"{filter_str}{hash(input_record)}"
251
+ if output == full_hash_to_bool_output[full_hash]:
252
+ record_set.record_op_stats[0].quality = full_hash_to_score[full_hash]
253
+ else:
254
+ record_set.record_op_stats[0].quality = 1.0 - full_hash_to_score[full_hash]
255
+
256
+ elif isinstance(op, JoinOp):
257
+ condition = op.condition
258
+ for left_idx, left_input_record in enumerate(record_set.input[0]):
259
+ for right_idx, right_input_record in enumerate(record_set.input[1]):
260
+ record_idx = left_idx * len(record_set.input[1]) + right_idx
261
+ output = record_set.data_records[record_idx].passed_operator
262
+ full_hash = f"{condition}{hash(left_input_record)}{hash(right_input_record)}"
263
+ if output == full_hash_to_bool_output[full_hash]:
264
+ record_set.record_op_stats[record_idx].quality = full_hash_to_score[full_hash]
265
+ else:
266
+ record_set.record_op_stats[record_idx].quality = 1.0 - full_hash_to_score[full_hash]
274
267
 
275
268
  # return the quality annotated record sets
276
- return source_idx_to_record_sets
277
-
278
- def _get_target_record_sets(
279
- self,
280
- logical_op_id: str,
281
- source_idx_to_record_set_tuples: dict[int, list[tuple[DataRecordSet, PhysicalOperator, bool]]],
282
- expected_outputs: dict[int, dict] | None,
283
- ) -> dict[int, DataRecordSet]:
284
- # initialize mapping from source index to target record sets
285
- source_idx_to_target_record_set = {}
286
-
287
- for source_idx, record_set_tuples in source_idx_to_record_set_tuples.items():
288
- # get the first generated output for this source_idx
289
- base_target_record = None
290
- for record_set, _, _ in record_set_tuples:
291
- if len(record_set) > 0:
292
- base_target_record = record_set[0]
293
- break
294
-
295
- # compute availability of data
296
- base_target_present = base_target_record is not None
297
- labels_present = expected_outputs is not None
298
- labels_for_source_present = False
299
- if labels_present and source_idx in expected_outputs:
300
- labels = expected_outputs[source_idx].get("labels", [])
301
- labels_dict_lst = labels if isinstance(labels, list) else [labels]
302
- labels_for_source_present = labels_dict_lst != [] and labels_dict_lst != [None]
303
-
304
- # if we have a base target record and label info, use the label info to construct the target record set
305
- if base_target_present and labels_for_source_present:
306
- # get the field_to_score_fn
307
- field_to_score_fn = expected_outputs[source_idx].get("score_fn", {})
308
-
309
- # construct the target record set; we force passed_operator to be True for all target records
310
- target_records = []
311
- for labels_dict in labels_dict_lst:
312
- target_record = base_target_record.copy()
313
- for field, value in labels_dict.items():
314
- target_record[field] = value
315
- target_record.passed_operator = True
316
- target_records.append(target_record)
317
-
318
- source_idx_to_target_record_set[source_idx] = DataRecordSet(target_records, None, field_to_score_fn)
319
- continue
320
-
321
- # get the best computed output for this (source_idx, logical_op_id) so far (if one exists)
322
- champion_record_set, champion_op_quality = None, None
323
- if source_idx in self.champion_output_cache and logical_op_id in self.champion_output_cache[source_idx]:
324
- champion_record_set, champion_op_quality = self.champion_output_cache[source_idx][logical_op_id]
325
-
326
- # get the highest quality output that we just computed
327
- max_quality_record_set, max_op_quality = self._pick_champion_output(record_set_tuples)
328
-
329
- # if this new output is of higher quality than our previous champion (or if we didn't have
330
- # a previous champion) then we update our champion record set
331
- if champion_op_quality is None or (max_op_quality is not None and max_op_quality > champion_op_quality):
332
- champion_record_set, champion_op_quality = max_quality_record_set, max_op_quality
333
-
334
- # update the cache with the new champion record set and quality
335
- if source_idx not in self.champion_output_cache:
336
- self.champion_output_cache[source_idx] = {}
337
- self.champion_output_cache[source_idx][logical_op_id] = (champion_record_set, champion_op_quality)
338
-
339
- # set the target
340
- source_idx_to_target_record_set[source_idx] = champion_record_set
341
-
342
- return source_idx_to_target_record_set
343
-
344
- def _pick_champion_output(self, record_set_tuples: list[tuple[DataRecordSet, PhysicalOperator, bool]]) -> tuple[DataRecordSet, float | None]:
345
- # find the operator with the highest estimated quality and return its record_set
346
- base_op_cost_est = OperatorCostEstimates(cardinality=1.0, cost_per_record=0.0, time_per_record=0.0, quality=1.0)
347
- champion_record_set, champion_quality = None, None
348
- for record_set, op, _ in record_set_tuples:
349
- # skip failed operations
350
- if len(record_set) == 0:
351
- continue
352
-
353
- # get the estimated quality of this operator
354
- est_quality = op.naive_cost_estimates(base_op_cost_est).quality if self._is_llm_op(op) else 1.0
355
- if champion_quality is None or est_quality > champion_quality:
356
- champion_record_set, champion_quality = record_set, est_quality
357
-
358
- return champion_record_set, champion_quality
359
-
360
- def _flatten_record_sets(self, source_idx_to_record_sets: dict[int, list[DataRecordSet]]) -> tuple[list[DataRecord], list[RecordOpStats]]:
361
- """
362
- Flatten the list of record sets and record op stats for each source_idx.
363
- """
364
- all_records, all_record_op_stats = [], []
365
- for _, record_sets in source_idx_to_record_sets.items():
366
- for record_set in record_sets:
367
- all_records.extend(record_set.data_records)
368
- all_record_op_stats.extend(record_set.record_op_stats)
369
-
370
- return all_records, all_record_op_stats
371
-
372
- def _execute_op_set(self, op_input_pairs: list[tuple[PhysicalOperator, DataRecord | int]]) -> tuple[dict[int, list[tuple[DataRecordSet, PhysicalOperator, bool]]], dict[str, int]]:
373
- def execute_op_wrapper(operator, input) -> tuple[DataRecordSet, PhysicalOperator, DataRecord | int]:
374
- record_set = operator(input)
375
- return record_set, operator, input
376
-
377
- # TODO: modify unit tests to always have record_op_stats so we can use record_op_stats for source_idx
378
- # for scan operators, `input` will be the source_idx
379
- def get_source_idx(input):
380
- return input.source_idx if isinstance(input, DataRecord) else input
381
-
382
- def get_hash(operator, input):
269
+ return source_indices_to_record_sets, validation_gen_stats
270
+
271
+ def _execute_op_set(self, unique_logical_op_id: str, op_inputs: list[tuple[PhysicalOperator, str | tuple, int | DataRecord | list[DataRecord] | tuple[list[DataRecord]]]]) -> tuple[dict[int, list[tuple[DataRecordSet, PhysicalOperator, bool]]], dict[str, int]]:
272
+ def execute_op_wrapper(operator: PhysicalOperator, source_indices: str | tuple, input: int | DataRecord | list[DataRecord] | tuple[list[DataRecord]]) -> tuple[DataRecordSet, PhysicalOperator, list[DataRecord] | list[int]]:
273
+ # operator is a join
274
+ record_set = operator(input[0], input[1]) if isinstance(operator, JoinOp) else operator(input)
275
+ return record_set, operator, source_indices, input
276
+
277
+ def get_hash(operator: PhysicalOperator, input: int | DataRecord | list[DataRecord] | tuple[list[DataRecord]]):
278
+ if isinstance(input, list):
279
+ input = tuple(input)
280
+ elif isinstance(input, tuple):
281
+ input = (tuple(input[0]), tuple(input[1]))
383
282
  return hash(f"{operator.get_full_op_id()}{hash(input)}")
384
283
 
385
284
  # initialize mapping from source indices to output record sets
386
- source_idx_to_record_sets_and_ops = {get_source_idx(input): [] for _, input in op_input_pairs}
285
+ source_indices_to_record_sets_and_ops = {source_indices: [] for _, source_indices, _ in op_inputs}
387
286
 
388
287
  # if any operations were previously executed, read the results from the cache
389
- final_op_input_pairs = []
390
- for operator, input in op_input_pairs:
288
+ final_op_inputs = []
289
+ for operator, source_indices, input in op_inputs:
391
290
  # compute hash
392
291
  op_input_hash = get_hash(operator, input)
393
292
 
394
293
  # get result from cache
395
294
  if op_input_hash in self.cache:
396
- source_idx = get_source_idx(input)
397
295
  record_set, operator = self.cache[op_input_hash]
398
- source_idx_to_record_sets_and_ops[source_idx].append((record_set, operator, False))
296
+ source_indices_to_record_sets_and_ops[source_indices].append((record_set, operator, False))
399
297
 
400
- # otherwise, add to final_op_input_pairs
298
+ # otherwise, add to final_op_inputs
401
299
  else:
402
- final_op_input_pairs.append((operator, input))
300
+ final_op_inputs.append((operator, source_indices, input))
403
301
 
404
302
  # keep track of the number of llm operations
405
303
  num_llm_ops = 0
@@ -408,46 +306,41 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
408
306
  with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
409
307
  # create futures
410
308
  futures = [
411
- executor.submit(execute_op_wrapper, operator, input)
412
- for operator, input in final_op_input_pairs
309
+ executor.submit(execute_op_wrapper, operator, source_indices, input)
310
+ for operator, source_indices, input in final_op_inputs
413
311
  ]
414
- output_record_sets = []
415
- while len(futures) > 0:
416
- done_futures, not_done_futures = wait(futures, timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
417
- for future in done_futures:
418
- # update output record sets
419
- record_set, operator, input = future.result()
420
- output_record_sets.append((record_set, operator, input))
421
-
422
- # update cache
423
- op_input_hash = get_hash(operator, input)
424
- self.cache[op_input_hash] = (record_set, operator)
425
312
 
426
- # update progress manager
427
- if self._is_llm_op(operator):
428
- num_llm_ops += 1
429
- self.progress_manager.incr(operator.get_logical_op_id(), num_samples=1, total_cost=record_set.get_total_cost())
313
+ output_record_sets = []
314
+ for future in as_completed(futures):
315
+ # update output record sets
316
+ record_set, operator, source_indices, input = future.result()
317
+ output_record_sets.append((record_set, operator, source_indices, input))
430
318
 
431
- # update futures
432
- futures = list(not_done_futures)
319
+ # update cache
320
+ op_input_hash = get_hash(operator, input)
321
+ self.cache[op_input_hash] = (record_set, operator)
433
322
 
434
- # update mapping from source_idx to record sets and operators
435
- for record_set, operator, input in output_record_sets:
436
- # get the source_idx associated with this input record;
437
- source_idx = get_source_idx(input)
323
+ # update progress manager
324
+ if self._is_llm_op(operator):
325
+ num_llm_ops += 1
326
+ self.progress_manager.incr(unique_logical_op_id, num_samples=1, total_cost=record_set.get_total_cost())
438
327
 
439
- # add record_set to mapping from source_idx --> record_sets
440
- source_idx_to_record_sets_and_ops[source_idx].append((record_set, operator, True))
328
+ # update mapping from source_indices to record sets and operators
329
+ for record_set, operator, source_indices, input in output_record_sets:
330
+ # add record_set to mapping from source_indices --> record_sets
331
+ record_set.input = input
332
+ source_indices_to_record_sets_and_ops[source_indices].append((record_set, operator, True))
441
333
 
442
- return source_idx_to_record_sets_and_ops, num_llm_ops
334
+ return source_indices_to_record_sets_and_ops, num_llm_ops
443
335
 
444
336
  def _is_llm_op(self, physical_op: PhysicalOperator) -> bool:
445
337
  is_llm_convert = isinstance(physical_op, LLMConvert)
446
338
  is_llm_filter = isinstance(physical_op, LLMFilter)
447
339
  is_llm_retrieve = isinstance(physical_op, RetrieveOp) and isinstance(physical_op.index, Collection)
448
- return is_llm_convert or is_llm_filter or is_llm_retrieve
340
+ is_llm_join = isinstance(physical_op, JoinOp)
341
+ return is_llm_convert or is_llm_filter or is_llm_retrieve or is_llm_join
449
342
 
450
343
  @abstractmethod
451
- def execute_sentinel_plan(self, sentinel_plan: SentinelPlan, expected_outputs: dict[str, dict]):
344
+ def execute_sentinel_plan(self, sentinel_plan: SentinelPlan, train_dataset: dict[str, Dataset], validator: Validator) -> SentinelPlanStats:
452
345
  """Execute a SentinelPlan according to strategy"""
453
346
  pass