palimpzest 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +1 -0
- palimpzest/core/data/dataset.py +33 -5
- palimpzest/core/elements/groupbysig.py +5 -1
- palimpzest/core/elements/records.py +16 -7
- palimpzest/core/lib/schemas.py +20 -3
- palimpzest/core/models.py +4 -4
- palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
- palimpzest/query/execution/execution_strategy.py +8 -8
- palimpzest/query/execution/mab_execution_strategy.py +30 -11
- palimpzest/query/execution/parallel_execution_strategy.py +31 -7
- palimpzest/query/execution/single_threaded_execution_strategy.py +23 -6
- palimpzest/query/operators/__init__.py +7 -6
- palimpzest/query/operators/aggregate.py +110 -5
- palimpzest/query/operators/convert.py +1 -1
- palimpzest/query/operators/join.py +279 -23
- palimpzest/query/operators/logical.py +20 -8
- palimpzest/query/operators/mixture_of_agents.py +3 -1
- palimpzest/query/operators/physical.py +5 -2
- palimpzest/query/operators/{retrieve.py → topk.py} +10 -10
- palimpzest/query/optimizer/__init__.py +7 -3
- palimpzest/query/optimizer/cost_model.py +5 -5
- palimpzest/query/optimizer/optimizer.py +3 -2
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/rules.py +31 -11
- palimpzest/query/optimizer/tasks.py +4 -4
- palimpzest/utils/progress.py +19 -17
- palimpzest/validator/validator.py +7 -7
- {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/METADATA +26 -66
- {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/RECORD +32 -32
- {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import contextlib
|
|
3
4
|
import time
|
|
4
5
|
from typing import Any
|
|
5
6
|
|
|
@@ -14,7 +15,7 @@ from palimpzest.constants import (
|
|
|
14
15
|
)
|
|
15
16
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
16
17
|
from palimpzest.core.elements.records import DataRecord, DataRecordSet
|
|
17
|
-
from palimpzest.core.lib.schemas import Average, Count, Max, Min
|
|
18
|
+
from palimpzest.core.lib.schemas import Average, Count, Max, Min, Sum
|
|
18
19
|
from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
|
|
19
20
|
from palimpzest.query.generators.generators import Generator
|
|
20
21
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
@@ -68,6 +69,16 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
68
69
|
return 0
|
|
69
70
|
elif func.lower() == "average":
|
|
70
71
|
return (0, 0)
|
|
72
|
+
elif func.lower() == "sum":
|
|
73
|
+
return 0
|
|
74
|
+
elif func.lower() == "min":
|
|
75
|
+
return float("inf")
|
|
76
|
+
elif func.lower() == "max":
|
|
77
|
+
return float("-inf")
|
|
78
|
+
elif func.lower() == "list":
|
|
79
|
+
return []
|
|
80
|
+
elif func.lower() == "set":
|
|
81
|
+
return set()
|
|
71
82
|
else:
|
|
72
83
|
raise Exception("Unknown agg function " + func)
|
|
73
84
|
|
|
@@ -76,16 +87,34 @@ class ApplyGroupByOp(AggregateOp):
|
|
|
76
87
|
if func.lower() == "count":
|
|
77
88
|
return state + 1
|
|
78
89
|
elif func.lower() == "average":
|
|
79
|
-
|
|
90
|
+
sum_, cnt = state
|
|
91
|
+
if val is None:
|
|
92
|
+
return (sum_, cnt)
|
|
93
|
+
return (sum_ + val, cnt + 1)
|
|
94
|
+
elif func.lower() == "sum":
|
|
95
|
+
if val is None:
|
|
96
|
+
return state
|
|
97
|
+
return state + sum(val) if isinstance(val, list) else state + val
|
|
98
|
+
elif func.lower() == "min":
|
|
99
|
+
if val is None:
|
|
100
|
+
return state
|
|
101
|
+
return min(state, min(val) if isinstance(val, list) else val)
|
|
102
|
+
elif func.lower() == "max":
|
|
80
103
|
if val is None:
|
|
81
|
-
return
|
|
82
|
-
return (
|
|
104
|
+
return state
|
|
105
|
+
return max(state, max(val) if isinstance(val, list) else val)
|
|
106
|
+
elif func.lower() == "list":
|
|
107
|
+
state.append(val)
|
|
108
|
+
return state
|
|
109
|
+
elif func.lower() == "set":
|
|
110
|
+
state.add(val)
|
|
111
|
+
return state
|
|
83
112
|
else:
|
|
84
113
|
raise Exception("Unknown agg function " + func)
|
|
85
114
|
|
|
86
115
|
@staticmethod
|
|
87
116
|
def agg_final(func, state):
|
|
88
|
-
if func.lower()
|
|
117
|
+
if func.lower() in ["count", "sum", "min", "max", "list", "set"]:
|
|
89
118
|
return state
|
|
90
119
|
elif func.lower() == "average":
|
|
91
120
|
sum, cnt = state
|
|
@@ -240,6 +269,82 @@ class AverageAggregateOp(AggregateOp):
|
|
|
240
269
|
return DataRecordSet([dr], [record_op_stats])
|
|
241
270
|
|
|
242
271
|
|
|
272
|
+
class SumAggregateOp(AggregateOp):
|
|
273
|
+
# NOTE: we don't actually need / use agg_func here (yet)
|
|
274
|
+
|
|
275
|
+
def __init__(self, agg_func: AggFunc, *args, **kwargs):
|
|
276
|
+
# enforce that output schema is correct
|
|
277
|
+
assert kwargs["output_schema"].model_fields.keys() == Sum.model_fields.keys(), "SumAggregateOp requires output_schema to be Sum"
|
|
278
|
+
|
|
279
|
+
# enforce that input schema is a single numeric field
|
|
280
|
+
input_field_types = list(kwargs["input_schema"].model_fields.values())
|
|
281
|
+
assert len(input_field_types) == 1, "SumAggregateOp requires input_schema to have exactly one field"
|
|
282
|
+
numeric_field_types = [
|
|
283
|
+
bool, int, float, int | float,
|
|
284
|
+
bool | None, int | None, float | None, int | float | None,
|
|
285
|
+
bool | Any, int | Any, float | Any, int | float | Any,
|
|
286
|
+
bool | None | Any, int | None | Any, float | None | Any, int | float | None | Any,
|
|
287
|
+
]
|
|
288
|
+
is_numeric = input_field_types[0].annotation in numeric_field_types
|
|
289
|
+
assert is_numeric, f"SumAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
|
|
290
|
+
|
|
291
|
+
# call parent constructor
|
|
292
|
+
super().__init__(*args, **kwargs)
|
|
293
|
+
self.agg_func = agg_func
|
|
294
|
+
|
|
295
|
+
def __str__(self):
|
|
296
|
+
op = super().__str__()
|
|
297
|
+
op += f" Function: {str(self.agg_func)}\n"
|
|
298
|
+
return op
|
|
299
|
+
|
|
300
|
+
def get_id_params(self):
|
|
301
|
+
id_params = super().get_id_params()
|
|
302
|
+
return {"agg_func": str(self.agg_func), **id_params}
|
|
303
|
+
|
|
304
|
+
def get_op_params(self):
|
|
305
|
+
op_params = super().get_op_params()
|
|
306
|
+
return {"agg_func": self.agg_func, **op_params}
|
|
307
|
+
|
|
308
|
+
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
309
|
+
# for now, assume applying the aggregation takes negligible additional time (and no cost in USD)
|
|
310
|
+
return OperatorCostEstimates(
|
|
311
|
+
cardinality=1,
|
|
312
|
+
time_per_record=0,
|
|
313
|
+
cost_per_record=0,
|
|
314
|
+
quality=1.0,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
|
|
318
|
+
start_time = time.time()
|
|
319
|
+
|
|
320
|
+
# NOTE: we currently do not guarantee that input values conform to their specified type;
|
|
321
|
+
# as a result, we simply omit any values which do not parse to a float from the average
|
|
322
|
+
# NOTE: right now we perform a check in the constructor which enforces that the input_schema
|
|
323
|
+
# has a single field which is numeric in nature; in the future we may want to have a
|
|
324
|
+
# cleaner way of computing the value (rather than `float(list(candidate...))` below)
|
|
325
|
+
summation = 0
|
|
326
|
+
for candidate in candidates:
|
|
327
|
+
with contextlib.suppress(Exception):
|
|
328
|
+
summation += float(list(candidate.to_dict().values())[0])
|
|
329
|
+
data_item = Sum(sum=summation)
|
|
330
|
+
dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
|
|
331
|
+
|
|
332
|
+
# create RecordOpStats object
|
|
333
|
+
record_op_stats = RecordOpStats(
|
|
334
|
+
record_id=dr._id,
|
|
335
|
+
record_parent_ids=dr._parent_ids,
|
|
336
|
+
record_source_indices=dr._source_indices,
|
|
337
|
+
record_state=dr.to_dict(include_bytes=False),
|
|
338
|
+
full_op_id=self.get_full_op_id(),
|
|
339
|
+
logical_op_id=self.logical_op_id,
|
|
340
|
+
op_name=self.op_name(),
|
|
341
|
+
time_per_record=time.time() - start_time,
|
|
342
|
+
cost_per_record=0.0,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
return DataRecordSet([dr], [record_op_stats])
|
|
346
|
+
|
|
347
|
+
|
|
243
348
|
class CountAggregateOp(AggregateOp):
|
|
244
349
|
# NOTE: we don't actually need / use agg_func here (yet)
|
|
245
350
|
|
|
@@ -320,7 +320,7 @@ class LLMConvert(ConvertOp):
|
|
|
320
320
|
est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
|
|
321
321
|
|
|
322
322
|
# get est. of conversion time per record from model card;
|
|
323
|
-
model_name = self.model.value
|
|
323
|
+
model_name = self.model.value
|
|
324
324
|
model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
|
|
325
325
|
|
|
326
326
|
# get est. of conversion cost (in USD) per record from model card
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import threading
|
|
3
4
|
import time
|
|
4
5
|
from abc import ABC, abstractmethod
|
|
5
6
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
@@ -37,10 +38,9 @@ class JoinOp(PhysicalOperator, ABC):
|
|
|
37
38
|
def __init__(
|
|
38
39
|
self,
|
|
39
40
|
condition: str,
|
|
40
|
-
|
|
41
|
-
|
|
41
|
+
how: str = "inner",
|
|
42
|
+
on: list[str] | None = None,
|
|
42
43
|
join_parallelism: int = 64,
|
|
43
|
-
reasoning_effort: str | None = None,
|
|
44
44
|
retain_inputs: bool = True,
|
|
45
45
|
desc: str | None = None,
|
|
46
46
|
*args,
|
|
@@ -49,33 +49,37 @@ class JoinOp(PhysicalOperator, ABC):
|
|
|
49
49
|
super().__init__(*args, **kwargs)
|
|
50
50
|
assert self.input_schema == self.output_schema, "Input and output schemas must match for JoinOp"
|
|
51
51
|
self.condition = condition
|
|
52
|
-
self.
|
|
53
|
-
self.
|
|
52
|
+
self.how = how
|
|
53
|
+
self.on = on
|
|
54
54
|
self.join_parallelism = join_parallelism
|
|
55
|
-
self.reasoning_effort = reasoning_effort
|
|
56
55
|
self.retain_inputs = retain_inputs
|
|
57
56
|
self.desc = desc
|
|
58
|
-
self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
|
|
59
57
|
self.join_idx = 0
|
|
58
|
+
self.finished = False
|
|
60
59
|
|
|
61
60
|
# maintain list(s) of input records for the join
|
|
62
61
|
self._left_input_records: list[DataRecord] = []
|
|
63
62
|
self._right_input_records: list[DataRecord] = []
|
|
64
63
|
|
|
64
|
+
# maintain set of left/right record ids that have been joined (for left/right/outer joins)
|
|
65
|
+
self._left_joined_record_ids: set[str] = set()
|
|
66
|
+
self._right_joined_record_ids: set[str] = set()
|
|
67
|
+
|
|
65
68
|
def __str__(self):
|
|
66
69
|
op = super().__str__()
|
|
67
70
|
op += f" Condition: {self.condition}\n"
|
|
71
|
+
op += f" How: {self.how}\n"
|
|
72
|
+
op += f" On: {self.on}\n"
|
|
68
73
|
return op
|
|
69
74
|
|
|
70
75
|
def get_id_params(self):
|
|
71
76
|
id_params = super().get_id_params()
|
|
72
77
|
id_params = {
|
|
73
78
|
"condition": self.condition,
|
|
74
|
-
"model": self.model.value,
|
|
75
|
-
"prompt_strategy": self.prompt_strategy.value,
|
|
76
79
|
"join_parallelism": self.join_parallelism,
|
|
77
|
-
"reasoning_effort": self.reasoning_effort,
|
|
78
80
|
"desc": self.desc,
|
|
81
|
+
"how": self.how,
|
|
82
|
+
"on": self.on,
|
|
79
83
|
**id_params,
|
|
80
84
|
}
|
|
81
85
|
return id_params
|
|
@@ -84,23 +88,232 @@ class JoinOp(PhysicalOperator, ABC):
|
|
|
84
88
|
op_params = super().get_op_params()
|
|
85
89
|
op_params = {
|
|
86
90
|
"condition": self.condition,
|
|
87
|
-
"model": self.model,
|
|
88
|
-
"prompt_strategy": self.prompt_strategy,
|
|
89
91
|
"join_parallelism": self.join_parallelism,
|
|
90
|
-
"reasoning_effort": self.reasoning_effort,
|
|
91
92
|
"retain_inputs": self.retain_inputs,
|
|
92
93
|
"desc": self.desc,
|
|
94
|
+
"how": self.how,
|
|
95
|
+
"on": self.on,
|
|
93
96
|
**op_params,
|
|
94
97
|
}
|
|
95
98
|
return op_params
|
|
96
99
|
|
|
97
|
-
def
|
|
98
|
-
|
|
100
|
+
def _compute_unmatched_records(self) -> DataRecordSet:
|
|
101
|
+
"""Helper function to compute unmatched records for left/right/outer joins."""
|
|
102
|
+
def join_unmatched_records(input_records: list[DataRecord] | list[tuple[DataRecord, list[float]]], joined_record_ids: set[str], left: bool = True):
|
|
103
|
+
records, record_op_stats_lst = [], []
|
|
104
|
+
for record in input_records:
|
|
105
|
+
start_time = time.time()
|
|
106
|
+
record = record[0] if isinstance(record, tuple) else record
|
|
107
|
+
if record._id not in joined_record_ids:
|
|
108
|
+
unmatched_dr = (
|
|
109
|
+
DataRecord.from_join_parents(self.output_schema, record, None)
|
|
110
|
+
if left
|
|
111
|
+
else DataRecord.from_join_parents(self.output_schema, None, record)
|
|
112
|
+
)
|
|
113
|
+
unmatched_dr._passed_operator = True
|
|
114
|
+
|
|
115
|
+
# compute record stats and add to output_record_op_stats
|
|
116
|
+
time_per_record = time.time() - start_time
|
|
117
|
+
record_op_stats = RecordOpStats(
|
|
118
|
+
record_id=unmatched_dr._id,
|
|
119
|
+
record_parent_ids=unmatched_dr._parent_ids,
|
|
120
|
+
record_source_indices=unmatched_dr._source_indices,
|
|
121
|
+
record_state=unmatched_dr.to_dict(include_bytes=False),
|
|
122
|
+
full_op_id=self.get_full_op_id(),
|
|
123
|
+
logical_op_id=self.logical_op_id,
|
|
124
|
+
op_name=self.op_name(),
|
|
125
|
+
time_per_record=time_per_record,
|
|
126
|
+
cost_per_record=0.0,
|
|
127
|
+
model_name=self.get_model_name(),
|
|
128
|
+
join_condition=str(self.on),
|
|
129
|
+
fn_call_duration_secs=time_per_record,
|
|
130
|
+
answer={"passed_operator": True},
|
|
131
|
+
passed_operator=True,
|
|
132
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
133
|
+
)
|
|
134
|
+
records.append(unmatched_dr)
|
|
135
|
+
record_op_stats_lst.append(record_op_stats)
|
|
136
|
+
return records, record_op_stats_lst
|
|
137
|
+
|
|
138
|
+
records, record_op_stats = [], []
|
|
139
|
+
if self.how == "left":
|
|
140
|
+
records, record_op_stats = join_unmatched_records(self._left_input_records, self._left_joined_record_ids, left=True)
|
|
141
|
+
|
|
142
|
+
elif self.how == "right":
|
|
143
|
+
records, record_op_stats = join_unmatched_records(self._right_input_records, self._right_joined_record_ids, left=False)
|
|
144
|
+
|
|
145
|
+
elif self.how == "outer":
|
|
146
|
+
records, record_op_stats = join_unmatched_records(self._left_input_records, self._left_joined_record_ids, left=True)
|
|
147
|
+
right_records, right_record_op_stats = join_unmatched_records(self._right_input_records, self._right_joined_record_ids, left=False)
|
|
148
|
+
records.extend(right_records)
|
|
149
|
+
record_op_stats.extend(right_record_op_stats)
|
|
150
|
+
|
|
151
|
+
return DataRecordSet(records, record_op_stats)
|
|
99
152
|
|
|
100
153
|
@abstractmethod
|
|
101
154
|
def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
102
155
|
pass
|
|
103
156
|
|
|
157
|
+
def set_finished(self):
|
|
158
|
+
"""Mark the operator as finished after computing left/right/outer join logic."""
|
|
159
|
+
self.finished = True
|
|
160
|
+
|
|
161
|
+
class RelationalJoin(JoinOp):
|
|
162
|
+
|
|
163
|
+
def get_model_name(self):
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
def _process_join_candidate_pair(self, left_candidate, right_candidate) -> tuple[DataRecord, RecordOpStats]:
|
|
167
|
+
start_time = time.time()
|
|
168
|
+
|
|
169
|
+
# determine whether or not the join was satisfied
|
|
170
|
+
passed_operator = all(
|
|
171
|
+
left_candidate[field] == right_candidate[field]
|
|
172
|
+
for field in self.on
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# handle different join types
|
|
176
|
+
if self.how == "left" and passed_operator:
|
|
177
|
+
self._left_joined_record_ids.add(left_candidate._id)
|
|
178
|
+
elif self.how == "right" and passed_operator:
|
|
179
|
+
self._right_joined_record_ids.add(right_candidate._id)
|
|
180
|
+
elif self.how == "outer" and passed_operator:
|
|
181
|
+
self._left_joined_record_ids.add(left_candidate._id)
|
|
182
|
+
self._right_joined_record_ids.add(right_candidate._id)
|
|
183
|
+
|
|
184
|
+
# compute output record and add to output_records
|
|
185
|
+
join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
|
|
186
|
+
join_dr._passed_operator = passed_operator
|
|
187
|
+
|
|
188
|
+
# compute record stats and add to output_record_op_stats
|
|
189
|
+
time_per_record = time.time() - start_time
|
|
190
|
+
record_op_stats = RecordOpStats(
|
|
191
|
+
record_id=join_dr._id,
|
|
192
|
+
record_parent_ids=join_dr._parent_ids,
|
|
193
|
+
record_source_indices=join_dr._source_indices,
|
|
194
|
+
record_state=join_dr.to_dict(include_bytes=False),
|
|
195
|
+
full_op_id=self.get_full_op_id(),
|
|
196
|
+
logical_op_id=self.logical_op_id,
|
|
197
|
+
op_name=self.op_name(),
|
|
198
|
+
time_per_record=time_per_record,
|
|
199
|
+
cost_per_record=0.0,
|
|
200
|
+
model_name=self.get_model_name(),
|
|
201
|
+
join_condition=str(self.on),
|
|
202
|
+
fn_call_duration_secs=time_per_record,
|
|
203
|
+
answer={"passed_operator": passed_operator},
|
|
204
|
+
passed_operator=passed_operator,
|
|
205
|
+
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
return join_dr, record_op_stats
|
|
209
|
+
|
|
210
|
+
def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
|
|
211
|
+
# estimate output cardinality using a constant assumption of the filter selectivity
|
|
212
|
+
selectivity = NAIVE_EST_JOIN_SELECTIVITY
|
|
213
|
+
cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
|
|
214
|
+
|
|
215
|
+
# estimate 1 ms execution time per input record pair
|
|
216
|
+
time_per_record = 0.001 * (left_source_op_cost_estimates.cardinality + right_source_op_cost_estimates.cardinality)
|
|
217
|
+
|
|
218
|
+
return OperatorCostEstimates(
|
|
219
|
+
cardinality=cardinality,
|
|
220
|
+
time_per_record=time_per_record,
|
|
221
|
+
cost_per_record=0.0,
|
|
222
|
+
quality=1.0,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
|
|
226
|
+
# create the set of candidates to join
|
|
227
|
+
join_candidates = []
|
|
228
|
+
for candidate in left_candidates:
|
|
229
|
+
for right_candidate in right_candidates:
|
|
230
|
+
join_candidates.append((candidate, right_candidate))
|
|
231
|
+
for right_candidate in self._right_input_records:
|
|
232
|
+
join_candidates.append((candidate, right_candidate))
|
|
233
|
+
for candidate in self._left_input_records:
|
|
234
|
+
for right_candidate in right_candidates:
|
|
235
|
+
join_candidates.append((candidate, right_candidate))
|
|
236
|
+
|
|
237
|
+
# apply the join logic to each pair of candidates
|
|
238
|
+
output_records, output_record_op_stats = [], []
|
|
239
|
+
with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
|
|
240
|
+
futures = [
|
|
241
|
+
executor.submit(self._process_join_candidate_pair, candidate, right_candidate)
|
|
242
|
+
for candidate, right_candidate in join_candidates
|
|
243
|
+
]
|
|
244
|
+
|
|
245
|
+
# collect results as they complete
|
|
246
|
+
for future in as_completed(futures):
|
|
247
|
+
self.join_idx += 1
|
|
248
|
+
join_output_record, join_output_record_op_stats = future.result()
|
|
249
|
+
output_records.append(join_output_record)
|
|
250
|
+
output_record_op_stats.append(join_output_record_op_stats)
|
|
251
|
+
|
|
252
|
+
# compute the number of inputs processed
|
|
253
|
+
num_inputs_processed = len(join_candidates)
|
|
254
|
+
|
|
255
|
+
# store input records to join with new records added later
|
|
256
|
+
if self.retain_inputs:
|
|
257
|
+
self._left_input_records.extend(left_candidates)
|
|
258
|
+
self._right_input_records.extend(right_candidates)
|
|
259
|
+
|
|
260
|
+
# if this is the final call, then add in any left/right/outer join records that did not match
|
|
261
|
+
if final:
|
|
262
|
+
return self._compute_unmatched_records(), 0
|
|
263
|
+
|
|
264
|
+
# return empty DataRecordSet if no output records were produced
|
|
265
|
+
if len(output_records) == 0:
|
|
266
|
+
return DataRecordSet([], []), num_inputs_processed
|
|
267
|
+
|
|
268
|
+
return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class LLMJoin(JoinOp):
|
|
273
|
+
def __init__(
|
|
274
|
+
self,
|
|
275
|
+
model: Model,
|
|
276
|
+
prompt_strategy: PromptStrategy = PromptStrategy.JOIN,
|
|
277
|
+
reasoning_effort: str | None = None,
|
|
278
|
+
*args,
|
|
279
|
+
**kwargs,
|
|
280
|
+
):
|
|
281
|
+
super().__init__(*args, **kwargs)
|
|
282
|
+
self.model = model
|
|
283
|
+
self.prompt_strategy = prompt_strategy
|
|
284
|
+
self.reasoning_effort = reasoning_effort
|
|
285
|
+
self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
|
|
286
|
+
|
|
287
|
+
def __str__(self):
|
|
288
|
+
op = super().__str__()
|
|
289
|
+
op += f" Model: {self.model.value}\n"
|
|
290
|
+
op += f" Reasoning Effort: {self.reasoning_effort}\n"
|
|
291
|
+
op += f" Prompt Strategy: {self.prompt_strategy.value}\n"
|
|
292
|
+
return op
|
|
293
|
+
|
|
294
|
+
def get_id_params(self):
|
|
295
|
+
id_params = super().get_id_params()
|
|
296
|
+
id_params = {
|
|
297
|
+
"model": self.model.value,
|
|
298
|
+
"prompt_strategy": self.prompt_strategy.value,
|
|
299
|
+
"reasoning_effort": self.reasoning_effort,
|
|
300
|
+
**id_params,
|
|
301
|
+
}
|
|
302
|
+
return id_params
|
|
303
|
+
|
|
304
|
+
def get_op_params(self):
|
|
305
|
+
op_params = super().get_op_params()
|
|
306
|
+
op_params = {
|
|
307
|
+
"model": self.model,
|
|
308
|
+
"prompt_strategy": self.prompt_strategy,
|
|
309
|
+
"reasoning_effort": self.reasoning_effort,
|
|
310
|
+
**op_params,
|
|
311
|
+
}
|
|
312
|
+
return op_params
|
|
313
|
+
|
|
314
|
+
def get_model_name(self):
|
|
315
|
+
return self.model.value
|
|
316
|
+
|
|
104
317
|
def _process_join_candidate_pair(
|
|
105
318
|
self,
|
|
106
319
|
left_candidate: DataRecord,
|
|
@@ -116,6 +329,15 @@ class JoinOp(PhysicalOperator, ABC):
|
|
|
116
329
|
# determine whether or not the join was satisfied
|
|
117
330
|
passed_operator = field_answers["passed_operator"]
|
|
118
331
|
|
|
332
|
+
# handle different join types
|
|
333
|
+
if self.how == "left" and passed_operator:
|
|
334
|
+
self._left_joined_record_ids.add(left_candidate._id)
|
|
335
|
+
elif self.how == "right" and passed_operator:
|
|
336
|
+
self._right_joined_record_ids.add(right_candidate._id)
|
|
337
|
+
elif self.how == "outer" and passed_operator:
|
|
338
|
+
self._left_joined_record_ids.add(left_candidate._id)
|
|
339
|
+
self._right_joined_record_ids.add(right_candidate._id)
|
|
340
|
+
|
|
119
341
|
# compute output record and add to output_records
|
|
120
342
|
join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
|
|
121
343
|
join_dr._passed_operator = passed_operator
|
|
@@ -149,7 +371,7 @@ class JoinOp(PhysicalOperator, ABC):
|
|
|
149
371
|
return join_dr, record_op_stats
|
|
150
372
|
|
|
151
373
|
|
|
152
|
-
class NestedLoopsJoin(
|
|
374
|
+
class NestedLoopsJoin(LLMJoin):
|
|
153
375
|
|
|
154
376
|
def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
|
|
155
377
|
# estimate number of input tokens from source
|
|
@@ -192,7 +414,7 @@ class NestedLoopsJoin(JoinOp):
|
|
|
192
414
|
quality=quality,
|
|
193
415
|
)
|
|
194
416
|
|
|
195
|
-
def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
|
|
417
|
+
def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
|
|
196
418
|
# get the set of input fields from both records in the join
|
|
197
419
|
input_fields = self.get_input_fields()
|
|
198
420
|
|
|
@@ -234,6 +456,10 @@ class NestedLoopsJoin(JoinOp):
|
|
|
234
456
|
self._left_input_records.extend(left_candidates)
|
|
235
457
|
self._right_input_records.extend(right_candidates)
|
|
236
458
|
|
|
459
|
+
# if this is the final call, then add in any left/right/outer join records that did not match
|
|
460
|
+
if final:
|
|
461
|
+
return self._compute_unmatched_records(), 0
|
|
462
|
+
|
|
237
463
|
# return empty DataRecordSet if no output records were produced
|
|
238
464
|
if len(output_records) == 0:
|
|
239
465
|
return DataRecordSet([], []), num_inputs_processed
|
|
@@ -241,7 +467,7 @@ class NestedLoopsJoin(JoinOp):
|
|
|
241
467
|
return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
|
|
242
468
|
|
|
243
469
|
|
|
244
|
-
class EmbeddingJoin(
|
|
470
|
+
class EmbeddingJoin(LLMJoin):
|
|
245
471
|
# NOTE: we currently do not support audio joins as embedding models for audio seem to have
|
|
246
472
|
# specialized use cases (e.g., speech-to-text) with strict requirements on things like e.g. sample rate
|
|
247
473
|
def __init__(
|
|
@@ -261,6 +487,8 @@ class EmbeddingJoin(JoinOp):
|
|
|
261
487
|
if field_name.split(".")[-1] in self.get_input_fields()
|
|
262
488
|
])
|
|
263
489
|
self.embedding_model = Model.TEXT_EMBEDDING_3_SMALL if self.text_only else Model.CLIP_VIT_B_32
|
|
490
|
+
self.clip_model = None
|
|
491
|
+
self._lock = threading.Lock()
|
|
264
492
|
|
|
265
493
|
# keep track of embedding costs that could not be amortized if no output records were produced
|
|
266
494
|
self.residual_embedding_cost = 0.0
|
|
@@ -276,6 +504,11 @@ class EmbeddingJoin(JoinOp):
|
|
|
276
504
|
self.min_matching_sim = float("inf")
|
|
277
505
|
self.max_non_matching_sim = float("-inf")
|
|
278
506
|
|
|
507
|
+
def __str__(self):
|
|
508
|
+
op = super().__str__()
|
|
509
|
+
op += f" Num Samples: {self.num_samples}\n"
|
|
510
|
+
return op
|
|
511
|
+
|
|
279
512
|
def get_id_params(self):
|
|
280
513
|
id_params = super().get_id_params()
|
|
281
514
|
id_params = {
|
|
@@ -327,6 +560,12 @@ class EmbeddingJoin(JoinOp):
|
|
|
327
560
|
quality=quality,
|
|
328
561
|
)
|
|
329
562
|
|
|
563
|
+
def _get_clip_model(self):
|
|
564
|
+
with self._lock:
|
|
565
|
+
if self.clip_model is None:
|
|
566
|
+
self.clip_model = SentenceTransformer(self.embedding_model.value)
|
|
567
|
+
return self.clip_model
|
|
568
|
+
|
|
330
569
|
def _compute_embeddings(self, candidates: list[DataRecord], input_fields: list[str]) -> tuple[np.ndarray, GenerationStats]:
|
|
331
570
|
# return empty array and empty stats if no candidates
|
|
332
571
|
if len(candidates) == 0:
|
|
@@ -342,7 +581,7 @@ class EmbeddingJoin(JoinOp):
|
|
|
342
581
|
total_input_tokens = response.usage.total_tokens
|
|
343
582
|
embeddings = np.array([item.embedding for item in response.data])
|
|
344
583
|
else:
|
|
345
|
-
model =
|
|
584
|
+
model = self._get_clip_model()
|
|
346
585
|
embeddings = np.zeros((len(candidates), 512)) # CLIP embeddings are 512-dimensional
|
|
347
586
|
num_input_fields_present = 0
|
|
348
587
|
for field in input_fields:
|
|
@@ -389,6 +628,15 @@ class EmbeddingJoin(JoinOp):
|
|
|
389
628
|
join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
|
|
390
629
|
join_dr._passed_operator = passed_operator
|
|
391
630
|
|
|
631
|
+
# handle different join types
|
|
632
|
+
if self.how == "left" and passed_operator:
|
|
633
|
+
self._left_joined_record_ids.add(left_candidate._id)
|
|
634
|
+
elif self.how == "right" and passed_operator:
|
|
635
|
+
self._right_joined_record_ids.add(right_candidate._id)
|
|
636
|
+
elif self.how == "outer" and passed_operator:
|
|
637
|
+
self._left_joined_record_ids.add(left_candidate._id)
|
|
638
|
+
self._right_joined_record_ids.add(right_candidate._id)
|
|
639
|
+
|
|
392
640
|
# NOTE: embedding costs are amortized over all records and added at the end of __call__
|
|
393
641
|
# compute record stats and add to output_record_op_stats
|
|
394
642
|
record_op_stats = RecordOpStats(
|
|
@@ -410,7 +658,7 @@ class EmbeddingJoin(JoinOp):
|
|
|
410
658
|
|
|
411
659
|
return join_dr, record_op_stats
|
|
412
660
|
|
|
413
|
-
def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
|
|
661
|
+
def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
|
|
414
662
|
# get the set of input fields from both records in the join
|
|
415
663
|
input_fields = self.get_input_fields()
|
|
416
664
|
|
|
@@ -468,18 +716,22 @@ class EmbeddingJoin(JoinOp):
|
|
|
468
716
|
self.max_non_matching_sim = embedding_sim
|
|
469
717
|
if records_joined and embedding_sim < self.min_matching_sim:
|
|
470
718
|
self.min_matching_sim = embedding_sim
|
|
471
|
-
|
|
719
|
+
|
|
472
720
|
# update samples drawn and num_inputs_processed
|
|
473
721
|
self.samples_drawn += samples_to_draw
|
|
474
722
|
num_inputs_processed += samples_to_draw
|
|
475
723
|
|
|
476
724
|
# process remaining candidates based on embedding similarity
|
|
477
725
|
if len(join_candidates) > 0:
|
|
478
|
-
assert self.samples_drawn
|
|
726
|
+
assert self.samples_drawn >= self.num_samples, "All samples should have been drawn before processing remaining candidates"
|
|
479
727
|
with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
|
|
480
728
|
futures = []
|
|
481
729
|
for left_candidate, right_candidate, embedding_sim in join_candidates:
|
|
482
|
-
llm_call_needed =
|
|
730
|
+
llm_call_needed = (
|
|
731
|
+
self.min_matching_sim == float("inf")
|
|
732
|
+
or self.max_non_matching_sim == float("-inf")
|
|
733
|
+
or self.min_matching_sim <= embedding_sim <= self.max_non_matching_sim
|
|
734
|
+
)
|
|
483
735
|
|
|
484
736
|
if llm_call_needed:
|
|
485
737
|
futures.append(executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim))
|
|
@@ -526,6 +778,10 @@ class EmbeddingJoin(JoinOp):
|
|
|
526
778
|
self._left_input_records.extend(zip(left_candidates, left_embeddings))
|
|
527
779
|
self._right_input_records.extend(zip(right_candidates, right_embeddings))
|
|
528
780
|
|
|
781
|
+
# if this is the final call, then add in any left/right/outer join records that did not match
|
|
782
|
+
if final:
|
|
783
|
+
return self._compute_unmatched_records(), 0
|
|
784
|
+
|
|
529
785
|
# return empty DataRecordSet if no output records were produced
|
|
530
786
|
if len(output_records) == 0:
|
|
531
787
|
self.residual_embedding_cost = total_embedding_cost
|