palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +343 -209
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +639 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +62 -6
  19. palimpzest/prompts/filter_prompts.py +51 -6
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
  22. palimpzest/prompts/prompt_factory.py +375 -47
  23. palimpzest/prompts/split_proposer_prompts.py +1 -1
  24. palimpzest/prompts/util_phrases.py +5 -0
  25. palimpzest/prompts/validator.py +239 -0
  26. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  27. palimpzest/query/execution/execution_strategy.py +210 -317
  28. palimpzest/query/execution/execution_strategy_type.py +5 -7
  29. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  30. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  31. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  32. palimpzest/query/generators/generators.py +160 -331
  33. palimpzest/query/operators/__init__.py +15 -5
  34. palimpzest/query/operators/aggregate.py +50 -33
  35. palimpzest/query/operators/compute.py +201 -0
  36. palimpzest/query/operators/convert.py +33 -19
  37. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  38. palimpzest/query/operators/distinct.py +62 -0
  39. palimpzest/query/operators/filter.py +26 -16
  40. palimpzest/query/operators/join.py +403 -0
  41. palimpzest/query/operators/limit.py +3 -3
  42. palimpzest/query/operators/logical.py +205 -77
  43. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  44. palimpzest/query/operators/physical.py +27 -21
  45. palimpzest/query/operators/project.py +3 -3
  46. palimpzest/query/operators/rag_convert.py +7 -7
  47. palimpzest/query/operators/retrieve.py +9 -9
  48. palimpzest/query/operators/scan.py +81 -42
  49. palimpzest/query/operators/search.py +524 -0
  50. palimpzest/query/operators/split_convert.py +10 -8
  51. palimpzest/query/optimizer/__init__.py +7 -9
  52. palimpzest/query/optimizer/cost_model.py +108 -441
  53. palimpzest/query/optimizer/optimizer.py +123 -181
  54. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  55. palimpzest/query/optimizer/plan.py +352 -67
  56. palimpzest/query/optimizer/primitives.py +43 -19
  57. palimpzest/query/optimizer/rules.py +484 -646
  58. palimpzest/query/optimizer/tasks.py +127 -58
  59. palimpzest/query/processor/config.py +42 -76
  60. palimpzest/query/processor/query_processor.py +73 -18
  61. palimpzest/query/processor/query_processor_factory.py +46 -38
  62. palimpzest/schemabuilder/schema_builder.py +15 -28
  63. palimpzest/utils/model_helpers.py +32 -77
  64. palimpzest/utils/progress.py +114 -102
  65. palimpzest/validator/__init__.py +0 -0
  66. palimpzest/validator/validator.py +306 -0
  67. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
  68. palimpzest-0.8.1.dist-info/RECORD +95 -0
  69. palimpzest/core/lib/fields.py +0 -141
  70. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  71. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  72. palimpzest/query/generators/api_client_factory.py +0 -30
  73. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  74. palimpzest/query/operators/map.py +0 -130
  75. palimpzest/query/processor/nosentinel_processor.py +0 -33
  76. palimpzest/query/processor/processing_strategy_type.py +0 -28
  77. palimpzest/query/processor/sentinel_processor.py +0 -88
  78. palimpzest/query/processor/streaming_processor.py +0 -149
  79. palimpzest/sets.py +0 -405
  80. palimpzest/utils/datareader_helpers.py +0 -61
  81. palimpzest/utils/demo_helpers.py +0 -75
  82. palimpzest/utils/field_helpers.py +0 -69
  83. palimpzest/utils/generation_helpers.py +0 -69
  84. palimpzest/utils/sandbox.py +0 -183
  85. palimpzest-0.7.21.dist-info/RECORD +0 -95
  86. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
  88. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
  89. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+
