palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +343 -209
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +639 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +62 -6
  19. palimpzest/prompts/filter_prompts.py +51 -6
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
  22. palimpzest/prompts/prompt_factory.py +375 -47
  23. palimpzest/prompts/split_proposer_prompts.py +1 -1
  24. palimpzest/prompts/util_phrases.py +5 -0
  25. palimpzest/prompts/validator.py +239 -0
  26. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  27. palimpzest/query/execution/execution_strategy.py +210 -317
  28. palimpzest/query/execution/execution_strategy_type.py +5 -7
  29. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  30. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  31. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  32. palimpzest/query/generators/generators.py +160 -331
  33. palimpzest/query/operators/__init__.py +15 -5
  34. palimpzest/query/operators/aggregate.py +50 -33
  35. palimpzest/query/operators/compute.py +201 -0
  36. palimpzest/query/operators/convert.py +33 -19
  37. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  38. palimpzest/query/operators/distinct.py +62 -0
  39. palimpzest/query/operators/filter.py +26 -16
  40. palimpzest/query/operators/join.py +403 -0
  41. palimpzest/query/operators/limit.py +3 -3
  42. palimpzest/query/operators/logical.py +205 -77
  43. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  44. palimpzest/query/operators/physical.py +27 -21
  45. palimpzest/query/operators/project.py +3 -3
  46. palimpzest/query/operators/rag_convert.py +7 -7
  47. palimpzest/query/operators/retrieve.py +9 -9
  48. palimpzest/query/operators/scan.py +81 -42
  49. palimpzest/query/operators/search.py +524 -0
  50. palimpzest/query/operators/split_convert.py +10 -8
  51. palimpzest/query/optimizer/__init__.py +7 -9
  52. palimpzest/query/optimizer/cost_model.py +108 -441
  53. palimpzest/query/optimizer/optimizer.py +123 -181
  54. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  55. palimpzest/query/optimizer/plan.py +352 -67
  56. palimpzest/query/optimizer/primitives.py +43 -19
  57. palimpzest/query/optimizer/rules.py +484 -646
  58. palimpzest/query/optimizer/tasks.py +127 -58
  59. palimpzest/query/processor/config.py +42 -76
  60. palimpzest/query/processor/query_processor.py +73 -18
  61. palimpzest/query/processor/query_processor_factory.py +46 -38
  62. palimpzest/schemabuilder/schema_builder.py +15 -28
  63. palimpzest/utils/model_helpers.py +32 -77
  64. palimpzest/utils/progress.py +114 -102
  65. palimpzest/validator/__init__.py +0 -0
  66. palimpzest/validator/validator.py +306 -0
  67. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
  68. palimpzest-0.8.1.dist-info/RECORD +95 -0
  69. palimpzest/core/lib/fields.py +0 -141
  70. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  71. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  72. palimpzest/query/generators/api_client_factory.py +0 -30
  73. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  74. palimpzest/query/operators/map.py +0 -130
  75. palimpzest/query/processor/nosentinel_processor.py +0 -33
  76. palimpzest/query/processor/processing_strategy_type.py +0 -28
  77. palimpzest/query/processor/sentinel_processor.py +0 -88
  78. palimpzest/query/processor/streaming_processor.py +0 -149
  79. palimpzest/sets.py +0 -405
  80. palimpzest/utils/datareader_helpers.py +0 -61
  81. palimpzest/utils/demo_helpers.py +0 -75
  82. palimpzest/utils/field_helpers.py +0 -69
  83. palimpzest/utils/generation_helpers.py +0 -69
  84. palimpzest/utils/sandbox.py +0 -183
  85. palimpzest-0.7.21.dist-info/RECORD +0 -95
  86. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
  88. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
  89. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
@@ -3,11 +3,13 @@ from __future__ import annotations
3
3
  import json
4
4
  from typing import Callable
5
5
 
