palimpzest 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. palimpzest/constants.py +38 -62
  2. palimpzest/core/data/dataset.py +1 -1
  3. palimpzest/core/data/iter_dataset.py +5 -5
  4. palimpzest/core/elements/groupbysig.py +1 -1
  5. palimpzest/core/elements/records.py +91 -109
  6. palimpzest/core/lib/schemas.py +23 -0
  7. palimpzest/core/models.py +3 -3
  8. palimpzest/prompts/__init__.py +2 -6
  9. palimpzest/prompts/convert_prompts.py +10 -66
  10. palimpzest/prompts/critique_and_refine_prompts.py +66 -0
  11. palimpzest/prompts/filter_prompts.py +8 -46
  12. palimpzest/prompts/join_prompts.py +12 -75
  13. palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
  14. palimpzest/prompts/moa_proposer_prompts.py +87 -0
  15. palimpzest/prompts/prompt_factory.py +351 -479
  16. palimpzest/prompts/split_merge_prompts.py +51 -2
  17. palimpzest/prompts/split_proposer_prompts.py +48 -16
  18. palimpzest/prompts/utils.py +109 -0
  19. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  20. palimpzest/query/execution/execution_strategy.py +4 -4
  21. palimpzest/query/execution/mab_execution_strategy.py +47 -23
  22. palimpzest/query/execution/parallel_execution_strategy.py +3 -3
  23. palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
  24. palimpzest/query/generators/generators.py +31 -17
  25. palimpzest/query/operators/__init__.py +15 -2
  26. palimpzest/query/operators/aggregate.py +21 -19
  27. palimpzest/query/operators/compute.py +6 -8
  28. palimpzest/query/operators/convert.py +12 -37
  29. palimpzest/query/operators/critique_and_refine.py +194 -0
  30. palimpzest/query/operators/distinct.py +7 -7
  31. palimpzest/query/operators/filter.py +13 -25
  32. palimpzest/query/operators/join.py +321 -192
  33. palimpzest/query/operators/limit.py +4 -4
  34. palimpzest/query/operators/mixture_of_agents.py +246 -0
  35. palimpzest/query/operators/physical.py +25 -2
  36. palimpzest/query/operators/project.py +4 -4
  37. palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
  38. palimpzest/query/operators/retrieve.py +10 -9
  39. palimpzest/query/operators/scan.py +9 -10
  40. palimpzest/query/operators/search.py +18 -24
  41. palimpzest/query/operators/split.py +321 -0
  42. palimpzest/query/optimizer/__init__.py +12 -8
  43. palimpzest/query/optimizer/optimizer.py +12 -10
  44. palimpzest/query/optimizer/rules.py +201 -108
  45. palimpzest/query/optimizer/tasks.py +18 -6
  46. palimpzest/query/processor/config.py +2 -2
  47. palimpzest/query/processor/query_processor.py +2 -2
  48. palimpzest/query/processor/query_processor_factory.py +9 -5
  49. palimpzest/validator/validator.py +7 -9
  50. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/METADATA +3 -8
  51. palimpzest-0.8.3.dist-info/RECORD +95 -0
  52. palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
  53. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
  54. palimpzest/prompts/util_phrases.py +0 -19
  55. palimpzest/query/operators/critique_and_refine_convert.py +0 -113
  56. palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
  57. palimpzest/query/operators/split_convert.py +0 -170
  58. palimpzest-0.8.1.dist-info/RECORD +0 -95
  59. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/WHEEL +0 -0
  60. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/licenses/LICENSE +0 -0
  61. {palimpzest-0.8.1.dist-info → palimpzest-0.8.3.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,12 @@ import time
4
4
  from abc import ABC, abstractmethod
5
5
  from concurrent.futures import ThreadPoolExecutor, as_completed
6
6
 
7
+ import numpy as np
8
+ from numpy.linalg import norm
9
+ from openai import OpenAI
10
+ from PIL import Image
7
11
  from pydantic.fields import FieldInfo
12
+ from sentence_transformers import SentenceTransformer
8
13
 
9
14
  from palimpzest.constants import (
10
15
  MODEL_CARDS,
@@ -15,136 +20,90 @@ from palimpzest.constants import (
15
20
  PromptStrategy,
16
21
  )
17
22
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
18
- from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
23
+ from palimpzest.core.lib.schemas import AUDIO_FIELD_TYPES, IMAGE_FIELD_TYPES, ImageFilepath
24
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
19
25
  from palimpzest.query.generators.generators import Generator
20
26
  from palimpzest.query.operators.physical import PhysicalOperator
21
27
 
22
28
 
23
- class JoinOp(PhysicalOperator, ABC):
24
- def __init__(self, condition: str, desc: str | None = None, *args, **kwargs):
25
- super().__init__(*args, **kwargs)
26
- assert self.input_schema == self.output_schema, "Input and output schemas must match for JoinOp"
27
- self.condition = condition
28
- self.desc = desc
29
-
30
- def __str__(self):
31
- op = super().__str__()
32
- op += f" Condition: {self.condition}\n"
33
- return op
34
-
35
- def get_id_params(self):
36
- id_params = super().get_id_params()
37
- return {"condition": self.condition, "desc": self.desc, **id_params}
29
+ def compute_similarity(left_embedding: list[float], right_embedding: list[float]) -> float:
30
+ """
31
+ Compute the similarity between two embeddings using cosine similarity.
32
+ """
33
+ return np.dot(left_embedding, right_embedding) / (norm(left_embedding) * norm(right_embedding))
38
34
 
39
- def get_op_params(self):
40
- op_params = super().get_op_params()
41
- return {"condition": self.condition, "desc": self.desc, **op_params}
42
35
 
43
- @abstractmethod
44
- def is_image_join(self) -> bool:
45
- """Return True if the join operation processes image(s), False otherwise."""
46
- pass
47
-
48
- @abstractmethod
49
- def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
50
- pass
51
-
52
-
53
- class BlockingNestedLoopsJoin(JoinOp):
36
+ class JoinOp(PhysicalOperator, ABC):
54
37
  def __init__(
55
38
  self,
39
+ condition: str,
56
40
  model: Model,
57
- prompt_strategy: PromptStrategy = PromptStrategy.COT_JOIN,
41
+ prompt_strategy: PromptStrategy = PromptStrategy.JOIN,
58
42
  join_parallelism: int = 64,
59
43
  reasoning_effort: str | None = None,
44
+ desc: str | None = None,
60
45
  *args,
61
46
  **kwargs,
62
47
  ):
63
48
  super().__init__(*args, **kwargs)
49
+ assert self.input_schema == self.output_schema, "Input and output schemas must match for JoinOp"
50
+ self.condition = condition
64
51
  self.model = model
65
52
  self.prompt_strategy = prompt_strategy
66
53
  self.join_parallelism = join_parallelism
67
54
  self.reasoning_effort = reasoning_effort
55
+ self.desc = desc
68
56
  self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
69
57
  self.join_idx = 0
70
58
 
59
+ # maintain list(s) of input records for the join
60
+ self._left_input_records: list[DataRecord] = []
61
+ self._right_input_records: list[DataRecord] = []
62
+
63
+ def __str__(self):
64
+ op = super().__str__()
65
+ op += f" Condition: {self.condition}\n"
66
+ return op
67
+
71
68
  def get_id_params(self):
72
69
  id_params = super().get_id_params()
73
70
  id_params = {
71
+ "condition": self.condition,
74
72
  "model": self.model.value,
75
73
  "prompt_strategy": self.prompt_strategy.value,
76
74
  "join_parallelism": self.join_parallelism,
77
75
  "reasoning_effort": self.reasoning_effort,
76
+ "desc": self.desc,
78
77
  **id_params,
79
78
  }
80
-
81
79
  return id_params
82
80
 
83
81
  def get_op_params(self):
84
82
  op_params = super().get_op_params()
85
83
  op_params = {
86
- "model": self.model,
87
- "prompt_strategy": self.prompt_strategy,
84
+ "condition": self.condition,
85
+ "model": self.model.value,
86
+ "prompt_strategy": self.prompt_strategy.value,
88
87
  "join_parallelism": self.join_parallelism,
89
88
  "reasoning_effort": self.reasoning_effort,
89
+ "desc": self.desc,
90
90
  **op_params,
91
91
  }
92
-
93
92
  return op_params
94
93
 
95
94
  def get_model_name(self):
96
95
  return self.model.value
97
96
 
98
- def is_image_join(self) -> bool:
99
- return self.prompt_strategy is PromptStrategy.COT_JOIN_IMAGE
100
-
101
- def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
102
- # estimate number of input tokens from source
103
- est_num_input_tokens = 2 * NAIVE_EST_NUM_INPUT_TOKENS
104
- if self.is_image_join():
105
- est_num_input_tokens = 2 * 765 / 10 # 1024x1024 image is 765 tokens
106
-
107
- # NOTE: the output often generates an entire reasoning sentence, thus the true value may be higher
108
- # the filter operation's LLM call should only output TRUE or FALSE, thus we expect its
109
- # number of output tokens to be ~1.25
110
- est_num_output_tokens = 1.25
111
-
112
- # get est. of conversion time per record from model card;
113
- model_conversion_time_per_record = (
114
- MODEL_CARDS[self.model.value]["seconds_per_output_token"] * est_num_output_tokens
115
- )
116
-
117
- # get est. of conversion cost (in USD) per record from model card
118
- usd_per_input_token = (
119
- MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
120
- if self.prompt_strategy.is_audio_prompt()
121
- else MODEL_CARDS[self.model.value]["usd_per_input_token"]
122
- )
123
- model_conversion_usd_per_record = (
124
- usd_per_input_token * est_num_input_tokens
125
- + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
126
- )
127
-
128
- # estimate output cardinality using a constant assumption of the filter selectivity
129
- selectivity = NAIVE_EST_JOIN_SELECTIVITY
130
- cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
131
-
132
- # estimate quality of output based on the strength of the model being used
133
- quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0)
134
-
135
- return OperatorCostEstimates(
136
- cardinality=cardinality,
137
- time_per_record=model_conversion_time_per_record,
138
- cost_per_record=model_conversion_usd_per_record,
139
- quality=quality,
140
- )
97
+ @abstractmethod
98
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
99
+ pass
141
100
 
142
101
  def _process_join_candidate_pair(
143
102
  self,
144
103
  left_candidate: DataRecord,
145
104
  right_candidate: DataRecord,
146
105
  gen_kwargs: dict,
147
- ) -> tuple[list[DataRecord], list[RecordOpStats]]:
106
+ ) -> tuple[DataRecord, RecordOpStats]:
148
107
  start_time = time.time()
