palimpzest 0.9.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. palimpzest/constants.py +1 -0
  2. palimpzest/core/data/dataset.py +33 -5
  3. palimpzest/core/elements/groupbysig.py +5 -1
  4. palimpzest/core/elements/records.py +16 -7
  5. palimpzest/core/lib/schemas.py +20 -3
  6. palimpzest/core/models.py +4 -4
  7. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  8. palimpzest/query/execution/execution_strategy.py +8 -8
  9. palimpzest/query/execution/mab_execution_strategy.py +30 -11
  10. palimpzest/query/execution/parallel_execution_strategy.py +31 -7
  11. palimpzest/query/execution/single_threaded_execution_strategy.py +23 -6
  12. palimpzest/query/operators/__init__.py +7 -6
  13. palimpzest/query/operators/aggregate.py +110 -5
  14. palimpzest/query/operators/convert.py +1 -1
  15. palimpzest/query/operators/join.py +279 -23
  16. palimpzest/query/operators/logical.py +20 -8
  17. palimpzest/query/operators/mixture_of_agents.py +3 -1
  18. palimpzest/query/operators/physical.py +5 -2
  19. palimpzest/query/operators/{retrieve.py → topk.py} +10 -10
  20. palimpzest/query/optimizer/__init__.py +7 -3
  21. palimpzest/query/optimizer/cost_model.py +5 -5
  22. palimpzest/query/optimizer/optimizer.py +3 -2
  23. palimpzest/query/optimizer/plan.py +2 -3
  24. palimpzest/query/optimizer/rules.py +31 -11
  25. palimpzest/query/optimizer/tasks.py +4 -4
  26. palimpzest/utils/progress.py +19 -17
  27. palimpzest/validator/validator.py +7 -7
  28. {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/METADATA +26 -66
  29. {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/RECORD +32 -32
  30. {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/WHEEL +0 -0
  31. {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/licenses/LICENSE +0 -0
  32. {palimpzest-0.9.0.dist-info → palimpzest-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import contextlib
3
4
  import time
4
5
  from typing import Any
5
6
 
@@ -14,7 +15,7 @@ from palimpzest.constants import (
14
15
  )
15
16
  from palimpzest.core.elements.groupbysig import GroupBySig
16
17
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
17
- from palimpzest.core.lib.schemas import Average, Count, Max, Min
18
+ from palimpzest.core.lib.schemas import Average, Count, Max, Min, Sum
18
19
  from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
19
20
  from palimpzest.query.generators.generators import Generator
20
21
  from palimpzest.query.operators.physical import PhysicalOperator
@@ -68,6 +69,16 @@ class ApplyGroupByOp(AggregateOp):
68
69
  return 0
69
70
  elif func.lower() == "average":
70
71
  return (0, 0)
72
+ elif func.lower() == "sum":
73
+ return 0
74
+ elif func.lower() == "min":
75
+ return float("inf")
76
+ elif func.lower() == "max":
77
+ return float("-inf")
78
+ elif func.lower() == "list":
79
+ return []
80
+ elif func.lower() == "set":
81
+ return set()
71
82
  else:
72
83
  raise Exception("Unknown agg function " + func)
73
84
 
@@ -76,16 +87,34 @@ class ApplyGroupByOp(AggregateOp):
76
87
  if func.lower() == "count":
77
88
  return state + 1
78
89
  elif func.lower() == "average":
79
- sum, cnt = state
90
+ sum_, cnt = state
91
+ if val is None:
92
+ return (sum_, cnt)
93
+ return (sum_ + val, cnt + 1)
94
+ elif func.lower() == "sum":
95
+ if val is None:
96
+ return state
97
+ return state + sum(val) if isinstance(val, list) else state + val
98
+ elif func.lower() == "min":
99
+ if val is None:
100
+ return state
101
+ return min(state, min(val) if isinstance(val, list) else val)
102
+ elif func.lower() == "max":
80
103
  if val is None:
81
- return (sum, cnt)
82
- return (sum + val, cnt + 1)
104
+ return state
105
+ return max(state, max(val) if isinstance(val, list) else val)
106
+ elif func.lower() == "list":
107
+ state.append(val)
108
+ return state
109
+ elif func.lower() == "set":
110
+ state.add(val)
111
+ return state
83
112
  else:
84
113
  raise Exception("Unknown agg function " + func)
85
114
 
86
115
  @staticmethod
87
116
  def agg_final(func, state):
88
- if func.lower() == "count":
117
+ if func.lower() in ["count", "sum", "min", "max", "list", "set"]:
89
118
  return state
90
119
  elif func.lower() == "average":
91
120
  sum, cnt = state
@@ -240,6 +269,82 @@ class AverageAggregateOp(AggregateOp):
240
269
  return DataRecordSet([dr], [record_op_stats])
241
270
 
242
271
 
272
+ class SumAggregateOp(AggregateOp):
273
+ # NOTE: we don't actually need / use agg_func here (yet)
274
+
275
+ def __init__(self, agg_func: AggFunc, *args, **kwargs):
276
+ # enforce that output schema is correct
277
+ assert kwargs["output_schema"].model_fields.keys() == Sum.model_fields.keys(), "SumAggregateOp requires output_schema to be Sum"
278
+
279
+ # enforce that input schema is a single numeric field
280
+ input_field_types = list(kwargs["input_schema"].model_fields.values())
281
+ assert len(input_field_types) == 1, "SumAggregateOp requires input_schema to have exactly one field"
282
+ numeric_field_types = [
283
+ bool, int, float, int | float,
284
+ bool | None, int | None, float | None, int | float | None,
285
+ bool | Any, int | Any, float | Any, int | float | Any,
286
+ bool | None | Any, int | None | Any, float | None | Any, int | float | None | Any,
287
+ ]
288
+ is_numeric = input_field_types[0].annotation in numeric_field_types
289
+ assert is_numeric, f"SumAggregateOp requires input_schema to have a numeric field type, i.e. one of: {numeric_field_types}\nGot: {input_field_types[0]}"
290
+
291
+ # call parent constructor
292
+ super().__init__(*args, **kwargs)
293
+ self.agg_func = agg_func
294
+
295
+ def __str__(self):
296
+ op = super().__str__()
297
+ op += f" Function: {str(self.agg_func)}\n"
298
+ return op
299
+
300
+ def get_id_params(self):
301
+ id_params = super().get_id_params()
302
+ return {"agg_func": str(self.agg_func), **id_params}
303
+
304
+ def get_op_params(self):
305
+ op_params = super().get_op_params()
306
+ return {"agg_func": self.agg_func, **op_params}
307
+
308
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
309
+ # for now, assume applying the aggregation takes negligible additional time (and no cost in USD)
310
+ return OperatorCostEstimates(
311
+ cardinality=1,
312
+ time_per_record=0,
313
+ cost_per_record=0,
314
+ quality=1.0,
315
+ )
316
+
317
+ def __call__(self, candidates: list[DataRecord]) -> DataRecordSet:
318
+ start_time = time.time()
319
+
320
+ # NOTE: we currently do not guarantee that input values conform to their specified type;
321
+ # as a result, we simply omit any values which do not parse to a float from the average
322
+ # NOTE: right now we perform a check in the constructor which enforces that the input_schema
323
+ # has a single field which is numeric in nature; in the future we may want to have a
324
+ # cleaner way of computing the value (rather than `float(list(candidate...))` below)
325
+ summation = 0
326
+ for candidate in candidates:
327
+ with contextlib.suppress(Exception):
328
+ summation += float(list(candidate.to_dict().values())[0])
329
+ data_item = Sum(sum=summation)
330
+ dr = DataRecord.from_agg_parents(data_item, parent_records=candidates)
331
+
332
+ # create RecordOpStats object
333
+ record_op_stats = RecordOpStats(
334
+ record_id=dr._id,
335
+ record_parent_ids=dr._parent_ids,
336
+ record_source_indices=dr._source_indices,
337
+ record_state=dr.to_dict(include_bytes=False),
338
+ full_op_id=self.get_full_op_id(),
339
+ logical_op_id=self.logical_op_id,
340
+ op_name=self.op_name(),
341
+ time_per_record=time.time() - start_time,
342
+ cost_per_record=0.0,
343
+ )
344
+
345
+ return DataRecordSet([dr], [record_op_stats])
346
+
347
+
243
348
  class CountAggregateOp(AggregateOp):
244
349
  # NOTE: we don't actually need / use agg_func here (yet)
245
350
 
@@ -320,7 +320,7 @@ class LLMConvert(ConvertOp):
320
320
  est_num_output_tokens = NAIVE_EST_NUM_OUTPUT_TOKENS
321
321
 
322
322
  # get est. of conversion time per record from model card;
323
- model_name = self.model.value if getattr(self, "model", None) is not None else Model.GPT_4o_MINI.value
323
+ model_name = self.model.value
324
324
  model_conversion_time_per_record = MODEL_CARDS[model_name]["seconds_per_output_token"] * est_num_output_tokens
325
325
 
326
326
  # get est. of conversion cost (in USD) per record from model card
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import threading
3
4
  import time
4
5
  from abc import ABC, abstractmethod
5
6
  from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -37,10 +38,9 @@ class JoinOp(PhysicalOperator, ABC):
37
38
  def __init__(
38
39
  self,
39
40
  condition: str,
40
- model: Model,
41
- prompt_strategy: PromptStrategy = PromptStrategy.JOIN,
41
+ how: str = "inner",
42
+ on: list[str] | None = None,
42
43
  join_parallelism: int = 64,
43
- reasoning_effort: str | None = None,
44
44
  retain_inputs: bool = True,
45
45
  desc: str | None = None,
46
46
  *args,
@@ -49,33 +49,37 @@ class JoinOp(PhysicalOperator, ABC):
49
49
  super().__init__(*args, **kwargs)
50
50
  assert self.input_schema == self.output_schema, "Input and output schemas must match for JoinOp"
51
51
  self.condition = condition
52
- self.model = model
53
- self.prompt_strategy = prompt_strategy
52
+ self.how = how
53
+ self.on = on
54
54
  self.join_parallelism = join_parallelism
55
- self.reasoning_effort = reasoning_effort
56
55
  self.retain_inputs = retain_inputs
57
56
  self.desc = desc
58
- self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
59
57
  self.join_idx = 0
58
+ self.finished = False
60
59
 
61
60
  # maintain list(s) of input records for the join
62
61
  self._left_input_records: list[DataRecord] = []
63
62
  self._right_input_records: list[DataRecord] = []
64
63
 
64
+ # maintain set of left/right record ids that have been joined (for left/right/outer joins)
65
+ self._left_joined_record_ids: set[str] = set()
66
+ self._right_joined_record_ids: set[str] = set()
67
+
65
68
  def __str__(self):
66
69
  op = super().__str__()
67
70
  op += f" Condition: {self.condition}\n"
71
+ op += f" How: {self.how}\n"
72
+ op += f" On: {self.on}\n"
68
73
  return op
69
74
 
70
75
  def get_id_params(self):
71
76
  id_params = super().get_id_params()
72
77
  id_params = {
73
78
  "condition": self.condition,
74
- "model": self.model.value,
75
- "prompt_strategy": self.prompt_strategy.value,
76
79
  "join_parallelism": self.join_parallelism,
77
- "reasoning_effort": self.reasoning_effort,
78
80
  "desc": self.desc,
81
+ "how": self.how,
82
+ "on": self.on,
79
83
  **id_params,
80
84
  }
81
85
  return id_params
@@ -84,23 +88,232 @@ class JoinOp(PhysicalOperator, ABC):
84
88
  op_params = super().get_op_params()
85
89
  op_params = {
86
90
  "condition": self.condition,
87
- "model": self.model,
88
- "prompt_strategy": self.prompt_strategy,
89
91
  "join_parallelism": self.join_parallelism,
90
- "reasoning_effort": self.reasoning_effort,
91
92
  "retain_inputs": self.retain_inputs,
92
93
  "desc": self.desc,
94
+ "how": self.how,
95
+ "on": self.on,
93
96
  **op_params,
94
97
  }
95
98
  return op_params
96
99
 
97
- def get_model_name(self):
98
- return self.model.value
100
+ def _compute_unmatched_records(self) -> DataRecordSet:
101
+ """Helper function to compute unmatched records for left/right/outer joins."""
102
+ def join_unmatched_records(input_records: list[DataRecord] | list[tuple[DataRecord, list[float]]], joined_record_ids: set[str], left: bool = True):
103
+ records, record_op_stats_lst = [], []
104
+ for record in input_records:
105
+ start_time = time.time()
106
+ record = record[0] if isinstance(record, tuple) else record
107
+ if record._id not in joined_record_ids:
108
+ unmatched_dr = (
109
+ DataRecord.from_join_parents(self.output_schema, record, None)
110
+ if left
111
+ else DataRecord.from_join_parents(self.output_schema, None, record)
112
+ )
113
+ unmatched_dr._passed_operator = True
114
+
115
+ # compute record stats and add to output_record_op_stats
116
+ time_per_record = time.time() - start_time
117
+ record_op_stats = RecordOpStats(
118
+ record_id=unmatched_dr._id,
119
+ record_parent_ids=unmatched_dr._parent_ids,
120
+ record_source_indices=unmatched_dr._source_indices,
121
+ record_state=unmatched_dr.to_dict(include_bytes=False),
122
+ full_op_id=self.get_full_op_id(),
123
+ logical_op_id=self.logical_op_id,
124
+ op_name=self.op_name(),
125
+ time_per_record=time_per_record,
126
+ cost_per_record=0.0,
127
+ model_name=self.get_model_name(),
128
+ join_condition=str(self.on),
129
+ fn_call_duration_secs=time_per_record,
130
+ answer={"passed_operator": True},
131
+ passed_operator=True,
132
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
133
+ )
134
+ records.append(unmatched_dr)
135
+ record_op_stats_lst.append(record_op_stats)
136
+ return records, record_op_stats_lst
137
+
138
+ records, record_op_stats = [], []
139
+ if self.how == "left":
140
+ records, record_op_stats = join_unmatched_records(self._left_input_records, self._left_joined_record_ids, left=True)
141
+
142
+ elif self.how == "right":
143
+ records, record_op_stats = join_unmatched_records(self._right_input_records, self._right_joined_record_ids, left=False)
144
+
145
+ elif self.how == "outer":
146
+ records, record_op_stats = join_unmatched_records(self._left_input_records, self._left_joined_record_ids, left=True)
147
+ right_records, right_record_op_stats = join_unmatched_records(self._right_input_records, self._right_joined_record_ids, left=False)
148
+ records.extend(right_records)
149
+ record_op_stats.extend(right_record_op_stats)
150
+
151
+ return DataRecordSet(records, record_op_stats)
99
152
 
100
153
  @abstractmethod
101
154
  def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
102
155
  pass
103
156
 
157
+ def set_finished(self):
158
+ """Mark the operator as finished after computing left/right/outer join logic."""
159
+ self.finished = True
160
+
161
+ class RelationalJoin(JoinOp):
162
+
163
+ def get_model_name(self):
164
+ return None
165
+
166
+ def _process_join_candidate_pair(self, left_candidate, right_candidate) -> tuple[DataRecord, RecordOpStats]:
167
+ start_time = time.time()
168
+
169
+ # determine whether or not the join was satisfied
170
+ passed_operator = all(
171
+ left_candidate[field] == right_candidate[field]
172
+ for field in self.on
173
+ )
174
+
175
+ # handle different join types
176
+ if self.how == "left" and passed_operator:
177
+ self._left_joined_record_ids.add(left_candidate._id)
178
+ elif self.how == "right" and passed_operator:
179
+ self._right_joined_record_ids.add(right_candidate._id)
180
+ elif self.how == "outer" and passed_operator:
181
+ self._left_joined_record_ids.add(left_candidate._id)
182
+ self._right_joined_record_ids.add(right_candidate._id)
183
+
184
+ # compute output record and add to output_records
185
+ join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
186
+ join_dr._passed_operator = passed_operator
187
+
188
+ # compute record stats and add to output_record_op_stats
189
+ time_per_record = time.time() - start_time
190
+ record_op_stats = RecordOpStats(
191
+ record_id=join_dr._id,
192
+ record_parent_ids=join_dr._parent_ids,
193
+ record_source_indices=join_dr._source_indices,
194
+ record_state=join_dr.to_dict(include_bytes=False),
195
+ full_op_id=self.get_full_op_id(),
196
+ logical_op_id=self.logical_op_id,
197
+ op_name=self.op_name(),
198
+ time_per_record=time_per_record,
199
+ cost_per_record=0.0,
200
+ model_name=self.get_model_name(),
201
+ join_condition=str(self.on),
202
+ fn_call_duration_secs=time_per_record,
203
+ answer={"passed_operator": passed_operator},
204
+ passed_operator=passed_operator,
205
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
206
+ )
207
+
208
+ return join_dr, record_op_stats
209
+
210
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
211
+ # estimate output cardinality using a constant assumption of the filter selectivity
212
+ selectivity = NAIVE_EST_JOIN_SELECTIVITY
213
+ cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
214
+
215
+ # estimate 1 ms execution time per input record pair
216
+ time_per_record = 0.001 * (left_source_op_cost_estimates.cardinality + right_source_op_cost_estimates.cardinality)
217
+
218
+ return OperatorCostEstimates(
219
+ cardinality=cardinality,
220
+ time_per_record=time_per_record,
221
+ cost_per_record=0.0,
222
+ quality=1.0,
223
+ )
224
+
225
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
226
+ # create the set of candidates to join
227
+ join_candidates = []
228
+ for candidate in left_candidates:
229
+ for right_candidate in right_candidates:
230
+ join_candidates.append((candidate, right_candidate))
231
+ for right_candidate in self._right_input_records:
232
+ join_candidates.append((candidate, right_candidate))
233
+ for candidate in self._left_input_records:
234
+ for right_candidate in right_candidates:
235
+ join_candidates.append((candidate, right_candidate))
236
+
237
+ # apply the join logic to each pair of candidates
238
+ output_records, output_record_op_stats = [], []
239
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
240
+ futures = [
241
+ executor.submit(self._process_join_candidate_pair, candidate, right_candidate)
242
+ for candidate, right_candidate in join_candidates
243
+ ]
244
+
245
+ # collect results as they complete
246
+ for future in as_completed(futures):
247
+ self.join_idx += 1
248
+ join_output_record, join_output_record_op_stats = future.result()
249
+ output_records.append(join_output_record)
250
+ output_record_op_stats.append(join_output_record_op_stats)
251
+
252
+ # compute the number of inputs processed
253
+ num_inputs_processed = len(join_candidates)
254
+
255
+ # store input records to join with new records added later
256
+ if self.retain_inputs:
257
+ self._left_input_records.extend(left_candidates)
258
+ self._right_input_records.extend(right_candidates)
259
+
260
+ # if this is the final call, then add in any left/right/outer join records that did not match
261
+ if final:
262
+ return self._compute_unmatched_records(), 0
263
+
264
+ # return empty DataRecordSet if no output records were produced
265
+ if len(output_records) == 0:
266
+ return DataRecordSet([], []), num_inputs_processed
267
+
268
+ return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
269
+
270
+
271
+
272
+ class LLMJoin(JoinOp):
273
+ def __init__(
274
+ self,
275
+ model: Model,
276
+ prompt_strategy: PromptStrategy = PromptStrategy.JOIN,
277
+ reasoning_effort: str | None = None,
278
+ *args,
279
+ **kwargs,
280
+ ):
281
+ super().__init__(*args, **kwargs)
282
+ self.model = model
283
+ self.prompt_strategy = prompt_strategy
284
+ self.reasoning_effort = reasoning_effort
285
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
286
+
287
+ def __str__(self):
288
+ op = super().__str__()
289
+ op += f" Model: {self.model.value}\n"
290
+ op += f" Reasoning Effort: {self.reasoning_effort}\n"
291
+ op += f" Prompt Strategy: {self.prompt_strategy.value}\n"
292
+ return op
293
+
294
+ def get_id_params(self):
295
+ id_params = super().get_id_params()
296
+ id_params = {
297
+ "model": self.model.value,
298
+ "prompt_strategy": self.prompt_strategy.value,
299
+ "reasoning_effort": self.reasoning_effort,
300
+ **id_params,
301
+ }
302
+ return id_params
303
+
304
+ def get_op_params(self):
305
+ op_params = super().get_op_params()
306
+ op_params = {
307
+ "model": self.model,
308
+ "prompt_strategy": self.prompt_strategy,
309
+ "reasoning_effort": self.reasoning_effort,
310
+ **op_params,
311
+ }
312
+ return op_params
313
+
314
+ def get_model_name(self):
315
+ return self.model.value
316
+
104
317
  def _process_join_candidate_pair(
105
318
  self,
106
319
  left_candidate: DataRecord,
@@ -116,6 +329,15 @@ class JoinOp(PhysicalOperator, ABC):
116
329
  # determine whether or not the join was satisfied
117
330
  passed_operator = field_answers["passed_operator"]
118
331
 
332
+ # handle different join types
333
+ if self.how == "left" and passed_operator:
334
+ self._left_joined_record_ids.add(left_candidate._id)
335
+ elif self.how == "right" and passed_operator:
336
+ self._right_joined_record_ids.add(right_candidate._id)
337
+ elif self.how == "outer" and passed_operator:
338
+ self._left_joined_record_ids.add(left_candidate._id)
339
+ self._right_joined_record_ids.add(right_candidate._id)
340
+
119
341
  # compute output record and add to output_records
120
342
  join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
121
343
  join_dr._passed_operator = passed_operator
@@ -149,7 +371,7 @@ class JoinOp(PhysicalOperator, ABC):
149
371
  return join_dr, record_op_stats
150
372
 
151
373
 
152
- class NestedLoopsJoin(JoinOp):
374
+ class NestedLoopsJoin(LLMJoin):
153
375
 
154
376
  def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
155
377
  # estimate number of input tokens from source
@@ -192,7 +414,7 @@ class NestedLoopsJoin(JoinOp):
192
414
  quality=quality,
193
415
  )
194
416
 
195
- def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
417
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
196
418
  # get the set of input fields from both records in the join
197
419
  input_fields = self.get_input_fields()
198
420
 
@@ -234,6 +456,10 @@ class NestedLoopsJoin(JoinOp):
234
456
  self._left_input_records.extend(left_candidates)
235
457
  self._right_input_records.extend(right_candidates)
236
458
 
459
+ # if this is the final call, then add in any left/right/outer join records that did not match
460
+ if final:
461
+ return self._compute_unmatched_records(), 0
462
+
237
463
  # return empty DataRecordSet if no output records were produced
238
464
  if len(output_records) == 0:
239
465
  return DataRecordSet([], []), num_inputs_processed
@@ -241,7 +467,7 @@ class NestedLoopsJoin(JoinOp):
241
467
  return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
242
468
 
243
469
 
244
- class EmbeddingJoin(JoinOp):
470
+ class EmbeddingJoin(LLMJoin):
245
471
  # NOTE: we currently do not support audio joins as embedding models for audio seem to have
246
472
  # specialized use cases (e.g., speech-to-text) with strict requirements on things like e.g. sample rate
247
473
  def __init__(
@@ -261,6 +487,8 @@ class EmbeddingJoin(JoinOp):
261
487
  if field_name.split(".")[-1] in self.get_input_fields()
262
488
  ])
263
489
  self.embedding_model = Model.TEXT_EMBEDDING_3_SMALL if self.text_only else Model.CLIP_VIT_B_32
490
+ self.clip_model = None
491
+ self._lock = threading.Lock()
264
492
 
265
493
  # keep track of embedding costs that could not be amortized if no output records were produced
266
494
  self.residual_embedding_cost = 0.0
@@ -276,6 +504,11 @@ class EmbeddingJoin(JoinOp):
276
504
  self.min_matching_sim = float("inf")
277
505
  self.max_non_matching_sim = float("-inf")
278
506
 
507
+ def __str__(self):
508
+ op = super().__str__()
509
+ op += f" Num Samples: {self.num_samples}\n"
510
+ return op
511
+
279
512
  def get_id_params(self):
280
513
  id_params = super().get_id_params()
281
514
  id_params = {
@@ -327,6 +560,12 @@ class EmbeddingJoin(JoinOp):
327
560
  quality=quality,
328
561
  )
329
562
 
563
+ def _get_clip_model(self):
564
+ with self._lock:
565
+ if self.clip_model is None:
566
+ self.clip_model = SentenceTransformer(self.embedding_model.value)
567
+ return self.clip_model
568
+
330
569
  def _compute_embeddings(self, candidates: list[DataRecord], input_fields: list[str]) -> tuple[np.ndarray, GenerationStats]:
331
570
  # return empty array and empty stats if no candidates
332
571
  if len(candidates) == 0:
@@ -342,7 +581,7 @@ class EmbeddingJoin(JoinOp):
342
581
  total_input_tokens = response.usage.total_tokens
343
582
  embeddings = np.array([item.embedding for item in response.data])
344
583
  else:
345
- model = SentenceTransformer(self.embedding_model.value)
584
+ model = self._get_clip_model()
346
585
  embeddings = np.zeros((len(candidates), 512)) # CLIP embeddings are 512-dimensional
347
586
  num_input_fields_present = 0
348
587
  for field in input_fields:
@@ -389,6 +628,15 @@ class EmbeddingJoin(JoinOp):
389
628
  join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
390
629
  join_dr._passed_operator = passed_operator
391
630
 
631
+ # handle different join types
632
+ if self.how == "left" and passed_operator:
633
+ self._left_joined_record_ids.add(left_candidate._id)
634
+ elif self.how == "right" and passed_operator:
635
+ self._right_joined_record_ids.add(right_candidate._id)
636
+ elif self.how == "outer" and passed_operator:
637
+ self._left_joined_record_ids.add(left_candidate._id)
638
+ self._right_joined_record_ids.add(right_candidate._id)
639
+
392
640
  # NOTE: embedding costs are amortized over all records and added at the end of __call__
393
641
  # compute record stats and add to output_record_op_stats
394
642
  record_op_stats = RecordOpStats(
@@ -410,7 +658,7 @@ class EmbeddingJoin(JoinOp):
410
658
 
411
659
  return join_dr, record_op_stats
412
660
 
413
- def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
661
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
414
662
  # get the set of input fields from both records in the join
415
663
  input_fields = self.get_input_fields()
416
664
 
@@ -468,18 +716,22 @@ class EmbeddingJoin(JoinOp):
468
716
  self.max_non_matching_sim = embedding_sim
469
717
  if records_joined and embedding_sim < self.min_matching_sim:
470
718
  self.min_matching_sim = embedding_sim
471
-
719
+
472
720
  # update samples drawn and num_inputs_processed
473
721
  self.samples_drawn += samples_to_draw
474
722
  num_inputs_processed += samples_to_draw
475
723
 
476
724
  # process remaining candidates based on embedding similarity
477
725
  if len(join_candidates) > 0:
478
- assert self.samples_drawn == self.num_samples, "All samples should have been drawn before processing remaining candidates"
726
+ assert self.samples_drawn >= self.num_samples, "All samples should have been drawn before processing remaining candidates"
479
727
  with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
480
728
  futures = []
481
729
  for left_candidate, right_candidate, embedding_sim in join_candidates:
482
- llm_call_needed = self.min_matching_sim <= embedding_sim <= self.max_non_matching_sim
730
+ llm_call_needed = (
731
+ self.min_matching_sim == float("inf")
732
+ or self.max_non_matching_sim == float("-inf")
733
+ or self.min_matching_sim <= embedding_sim <= self.max_non_matching_sim
734
+ )
483
735
 
484
736
  if llm_call_needed:
485
737
  futures.append(executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim))
@@ -526,6 +778,10 @@ class EmbeddingJoin(JoinOp):
526
778
  self._left_input_records.extend(zip(left_candidates, left_embeddings))
527
779
  self._right_input_records.extend(zip(right_candidates, right_embeddings))
528
780
 
781
+ # if this is the final call, then add in any left/right/outer join records that did not match
782
+ if final:
783
+ return self._compute_unmatched_records(), 0
784
+
529
785
  # return empty DataRecordSet if no output records were produced
530
786
  if len(output_records) == 0:
531
787
  self.residual_embedding_cost = total_embedding_cost