3
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
4
+ from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
5
+ from palimpzest.query.operators.physical import PhysicalOperator
6
+
7
+
8
+ class DistinctOp(PhysicalOperator):
9
+ def __init__(self, distinct_cols: list[str], distinct_seen: set | None = None, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+ self.distinct_cols = distinct_cols
12
+ self._distinct_seen = set() if distinct_seen is None else distinct_seen
13
+
14
+ def __str__(self):
15
+ op = super().__str__()
16
+ op += f" Distinct Cols: {self.distinct_cols}\n"
17
+ return op
18
+
19
+ def get_id_params(self):
20
+ id_params = super().get_id_params()
21
+ return {"distinct_cols": self.distinct_cols, **id_params}
22
+
23
+ def get_op_params(self):
24
+ op_params = super().get_op_params()
25
+ return {"distinct_cols": self.distinct_cols, "distinct_seen": self._distinct_seen, **op_params}
26
+
27
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
28
+ # assume applying the distinct operator takes negligible additional time (and no cost in USD)
29
+ return OperatorCostEstimates(
30
+ cardinality=source_op_cost_estimates.cardinality,
31
+ time_per_record=0,
32
+ cost_per_record=0,
33
+ quality=1.0,
34
+ )
35
+
36
+ def __call__(self, candidate: DataRecord) -> DataRecordSet:
37
+ # create new DataRecord
38
+ dr = DataRecord.from_parent(schema=candidate.schema, parent_record=candidate)
39
+
40
+ # output record only if it has not been seen before
41
+ record_str = dr.to_json_str(project_cols=self.distinct_cols, bytes_to_str=True, sorted=True)
42
+ record_hash = f"{hash(record_str)}"
43
+ dr.passed_operator = record_hash not in self._distinct_seen
44
+ if dr.passed_operator:
45
+ self._distinct_seen.add(record_hash)
46
+
47
+ # create RecordOpStats object
48
+ record_op_stats = RecordOpStats(
49
+ record_id=dr.id,
50
+ record_parent_ids=dr.parent_ids,
51
+ record_source_indices=dr.source_indices,
52
+ record_state=dr.to_dict(include_bytes=False),
53
+ full_op_id=self.get_full_op_id(),
54
+ logical_op_id=self.logical_op_id,
55
+ op_name=self.op_name(),
56
+ time_per_record=0.0,
57
+ cost_per_record=0.0,
58
+ passed_operator=dr.passed_operator,
59
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
60
+ )
61
+
62
+ return DataRecordSet([dr], [record_op_stats])
@@ -4,6 +4,8 @@ import time
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import Any
6
6
 
7
+ from pydantic.fields import FieldInfo
8
+
7
9
  from palimpzest.constants import (
8
10
  MODEL_CARDS,
9
11
  NAIVE_EST_FILTER_SELECTIVITY,
@@ -12,20 +14,19 @@ from palimpzest.constants import (
12
14
  Model,
13
15
  PromptStrategy,
14
16
  )
15
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates, RecordOpStats
16
17
  from palimpzest.core.elements.filters import Filter
17
18
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
18
- from palimpzest.core.lib.fields import BooleanField
19
- from palimpzest.query.generators.generators import generator_factory
19
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
20
+ from palimpzest.query.generators.generators import Generator
20
21
  from palimpzest.query.operators.physical import PhysicalOperator
21
- from palimpzest.utils.model_helpers import get_vision_models
22
22
 
23
23
 
24
24
  class FilterOp(PhysicalOperator, ABC):
25
- def __init__(self, filter: Filter, *args, **kwargs):
25
+ def __init__(self, filter: Filter, desc: str | None = None, *args, **kwargs):
26
26
  super().__init__(*args, **kwargs)
27
- assert self.input_schema.get_desc() == self.output_schema.get_desc(), "Input and output schemas must match for FilterOp"
27
+ assert self.input_schema == self.output_schema, "Input and output schemas must match for FilterOp"
28
28
  self.filter_obj = filter
29
+ self.desc = desc
29
30
 
30
31
  def __str__(self):
31
32
  op = super().__str__()
@@ -34,11 +35,11 @@ class FilterOp(PhysicalOperator, ABC):
34
35
 
35
36
  def get_id_params(self):
36
37
  id_params = super().get_id_params()
37
- return {"filter": str(self.filter_obj), **id_params}
38
+ return {"filter": str(self.filter_obj), "desc": self.desc, **id_params}
38
39
 
39
40
  def get_op_params(self):
40
41
  op_params = super().get_op_params()
41
- return {"filter": self.filter_obj, **op_params}
42
+ return {"filter": self.filter_obj, "desc": self.desc, **op_params}
42
43
 
43
44
  @abstractmethod
44
45
  def is_image_filter(self) -> bool:
@@ -81,8 +82,8 @@ class FilterOp(PhysicalOperator, ABC):
81
82
  # create RecordOpStats object
82
83
  record_op_stats = RecordOpStats(
83
84
  record_id=dr.id,
84
- record_parent_id=dr.parent_id,
85
- record_source_idx=dr.source_idx,
85
+ record_parent_ids=dr.parent_ids,
86
+ record_source_indices=dr.source_indices,
86
87
  record_state=dr.to_dict(include_bytes=False),
87
88
  full_op_id=self.get_full_op_id(),
88
89
  logical_op_id=self.logical_op_id,
@@ -174,19 +175,22 @@ class LLMFilter(FilterOp):
174
175
  self,
175
176
  model: Model,
176
177
  prompt_strategy: PromptStrategy = PromptStrategy.COT_BOOL,
178
+ reasoning_effort: str | None = None,
177
179
  *args,
178
180
  **kwargs,
179
181
  ):
180
182
  super().__init__(*args, **kwargs)
181
183
  self.model = model
182
184
  self.prompt_strategy = prompt_strategy
183
- self.generator = generator_factory(model, prompt_strategy, Cardinality.ONE_TO_ONE, self.verbose)
185
+ self.reasoning_effort = reasoning_effort
186
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
184
187
 
185
188
  def get_id_params(self):
186
189
  id_params = super().get_id_params()
187
190
  id_params = {
188
191
  "model": self.model.value,
189
192
  "prompt_strategy": self.prompt_strategy.value,
193
+ "reasoning_effort": self.reasoning_effort,
190
194
  **id_params,
191
195
  }
192
196
 
@@ -197,6 +201,7 @@ class LLMFilter(FilterOp):
197
201
  op_params = {
198
202
  "model": self.model,
199
203
  "prompt_strategy": self.prompt_strategy,
204
+ "reasoning_effort": self.reasoning_effort,
200
205
  **op_params,
201
206
  }
202
207
 
@@ -206,7 +211,7 @@ class LLMFilter(FilterOp):
206
211
  return self.model.value
207
212
 
208
213
  def is_image_filter(self) -> bool:
209
- return self.model in get_vision_models()
214
+ return self.prompt_strategy is PromptStrategy.COT_BOOL_IMAGE
210
215
 
211
216
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates):
212
217
  # estimate number of input tokens from source
@@ -225,8 +230,13 @@ class LLMFilter(FilterOp):
225
230
  )