149
108
 
150
109
  # generate output; NOTE: FieldInfo is used to indicate the output type; thus, the desc is not needed
@@ -156,13 +115,13 @@ class BlockingNestedLoopsJoin(JoinOp):
156
115
 
157
116
  # compute output record and add to output_records
158
117
  join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
159
- join_dr.passed_operator = passed_operator
118
+ join_dr._passed_operator = passed_operator
160
119
 
161
120
  # compute record stats and add to output_record_op_stats
162
121
  record_op_stats = RecordOpStats(
163
- record_id=join_dr.id,
164
- record_parent_ids=join_dr.parent_ids,
165
- record_source_indices=join_dr.source_indices,
122
+ record_id=join_dr._id,
123
+ record_parent_ids=join_dr._parent_ids,
124
+ record_source_indices=join_dr._source_indices,
166
125
  record_state=join_dr.to_dict(include_bytes=False),
167
126
  full_op_id=self.get_full_op_id(),
168
127
  logical_op_id=self.logical_op_id,
@@ -181,11 +140,54 @@ class BlockingNestedLoopsJoin(JoinOp):
181
140
  total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
182
141
  answer=field_answers,
183
142
  passed_operator=passed_operator,
184
- image_operation=self.is_image_join(),
185
143
  op_details={k: str(v) for k, v in self.get_id_params().items()},
186
144
  )
