palimpzest 0.7.21__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.21.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -3,11 +3,13 @@ from __future__ import annotations
3
3
  import json
4
4
  from typing import Callable
5
5
 
6
+ from pydantic import BaseModel
7
+
6
8
  from palimpzest.constants import AggFunc, Cardinality
7
- from palimpzest.core.data.datareaders import DataReader
9
+ from palimpzest.core.data import context, dataset
8
10
  from palimpzest.core.elements.filters import Filter
9
11
  from palimpzest.core.elements.groupbysig import GroupBySig
10
- from palimpzest.core.lib.schemas import Schema
12
+ from palimpzest.core.lib.schemas import Average, Count
11
13
  from palimpzest.utils.hash_helpers import hash_for_id
12
14
 
13
15
 
@@ -16,8 +18,8 @@ class LogicalOperator:
16
18
  A logical operator is an operator that operates on Sets.
17
19
 
18
20
  Right now it can be one of:
19
- - BaseScan (scans data from DataReader)
20
- - CacheScan (scans cached Set)
21
+ - BaseScan (scans data from a root Dataset)
22
+ - ContextScan (loads the context for a root Dataset)
21
23
  - FilteredScan (scans input Set and applies filter)
22
24
  - ConvertScan (scans input Set and converts it to new Schema)
23
25
  - LimitScan (scans up to N records from a Set)
@@ -25,6 +27,8 @@ class LogicalOperator:
25
27
  - Aggregate (applies an aggregation on the Set)
26
28
  - RetrieveScan (fetches documents from a provided input for a given query)
27
29
  - Map (applies a function to each record in the Set without adding any new columns)
30
+ - ComputeOperator (executes a computation described in natural language)
31
+ - SearchOperator (executes a search query on the input Context)
28
32
 
29
33
  Every logical operator must declare the get_logical_id_params() and get_logical_op_params() methods,
30
34
  which return dictionaries of parameters that are used to compute the logical op id and to implement
@@ -33,17 +37,21 @@ class LogicalOperator:
33
37
 
34
38
  def __init__(
35
39
  self,
36
- output_schema: Schema,
37
- input_schema: Schema | None = None,
40
+ output_schema: type[BaseModel],
41
+ input_schema: type[BaseModel] | None = None,
42
+ depends_on: list[str] | None = None,
38
43
  ):
44
+ # TODO: can we eliminate input_schema?
39
45
  self.output_schema = output_schema
40
46
  self.input_schema = input_schema
47
+ self.depends_on = [] if depends_on is None else sorted(depends_on)
41
48
  self.logical_op_id: str | None = None
49
+ self.unique_logical_op_id: str | None = None
42
50
 
43
51
  # compute the fields generated by this logical operator
44
- input_field_names = self.input_schema.field_names() if self.input_schema is not None else []
52
+ input_field_names = list(self.input_schema.model_fields) if self.input_schema is not None else []
45
53
  self.generated_fields = sorted(
46
- [field_name for field_name in self.output_schema.field_names() if field_name not in input_field_names]
54
+ [field_name for field_name in self.output_schema.model_fields if field_name not in input_field_names]
47
55
  )
48
56
 
49
57
  def __str__(self) -> str:
@@ -54,12 +62,28 @@ class LogicalOperator:
54
62
  return isinstance(other, self.__class__) and all_id_params_match
55
63
 
56
64
  def copy(self) -> LogicalOperator:
57
- return self.__class__(**self.get_logical_op_params())
65
+ logical_op_copy = self.__class__(**self.get_logical_op_params())
66
+ logical_op_copy.logical_op_id = self.logical_op_id
67
+ logical_op_copy.unique_logical_op_id = self.unique_logical_op_id
68
+ return logical_op_copy
58
69
 
59
70
  def logical_op_name(self) -> str:
60
71
  """Name of the logical operator."""
61
72
  return str(self.__class__.__name__)
62
73
 
74
+ def get_unique_logical_op_id(self) -> str:
75
+ """
76
+ Get the unique logical operator id for this logical operator.
77
+ """
78
+ return self.unique_logical_op_id
79
+
80
+ def set_unique_logical_op_id(self, unique_logical_op_id: str) -> None:
81
+ """
82
+ Set the unique logical operator id for this logical operator.
83
+ This is used to uniquely identify the logical operator in the query plan.
84
+ """
85
+ self.unique_logical_op_id = unique_logical_op_id
86
+
63
87
  def get_logical_id_params(self) -> dict:
64
88
  """
65
89
  Returns a dictionary mapping of logical operator parameters which are relevant
@@ -69,6 +93,7 @@ class LogicalOperator:
69
93
  NOTE: input_schema and output_schema are not included in the id params because
70
94
  they depend on how the Optimizer orders operations.
71
95
  """
96
+ # TODO: should we use `generated_fields` after getting rid of them in PhysicalOperator?
72
97
  return {"generated_fields": self.generated_fields}
73
98
 
74
99
  def get_logical_op_params(self) -> dict:
@@ -78,10 +103,16 @@ class LogicalOperator:
78
103
 
79
104
  NOTE: Should be overriden by subclasses to include class-specific parameters.
80
105
  """
81
- return {"input_schema": self.input_schema, "output_schema": self.output_schema}
106
+ return {
107
+ "input_schema": self.input_schema,
108
+ "output_schema": self.output_schema,
109
+ "depends_on": self.depends_on,
110
+ }
82
111
 
83
112
  def get_logical_op_id(self):
84
113
  """
114
+ TODO: turn this into a property?
115
+
85
116
  NOTE: We do not call this in the __init__() method as subclasses may set parameters
86
117
  returned by self.get_logical_op_params() after they call to super().__init__().
87
118
  """
@@ -119,13 +150,19 @@ class Aggregate(LogicalOperator):
119
150
  def __init__(
120
151
  self,
121
152
  agg_func: AggFunc,
122
- target_cache_id: str | None = None,
123
153
  *args,
124
154
  **kwargs,
125
155
  ):
156
+ if kwargs.get("output_schema") is None:
157
+ if agg_func == AggFunc.COUNT:
158
+ kwargs["output_schema"] = Count
159
+ elif agg_func == AggFunc.AVERAGE:
160
+ kwargs["output_schema"] = Average
161
+ else:
162
+ raise ValueError(f"Unsupported aggregation function: {agg_func}")
163
+
126
164
  super().__init__(*args, **kwargs)
127
165
  self.agg_func = agg_func
128
- self.target_cache_id = target_cache_id
129
166
 
130
167
  def __str__(self):
131
168
  return f"{self.__class__.__name__}(function: {str(self.agg_func.value)})"
@@ -140,7 +177,6 @@ class Aggregate(LogicalOperator):
140
177
  logical_op_params = super().get_logical_op_params()
141
178
  logical_op_params = {
142
179
  "agg_func": self.agg_func,
143
- "target_cache_id": self.target_cache_id,
144
180
  **logical_op_params,
145
181
  }
146
182
 
@@ -148,75 +184,87 @@ class Aggregate(LogicalOperator):
148
184
 
149
185
 
150
186
  class BaseScan(LogicalOperator):
151
- """A BaseScan is a logical operator that represents a scan of a particular data source."""
187
+ """A BaseScan is a logical operator that represents a scan of a particular root Dataset."""
152
188
 
153
- def __init__(self, datareader: DataReader, output_schema: Schema):
154
- super().__init__(output_schema=output_schema)
155
- self.datareader = datareader
189
+ def __init__(self, datasource: dataset.Dataset, output_schema: type[BaseModel], *args, **kwargs):
190
+ super().__init__(*args, output_schema=output_schema, **kwargs)
191
+ self.datasource = datasource
156
192
 
157
193
  def __str__(self):
158
- return f"BaseScan({self.datareader},{self.output_schema})"
194
+ return f"BaseScan({self.datasource},{self.output_schema})"
159
195
 
160
196
  def __eq__(self, other) -> bool:
161
197
  return (
162
198
  isinstance(other, BaseScan)
163
- and self.input_schema.get_desc() == other.input_schema.get_desc()
164
- and self.output_schema.get_desc() == other.output_schema.get_desc()
165
- and self.datareader == other.datareader
199
+ and self.input_schema == other.input_schema
200
+ and self.output_schema == other.output_schema
201
+ and self.datasource == other.datasource
166
202
  )
167
203
 
168
204
  def get_logical_id_params(self) -> dict:
169
- return super().get_logical_id_params()
205
+ logical_id_params = super().get_logical_id_params()
206
+ logical_id_params = {
207
+ "id": self.datasource.id,
208
+ **logical_id_params,
209
+ }
210
+
211
+ return logical_id_params
170
212
 
171
213
  def get_logical_op_params(self) -> dict:
172
214
  logical_op_params = super().get_logical_op_params()
173
- logical_op_params = {"datareader": self.datareader, **logical_op_params}
215
+ logical_op_params = {"datasource": self.datasource, **logical_op_params}
174
216
 
175
217
  return logical_op_params
176
218
 
177
219
 
178
- class CacheScan(LogicalOperator):
179
- """A CacheScan is a logical operator that represents a scan of a cached Set."""
220
+ class ContextScan(LogicalOperator):
221
+ """A ContextScan is a logical operator that loads the context for a particular root Dataset."""
180
222
 
181
- def __init__(self, datareader: DataReader, output_schema: Schema):
182
- super().__init__(output_schema=output_schema)
183
- self.datareader = datareader
223
+ def __init__(self, context: context.Context, output_schema: type[BaseModel], *args, **kwargs):
224
+ super().__init__(*args, output_schema=output_schema, **kwargs)
225
+ self.context = context
184
226
 
185
227
  def __str__(self):
186
- return f"CacheScan({self.datareader},{self.output_schema})"
228
+ return f"ContextScan({self.context},{self.output_schema})"
229
+
230
+ def __eq__(self, other) -> bool:
231
+ return (
232
+ isinstance(other, ContextScan)
233
+ and self.context.id == other.context.id
234
+ )
187
235
 
188
236
  def get_logical_id_params(self) -> dict:
189
- return super().get_logical_id_params()
237
+ logical_id_params = super().get_logical_id_params()
238
+ logical_id_params = {
239
+ "id": self.context.id,
240
+ **logical_id_params,
241
+ }
242
+
243
+ return logical_id_params
190
244
 
191
245
  def get_logical_op_params(self) -> dict:
192
246
  logical_op_params = super().get_logical_op_params()
193
- logical_op_params = {"datareader": self.datareader, **logical_op_params}
247
+ logical_op_params = {"context": self.context, **logical_op_params}
194
248
 
195
249
  return logical_op_params
196
250
 
197
251
 
198
252
  class ConvertScan(LogicalOperator):
199
- """A ConvertScan is a logical operator that represents a scan of a particular data source, with conversion applied."""
253
+ """A ConvertScan is a logical operator that represents a scan of a particular input Dataset, with conversion applied."""
200
254
 
201
255
  def __init__(
202
256
  self,
203
257
  cardinality: Cardinality = Cardinality.ONE_TO_ONE,
204
258
  udf: Callable | None = None,
205
- depends_on: list[str] | None = None,
206
- desc: str | None = None,
207
- target_cache_id: str | None = None,
208
259
  *args,
209
260
  **kwargs,
210
261
  ):
211
262
  super().__init__(*args, **kwargs)
212
263
  self.cardinality = cardinality
213
264
  self.udf = udf
214
- self.depends_on = [] if depends_on is None else sorted(depends_on)
215
- self.desc = desc
216
- self.target_cache_id = target_cache_id
217
265
 
218
266
  def __str__(self):
219
- return f"ConvertScan({self.input_schema} -> {str(self.output_schema)},{str(self.desc)})"
267
+ return f"ConvertScan({self.input_schema} -> {str(self.output_schema)})"
220
268
 
221
269
  def get_logical_id_params(self) -> dict:
222
270
  logical_id_params = super().get_logical_id_params()
@@ -233,9 +281,40 @@ class ConvertScan(LogicalOperator):
233
281
  logical_op_params = {
234
282
  "cardinality": self.cardinality,
235
283
  "udf": self.udf,
236
- "depends_on": self.depends_on,
237
- "desc": self.desc,
238
- "target_cache_id": self.target_cache_id,
284
+ **logical_op_params,
285
+ }
286
+
287
+ return logical_op_params
288
+
289
+
290
+ class Distinct(LogicalOperator):
291
+ def __init__(self, distinct_cols: list[str] | None, *args, **kwargs):
292
+ super().__init__(*args, **kwargs)
293
+ # if distinct_cols is not None, check that all columns are in the input schema
294
+ if distinct_cols is not None:
295
+ for col in distinct_cols:
296
+ assert col in self.input_schema.model_fields, f"Column {col} not found in input schema {self.input_schema} for Distinct operator"
297
+
298
+ # store the list of distinct columns, sorted
299
+ self.distinct_cols = (
300
+ sorted([field_name for field_name in self.input_schema.model_fields])
301
+ if distinct_cols is None
302
+ else sorted(distinct_cols)
303
+ )
304
+
305
+ def __str__(self):
306
+ return f"Distinct({self.distinct_cols})"
307
+
308
+ def get_logical_id_params(self) -> dict:
309
+ logical_id_params = super().get_logical_id_params()
310
+ logical_id_params = {"distinct_cols": self.distinct_cols, **logical_id_params}
311
+
312
+ return logical_id_params
313
+
314
+ def get_logical_op_params(self) -> dict:
315
+ logical_op_params = super().get_logical_op_params()
316
+ logical_op_params = {
317
+ "distinct_cols": self.distinct_cols,
239
318
  **logical_op_params,
240
319
  }
241
320
 
@@ -243,20 +322,16 @@ class ConvertScan(LogicalOperator):
243
322
 
244
323
 
245
324
  class FilteredScan(LogicalOperator):
246
- """A FilteredScan is a logical operator that represents a scan of a particular data source, with filters applied."""
325
+ """A FilteredScan is a logical operator that represents a scan of a particular input Dataset, with filters applied."""
247
326
 
248
327
  def __init__(
249
328
  self,
250
329
  filter: Filter,
251
- depends_on: list[str] | None = None,
252
- target_cache_id: str | None = None,
253
330
  *args,
254
331
  **kwargs,
255
332
  ):
256
333
  super().__init__(*args, **kwargs)
257
334
  self.filter = filter
258
- self.depends_on = [] if depends_on is None else sorted(depends_on)
259
- self.target_cache_id = target_cache_id
260
335
 
261
336
  def __str__(self):
262
337
  return f"FilteredScan({str(self.output_schema)}, {str(self.filter)})"
@@ -274,8 +349,6 @@ class FilteredScan(LogicalOperator):
274
349
  logical_op_params = super().get_logical_op_params()
275
350
  logical_op_params = {
276
351
  "filter": self.filter,
277
- "depends_on": self.depends_on,
278
- "target_cache_id": self.target_cache_id,
279
352
  **logical_op_params,
280
353
  }
281
354
 
@@ -286,7 +359,6 @@ class GroupByAggregate(LogicalOperator):
286
359
  def __init__(
287
360
  self,
288
361
  group_by_sig: GroupBySig,
289
- target_cache_id: str | None = None,
290
362
  *args,
291
363
  **kwargs,
292
364
  ):
@@ -297,7 +369,6 @@ class GroupByAggregate(LogicalOperator):
297
369
  if not valid:
298
370
  raise TypeError(error)
299
371
  self.group_by_sig = group_by_sig
300
- self.target_cache_id = target_cache_id
301
372
 
302
373
  def __str__(self):
303
374
  return f"GroupBy({self.group_by_sig.serialize()})"
@@ -312,7 +383,30 @@ class GroupByAggregate(LogicalOperator):
312
383
  logical_op_params = super().get_logical_op_params()
313
384
  logical_op_params = {
314
385
  "group_by_sig": self.group_by_sig,
315
- "target_cache_id": self.target_cache_id,
386
+ **logical_op_params,
387
+ }
388
+
389
+ return logical_op_params
390
+
391
+
392
+ class JoinOp(LogicalOperator):
393
+ def __init__(self, condition: str, *args, **kwargs):
394
+ super().__init__(*args, **kwargs)
395
+ self.condition = condition
396
+
397
+ def __str__(self):
398
+ return f"Join(condition={self.condition})"
399
+
400
+ def get_logical_id_params(self) -> dict:
401
+ logical_id_params = super().get_logical_id_params()
402
+ logical_id_params = {"condition": self.condition, **logical_id_params}
403
+
404
+ return logical_id_params
405
+
406
+ def get_logical_op_params(self) -> dict:
407
+ logical_op_params = super().get_logical_op_params()
408
+ logical_op_params = {
409
+ "condition": self.condition,
316
410
  **logical_op_params,
317
411
  }
318
412
 
@@ -320,10 +414,9 @@ class GroupByAggregate(LogicalOperator):
320
414
 
321
415
 
322
416
  class LimitScan(LogicalOperator):
323
- def __init__(self, limit: int, target_cache_id: str | None = None, *args, **kwargs):
417
+ def __init__(self, limit: int, *args, **kwargs):
324
418
  super().__init__(*args, **kwargs)
325
419
  self.limit = limit
326
- self.target_cache_id = target_cache_id
327
420
 
328
421
  def __str__(self):
329
422
  return f"LimitScan({str(self.input_schema)}, {str(self.output_schema)})"
@@ -338,7 +431,6 @@ class LimitScan(LogicalOperator):
338
431
  logical_op_params = super().get_logical_op_params()
339
432
  logical_op_params = {
340
433
  "limit": self.limit,
341
- "target_cache_id": self.target_cache_id,
342
434
  **logical_op_params,
343
435
  }
344
436
 
@@ -346,10 +438,9 @@ class LimitScan(LogicalOperator):
346
438
 
347
439
 
348
440
  class Project(LogicalOperator):
349
- def __init__(self, project_cols: list[str], target_cache_id: str | None = None, *args, **kwargs):
441
+ def __init__(self, project_cols: list[str], *args, **kwargs):
350
442
  super().__init__(*args, **kwargs)
351
443
  self.project_cols = project_cols
352
- self.target_cache_id = target_cache_id
353
444
 
354
445
  def __str__(self):
355
446
  return f"Project({self.input_schema}, {self.project_cols})"
@@ -364,7 +455,6 @@ class Project(LogicalOperator):
364
455
  logical_op_params = super().get_logical_op_params()
365
456
  logical_op_params = {
366
457
  "project_cols": self.project_cols,
367
- "target_cache_id": self.target_cache_id,
368
458
  **logical_op_params,
369
459
  }
370
460
 
@@ -372,7 +462,7 @@ class Project(LogicalOperator):
372
462
 
373
463
 
374
464
  class RetrieveScan(LogicalOperator):
375
- """A RetrieveScan is a logical operator that represents a scan of a particular data source, with a convert-like retrieve applied."""
465
+ """A RetrieveScan is a logical operator that represents a scan of a particular input Dataset, with a convert-like retrieve applied."""
376
466
 
377
467
  def __init__(
378
468
  self,
@@ -381,7 +471,6 @@ class RetrieveScan(LogicalOperator):
381
471
  search_attr,
382
472
  output_attrs,
383
473
  k,
384
- target_cache_id: str = None,
385
474
  *args,
386
475
  **kwargs,
387
476
  ):
@@ -391,10 +480,9 @@ class RetrieveScan(LogicalOperator):
391
480
  self.search_attr = search_attr
392
481
  self.output_attrs = output_attrs
393
482
  self.k = k
394
- self.target_cache_id = target_cache_id
395
483
 
396
484
  def __str__(self):
397
- return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)},{str(self.desc)})"
485
+ return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)})"
398
486
 
399
487
  def get_logical_id_params(self) -> dict:
400
488
  # NOTE: if we allow optimization over index, then we will need to include it in the id params
@@ -418,36 +506,31 @@ class RetrieveScan(LogicalOperator):
418
506
  "search_attr": self.search_attr,
419
507
  "output_attrs": self.output_attrs,
420
508
  "k": self.k,
421
- "target_cache_id": self.target_cache_id,
422
509
  **logical_op_params,
423
510
  }
424
511
 
425
512
  return logical_op_params
426
513
 
427
514
 
428
- # TODO: (near-term) maybe we should try to fold this into ConvertScan, and make the internals of PZ
429
- # amenable to a convert operator (with a UDF) that does not add new columns?
430
- class MapScan(LogicalOperator):
431
- """A MapScan is a logical operator that applies a UDF to each input record without adding new columns."""
515
+ class ComputeOperator(LogicalOperator):
516
+ """
517
+ A ComputeOperator is a logical operator that performs a computation described in natural language
518
+ on a given Context.
519
+ """
432
520
 
433
- def __init__(
434
- self,
435
- udf: Callable | None = None,
436
- target_cache_id: str | None = None,
437
- *args,
438
- **kwargs,
439
- ):
521
+ def __init__(self, context_id: str, instruction: str, *args, **kwargs):
440
522
  super().__init__(*args, **kwargs)
441
- self.udf = udf
442
- self.target_cache_id = target_cache_id
523
+ self.context_id = context_id
524
+ self.instruction = instruction
443
525
 
444
526
  def __str__(self):
445
- return f"MapScan({self.output_schema}, {self.udf.__name__})"
527
+ return f"ComputeOperator(id={self.context_id}, instr={self.instruction:20s})"
446
528
 
447
529
  def get_logical_id_params(self) -> dict:
448
530
  logical_id_params = super().get_logical_id_params()
449
531
  logical_id_params = {
450
- "udf": self.udf,
532
+ "context_id": self.context_id,
533
+ "instruction": self.instruction,
451
534
  **logical_id_params,
452
535
  }
453
536
 
@@ -456,8 +539,43 @@ class MapScan(LogicalOperator):
456
539
  def get_logical_op_params(self) -> dict:
457
540
  logical_op_params = super().get_logical_op_params()
458
541
  logical_op_params = {
459
- "udf": self.udf,
460
- "target_cache_id": self.target_cache_id,
542
+ "context_id": self.context_id,
543
+ "instruction": self.instruction,
544
+ **logical_op_params,
545
+ }
546
+
547
+ return logical_op_params
548
+
549
+
550
+ class SearchOperator(LogicalOperator):
551
+ """
552
+ A SearchOperator is a logical operator that executes a search described in natural language
553
+ on a given Context.
554
+ """
555
+
556
+ def __init__(self, context_id: str, search_query: str, *args, **kwargs):
557
+ super().__init__(*args, **kwargs)
558
+ self.context_id = context_id
559
+ self.search_query = search_query
560
+
561
+ def __str__(self):
562
+ return f"SearchOperator(id={self.context_id}, search_query={self.search_query:20s})"
563
+
564
+ def get_logical_id_params(self) -> dict:
565
+ logical_id_params = super().get_logical_id_params()
566
+ logical_id_params = {
567
+ "context_id": self.context_id,
568
+ "search_query": self.search_query,
569
+ **logical_id_params,
570
+ }
571
+
572
+ return logical_id_params
573
+
574
+ def get_logical_op_params(self) -> dict:
575
+ logical_op_params = super().get_logical_op_params()
576
+ logical_op_params = {
577
+ "context_id": self.context_id,
578
+ "search_query": self.search_query,
461
579
  **logical_op_params,
462
580
  }
463
581
 
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from pydantic.fields import FieldInfo
4
+
3
5
  from palimpzest.constants import MODEL_CARDS, Model, PromptStrategy
4
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
5
6
  from palimpzest.core.elements.records import DataRecord
6
- from palimpzest.core.lib.fields import Field
7
- from palimpzest.query.generators.generators import generator_factory
7
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates
8
+ from palimpzest.query.generators.generators import Generator
8
9
  from palimpzest.query.operators.convert import LLMConvert
9
10
 
10
11
  # TYPE DEFINITIONS
@@ -20,7 +21,6 @@ class MixtureOfAgentsConvert(LLMConvert):
20
21
  aggregator_model: Model,
21
22
  proposer_prompt_strategy: PromptStrategy = PromptStrategy.COT_MOA_PROPOSER,
22
23
  aggregator_prompt_strategy: PromptStrategy = PromptStrategy.COT_MOA_AGG,
23
- proposer_prompt: str | None = None,
24
24
  *args,
25
25
  **kwargs,
26
26
  ):
@@ -33,14 +33,13 @@ class MixtureOfAgentsConvert(LLMConvert):
33
33
  self.aggregator_model = aggregator_model
34
34
  self.proposer_prompt_strategy = proposer_prompt_strategy
35
35
  self.aggregator_prompt_strategy = aggregator_prompt_strategy
36
- self.proposer_prompt = proposer_prompt
37
36
 
38
37
  # create generators
39
38
  self.proposer_generators = [
40
- generator_factory(model, self.proposer_prompt_strategy, self.cardinality, self.verbose)
39
+ Generator(model, self.proposer_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.verbose)
41
40
  for model in proposer_models
42
41
  ]
43
- self.aggregator_generator = generator_factory(aggregator_model, self.aggregator_prompt_strategy, self.cardinality, self.verbose)
42
+ self.aggregator_generator = Generator(aggregator_model, self.aggregator_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.verbose)
44
43
 
45
44
  def __str__(self):
46
45
  op = super().__str__()
@@ -77,6 +76,9 @@ class MixtureOfAgentsConvert(LLMConvert):
77
76
 
78
77
  return op_params
79
78
 
79
+ def is_image_conversion(self) -> bool:
80
+ return self.proposer_prompt_strategy.is_image_prompt()
81
+
80
82
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
81
83
  """
82
84
  Currently, we are using multiple proposer models with different temperatures to synthesize
@@ -111,7 +113,7 @@ class MixtureOfAgentsConvert(LLMConvert):
111
113
 
112
114
  return naive_op_cost_estimates
113
115
 
114
- def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
116
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
115
117
  # get input fields
116
118
  input_fields = self.get_input_fields()
117
119