palimpzest 0.5.4__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. palimpzest/__init__.py +7 -9
  2. palimpzest/constants.py +47 -7
  3. palimpzest/core/__init__.py +20 -26
  4. palimpzest/core/data/dataclasses.py +9 -2
  5. palimpzest/core/data/datareaders.py +497 -0
  6. palimpzest/core/elements/records.py +29 -37
  7. palimpzest/core/lib/fields.py +14 -12
  8. palimpzest/core/lib/schemas.py +80 -94
  9. palimpzest/policy.py +58 -0
  10. palimpzest/prompts/__init__.py +22 -0
  11. palimpzest/prompts/code_synthesis_prompts.py +28 -0
  12. palimpzest/prompts/convert_prompts.py +87 -0
  13. palimpzest/prompts/critique_and_refine_convert_prompts.py +216 -0
  14. palimpzest/prompts/filter_prompts.py +69 -0
  15. palimpzest/prompts/moa_aggregator_convert_prompts.py +57 -0
  16. palimpzest/prompts/moa_proposer_convert_prompts.py +79 -0
  17. palimpzest/prompts/prompt_factory.py +732 -0
  18. palimpzest/prompts/util_phrases.py +14 -0
  19. palimpzest/query/execution/execution_strategy.py +0 -3
  20. palimpzest/query/execution/parallel_execution_strategy.py +12 -25
  21. palimpzest/query/execution/single_threaded_execution_strategy.py +31 -45
  22. palimpzest/query/generators/generators.py +71 -347
  23. palimpzest/query/operators/__init__.py +5 -5
  24. palimpzest/query/operators/aggregate.py +10 -5
  25. palimpzest/query/operators/code_synthesis_convert.py +4 -48
  26. palimpzest/query/operators/convert.py +5 -2
  27. palimpzest/query/operators/critique_and_refine_convert.py +112 -0
  28. palimpzest/query/operators/filter.py +1 -1
  29. palimpzest/query/operators/limit.py +1 -1
  30. palimpzest/query/operators/logical.py +28 -27
  31. palimpzest/query/operators/mixture_of_agents_convert.py +4 -1
  32. palimpzest/query/operators/physical.py +32 -20
  33. palimpzest/query/operators/project.py +1 -1
  34. palimpzest/query/operators/rag_convert.py +6 -3
  35. palimpzest/query/operators/retrieve.py +13 -31
  36. palimpzest/query/operators/scan.py +150 -0
  37. palimpzest/query/optimizer/__init__.py +5 -1
  38. palimpzest/query/optimizer/cost_model.py +18 -34
  39. palimpzest/query/optimizer/optimizer.py +40 -25
  40. palimpzest/query/optimizer/optimizer_strategy.py +26 -0
  41. palimpzest/query/optimizer/plan.py +2 -2
  42. palimpzest/query/optimizer/rules.py +118 -27
  43. palimpzest/query/processor/config.py +12 -1
  44. palimpzest/query/processor/mab_sentinel_processor.py +125 -112
  45. palimpzest/query/processor/nosentinel_processor.py +46 -62
  46. palimpzest/query/processor/query_processor.py +10 -20
  47. palimpzest/query/processor/query_processor_factory.py +12 -5
  48. palimpzest/query/processor/random_sampling_sentinel_processor.py +112 -91
  49. palimpzest/query/processor/streaming_processor.py +11 -17
  50. palimpzest/sets.py +170 -94
  51. palimpzest/tools/pdfparser.py +5 -64
  52. palimpzest/utils/datareader_helpers.py +61 -0
  53. palimpzest/utils/field_helpers.py +69 -0
  54. palimpzest/utils/hash_helpers.py +3 -2
  55. palimpzest/utils/udfs.py +0 -28
  56. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/METADATA +49 -49
  57. palimpzest-0.6.1.dist-info/RECORD +87 -0
  58. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/top_level.txt +0 -1
  59. cli/README.md +0 -156
  60. cli/__init__.py +0 -0
  61. cli/cli_main.py +0 -390
  62. palimpzest/config.py +0 -89
  63. palimpzest/core/data/datasources.py +0 -369
  64. palimpzest/datamanager/__init__.py +0 -0
  65. palimpzest/datamanager/datamanager.py +0 -300
  66. palimpzest/prompts.py +0 -397
  67. palimpzest/query/operators/datasource.py +0 -202
  68. palimpzest-0.5.4.dist-info/RECORD +0 -83
  69. palimpzest-0.5.4.dist-info/entry_points.txt +0 -2
  70. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/LICENSE +0 -0
  71. {palimpzest-0.5.4.dist-info → palimpzest-0.6.1.dist-info}/WHEEL +0 -0