187
145
 
188
- return [join_dr], [record_op_stats]
146
+ return join_dr, record_op_stats
147
+
148
+
149
+ class NestedLoopsJoin(JoinOp):
150
+
151
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
152
+ # estimate number of input tokens from source
153
+ est_num_input_tokens = 2 * NAIVE_EST_NUM_INPUT_TOKENS
154
+ if self.is_image_op():
155
+ est_num_input_tokens = 2 * 765 / 10 # 1024x1024 image is 765 tokens
156
+
157
+ # NOTE: the output often generates an entire reasoning sentence, thus the true value may be higher
158
+ # the filter operation's LLM call should only output TRUE or FALSE, thus we expect its
159
+ # number of output tokens to be ~1.25
160
+ est_num_output_tokens = 1.25
161
+
162
+ # get est. of conversion time per record from model card;
163
+ model_conversion_time_per_record = (
164
+ MODEL_CARDS[self.model.value]["seconds_per_output_token"] * est_num_output_tokens
165
+ )
166
+
167
+ # get est. of conversion cost (in USD) per record from model card
168
+ usd_per_input_token = (
169
+ MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
170
+ if self.is_audio_op()
171
+ else MODEL_CARDS[self.model.value]["usd_per_input_token"]
172
+ )
173
+ model_conversion_usd_per_record = (
174
+ usd_per_input_token * est_num_input_tokens
175
+ + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
176
+ )
177
+
178
+ # estimate output cardinality using a constant assumption of the filter selectivity
179
+ selectivity = NAIVE_EST_JOIN_SELECTIVITY
180
+ cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
181
+
182
+ # estimate quality of output based on the strength of the model being used
183
+ quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0)
184
+
185
+ return OperatorCostEstimates(
186
+ cardinality=cardinality,
187
+ time_per_record=model_conversion_time_per_record,
188
+ cost_per_record=model_conversion_usd_per_record,
189
+ quality=quality,
190
+ )
189
191
 