6
+ from pydantic import BaseModel
7
+
6
8
  from palimpzest.constants import AggFunc, Cardinality
7
- from palimpzest.core.data.datareaders import DataReader
9
+ from palimpzest.core.data import context, dataset
8
10
  from palimpzest.core.elements.filters import Filter
9
11
  from palimpzest.core.elements.groupbysig import GroupBySig
10
- from palimpzest.core.lib.schemas import Schema
12
+ from palimpzest.core.lib.schemas import Average, Count
11
13
  from palimpzest.utils.hash_helpers import hash_for_id
12
14
 
13
15
 
@@ -16,8 +18,8 @@ class LogicalOperator:
16
18
  A logical operator is an operator that operates on Sets.
17
19
 
18
20
  Right now it can be one of:
19
- - BaseScan (scans data from DataReader)
20
- - CacheScan (scans cached Set)
21
+ - BaseScan (scans data from a root Dataset)
22
+ - ContextScan (loads the context for a root Dataset)
21
23
  - FilteredScan (scans input Set and applies filter)
22
24
  - ConvertScan (scans input Set and converts it to new Schema)
23
25
  - LimitScan (scans up to N records from a Set)
@@ -25,6 +27,8 @@ class LogicalOperator:
25
27
  - Aggregate (applies an aggregation on the Set)
26
28
  - RetrieveScan (fetches documents from a provided input for a given query)
27
29
  - Map (applies a function to each record in the Set without adding any new columns)
30
+ - ComputeOperator (executes a computation described in natural language)
31
+ - SearchOperator (executes a search query on the input Context)
28
32
 
29
33
  Every logical operator must declare the get_logical_id_params() and get_logical_op_params() methods,
30
34
  which return dictionaries of parameters that are used to compute the logical op id and to implement
@@ -33,17 +37,21 @@ class LogicalOperator:
33
37
 
34
38
  def __init__(
35
39
  self,
36
- output_schema: Schema,
37
- input_schema: Schema | None = None,
40
+ output_schema: type[BaseModel],
41
+ input_schema: type[BaseModel] | None = None,
42
+ depends_on: list[str] | None = None,
38
43
  ):
44
+ # TODO: can we eliminate input_schema?
39
45
  self.output_schema = output_schema
40
46
  self.input_schema = input_schema
47
+ self.depends_on = [] if depends_on is None else sorted(depends_on)
41
48
  self.logical_op_id: str | None = None
49
+ self.unique_logical_op_id: str | None = None
42
50
 
43
51
  # compute the fields generated by this logical operator
44
- input_field_names = self.input_schema.field_names() if self.input_schema is not None else []
52
+ input_field_names = list(self.input_schema.model_fields) if self.input_schema is not None else []
45
53
  self.generated_fields = sorted(
46
- [field_name for field_name in self.output_schema.field_names() if field_name not in input_field_names]
54
+ [field_name for field_name in self.output_schema.model_fields if field_name not in input_field_names]
47
55
  )
48
56
 
49
57
  def __str__(self) -> str:
@@ -54,12 +62,28 @@ class LogicalOperator:
54
62
  return isinstance(other, self.__class__) and all_id_params_match
55
63
 
56
64
  def copy(self) -> LogicalOperator:
57
- return self.__class__(**self.get_logical_op_params())
65
+ logical_op_copy = self.__class__(**self.get_logical_op_params())
66
+ logical_op_copy.logical_op_id = self.logical_op_id
67
+ logical_op_copy.unique_logical_op_id = self.unique_logical_op_id
68
+ return logical_op_copy
58
69
 
59
70
  def logical_op_name(self) -> str:
60
71
  """Name of the logical operator."""
61
72
  return str(self.__class__.__name__)
62
73
 
