palimpzest 0.6.4__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +5 -0
- palimpzest/constants.py +110 -43
- palimpzest/core/__init__.py +0 -78
- palimpzest/core/data/dataclasses.py +382 -44
- palimpzest/core/elements/filters.py +7 -3
- palimpzest/core/elements/index.py +70 -0
- palimpzest/core/elements/records.py +33 -11
- palimpzest/core/lib/fields.py +1 -0
- palimpzest/core/lib/schemas.py +4 -3
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -4
- palimpzest/prompts/prompt_factory.py +44 -7
- palimpzest/prompts/split_merge_prompts.py +56 -0
- palimpzest/prompts/split_proposer_prompts.py +55 -0
- palimpzest/query/execution/execution_strategy.py +435 -53
- palimpzest/query/execution/execution_strategy_type.py +20 -0
- palimpzest/query/execution/mab_execution_strategy.py +532 -0
- palimpzest/query/execution/parallel_execution_strategy.py +143 -172
- palimpzest/query/execution/random_sampling_execution_strategy.py +240 -0
- palimpzest/query/execution/single_threaded_execution_strategy.py +173 -203
- palimpzest/query/generators/api_client_factory.py +31 -0
- palimpzest/query/generators/generators.py +256 -76
- palimpzest/query/operators/__init__.py +1 -2
- palimpzest/query/operators/code_synthesis_convert.py +33 -18
- palimpzest/query/operators/convert.py +30 -97
- palimpzest/query/operators/critique_and_refine_convert.py +5 -6
- palimpzest/query/operators/filter.py +7 -10
- palimpzest/query/operators/logical.py +54 -10
- palimpzest/query/operators/map.py +130 -0
- palimpzest/query/operators/mixture_of_agents_convert.py +6 -6
- palimpzest/query/operators/physical.py +3 -12
- palimpzest/query/operators/rag_convert.py +66 -18
- palimpzest/query/operators/retrieve.py +230 -34
- palimpzest/query/operators/scan.py +5 -2
- palimpzest/query/operators/split_convert.py +169 -0
- palimpzest/query/operators/token_reduction_convert.py +8 -14
- palimpzest/query/optimizer/__init__.py +4 -16
- palimpzest/query/optimizer/cost_model.py +73 -266
- palimpzest/query/optimizer/optimizer.py +87 -58
- palimpzest/query/optimizer/optimizer_strategy.py +18 -97
- palimpzest/query/optimizer/optimizer_strategy_type.py +37 -0
- palimpzest/query/optimizer/plan.py +2 -3
- palimpzest/query/optimizer/primitives.py +5 -3
- palimpzest/query/optimizer/rules.py +336 -172
- palimpzest/query/optimizer/tasks.py +30 -100
- palimpzest/query/processor/config.py +38 -22
- palimpzest/query/processor/nosentinel_processor.py +16 -520
- palimpzest/query/processor/processing_strategy_type.py +28 -0
- palimpzest/query/processor/query_processor.py +38 -206
- palimpzest/query/processor/query_processor_factory.py +117 -130
- palimpzest/query/processor/sentinel_processor.py +90 -0
- palimpzest/query/processor/streaming_processor.py +25 -32
- palimpzest/sets.py +88 -41
- palimpzest/utils/model_helpers.py +8 -7
- palimpzest/utils/progress.py +368 -152
- palimpzest/utils/token_reduction_helpers.py +1 -3
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/METADATA +19 -9
- palimpzest-0.7.1.dist-info/RECORD +96 -0
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/WHEEL +1 -1
- palimpzest/query/processor/mab_sentinel_processor.py +0 -884
- palimpzest/query/processor/random_sampling_sentinel_processor.py +0 -639
- palimpzest/utils/index_helpers.py +0 -6
- palimpzest-0.6.4.dist-info/RECORD +0 -87
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info/licenses}/LICENSE +0 -0
- {palimpzest-0.6.4.dist-info → palimpzest-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import time
|
|
2
3
|
|
|
3
|
-
from palimpzest.core.data.dataclasses import
|
|
4
|
+
from palimpzest.core.data.dataclasses import PlanStats
|
|
4
5
|
from palimpzest.core.elements.records import DataRecordCollection
|
|
5
|
-
from palimpzest.policy import Policy
|
|
6
6
|
from palimpzest.query.operators.aggregate import AggregateOp
|
|
7
7
|
from palimpzest.query.operators.filter import FilterOp
|
|
8
8
|
from palimpzest.query.operators.limit import LimitScanOp
|
|
@@ -11,6 +11,7 @@ from palimpzest.query.optimizer.plan import PhysicalPlan
|
|
|
11
11
|
from palimpzest.query.processor.query_processor import QueryProcessor
|
|
12
12
|
from palimpzest.sets import Dataset
|
|
13
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
14
15
|
|
|
15
16
|
class StreamingQueryProcessor(QueryProcessor):
|
|
16
17
|
"""This class can be used for a streaming, record-based execution.
|
|
@@ -24,6 +25,7 @@ class StreamingQueryProcessor(QueryProcessor):
|
|
|
24
25
|
self.current_scan_idx: int = 0
|
|
25
26
|
self.plan_generated: bool = False
|
|
26
27
|
self.records_count: int = 0
|
|
28
|
+
logger.info("Initialized StreamingQueryProcessor")
|
|
27
29
|
|
|
28
30
|
@property
|
|
29
31
|
def plan(self) -> PhysicalPlan:
|
|
@@ -45,33 +47,32 @@ class StreamingQueryProcessor(QueryProcessor):
|
|
|
45
47
|
def plan_stats(self, plan_stats: PlanStats):
|
|
46
48
|
self._plan_stats = plan_stats
|
|
47
49
|
|
|
48
|
-
def generate_plan(self, dataset: Dataset
|
|
50
|
+
def generate_plan(self, dataset: Dataset):
|
|
49
51
|
# self.clear_cached_examples()
|
|
50
52
|
start_time = time.time()
|
|
51
53
|
|
|
54
|
+
# check that the plan does not contain any aggregation operators
|
|
55
|
+
for op in self.plan.operators:
|
|
56
|
+
if isinstance(op, AggregateOp):
|
|
57
|
+
raise Exception("You cannot have a Streaming Execution if there is an Aggregation Operator")
|
|
58
|
+
|
|
52
59
|
# TODO: Do we need to re-initialize the optimizer here?
|
|
53
60
|
# Effectively always use the optimal strategy
|
|
54
61
|
optimizer = self.optimizer.deepcopy_clean()
|
|
55
|
-
plans = optimizer.optimize(dataset
|
|
62
|
+
plans = optimizer.optimize(dataset)
|
|
56
63
|
self.plan = plans[0]
|
|
57
|
-
self.plan_stats = PlanStats(
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
raise Exception("You cannot have a Streaming Execution if there is an Aggregation Operator")
|
|
61
|
-
op_id = op.get_op_id()
|
|
62
|
-
op_name = op.op_name()
|
|
63
|
-
op_details = {k: str(v) for k, v in op.get_id_params().items()}
|
|
64
|
-
self.plan_stats.operator_stats[op_id] = OperatorStats(op_id=op_id, op_name=op_name, op_details=op_details)
|
|
65
|
-
print("Time for planning: ", time.time() - start_time)
|
|
64
|
+
self.plan_stats = PlanStats.from_plan(self.plan)
|
|
65
|
+
self.plan_stats.start()
|
|
66
|
+
logger.info(f"Time for planning: {time.time() - start_time:.2f} seconds")
|
|
66
67
|
self.plan_generated = True
|
|
67
|
-
|
|
68
|
+
logger.info(f"Generated plan:\n{self.plan}")
|
|
68
69
|
return self.plan
|
|
69
70
|
|
|
70
71
|
def execute(self):
|
|
71
|
-
|
|
72
|
+
logger.info("Executing StreamingQueryProcessor")
|
|
72
73
|
# Always delete cache
|
|
73
74
|
if not self.plan_generated:
|
|
74
|
-
self.generate_plan(self.dataset
|
|
75
|
+
self.generate_plan(self.dataset)
|
|
75
76
|
|
|
76
77
|
# if dry_run:
|
|
77
78
|
# yield [], self.plan, self.plan_stats
|
|
@@ -82,11 +83,14 @@ class StreamingQueryProcessor(QueryProcessor):
|
|
|
82
83
|
# print("Iteration number: ", idx+1, "out of", len(input_records))
|
|
83
84
|
output_records = self.execute_opstream(self.plan, record)
|
|
84
85
|
if idx == len(input_records) - 1:
|
|
85
|
-
|
|
86
|
-
self.plan_stats.
|
|
86
|
+
# finalize plan stats
|
|
87
|
+
self.plan_stats.finish()
|
|
87
88
|
self.plan_stats.plan_str = str(self.plan)
|
|
88
89
|
yield DataRecordCollection(output_records, plan_stats=self.plan_stats)
|
|
89
90
|
|
|
91
|
+
logger.info("Done executing StreamingQueryProcessor")
|
|
92
|
+
|
|
93
|
+
|
|
90
94
|
def get_input_records(self):
|
|
91
95
|
scan_operator = self.plan.operators[0]
|
|
92
96
|
assert isinstance(scan_operator, ScanPhysicalOp), "First operator in physical plan must be a ScanPhysicalOp"
|
|
@@ -102,12 +106,7 @@ class StreamingQueryProcessor(QueryProcessor):
|
|
|
102
106
|
input_records += record_set.data_records
|
|
103
107
|
record_op_stats += record_set.record_op_stats
|
|
104
108
|
|
|
105
|
-
|
|
106
|
-
self.plan_stats.operator_stats[op_id].add_record_op_stats(
|
|
107
|
-
record_op_stats,
|
|
108
|
-
source_op_id=None,
|
|
109
|
-
plan_id=self.plan.plan_id,
|
|
110
|
-
)
|
|
109
|
+
self.plan_stats.add_record_op_stats(record_op_stats)
|
|
111
110
|
|
|
112
111
|
return input_records
|
|
113
112
|
|
|
@@ -116,13 +115,11 @@ class StreamingQueryProcessor(QueryProcessor):
|
|
|
116
115
|
input_records = [record]
|
|
117
116
|
record_op_stats_lst = []
|
|
118
117
|
|
|
119
|
-
for
|
|
118
|
+
for operator in plan.operators:
|
|
120
119
|
# TODO: this being defined in the for loop potentially makes the return
|
|
121
120
|
# unbounded if plan.operators is empty. This should be defined outside the loop
|
|
122
121
|
# and the loop refactored to account for not redeclaring this for each operator
|
|
123
122
|
output_records = []
|
|
124
|
-
op_id = operator.get_op_id()
|
|
125
|
-
prev_op_id = plan.operators[op_idx - 1].get_op_id() if op_idx > 1 else None
|
|
126
123
|
|
|
127
124
|
if isinstance(operator, ScanPhysicalOp):
|
|
128
125
|
continue
|
|
@@ -145,11 +142,7 @@ class StreamingQueryProcessor(QueryProcessor):
|
|
|
145
142
|
if not output_records:
|
|
146
143
|
break
|
|
147
144
|
|
|
148
|
-
self.plan_stats.
|
|
149
|
-
record_op_stats_lst,
|
|
150
|
-
source_op_id=prev_op_id,
|
|
151
|
-
plan_id=plan.plan_id,
|
|
152
|
-
)
|
|
145
|
+
self.plan_stats.add_record_op_stats(record_op_stats_lst)
|
|
153
146
|
input_records = output_records
|
|
154
147
|
self.records_count += len(output_records)
|
|
155
148
|
|
palimpzest/sets.py
CHANGED
|
@@ -4,18 +4,18 @@ from pathlib import Path
|
|
|
4
4
|
from typing import Callable
|
|
5
5
|
|
|
6
6
|
import pandas as pd
|
|
7
|
+
from chromadb.api.models.Collection import Collection
|
|
8
|
+
from ragatouille.RAGPretrainedModel import RAGPretrainedModel
|
|
7
9
|
|
|
8
10
|
from palimpzest.constants import AggFunc, Cardinality
|
|
9
11
|
from palimpzest.core.data.datareaders import DataReader
|
|
10
12
|
from palimpzest.core.elements.filters import Filter
|
|
11
13
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
12
|
-
from palimpzest.core.lib.fields import ListField, StringField
|
|
13
14
|
from palimpzest.core.lib.schemas import Number, Schema
|
|
14
15
|
from palimpzest.policy import construct_policy_from_kwargs
|
|
15
16
|
from palimpzest.query.processor.config import QueryProcessorConfig
|
|
16
17
|
from palimpzest.utils.datareader_helpers import get_local_datareader
|
|
17
18
|
from palimpzest.utils.hash_helpers import hash_for_serialized_dict
|
|
18
|
-
from palimpzest.utils.index_helpers import get_index_str
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
#####################################################
|
|
@@ -35,15 +35,15 @@ class Set:
|
|
|
35
35
|
agg_func: AggFunc | None = None,
|
|
36
36
|
group_by: GroupBySig | None = None,
|
|
37
37
|
project_cols: list[str] | None = None,
|
|
38
|
-
index
|
|
38
|
+
index: Collection | RAGPretrainedModel | None = None,
|
|
39
39
|
search_func: Callable | None = None,
|
|
40
40
|
search_attr: str | None = None,
|
|
41
|
-
|
|
41
|
+
output_attrs: list[dict] | None = None,
|
|
42
42
|
k: int | None = None, # TODO: disambiguate `k` to be something like `retrieve_k`
|
|
43
43
|
limit: int | None = None,
|
|
44
44
|
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
45
45
|
depends_on: list[str] | None = None,
|
|
46
|
-
|
|
46
|
+
cache: bool = False,
|
|
47
47
|
):
|
|
48
48
|
self._schema = schema
|
|
49
49
|
self._source = source
|
|
@@ -56,12 +56,12 @@ class Set:
|
|
|
56
56
|
self._index = index
|
|
57
57
|
self._search_func = search_func
|
|
58
58
|
self._search_attr = search_attr
|
|
59
|
-
self.
|
|
59
|
+
self._output_attrs = output_attrs
|
|
60
60
|
self._k = k
|
|
61
61
|
self._limit = limit
|
|
62
62
|
self._cardinality = cardinality
|
|
63
63
|
self._depends_on = [] if depends_on is None else sorted(depends_on)
|
|
64
|
-
self.
|
|
64
|
+
self._cache = cache
|
|
65
65
|
|
|
66
66
|
@property
|
|
67
67
|
def schema(self) -> Schema:
|
|
@@ -83,16 +83,16 @@ class Set:
|
|
|
83
83
|
"source": self._source.serialize(),
|
|
84
84
|
"desc": repr(self._desc),
|
|
85
85
|
"filter": None if self._filter is None else self._filter.serialize(),
|
|
86
|
-
"udf": None if self._udf is None else
|
|
86
|
+
"udf": None if self._udf is None else self._udf.__name__,
|
|
87
87
|
"agg_func": None if self._agg_func is None else self._agg_func.value,
|
|
88
88
|
"cardinality": self._cardinality,
|
|
89
89
|
"limit": self._limit,
|
|
90
|
-
"group_by":
|
|
91
|
-
"project_cols":
|
|
92
|
-
"index": None if self._index is None else
|
|
93
|
-
"search_func": None if self._search_func is None else
|
|
90
|
+
"group_by": None if self._group_by is None else self._group_by.serialize(),
|
|
91
|
+
"project_cols": None if self._project_cols is None else self._project_cols,
|
|
92
|
+
"index": None if self._index is None else self._index.__class__.__name__,
|
|
93
|
+
"search_func": None if self._search_func is None else self._search_func.__name__,
|
|
94
94
|
"search_attr": self._search_attr,
|
|
95
|
-
"
|
|
95
|
+
"output_attrs": None if self._output_attrs is None else str(self._output_attrs),
|
|
96
96
|
"k": self._k,
|
|
97
97
|
}
|
|
98
98
|
|
|
@@ -132,10 +132,31 @@ class Dataset(Set):
|
|
|
132
132
|
|
|
133
133
|
# get the schema
|
|
134
134
|
schema = updated_source.schema if schema is None else schema
|
|
135
|
-
|
|
135
|
+
|
|
136
136
|
# intialize class
|
|
137
137
|
super().__init__(updated_source, schema, *args, **kwargs)
|
|
138
138
|
|
|
139
|
+
def copy(self):
|
|
140
|
+
return Dataset(
|
|
141
|
+
source=self._source.copy() if isinstance(self._source, Set) else self._source,
|
|
142
|
+
schema=self._schema,
|
|
143
|
+
desc=self._desc,
|
|
144
|
+
filter=self._filter,
|
|
145
|
+
udf=self._udf,
|
|
146
|
+
agg_func=self._agg_func,
|
|
147
|
+
group_by=self._group_by,
|
|
148
|
+
project_cols=self._project_cols,
|
|
149
|
+
index=self._index,
|
|
150
|
+
search_func=self._search_func,
|
|
151
|
+
search_attr=self._search_attr,
|
|
152
|
+
output_attrs=self._output_attrs,
|
|
153
|
+
k=self._k,
|
|
154
|
+
limit=self._limit,
|
|
155
|
+
cardinality=self._cardinality,
|
|
156
|
+
depends_on=self._depends_on,
|
|
157
|
+
cache=self._cache,
|
|
158
|
+
)
|
|
159
|
+
|
|
139
160
|
def filter(
|
|
140
161
|
self,
|
|
141
162
|
_filter: Callable,
|
|
@@ -159,9 +180,9 @@ class Dataset(Set):
|
|
|
159
180
|
schema=self.schema,
|
|
160
181
|
filter=f,
|
|
161
182
|
depends_on=depends_on,
|
|
162
|
-
|
|
183
|
+
cache=self._cache,
|
|
163
184
|
)
|
|
164
|
-
|
|
185
|
+
|
|
165
186
|
def sem_filter(
|
|
166
187
|
self,
|
|
167
188
|
_filter: str,
|
|
@@ -173,7 +194,7 @@ class Dataset(Set):
|
|
|
173
194
|
f = Filter(_filter)
|
|
174
195
|
else:
|
|
175
196
|
raise Exception("sem_filter() only supports `str` input for _filter.", type(_filter))
|
|
176
|
-
|
|
197
|
+
|
|
177
198
|
if isinstance(depends_on, str):
|
|
178
199
|
depends_on = [depends_on]
|
|
179
200
|
|
|
@@ -182,11 +203,11 @@ class Dataset(Set):
|
|
|
182
203
|
schema=self.schema,
|
|
183
204
|
filter=f,
|
|
184
205
|
depends_on=depends_on,
|
|
185
|
-
|
|
206
|
+
cache=self._cache,
|
|
186
207
|
)
|
|
187
208
|
|
|
188
209
|
def sem_add_columns(self, cols: list[dict] | type[Schema],
|
|
189
|
-
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
210
|
+
cardinality: Cardinality = Cardinality.ONE_TO_ONE,
|
|
190
211
|
depends_on: str | list[str] | None = None,
|
|
191
212
|
desc: str = "Add new columns via semantic reasoning") -> Dataset:
|
|
192
213
|
"""
|
|
@@ -217,7 +238,7 @@ class Dataset(Set):
|
|
|
217
238
|
cardinality=cardinality,
|
|
218
239
|
depends_on=depends_on,
|
|
219
240
|
desc=desc,
|
|
220
|
-
|
|
241
|
+
cache=self._cache,
|
|
221
242
|
)
|
|
222
243
|
|
|
223
244
|
def add_columns(self, udf: Callable,
|
|
@@ -254,7 +275,7 @@ class Dataset(Set):
|
|
|
254
275
|
col_dict["desc"] = col_dict.get("desc", "New column: " + col_dict["name"])
|
|
255
276
|
updated_cols.append(col_dict)
|
|
256
277
|
new_output_schema = self.schema.add_fields(updated_cols)
|
|
257
|
-
|
|
278
|
+
|
|
258
279
|
elif issubclass(cols, Schema):
|
|
259
280
|
new_output_schema = self.schema.union(cols)
|
|
260
281
|
|
|
@@ -268,7 +289,24 @@ class Dataset(Set):
|
|
|
268
289
|
cardinality=cardinality,
|
|
269
290
|
desc=desc,
|
|
270
291
|
depends_on=depends_on,
|
|
271
|
-
|
|
292
|
+
cache=self._cache,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
def map(self, udf: Callable) -> Dataset:
|
|
296
|
+
"""
|
|
297
|
+
Apply a UDF map function.
|
|
298
|
+
|
|
299
|
+
Examples:
|
|
300
|
+
map(udf=clean_column_values)
|
|
301
|
+
"""
|
|
302
|
+
if udf is None:
|
|
303
|
+
raise ValueError("`udf` must be provided for map.")
|
|
304
|
+
|
|
305
|
+
return Dataset(
|
|
306
|
+
source=self,
|
|
307
|
+
schema=self.schema,
|
|
308
|
+
udf=udf,
|
|
309
|
+
cache=self._cache,
|
|
272
310
|
)
|
|
273
311
|
|
|
274
312
|
def count(self) -> Dataset:
|
|
@@ -278,7 +316,7 @@ class Dataset(Set):
|
|
|
278
316
|
schema=Number,
|
|
279
317
|
desc="Count results",
|
|
280
318
|
agg_func=AggFunc.COUNT,
|
|
281
|
-
|
|
319
|
+
cache=self._cache,
|
|
282
320
|
)
|
|
283
321
|
|
|
284
322
|
def average(self) -> Dataset:
|
|
@@ -288,7 +326,7 @@ class Dataset(Set):
|
|
|
288
326
|
schema=Number,
|
|
289
327
|
desc="Average results",
|
|
290
328
|
agg_func=AggFunc.AVERAGE,
|
|
291
|
-
|
|
329
|
+
cache=self._cache,
|
|
292
330
|
)
|
|
293
331
|
|
|
294
332
|
def groupby(self, groupby: GroupBySig) -> Dataset:
|
|
@@ -297,34 +335,43 @@ class Dataset(Set):
|
|
|
297
335
|
schema=groupby.output_schema(),
|
|
298
336
|
desc="Group By",
|
|
299
337
|
group_by=groupby,
|
|
300
|
-
|
|
338
|
+
cache=self._cache,
|
|
301
339
|
)
|
|
302
340
|
|
|
303
341
|
def retrieve(
|
|
304
|
-
self,
|
|
342
|
+
self,
|
|
343
|
+
index: Collection | RAGPretrainedModel,
|
|
344
|
+
search_attr: str,
|
|
345
|
+
output_attrs: list[dict] | type[Schema],
|
|
346
|
+
search_func: Callable | None = None,
|
|
347
|
+
k: int = -1,
|
|
305
348
|
) -> Dataset:
|
|
306
349
|
"""
|
|
307
|
-
Retrieve the top
|
|
308
|
-
|
|
309
|
-
and the `output_attr` with type ListField(StringField). `search_func` is a function of
|
|
310
|
-
type (index, query: str | list(str), k: int) -> list[str]. It should implement the lookup
|
|
311
|
-
logic for the index and return the top k results. The value of the `search_attr` field is
|
|
312
|
-
used as the query to lookup in the index. The results are stored in the `output_attr`
|
|
313
|
-
field. `output_attr_desc` is the description of the `output_attr` field.
|
|
350
|
+
Retrieve the top-k nearest neighbors of the value of the `search_attr` from the `index` and
|
|
351
|
+
use these results to construct the `output_attrs` field(s).
|
|
314
352
|
"""
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
353
|
+
new_output_schema = None
|
|
354
|
+
if isinstance(output_attrs, list):
|
|
355
|
+
new_output_schema = self.schema.add_fields(output_attrs)
|
|
356
|
+
elif issubclass(output_attrs, Schema):
|
|
357
|
+
new_output_schema = self.schema.union(output_attrs)
|
|
358
|
+
else:
|
|
359
|
+
raise ValueError("`cols` must be a list of dictionaries or a Schema.")
|
|
360
|
+
|
|
361
|
+
# TODO: revisit once we can think through abstraction(s)
|
|
362
|
+
# # construct the PZIndex from the user-provided index
|
|
363
|
+
# index = index_factory(index)
|
|
364
|
+
|
|
318
365
|
return Dataset(
|
|
319
366
|
source=self,
|
|
320
|
-
schema=
|
|
367
|
+
schema=new_output_schema,
|
|
321
368
|
desc="Retrieve",
|
|
322
369
|
index=index,
|
|
323
370
|
search_func=search_func,
|
|
324
371
|
search_attr=search_attr,
|
|
325
|
-
|
|
372
|
+
output_attrs=output_attrs,
|
|
326
373
|
k=k,
|
|
327
|
-
|
|
374
|
+
cache=self._cache,
|
|
328
375
|
)
|
|
329
376
|
|
|
330
377
|
def limit(self, n: int) -> Dataset:
|
|
@@ -334,7 +381,7 @@ class Dataset(Set):
|
|
|
334
381
|
schema=self.schema,
|
|
335
382
|
desc="LIMIT " + str(n),
|
|
336
383
|
limit=n,
|
|
337
|
-
|
|
384
|
+
cache=self._cache,
|
|
338
385
|
)
|
|
339
386
|
|
|
340
387
|
def project(self, project_cols: list[str] | str) -> Dataset:
|
|
@@ -343,7 +390,7 @@ class Dataset(Set):
|
|
|
343
390
|
source=self,
|
|
344
391
|
schema=self.schema.project(project_cols),
|
|
345
392
|
project_cols=project_cols if isinstance(project_cols, list) else [project_cols],
|
|
346
|
-
|
|
393
|
+
cache=self._cache,
|
|
347
394
|
)
|
|
348
395
|
|
|
349
396
|
def run(self, config: QueryProcessorConfig | None = None, **kwargs):
|
|
@@ -26,7 +26,7 @@ def get_models(include_vision: bool = False) -> list[Model]:
|
|
|
26
26
|
models.extend([Model.GPT_4o, Model.GPT_4o_MINI])
|
|
27
27
|
|
|
28
28
|
if os.getenv("TOGETHER_API_KEY") is not None:
|
|
29
|
-
models.extend([Model.LLAMA3, Model.MIXTRAL])
|
|
29
|
+
models.extend([Model.LLAMA3, Model.MIXTRAL, Model.DEEPSEEK])
|
|
30
30
|
|
|
31
31
|
if include_vision:
|
|
32
32
|
vision_models = get_vision_models()
|
|
@@ -39,23 +39,24 @@ TEXT_MODEL_PRIORITY = [
|
|
|
39
39
|
Model.GPT_4o,
|
|
40
40
|
Model.GPT_4o_MINI,
|
|
41
41
|
Model.LLAMA3,
|
|
42
|
-
Model.MIXTRAL
|
|
42
|
+
Model.MIXTRAL,
|
|
43
|
+
Model.DEEPSEEK,
|
|
43
44
|
]
|
|
44
45
|
|
|
45
46
|
VISION_MODEL_PRIORITY = [
|
|
46
47
|
Model.GPT_4o_V,
|
|
47
48
|
Model.GPT_4o_MINI_V,
|
|
48
|
-
Model.LLAMA3_V
|
|
49
|
+
Model.LLAMA3_V,
|
|
49
50
|
]
|
|
50
|
-
def get_champion_model(available_models, vision=False):
|
|
51
|
+
def get_champion_model(available_models, vision=False):
|
|
51
52
|
# Select appropriate priority list based on task
|
|
52
53
|
model_priority = VISION_MODEL_PRIORITY if vision else TEXT_MODEL_PRIORITY
|
|
53
|
-
|
|
54
|
+
|
|
54
55
|
# Return first available model from priority list
|
|
55
56
|
for model in model_priority:
|
|
56
57
|
if model in available_models:
|
|
57
58
|
return model
|
|
58
|
-
|
|
59
|
+
|
|
59
60
|
# If no suitable model found, raise informative error
|
|
60
61
|
task_type = "vision" if vision else "text"
|
|
61
62
|
raise Exception(
|
|
@@ -66,7 +67,7 @@ def get_champion_model(available_models, vision=False):
|
|
|
66
67
|
)
|
|
67
68
|
|
|
68
69
|
|
|
69
|
-
def
|
|
70
|
+
def get_fallback_model(available_models, vision=False):
|
|
70
71
|
return get_champion_model(available_models, vision)
|
|
71
72
|
|
|
72
73
|
|