190
192
  def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
191
193
  # get the set of input fields from both records in the join
@@ -194,55 +196,86 @@ class BlockingNestedLoopsJoin(JoinOp):
194
196
  # construct kwargs for generation
195
197
  gen_kwargs = {"project_cols": input_fields, "join_condition": self.condition}
196
198
 
199
+ # create the set of candidates to join
200
+ join_candidates = []
201
+ for candidate in left_candidates:
202
+ for right_candidate in right_candidates:
203
+ join_candidates.append((candidate, right_candidate))
204
+ for right_candidate in self._right_input_records:
205
+ join_candidates.append((candidate, right_candidate))
206
+ for candidate in self._left_input_records:
207
+ for right_candidate in right_candidates:
208
+ join_candidates.append((candidate, right_candidate))
209
+
197
210
  # apply the generator to each pair of candidates
198
- output_records, output_record_op_stats, num_inputs_processed = [], [], 0
199
- total_join_candidates = len(left_candidates) * len(right_candidates)
211
+ output_records, output_record_op_stats = [], []
200
212
  with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
201
- futures = []
202
- for candidate in left_candidates:
203
- for right_candidate in right_candidates:
204
- futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
205
- num_inputs_processed += 1
206
-
213
+ futures = [
214
+ executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs)
215
+ for candidate, right_candidate in join_candidates
216
+ ]
217
+
218
+ # collect results as they complete
207
219
  for future in as_completed(futures):
208
220
  self.join_idx += 1
209
- join_output_records, join_output_record_op_stats = future.result()
210
- output_records.extend(join_output_records)
211
- output_record_op_stats.extend(join_output_record_op_stats)
212
- print(f"{self.join_idx}/{total_join_candidates} JOINED")
221
+ join_output_record, join_output_record_op_stats = future.result()
222
+ output_records.append(join_output_record)
223
+ output_record_op_stats.append(join_output_record_op_stats)
224
+ print(f"{self.join_idx} JOINED")
225
+
226
+ # compute the number of inputs processed
227
+ num_inputs_processed = len(join_candidates)
228
+
229
+ # store input records to join with new records added later
230
+ self._left_input_records.extend(left_candidates)
231
+ self._right_input_records.extend(right_candidates)
232
+
233
+ # return empty DataRecordSet if no output records were produced
234
+ if len(output_records) == 0:
235
+ return DataRecordSet([], []), num_inputs_processed
213
236
 
214
237
  return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
215
238
 
216
239
 
217
- class NestedLoopsJoin(JoinOp):
240
+ class EmbeddingJoin(JoinOp):
241
+ # NOTE: we currently do not support audio joins as embedding models for audio seem to have
242
+ # specialized use cases (e.g., speech-to-text) with strict requirements on things like e.g. sample rate
218
243
  def __init__(
219
244
  self,
220
- model: Model,
221
- prompt_strategy: PromptStrategy = PromptStrategy.COT_JOIN,
222
- join_parallelism: int = 64,
223
- reasoning_effort: str | None = None,
245
+ num_samples: int = 100,
224
246
  *args,
225
247
  **kwargs,
226
248
  ):
227
249
  super().__init__(*args, **kwargs)
228
- self.model = model
229
- self.prompt_strategy = prompt_strategy
230
- self.join_parallelism = join_parallelism
231
- self.reasoning_effort = reasoning_effort
232
- self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
233
- self.join_idx = 0
250
+ self.num_samples = num_samples
251
+ self.samples_drawn = 0
234
252
 