@@ -5,7 +5,6 @@ from typing import Any
5
5
  from palimpzest.constants import Cardinality, GPT_4o_MODEL_CARD, Model
6
6
  from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
7
7
  from palimpzest.core.elements.records import DataRecord
8
- from palimpzest.datamanager.datamanager import DataDirectory
9
8
  from palimpzest.prompts import ADVICEGEN_PROMPT, CODEGEN_PROMPT, EXAMPLE_PROMPT
10
9
  from palimpzest.query.generators.generators import code_ensemble_execution, generator_factory
11
10
  from palimpzest.query.operators.convert import LLMConvert, LLMConvertBonded, LLMConvertConventional
@@ -26,18 +25,18 @@ class CodeSynthesisConvert(LLMConvert):
26
25
  exemplar_generation_model: Model = Model.GPT_4o,
27
26
  code_synth_model: Model = Model.GPT_4o,
28
27
  conventional_fallback_model: Model = Model.GPT_4o_MINI,
29
- cache_across_plans: bool = True,
30
28
  *args,
31
29
  **kwargs,
32
30
  ):
33
31
  kwargs["model"] = None
34
32
  super().__init__(*args, **kwargs)
33
+
34
+ # set models
35
35
  self.exemplar_generation_model = exemplar_generation_model
36
36
  self.code_synth_model = code_synth_model
37
37
  self.conventional_fallback_model = conventional_fallback_model
38
- self.cache_across_plans = cache_across_plans
39
38
 
40
- # initialize optimization-specific parameters
39
+ # initialize parameters
41
40
  self.field_to_code_ensemble = None
42
41
  self.exemplars = []
43
42
  self.code_synthesized = False
@@ -47,15 +46,6 @@ class CodeSynthesisConvert(LLMConvert):
47
46
  cardinality=Cardinality.ONE_TO_ONE,
48
47
  verbose=self.verbose,
49
48
  )
50
-
51
- # read the list of exemplars already generated by this operator if present
52
- if self.cache_across_plans:
53
- cache = DataDirectory().get_cache_service()
54
- exemplars_cache_id = self.get_op_id()
55
- exemplars = cache.get_cached_data("codeExemplars", exemplars_cache_id)
56
- # set and return exemplars if it is not empty
57
- if exemplars is not None and isinstance(exemplars, list) and len(exemplars) > 0:
58
- self.exemplars = exemplars
59
49
  self.field_to_code_ensemble = {}
60
50
 
61
51
  def __str__(self):
@@ -80,7 +70,6 @@ class CodeSynthesisConvert(LLMConvert):
80
70
  "exemplar_generation_model": self.exemplar_generation_model,
81
71
  "code_synth_model": self.code_synth_model,
82
72
  "conventional_fallback_model": self.conventional_fallback_model,
83
- "cache_across_plans": self.cache_across_plans,
84
73
  **op_params,
85
74
  }
86
75
 
@@ -109,23 +98,6 @@ class CodeSynthesisConvert(LLMConvert):
109
98
 
110
99
  return naive_op_cost_estimates
111
100
 
112
- def _fetch_cached_code(self, fields_to_generate: list[str]) -> dict[CodeName, Code]:
113
- # if we are allowed to cache synthesized code across plan executions, check the cache
114
- field_to_code_ensemble = {}
115
- cache = DataDirectory().get_cache_service()
116
- for field_name in fields_to_generate:
117
- code_ensemble_cache_id = "_".join([self.get_op_id(), field_name])
118
- code_ensemble = cache.get_cached_data("codeEnsembles", code_ensemble_cache_id)
119
- if code_ensemble is not None:
120
- field_to_code_ensemble[field_name] = code_ensemble
121
-
122
- # set and return field_to_code_ensemble if all fields are present and have code
123
- if all([field_to_code_ensemble.get(field_name) is not None for field_name in fields_to_generate]):
124
- self.field_to_code_ensemble = field_to_code_ensemble
125
- return self.field_to_code_ensemble
126
- else:
127
- return {}
128
-
129
101
  def _should_synthesize(
130
102
  self, exemplars: list[Exemplar], num_exemplars: int = 1, code_regenerate_frequency: int = 200, *args, **kwargs
131
103
  ) -> bool:
@@ -168,12 +140,6 @@ class CodeSynthesisConvert(LLMConvert):
168
140
  field_to_code_ensemble[field_name] = code_ensemble
169
141
  generation_stats += code_synth_stats
170
142
 
171
- # add code ensemble to the cache
172
- if self.cache_across_plans:
173
- cache = DataDirectory().get_cache_service()
174
- code_ensemble_cache_id = "_".join([self.get_op_id(), field_name])
175
- cache.put_cached_data("codeEnsembles", code_ensemble_cache_id, code_ensemble)
176
-
177
143
  if self.verbose:
178
144
  for code_name, code in code_ensemble.items():
179
145
  print(f"CODE NAME: {code_name}")
@@ -184,7 +150,7 @@ class CodeSynthesisConvert(LLMConvert):
184
150
  return field_to_code_ensemble, generation_stats
185
151
 
186
152
  def _bonded_query_fallback(self, candidate: DataRecord) -> tuple[dict[FieldName, list[Any] | None], GenerationStats]:
187
- fields_to_generate = self.get_fields_to_generate(candidate, self.input_schema, self.output_schema)
153
+ fields_to_generate = self.get_fields_to_generate(candidate)
188
154
  projected_candidate = candidate.copy(include_bytes=False, project_cols=self.depends_on)
189
155
 
190
156
  # execute the bonded convert
@@ -209,12 +175,6 @@ class CodeSynthesisConvert(LLMConvert):
209
175
  exemplars = [(projected_candidate.to_dict(include_bytes=False), dr.to_dict(include_bytes=False)) for dr in drs]
210
176
  self.exemplars.extend(exemplars)
211
177
 
212
- # if we are allowed to cache exemplars across plan executions, add exemplars to cache
213
- if self.cache_across_plans:
214
- cache = DataDirectory().get_cache_service()
215
- exemplars_cache_id = self.get_op_id()
216
- cache.put_cached_data("codeExemplars", exemplars_cache_id, exemplars)
217
-
218
178
  return field_answers, generation_stats
219
179
 
220
180
  def is_image_conversion(self):
@@ -231,10 +191,6 @@ class CodeSynthesisConvert(LLMConvert):
231
191
  self.field_to_code_ensemble, total_code_synth_stats = self.synthesize_code_ensemble(fields, candidate)
232
192
  self.code_synthesized = True
233
193
  generation_stats += total_code_synth_stats
234
- else:
235
- # read the dictionary of ensembles already synthesized by this operator if present
236
- if self.cache_across_plans:
237
- self.field_to_code_ensemble = self._fetch_cached_code(fields)
238
194
 
239
195
  # if we have yet to synthesize code (perhaps b/c we are waiting for more exemplars),
240
196
  # use the exemplar generation model to perform the convert (and generate high-quality
