palimpzest 0.7.21__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.21.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+
3
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
4
+ from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
5
+ from palimpzest.query.operators.physical import PhysicalOperator
6
+
7
+
8
+ class DistinctOp(PhysicalOperator):
9
+ def __init__(self, distinct_cols: list[str], distinct_seen: set | None = None, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+ self.distinct_cols = distinct_cols
12
+ self._distinct_seen = set() if distinct_seen is None else distinct_seen
13
+
14
+ def __str__(self):
15
+ op = super().__str__()
16
+ op += f" Distinct Cols: {self.distinct_cols}\n"
17
+ return op
18
+
19
+ def get_id_params(self):
20
+ id_params = super().get_id_params()
21
+ return {"distinct_cols": self.distinct_cols, **id_params}
22
+
23
+ def get_op_params(self):
24
+ op_params = super().get_op_params()
25
+ return {"distinct_cols": self.distinct_cols, "distinct_seen": self._distinct_seen, **op_params}
26
+
27
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
28
+ # assume applying the distinct operator takes negligible additional time (and no cost in USD)
29
+ return OperatorCostEstimates(
30
+ cardinality=source_op_cost_estimates.cardinality,
31
+ time_per_record=0,
32
+ cost_per_record=0,
33
+ quality=1.0,
34
+ )
35
+
36
+ def __call__(self, candidate: DataRecord) -> DataRecordSet:
37
+ # create new DataRecord
38
+ dr = DataRecord.from_parent(schema=candidate.schema, parent_record=candidate)
39
+
40
+ # output record only if it has not been seen before
41
+ record_str = dr.to_json_str(project_cols=self.distinct_cols, bytes_to_str=True, sorted=True)
42
+ record_hash = f"{hash(record_str)}"
43
+ dr.passed_operator = record_hash not in self._distinct_seen
44
+ if dr.passed_operator:
45
+ self._distinct_seen.add(record_hash)
46
+
47
+ # create RecordOpStats object
48
+ record_op_stats = RecordOpStats(
49
+ record_id=dr.id,
50
+ record_parent_ids=dr.parent_ids,
51
+ record_source_indices=dr.source_indices,
52
+ record_state=dr.to_dict(include_bytes=False),
53
+ full_op_id=self.get_full_op_id(),
54
+ logical_op_id=self.logical_op_id,
55
+ op_name=self.op_name(),
56
+ time_per_record=0.0,
57
+ cost_per_record=0.0,
58
+ passed_operator=dr.passed_operator,
59
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
60
+ )
61
+
62
+ return DataRecordSet([dr], [record_op_stats])
@@ -4,6 +4,8 @@ import time
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import Any
6
6
 
7
+ from pydantic.fields import FieldInfo
8
+
7
9
  from palimpzest.constants import (
8
10
  MODEL_CARDS,
9
11
  NAIVE_EST_FILTER_SELECTIVITY,
@@ -12,19 +14,17 @@ from palimpzest.constants import (
12
14
  Model,
13
15
  PromptStrategy,
14
16
  )
15
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates, RecordOpStats
16
17
  from palimpzest.core.elements.filters import Filter
17
18
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
18
- from palimpzest.core.lib.fields import BooleanField
19
- from palimpzest.query.generators.generators import generator_factory
19
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, RecordOpStats
20
+ from palimpzest.query.generators.generators import Generator
20
21
  from palimpzest.query.operators.physical import PhysicalOperator
21
- from palimpzest.utils.model_helpers import get_vision_models
22
22
 
23
23
 
24
24
  class FilterOp(PhysicalOperator, ABC):
25
25
  def __init__(self, filter: Filter, *args, **kwargs):
26
26
  super().__init__(*args, **kwargs)
27
- assert self.input_schema.get_desc() == self.output_schema.get_desc(), "Input and output schemas must match for FilterOp"
27
+ assert self.input_schema == self.output_schema, "Input and output schemas must match for FilterOp"
28
28
  self.filter_obj = filter
29
29
 
30
30
  def __str__(self):
@@ -81,8 +81,8 @@ class FilterOp(PhysicalOperator, ABC):
81
81
  # create RecordOpStats object
82
82
  record_op_stats = RecordOpStats(
83
83
  record_id=dr.id,
84
- record_parent_id=dr.parent_id,
85
- record_source_idx=dr.source_idx,
84
+ record_parent_ids=dr.parent_ids,
85
+ record_source_indices=dr.source_indices,
86
86
  record_state=dr.to_dict(include_bytes=False),
87
87
  full_op_id=self.get_full_op_id(),
88
88
  logical_op_id=self.logical_op_id,
@@ -174,19 +174,22 @@ class LLMFilter(FilterOp):
174
174
  self,
175
175
  model: Model,
176
176
  prompt_strategy: PromptStrategy = PromptStrategy.COT_BOOL,
177
+ reasoning_effort: str | None = None,
177
178
  *args,
178
179
  **kwargs,
179
180
  ):
180
181
  super().__init__(*args, **kwargs)
181
182
  self.model = model
182
183
  self.prompt_strategy = prompt_strategy
183
- self.generator = generator_factory(model, prompt_strategy, Cardinality.ONE_TO_ONE, self.verbose)
184
+ self.reasoning_effort = reasoning_effort
185
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.verbose)
184
186
 
185
187
  def get_id_params(self):
186
188
  id_params = super().get_id_params()
187
189
  id_params = {
188
190
  "model": self.model.value,
189
191
  "prompt_strategy": self.prompt_strategy.value,
192
+ "reasoning_effort": self.reasoning_effort,
190
193
  **id_params,
191
194
  }
192
195
 
@@ -197,6 +200,7 @@ class LLMFilter(FilterOp):
197
200
  op_params = {
198
201
  "model": self.model,
199
202
  "prompt_strategy": self.prompt_strategy,
203
+ "reasoning_effort": self.reasoning_effort,
200
204
  **op_params,
201
205
  }
202
206
 
@@ -206,7 +210,7 @@ class LLMFilter(FilterOp):
206
210
  return self.model.value
207
211
 
208
212
  def is_image_filter(self) -> bool:
209
- return self.model in get_vision_models()
213
+ return self.prompt_strategy is PromptStrategy.COT_BOOL_IMAGE
210
214
 
211
215
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates):
212
216
  # estimate number of input tokens from source