235
- # maintain list(s) of input records for the join
236
- self._left_input_records: list[DataRecord] = []
237
- self._right_input_records: list[DataRecord] = []
253
+ # compute whether all fields are text fields
254
+ self.text_only = all([
255
+ field.annotation not in IMAGE_FIELD_TYPES + AUDIO_FIELD_TYPES
256
+ for field_name, field in self.input_schema.model_fields.items()
257
+ if field_name.split(".")[-1] in self.get_input_fields()
258
+ ])
259
+ self.embedding_model = Model.TEXT_EMBEDDING_3_SMALL if self.text_only else Model.CLIP_VIT_B_32
260
+
261
+ # keep track of embedding costs that could not be amortized if no output records were produced
262
+ self.residual_embedding_cost = 0.0
263
+
264
+ # crude adjustment factor for naive estimation in unoptimized setting
265
+ self.naive_quality_adjustment = 0.6
266
+
267
+ # maintain list(s) of input records and their embeddings for the join
268
+ self._left_input_records: list[tuple[DataRecord, list[float]]] = []
269
+ self._right_input_records: list[tuple[DataRecord, list[float]]] = []
270
+
271
+ # maintain lowest and highest embedding similarities for matching and non-matching pairs
272
+ self.min_matching_sim = float("inf")
273
+ self.max_non_matching_sim = float("-inf")
238
274
 
239
275
  def get_id_params(self):
240
276
  id_params = super().get_id_params()
241
277
  id_params = {
242
- "model": self.model.value,
243
- "prompt_strategy": self.prompt_strategy.value,
244
- "join_parallelism": self.join_parallelism,
245
- "reasoning_effort": self.reasoning_effort,
278
+ "num_samples": self.num_samples,
246
279
  **id_params,
247
280
  }
248
281
 
@@ -251,25 +284,16 @@ class NestedLoopsJoin(JoinOp):
251
284
  def get_op_params(self):
252
285
  op_params = super().get_op_params()
253
286
  op_params = {
254
- "model": self.model,
255
- "prompt_strategy": self.prompt_strategy,
256
- "join_parallelism": self.join_parallelism,
257
- "reasoning_effort": self.reasoning_effort,
287
+ "num_samples": self.num_samples,
258
288
  **op_params,
259
289
  }
260
290
 
261
291
  return op_params
262
292
 
263
- def get_model_name(self):
264
- return self.model.value
265
-
266
- def is_image_join(self) -> bool:
267
- return self.prompt_strategy is PromptStrategy.COT_JOIN_IMAGE
268
-
269
293
  def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
270
294
  # estimate number of input tokens from source
271
295
  est_num_input_tokens = 2 * NAIVE_EST_NUM_INPUT_TOKENS
272
- if self.is_image_join():
296
+ if self.is_image_op():
273
297
  est_num_input_tokens = 2 * 765 / 10 # 1024x1024 image is 765 tokens
274
298
 
275
299
  # NOTE: the output often generates an entire reasoning sentence, thus the true value may be higher
@@ -279,18 +303,13 @@ class NestedLoopsJoin(JoinOp):
279
303
 
280
304
  # get est. of conversion time per record from model card;
281
305
  model_conversion_time_per_record = (
282
- MODEL_CARDS[self.model.value]["seconds_per_output_token"] * est_num_output_tokens
306
+ MODEL_CARDS[self.embedding_model.value]["seconds_per_output_token"] * est_num_output_tokens
283
307
  )
284
308
 
285
309
  # get est. of conversion cost (in USD) per record from model card
286
- usd_per_input_token = (
287
- MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
288
- if self.prompt_strategy.is_audio_prompt()
289
- else MODEL_CARDS[self.model.value]["usd_per_input_token"]
290
- )
291
310
  model_conversion_usd_per_record = (
292
- usd_per_input_token * est_num_input_tokens
293
- + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
311
+ MODEL_CARDS[self.embedding_model.value]["usd_per_input_token"] * est_num_input_tokens
312
+ + MODEL_CARDS[self.embedding_model.value]["usd_per_output_token"] * est_num_output_tokens
294
313
  )
295
314
 
296
315
  # estimate output cardinality using a constant assumption of the filter selectivity
@@ -298,7 +317,7 @@ class NestedLoopsJoin(JoinOp):
298
317
  cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
299
318
 
300
319
  # estimate quality of output based on the strength of the model being used
301
- quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0)
320
+ quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0) * self.naive_quality_adjustment
302
321
 
303
322
  return OperatorCostEstimates(
304
323
  cardinality=cardinality,
@@ -307,97 +326,207 @@ class NestedLoopsJoin(JoinOp):
307
326
  quality=quality,
308
327
  )