74
+ def get_unique_logical_op_id(self) -> str:
75
+ """
76
+ Get the unique logical operator id for this logical operator.
77
+ """
78
+ return self.unique_logical_op_id
79
+
80
+ def set_unique_logical_op_id(self, unique_logical_op_id: str) -> None:
81
+ """
82
+ Set the unique logical operator id for this logical operator.
83
+ This is used to uniquely identify the logical operator in the query plan.
84
+ """
85
+ self.unique_logical_op_id = unique_logical_op_id
86
+
63
87
  def get_logical_id_params(self) -> dict:
64
88
  """
65
89
  Returns a dictionary mapping of logical operator parameters which are relevant
@@ -69,6 +93,7 @@ class LogicalOperator:
69
93
  NOTE: input_schema and output_schema are not included in the id params because
70
94
  they depend on how the Optimizer orders operations.
71
95
  """
96
+ # TODO: should we use `generated_fields` after getting rid of them in PhysicalOperator?
72
97
  return {"generated_fields": self.generated_fields}
73
98
 
74
99
  def get_logical_op_params(self) -> dict:
@@ -78,10 +103,16 @@ class LogicalOperator:
78
103
 
79
104
  NOTE: Should be overriden by subclasses to include class-specific parameters.
80
105
  """
81
- return {"input_schema": self.input_schema, "output_schema": self.output_schema}
106
+ return {
107
+ "input_schema": self.input_schema,
108
+ "output_schema": self.output_schema,
109
+ "depends_on": self.depends_on,
110
+ }
82
111
 
83
112
  def get_logical_op_id(self):
84
113
  """
114
+ TODO: turn this into a property?
115
+
85
116
  NOTE: We do not call this in the __init__() method as subclasses may set parameters
86
117
  returned by self.get_logical_op_params() after they call to super().__init__().