@@ -120,7 +120,7 @@ class ConvertOp(PhysicalOperator, ABC):
120
120
  RecordOpStats(
121
121
  record_id=dr.id,
122
122
  record_parent_id=dr.parent_id,
123
- record_source_id=dr.source_id,
123
+ record_source_idx=dr.source_idx,
124
124
  record_state=dr.to_dict(include_bytes=False),
125
125
  op_id=self.get_op_id(),
126
126
  logical_op_id=self.logical_op_id,
@@ -183,7 +183,7 @@ class ConvertOp(PhysicalOperator, ABC):
183
183
  start_time = time.time()
184
184
 
185
185
  # get fields to generate with this convert
186
- fields_to_generate = self.get_fields_to_generate(candidate, self.input_schema, self.output_schema)
186
+ fields_to_generate = self.get_fields_to_generate(candidate)
187
187
 
188
188
  # execute the convert
189
189
  field_answers: dict[str, list]
@@ -276,6 +276,9 @@ class NonLLMConvert(ConvertOp):
276
276
 
277
277
 
278
278
  class LLMConvert(ConvertOp):
279
+ """
280
+ This is the base class for convert operations which use an LLM to generate the output fields.
281
+ """
279
282
  def __init__(
280
283
  self,
281
284
  model: Model,
@@ -0,0 +1,112 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from palimpzest.constants import MODEL_CARDS, Model, PromptStrategy
6
+ from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
7
+ from palimpzest.core.elements.records import DataRecord
8
+ from palimpzest.query.generators.generators import generator_factory
9
+ from palimpzest.query.operators.convert import LLMConvert
10
+
11
+ # TYPE DEFINITIONS
12
+ FieldName = str
13
+
14
+
15
+ class CriticAndRefineConvert(LLMConvert):
16
+
17
+ def __init__(
18
+ self,
19
+ critic_model: Model,
20
+ refine_model: Model,
21
+ *args,
22
+ **kwargs,
23
+ ):
24
+ super().__init__(*args, **kwargs)
25
+ self.critic_model = critic_model
26
+ self.refine_model = refine_model
27
+
28
+ if self.prompt_strategy == PromptStrategy.COT_QA:
29
+ self.critic_prompt_strategy = PromptStrategy.COT_QA_CRITIC
30
+ self.refinement_prompt_strategy = PromptStrategy.COT_QA_REFINE
31
+ elif self.prompt_strategy == PromptStrategy.COT_QA_IMAGE:
32
+ self.critic_prompt_strategy = PromptStrategy.COT_QA_IMAGE_CRITIC
33
+ self.refinement_prompt_strategy = PromptStrategy.COT_QA_IMAGE_REFINE
34
+ else:
35
+ raise ValueError(f"Unsupported prompt strategy: {self.prompt_strategy}")
36
+
37
+ # create generators
38
+ self.critic_generator = generator_factory(self.critic_model, self.critic_prompt_strategy, self.cardinality, self.verbose)
39
+ self.refine_generator = generator_factory(self.refine_model, self.refinement_prompt_strategy, self.cardinality, self.verbose)
40
+
41
+ def __str__(self):
42
+ op = super().__str__()
43
+ op += f" Critic Model: {self.critic_model}\n"
44
+ op += f" Critic Prompt Strategy: {self.critic_prompt_strategy}\n"
45
+ op += f" Refine Model: {self.refine_model}\n"
46
+ op += f" Refinement Prompt Strategy: {self.refinement_prompt_strategy}\n"
47
+ return op
48
+
49
+ def get_id_params(self):
50
+ id_params = super().get_id_params()
51
+ id_params = {
52
+ "critic_model": self.critic_model.value,
53
+ "refine_model": self.refine_model.value,
54
+ **id_params,
55
+ }
56
+
57
+ return id_params
58
+
59
+ def get_op_params(self):
60
+ op_params = super().get_op_params()
61
+ op_params = {
62
+ "critic_model": self.critic_model,
63
+ "refine_model": self.refine_model,
64
+ **op_params,
65
+ }
66
+
67
+ return op_params
68
+
69
+ def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
70
+ """
71
+ Currently, we are invoking `self.model`, then critiquing its output with `self.critic_model`, and
72
+ finally refining the output with `self.refine_model`. Thus, we roughly expect to incur the cost
73
+ and time of three LLMConverts. In practice, this naive quality estimate will be overwritten by the
74
+ CostModel's estimate once it executes a few instances of the operator.
75
+ """
76
+ # get naive cost estimates for first LLM call and multiply by 3 for now;
77
+ # of course we should sum individual estimates for each model, but this is a rough estimate
78
+ # and in practice we will need to revamp our naive cost estimates in the near future
79
+ naive_op_cost_estimates = 3 * super().naive_cost_estimates(source_op_cost_estimates)
80
+
81
+ # for naive setting, estimate quality as quality of refine model
82
+ model_quality = MODEL_CARDS[self.refine_model.value]["overall"] / 100.0
83
+ naive_op_cost_estimates.quality = model_quality
84
+ naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
85
+ naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
86
+
87
+ return naive_op_cost_estimates
88
+
89
+ def convert(self, candidate: DataRecord, fields: list[str]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
90
+ # get input fields
91
+ input_fields = self.get_input_fields()
92
+
93
+ # NOTE: when I merge in the `abacus` branch, I will want to update this to reflect the changes I made to reasoning extraction
94
+ # execute the initial model
95
+ original_gen_kwargs = {"project_cols": input_fields, "output_schema": self.output_schema}
96
+ field_answers, reasoning, original_gen_stats = self.generator(candidate, fields, **original_gen_kwargs)
97
+ original_output = f"REASONING: {reasoning}\nANSWER:{field_answers}\n"
98
+ original_messages = self.generator.get_messages()
99
+
100
+ # execute the critic model
101
+ critic_gen_kwargs = {"original_output": original_output, "original_messages": original_messages, **original_gen_kwargs}
102
+ field_answers, reasoning, critic_gen_stats = self.critic_generator(candidate, fields, **critic_gen_kwargs)
103
+ critique_output = f"REASONING: {reasoning}\nANSWER:{field_answers}\n"
104
+
105
+ # execute the refinement model
106
+ refine_gen_kwargs = {"critique_output": critique_output, **critic_gen_kwargs}
107
+ field_answers, reasoning, refine_gen_stats = self.refine_generator(candidate, fields, **refine_gen_kwargs)
108
+
109
+ # compute the total generation stats
110
+ generation_stats = original_gen_stats + critic_gen_stats + refine_gen_stats
111
+
112
+ return field_answers, generation_stats
@@ -81,7 +81,7 @@ class FilterOp(PhysicalOperator, ABC):
81
81
  record_op_stats = RecordOpStats(
82
82
  record_id=dr.id,
83
83
  record_parent_id=dr.parent_id,
84
- record_source_id=dr.source_id,
84
+ record_source_idx=dr.source_idx,
85
85
  record_state=dr.to_dict(include_bytes=False),
86
86
  op_id=self.get_op_id(),
87
87
  logical_op_id=self.logical_op_id,
@@ -42,7 +42,7 @@ class LimitScanOp(PhysicalOperator):
42
42
  record_op_stats = RecordOpStats(
43
43
  record_id=dr.id,
44
44
  record_parent_id=dr.parent_id,
45
- record_source_id=dr.source_id,
45
+ record_source_idx=dr.source_idx,
46
46
  record_state=dr.to_dict(include_bytes=False),
47
47
  op_id=self.get_op_id(),
48
48
  logical_op_id=self.logical_op_id,
@@ -4,6 +4,7 @@ import json
4
4
  from typing import Callable
5
5
 
6
6
  from palimpzest.constants import AggFunc, Cardinality
7
+ from palimpzest.core.data.datareaders import DataReader
7
8
  from palimpzest.core.elements.filters import Filter
8
9
  from palimpzest.core.elements.groupbysig import GroupBySig
9
10
  from palimpzest.core.lib.schemas import Schema
@@ -15,7 +16,7 @@ class LogicalOperator:
15
16
  A logical operator is an operator that operates on Sets.
16
17
 
17
18
  Right now it can be one of:
18
- - BaseScan (scans data from DataSource)
19
+ - BaseScan (scans data from DataReader)
19
20
  - CacheScan (scans cached Set)
20
21
  - FilteredScan (scans input Set and applies filter)
21
22
  - ConvertScan (scans input Set and converts it to new Schema)
@@ -38,6 +39,14 @@ class LogicalOperator:
38
39
  self.input_schema = input_schema
39
40
  self.logical_op_id: str | None = None
40
41
 
42
+ # compute the fields generated by this logical operator
43
+ input_field_names = self.input_schema.field_names() if self.input_schema is not None else []
44
+ self.generated_fields = sorted([
45
+ field_name
46
+ for field_name in self.output_schema.field_names()
47
+ if field_name not in input_field_names
48
+ ])
49
+
41
50
  def __str__(self) -> str:
42
51
  raise NotImplementedError("Abstract method")
43
52
 
@@ -58,9 +67,10 @@ class LogicalOperator:
58
67
  for computing the logical operator id.
59
68
 
60
69
  NOTE: Should be overriden by subclasses to include class-specific parameters.
61
- NOTE: input_schema is not included in the id params because it depends on how the Optimizer orders operations.
70
+ NOTE: input_schema and output_schema are not included in the id params because
71
+ they depend on how the Optimizer orders operations.
62
72
  """
63
- return {"output_schema": self.output_schema}
73
+ return {"generated_fields": self.generated_fields}
64
74
 
65
75
  def get_logical_op_params(self) -> dict:
66
76
  """
@@ -137,30 +147,27 @@ class Aggregate(LogicalOperator):
137
147
  class BaseScan(LogicalOperator):
138
148
  """A BaseScan is a logical operator that represents a scan of a particular data source."""
139
149
 
140
- def __init__(self, dataset_id: str, output_schema: Schema):
150
+ def __init__(self, datareader: DataReader, output_schema: Schema):
141
151
  super().__init__(output_schema=output_schema)
142
- self.dataset_id = dataset_id
152
+ self.datareader = datareader
143
153
 
144
154
  def __str__(self):
145
- return f"BaseScan({self.dataset_id},{str(self.output_schema)})"
155
+ return f"BaseScan({self.datareader},{self.output_schema})"
146
156
 
147
157
  def __eq__(self, other) -> bool:
148
158
  return (
149
159
  isinstance(other, BaseScan)
150
160
  and self.input_schema.get_desc() == other.input_schema.get_desc()
151
161
  and self.output_schema.get_desc() == other.output_schema.get_desc()
152
- and self.dataset_id == other.dataset_id
162
+ and self.datareader == other.datareader
153
163
  )
154
164
 
155
165
  def get_logical_id_params(self) -> dict:
156
- logical_id_params = super().get_logical_id_params()
157
- logical_id_params = {"dataset_id": self.dataset_id, **logical_id_params}
158
-
159
- return logical_id_params
166
+ return super().get_logical_id_params()
160
167
 
161
168
  def get_logical_op_params(self) -> dict:
162
169
  logical_op_params = super().get_logical_op_params()
163
- logical_op_params = {"dataset_id": self.dataset_id, **logical_op_params}
170
+ logical_op_params = {"datareader": self.datareader, **logical_op_params}
164
171
 
165
172
  return logical_op_params
166
173
 
@@ -168,28 +175,19 @@ class BaseScan(LogicalOperator):
168
175
  class CacheScan(LogicalOperator):
169
176
  """A CacheScan is a logical operator that represents a scan of a cached Set."""
170
177
 
171
- def __init__(self, dataset_id: str, *args, **kwargs):
172
- if kwargs.get("input_schema") is not None:
173
- raise Exception(
174
- f"CacheScan must be initialized with `input_schema=None` but was initialized with "
175
- f"`input_schema={kwargs.get('input_schema')}`"
176
- )
177
-
178
- super().__init__(*args, **kwargs)
179
- self.dataset_id = dataset_id
178
+ def __init__(self, datareader: DataReader, output_schema: Schema):
179
+ super().__init__(output_schema=output_schema)
180
+ self.datareader = datareader
180
181
 
181
182
  def __str__(self):
182
- return f"CacheScan({str(self.output_schema)},{str(self.dataset_id)})"
183
+ return f"CacheScan({self.datareader},{self.output_schema})"
183
184
 
184
185
  def get_logical_id_params(self) -> dict:
185
- logical_id_params = super().get_logical_id_params()
186
- logical_id_params = {"dataset_id": self.dataset_id, **logical_id_params}
187
-
188
- return logical_id_params
186
+ return super().get_logical_id_params()
189
187
 
190
188
  def get_logical_op_params(self) -> dict:
191
189
  logical_op_params = super().get_logical_op_params()
192
- logical_op_params = {"dataset_id": self.dataset_id, **logical_op_params}
190
+ logical_op_params = {"datareader": self.datareader, **logical_op_params}
193
191
 
194
192
  return logical_op_params
195
193
 
@@ -374,6 +372,7 @@ class RetrieveScan(LogicalOperator):
374
372
  def __init__(
375
373
  self,
376
374
  index,
375
+ search_func,
377
376
  search_attr,
378
377
  output_attr,
379
378
  k,
@@ -383,6 +382,7 @@ class RetrieveScan(LogicalOperator):
383
382
  ):
384
383
  super().__init__(*args, **kwargs)
385
384
  self.index = index
385
+ self.search_func = search_func
386
386
  self.search_attr = search_attr
387
387
  self.output_attr = output_attr
388
388
  self.k = k
@@ -409,6 +409,7 @@ class RetrieveScan(LogicalOperator):
409
409
  logical_op_params = super().get_logical_op_params()
410
410
  logical_op_params = {
411
411
  "index": self.index,
412
+ "search_func": self.search_func,
412
413
  "search_attr": self.search_attr,
413
414
  "output_attr": self.output_attr,
414
415
  "k": self.k,
@@ -84,7 +84,7 @@ class MixtureOfAgentsConvert(LLMConvert):
84
84
  answers, which are then aggregated and summarized by a single aggregator model. Thus, we
85
85
  roughly expect to incur the cost and time of an LLMConvert * (len(proposer_models) + 1).
86
86
  In practice, this naive quality estimate will be overwritten by the CostModel's estimate
87
- once it executes a few code generated examples.
87
+ once it executes a few instances of the operator.
88
88
  """
89
89
  # temporarily set self.model so that super().naive_cost_estimates(...) can compute an estimate
90
90
  self.model = self.proposer_models[0]
@@ -107,6 +107,9 @@ class MixtureOfAgentsConvert(LLMConvert):
107
107
  naive_op_cost_estimates.quality_lower_bound = naive_op_cost_estimates.quality
108
108
  naive_op_cost_estimates.quality_upper_bound = naive_op_cost_estimates.quality
109
109
 
110
+ # reset self.model to be None
111
+ self.model = None
112
+
110
113
  return naive_op_cost_estimates
111
114
 
112
115
  def convert(self, candidate: DataRecord, fields: list[str]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
@@ -5,7 +5,6 @@ import json
5
5
  from palimpzest.core.data.dataclasses import OperatorCostEstimates
6
6
  from palimpzest.core.elements.records import DataRecord, DataRecordSet
7
7
  from palimpzest.core.lib.schemas import Schema
8
- from palimpzest.datamanager.datamanager import DataDirectory
9
8
  from palimpzest.utils.hash_helpers import hash_for_id
10
9
 
11
10
 
@@ -36,7 +35,14 @@ class PhysicalOperator:
36
35
  self.target_cache_id = target_cache_id
37
36
  self.verbose = verbose
38
37
  self.op_id = None
39
- self.datadir = DataDirectory()
38
+
39
+ # compute the fields generated by this physical operator
40
+ input_field_names = self.input_schema.field_names() if self.input_schema is not None else []
41
+ self.generated_fields = sorted([
42
+ field_name
43
+ for field_name in self.output_schema.field_names()
44
+ if field_name not in input_field_names
45
+ ])
40
46
 
41
47
  # sets __hash__() for each child Operator to be the base class' __hash__() method;
42
48
  # by default, if a subclass defines __eq__() but not __hash__() Python will set that
@@ -68,9 +74,12 @@ class PhysicalOperator:
68
74
  for computing the physical operator id.
69
75
 
70
76
  NOTE: Should be overriden by subclasses to include class-specific parameters.
71
- NOTE: input_schema is not included in the id params because it depends on how the Optimizer orders operations.
77
+ NOTE: input_schema and output_schema are not included in the id params by default,
78
+ because they may depend on the order of operations chosen by the Optimizer.
79
+ This is particularly true for convert operations, where the output schema
80
+ is now the union of the input and output schemas of the logical operator.
72
81
  """
73
- return {"output_schema": self.output_schema}
82
+ return {"generated_fields": self.generated_fields}
74
83
 
75
84
  def get_op_params(self) -> dict:
76
85
  """
@@ -106,7 +115,10 @@ class PhysicalOperator:
106
115
  # get op name and op parameters which are relevant for computing the id
107
116
  op_name = self.op_name()
108
117
  id_params = self.get_id_params()
109
- id_params = {k: str(v) for k, v in id_params.items()}
118
+ id_params = {
119
+ k: str(v) if k != "output_schema" else sorted(v.field_names())
120
+ for k, v in id_params.items()
121
+ }
110
122
 
111
123
  # compute, set, and return the op_id
112
124
  hash_str = json.dumps({"op_name": op_name, **id_params}, sort_keys=True)
@@ -136,20 +148,20 @@ class PhysicalOperator:
136
148
 
137
149
  return input_fields
138
150
 
139
- def get_fields_to_generate(self, candidate: DataRecord, input_schema: Schema, output_schema: Schema) -> list[str]:
151
+ def get_fields_to_generate(self, candidate: DataRecord) -> list[str]:
140
152
  """
141
- Creates the list of field names that an operation needs to generate. Right now this is only used
142
- by convert and retrieve operators.
153
+ Returns the list of field names that this operator needs to generate for the given candidate.
154
+ This function returns only the fields in self.generated_fields which are not already present
155
+ in the candidate. This is important for operators with retry logic, where we may only need to
156
+ recompute a subset of self.generated_fields.
157
+
158
+ Right now this is only used by convert and retrieve operators.
143
159
  """
144
- # construct the list of fields in output_schema which will need to be generated;
145
- # specifically, this is the set of fields which are:
146
- # 1. not declared in the input schema, and
147
- # 2. not present in the candidate's attributes
148
- # a. if the field is present, but its value is None --> we will try to generate it
149
- fields_to_generate = []
150
- for field_name in output_schema.field_names():
151
- if field_name not in input_schema.field_names() and getattr(candidate, field_name, None) is None:
152
- fields_to_generate.append(field_name)
160
+ fields_to_generate = [
161
+ field_name
162
+ for field_name in self.generated_fields
163
+ if getattr(candidate, field_name, None) is None
164
+ ]
153
165
 
154
166
  return fields_to_generate
155
167
 
@@ -168,8 +180,8 @@ class PhysicalOperator:
168
180
  when PZ does not have sample execution data -- and it will be necessary
169
181
  in some cases even when sample execution data is present. (For example,
170
182
  the cardinality of each operator cannot be estimated based on sample
171
- execution data alone -- thus DataSourcePhysicalOps need to give
172
- at least ballpark correct estimates of this quantity).
183
+ execution data alone -- thus ScanPhysicalOps need to give at least ballpark
184
+ correct estimates of this quantity).
173
185
  """
174
186
  raise NotImplementedError("CostEstimates from abstract method")
175
187
 
@@ -177,7 +189,7 @@ class PhysicalOperator:
177
189
  raise NotImplementedError("Calling __call__ from abstract method")
178
190
 
179
191
  @staticmethod
180
- def execute_op_wrapper(operator: PhysicalOperator, op_input: DataRecord | list[DataRecord]) -> tuple[DataRecordSet, PhysicalOperator]:
192
+ def execute_op_wrapper(operator: PhysicalOperator, op_input: DataRecord | list[DataRecord] | int) -> tuple[DataRecordSet, PhysicalOperator]:
181
193
  """
182
194
  Wrapper function around operator execution which also and returns the operator.
183
195
  This is useful in the parallel setting(s) where operators are executed by a worker pool,
@@ -40,7 +40,7 @@ class ProjectOp(PhysicalOperator):
40
40
  record_op_stats = RecordOpStats(
41
41
  record_id=dr.id,
42
42
  record_parent_id=dr.parent_id,
43
- record_source_id=dr.source_id,
43
+ record_source_idx=dr.source_idx,
44
44
  record_state=dr.to_dict(include_bytes=False),
45
45
  op_id=self.get_op_id(),
46
46
  logical_op_id=self.logical_op_id,
@@ -12,7 +12,7 @@ from palimpzest.constants import (
12
12
  )
13
13
  from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
14
14
  from palimpzest.core.elements.records import DataRecord
15
- from palimpzest.core.lib.fields import ListField, StringField
15
+ from palimpzest.core.lib.fields import StringField
16
16
  from palimpzest.query.operators.convert import FieldName, LLMConvert
17
17
 
18
18
 
@@ -20,7 +20,7 @@ class RAGConvert(LLMConvert):
20
20
  def __init__(self, num_chunks_per_field: int, chunk_size: int = 1000, *args, **kwargs):
21
21
  super().__init__(*args, **kwargs)
22
22
  # NOTE: in the future, we should abstract the embedding model to allow for different models
23
- self.client = OpenAI()
23
+ self.client = None
24
24
  self.embedding_model = "text-embedding-3-small"
25
25
  self.num_chunks_per_field = num_chunks_per_field
26
26
  self.chunk_size = chunk_size
@@ -124,7 +124,7 @@ class RAGConvert(LLMConvert):
124
124
 
125
125
  # skip this field if it is not a string or a list of strings
126
126
  is_string_field = isinstance(field, StringField)
127
- is_list_string_field = isinstance(field, ListField) and isinstance(field.element_type, StringField)
127
+ is_list_string_field = hasattr(field, "element_type") and isinstance(field.element_type, StringField)
128
128
  if not (is_string_field or is_list_string_field):
129
129
  continue
130
130
 
@@ -157,6 +157,9 @@ class RAGConvert(LLMConvert):
157
157
  return candidate
158
158
 
159
159
  def convert(self, candidate: DataRecord, fields: list[str]) -> tuple[dict[FieldName, list[Any]], GenerationStats]:
160
+ # set client
161
+ self.client = OpenAI() if self.client is None else self.client
162
+
160
163
  # get the set of input fields to use for the convert operation
161
164
  input_fields = self.get_input_fields()
162
165