309
328
 
310
- def _process_join_candidate_pair(
311
- self,
312
- left_candidate: DataRecord,
313
- right_candidate: DataRecord,
314
- gen_kwargs: dict,
315
- ) -> tuple[list[DataRecord], list[RecordOpStats]]:
329
+ def _compute_embeddings(self, candidates: list[DataRecord], input_fields: list[str]) -> tuple[np.ndarray, GenerationStats]:
330
+ # return empty array and empty stats if no candidates
331
+ if len(candidates) == 0:
332
+ return np.zeros((0, 512)), GenerationStats()
333
+
316
334
  start_time = time.time()
335
+ total_input_tokens = 0
336
+ embeddings = None
337
+ if self.text_only:
338
+ client = OpenAI()
339
+ inputs = [dr.to_json_str(bytes_to_str=True, project_cols=input_fields, sorted=True) for dr in candidates]
340
+ response = client.embeddings.create(input=inputs, model=self.embedding_model.value)
341
+ total_input_tokens = response.usage.total_tokens
342
+ embeddings = np.array([item.embedding for item in response.data])
343
+ else:
344
+ model = SentenceTransformer(self.embedding_model.value)
345
+ embeddings = np.zeros((len(candidates), 512)) # CLIP embeddings are 512-dimensional
346
+ num_input_fields_present = 0
347
+ for field in input_fields:
348
+ field_inputs = []
349
+ for candidate in candidates:
350
+ if field not in candidate.get_field_names():
351
+ continue
352
+ num_input_fields_present += 1
353
+ field_type = candidate.get_field_type(field)
354
+ if field_type in [ImageFilepath]:
355
+ field_inputs.append(Image.open(candidate[field]))
356
+ else:
357
+ field_inputs.append(str(candidate[field]))
358
+
359
+ if len(field_inputs) > 0:
360
+ embeddings += model.encode(field_inputs, convert_to_numpy=True)
361
+
362
+ # average embeddings over input fields present in candidates
363
+ embeddings /= num_input_fields_present
364
+
365
+ # compute cost of embedding(s)
366
+ model_card = MODEL_CARDS[self.embedding_model.value]
367
+ total_input_cost = model_card["usd_per_input_token"] * total_input_tokens
368
+ embedding_gen_stats = GenerationStats(
369
+ model_name=self.embedding_model.value,
370
+ total_input_tokens=total_input_tokens,
371
+ total_output_tokens=0.0,
372
+ total_input_cost=total_input_cost,
373
+ total_output_cost=0.0,
374
+ cost_per_record=total_input_cost,
375
+ llm_call_duration_secs=time.time() - start_time,
376
+ total_llm_calls=1,
377
+ total_embedding_llm_calls=len(candidates),
378
+ )
317
379
 
318
- # generate output; NOTE: FieldInfo is used to indicate the output type; thus, the desc is not needed
319
- fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the records satisfy the join condition")}
320
- field_answers, _, generation_stats, _ = self.generator(left_candidate, fields, right_candidate=right_candidate, **gen_kwargs)
380
+ return embeddings, embedding_gen_stats
321
381
 
322
- # determine whether or not the join was satisfied
323
- passed_operator = field_answers["passed_operator"]
382
+ def _process_join_candidate_pair(self, left_candidate, right_candidate, gen_kwargs, embedding_sim):
383
+ output_record, output_record_op_stats = super()._process_join_candidate_pair(left_candidate, right_candidate, gen_kwargs)
384
+ return output_record, output_record_op_stats, embedding_sim
324
385
 
386
+ def _process_join_candidate_with_sim(self, left_candidate: DataRecord, right_candidate: DataRecord, passed_operator: bool) -> tuple[DataRecord, RecordOpStats]:
325
387
  # compute output record and add to output_records
326
388
  join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
327
- join_dr.passed_operator = passed_operator
389
+ join_dr._passed_operator = passed_operator
328
390
 
391
+ # NOTE: embedding costs are amortized over all records and added at the end of __call__
329
392
  # compute record stats and add to output_record_op_stats