@@ -225,8 +229,13 @@ class LLMFilter(FilterOp):
225
229
  )
226
230
 
227
231
  # get est. of conversion cost (in USD) per record from model card
232
+ usd_per_input_token = (
233
+ MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
234
+ if self.prompt_strategy.is_audio_prompt()
235
+ else MODEL_CARDS[self.model.value]["usd_per_input_token"]
236
+ )
228
237
  model_conversion_usd_per_record = (
229
- MODEL_CARDS[self.model.value]["usd_per_input_token"] * est_num_input_tokens
238
+ usd_per_input_token * est_num_input_tokens
230
239
  + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
231
240
  )
232
241
 
@@ -235,7 +244,7 @@ class LLMFilter(FilterOp):
235
244
  cardinality = selectivity * source_op_cost_estimates.cardinality
236
245
 
237
246
  # estimate quality of output based on the strength of the model being used
238
- quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0) * source_op_cost_estimates.quality
247
+ quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0)
239
248
 
240
249
  return OperatorCostEstimates(
241
250
  cardinality=cardinality,
@@ -251,8 +260,8 @@ class LLMFilter(FilterOp):
251
260
  # construct kwargs for generation
252
261
  gen_kwargs = {"project_cols": input_fields, "filter_condition": self.filter_obj.filter_condition}
253
262
 
254
- # generate output; NOTE: BooleanField is used to indicate the output type; thus, the desc is not needed
255
- fields = {"passed_operator": BooleanField(desc="")}
263
+ # generate output; NOTE: FieldInfo is used to indicate the output type; thus, the desc is not needed
264
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the record passed the filter operation")}
256
265
  field_answers, _, generation_stats, _ = self.generator(candidate, fields, **gen_kwargs)
257
266
 
258
267
  return field_answers, generation_stats
@@ -0,0 +1,402 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from abc import ABC, abstractmethod
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+
7
+ from pydantic.fields import FieldInfo
8
+
9
+ from palimpzest.constants import (
10
+ MODEL_CARDS,
11
+ NAIVE_EST_JOIN_SELECTIVITY,
12
+ NAIVE_EST_NUM_INPUT_TOKENS,
13
+ Cardinality,
14
+ Model,
15
+ PromptStrategy,
16
+ )
17
+ from palimpzest.core.elements.records import DataRecord, DataRecordSet
18
+ from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
19
+ from palimpzest.query.generators.generators import Generator
20
+ from palimpzest.query.operators.physical import PhysicalOperator
21
+
22
+
23
+ class JoinOp(PhysicalOperator, ABC):
24
+ def __init__(self, condition: str, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ assert self.input_schema == self.output_schema, "Input and output schemas must match for JoinOp"
27
+ self.condition = condition
28
+
29
+ def __str__(self):
30
+ op = super().__str__()
31
+ op += f" Condition: {self.condition}\n"
32
+ return op
33
+
34
+ def get_id_params(self):
35
+ id_params = super().get_id_params()
36
+ return {"condition": self.condition, **id_params}
37
+
38
+ def get_op_params(self):
39
+ op_params = super().get_op_params()
40
+ return {"condition": self.condition, **op_params}
41
+
42
+ @abstractmethod
43
+ def is_image_join(self) -> bool:
44
+ """Return True if the join operation processes image(s), False otherwise."""
45
+ pass
46
+
47
+ @abstractmethod
48
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
49
+ pass
50
+
51
+
52
+ class BlockingNestedLoopsJoin(JoinOp):
53
+ def __init__(
54
+ self,
55
+ model: Model,
56
+ prompt_strategy: PromptStrategy = PromptStrategy.COT_JOIN,
57
+ join_parallelism: int = 64,
58
+ reasoning_effort: str | None = None,
59
+ *args,
60
+ **kwargs,
61
+ ):
62
+ super().__init__(*args, **kwargs)
63
+ self.model = model
64
+ self.prompt_strategy = prompt_strategy
65
+ self.join_parallelism = join_parallelism
66
+ self.reasoning_effort = reasoning_effort
67
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.verbose)
68
+ self.join_idx = 0
69
+
70
+ def get_id_params(self):
71
+ id_params = super().get_id_params()
72
+ id_params = {
73
+ "model": self.model.value,
74
+ "prompt_strategy": self.prompt_strategy.value,
75
+ "join_parallelism": self.join_parallelism,
76
+ "reasoning_effort": self.reasoning_effort,
77
+ **id_params,
78
+ }
79
+
80
+ return id_params
81
+
82
+ def get_op_params(self):
83
+ op_params = super().get_op_params()
84
+ op_params = {
85
+ "model": self.model,
86
+ "prompt_strategy": self.prompt_strategy,
87
+ "join_parallelism": self.join_parallelism,
88
+ "reasoning_effort": self.reasoning_effort,
89
+ **op_params,
90
+ }
91
+
92
+ return op_params
93
+
94
+ def get_model_name(self):
95
+ return self.model.value
96
+
97
+ def is_image_join(self) -> bool:
98
+ return self.prompt_strategy is PromptStrategy.COT_JOIN_IMAGE
99
+
100
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
101
+ # estimate number of input tokens from source
102
+ est_num_input_tokens = 2 * NAIVE_EST_NUM_INPUT_TOKENS
103
+ if self.is_image_join():
104
+ est_num_input_tokens = 2 * 765 / 10 # 1024x1024 image is 765 tokens
105
+
106
+ # NOTE: the output often generates an entire reasoning sentence, thus the true value may be higher
107
+ # the filter operation's LLM call should only output TRUE or FALSE, thus we expect its
108
+ # number of output tokens to be ~1.25
109
+ est_num_output_tokens = 1.25
110
+
111
+ # get est. of conversion time per record from model card;
112
+ model_conversion_time_per_record = (
113
+ MODEL_CARDS[self.model.value]["seconds_per_output_token"] * est_num_output_tokens
114
+ )
115
+
116
+ # get est. of conversion cost (in USD) per record from model card
117
+ usd_per_input_token = (
118
+ MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
119
+ if self.prompt_strategy.is_audio_prompt()
120
+ else MODEL_CARDS[self.model.value]["usd_per_input_token"]
121
+ )
122
+ model_conversion_usd_per_record = (
123
+ usd_per_input_token * est_num_input_tokens
124
+ + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
125
+ )
126
+
127
+ # estimate output cardinality using a constant assumption of the filter selectivity
128
+ selectivity = NAIVE_EST_JOIN_SELECTIVITY
129
+ cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
130
+
131
+ # estimate quality of output based on the strength of the model being used
132
+ quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0)
133
+
134
+ return OperatorCostEstimates(
135
+ cardinality=cardinality,
136
+ time_per_record=model_conversion_time_per_record,
137
+ cost_per_record=model_conversion_usd_per_record,
138
+ quality=quality,
139
+ )
140
+
141
+ def _process_join_candidate_pair(
142
+ self,
143
+ left_candidate: DataRecord,
144
+ right_candidate: DataRecord,
145
+ gen_kwargs: dict,
146
+ ) -> tuple[list[DataRecord], list[RecordOpStats]]:
147
+ start_time = time.time()
148
+
149
+ # generate output; NOTE: FieldInfo is used to indicate the output type; thus, the desc is not needed
150
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the records satisfy the join condition")}
151
+ field_answers, _, generation_stats, _ = self.generator(left_candidate, fields, right_candidate=right_candidate, **gen_kwargs)
152
+
153
+ # determine whether or not the join was satisfied
154
+ passed_operator = field_answers["passed_operator"]
155
+
156
+ # compute output record and add to output_records
157
+ join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
158
+ join_dr.passed_operator = passed_operator
159
+
160
+ # compute record stats and add to output_record_op_stats
161
+ record_op_stats = RecordOpStats(
162
+ record_id=join_dr.id,
163
+ record_parent_ids=join_dr.parent_ids,
164
+ record_source_indices=join_dr.source_indices,
165
+ record_state=join_dr.to_dict(include_bytes=False),
166
+ full_op_id=self.get_full_op_id(),
167
+ logical_op_id=self.logical_op_id,
168
+ op_name=self.op_name(),
169
+ time_per_record=time.time() - start_time,
170
+ cost_per_record=generation_stats.cost_per_record,
171
+ model_name=self.get_model_name(),
172
+ join_condition=self.condition,
173
+ total_input_tokens=generation_stats.total_input_tokens,
174
+ total_output_tokens=generation_stats.total_output_tokens,
175
+ total_input_cost=generation_stats.total_input_cost,
176
+ total_output_cost=generation_stats.total_output_cost,
177
+ llm_call_duration_secs=generation_stats.llm_call_duration_secs,
178
+ fn_call_duration_secs=generation_stats.fn_call_duration_secs,
179
+ total_llm_calls=generation_stats.total_llm_calls,
180
+ total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
181
+ answer=field_answers,
182
+ passed_operator=passed_operator,
183
+ image_operation=self.is_image_join(),
184
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
185
+ )
186
+
187
+ return [join_dr], [record_op_stats]
188
+
189
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet, int]:
190
+ # get the set of input fields from both records in the join
191
+ input_fields = self.get_input_fields()
192
+
193
+ # construct kwargs for generation
194
+ gen_kwargs = {"project_cols": input_fields, "join_condition": self.condition}
195
+
196
+ # apply the generator to each pair of candidates
197
+ output_records, output_record_op_stats, num_inputs_processed = [], [], 0
198
+ total_join_candidates = len(left_candidates) * len(right_candidates)
199
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
200
+ futures = []
201
+ for candidate in left_candidates:
202
+ for right_candidate in right_candidates:
203
+ futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
204
+ num_inputs_processed += 1
205
+
206
+ for future in as_completed(futures):
207
+ self.join_idx += 1
208
+ join_output_records, join_output_record_op_stats = future.result()
209
+ output_records.extend(join_output_records)
210
+ output_record_op_stats.extend(join_output_record_op_stats)
211
+ print(f"{self.join_idx}/{total_join_candidates} JOINED")
212
+
213
+ return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
214
+
215
+
216
+ class NestedLoopsJoin(JoinOp):
217
+ def __init__(
218
+ self,
219
+ model: Model,
220
+ prompt_strategy: PromptStrategy = PromptStrategy.COT_JOIN,
221
+ join_parallelism: int = 64,
222
+ reasoning_effort: str | None = None,
223
+ *args,
224
+ **kwargs,
225
+ ):
226
+ super().__init__(*args, **kwargs)
227
+ self.model = model
228
+ self.prompt_strategy = prompt_strategy
229
+ self.join_parallelism = join_parallelism
230
+ self.reasoning_effort = reasoning_effort
231
+ self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.verbose)
232
+ self.join_idx = 0
233
+
234
+ # maintain list(s) of input records for the join
235
+ self._left_input_records: list[DataRecord] = []
236
+ self._right_input_records: list[DataRecord] = []
237
+
238
+ def get_id_params(self):
239
+ id_params = super().get_id_params()
240
+ id_params = {
241
+ "model": self.model.value,
242
+ "prompt_strategy": self.prompt_strategy.value,
243
+ "join_parallelism": self.join_parallelism,
244
+ "reasoning_effort": self.reasoning_effort,
245
+ **id_params,
246
+ }
247
+
248
+ return id_params
249
+
250
+ def get_op_params(self):
251
+ op_params = super().get_op_params()
252
+ op_params = {
253
+ "model": self.model,
254
+ "prompt_strategy": self.prompt_strategy,
255
+ "join_parallelism": self.join_parallelism,
256
+ "reasoning_effort": self.reasoning_effort,
257
+ **op_params,
258
+ }
259
+
260
+ return op_params
261
+
262
+ def get_model_name(self):
263
+ return self.model.value
264
+
265
+ def is_image_join(self) -> bool:
266
+ return self.prompt_strategy is PromptStrategy.COT_JOIN_IMAGE
267
+
268
+ def naive_cost_estimates(self, left_source_op_cost_estimates: OperatorCostEstimates, right_source_op_cost_estimates: OperatorCostEstimates):
269
+ # estimate number of input tokens from source
270
+ est_num_input_tokens = 2 * NAIVE_EST_NUM_INPUT_TOKENS
271
+ if self.is_image_join():
272
+ est_num_input_tokens = 2 * 765 / 10 # 1024x1024 image is 765 tokens
273
+
274
+ # NOTE: the output often generates an entire reasoning sentence, thus the true value may be higher
275
+ # the filter operation's LLM call should only output TRUE or FALSE, thus we expect its
276
+ # number of output tokens to be ~1.25
277
+ est_num_output_tokens = 1.25
278
+
279
+ # get est. of conversion time per record from model card;
280
+ model_conversion_time_per_record = (
281
+ MODEL_CARDS[self.model.value]["seconds_per_output_token"] * est_num_output_tokens
282
+ )
283
+
284
+ # get est. of conversion cost (in USD) per record from model card
285
+ usd_per_input_token = (
286
+ MODEL_CARDS[self.model.value]["usd_per_audio_input_token"]
287
+ if self.prompt_strategy.is_audio_prompt()
288
+ else MODEL_CARDS[self.model.value]["usd_per_input_token"]
289
+ )
290
+ model_conversion_usd_per_record = (
291
+ usd_per_input_token * est_num_input_tokens
292
+ + MODEL_CARDS[self.model.value]["usd_per_output_token"] * est_num_output_tokens
293
+ )
294
+
295
+ # estimate output cardinality using a constant assumption of the filter selectivity
296
+ selectivity = NAIVE_EST_JOIN_SELECTIVITY
297
+ cardinality = selectivity * (left_source_op_cost_estimates.cardinality * right_source_op_cost_estimates.cardinality)
298
+
299
+ # estimate quality of output based on the strength of the model being used
300
+ quality = (MODEL_CARDS[self.model.value]["overall"] / 100.0)
301
+
302
+ return OperatorCostEstimates(
303
+ cardinality=cardinality,
304
+ time_per_record=model_conversion_time_per_record,
305
+ cost_per_record=model_conversion_usd_per_record,
306
+ quality=quality,
307
+ )
308
+
309
+ def _process_join_candidate_pair(
310
+ self,
311
+ left_candidate: DataRecord,
312
+ right_candidate: DataRecord,
313
+ gen_kwargs: dict,
314
+ ) -> tuple[list[DataRecord], list[RecordOpStats]]:
315
+ start_time = time.time()
316
+
317
+ # generate output; NOTE: FieldInfo is used to indicate the output type; thus, the desc is not needed
318
+ fields = {"passed_operator": FieldInfo(annotation=bool, description="Whether the records satisfy the join condition")}
319
+ field_answers, _, generation_stats, _ = self.generator(left_candidate, fields, right_candidate=right_candidate, **gen_kwargs)
320
+
321
+ # determine whether or not the join was satisfied
322
+ passed_operator = field_answers["passed_operator"]
323
+
324
+ # compute output record and add to output_records
325
+ join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
326
+ join_dr.passed_operator = passed_operator
327
+
328
+ # compute record stats and add to output_record_op_stats
329
+ record_op_stats = RecordOpStats(
330
+ record_id=join_dr.id,
331
+ record_parent_ids=join_dr.parent_ids,
332
+ record_source_indices=join_dr.source_indices,
333
+ record_state=join_dr.to_dict(include_bytes=False),
334
+ full_op_id=self.get_full_op_id(),
335
+ logical_op_id=self.logical_op_id,
336
+ op_name=self.op_name(),
337
+ time_per_record=time.time() - start_time,
338
+ cost_per_record=generation_stats.cost_per_record,
339
+ model_name=self.get_model_name(),
340
+ join_condition=self.condition,
341
+ total_input_tokens=generation_stats.total_input_tokens,
342
+ total_output_tokens=generation_stats.total_output_tokens,
343
+ total_input_cost=generation_stats.total_input_cost,
344
+ total_output_cost=generation_stats.total_output_cost,
345
+ llm_call_duration_secs=generation_stats.llm_call_duration_secs,
346
+ fn_call_duration_secs=generation_stats.fn_call_duration_secs,
347
+ total_llm_calls=generation_stats.total_llm_calls,
348
+ total_embedding_llm_calls=generation_stats.total_embedding_llm_calls,
349
+ answer=field_answers,
350
+ passed_operator=passed_operator,
351
+ image_operation=self.is_image_join(),
352
+ op_details={k: str(v) for k, v in self.get_id_params().items()},
353
+ )
354
+
355
+ return [join_dr], [record_op_stats]
356
+
357
+ def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord]) -> tuple[DataRecordSet | None, int]:
358
+ # get the set of input fields from both records in the join
359
+ input_fields = self.get_input_fields()
360
+
361
+ # construct kwargs for generation
362
+ gen_kwargs = {"project_cols": input_fields, "join_condition": self.condition}
363
+
364
+ # apply the generator to each pair of candidates
365
+ output_records, output_record_op_stats, num_inputs_processed = [], [], 0
366
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
367
+ futures = []
368
+ # join new left candidates with new right candidates
369
+ for candidate in left_candidates:
370
+ for right_candidate in right_candidates:
371
+ futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
372
+ num_inputs_processed += 1
373
+
374
+ # join new left candidates with stored right input records
375
+ for candidate in left_candidates:
376
+ for right_candidate in self._right_input_records:
377
+ futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
378
+ num_inputs_processed += 1
379
+
380
+ # join new right candidates with stored left input records
381
+ for candidate in self._left_input_records:
382
+ for right_candidate in right_candidates:
383
+ futures.append(executor.submit(self._process_join_candidate_pair, candidate, right_candidate, gen_kwargs))
384
+ num_inputs_processed += 1
385
+
386
+ # collect results as they complete
387
+ for future in as_completed(futures):
388
+ self.join_idx += 1
389
+ join_output_records, join_output_record_op_stats = future.result()
390
+ output_records.extend(join_output_records)
391
+ output_record_op_stats.extend(join_output_record_op_stats)
392
+ print(f"{self.join_idx} JOINED")
393
+
394
+ # store input records to join with new records added later
395
+ self._left_input_records.extend(left_candidates)
396
+ self._right_input_records.extend(right_candidates)
397
+
398
+ # return None if no output records were produced
399
+ if len(output_records) == 0:
400
+ return None, num_inputs_processed
401
+
402
+ return DataRecordSet(output_records, output_record_op_stats), num_inputs_processed
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from palimpzest.core.data.dataclasses import OperatorCostEstimates, RecordOpStats
4
3
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
4
+ from palimpzest.core.models import OperatorCostEstimates, RecordOpStats
5
5
  from palimpzest.query.operators.physical import PhysicalOperator
6
6
 
7
7
 
@@ -41,8 +41,8 @@ class LimitScanOp(PhysicalOperator):
41
41
  # create RecordOpStats object
42
42
  record_op_stats = RecordOpStats(
43
43
  record_id=dr.id,
44
- record_parent_id=dr.parent_id,
45
- record_source_idx=dr.source_idx,
44
+ record_parent_ids=dr.parent_ids,
45
+ record_source_indices=dr.source_indices,
46
46
  record_state=dr.to_dict(include_bytes=False),
47
47
  full_op_id=self.get_full_op_id(),
48
48
  logical_op_id=self.logical_op_id,