226
231
 
227
232
  # get est. of conversion cost (in USD) per record from model card
233
+ usd_per_input_token = (
234
+ MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
235
+ if self.prompt_strategy.is_audio_prompt()
236
+ else MODEL_CARDS[self.model.value]["usd_per_input_token"]
237
+ )
228
238
  model_conversion_usd_per_record = (
229
- MODEL_CARDS[self.model.value]["usd_per_input_token"] * est_num_input_tokens
239
+ usd_per_input_token * est_num_input_tokens
230
240
  + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
231
241
  )
232
242
 
@@ -235,7 +245,7 @@ class LLMFilter(FilterOp):
235
245
  cardinality = selectivity * source_op_cost_estimates.cardinality
236
246
 
237
247
  # estimate quality of output based on the strength of the model being used
238
- quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0) * source_op_cost_estimates.quality
248
+ quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0)
239
249
 
240
250
  return OperatorCostEstimates(
241
251
  cardinality=cardinality,
@@ -251,8 +261,8 @@ class LLMFilter(FilterOp):
251
261
  # construct kwargs for generation
252
262
  gen_kwargs = {"project_cols": input_fields, "filter_condition": self.filter_obj.filter_condition}
253
263
 
254
- # generate output; NOTE: BooleanField is used to indicate the output type; thus, the desc is not needed
255
- fields = {"passed_operator": BooleanField(desc="")}
264
+ # generate output; NOTE: FieldInfo is used to indicate the output type; thus, the desc is not needed
265
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the record passed the filter operation")}
256
266
  field_answers, _, generation_stats, _ = self.generator(candidate, fields, **gen_kwargs)
257
267
 
258
268
  return field_answers, generation_stats
@@ -0,0 +1,403 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from abc import ABC, abstractmethod
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+
7
+ from pydantic.fields import FieldInfo
8
+
9
+ from palimpzest.constants import (
10
+ MODEL_CARDS,
11
+ NAIVE_EST_JOIN_SELECTIVITY,
12
+ NAIVE_EST_NUM_INPUT_TOKENS,
13
+ Cardinality,
14
+ Model,
15
+ PromptStrategy,
16
+ )
17
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
18
+ from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
19
+ from palimpzest.query.generators.generators import Generator
20
+ from palimpzest.query.operators.physical import PhysicalOperator
21
+
22
+
23
+ class JoinOp(PhysicalOperator, ABC):
24
+ def __init__(self, condition: str, desc: str | None = None, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ assert self.input_schema == self.output_schema, "Input and output schemas must match for JoinOp"
27
+ self.condition = condition
28
+ self.desc = desc
29
+
30
+ def __str__(self):
31
+ op = super().__str__()
32
+ op += f" Condition: {self.condition}\n"
33
+ return op
34
+
35
+ def get_id_params(self):
36
+ id_params = super().get_id_params()
37
+ return {"condition": self.condition, "desc": self.desc, **id_params}
38
+
39
+ def get_op_params(self):
40
+ op_params = super().get_op_params()
41
+ return {"condition": self.condition, "desc": self.desc, **op_params}
42
+
43
+ @abstractmethod
44
+ def is_image_join(self) -> bool:
45
+ """Return True if the join operation processes image(s), False otherwise."""
46
+ pass
47
+
48
+ @abstractmethod
49
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
50
+ pass
51
+
52
+
53
+ class BlockingNestedLoopsJoin(JoinOp):
54
+ def __init__(
55
+ self,
56
+ model: Model,
57
+ prompt_strategy: PromptStrategy = PromptStrategy.COT_JOIN,
58
+ join_parallelism: int = 64,
59
+ reasoning_effort: str | None = None,
60
+ *args,
61
+ **kwargs,
62
+ ):
63
+ super().__init__(*args, **kwargs)
64
+ self.model = model
65
+ self.prompt_strategy = prompt_strategy
66
+ self.join_parallelism = join_parallelism
67
+ self.reasoning_effort = reasoning_effort
68
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
69
+ self.join_idx = 0
70
+
71
+ def get_id_params(self):
72
+ id_params = super().get_id_params()
73
+ id_params = {
74
+ "model": self.model.value,
75
+ "prompt_strategy": self.prompt_strategy.value,
76
+ "join_parallelism": self.join_parallelism,
77
+ "reasoning_effort": self.reasoning_effort,
78
+ **id_params,
79
+ }
80
+
81
+ return id_params
82
+
83
+ def get_op_params(self):
84
+ op_params = super().get_op_params()
85
+ op_params = {
86
+ "model": self.model,
87
+ "prompt_strategy": self.prompt_strategy,
88
+ "join_parallelism": self.join_parallelism,
89
+ "reasoning_effort": self.reasoning_effort,
90
+ **op_params,
91
+ }
92
+
93
+ return op_params
94
+
95
+ def get_model_name(self):
96
+ return self.model.value
97
+
98
+ def is_image_join(self) -> bool:
99
+ return self.prompt_strategy is PromptStrategy.COT_JOIN_IMAGE
100
+
101
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
102
+ # estimate number of input tokens from source
103
+ est_num_input_tokens = 2 * NAIVE_EST_NUM_INPUT_TOKENS
104
+ if self.is_image_join():
105
+ est_num_input_tokens = 2 * 765 / 10 # 1024x1024 image is 765 tokens
106
+
107
+ # NOTE: the output often generates an entire reasoning sentence, thus the true value may be higher
108
+ # the filter operation's LLM call should only output TRUE or FALSE, thus we expect its
109
+ # number of output tokens to be ~1.25
110
+ est_num_output_tokens = 1.25
111
+
112
+ # get est. of conversion time per record from model card;
113
+ model_conversion_time_per_record = (
114
+ MODEL_CARDS[self.model.value]["seconds_per_output_token"] * est_num_output_tokens
115
+ )
116
+
117
+ # get est. of conversion cost (in USD) per record from model card
118
+ usd_per_input_token = (
119
+ MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
120
+ if self.prompt_strategy.is_audio_prompt()
121
+ else MODEL_CARDS[self.model.value]["usd_per_input_token"]
122
+ )
123
+ model_conversion_usd_per_record = (
124
+ usd_per_input_token * est_num_input_tokens
125
+ + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
126
+ )
127
+
128
+ # estimate output cardinality using a constant assumption of the filter selectivity
129
+ selectivity = NAIVE_EST_JOIN_SELECTIVITY
130
+ cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
131
+
132
+ # estimate quality of output based on the strength of the model being used
133
+ quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0)
134
+
135
+ return OperatorCostEstimates(
136
+ cardinality=cardinality,
137
+ time_per_record=model_conversion_time_per_record,
138
+ cost_per_record=model_conversion_usd_per_record,
139
+ quality=quality,
140
+ )
141
+
142
+ def _process_join_candidate_pair(
143
+ self,
144
+ left_candidate: DataRecord,
145
+ right_candidate: DataRecord,
146
+ gen_kwargs: dict,
147
+ ) -> tuple[list[DataRecord], list[RecordOpStats]]:
148
+ start_time = time.time()
149
+
150
+ # generate output; NOTE: FieldInfo is used to indicate the output type; thus, the desc is not needed
151
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the records satisfy the join condition")}
152
+ field_answers, _, generation_stats, _ = self.generator(left_candidate, fields, right_candidate=right_candidate, **gen_kwargs)
153
+
154
+ # determine whether or not the join was satisfied
155
+ passed_operator = field_answers["passed_operator"]
156
+
157
+ # compute output record and add to output_records
158
+ join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
159
+ join_dr.passed_operator = passed_operator
160
+
161
+ # compute record stats and add to output_record_op_stats
162
+ record_op_stats = RecordOpStats(
163
+ record_id=join_dr.id,
164
+ record_parent_ids=join_dr.parent_ids,
165
+ record_source_indices=join_dr.source_indices,
166
+ record_state=join_dr.to_dict(include_bytes=False),
167
+ full_op_id=self.get_full_op_id(),
168
+ logical_op_id=self.logical_op_id,
169
+ op_name=self.op_name(),
170
+ time_per_record=time.time() - start_time,
171
+ cost_per_record=generation_stats.cost_per_record,
172
+ model_name=self.get_model_name(),
173
+ join_condition=self.condition,
174
+ total_input_tokens=generation_stats.total_input_tokens,
175
+ total_output_tokens=generation_stats.total_output_tokens,
176
+ total_input_cost=generation_stats.total_input_cost,
177
+ total_output_cost=generation_stats.total_output_cost,
178
+ llm_call_duration_secs=generation_stats.llm_call_duration_secs,
179
+ fn_call_duration_secs=generation_stats.fn_call_duration_secs,
180
+ total_llm_calls=generation_stats.total_llm_calls,
181
+ total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
182
+ answer=field_answers,
183
+ passed_operator=passed_operator,
184
+ image_operation=self.is_image_join(),
185
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
186
+ )
187
+
188
+ return [join_dr], [record_op_stats]
189
+
190
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
191
+ # get the set of input fields from both records in the join
192
+ input_fields = self.get_input_fields()
193
+
194
+ # construct kwargs for generation
195
+ gen_kwargs = {"project_cols": input_fields, "join_condition": self.condition}
196
+
197
+ # apply the generator to each pair of candidates
198
+ output_records, output_record_op_stats, num_inputs_processed = [], [], 0
199
+ total_join_candidates = len(left_candidates) * len(right_candidates)
200
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
201
+ futures = []
202
+ for candidate in left_candidates:
203
+ for right_candidate in right_candidates:
204
+ futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
205
+ num_inputs_processed += 1
206
+
207
+ for future in as_completed(futures):
208
+ self.join_idx += 1
209
+ join_output_records, join_output_record_op_stats = future.result()
210
+ output_records.extend(join_output_records)
211
+ output_record_op_stats.extend(join_output_record_op_stats)
212
+ print(f"{self.join_idx}/{total_join_candidates} JOINED")
213
+
214
+ return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
215
+
216
+
217
+ class NestedLoopsJoin(JoinOp):
218
+ def __init__(
219
+ self,
220
+ model: Model,
221
+ prompt_strategy: PromptStrategy = PromptStrategy.COT_JOIN,
222
+ join_parallelism: int = 64,
223
+ reasoning_effort: str | None = None,
224
+ *args,
225
+ **kwargs,
226
+ ):
227
+ super().__init__(*args, **kwargs)
228
+ self.model = model
229
+ self.prompt_strategy = prompt_strategy
230
+ self.join_parallelism = join_parallelism
231
+ self.reasoning_effort = reasoning_effort
232
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
233
+ self.join_idx = 0
234
+
235
+ # maintain list(s) of input records for the join
236
+ self._left_input_records: list[DataRecord] = []
237
+ self._right_input_records: list[DataRecord] = []
238
+
239
+ def get_id_params(self):
240
+ id_params = super().get_id_params()
241
+ id_params = {
242
+ "model": self.model.value,
243
+ "prompt_strategy": self.prompt_strategy.value,
244
+ "join_parallelism": self.join_parallelism,
245
+ "reasoning_effort": self.reasoning_effort,
246
+ **id_params,
247
+ }
248
+
249
+ return id_params
250
+
251
+ def get_op_params(self):
252
+ op_params = super().get_op_params()
253
+ op_params = {
254
+ "model": self.model,
255
+ "prompt_strategy": self.prompt_strategy,
256
+ "join_parallelism": self.join_parallelism,
257
+ "reasoning_effort": self.reasoning_effort,
258
+ **op_params,
259
+ }
260
+
261
+ return op_params
262
+
263
+ def get_model_name(self):
264
+ return self.model.value
265
+
266
+ def is_image_join(self) -> bool:
267
+ return self.prompt_strategy is PromptStrategy.COT_JOIN_IMAGE
268
+
269
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
270
+ # estimate number of input tokens from source
271
+ est_num_input_tokens = 2 * NAIVE_EST_NUM_INPUT_TOKENS
272
+ if self.is_image_join():
273
+ est_num_input_tokens = 2 * 765 / 10 # 1024x1024 image is 765 tokens
274
+
275
+ # NOTE: the output often generates an entire reasoning sentence, thus the true value may be higher
276
+ # the filter operation's LLM call should only output TRUE or FALSE, thus we expect its
277
+ # number of output tokens to be ~1.25
278
+ est_num_output_tokens = 1.25
279
+
280
+ # get est. of conversion time per record from model card;
281
+ model_conversion_time_per_record = (
282
+ MODEL_CARDS[self.model.value]["seconds_per_output_token"] * est_num_output_tokens
283
+ )
284
+
285
+ # get est. of conversion cost (in USD) per record from model card
286
+ usd_per_input_token = (
287
+ MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
288
+ if self.prompt_strategy.is_audio_prompt()
289
+ else MODEL_CARDS[self.model.value]["usd_per_input_token"]
290
+ )
291
+ model_conversion_usd_per_record = (
292
+ usd_per_input_token * est_num_input_tokens
293
+ + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
294
+ )
295
+
296
+ # estimate output cardinality using a constant assumption of the filter selectivity
297
+ selectivity = NAIVE_EST_JOIN_SELECTIVITY
298
+ cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
299
+
300
+ # estimate quality of output based on the strength of the model being used
301
+ quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0)
302
+
303
+ return OperatorCostEstimates(
304
+ cardinality=cardinality,
305
+ time_per_record=model_conversion_time_per_record,
306
+ cost_per_record=model_conversion_usd_per_record,
307
+ quality=quality,
308
+ )
309
+
310
+ def _process_join_candidate_pair(
311
+ self,
312
+ left_candidate: DataRecord,
313
+ right_candidate: DataRecord,
314
+ gen_kwargs: dict,
315
+ ) -> tuple[list[DataRecord], list[RecordOpStats]]:
316
+ start_time = time.time()
317
+
318
+ # generate output; NOTE: FieldInfo is used to indicate the output type; thus, the desc is not needed
319
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the records satisfy the join condition")}
320
+ field_answers, _, generation_stats, _ = self.generator(left_candidate, fields, right_candidate=right_candidate, **gen_kwargs)
321
+
322
+ # determine whether or not the join was satisfied
323
+ passed_operator = field_answers["passed_operator"]
324
+
325
+ # compute output record and add to output_records
326
+ join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
327
+ join_dr.passed_operator = passed_operator
328
+
329
+ # compute record stats and add to output_record_op_stats
330
+ record_op_stats = RecordOpStats(
331
+ record_id=join_dr.id,
332
+ record_parent_ids=join_dr.parent_ids,
333
+ record_source_indices=join_dr.source_indices,
334
+ record_state=join_dr.to_dict(include_bytes=False),
335
+ full_op_id=self.get_full_op_id(),
336
+ logical_op_id=self.logical_op_id,
337
+ op_name=self.op_name(),
338
+ time_per_record=time.time() - start_time,
339
+ cost_per_record=generation_stats.cost_per_record,
340
+ model_name=self.get_model_name(),
341
+ join_condition=self.condition,
342
+ total_input_tokens=generation_stats.total_input_tokens,
343
+ total_output_tokens=generation_stats.total_output_tokens,
344
+ total_input_cost=generation_stats.total_input_cost,
345
+ total_output_cost=generation_stats.total_output_cost,
346
+ llm_call_duration_secs=generation_stats.llm_call_duration_secs,
347
+ fn_call_duration_secs=generation_stats.fn_call_duration_secs,
348
+ total_llm_calls=generation_stats.total_llm_calls,
349
+ total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
350
+ answer=field_answers,
351
+ passed_operator=passed_operator,
352
+ image_operation=self.is_image_join(),
353
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
354
+ )
355
+
356
+ return [join_dr], [record_op_stats]
357
+
358
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet | None, int]:
359
+ # get the set of input fields from both records in the join
360
+ input_fields = self.get_input_fields()
361
+
362
+ # construct kwargs for generation
363
+ gen_kwargs = {"project_cols": input_fields, "join_condition": self.condition}
364
+
365
+ # apply the generator to each pair of candidates
366
+ output_records, output_record_op_stats, num_inputs_processed = [], [], 0
367
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
368
+ futures = []
369
+ # join new left candidates with new right candidates
370
+ for candidate in left_candidates:
371
+ for right_candidate in right_candidates:
372
+ futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
373
+ num_inputs_processed += 1
374
+
375
+ # join new left candidates with stored right input records
376
+ for candidate in left_candidates:
377
+ for right_candidate in self._right_input_records:
378
+ futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
379
+ num_inputs_processed += 1
380
+
381
+ # join new right candidates with stored left input records
382
+ for candidate in self._left_input_records:
383
+ for right_candidate in right_candidates:
384
+ futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
385
+ num_inputs_processed += 1
386
+
387
+ # collect results as they complete
388
+ for future in as_completed(futures):
389
+ self.join_idx += 1
390
+ join_output_records, join_output_record_op_stats = future.result()
391
+ output_records.extend(join_output_records)
392
+ output_record_op_stats.extend(join_output_record_op_stats)
393
+ print(f"{self.join_idx} JOINED")
394
+
395
+ # store input records to join with new records added later
396
+ self._left_input_records.extend(left_candidates)
397
+ self._right_input_records.extend(right_candidates)
398
+
399
+ # return None if no output records were produced
400
+ if len(output_records) == 0:
401
+ return None, num_inputs_processed
402
+
403
+ return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from palimpzest.core.data.dataclasses import OperatorCostEstimates, RecordOpStats
4
3
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
4
+ from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
5
5
  from palimpzest.query.operators.physical import PhysicalOperator
6
6
 
7
7
 
@@ -41,8 +41,8 @@ class LimitScanOp(PhysicalOperator):
41
41
  # create RecordOpStats object
42
42
  record_op_stats = RecordOpStats(
43
43
  record_id=dr.id,
44
- record_parent_id=dr.parent_id,
45
- record_source_idx=dr.source_idx,
44
+ record_parent_ids=dr.parent_ids,
45
+ record_source_indices=dr.source_indices,
46
46
  record_state=dr.to_dict(include_bytes=False),
47
47
  full_op_id=self.get_full_op_id(),
48
48
  logical_op_id=self.logical_op_id,