330
393
  record_op_stats = RecordOpStats(
331
- record_id=join_dr.id,
332
- record_parent_ids=join_dr.parent_ids,
333
- record_source_indices=join_dr.source_indices,
394
+ record_id=join_dr._id,
395
+ record_parent_ids=join_dr._parent_ids,
396
+ record_source_indices=join_dr._source_indices,
334
397
  record_state=join_dr.to_dict(include_bytes=False),
335
398
  full_op_id=self.get_full_op_id(),
336
399
  logical_op_id=self.logical_op_id,
337
400
  op_name=self.op_name(),
338
- time_per_record=time.time() - start_time,
339
- cost_per_record=generation_stats.cost_per_record,
401
+ time_per_record=0.0,
402
+ cost_per_record=0.0,
340
403
  model_name=self.get_model_name(),
341
404
  join_condition=self.condition,
342
- total_input_tokens=generation_stats.total_input_tokens,
343
- total_output_tokens=generation_stats.total_output_tokens,
344
- total_input_cost=generation_stats.total_input_cost,
345
- total_output_cost=generation_stats.total_output_cost,
346
- llm_call_duration_secs=generation_stats.llm_call_duration_secs,
347
- fn_call_duration_secs=generation_stats.fn_call_duration_secs,
348
- total_llm_calls=generation_stats.total_llm_calls,
349
- total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
350
- answer=field_answers,
405
+ answer={"passed_operator": passed_operator},
351
406
  passed_operator=passed_operator,
352
- image_operation=self.is_image_join(),
353
407
  op_details={k: str(v) for k, v in self.get_id_params().items()},
354
408
  )
355
409
 
356
- return [join_dr], [record_op_stats]
410
+ return join_dr, record_op_stats
357
411
 
358
- def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet | None, int]:
412
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
359
413
  # get the set of input fields from both records in the join
360
414
  input_fields = self.get_input_fields()
361
415
 
416
+ # compute the embeding for each candidate
417
+ left_embeddings, left_embedding_gen_stats = self._compute_embeddings(left_candidates, input_fields)
418
+ right_embeddings, right_embedding_gen_stats = self._compute_embeddings(right_candidates, input_fields)
419
+ total_embedding_cost = left_embedding_gen_stats.cost_per_record + right_embedding_gen_stats.cost_per_record + self.residual_embedding_cost
420
+ self.residual_embedding_cost = 0.0
421
+
362
422
  # construct kwargs for generation
363
423
  gen_kwargs = {"project_cols": input_fields, "join_condition": self.condition}
364
424
 
365
- # apply the generator to each pair of candidates
425
+ # TODO: add embeddings to join candidates
426
+ # create the set of candidates to join
427
+ join_candidates = []
428
+ for candidate, embedding in zip(left_candidates, left_embeddings):
429
+ for right_candidate, right_embedding in zip(right_candidates, right_embeddings):
430
+ embedding_sim = compute_similarity(embedding, right_embedding)
431
+ join_candidates.append((candidate, right_candidate, embedding_sim))
432
+ for right_candidate, right_embedding in self._right_input_records:
433
+ embedding_sim = compute_similarity(embedding, right_embedding)
434
+ join_candidates.append((candidate, right_candidate, embedding_sim))
435
+ for candidate, embedding in self._left_input_records:
436
+ for right_candidate, right_embedding in zip(right_candidates, right_embeddings):
437
+ embedding_sim = compute_similarity(embedding, right_embedding)
438
+ join_candidates.append((candidate, right_candidate, embedding_sim))
439
+
440
+ # prepare list of output records and their stats
366
441
  output_records, output_record_op_stats, num_inputs_processed = [], [], 0
367
- with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
368
- futures = []
369
- # join new left candidates with new right candidates
370
- for candidate in left_candidates:
371
- for right_candidate in right_candidates:
372
- futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
373
- num_inputs_processed += 1
374
442
 
