palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +259 -197
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +634 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +61 -5
- palimpzest/prompts/filter_prompts.py +50 -5
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
- palimpzest/prompts/prompt_factory.py +358 -46
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +157 -330
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +27 -21
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +22 -13
- palimpzest/query/operators/join.py +402 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +198 -80
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +41 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +27 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
- palimpzest-0.8.0.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.20.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -3,11 +3,13 @@ from __future__ import annotations
|
|
|
3
3
|
import json
|
|
4
4
|
from typing import Callable
|
|
5
5
|
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
6
8
|
from palimpzest.constants import AggFunc, Cardinality
|
|
7
|
-
from palimpzest.core.data
|
|
9
|
+
from palimpzest.core.data import context, dataset
|
|
8
10
|
from palimpzest.core.elements.filters import Filter
|
|
9
11
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
10
|
-
from palimpzest.core.lib.schemas import
|
|
12
|
+
from palimpzest.core.lib.schemas import Average, Count
|
|
11
13
|
from palimpzest.utils.hash_helpers import hash_for_id
|
|
12
14
|
|
|
13
15
|
|
|
@@ -16,8 +18,8 @@ class LogicalOperator:
|
|
|
16
18
|
A logical operator is an operator that operates on Sets.
|
|
17
19
|
|
|
18
20
|
Right now it can be one of:
|
|
19
|
-
- BaseScan (scans data from
|
|
20
|
-
-
|
|
21
|
+
- BaseScan (scans data from a root Dataset)
|
|
22
|
+
- ContextScan (loads the context for a root Dataset)
|
|
21
23
|
- FilteredScan (scans input Set and applies filter)
|
|
22
24
|
- ConvertScan (scans input Set and converts it to new Schema)
|
|
23
25
|
- LimitScan (scans up to N records from a Set)
|
|
@@ -25,6 +27,8 @@ class LogicalOperator:
|
|
|
25
27
|
- Aggregate (applies an aggregation on the Set)
|
|
26
28
|
- RetrieveScan (fetches documents from a provided input for a given query)
|
|
27
29
|
- Map (applies a function to each record in the Set without adding any new columns)
|
|
30
|
+
- ComputeOperator (executes a computation described in natural language)
|
|
31
|
+
- SearchOperator (executes a search query on the input Context)
|
|
28
32
|
|
|
29
33
|
Every logical operator must declare the get_logical_id_params() and get_logical_op_params() methods,
|
|
30
34
|
which return dictionaries of parameters that are used to compute the logical op id and to implement
|
|
@@ -33,17 +37,21 @@ class LogicalOperator:
|
|
|
33
37
|
|
|
34
38
|
def __init__(
|
|
35
39
|
self,
|
|
36
|
-
output_schema:
|
|
37
|
-
input_schema:
|
|
40
|
+
output_schema: type[BaseModel],
|
|
41
|
+
input_schema: type[BaseModel] | None = None,
|
|
42
|
+
depends_on: list[str] | None = None,
|
|
38
43
|
):
|
|
44
|
+
# TODO: can we eliminate input_schema?
|
|
39
45
|
self.output_schema = output_schema
|
|
40
46
|
self.input_schema = input_schema
|
|
47
|
+
self.depends_on = [] if depends_on is None else sorted(depends_on)
|
|
41
48
|
self.logical_op_id: str | None = None
|
|
49
|
+
self.unique_logical_op_id: str | None = None
|
|
42
50
|
|
|
43
51
|
# compute the fields generated by this logical operator
|
|
44
|
-
input_field_names = self.input_schema.
|
|
52
|
+
input_field_names = list(self.input_schema.model_fields) if self.input_schema is not None else []
|
|
45
53
|
self.generated_fields = sorted(
|
|
46
|
-
[field_name for field_name in self.output_schema.
|
|
54
|
+
[field_name for field_name in self.output_schema.model_fields if field_name not in input_field_names]
|
|
47
55
|
)
|
|
48
56
|
|
|
49
57
|
def __str__(self) -> str:
|
|
@@ -54,12 +62,28 @@ class LogicalOperator:
|
|
|
54
62
|
return isinstance(other, self.__class__) and all_id_params_match
|
|
55
63
|
|
|
56
64
|
def copy(self) -> LogicalOperator:
|
|
57
|
-
|
|
65
|
+
logical_op_copy = self.__class__(**self.get_logical_op_params())
|
|
66
|
+
logical_op_copy.logical_op_id = self.logical_op_id
|
|
67
|
+
logical_op_copy.unique_logical_op_id = self.unique_logical_op_id
|
|
68
|
+
return logical_op_copy
|
|
58
69
|
|
|
59
70
|
def logical_op_name(self) -> str:
|
|
60
71
|
"""Name of the logical operator."""
|
|
61
72
|
return str(self.__class__.__name__)
|
|
62
73
|
|
|
74
|
+
def get_unique_logical_op_id(self) -> str:
|
|
75
|
+
"""
|
|
76
|
+
Get the unique logical operator id for this logical operator.
|
|
77
|
+
"""
|
|
78
|
+
return self.unique_logical_op_id
|
|
79
|
+
|
|
80
|
+
def set_unique_logical_op_id(self, unique_logical_op_id: str) -> None:
|
|
81
|
+
"""
|
|
82
|
+
Set the unique logical operator id for this logical operator.
|
|
83
|
+
This is used to uniquely identify the logical operator in the query plan.
|
|
84
|
+
"""
|
|
85
|
+
self.unique_logical_op_id = unique_logical_op_id
|
|
86
|
+
|
|
63
87
|
def get_logical_id_params(self) -> dict:
|
|
64
88
|
"""
|
|
65
89
|
Returns a dictionary mapping of logical operator parameters which are relevant
|
|
@@ -69,6 +93,7 @@ class LogicalOperator:
|
|
|
69
93
|
NOTE: input_schema and output_schema are not included in the id params because
|
|
70
94
|
they depend on how the Optimizer orders operations.
|
|
71
95
|
"""
|
|
96
|
+
# TODO: should we use `generated_fields` after getting rid of them in PhysicalOperator?
|
|
72
97
|
return {"generated_fields": self.generated_fields}
|
|
73
98
|
|
|
74
99
|
def get_logical_op_params(self) -> dict:
|
|
@@ -78,10 +103,16 @@ class LogicalOperator:
|
|
|
78
103
|
|
|
79
104
|
NOTE: Should be overriden by subclasses to include class-specific parameters.
|
|
80
105
|
"""
|
|
81
|
-
return {
|
|
106
|
+
return {
|
|
107
|
+
"input_schema": self.input_schema,
|
|
108
|
+
"output_schema": self.output_schema,
|
|
109
|
+
"depends_on": self.depends_on,
|
|
110
|
+
}
|
|
82
111
|
|
|
83
112
|
def get_logical_op_id(self):
|
|
84
113
|
"""
|
|
114
|
+
TODO: turn this into a property?
|
|
115
|
+
|
|
85
116
|
NOTE: We do not call this in the __init__() method as subclasses may set parameters
|
|
86
117
|
returned by self.get_logical_op_params() after they call to super().__init__().
|
|
87
118
|
"""
|
|
@@ -119,13 +150,19 @@ class Aggregate(LogicalOperator):
|
|
|
119
150
|
def __init__(
|
|
120
151
|
self,
|
|
121
152
|
agg_func: AggFunc,
|
|
122
|
-
target_cache_id: str | None = None,
|
|
123
153
|
*args,
|
|
124
154
|
**kwargs,
|
|
125
155
|
):
|
|
156
|
+
if kwargs.get("output_schema") is None:
|
|
157
|
+
if agg_func == AggFunc.COUNT:
|
|
158
|
+
kwargs["output_schema"] = Count
|
|
159
|
+
elif agg_func == AggFunc.AVERAGE:
|
|
160
|
+
kwargs["output_schema"] = Average
|
|
161
|
+
else:
|
|
162
|
+
raise ValueError(f"Unsupported aggregation function: {agg_func}")
|
|
163
|
+
|
|
126
164
|
super().__init__(*args, **kwargs)
|
|
127
165
|
self.agg_func = agg_func
|
|
128
|
-
self.target_cache_id = target_cache_id
|
|
129
166
|
|
|
130
167
|
def __str__(self):
|
|
131
168
|
return f"{self.__class__.__name__}(function: {str(self.agg_func.value)})"
|
|
@@ -140,7 +177,6 @@ class Aggregate(LogicalOperator):
|
|
|
140
177
|
logical_op_params = super().get_logical_op_params()
|
|
141
178
|
logical_op_params = {
|
|
142
179
|
"agg_func": self.agg_func,
|
|
143
|
-
"target_cache_id": self.target_cache_id,
|
|
144
180
|
**logical_op_params,
|
|
145
181
|
}
|
|
146
182
|
|
|
@@ -148,75 +184,87 @@ class Aggregate(LogicalOperator):
|
|
|
148
184
|
|
|
149
185
|
|
|
150
186
|
class BaseScan(LogicalOperator):
|
|
151
|
-
"""A BaseScan is a logical operator that represents a scan of a particular
|
|
187
|
+
"""A BaseScan is a logical operator that represents a scan of a particular root Dataset."""
|
|
152
188
|
|
|
153
|
-
def __init__(self,
|
|
154
|
-
super().__init__(output_schema=output_schema)
|
|
155
|
-
self.
|
|
189
|
+
def __init__(self, datasource: dataset.Dataset, output_schema: type[BaseModel], *args, **kwargs):
|
|
190
|
+
super().__init__(*args, output_schema=output_schema, **kwargs)
|
|
191
|
+
self.datasource = datasource
|
|
156
192
|
|
|
157
193
|
def __str__(self):
|
|
158
|
-
return f"BaseScan({self.
|
|
194
|
+
return f"BaseScan({self.datasource},{self.output_schema})"
|
|
159
195
|
|
|
160
196
|
def __eq__(self, other) -> bool:
|
|
161
197
|
return (
|
|
162
198
|
isinstance(other, BaseScan)
|
|
163
|
-
and self.input_schema
|
|
164
|
-
and self.output_schema
|
|
165
|
-
and self.
|
|
199
|
+
and self.input_schema == other.input_schema
|
|
200
|
+
and self.output_schema == other.output_schema
|
|
201
|
+
and self.datasource == other.datasource
|
|
166
202
|
)
|
|
167
203
|
|
|
168
204
|
def get_logical_id_params(self) -> dict:
|
|
169
|
-
|
|
205
|
+
logical_id_params = super().get_logical_id_params()
|
|
206
|
+
logical_id_params = {
|
|
207
|
+
"id": self.datasource.id,
|
|
208
|
+
**logical_id_params,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
return logical_id_params
|
|
170
212
|
|
|
171
213
|
def get_logical_op_params(self) -> dict:
|
|
172
214
|
logical_op_params = super().get_logical_op_params()
|
|
173
|
-
logical_op_params = {"
|
|
215
|
+
logical_op_params = {"datasource": self.datasource, **logical_op_params}
|
|
174
216
|
|
|
175
217
|
return logical_op_params
|
|
176
218
|
|
|
177
219
|
|
|
178
|
-
class
|
|
179
|
-
"""A
|
|
220
|
+
class ContextScan(LogicalOperator):
|
|
221
|
+
"""A ContextScan is a logical operator that loads the context for a particular root Dataset."""
|
|
180
222
|
|
|
181
|
-
def __init__(self,
|
|
182
|
-
super().__init__(output_schema=output_schema)
|
|
183
|
-
self.
|
|
223
|
+
def __init__(self, context: context.Context, output_schema: type[BaseModel], *args, **kwargs):
|
|
224
|
+
super().__init__(*args, output_schema=output_schema, **kwargs)
|
|
225
|
+
self.context = context
|
|
184
226
|
|
|
185
227
|
def __str__(self):
|
|
186
|
-
return f"
|
|
228
|
+
return f"ContextScan({self.context},{self.output_schema})"
|
|
229
|
+
|
|
230
|
+
def __eq__(self, other) -> bool:
|
|
231
|
+
return (
|
|
232
|
+
isinstance(other, ContextScan)
|
|
233
|
+
and self.context.id == other.context.id
|
|
234
|
+
)
|
|
187
235
|
|
|
188
236
|
def get_logical_id_params(self) -> dict:
|
|
189
|
-
|
|
237
|
+
logical_id_params = super().get_logical_id_params()
|
|
238
|
+
logical_id_params = {
|
|
239
|
+
"id": self.context.id,
|
|
240
|
+
**logical_id_params,
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
return logical_id_params
|
|
190
244
|
|
|
191
245
|
def get_logical_op_params(self) -> dict:
|
|
192
246
|
logical_op_params = super().get_logical_op_params()
|
|
193
|
-
logical_op_params = {"
|
|
247
|
+
logical_op_params = {"context": self.context, **logical_op_params}
|
|
194
248
|
|
|
195
249
|
return logical_op_params
|
|
196
250
|
|
|
197
251
|
|
|
198
252
|
class ConvertScan(LogicalOperator):
|
|
199
|
-
"""A ConvertScan is a logical operator that represents a scan of a particular
|
|
253
|
+
"""A ConvertScan is a logical operator that represents a scan of a particular input Dataset, with conversion applied."""
|
|
200
254
|
|
|
201
255
|
def __init__(
|
|
202
256
|
self,
|
|
203
257
|
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
204
258
|
udf: Callable | None = None,
|
|
205
|
-
depends_on: list[str] | None = None,
|
|
206
|
-
desc: str | None = None,
|
|
207
|
-
target_cache_id: str | None = None,
|
|
208
259
|
*args,
|
|
209
260
|
**kwargs,
|
|
210
261
|
):
|
|
211
262
|
super().__init__(*args, **kwargs)
|
|
212
263
|
self.cardinality = cardinality
|
|
213
264
|
self.udf = udf
|
|
214
|
-
self.depends_on = [] if depends_on is None else sorted(depends_on)
|
|
215
|
-
self.desc = desc
|
|
216
|
-
self.target_cache_id = target_cache_id
|
|
217
265
|
|
|
218
266
|
def __str__(self):
|
|
219
|
-
return f"ConvertScan({self.input_schema} -> {str(self.output_schema)}
|
|
267
|
+
return f"ConvertScan({self.input_schema} -> {str(self.output_schema)})"
|
|
220
268
|
|
|
221
269
|
def get_logical_id_params(self) -> dict:
|
|
222
270
|
logical_id_params = super().get_logical_id_params()
|
|
@@ -233,9 +281,40 @@ class ConvertScan(LogicalOperator):
|
|
|
233
281
|
logical_op_params = {
|
|
234
282
|
"cardinality": self.cardinality,
|
|
235
283
|
"udf": self.udf,
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
284
|
+
**logical_op_params,
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
return logical_op_params
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class Distinct(LogicalOperator):
|
|
291
|
+
def __init__(self, distinct_cols: list[str] | None, *args, **kwargs):
|
|
292
|
+
super().__init__(*args, **kwargs)
|
|
293
|
+
# if distinct_cols is not None, check that all columns are in the input schema
|
|
294
|
+
if distinct_cols is not None:
|
|
295
|
+
for col in distinct_cols:
|
|
296
|
+
assert col in self.input_schema.model_fields, f"Column {col} not found in input schema {self.input_schema} for Distinct operator"
|
|
297
|
+
|
|
298
|
+
# store the list of distinct columns, sorted
|
|
299
|
+
self.distinct_cols = (
|
|
300
|
+
sorted([field_name for field_name in self.input_schema.model_fields])
|
|
301
|
+
if distinct_cols is None
|
|
302
|
+
else sorted(distinct_cols)
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def __str__(self):
|
|
306
|
+
return f"Distinct({self.distinct_cols})"
|
|
307
|
+
|
|
308
|
+
def get_logical_id_params(self) -> dict:
|
|
309
|
+
logical_id_params = super().get_logical_id_params()
|
|
310
|
+
logical_id_params = {"distinct_cols": self.distinct_cols, **logical_id_params}
|
|
311
|
+
|
|
312
|
+
return logical_id_params
|
|
313
|
+
|
|
314
|
+
def get_logical_op_params(self) -> dict:
|
|
315
|
+
logical_op_params = super().get_logical_op_params()
|
|
316
|
+
logical_op_params = {
|
|
317
|
+
"distinct_cols": self.distinct_cols,
|
|
239
318
|
**logical_op_params,
|
|
240
319
|
}
|
|
241
320
|
|
|
@@ -243,20 +322,16 @@ class ConvertScan(LogicalOperator):
|
|
|
243
322
|
|
|
244
323
|
|
|
245
324
|
class FilteredScan(LogicalOperator):
|
|
246
|
-
"""A FilteredScan is a logical operator that represents a scan of a particular
|
|
325
|
+
"""A FilteredScan is a logical operator that represents a scan of a particular input Dataset, with filters applied."""
|
|
247
326
|
|
|
248
327
|
def __init__(
|
|
249
328
|
self,
|
|
250
329
|
filter: Filter,
|
|
251
|
-
depends_on: list[str] | None = None,
|
|
252
|
-
target_cache_id: str | None = None,
|
|
253
330
|
*args,
|
|
254
331
|
**kwargs,
|
|
255
332
|
):
|
|
256
333
|
super().__init__(*args, **kwargs)
|
|
257
334
|
self.filter = filter
|
|
258
|
-
self.depends_on = [] if depends_on is None else sorted(depends_on)
|
|
259
|
-
self.target_cache_id = target_cache_id
|
|
260
335
|
|
|
261
336
|
def __str__(self):
|
|
262
337
|
return f"FilteredScan({str(self.output_schema)}, {str(self.filter)})"
|
|
@@ -274,8 +349,6 @@ class FilteredScan(LogicalOperator):
|
|
|
274
349
|
logical_op_params = super().get_logical_op_params()
|
|
275
350
|
logical_op_params = {
|
|
276
351
|
"filter": self.filter,
|
|
277
|
-
"depends_on": self.depends_on,
|
|
278
|
-
"target_cache_id": self.target_cache_id,
|
|
279
352
|
**logical_op_params,
|
|
280
353
|
}
|
|
281
354
|
|
|
@@ -286,7 +359,6 @@ class GroupByAggregate(LogicalOperator):
|
|
|
286
359
|
def __init__(
|
|
287
360
|
self,
|
|
288
361
|
group_by_sig: GroupBySig,
|
|
289
|
-
target_cache_id: str | None = None,
|
|
290
362
|
*args,
|
|
291
363
|
**kwargs,
|
|
292
364
|
):
|
|
@@ -297,7 +369,6 @@ class GroupByAggregate(LogicalOperator):
|
|
|
297
369
|
if not valid:
|
|
298
370
|
raise TypeError(error)
|
|
299
371
|
self.group_by_sig = group_by_sig
|
|
300
|
-
self.target_cache_id = target_cache_id
|
|
301
372
|
|
|
302
373
|
def __str__(self):
|
|
303
374
|
return f"GroupBy({self.group_by_sig.serialize()})"
|
|
@@ -312,7 +383,30 @@ class GroupByAggregate(LogicalOperator):
|
|
|
312
383
|
logical_op_params = super().get_logical_op_params()
|
|
313
384
|
logical_op_params = {
|
|
314
385
|
"group_by_sig": self.group_by_sig,
|
|
315
|
-
|
|
386
|
+
**logical_op_params,
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
return logical_op_params
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
class JoinOp(LogicalOperator):
|
|
393
|
+
def __init__(self, condition: str, *args, **kwargs):
|
|
394
|
+
super().__init__(*args, **kwargs)
|
|
395
|
+
self.condition = condition
|
|
396
|
+
|
|
397
|
+
def __str__(self):
|
|
398
|
+
return f"Join(condition={self.condition})"
|
|
399
|
+
|
|
400
|
+
def get_logical_id_params(self) -> dict:
|
|
401
|
+
logical_id_params = super().get_logical_id_params()
|
|
402
|
+
logical_id_params = {"condition": self.condition, **logical_id_params}
|
|
403
|
+
|
|
404
|
+
return logical_id_params
|
|
405
|
+
|
|
406
|
+
def get_logical_op_params(self) -> dict:
|
|
407
|
+
logical_op_params = super().get_logical_op_params()
|
|
408
|
+
logical_op_params = {
|
|
409
|
+
"condition": self.condition,
|
|
316
410
|
**logical_op_params,
|
|
317
411
|
}
|
|
318
412
|
|
|
@@ -320,10 +414,9 @@ class GroupByAggregate(LogicalOperator):
|
|
|
320
414
|
|
|
321
415
|
|
|
322
416
|
class LimitScan(LogicalOperator):
|
|
323
|
-
def __init__(self, limit: int,
|
|
417
|
+
def __init__(self, limit: int, *args, **kwargs):
|
|
324
418
|
super().__init__(*args, **kwargs)
|
|
325
419
|
self.limit = limit
|
|
326
|
-
self.target_cache_id = target_cache_id
|
|
327
420
|
|
|
328
421
|
def __str__(self):
|
|
329
422
|
return f"LimitScan({str(self.input_schema)}, {str(self.output_schema)})"
|
|
@@ -338,7 +431,6 @@ class LimitScan(LogicalOperator):
|
|
|
338
431
|
logical_op_params = super().get_logical_op_params()
|
|
339
432
|
logical_op_params = {
|
|
340
433
|
"limit": self.limit,
|
|
341
|
-
"target_cache_id": self.target_cache_id,
|
|
342
434
|
**logical_op_params,
|
|
343
435
|
}
|
|
344
436
|
|
|
@@ -346,10 +438,9 @@ class LimitScan(LogicalOperator):
|
|
|
346
438
|
|
|
347
439
|
|
|
348
440
|
class Project(LogicalOperator):
|
|
349
|
-
def __init__(self, project_cols: list[str],
|
|
441
|
+
def __init__(self, project_cols: list[str], *args, **kwargs):
|
|
350
442
|
super().__init__(*args, **kwargs)
|
|
351
443
|
self.project_cols = project_cols
|
|
352
|
-
self.target_cache_id = target_cache_id
|
|
353
444
|
|
|
354
445
|
def __str__(self):
|
|
355
446
|
return f"Project({self.input_schema}, {self.project_cols})"
|
|
@@ -364,7 +455,6 @@ class Project(LogicalOperator):
|
|
|
364
455
|
logical_op_params = super().get_logical_op_params()
|
|
365
456
|
logical_op_params = {
|
|
366
457
|
"project_cols": self.project_cols,
|
|
367
|
-
"target_cache_id": self.target_cache_id,
|
|
368
458
|
**logical_op_params,
|
|
369
459
|
}
|
|
370
460
|
|
|
@@ -372,7 +462,7 @@ class Project(LogicalOperator):
|
|
|
372
462
|
|
|
373
463
|
|
|
374
464
|
class RetrieveScan(LogicalOperator):
|
|
375
|
-
"""A RetrieveScan is a logical operator that represents a scan of a particular
|
|
465
|
+
"""A RetrieveScan is a logical operator that represents a scan of a particular input Dataset, with a convert-like retrieve applied."""
|
|
376
466
|
|
|
377
467
|
def __init__(
|
|
378
468
|
self,
|
|
@@ -381,7 +471,6 @@ class RetrieveScan(LogicalOperator):
|
|
|
381
471
|
search_attr,
|
|
382
472
|
output_attrs,
|
|
383
473
|
k,
|
|
384
|
-
target_cache_id: str = None,
|
|
385
474
|
*args,
|
|
386
475
|
**kwargs,
|
|
387
476
|
):
|
|
@@ -391,10 +480,9 @@ class RetrieveScan(LogicalOperator):
|
|
|
391
480
|
self.search_attr = search_attr
|
|
392
481
|
self.output_attrs = output_attrs
|
|
393
482
|
self.k = k
|
|
394
|
-
self.target_cache_id = target_cache_id
|
|
395
483
|
|
|
396
484
|
def __str__(self):
|
|
397
|
-
return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)}
|
|
485
|
+
return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)})"
|
|
398
486
|
|
|
399
487
|
def get_logical_id_params(self) -> dict:
|
|
400
488
|
# NOTE: if we allow optimization over index, then we will need to include it in the id params
|
|
@@ -418,36 +506,31 @@ class RetrieveScan(LogicalOperator):
|
|
|
418
506
|
"search_attr": self.search_attr,
|
|
419
507
|
"output_attrs": self.output_attrs,
|
|
420
508
|
"k": self.k,
|
|
421
|
-
"target_cache_id": self.target_cache_id,
|
|
422
509
|
**logical_op_params,
|
|
423
510
|
}
|
|
424
511
|
|
|
425
512
|
return logical_op_params
|
|
426
513
|
|
|
427
514
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
515
|
+
class ComputeOperator(LogicalOperator):
|
|
516
|
+
"""
|
|
517
|
+
A ComputeOperator is a logical operator that performs a computation described in natural language
|
|
518
|
+
on a given Context.
|
|
519
|
+
"""
|
|
432
520
|
|
|
433
|
-
def __init__(
|
|
434
|
-
self,
|
|
435
|
-
udf: Callable | None = None,
|
|
436
|
-
target_cache_id: str | None = None,
|
|
437
|
-
*args,
|
|
438
|
-
**kwargs,
|
|
439
|
-
):
|
|
521
|
+
def __init__(self, context_id: str, instruction: str, *args, **kwargs):
|
|
440
522
|
super().__init__(*args, **kwargs)
|
|
441
|
-
self.
|
|
442
|
-
self.
|
|
523
|
+
self.context_id = context_id
|
|
524
|
+
self.instruction = instruction
|
|
443
525
|
|
|
444
526
|
def __str__(self):
|
|
445
|
-
return f"
|
|
527
|
+
return f"ComputeOperator(id={self.context_id}, instr={self.instruction:20s})"
|
|
446
528
|
|
|
447
529
|
def get_logical_id_params(self) -> dict:
|
|
448
530
|
logical_id_params = super().get_logical_id_params()
|
|
449
531
|
logical_id_params = {
|
|
450
|
-
"
|
|
532
|
+
"context_id": self.context_id,
|
|
533
|
+
"instruction": self.instruction,
|
|
451
534
|
**logical_id_params,
|
|
452
535
|
}
|
|
453
536
|
|
|
@@ -456,8 +539,43 @@ class MapScan(LogicalOperator):
|
|
|
456
539
|
def get_logical_op_params(self) -> dict:
|
|
457
540
|
logical_op_params = super().get_logical_op_params()
|
|
458
541
|
logical_op_params = {
|
|
459
|
-
"
|
|
460
|
-
"
|
|
542
|
+
"context_id": self.context_id,
|
|
543
|
+
"instruction": self.instruction,
|
|
544
|
+
**logical_op_params,
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
return logical_op_params
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
class SearchOperator(LogicalOperator):
|
|
551
|
+
"""
|
|
552
|
+
A SearchOperator is a logical operator that executes a search described in natural language
|
|
553
|
+
on a given Context.
|
|
554
|
+
"""
|
|
555
|
+
|
|
556
|
+
def __init__(self, context_id: str, search_query: str, *args, **kwargs):
|
|
557
|
+
super().__init__(*args, **kwargs)
|
|
558
|
+
self.context_id = context_id
|
|
559
|
+
self.search_query = search_query
|
|
560
|
+
|
|
561
|
+
def __str__(self):
|
|
562
|
+
return f"SearchOperator(id={self.context_id}, search_query={self.search_query:20s})"
|
|
563
|
+
|
|
564
|
+
def get_logical_id_params(self) -> dict:
|
|
565
|
+
logical_id_params = super().get_logical_id_params()
|
|
566
|
+
logical_id_params = {
|
|
567
|
+
"context_id": self.context_id,
|
|
568
|
+
"search_query": self.search_query,
|
|
569
|
+
**logical_id_params,
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
return logical_id_params
|
|
573
|
+
|
|
574
|
+
def get_logical_op_params(self) -> dict:
|
|
575
|
+
logical_op_params = super().get_logical_op_params()
|
|
576
|
+
logical_op_params = {
|
|
577
|
+
"context_id": self.context_id,
|
|
578
|
+
"search_query": self.search_query,
|
|
461
579
|
**logical_op_params,
|
|
462
580
|
}
|
|
463
581
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from pydantic.fields import FieldInfo
|
|
4
|
+
|
|
3
5
|
from palimpzest.constants import MODEL_CARDS, Model, PromptStrategy
|
|
4
|
-
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
|
|
5
6
|
from palimpzest.core.elements.records import DataRecord
|
|
6
|
-
from palimpzest.core.
|
|
7
|
-
from palimpzest.query.generators.generators import
|
|
7
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates
|
|
8
|
+
from palimpzest.query.generators.generators import Generator
|
|
8
9
|
from palimpzest.query.operators.convert import LLMConvert
|
|
9
10
|
|
|
10
11
|
# TYPE DEFINITIONS
|
|
@@ -20,7 +21,6 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
20
21
|
aggregator_model: Model,
|
|
21
22
|
proposer_prompt_strategy: PromptStrategy = PromptStrategy.COT_MOA_PROPOSER,
|
|
22
23
|
aggregator_prompt_strategy: PromptStrategy = PromptStrategy.COT_MOA_AGG,
|
|
23
|
-
proposer_prompt: str | None = None,
|
|
24
24
|
*args,
|
|
25
25
|
**kwargs,
|
|
26
26
|
):
|
|
@@ -33,14 +33,13 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
33
33
|
self.aggregator_model = aggregator_model
|
|
34
34
|
self.proposer_prompt_strategy = proposer_prompt_strategy
|
|
35
35
|
self.aggregator_prompt_strategy = aggregator_prompt_strategy
|
|
36
|
-
self.proposer_prompt = proposer_prompt
|
|
37
36
|
|
|
38
37
|
# create generators
|
|
39
38
|
self.proposer_generators = [
|
|
40
|
-
|
|
39
|
+
Generator(model, self.proposer_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.verbose)
|
|
41
40
|
for model in proposer_models
|
|
42
41
|
]
|
|
43
|
-
self.aggregator_generator =
|
|
42
|
+
self.aggregator_generator = Generator(aggregator_model, self.aggregator_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.verbose)
|
|
44
43
|
|
|
45
44
|
def __str__(self):
|
|
46
45
|
op = super().__str__()
|
|
@@ -77,6 +76,9 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
77
76
|
|
|
78
77
|
return op_params
|
|
79
78
|
|
|
79
|
+
def is_image_conversion(self) -> bool:
|
|
80
|
+
return self.proposer_prompt_strategy.is_image_prompt()
|
|
81
|
+
|
|
80
82
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
81
83
|
"""
|
|
82
84
|
Currently, we are using multiple proposer models with different temperatures to synthesize
|
|
@@ -111,7 +113,7 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
111
113
|
|
|
112
114
|
return naive_op_cost_estimates
|
|
113
115
|
|
|
114
|
-
def convert(self, candidate: DataRecord, fields: dict[str,
|
|
116
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
115
117
|
# get input fields
|
|
116
118
|
input_fields = self.get_input_fields()
|
|
117
119
|
|