87
118
  """
@@ -119,13 +150,19 @@ class Aggregate(LogicalOperator):
119
150
  def __init__(
120
151
  self,
121
152
  agg_func: AggFunc,
122
- target_cache_id: str | None = None,
123
153
  *args,
124
154
  **kwargs,
125
155
  ):
156
+ if kwargs.get("output_schema") is None:
157
+ if agg_func == AggFunc.COUNT:
158
+ kwargs["output_schema"] = Count
159
+ elif agg_func == AggFunc.AVERAGE:
160
+ kwargs["output_schema"] = Average
161
+ else:
162
+ raise ValueError(f"Unsupported aggregation function: {agg_func}")
163
+
126
164
  super().__init__(*args, **kwargs)
127
165
  self.agg_func = agg_func
128
- self.target_cache_id = target_cache_id
129
166
 
130
167
  def __str__(self):
131
168
  return f"{self.__class__.__name__}(function: {str(self.agg_func.value)})"
@@ -140,7 +177,6 @@ class Aggregate(LogicalOperator):
140
177
  logical_op_params = super().get_logical_op_params()
141
178
  logical_op_params = {
142
179
  "agg_func": self.agg_func,
143
- "target_cache_id": self.target_cache_id,
144
180
  **logical_op_params,
145
181
  }
146
182
 
@@ -148,81 +184,96 @@ class Aggregate(LogicalOperator):
148
184
 
149
185
 
150
186
  class BaseScan(LogicalOperator):
151
- """A BaseScan is a logical operator that represents a scan of a particular data source."""
187
+ """A BaseScan is a logical operator that represents a scan of a particular root Dataset."""
152
188
 
153
- def __init__(self, datareader: DataReader, output_schema: Schema):
154
- super().__init__(output_schema=output_schema)
155
- self.datareader = datareader
189
+ def __init__(self, datasource: dataset.Dataset, output_schema: type[BaseModel], *args, **kwargs):
190
+ super().__init__(*args, output_schema=output_schema, **kwargs)
191
+ self.datasource = datasource
156
192
 
157
193
  def __str__(self):
158
- return f"BaseScan({self.datareader},{self.output_schema})"
194
+ return f"BaseScan({self.datasource},{self.output_schema})"
159
195
 
160
196
  def __eq__(self, other) -> bool:
161
197
  return (
162
198
  isinstance(other, BaseScan)
163
- and self.input_schema.get_desc() == other.input_schema.get_desc()
164
- and self.output_schema.get_desc() == other.output_schema.get_desc()
165
- and self.datareader == other.datareader
199
+ and self.input_schema == other.input_schema
200
+ and self.output_schema == other.output_schema
201
+ and self.datasource == other.datasource
166
202
  )
167
203
 
168
204
  def get_logical_id_params(self) -> dict:
169
- return super().get_logical_id_params()
205
+ logical_id_params = super().get_logical_id_params()
206
+ logical_id_params = {
207
+ "id": self.datasource.id,
208
+ **logical_id_params,
209
+ }
210
+
211
+ return logical_id_params
170
212
 
171
213
  def get_logical_op_params(self) -> dict:
172
214
  logical_op_params = super().get_logical_op_params()
173
- logical_op_params = {"datareader": self.datareader, **logical_op_params}
215
+ logical_op_params = {"datasource": self.datasource, **logical_op_params}
174
216
 
175
217
  return logical_op_params
176
218
 
177
219
 
178
- class CacheScan(LogicalOperator):
179
- """A CacheScan is a logical operator that represents a scan of a cached Set."""
220
+ class ContextScan(LogicalOperator):
221
+ """A ContextScan is a logical operator that loads the context for a particular root Dataset."""
180
222
 
181
- def __init__(self, datareader: DataReader, output_schema: Schema):
182
- super().__init__(output_schema=output_schema)
183
- self.datareader = datareader
223
+ def __init__(self, context: context.Context, output_schema: type[BaseModel], *args, **kwargs):
224
+ super().__init__(*args, output_schema=output_schema, **kwargs)
225
+ self.context = context
184
226
 
185
227
  def __str__(self):
186
- return f"CacheScan({self.datareader},{self.output_schema})"
228
+ return f"ContextScan({self.context},{self.output_schema})"
229
+
230
+ def __eq__(self, other) -> bool:
231
+ return (
232
+ isinstance(other, ContextScan)
233
+ and self.context.id == other.context.id
234
+ )
187
235
 
188
236
  def get_logical_id_params(self) -> dict:
189
- return super().get_logical_id_params()
237
+ logical_id_params = super().get_logical_id_params()
238
+ logical_id_params = {
239
+ "id": self.context.id,
240
+ **logical_id_params,
241
+ }
242
+
243
+ return logical_id_params
190
244
 
191
245
  def get_logical_op_params(self) -> dict:
192
246
  logical_op_params = super().get_logical_op_params()
193
- logical_op_params = {"datareader": self.datareader, **logical_op_params}
247
+ logical_op_params = {"context": self.context, **logical_op_params}
194
248
 
195
249
  return logical_op_params
196
250
 
197
251
 
198
252
  class ConvertScan(LogicalOperator):
199
- """A ConvertScan is a logical operator that represents a scan of a particular data source, with conversion applied."""
253
+ """A ConvertScan is a logical operator that represents a scan of a particular input Dataset, with conversion applied."""
200
254
 
201
255
  def __init__(
202
256
  self,
203
257
  cardinality: Cardinality = Cardinality.ONE_TO_ONE,
204
258
  udf: Callable | None = None,
205
- depends_on: list[str] | None = None,
206
259
  desc: str | None = None,
207
- target_cache_id: str | None = None,
208
260
  *args,
209
261
  **kwargs,
210
262
  ):
211
263
  super().__init__(*args, **kwargs)
212
264
  self.cardinality = cardinality
213
265
  self.udf = udf
214
- self.depends_on = [] if depends_on is None else sorted(depends_on)
215
266
  self.desc = desc
216
- self.target_cache_id = target_cache_id
217
267
 
218
268
  def __str__(self):
219
- return f"ConvertScan({self.input_schema} -> {str(self.output_schema)},{str(self.desc)})"
269
+ return f"ConvertScan({self.input_schema} -> {str(self.output_schema)})"
220
270
 
221
271
  def get_logical_id_params(self) -> dict:
222
272
  logical_id_params = super().get_logical_id_params()
223
273
  logical_id_params = {
224
274
  "cardinality": self.cardinality,
225
275
  "udf": self.udf,
276
+ "desc": self.desc,
226
277
  **logical_id_params,
227
278
  }
228
279
 
@@ -233,9 +284,41 @@ class ConvertScan(LogicalOperator):
233
284
  logical_op_params = {
234
285
  "cardinality": self.cardinality,
235
286
  "udf": self.udf,
236
- "depends_on": self.depends_on,
237
287
  "desc": self.desc,
238
- "target_cache_id": self.target_cache_id,
288
+ **logical_op_params,
289
+ }
290
+
291
+ return logical_op_params
292
+
293
+
294
+ class Distinct(LogicalOperator):
295
+ def __init__(self, distinct_cols: list[str] | None, *args, **kwargs):
296
+ super().__init__(*args, **kwargs)
297
+ # if distinct_cols is not None, check that all columns are in the input schema
298
+ if distinct_cols is not None:
299
+ for col in distinct_cols:
300
+ assert col in self.input_schema.model_fields, f"Column {col} not found in input schema {self.input_schema} for Distinct operator"
301
+
302
+ # store the list of distinct columns, sorted
303
+ self.distinct_cols = (
304
+ sorted([field_name for field_name in self.input_schema.model_fields])
305
+ if distinct_cols is None
306
+ else sorted(distinct_cols)
307
+ )
308
+
309
+ def __str__(self):
310
+ return f"Distinct({self.distinct_cols})"
311
+
312
+ def get_logical_id_params(self) -> dict:
313
+ logical_id_params = super().get_logical_id_params()
314
+ logical_id_params = {"distinct_cols": self.distinct_cols, **logical_id_params}
315
+
316
+ return logical_id_params
317
+
318
+ def get_logical_op_params(self) -> dict:
319
+ logical_op_params = super().get_logical_op_params()
320
+ logical_op_params = {
321
+ "distinct_cols": self.distinct_cols,
239
322
  **logical_op_params,
240
323
  }
241
324
 
@@ -243,20 +326,18 @@ class ConvertScan(LogicalOperator):
243
326
 
244
327
 
245
328
  class FilteredScan(LogicalOperator):
246
- """A FilteredScan is a logical operator that represents a scan of a particular data source, with filters applied."""
329
+ """A FilteredScan is a logical operator that represents a scan of a particular input Dataset, with filters applied."""
247
330
 
248
331
  def __init__(
249
332
  self,
250
333
  filter: Filter,
251
- depends_on: list[str] | None = None,
252
- target_cache_id: str | None = None,
334
+ desc: str | None = None,
253
335
  *args,
254
336
  **kwargs,
255
337
  ):
256
338
  super().__init__(*args, **kwargs)
257
339
  self.filter = filter
258
- self.depends_on = [] if depends_on is None else sorted(depends_on)
259
- self.target_cache_id = target_cache_id
340
+ self.desc = desc
260
341
 
261
342
  def __str__(self):
262
343
  return f"FilteredScan({str(self.output_schema)}, {str(self.filter)})"
@@ -265,6 +346,7 @@ class FilteredScan(LogicalOperator):
265
346
  logical_id_params = super().get_logical_id_params()
266
347
  logical_id_params = {
267
348
  "filter": self.filter,
349
+ "desc": self.desc,
268
350
  **logical_id_params,
269
351
  }
270
352
 
@@ -274,8 +356,7 @@ class FilteredScan(LogicalOperator):
274
356
  logical_op_params = super().get_logical_op_params()
275
357
  logical_op_params = {
276
358
  "filter": self.filter,
277
- "depends_on": self.depends_on,
278
- "target_cache_id": self.target_cache_id,
359
+ "desc": self.desc,
279
360
  **logical_op_params,
280
361
  }
281
362
 
@@ -286,7 +367,6 @@ class GroupByAggregate(LogicalOperator):
286
367
  def __init__(
287
368
  self,
288
369
  group_by_sig: GroupBySig,
289
- target_cache_id: str | None = None,
290
370
  *args,
291
371
  **kwargs,
292
372
  ):
@@ -297,7 +377,6 @@ class GroupByAggregate(LogicalOperator):
297
377
  if not valid:
298
378
  raise TypeError(error)
299
379
  self.group_by_sig = group_by_sig
300
- self.target_cache_id = target_cache_id
301
380
 
302
381
  def __str__(self):
303
382
  return f"GroupBy({self.group_by_sig.serialize()})"
@@ -312,7 +391,32 @@ class GroupByAggregate(LogicalOperator):
312
391
  logical_op_params = super().get_logical_op_params()
313
392
  logical_op_params = {
314
393
  "group_by_sig": self.group_by_sig,
315
- "target_cache_id": self.target_cache_id,
394
+ **logical_op_params,
395
+ }
396
+
397
+ return logical_op_params
398
+
399
+
400
+ class JoinOp(LogicalOperator):
401
+ def __init__(self, condition: str, desc: str | None = None, *args, **kwargs):
402
+ super().__init__(*args, **kwargs)
403
+ self.condition = condition
404
+ self.desc = desc
405
+
406
+ def __str__(self):
407
+ return f"Join(condition={self.condition})"
408
+
409
+ def get_logical_id_params(self) -> dict:
410
+ logical_id_params = super().get_logical_id_params()
411
+ logical_id_params = {"condition": self.condition, "desc": self.desc, **logical_id_params}
412
+
413
+ return logical_id_params
414
+
415
+ def get_logical_op_params(self) -> dict:
416
+ logical_op_params = super().get_logical_op_params()
417
+ logical_op_params = {
418
+ "condition": self.condition,
419
+ "desc": self.desc,
316
420
  **logical_op_params,
317
421
  }
318
422
 
@@ -320,10 +424,9 @@ class GroupByAggregate(LogicalOperator):
320
424
 
321
425
 
322
426
  class LimitScan(LogicalOperator):
323
- def __init__(self, limit: int, target_cache_id: str | None = None, *args, **kwargs):
427
+ def __init__(self, limit: int, *args, **kwargs):
324
428
  super().__init__(*args, **kwargs)
325
429
  self.limit = limit
326
- self.target_cache_id = target_cache_id
327
430
 
328
431
  def __str__(self):
329
432
  return f"LimitScan({str(self.input_schema)}, {str(self.output_schema)})"
@@ -338,7 +441,6 @@ class LimitScan(LogicalOperator):
338
441
  logical_op_params = super().get_logical_op_params()
339
442
  logical_op_params = {
340
443
  "limit": self.limit,
341
- "target_cache_id": self.target_cache_id,
342
444
  **logical_op_params,
343
445
  }
344
446
 
@@ -346,10 +448,9 @@ class LimitScan(LogicalOperator):
346
448
 
347
449
 
348
450
  class Project(LogicalOperator):
349
- def __init__(self, project_cols: list[str], target_cache_id: str | None = None, *args, **kwargs):
451
+ def __init__(self, project_cols: list[str], *args, **kwargs):
350
452
  super().__init__(*args, **kwargs)
351
453
  self.project_cols = project_cols
352
- self.target_cache_id = target_cache_id
353
454
 
354
455
  def __str__(self):
355
456
  return f"Project({self.input_schema}, {self.project_cols})"
@@ -364,7 +465,6 @@ class Project(LogicalOperator):
364
465
  logical_op_params = super().get_logical_op_params()
365
466
  logical_op_params = {
366
467
  "project_cols": self.project_cols,
367
- "target_cache_id": self.target_cache_id,
368
468
  **logical_op_params,
369
469
  }
370
470
 
@@ -372,7 +472,7 @@ class Project(LogicalOperator):
372
472
 
373
473
 
374
474
  class RetrieveScan(LogicalOperator):
375
- """A RetrieveScan is a logical operator that represents a scan of a particular data source, with a convert-like retrieve applied."""
475
+ """A RetrieveScan is a logical operator that represents a scan of a particular input Dataset, with a convert-like retrieve applied."""
376
476
 
377
477
  def __init__(
378
478
  self,
@@ -381,7 +481,6 @@ class RetrieveScan(LogicalOperator):
381
481
  search_attr,
382
482
  output_attrs,
383
483
  k,
384
- target_cache_id: str = None,
385
484
  *args,
386
485
  **kwargs,
387
486
  ):
@@ -391,10 +490,9 @@ class RetrieveScan(LogicalOperator):
391
490
  self.search_attr = search_attr
392
491
  self.output_attrs = output_attrs
393
492
  self.k = k
394
- self.target_cache_id = target_cache_id
395
493
 
396
494
  def __str__(self):
397
- return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)},{str(self.desc)})"
495
+ return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)})"
398
496
 
399
497
  def get_logical_id_params(self) -> dict:
400
498
  # NOTE: if we allow optimization over index, then we will need to include it in the id params
@@ -418,36 +516,31 @@ class RetrieveScan(LogicalOperator):
418
516
  "search_attr": self.search_attr,
419
517
  "output_attrs": self.output_attrs,
420
518
  "k": self.k,
421
- "target_cache_id": self.target_cache_id,
422
519
  **logical_op_params,
423
520
  }
424
521
 
425
522
  return logical_op_params
426
523
 
427
524
 
428
- # TODO: (near-term) maybe we should try to fold this into ConvertScan, and make the internals of PZ
429
- # amenable to a convert operator (with a UDF) that does not add new columns?
430
- class MapScan(LogicalOperator):
431
- """A MapScan is a logical operator that applies a UDF to each input record without adding new columns."""
525
+ class ComputeOperator(LogicalOperator):
526
+ """
527
+ A ComputeOperator is a logical operator that performs a computation described in natural language
528
+ on a given Context.
529
+ """
432
530
 
433
- def __init__(
434
- self,
435
- udf: Callable | None = None,
436
- target_cache_id: str | None = None,
437
- *args,
438
- **kwargs,
439
- ):
531
+ def __init__(self, context_id: str, instruction: str, *args, **kwargs):
440
532
  super().__init__(*args, **kwargs)
441
- self.udf = udf
442
- self.target_cache_id = target_cache_id
533
+ self.context_id = context_id
534
+ self.instruction = instruction
443
535
 
444
536
  def __str__(self):
445
- return f"MapScan({self.output_schema}, {self.udf.__name__})"
537
+ return f"ComputeOperator(id={self.context_id}, instr={self.instruction:20s})"
446
538
 
447
539
  def get_logical_id_params(self) -> dict:
448
540
  logical_id_params = super().get_logical_id_params()
449
541
  logical_id_params = {
450
- "udf": self.udf,
542
+ "context_id": self.context_id,
543
+ "instruction": self.instruction,
451
544
  **logical_id_params,
452
545
  }
453
546
 
@@ -456,8 +549,43 @@ class MapScan(LogicalOperator):
456
549
  def get_logical_op_params(self) -> dict:
457
550
  logical_op_params = super().get_logical_op_params()
458
551
  logical_op_params = {
459
- "udf": self.udf,
460
- "target_cache_id": self.target_cache_id,
552
+ "context_id": self.context_id,
553
+ "instruction": self.instruction,
554
+ **logical_op_params,
555
+ }
556
+
557
+ return logical_op_params
558
+
559
+
560
+ class SearchOperator(LogicalOperator):
561
+ """
562
+ A SearchOperator is a logical operator that executes a search described in natural language
563
+ on a given Context.
564
+ """
565
+
566
+ def __init__(self, context_id: str, search_query: str, *args, **kwargs):
567
+ super().__init__(*args, **kwargs)
568
+ self.context_id = context_id
569
+ self.search_query = search_query
570
+
571
+ def __str__(self):
572
+ return f"SearchOperator(id={self.context_id}, search_query={self.search_query:20s})"
573
+
574
+ def get_logical_id_params(self) -> dict:
575
+ logical_id_params = super().get_logical_id_params()
576
+ logical_id_params = {
577
+ "context_id": self.context_id,
578
+ "search_query": self.search_query,
579
+ **logical_id_params,
580
+ }
581
+
582
+ return logical_id_params
583
+
584
+ def get_logical_op_params(self) -> dict:
585
+ logical_op_params = super().get_logical_op_params()
586
+ logical_op_params = {
587
+ "context_id": self.context_id,
588
+ "search_query": self.search_query,
461
589
  **logical_op_params,
462
590
  }
463
591
 
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from pydantic.fields import FieldInfo
4
+
3
5
  from palimpzest.constants import MODEL_CARDS, Model, PromptStrategy
4
- from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
5
6
  from palimpzest.core.elements.records import DataRecord
6
- from palimpzest.core.lib.fields import Field
7
- from palimpzest.query.generators.generators import generator_factory
7
+ from palimpzest.core.models import GenerationStats, OperatorCostEstimates
8
+ from palimpzest.query.generators.generators import Generator
8
9
  from palimpzest.query.operators.convert import LLMConvert
9
10
 
10
11
  # TYPE DEFINITIONS
@@ -20,7 +21,6 @@ class MixtureOfAgentsConvert(LLMConvert):
20
21
  aggregator_model: Model,
21
22
  proposer_prompt_strategy: PromptStrategy = PromptStrategy.COT_MOA_PROPOSER,
22
23
  aggregator_prompt_strategy: PromptStrategy = PromptStrategy.COT_MOA_AGG,
23
- proposer_prompt: str | None = None,
24
24
  *args,
25
25
  **kwargs,
26
26
  ):
@@ -33,14 +33,13 @@ class MixtureOfAgentsConvert(LLMConvert):
33
33
  self.aggregator_model = aggregator_model
34
34
  self.proposer_prompt_strategy = proposer_prompt_strategy
35
35
  self.aggregator_prompt_strategy = aggregator_prompt_strategy
36
- self.proposer_prompt = proposer_prompt
37
36
 
38
37
  # create generators
39
38
  self.proposer_generators = [
40
- generator_factory(model, self.proposer_prompt_strategy, self.cardinality, self.verbose)
39
+ Generator(model, self.proposer_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
41
40
  for model in proposer_models
42
41
  ]
43
- self.aggregator_generator = generator_factory(aggregator_model, self.aggregator_prompt_strategy, self.cardinality, self.verbose)
42
+ self.aggregator_generator = Generator(aggregator_model, self.aggregator_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
44
43
 
45
44
  def __str__(self):
46
45
  op = super().__str__()
@@ -77,6 +76,9 @@ class MixtureOfAgentsConvert(LLMConvert):
77
76
 
78
77
  return op_params
79
78
 
79
+ def is_image_conversion(self) -> bool:
80
+ return self.proposer_prompt_strategy.is_image_prompt()
81
+
80
82
  def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
81
83
  """
82
84
  Currently, we are using multiple proposer models with different temperatures to synthesize
@@ -111,7 +113,7 @@ class MixtureOfAgentsConvert(LLMConvert):
111
113
 
112
114
  return naive_op_cost_estimates
113
115
 
114
- def convert(self, candidate: DataRecord, fields: dict[str, Field]) -> tuple[dict[str, list], GenerationStats]:
116
+ def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
115
117
  # get input fields
116
118
  input_fields = self.get_input_fields()
117
119