375
- # join new left candidates with stored right input records
376
- for candidate in left_candidates:
377
- for right_candidate in self._right_input_records:
378
- futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
379
- num_inputs_processed += 1
443
+ # draw samples until num_samples is reached
444
+ if self.samples_drawn < self.num_samples:
445
+ samples_to_draw = min(self.num_samples - self.samples_drawn, len(join_candidates))
446
+ join_candidate_samples = join_candidates[:samples_to_draw]
447
+ join_candidates = join_candidates[samples_to_draw:]
448
+
449
+ # apply the generator to each pair of candidates
450
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
451
+ futures = [
452
+ executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim)
453
+ for left_candidate, right_candidate, embedding_sim in join_candidate_samples
454
+ ]
455
+
456
+ # collect results as they complete
457
+ for future in as_completed(futures):
458
+ self.join_idx += 1
459
+ join_output_record, join_output_record_op_stats, embedding_sim = future.result()
460
+ output_records.append(join_output_record)
461
+ output_record_op_stats.append(join_output_record_op_stats)
462
+ print(f"{self.join_idx} JOINED")
463
+
464
+ # update similarity thresholds
465
+ records_joined = join_output_record._passed_operator
466
+ if not records_joined and embedding_sim > self.max_non_matching_sim:
467
+ self.max_non_matching_sim = embedding_sim
468
+ if records_joined and embedding_sim < self.min_matching_sim:
469
+ self.min_matching_sim = embedding_sim
470
+
471
+ # update samples drawn and num_inputs_processed
472
+ self.samples_drawn += samples_to_draw
473
+ num_inputs_processed += samples_to_draw
474
+
475
+ # process remaining candidates based on embedding similarity
476
+ if len(join_candidates) > 0:
477
+ assert self.samples_drawn == self.num_samples, "All samples should have been drawn before processing remaining candidates"
478
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
479
+ futures = []
480
+ for left_candidate, right_candidate, embedding_sim in join_candidates:
481
+ llm_call_needed = self.min_matching_sim <= embedding_sim <= self.max_non_matching_sim
482
+
483
+ if llm_call_needed:
484
+ futures.append(executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim))
485
+
486
+ elif embedding_sim < self.min_matching_sim:
487
+ self.join_idx += 1
488
+ output_record, record_op_stats = self._process_join_candidate_with_sim(left_candidate, right_candidate, passed_operator=False)
489
+ output_records.append(output_record)
490
+ output_record_op_stats.append(record_op_stats)
491
+ print(f"{self.join_idx} SKIPPED (low sim: {embedding_sim:.4f} < {self.min_matching_sim:.4f})")
492
+
493
+ elif embedding_sim > self.max_non_matching_sim:
494
+ self.join_idx += 1
495
+ output_record, record_op_stats = self._process_join_candidate_with_sim(left_candidate, right_candidate, passed_operator=True)
496
+ output_records.append(output_record)
497
+ output_record_op_stats.append(record_op_stats)
498
+ print(f"{self.join_idx} JOINED (high sim: {embedding_sim:.4f} > {self.max_non_matching_sim:.4f})")
380
499
 
381
- # join new right candidates with stored left input records
382
- for candidate in self._left_input_records:
383
- for right_candidate in right_candidates:
384
- futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
385
500
  num_inputs_processed += 1
386
501
 
387
- # collect results as they complete
388
- for future in as_completed(futures):
389
- self.join_idx += 1
390
- join_output_records, join_output_record_op_stats = future.result()
391
- output_records.extend(join_output_records)
392
- output_record_op_stats.extend(join_output_record_op_stats)
393
- print(f"{self.join_idx} JOINED")
502
+ # collect results as they complete
503
+ for future in as_completed(futures):
504
+ self.join_idx += 1
505
+ join_output_record, join_output_record_op_stats, embedding_sim = future.result()
506
+ output_records.append(join_output_record)
507
+ output_record_op_stats.append(join_output_record_op_stats)
508
+ print(f"{self.join_idx} JOINED")
509
+
510
+ # update similarity thresholds
511
+ records_joined = join_output_record._passed_operator
512
+ if not records_joined and embedding_sim > self.max_non_matching_sim:
513
+ self.max_non_matching_sim = embedding_sim
514
+ if records_joined and embedding_sim < self.min_matching_sim:
515
+ self.min_matching_sim = embedding_sim
516
+
517
+ # amortize embedding costs over all output records and add to each record's op stats
518
+ amortized_embedding_cost = total_embedding_cost / len(output_record_op_stats) if len(output_record_op_stats) > 0 else 0.0
519
+ for record_op_stats in output_record_op_stats:
520
+ record_op_stats.cost_per_record += amortized_embedding_cost
521
+ record_op_stats.total_embedding_cost = amortized_embedding_cost
394
522
 
395
523
  # store input records to join with new records added later
396
- self._left_input_records.extend(left_candidates)
397
- self._right_input_records.extend(right_candidates)
524
+ self._left_input_records.extend(zip(left_candidates, left_embeddings))
525
+ self._right_input_records.extend(zip(right_candidates, right_embeddings))
398
526
 
399
- # return None if no output records were produced
527
+ # return empty DataRecordSet if no output records were produced
400
528
  if len(output_records) == 0:
401
- return None, num_inputs_processed
529
+ self.residual_embedding_cost = total_embedding_cost
530
+ return DataRecordSet([], []), num_inputs_processed
402
531
 
403
532
  return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed