palimpzest 0.7.21__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +259 -197
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +634 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +61 -5
- palimpzest/prompts/filter_prompts.py +50 -5
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
- palimpzest/prompts/prompt_factory.py +358 -46
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +157 -330
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +27 -21
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +22 -13
- palimpzest/query/operators/join.py +402 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +198 -80
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +41 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +27 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
- palimpzest-0.8.0.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -1,22 +1,24 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
|
-
from concurrent.futures import ThreadPoolExecutor,
|
|
3
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from chromadb.api.models.Collection import Collection
|
|
7
7
|
|
|
8
|
-
from palimpzest.constants import
|
|
9
|
-
from palimpzest.core.data.
|
|
10
|
-
from palimpzest.core.data.datareaders import DataReader
|
|
8
|
+
from palimpzest.constants import Cardinality
|
|
9
|
+
from palimpzest.core.data.dataset import Dataset
|
|
11
10
|
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
11
|
+
from palimpzest.core.models import GenerationStats, PlanStats, SentinelPlanStats
|
|
12
12
|
from palimpzest.policy import Policy
|
|
13
13
|
from palimpzest.query.operators.convert import LLMConvert
|
|
14
|
-
from palimpzest.query.operators.filter import
|
|
14
|
+
from palimpzest.query.operators.filter import LLMFilter
|
|
15
|
+
from palimpzest.query.operators.join import JoinOp
|
|
15
16
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
16
17
|
from palimpzest.query.operators.retrieve import RetrieveOp
|
|
17
|
-
from palimpzest.query.operators.scan import ScanPhysicalOp
|
|
18
|
+
from palimpzest.query.operators.scan import ContextScanOp, ScanPhysicalOp
|
|
18
19
|
from palimpzest.query.optimizer.plan import PhysicalPlan, SentinelPlan
|
|
19
20
|
from palimpzest.utils.progress import PZSentinelProgressManager
|
|
21
|
+
from palimpzest.validator.validator import Validator
|
|
20
22
|
|
|
21
23
|
logger = logging.getLogger(__name__)
|
|
22
24
|
|
|
@@ -24,35 +26,20 @@ class BaseExecutionStrategy:
|
|
|
24
26
|
def __init__(self,
|
|
25
27
|
scan_start_idx: int = 0,
|
|
26
28
|
max_workers: int | None = None,
|
|
29
|
+
batch_size: int | None = None,
|
|
27
30
|
num_samples: int | None = None,
|
|
28
|
-
cache: bool = False,
|
|
29
31
|
verbose: bool = False,
|
|
30
32
|
progress: bool = True,
|
|
31
33
|
*args,
|
|
32
34
|
**kwargs):
|
|
33
35
|
self.scan_start_idx = scan_start_idx
|
|
34
36
|
self.max_workers = max_workers
|
|
37
|
+
self.batch_size = batch_size
|
|
35
38
|
self.num_samples = num_samples
|
|
36
|
-
self.cache = cache
|
|
37
39
|
self.verbose = verbose
|
|
38
40
|
self.progress = progress
|
|
39
41
|
|
|
40
42
|
|
|
41
|
-
def _add_records_to_cache(self, target_cache_id: str, records: list[DataRecord]) -> None:
|
|
42
|
-
"""Add each record (which isn't filtered) to the cache for the given target_cache_id."""
|
|
43
|
-
if self.cache:
|
|
44
|
-
for record in records:
|
|
45
|
-
if getattr(record, "passed_operator", True):
|
|
46
|
-
# self.datadir.append_cache(target_cache_id, record)
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
|
-
def _close_cache(self, target_cache_ids: list[str]) -> None:
|
|
50
|
-
"""Close the cache for each of the given target_cache_ids"""
|
|
51
|
-
if self.cache:
|
|
52
|
-
for target_cache_id in target_cache_ids: # noqa: B007
|
|
53
|
-
# self.datadir.close_cache(target_cache_id)
|
|
54
|
-
pass
|
|
55
|
-
|
|
56
43
|
class ExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
57
44
|
"""Base strategy for executing query plans. Defines how to execute a PhysicalPlan.
|
|
58
45
|
"""
|
|
@@ -66,19 +53,24 @@ class ExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
66
53
|
"""Execute a single plan according to strategy"""
|
|
67
54
|
pass
|
|
68
55
|
|
|
69
|
-
def _create_input_queues(self, plan: PhysicalPlan) -> dict[str, list]:
|
|
56
|
+
def _create_input_queues(self, plan: PhysicalPlan) -> dict[str, dict[str, list]]:
|
|
70
57
|
"""Initialize input queues for each operator in the plan."""
|
|
71
|
-
input_queues = {}
|
|
72
|
-
for op in plan
|
|
73
|
-
|
|
58
|
+
input_queues = {f"{topo_idx}-{op.get_full_op_id()}": {} for topo_idx, op in enumerate(plan)}
|
|
59
|
+
for topo_idx, op in enumerate(plan):
|
|
60
|
+
full_op_id = op.get_full_op_id()
|
|
61
|
+
unique_op_id = f"{topo_idx}-{full_op_id}"
|
|
74
62
|
if isinstance(op, ScanPhysicalOp):
|
|
75
63
|
scan_end_idx = (
|
|
76
|
-
len(op.
|
|
64
|
+
len(op.datasource)
|
|
77
65
|
if self.num_samples is None
|
|
78
|
-
else min(self.scan_start_idx + self.num_samples, len(op.
|
|
66
|
+
else min(self.scan_start_idx + self.num_samples, len(op.datasource))
|
|
79
67
|
)
|
|
80
|
-
|
|
81
|
-
|
|
68
|
+
input_queues[unique_op_id][f"source_{full_op_id}"] = [idx for idx in range(self.scan_start_idx, scan_end_idx)]
|
|
69
|
+
elif isinstance(op, ContextScanOp):
|
|
70
|
+
input_queues[unique_op_id][f"source_{full_op_id}"] = [None]
|
|
71
|
+
else:
|
|
72
|
+
for source_unique_full_op_id in plan.get_source_unique_full_op_ids(topo_idx, op):
|
|
73
|
+
input_queues[unique_op_id][source_unique_full_op_id] = []
|
|
82
74
|
|
|
83
75
|
return input_queues
|
|
84
76
|
|
|
@@ -90,7 +82,6 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
90
82
|
"""
|
|
91
83
|
def __init__(
|
|
92
84
|
self,
|
|
93
|
-
val_datasource: DataReader,
|
|
94
85
|
k: int,
|
|
95
86
|
j: int,
|
|
96
87
|
sample_budget: int,
|
|
@@ -103,7 +94,6 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
103
94
|
**kwargs,
|
|
104
95
|
):
|
|
105
96
|
super().__init__(*args, **kwargs)
|
|
106
|
-
self.val_datasource = val_datasource
|
|
107
97
|
self.k = k
|
|
108
98
|
self.j = j
|
|
109
99
|
self.sample_budget = sample_budget
|
|
@@ -114,292 +104,200 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
114
104
|
self.rng = np.random.default_rng(seed=seed)
|
|
115
105
|
self.exp_name = exp_name
|
|
116
106
|
|
|
117
|
-
# special cache which is used for tracking the target record sets for each (source_idx, logical_op_id)
|
|
118
|
-
self.champion_output_cache: dict[int, dict[str, tuple[DataRecordSet, float]]] = {}
|
|
119
|
-
|
|
120
107
|
# general cache which maps hash(logical_op_id, phys_op_id, hash(input)) --> record_set
|
|
121
108
|
self.cache: dict[int, DataRecordSet] = {}
|
|
122
109
|
|
|
123
110
|
# progress manager used to track progress of the execution
|
|
124
111
|
self.progress_manager: PZSentinelProgressManager | None = None
|
|
125
112
|
|
|
126
|
-
def _compute_quality(
|
|
127
|
-
self,
|
|
128
|
-
physical_op_cls: type[PhysicalOperator],
|
|
129
|
-
record_set: DataRecordSet,
|
|
130
|
-
target_record_set: DataRecordSet,
|
|
131
|
-
) -> DataRecordSet:
|
|
132
|
-
"""
|
|
133
|
-
Compute the quality for the given `record_set` by comparing it to the `target_record_set`.
|
|
134
|
-
|
|
135
|
-
Update the record_set by assigning the quality to each entry in its record_op_stats and
|
|
136
|
-
returning the updated record_set.
|
|
137
|
-
"""
|
|
138
|
-
# if this operation failed
|
|
139
|
-
if len(record_set) == 0:
|
|
140
|
-
record_set.record_op_stats[0].quality = 0.0
|
|
141
|
-
|
|
142
|
-
# if this operation is a filter:
|
|
143
|
-
# - return 1.0 if there's a match in the expected output which this operator does not filter out and 0.0 otherwise
|
|
144
|
-
elif issubclass(physical_op_cls, FilterOp):
|
|
145
|
-
# NOTE: we know that record_set.data_records will contain a single entry for a filter op
|
|
146
|
-
record = record_set.data_records[0]
|
|
147
|
-
|
|
148
|
-
# search for a record in the target with the same set of fields
|
|
149
|
-
found_match_in_target = False
|
|
150
|
-
for target_record in target_record_set:
|
|
151
|
-
all_correct = True
|
|
152
|
-
for field, value in record.field_values.items():
|
|
153
|
-
if value != target_record[field]:
|
|
154
|
-
all_correct = False
|
|
155
|
-
break
|
|
156
|
-
|
|
157
|
-
if all_correct:
|
|
158
|
-
found_match_in_target = target_record.passed_operator
|
|
159
|
-
break
|
|
160
|
-
|
|
161
|
-
# set quality based on whether we found a match in the target and return
|
|
162
|
-
record_set.record_op_stats[0].quality = int(record.passed_operator == found_match_in_target)
|
|
163
|
-
|
|
164
|
-
return record_set
|
|
165
|
-
|
|
166
|
-
# if this is a successful convert operation
|
|
167
|
-
else:
|
|
168
|
-
# NOTE: the following computation assumes we do not project out computed values
|
|
169
|
-
# (and that the validation examples provide all computed fields); even if
|
|
170
|
-
# a user program does add projection, we can ignore the projection on the
|
|
171
|
-
# validation dataset and use the champion model (as opposed to the validation
|
|
172
|
-
# output) for scoring fields which have their values projected out
|
|
173
|
-
|
|
174
|
-
# GREEDY ALGORITHM
|
|
175
|
-
# for each record in the expected output, we look for the computed record which maximizes the quality metric;
|
|
176
|
-
# once we've identified that computed record we remove it from consideration for the next expected output
|
|
177
|
-
field_to_score_fn = target_record_set.get_field_to_score_fn()
|
|
178
|
-
for target_record in target_record_set:
|
|
179
|
-
best_quality, best_record_op_stats = 0.0, None
|
|
180
|
-
for record_op_stats in record_set.record_op_stats:
|
|
181
|
-
# if we already assigned this record a quality, skip it
|
|
182
|
-
if record_op_stats.quality is not None:
|
|
183
|
-
continue
|
|
184
|
-
|
|
185
|
-
# compute number of matches between this record's computed fields and this expected record's outputs
|
|
186
|
-
total_quality = 0
|
|
187
|
-
for field in record_op_stats.generated_fields:
|
|
188
|
-
computed_value = record_op_stats.record_state.get(field, None)
|
|
189
|
-
expected_value = target_record[field]
|
|
190
|
-
|
|
191
|
-
# get the metric function for this field
|
|
192
|
-
score_fn = field_to_score_fn.get(field, "exact")
|
|
193
|
-
|
|
194
|
-
# compute exact match
|
|
195
|
-
if score_fn == "exact":
|
|
196
|
-
total_quality += int(computed_value == expected_value)
|
|
197
|
-
|
|
198
|
-
# compute UDF metric
|
|
199
|
-
elif callable(score_fn):
|
|
200
|
-
total_quality += score_fn(computed_value, expected_value)
|
|
201
|
-
|
|
202
|
-
# otherwise, throw an exception
|
|
203
|
-
else:
|
|
204
|
-
raise Exception(f"Unrecognized score_fn: {score_fn}")
|
|
205
|
-
|
|
206
|
-
# compute recall and update best seen so far
|
|
207
|
-
quality = total_quality / len(record_op_stats.generated_fields)
|
|
208
|
-
if quality > best_quality:
|
|
209
|
-
best_quality = quality
|
|
210
|
-
best_record_op_stats = record_op_stats
|
|
211
|
-
|
|
212
|
-
# set best_quality as quality for the best_record_op_stats
|
|
213
|
-
if best_record_op_stats is not None:
|
|
214
|
-
best_record_op_stats.quality = best_quality
|
|
215
|
-
|
|
216
|
-
# for any records which did not receive a quality, set it to 0.0 as these are unexpected extras
|
|
217
|
-
for record_op_stats in record_set.record_op_stats:
|
|
218
|
-
if record_op_stats.quality is None:
|
|
219
|
-
record_op_stats.quality = 0.0
|
|
220
|
-
|
|
221
|
-
return record_set
|
|
222
|
-
|
|
223
113
|
def _score_quality(
|
|
224
114
|
self,
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
) -> dict[int, list[DataRecordSet]]:
|
|
229
|
-
"""
|
|
230
|
-
NOTE: This approach to cost modeling does not work directly for aggregation queries;
|
|
231
|
-
for these queries, we would ask the user to provide validation data for the step immediately
|
|
232
|
-
before a final aggregation
|
|
233
|
-
|
|
234
|
-
NOTE: This function currently assumes that one-to-many converts do NOT create duplicate outputs.
|
|
235
|
-
This assumption would break if, for example, we extracted the breed of every dog in an image.
|
|
236
|
-
If there were two golden retrievers and a bernoodle in an image and we extracted:
|
|
237
|
-
|
|
238
|
-
{"image": "file1.png", "breed": "Golden Retriever"}
|
|
239
|
-
{"image": "file1.png", "breed": "Golden Retriever"}
|
|
240
|
-
{"image": "file1.png", "breed": "Bernedoodle"}
|
|
241
|
-
|
|
242
|
-
This function would currently give perfect accuracy to the following output:
|
|
243
|
-
|
|
244
|
-
{"image": "file1.png", "breed": "Golden Retriever"}
|
|
245
|
-
{"image": "file1.png", "breed": "Bernedoodle"}
|
|
246
|
-
|
|
247
|
-
Even though it is missing one of the golden retrievers.
|
|
248
|
-
"""
|
|
115
|
+
validator: Validator,
|
|
116
|
+
source_indices_to_record_sets: dict[tuple[str], list[tuple[DataRecordSet, PhysicalOperator]]],
|
|
117
|
+
) -> tuple[dict[int, list[DataRecordSet]], GenerationStats]:
|
|
249
118
|
# extract information about the logical operation performed at this stage of the sentinel plan;
|
|
250
119
|
# NOTE: we can infer these fields from context clues, but in the long-term we should have a more
|
|
251
120
|
# principled way of getting these directly from attributes either stored in the sentinel_plan
|
|
252
121
|
# or in the PhysicalOperator
|
|
253
|
-
is_perfect_quality_op
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
122
|
+
def is_perfect_quality_op(op: PhysicalOperator):
|
|
123
|
+
return (
|
|
124
|
+
not isinstance(op, LLMConvert)
|
|
125
|
+
and not isinstance(op, LLMFilter)
|
|
126
|
+
and not isinstance(op, RetrieveOp)
|
|
127
|
+
and not isinstance(op, JoinOp)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# create minimal set of futures necessary to compute quality of each output record
|
|
131
|
+
futures, full_hashes, full_hash_to_bool_output = [], set(), {}
|
|
132
|
+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
133
|
+
for _, record_set_tuples in source_indices_to_record_sets.items():
|
|
134
|
+
for record_set, op in record_set_tuples:
|
|
135
|
+
# if this operation does not involve an LLM, every record_op_stats object gets perfect quality
|
|
136
|
+
if is_perfect_quality_op(op):
|
|
137
|
+
for record_op_stats in record_set.record_op_stats:
|
|
138
|
+
record_op_stats.quality = 1.0
|
|
139
|
+
continue
|
|
258
140
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
for record_set in record_sets:
|
|
264
|
-
for record_op_stats in record_set.record_op_stats:
|
|
265
|
-
record_op_stats.quality = 1.0
|
|
266
|
-
continue
|
|
141
|
+
# if the operation failed, assign 0.0 quality
|
|
142
|
+
if len(record_set) == 0:
|
|
143
|
+
record_set.record_op_stats[0].quality = 0.0
|
|
144
|
+
continue
|
|
267
145
|
|
|
268
|
-
|
|
269
|
-
|
|
146
|
+
# create future for map
|
|
147
|
+
if isinstance(op, LLMConvert) and op.cardinality is Cardinality.ONE_TO_ONE:
|
|
148
|
+
fields = op.generated_fields
|
|
149
|
+
input_record: DataRecord = record_set.input
|
|
150
|
+
output = record_set.data_records[0].to_dict(project_cols=fields)
|
|
151
|
+
output_str = record_set.data_records[0].to_json_str(project_cols=fields, bytes_to_str=True, sorted=True)
|
|
152
|
+
full_hash = f"{hash(input_record)}{hash(output_str)}"
|
|
153
|
+
if full_hash not in full_hashes:
|
|
154
|
+
full_hashes.add(full_hash)
|
|
155
|
+
futures.append(executor.submit(validator._score_map, op, fields, input_record, output, full_hash))
|
|
156
|
+
|
|
157
|
+
# create future for flat map
|
|
158
|
+
elif isinstance(op, LLMConvert) and op.cardinality is Cardinality.ONE_TO_MANY:
|
|
159
|
+
fields = op.generated_fields
|
|
160
|
+
input_record: DataRecord = record_set.input
|
|
161
|
+
output, output_strs = [], []
|
|
162
|
+
for data_record in record_set.data_records:
|
|
163
|
+
output.append(data_record.to_dict(project_cols=fields))
|
|
164
|
+
output_strs.append(data_record.to_json_str(project_cols=fields, bytes_to_str=True, sorted=True))
|
|
165
|
+
full_hash = f"{hash(input_record)}{hash(tuple(sorted(output_strs)))}"
|
|
166
|
+
if full_hash not in full_hashes:
|
|
167
|
+
full_hashes.add(full_hash)
|
|
168
|
+
futures.append(executor.submit(validator._score_flat_map, op, fields, input_record, output, full_hash))
|
|
169
|
+
|
|
170
|
+
# create future for retrieve
|
|
171
|
+
elif isinstance(op, RetrieveOp):
|
|
172
|
+
fields = op.generated_fields
|
|
173
|
+
input_record: DataRecord = record_set.input
|
|
174
|
+
output = record_set.data_records[0].to_dict(project_cols=fields)
|
|
175
|
+
output_str = record_set.data_records[0].to_json_str(project_cols=fields, bytes_to_str=True, sorted=True)
|
|
176
|
+
full_hash = f"{hash(input_record)}{hash(output_str)}"
|
|
177
|
+
if full_hash not in full_hashes:
|
|
178
|
+
full_hashes.add(full_hash)
|
|
179
|
+
futures.append(executor.submit(validator._score_retrieve, op, fields, input_record, output, full_hash))
|
|
180
|
+
|
|
181
|
+
# create future for filter
|
|
182
|
+
elif isinstance(op, LLMFilter):
|
|
183
|
+
filter_str = op.filter_obj.filter_condition
|
|
184
|
+
input_record: DataRecord = record_set.input
|
|
185
|
+
output = record_set.data_records[0].passed_operator
|
|
186
|
+
full_hash = f"{filter_str}{hash(input_record)}"
|
|
187
|
+
if full_hash not in full_hashes:
|
|
188
|
+
full_hash_to_bool_output[full_hash] = output
|
|
189
|
+
full_hashes.add(full_hash)
|
|
190
|
+
futures.append(executor.submit(validator._score_filter, op, filter_str, input_record, output, full_hash))
|
|
191
|
+
|
|
192
|
+
# create future for join
|
|
193
|
+
elif isinstance(op, JoinOp):
|
|
194
|
+
condition = op.condition
|
|
195
|
+
for left_idx, left_input_record in enumerate(record_set.input[0]):
|
|
196
|
+
for right_idx, right_input_record in enumerate(record_set.input[1]):
|
|
197
|
+
record_idx = left_idx * len(record_set.input[1]) + right_idx
|
|
198
|
+
output = record_set.data_records[record_idx].passed_operator
|
|
199
|
+
full_hash = f"{condition}{hash(left_input_record)}{hash(right_input_record)}"
|
|
200
|
+
if full_hash not in full_hashes:
|
|
201
|
+
full_hash_to_bool_output[full_hash] = output
|
|
202
|
+
full_hashes.add(full_hash)
|
|
203
|
+
futures.append(executor.submit(validator._score_join, op, condition, left_input_record, right_input_record, output, full_hash))
|
|
204
|
+
|
|
205
|
+
# collect results from futures
|
|
206
|
+
full_hash_to_score, validation_gen_stats = {}, GenerationStats()
|
|
207
|
+
for future in as_completed(futures):
|
|
208
|
+
score, gen_stats, full_hash = future.result()
|
|
209
|
+
full_hash_to_score[full_hash] = score
|
|
210
|
+
validation_gen_stats += gen_stats
|
|
270
211
|
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
212
|
+
# compute quality of each output computed by this operator
|
|
213
|
+
for _, record_set_tuples in source_indices_to_record_sets.items():
|
|
214
|
+
for record_set, op in record_set_tuples:
|
|
215
|
+
if is_perfect_quality_op(op) or len(record_set) == 0:
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
if isinstance(op, LLMConvert) and op.cardinality is Cardinality.ONE_TO_ONE:
|
|
219
|
+
fields = op.generated_fields
|
|
220
|
+
input_record: DataRecord = record_set.input
|
|
221
|
+
output_str = record_set.data_records[0].to_json_str(project_cols=fields, bytes_to_str=True, sorted=True)
|
|
222
|
+
full_hash = f"{hash(input_record)}{hash(output_str)}"
|
|
223
|
+
record_set.record_op_stats[0].quality = full_hash_to_score[full_hash]
|
|
224
|
+
|
|
225
|
+
elif isinstance(op, LLMConvert) and op.cardinality is Cardinality.ONE_TO_MANY:
|
|
226
|
+
fields = op.generated_fields
|
|
227
|
+
input_record: DataRecord = record_set.input
|
|
228
|
+
output_strs = []
|
|
229
|
+
for data_record in record_set.data_records:
|
|
230
|
+
output_strs.append(data_record.to_json_str(project_cols=fields, bytes_to_str=True, sorted=True))
|
|
231
|
+
full_hash = f"{hash(input_record)}{hash(tuple(sorted(output_strs)))}"
|
|
232
|
+
score = full_hash_to_score[full_hash]
|
|
233
|
+
for record_op_stats in record_set.record_op_stats:
|
|
234
|
+
record_op_stats.quality = score
|
|
235
|
+
|
|
236
|
+
# TODO: this scoring function will (likely) bias towards small values of k since it
|
|
237
|
+
# measures precision and not recall / F1; will need to revisit this in the future
|
|
238
|
+
elif isinstance(op, RetrieveOp):
|
|
239
|
+
fields = op.generated_fields
|
|
240
|
+
input_record: DataRecord = record_set.input
|
|
241
|
+
output_str = record_set.data_records[0].to_json_str(project_cols=fields, bytes_to_str=True, sorted=True)
|
|
242
|
+
full_hash = f"{hash(input_record)}{hash(output_str)}"
|
|
243
|
+
score = full_hash_to_score[full_hash]
|
|
244
|
+
record_set.record_op_stats[0].quality = score
|
|
245
|
+
|
|
246
|
+
elif isinstance(op, LLMFilter):
|
|
247
|
+
filter_str = op.filter_obj.filter_condition
|
|
248
|
+
input_record: DataRecord = record_set.input
|
|
249
|
+
output = record_set.data_records[0].passed_operator
|
|
250
|
+
full_hash = f"{filter_str}{hash(input_record)}"
|
|
251
|
+
if output == full_hash_to_bool_output[full_hash]:
|
|
252
|
+
record_set.record_op_stats[0].quality = full_hash_to_score[full_hash]
|
|
253
|
+
else:
|
|
254
|
+
record_set.record_op_stats[0].quality = 1.0 - full_hash_to_score[full_hash]
|
|
255
|
+
|
|
256
|
+
elif isinstance(op, JoinOp):
|
|
257
|
+
condition = op.condition
|
|
258
|
+
for left_idx, left_input_record in enumerate(record_set.input[0]):
|
|
259
|
+
for right_idx, right_input_record in enumerate(record_set.input[1]):
|
|
260
|
+
record_idx = left_idx * len(record_set.input[1]) + right_idx
|
|
261
|
+
output = record_set.data_records[record_idx].passed_operator
|
|
262
|
+
full_hash = f"{condition}{hash(left_input_record)}{hash(right_input_record)}"
|
|
263
|
+
if output == full_hash_to_bool_output[full_hash]:
|
|
264
|
+
record_set.record_op_stats[record_idx].quality = full_hash_to_score[full_hash]
|
|
265
|
+
else:
|
|
266
|
+
record_set.record_op_stats[record_idx].quality = 1.0 - full_hash_to_score[full_hash]
|
|
274
267
|
|
|
275
268
|
# return the quality annotated record sets
|
|
276
|
-
return
|
|
277
|
-
|
|
278
|
-
def
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
base_target_record = None
|
|
290
|
-
for record_set, _, _ in record_set_tuples:
|
|
291
|
-
if len(record_set) > 0:
|
|
292
|
-
base_target_record = record_set[0]
|
|
293
|
-
break
|
|
294
|
-
|
|
295
|
-
# compute availability of data
|
|
296
|
-
base_target_present = base_target_record is not None
|
|
297
|
-
labels_present = expected_outputs is not None
|
|
298
|
-
labels_for_source_present = False
|
|
299
|
-
if labels_present and source_idx in expected_outputs:
|
|
300
|
-
labels = expected_outputs[source_idx].get("labels", [])
|
|
301
|
-
labels_dict_lst = labels if isinstance(labels, list) else [labels]
|
|
302
|
-
labels_for_source_present = labels_dict_lst != [] and labels_dict_lst != [None]
|
|
303
|
-
|
|
304
|
-
# if we have a base target record and label info, use the label info to construct the target record set
|
|
305
|
-
if base_target_present and labels_for_source_present:
|
|
306
|
-
# get the field_to_score_fn
|
|
307
|
-
field_to_score_fn = expected_outputs[source_idx].get("score_fn", {})
|
|
308
|
-
|
|
309
|
-
# construct the target record set; we force passed_operator to be True for all target records
|
|
310
|
-
target_records = []
|
|
311
|
-
for labels_dict in labels_dict_lst:
|
|
312
|
-
target_record = base_target_record.copy()
|
|
313
|
-
for field, value in labels_dict.items():
|
|
314
|
-
target_record[field] = value
|
|
315
|
-
target_record.passed_operator = True
|
|
316
|
-
target_records.append(target_record)
|
|
317
|
-
|
|
318
|
-
source_idx_to_target_record_set[source_idx] = DataRecordSet(target_records, None, field_to_score_fn)
|
|
319
|
-
continue
|
|
320
|
-
|
|
321
|
-
# get the best computed output for this (source_idx, logical_op_id) so far (if one exists)
|
|
322
|
-
champion_record_set, champion_op_quality = None, None
|
|
323
|
-
if source_idx in self.champion_output_cache and logical_op_id in self.champion_output_cache[source_idx]:
|
|
324
|
-
champion_record_set, champion_op_quality = self.champion_output_cache[source_idx][logical_op_id]
|
|
325
|
-
|
|
326
|
-
# get the highest quality output that we just computed
|
|
327
|
-
max_quality_record_set, max_op_quality = self._pick_champion_output(record_set_tuples)
|
|
328
|
-
|
|
329
|
-
# if this new output is of higher quality than our previous champion (or if we didn't have
|
|
330
|
-
# a previous champion) then we update our champion record set
|
|
331
|
-
if champion_op_quality is None or (max_op_quality is not None and max_op_quality > champion_op_quality):
|
|
332
|
-
champion_record_set, champion_op_quality = max_quality_record_set, max_op_quality
|
|
333
|
-
|
|
334
|
-
# update the cache with the new champion record set and quality
|
|
335
|
-
if source_idx not in self.champion_output_cache:
|
|
336
|
-
self.champion_output_cache[source_idx] = {}
|
|
337
|
-
self.champion_output_cache[source_idx][logical_op_id] = (champion_record_set, champion_op_quality)
|
|
338
|
-
|
|
339
|
-
# set the target
|
|
340
|
-
source_idx_to_target_record_set[source_idx] = champion_record_set
|
|
341
|
-
|
|
342
|
-
return source_idx_to_target_record_set
|
|
343
|
-
|
|
344
|
-
def _pick_champion_output(self, record_set_tuples: list[tuple[DataRecordSet, PhysicalOperator, bool]]) -> tuple[DataRecordSet, float | None]:
|
|
345
|
-
# find the operator with the highest estimated quality and return its record_set
|
|
346
|
-
base_op_cost_est = OperatorCostEstimates(cardinality=1.0, cost_per_record=0.0, time_per_record=0.0, quality=1.0)
|
|
347
|
-
champion_record_set, champion_quality = None, None
|
|
348
|
-
for record_set, op, _ in record_set_tuples:
|
|
349
|
-
# skip failed operations
|
|
350
|
-
if len(record_set) == 0:
|
|
351
|
-
continue
|
|
352
|
-
|
|
353
|
-
# get the estimated quality of this operator
|
|
354
|
-
est_quality = op.naive_cost_estimates(base_op_cost_est).quality if self._is_llm_op(op) else 1.0
|
|
355
|
-
if champion_quality is None or est_quality > champion_quality:
|
|
356
|
-
champion_record_set, champion_quality = record_set, est_quality
|
|
357
|
-
|
|
358
|
-
return champion_record_set, champion_quality
|
|
359
|
-
|
|
360
|
-
def _flatten_record_sets(self, source_idx_to_record_sets: dict[int, list[DataRecordSet]]) -> tuple[list[DataRecord], list[RecordOpStats]]:
|
|
361
|
-
"""
|
|
362
|
-
Flatten the list of record sets and record op stats for each source_idx.
|
|
363
|
-
"""
|
|
364
|
-
all_records, all_record_op_stats = [], []
|
|
365
|
-
for _, record_sets in source_idx_to_record_sets.items():
|
|
366
|
-
for record_set in record_sets:
|
|
367
|
-
all_records.extend(record_set.data_records)
|
|
368
|
-
all_record_op_stats.extend(record_set.record_op_stats)
|
|
369
|
-
|
|
370
|
-
return all_records, all_record_op_stats
|
|
371
|
-
|
|
372
|
-
def _execute_op_set(self, op_input_pairs: list[tuple[PhysicalOperator, DataRecord | int]]) -> tuple[dict[int, list[tuple[DataRecordSet, PhysicalOperator, bool]]], dict[str, int]]:
|
|
373
|
-
def execute_op_wrapper(operator, input) -> tuple[DataRecordSet, PhysicalOperator, DataRecord | int]:
|
|
374
|
-
record_set = operator(input)
|
|
375
|
-
return record_set, operator, input
|
|
376
|
-
|
|
377
|
-
# TODO: modify unit tests to always have record_op_stats so we can use record_op_stats for source_idx
|
|
378
|
-
# for scan operators, `input` will be the source_idx
|
|
379
|
-
def get_source_idx(input):
|
|
380
|
-
return input.source_idx if isinstance(input, DataRecord) else input
|
|
381
|
-
|
|
382
|
-
def get_hash(operator, input):
|
|
269
|
+
return source_indices_to_record_sets, validation_gen_stats
|
|
270
|
+
|
|
271
|
+
def _execute_op_set(self, unique_logical_op_id: str, op_inputs: list[tuple[PhysicalOperator, str | tuple, int | DataRecord | list[DataRecord] | tuple[list[DataRecord]]]]) -> tuple[dict[int, list[tuple[DataRecordSet, PhysicalOperator, bool]]], dict[str, int]]:
|
|
272
|
+
def execute_op_wrapper(operator: PhysicalOperator, source_indices: str | tuple, input: int | DataRecord | list[DataRecord] | tuple[list[DataRecord]]) -> tuple[DataRecordSet, PhysicalOperator, list[DataRecord] | list[int]]:
|
|
273
|
+
# operator is a join
|
|
274
|
+
record_set = operator(input[0], input[1]) if isinstance(operator, JoinOp) else operator(input)
|
|
275
|
+
return record_set, operator, source_indices, input
|
|
276
|
+
|
|
277
|
+
def get_hash(operator: PhysicalOperator, input: int | DataRecord | list[DataRecord] | tuple[list[DataRecord]]):
|
|
278
|
+
if isinstance(input, list):
|
|
279
|
+
input = tuple(input)
|
|
280
|
+
elif isinstance(input, tuple):
|
|
281
|
+
input = (tuple(input[0]), tuple(input[1]))
|
|
383
282
|
return hash(f"{operator.get_full_op_id()}{hash(input)}")
|
|
384
283
|
|
|
385
284
|
# initialize mapping from source indices to output record sets
|
|
386
|
-
|
|
285
|
+
source_indices_to_record_sets_and_ops = {source_indices: [] for _, source_indices, _ in op_inputs}
|
|
387
286
|
|
|
388
287
|
# if any operations were previously executed, read the results from the cache
|
|
389
|
-
|
|
390
|
-
for operator, input in
|
|
288
|
+
final_op_inputs = []
|
|
289
|
+
for operator, source_indices, input in op_inputs:
|
|
391
290
|
# compute hash
|
|
392
291
|
op_input_hash = get_hash(operator, input)
|
|
393
292
|
|
|
394
293
|
# get result from cache
|
|
395
294
|
if op_input_hash in self.cache:
|
|
396
|
-
source_idx = get_source_idx(input)
|
|
397
295
|
record_set, operator = self.cache[op_input_hash]
|
|
398
|
-
|
|
296
|
+
source_indices_to_record_sets_and_ops[source_indices].append((record_set, operator, False))
|
|
399
297
|
|
|
400
|
-
# otherwise, add to
|
|
298
|
+
# otherwise, add to final_op_inputs
|
|
401
299
|
else:
|
|
402
|
-
|
|
300
|
+
final_op_inputs.append((operator, source_indices, input))
|
|
403
301
|
|
|
404
302
|
# keep track of the number of llm operations
|
|
405
303
|
num_llm_ops = 0
|
|
@@ -408,46 +306,41 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
408
306
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
409
307
|
# create futures
|
|
410
308
|
futures = [
|
|
411
|
-
executor.submit(execute_op_wrapper, operator, input)
|
|
412
|
-
for operator, input in
|
|
309
|
+
executor.submit(execute_op_wrapper, operator, source_indices, input)
|
|
310
|
+
for operator, source_indices, input in final_op_inputs
|
|
413
311
|
]
|
|
414
|
-
output_record_sets = []
|
|
415
|
-
while len(futures) > 0:
|
|
416
|
-
done_futures, not_done_futures = wait(futures, timeout=PARALLEL_EXECUTION_SLEEP_INTERVAL_SECS)
|
|
417
|
-
for future in done_futures:
|
|
418
|
-
# update output record sets
|
|
419
|
-
record_set, operator, input = future.result()
|
|
420
|
-
output_record_sets.append((record_set, operator, input))
|
|
421
|
-
|
|
422
|
-
# update cache
|
|
423
|
-
op_input_hash = get_hash(operator, input)
|
|
424
|
-
self.cache[op_input_hash] = (record_set, operator)
|
|
425
312
|
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
313
|
+
output_record_sets = []
|
|
314
|
+
for future in as_completed(futures):
|
|
315
|
+
# update output record sets
|
|
316
|
+
record_set, operator, source_indices, input = future.result()
|
|
317
|
+
output_record_sets.append((record_set, operator, source_indices, input))
|
|
430
318
|
|
|
431
|
-
# update
|
|
432
|
-
|
|
319
|
+
# update cache
|
|
320
|
+
op_input_hash = get_hash(operator, input)
|
|
321
|
+
self.cache[op_input_hash] = (record_set, operator)
|
|
433
322
|
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
323
|
+
# update progress manager
|
|
324
|
+
if self._is_llm_op(operator):
|
|
325
|
+
num_llm_ops += 1
|
|
326
|
+
self.progress_manager.incr(unique_logical_op_id, num_samples=1, total_cost=record_set.get_total_cost())
|
|
438
327
|
|
|
439
|
-
|
|
440
|
-
|
|
328
|
+
# update mapping from source_indices to record sets and operators
|
|
329
|
+
for record_set, operator, source_indices, input in output_record_sets:
|
|
330
|
+
# add record_set to mapping from source_indices --> record_sets
|
|
331
|
+
record_set.input = input
|
|
332
|
+
source_indices_to_record_sets_and_ops[source_indices].append((record_set, operator, True))
|
|
441
333
|
|
|
442
|
-
return
|
|
334
|
+
return source_indices_to_record_sets_and_ops, num_llm_ops
|
|
443
335
|
|
|
444
336
|
def _is_llm_op(self, physical_op: PhysicalOperator) -> bool:
|
|
445
337
|
is_llm_convert = isinstance(physical_op, LLMConvert)
|
|
446
338
|
is_llm_filter = isinstance(physical_op, LLMFilter)
|
|
447
339
|
is_llm_retrieve = isinstance(physical_op, RetrieveOp) and isinstance(physical_op.index, Collection)
|
|
448
|
-
|
|
340
|
+
is_llm_join = isinstance(physical_op, JoinOp)
|
|
341
|
+
return is_llm_convert or is_llm_filter or is_llm_retrieve or is_llm_join
|
|
449
342
|
|
|
450
343
|
@abstractmethod
|
|
451
|
-
def execute_sentinel_plan(self, sentinel_plan: SentinelPlan,
|
|
344
|
+
def execute_sentinel_plan(self, sentinel_plan: SentinelPlan, train_dataset: dict[str, Dataset], validator: Validator) -> SentinelPlanStats:
|
|
452
345
|
"""Execute a SentinelPlan according to strategy"""
|
|
453
346
|
pass
|