palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +343 -209
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +639 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +62 -6
- palimpzest/prompts/filter_prompts.py +51 -6
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
- palimpzest/prompts/prompt_factory.py +375 -47
- palimpzest/prompts/split_proposer_prompts.py +1 -1
- palimpzest/prompts/util_phrases.py +5 -0
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +160 -331
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +33 -19
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +26 -16
- palimpzest/query/operators/join.py +403 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +205 -77
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +42 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +32 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
- palimpzest-0.8.1.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -3,11 +3,13 @@ from __future__ import annotations
|
|
|
3
3
|
import json
|
|
4
4
|
from typing import Callable
|
|
5
5
|
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
6
8
|
from palimpzest.constants import AggFunc, Cardinality
|
|
7
|
-
from palimpzest.core.data
|
|
9
|
+
from palimpzest.core.data import context, dataset
|
|
8
10
|
from palimpzest.core.elements.filters import Filter
|
|
9
11
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
10
|
-
from palimpzest.core.lib.schemas import
|
|
12
|
+
from palimpzest.core.lib.schemas import Average, Count
|
|
11
13
|
from palimpzest.utils.hash_helpers import hash_for_id
|
|
12
14
|
|
|
13
15
|
|
|
@@ -16,8 +18,8 @@ class LogicalOperator:
|
|
|
16
18
|
A logical operator is an operator that operates on Sets.
|
|
17
19
|
|
|
18
20
|
Right now it can be one of:
|
|
19
|
-
- BaseScan (scans data from
|
|
20
|
-
-
|
|
21
|
+
- BaseScan (scans data from a root Dataset)
|
|
22
|
+
- ContextScan (loads the context for a root Dataset)
|
|
21
23
|
- FilteredScan (scans input Set and applies filter)
|
|
22
24
|
- ConvertScan (scans input Set and converts it to new Schema)
|
|
23
25
|
- LimitScan (scans up to N records from a Set)
|
|
@@ -25,6 +27,8 @@ class LogicalOperator:
|
|
|
25
27
|
- Aggregate (applies an aggregation on the Set)
|
|
26
28
|
- RetrieveScan (fetches documents from a provided input for a given query)
|
|
27
29
|
- Map (applies a function to each record in the Set without adding any new columns)
|
|
30
|
+
- ComputeOperator (executes a computation described in natural language)
|
|
31
|
+
- SearchOperator (executes a search query on the input Context)
|
|
28
32
|
|
|
29
33
|
Every logical operator must declare the get_logical_id_params() and get_logical_op_params() methods,
|
|
30
34
|
which return dictionaries of parameters that are used to compute the logical op id and to implement
|
|
@@ -33,17 +37,21 @@ class LogicalOperator:
|
|
|
33
37
|
|
|
34
38
|
def __init__(
|
|
35
39
|
self,
|
|
36
|
-
output_schema:
|
|
37
|
-
input_schema:
|
|
40
|
+
output_schema: type[BaseModel],
|
|
41
|
+
input_schema: type[BaseModel] | None = None,
|
|
42
|
+
depends_on: list[str] | None = None,
|
|
38
43
|
):
|
|
44
|
+
# TODO: can we eliminate input_schema?
|
|
39
45
|
self.output_schema = output_schema
|
|
40
46
|
self.input_schema = input_schema
|
|
47
|
+
self.depends_on = [] if depends_on is None else sorted(depends_on)
|
|
41
48
|
self.logical_op_id: str | None = None
|
|
49
|
+
self.unique_logical_op_id: str | None = None
|
|
42
50
|
|
|
43
51
|
# compute the fields generated by this logical operator
|
|
44
|
-
input_field_names = self.input_schema.
|
|
52
|
+
input_field_names = list(self.input_schema.model_fields) if self.input_schema is not None else []
|
|
45
53
|
self.generated_fields = sorted(
|
|
46
|
-
[field_name for field_name in self.output_schema.
|
|
54
|
+
[field_name for field_name in self.output_schema.model_fields if field_name not in input_field_names]
|
|
47
55
|
)
|
|
48
56
|
|
|
49
57
|
def __str__(self) -> str:
|
|
@@ -54,12 +62,28 @@ class LogicalOperator:
|
|
|
54
62
|
return isinstance(other, self.__class__) and all_id_params_match
|
|
55
63
|
|
|
56
64
|
def copy(self) -> LogicalOperator:
|
|
57
|
-
|
|
65
|
+
logical_op_copy = self.__class__(**self.get_logical_op_params())
|
|
66
|
+
logical_op_copy.logical_op_id = self.logical_op_id
|
|
67
|
+
logical_op_copy.unique_logical_op_id = self.unique_logical_op_id
|
|
68
|
+
return logical_op_copy
|
|
58
69
|
|
|
59
70
|
def logical_op_name(self) -> str:
|
|
60
71
|
"""Name of the logical operator."""
|
|
61
72
|
return str(self.__class__.__name__)
|
|
62
73
|
|
|
74
|
+
def get_unique_logical_op_id(self) -> str:
|
|
75
|
+
"""
|
|
76
|
+
Get the unique logical operator id for this logical operator.
|
|
77
|
+
"""
|
|
78
|
+
return self.unique_logical_op_id
|
|
79
|
+
|
|
80
|
+
def set_unique_logical_op_id(self, unique_logical_op_id: str) -> None:
|
|
81
|
+
"""
|
|
82
|
+
Set the unique logical operator id for this logical operator.
|
|
83
|
+
This is used to uniquely identify the logical operator in the query plan.
|
|
84
|
+
"""
|
|
85
|
+
self.unique_logical_op_id = unique_logical_op_id
|
|
86
|
+
|
|
63
87
|
def get_logical_id_params(self) -> dict:
|
|
64
88
|
"""
|
|
65
89
|
Returns a dictionary mapping of logical operator parameters which are relevant
|
|
@@ -69,6 +93,7 @@ class LogicalOperator:
|
|
|
69
93
|
NOTE: input_schema and output_schema are not included in the id params because
|
|
70
94
|
they depend on how the Optimizer orders operations.
|
|
71
95
|
"""
|
|
96
|
+
# TODO: should we use `generated_fields` after getting rid of them in PhysicalOperator?
|
|
72
97
|
return {"generated_fields": self.generated_fields}
|
|
73
98
|
|
|
74
99
|
def get_logical_op_params(self) -> dict:
|
|
@@ -78,10 +103,16 @@ class LogicalOperator:
|
|
|
78
103
|
|
|
79
104
|
NOTE: Should be overriden by subclasses to include class-specific parameters.
|
|
80
105
|
"""
|
|
81
|
-
return {
|
|
106
|
+
return {
|
|
107
|
+
"input_schema": self.input_schema,
|
|
108
|
+
"output_schema": self.output_schema,
|
|
109
|
+
"depends_on": self.depends_on,
|
|
110
|
+
}
|
|
82
111
|
|
|
83
112
|
def get_logical_op_id(self):
|
|
84
113
|
"""
|
|
114
|
+
TODO: turn this into a property?
|
|
115
|
+
|
|
85
116
|
NOTE: We do not call this in the __init__() method as subclasses may set parameters
|
|
86
117
|
returned by self.get_logical_op_params() after they call to super().__init__().
|
|
87
118
|
"""
|
|
@@ -119,13 +150,19 @@ class Aggregate(LogicalOperator):
|
|
|
119
150
|
def __init__(
|
|
120
151
|
self,
|
|
121
152
|
agg_func: AggFunc,
|
|
122
|
-
target_cache_id: str | None = None,
|
|
123
153
|
*args,
|
|
124
154
|
**kwargs,
|
|
125
155
|
):
|
|
156
|
+
if kwargs.get("output_schema") is None:
|
|
157
|
+
if agg_func == AggFunc.COUNT:
|
|
158
|
+
kwargs["output_schema"] = Count
|
|
159
|
+
elif agg_func == AggFunc.AVERAGE:
|
|
160
|
+
kwargs["output_schema"] = Average
|
|
161
|
+
else:
|
|
162
|
+
raise ValueError(f"Unsupported aggregation function: {agg_func}")
|
|
163
|
+
|
|
126
164
|
super().__init__(*args, **kwargs)
|
|
127
165
|
self.agg_func = agg_func
|
|
128
|
-
self.target_cache_id = target_cache_id
|
|
129
166
|
|
|
130
167
|
def __str__(self):
|
|
131
168
|
return f"{self.__class__.__name__}(function: {str(self.agg_func.value)})"
|
|
@@ -140,7 +177,6 @@ class Aggregate(LogicalOperator):
|
|
|
140
177
|
logical_op_params = super().get_logical_op_params()
|
|
141
178
|
logical_op_params = {
|
|
142
179
|
"agg_func": self.agg_func,
|
|
143
|
-
"target_cache_id": self.target_cache_id,
|
|
144
180
|
**logical_op_params,
|
|
145
181
|
}
|
|
146
182
|
|
|
@@ -148,81 +184,96 @@ class Aggregate(LogicalOperator):
|
|
|
148
184
|
|
|
149
185
|
|
|
150
186
|
class BaseScan(LogicalOperator):
|
|
151
|
-
"""A BaseScan is a logical operator that represents a scan of a particular
|
|
187
|
+
"""A BaseScan is a logical operator that represents a scan of a particular root Dataset."""
|
|
152
188
|
|
|
153
|
-
def __init__(self,
|
|
154
|
-
super().__init__(output_schema=output_schema)
|
|
155
|
-
self.
|
|
189
|
+
def __init__(self, datasource: dataset.Dataset, output_schema: type[BaseModel], *args, **kwargs):
|
|
190
|
+
super().__init__(*args, output_schema=output_schema, **kwargs)
|
|
191
|
+
self.datasource = datasource
|
|
156
192
|
|
|
157
193
|
def __str__(self):
|
|
158
|
-
return f"BaseScan({self.
|
|
194
|
+
return f"BaseScan({self.datasource},{self.output_schema})"
|
|
159
195
|
|
|
160
196
|
def __eq__(self, other) -> bool:
|
|
161
197
|
return (
|
|
162
198
|
isinstance(other, BaseScan)
|
|
163
|
-
and self.input_schema
|
|
164
|
-
and self.output_schema
|
|
165
|
-
and self.
|
|
199
|
+
and self.input_schema == other.input_schema
|
|
200
|
+
and self.output_schema == other.output_schema
|
|
201
|
+
and self.datasource == other.datasource
|
|
166
202
|
)
|
|
167
203
|
|
|
168
204
|
def get_logical_id_params(self) -> dict:
|
|
169
|
-
|
|
205
|
+
logical_id_params = super().get_logical_id_params()
|
|
206
|
+
logical_id_params = {
|
|
207
|
+
"id": self.datasource.id,
|
|
208
|
+
**logical_id_params,
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
return logical_id_params
|
|
170
212
|
|
|
171
213
|
def get_logical_op_params(self) -> dict:
|
|
172
214
|
logical_op_params = super().get_logical_op_params()
|
|
173
|
-
logical_op_params = {"
|
|
215
|
+
logical_op_params = {"datasource": self.datasource, **logical_op_params}
|
|
174
216
|
|
|
175
217
|
return logical_op_params
|
|
176
218
|
|
|
177
219
|
|
|
178
|
-
class
|
|
179
|
-
"""A
|
|
220
|
+
class ContextScan(LogicalOperator):
|
|
221
|
+
"""A ContextScan is a logical operator that loads the context for a particular root Dataset."""
|
|
180
222
|
|
|
181
|
-
def __init__(self,
|
|
182
|
-
super().__init__(output_schema=output_schema)
|
|
183
|
-
self.
|
|
223
|
+
def __init__(self, context: context.Context, output_schema: type[BaseModel], *args, **kwargs):
|
|
224
|
+
super().__init__(*args, output_schema=output_schema, **kwargs)
|
|
225
|
+
self.context = context
|
|
184
226
|
|
|
185
227
|
def __str__(self):
|
|
186
|
-
return f"
|
|
228
|
+
return f"ContextScan({self.context},{self.output_schema})"
|
|
229
|
+
|
|
230
|
+
def __eq__(self, other) -> bool:
|
|
231
|
+
return (
|
|
232
|
+
isinstance(other, ContextScan)
|
|
233
|
+
and self.context.id == other.context.id
|
|
234
|
+
)
|
|
187
235
|
|
|
188
236
|
def get_logical_id_params(self) -> dict:
|
|
189
|
-
|
|
237
|
+
logical_id_params = super().get_logical_id_params()
|
|
238
|
+
logical_id_params = {
|
|
239
|
+
"id": self.context.id,
|
|
240
|
+
**logical_id_params,
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
return logical_id_params
|
|
190
244
|
|
|
191
245
|
def get_logical_op_params(self) -> dict:
|
|
192
246
|
logical_op_params = super().get_logical_op_params()
|
|
193
|
-
logical_op_params = {"
|
|
247
|
+
logical_op_params = {"context": self.context, **logical_op_params}
|
|
194
248
|
|
|
195
249
|
return logical_op_params
|
|
196
250
|
|
|
197
251
|
|
|
198
252
|
class ConvertScan(LogicalOperator):
|
|
199
|
-
"""A ConvertScan is a logical operator that represents a scan of a particular
|
|
253
|
+
"""A ConvertScan is a logical operator that represents a scan of a particular input Dataset, with conversion applied."""
|
|
200
254
|
|
|
201
255
|
def __init__(
|
|
202
256
|
self,
|
|
203
257
|
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
204
258
|
udf: Callable | None = None,
|
|
205
|
-
depends_on: list[str] | None = None,
|
|
206
259
|
desc: str | None = None,
|
|
207
|
-
target_cache_id: str | None = None,
|
|
208
260
|
*args,
|
|
209
261
|
**kwargs,
|
|
210
262
|
):
|
|
211
263
|
super().__init__(*args, **kwargs)
|
|
212
264
|
self.cardinality = cardinality
|
|
213
265
|
self.udf = udf
|
|
214
|
-
self.depends_on = [] if depends_on is None else sorted(depends_on)
|
|
215
266
|
self.desc = desc
|
|
216
|
-
self.target_cache_id = target_cache_id
|
|
217
267
|
|
|
218
268
|
def __str__(self):
|
|
219
|
-
return f"ConvertScan({self.input_schema} -> {str(self.output_schema)}
|
|
269
|
+
return f"ConvertScan({self.input_schema} -> {str(self.output_schema)})"
|
|
220
270
|
|
|
221
271
|
def get_logical_id_params(self) -> dict:
|
|
222
272
|
logical_id_params = super().get_logical_id_params()
|
|
223
273
|
logical_id_params = {
|
|
224
274
|
"cardinality": self.cardinality,
|
|
225
275
|
"udf": self.udf,
|
|
276
|
+
"desc": self.desc,
|
|
226
277
|
**logical_id_params,
|
|
227
278
|
}
|
|
228
279
|
|
|
@@ -233,9 +284,41 @@ class ConvertScan(LogicalOperator):
|
|
|
233
284
|
logical_op_params = {
|
|
234
285
|
"cardinality": self.cardinality,
|
|
235
286
|
"udf": self.udf,
|
|
236
|
-
"depends_on": self.depends_on,
|
|
237
287
|
"desc": self.desc,
|
|
238
|
-
|
|
288
|
+
**logical_op_params,
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
return logical_op_params
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class Distinct(LogicalOperator):
|
|
295
|
+
def __init__(self, distinct_cols: list[str] | None, *args, **kwargs):
|
|
296
|
+
super().__init__(*args, **kwargs)
|
|
297
|
+
# if distinct_cols is not None, check that all columns are in the input schema
|
|
298
|
+
if distinct_cols is not None:
|
|
299
|
+
for col in distinct_cols:
|
|
300
|
+
assert col in self.input_schema.model_fields, f"Column {col} not found in input schema {self.input_schema} for Distinct operator"
|
|
301
|
+
|
|
302
|
+
# store the list of distinct columns, sorted
|
|
303
|
+
self.distinct_cols = (
|
|
304
|
+
sorted([field_name for field_name in self.input_schema.model_fields])
|
|
305
|
+
if distinct_cols is None
|
|
306
|
+
else sorted(distinct_cols)
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
def __str__(self):
|
|
310
|
+
return f"Distinct({self.distinct_cols})"
|
|
311
|
+
|
|
312
|
+
def get_logical_id_params(self) -> dict:
|
|
313
|
+
logical_id_params = super().get_logical_id_params()
|
|
314
|
+
logical_id_params = {"distinct_cols": self.distinct_cols, **logical_id_params}
|
|
315
|
+
|
|
316
|
+
return logical_id_params
|
|
317
|
+
|
|
318
|
+
def get_logical_op_params(self) -> dict:
|
|
319
|
+
logical_op_params = super().get_logical_op_params()
|
|
320
|
+
logical_op_params = {
|
|
321
|
+
"distinct_cols": self.distinct_cols,
|
|
239
322
|
**logical_op_params,
|
|
240
323
|
}
|
|
241
324
|
|
|
@@ -243,20 +326,18 @@ class ConvertScan(LogicalOperator):
|
|
|
243
326
|
|
|
244
327
|
|
|
245
328
|
class FilteredScan(LogicalOperator):
|
|
246
|
-
"""A FilteredScan is a logical operator that represents a scan of a particular
|
|
329
|
+
"""A FilteredScan is a logical operator that represents a scan of a particular input Dataset, with filters applied."""
|
|
247
330
|
|
|
248
331
|
def __init__(
|
|
249
332
|
self,
|
|
250
333
|
filter: Filter,
|
|
251
|
-
|
|
252
|
-
target_cache_id: str | None = None,
|
|
334
|
+
desc: str | None = None,
|
|
253
335
|
*args,
|
|
254
336
|
**kwargs,
|
|
255
337
|
):
|
|
256
338
|
super().__init__(*args, **kwargs)
|
|
257
339
|
self.filter = filter
|
|
258
|
-
self.
|
|
259
|
-
self.target_cache_id = target_cache_id
|
|
340
|
+
self.desc = desc
|
|
260
341
|
|
|
261
342
|
def __str__(self):
|
|
262
343
|
return f"FilteredScan({str(self.output_schema)}, {str(self.filter)})"
|
|
@@ -265,6 +346,7 @@ class FilteredScan(LogicalOperator):
|
|
|
265
346
|
logical_id_params = super().get_logical_id_params()
|
|
266
347
|
logical_id_params = {
|
|
267
348
|
"filter": self.filter,
|
|
349
|
+
"desc": self.desc,
|
|
268
350
|
**logical_id_params,
|
|
269
351
|
}
|
|
270
352
|
|
|
@@ -274,8 +356,7 @@ class FilteredScan(LogicalOperator):
|
|
|
274
356
|
logical_op_params = super().get_logical_op_params()
|
|
275
357
|
logical_op_params = {
|
|
276
358
|
"filter": self.filter,
|
|
277
|
-
"
|
|
278
|
-
"target_cache_id": self.target_cache_id,
|
|
359
|
+
"desc": self.desc,
|
|
279
360
|
**logical_op_params,
|
|
280
361
|
}
|
|
281
362
|
|
|
@@ -286,7 +367,6 @@ class GroupByAggregate(LogicalOperator):
|
|
|
286
367
|
def __init__(
|
|
287
368
|
self,
|
|
288
369
|
group_by_sig: GroupBySig,
|
|
289
|
-
target_cache_id: str | None = None,
|
|
290
370
|
*args,
|
|
291
371
|
**kwargs,
|
|
292
372
|
):
|
|
@@ -297,7 +377,6 @@ class GroupByAggregate(LogicalOperator):
|
|
|
297
377
|
if not valid:
|
|
298
378
|
raise TypeError(error)
|
|
299
379
|
self.group_by_sig = group_by_sig
|
|
300
|
-
self.target_cache_id = target_cache_id
|
|
301
380
|
|
|
302
381
|
def __str__(self):
|
|
303
382
|
return f"GroupBy({self.group_by_sig.serialize()})"
|
|
@@ -312,7 +391,32 @@ class GroupByAggregate(LogicalOperator):
|
|
|
312
391
|
logical_op_params = super().get_logical_op_params()
|
|
313
392
|
logical_op_params = {
|
|
314
393
|
"group_by_sig": self.group_by_sig,
|
|
315
|
-
|
|
394
|
+
**logical_op_params,
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
return logical_op_params
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
class JoinOp(LogicalOperator):
|
|
401
|
+
def __init__(self, condition: str, desc: str | None = None, *args, **kwargs):
|
|
402
|
+
super().__init__(*args, **kwargs)
|
|
403
|
+
self.condition = condition
|
|
404
|
+
self.desc = desc
|
|
405
|
+
|
|
406
|
+
def __str__(self):
|
|
407
|
+
return f"Join(condition={self.condition})"
|
|
408
|
+
|
|
409
|
+
def get_logical_id_params(self) -> dict:
|
|
410
|
+
logical_id_params = super().get_logical_id_params()
|
|
411
|
+
logical_id_params = {"condition": self.condition, "desc": self.desc, **logical_id_params}
|
|
412
|
+
|
|
413
|
+
return logical_id_params
|
|
414
|
+
|
|
415
|
+
def get_logical_op_params(self) -> dict:
|
|
416
|
+
logical_op_params = super().get_logical_op_params()
|
|
417
|
+
logical_op_params = {
|
|
418
|
+
"condition": self.condition,
|
|
419
|
+
"desc": self.desc,
|
|
316
420
|
**logical_op_params,
|
|
317
421
|
}
|
|
318
422
|
|
|
@@ -320,10 +424,9 @@ class GroupByAggregate(LogicalOperator):
|
|
|
320
424
|
|
|
321
425
|
|
|
322
426
|
class LimitScan(LogicalOperator):
|
|
323
|
-
def __init__(self, limit: int,
|
|
427
|
+
def __init__(self, limit: int, *args, **kwargs):
|
|
324
428
|
super().__init__(*args, **kwargs)
|
|
325
429
|
self.limit = limit
|
|
326
|
-
self.target_cache_id = target_cache_id
|
|
327
430
|
|
|
328
431
|
def __str__(self):
|
|
329
432
|
return f"LimitScan({str(self.input_schema)}, {str(self.output_schema)})"
|
|
@@ -338,7 +441,6 @@ class LimitScan(LogicalOperator):
|
|
|
338
441
|
logical_op_params = super().get_logical_op_params()
|
|
339
442
|
logical_op_params = {
|
|
340
443
|
"limit": self.limit,
|
|
341
|
-
"target_cache_id": self.target_cache_id,
|
|
342
444
|
**logical_op_params,
|
|
343
445
|
}
|
|
344
446
|
|
|
@@ -346,10 +448,9 @@ class LimitScan(LogicalOperator):
|
|
|
346
448
|
|
|
347
449
|
|
|
348
450
|
class Project(LogicalOperator):
|
|
349
|
-
def __init__(self, project_cols: list[str],
|
|
451
|
+
def __init__(self, project_cols: list[str], *args, **kwargs):
|
|
350
452
|
super().__init__(*args, **kwargs)
|
|
351
453
|
self.project_cols = project_cols
|
|
352
|
-
self.target_cache_id = target_cache_id
|
|
353
454
|
|
|
354
455
|
def __str__(self):
|
|
355
456
|
return f"Project({self.input_schema}, {self.project_cols})"
|
|
@@ -364,7 +465,6 @@ class Project(LogicalOperator):
|
|
|
364
465
|
logical_op_params = super().get_logical_op_params()
|
|
365
466
|
logical_op_params = {
|
|
366
467
|
"project_cols": self.project_cols,
|
|
367
|
-
"target_cache_id": self.target_cache_id,
|
|
368
468
|
**logical_op_params,
|
|
369
469
|
}
|
|
370
470
|
|
|
@@ -372,7 +472,7 @@ class Project(LogicalOperator):
|
|
|
372
472
|
|
|
373
473
|
|
|
374
474
|
class RetrieveScan(LogicalOperator):
|
|
375
|
-
"""A RetrieveScan is a logical operator that represents a scan of a particular
|
|
475
|
+
"""A RetrieveScan is a logical operator that represents a scan of a particular input Dataset, with a convert-like retrieve applied."""
|
|
376
476
|
|
|
377
477
|
def __init__(
|
|
378
478
|
self,
|
|
@@ -381,7 +481,6 @@ class RetrieveScan(LogicalOperator):
|
|
|
381
481
|
search_attr,
|
|
382
482
|
output_attrs,
|
|
383
483
|
k,
|
|
384
|
-
target_cache_id: str = None,
|
|
385
484
|
*args,
|
|
386
485
|
**kwargs,
|
|
387
486
|
):
|
|
@@ -391,10 +490,9 @@ class RetrieveScan(LogicalOperator):
|
|
|
391
490
|
self.search_attr = search_attr
|
|
392
491
|
self.output_attrs = output_attrs
|
|
393
492
|
self.k = k
|
|
394
|
-
self.target_cache_id = target_cache_id
|
|
395
493
|
|
|
396
494
|
def __str__(self):
|
|
397
|
-
return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)}
|
|
495
|
+
return f"RetrieveScan({self.input_schema} -> {str(self.output_schema)})"
|
|
398
496
|
|
|
399
497
|
def get_logical_id_params(self) -> dict:
|
|
400
498
|
# NOTE: if we allow optimization over index, then we will need to include it in the id params
|
|
@@ -418,36 +516,31 @@ class RetrieveScan(LogicalOperator):
|
|
|
418
516
|
"search_attr": self.search_attr,
|
|
419
517
|
"output_attrs": self.output_attrs,
|
|
420
518
|
"k": self.k,
|
|
421
|
-
"target_cache_id": self.target_cache_id,
|
|
422
519
|
**logical_op_params,
|
|
423
520
|
}
|
|
424
521
|
|
|
425
522
|
return logical_op_params
|
|
426
523
|
|
|
427
524
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
525
|
+
class ComputeOperator(LogicalOperator):
|
|
526
|
+
"""
|
|
527
|
+
A ComputeOperator is a logical operator that performs a computation described in natural language
|
|
528
|
+
on a given Context.
|
|
529
|
+
"""
|
|
432
530
|
|
|
433
|
-
def __init__(
|
|
434
|
-
self,
|
|
435
|
-
udf: Callable | None = None,
|
|
436
|
-
target_cache_id: str | None = None,
|
|
437
|
-
*args,
|
|
438
|
-
**kwargs,
|
|
439
|
-
):
|
|
531
|
+
def __init__(self, context_id: str, instruction: str, *args, **kwargs):
|
|
440
532
|
super().__init__(*args, **kwargs)
|
|
441
|
-
self.
|
|
442
|
-
self.
|
|
533
|
+
self.context_id = context_id
|
|
534
|
+
self.instruction = instruction
|
|
443
535
|
|
|
444
536
|
def __str__(self):
|
|
445
|
-
return f"
|
|
537
|
+
return f"ComputeOperator(id={self.context_id}, instr={self.instruction:20s})"
|
|
446
538
|
|
|
447
539
|
def get_logical_id_params(self) -> dict:
|
|
448
540
|
logical_id_params = super().get_logical_id_params()
|
|
449
541
|
logical_id_params = {
|
|
450
|
-
"
|
|
542
|
+
"context_id": self.context_id,
|
|
543
|
+
"instruction": self.instruction,
|
|
451
544
|
**logical_id_params,
|
|
452
545
|
}
|
|
453
546
|
|
|
@@ -456,8 +549,43 @@ class MapScan(LogicalOperator):
|
|
|
456
549
|
def get_logical_op_params(self) -> dict:
|
|
457
550
|
logical_op_params = super().get_logical_op_params()
|
|
458
551
|
logical_op_params = {
|
|
459
|
-
"
|
|
460
|
-
"
|
|
552
|
+
"context_id": self.context_id,
|
|
553
|
+
"instruction": self.instruction,
|
|
554
|
+
**logical_op_params,
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
return logical_op_params
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class SearchOperator(LogicalOperator):
|
|
561
|
+
"""
|
|
562
|
+
A SearchOperator is a logical operator that executes a search described in natural language
|
|
563
|
+
on a given Context.
|
|
564
|
+
"""
|
|
565
|
+
|
|
566
|
+
def __init__(self, context_id: str, search_query: str, *args, **kwargs):
|
|
567
|
+
super().__init__(*args, **kwargs)
|
|
568
|
+
self.context_id = context_id
|
|
569
|
+
self.search_query = search_query
|
|
570
|
+
|
|
571
|
+
def __str__(self):
|
|
572
|
+
return f"SearchOperator(id={self.context_id}, search_query={self.search_query:20s})"
|
|
573
|
+
|
|
574
|
+
def get_logical_id_params(self) -> dict:
|
|
575
|
+
logical_id_params = super().get_logical_id_params()
|
|
576
|
+
logical_id_params = {
|
|
577
|
+
"context_id": self.context_id,
|
|
578
|
+
"search_query": self.search_query,
|
|
579
|
+
**logical_id_params,
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
return logical_id_params
|
|
583
|
+
|
|
584
|
+
def get_logical_op_params(self) -> dict:
|
|
585
|
+
logical_op_params = super().get_logical_op_params()
|
|
586
|
+
logical_op_params = {
|
|
587
|
+
"context_id": self.context_id,
|
|
588
|
+
"search_query": self.search_query,
|
|
461
589
|
**logical_op_params,
|
|
462
590
|
}
|
|
463
591
|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from pydantic.fields import FieldInfo
|
|
4
|
+
|
|
3
5
|
from palimpzest.constants import MODEL_CARDS, Model, PromptStrategy
|
|
4
|
-
from palimpzest.core.data.dataclasses import GenerationStats, OperatorCostEstimates
|
|
5
6
|
from palimpzest.core.elements.records import DataRecord
|
|
6
|
-
from palimpzest.core.
|
|
7
|
-
from palimpzest.query.generators.generators import
|
|
7
|
+
from palimpzest.core.models import GenerationStats, OperatorCostEstimates
|
|
8
|
+
from palimpzest.query.generators.generators import Generator
|
|
8
9
|
from palimpzest.query.operators.convert import LLMConvert
|
|
9
10
|
|
|
10
11
|
# TYPE DEFINITIONS
|
|
@@ -20,7 +21,6 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
20
21
|
aggregator_model: Model,
|
|
21
22
|
proposer_prompt_strategy: PromptStrategy = PromptStrategy.COT_MOA_PROPOSER,
|
|
22
23
|
aggregator_prompt_strategy: PromptStrategy = PromptStrategy.COT_MOA_AGG,
|
|
23
|
-
proposer_prompt: str | None = None,
|
|
24
24
|
*args,
|
|
25
25
|
**kwargs,
|
|
26
26
|
):
|
|
@@ -33,14 +33,13 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
33
33
|
self.aggregator_model = aggregator_model
|
|
34
34
|
self.proposer_prompt_strategy = proposer_prompt_strategy
|
|
35
35
|
self.aggregator_prompt_strategy = aggregator_prompt_strategy
|
|
36
|
-
self.proposer_prompt = proposer_prompt
|
|
37
36
|
|
|
38
37
|
# create generators
|
|
39
38
|
self.proposer_generators = [
|
|
40
|
-
|
|
39
|
+
Generator(model, self.proposer_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
41
40
|
for model in proposer_models
|
|
42
41
|
]
|
|
43
|
-
self.aggregator_generator =
|
|
42
|
+
self.aggregator_generator = Generator(aggregator_model, self.aggregator_prompt_strategy, self.reasoning_effort, self.api_base, self.cardinality, self.desc, self.verbose)
|
|
44
43
|
|
|
45
44
|
def __str__(self):
|
|
46
45
|
op = super().__str__()
|
|
@@ -77,6 +76,9 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
77
76
|
|
|
78
77
|
return op_params
|
|
79
78
|
|
|
79
|
+
def is_image_conversion(self) -> bool:
|
|
80
|
+
return self.proposer_prompt_strategy.is_image_prompt()
|
|
81
|
+
|
|
80
82
|
def naive_cost_estimates(self, source_op_cost_estimates: OperatorCostEstimates) -> OperatorCostEstimates:
|
|
81
83
|
"""
|
|
82
84
|
Currently, we are using multiple proposer models with different temperatures to synthesize
|
|
@@ -111,7 +113,7 @@ class MixtureOfAgentsConvert(LLMConvert):
|
|
|
111
113
|
|
|
112
114
|
return naive_op_cost_estimates
|
|
113
115
|
|
|
114
|
-
def convert(self, candidate: DataRecord, fields: dict[str,
|
|
116
|
+
def convert(self, candidate: DataRecord, fields: dict[str, FieldInfo]) -> tuple[dict[str, list], GenerationStats]:
|
|
115
117
|
# get input fields
|
|
116
118
|
input_fields = self.get_input_fields()
|
|
117
119
|
|