palimpzest 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/core/data/dataset.py +16 -1
- palimpzest/core/elements/records.py +3 -40
- palimpzest/core/lib/schemas.py +9 -0
- palimpzest/query/execution/execution_strategy.py +5 -0
- palimpzest/query/execution/mab_execution_strategy.py +56 -23
- palimpzest/query/operators/__init__.py +2 -1
- palimpzest/query/operators/join.py +13 -11
- palimpzest/query/optimizer/__init__.py +7 -3
- palimpzest/query/optimizer/optimizer.py +8 -0
- palimpzest/query/optimizer/optimizer_strategy.py +0 -3
- palimpzest/query/optimizer/plan.py +5 -6
- palimpzest/query/optimizer/rules.py +40 -6
- palimpzest/query/optimizer/tasks.py +9 -1
- palimpzest/query/processor/config.py +1 -0
- palimpzest/query/processor/query_processor_factory.py +7 -0
- palimpzest/validator/validator.py +14 -14
- {palimpzest-0.8.4.dist-info → palimpzest-0.8.6.dist-info}/METADATA +1 -1
- {palimpzest-0.8.4.dist-info → palimpzest-0.8.6.dist-info}/RECORD +21 -21
- {palimpzest-0.8.4.dist-info → palimpzest-0.8.6.dist-info}/WHEEL +0 -0
- {palimpzest-0.8.4.dist-info → palimpzest-0.8.6.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.8.4.dist-info → palimpzest-0.8.6.dist-info}/top_level.txt +0 -0
palimpzest/core/data/dataset.py
CHANGED
|
@@ -10,7 +10,7 @@ from pydantic import BaseModel
|
|
|
10
10
|
from palimpzest.constants import AggFunc, Cardinality
|
|
11
11
|
from palimpzest.core.elements.filters import Filter
|
|
12
12
|
from palimpzest.core.elements.groupbysig import GroupBySig
|
|
13
|
-
from palimpzest.core.lib.schemas import create_schema_from_fields, project, union_schemas
|
|
13
|
+
from palimpzest.core.lib.schemas import create_schema_from_fields, project, relax_schema, union_schemas
|
|
14
14
|
from palimpzest.policy import construct_policy_from_kwargs
|
|
15
15
|
from palimpzest.query.operators.logical import (
|
|
16
16
|
Aggregate,
|
|
@@ -193,6 +193,21 @@ class Dataset:
|
|
|
193
193
|
|
|
194
194
|
return root_datasets
|
|
195
195
|
|
|
196
|
+
def relax_types(self) -> None:
|
|
197
|
+
"""
|
|
198
|
+
Relax the types in this Dataset's schema and all upstream Datasets' schemas to be more permissive.
|
|
199
|
+
"""
|
|
200
|
+
# relax the types in this dataset's schema
|
|
201
|
+
self._schema = relax_schema(self._schema)
|
|
202
|
+
|
|
203
|
+
# relax the types in dataset's operator's input and output schemas
|
|
204
|
+
self._operator.input_schema = None if self._operator.input_schema is None else relax_schema(self._operator.input_schema)
|
|
205
|
+
self._operator.output_schema = relax_schema(self._operator.output_schema)
|
|
206
|
+
|
|
207
|
+
# recursively relax the types in all upstream datasets
|
|
208
|
+
for source in self._sources:
|
|
209
|
+
source.relax_types()
|
|
210
|
+
|
|
196
211
|
def get_upstream_datasets(self) -> list[Dataset]:
|
|
197
212
|
"""
|
|
198
213
|
Get the list of all upstream datasets that are sources to this dataset.
|
|
@@ -16,12 +16,11 @@ from palimpzest.core.lib.schemas import (
|
|
|
16
16
|
ImageBase64,
|
|
17
17
|
ImageFilepath,
|
|
18
18
|
ImageURL,
|
|
19
|
-
create_schema_from_df,
|
|
20
19
|
project,
|
|
21
20
|
union_schemas,
|
|
22
21
|
)
|
|
23
22
|
from palimpzest.core.models import ExecutionStats, PlanStats, RecordOpStats
|
|
24
|
-
from palimpzest.utils.hash_helpers import hash_for_id
|
|
23
|
+
from palimpzest.utils.hash_helpers import hash_for_id
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
class DataRecord:
|
|
@@ -93,10 +92,8 @@ class DataRecord:
|
|
|
93
92
|
|
|
94
93
|
|
|
95
94
|
def __getattr__(self, name: str) -> Any:
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
return field
|
|
99
|
-
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
|
95
|
+
return getattr(self._data_item, name)
|
|
96
|
+
|
|
100
97
|
|
|
101
98
|
def __getitem__(self, field: str) -> Any:
|
|
102
99
|
return getattr(self._data_item, field)
|
|
@@ -266,40 +263,6 @@ class DataRecord:
|
|
|
266
263
|
|
|
267
264
|
return new_dr
|
|
268
265
|
|
|
269
|
-
# TODO: unused outside of unit tests
|
|
270
|
-
@staticmethod
|
|
271
|
-
def from_df(df: pd.DataFrame, schema: type[BaseModel] | None = None) -> list[DataRecord]:
|
|
272
|
-
"""Create a list of DataRecords from a pandas DataFrame
|
|
273
|
-
|
|
274
|
-
Args:
|
|
275
|
-
df (pd.DataFrame): Input DataFrame
|
|
276
|
-
schema (BaseModel, optional): Schema for the DataRecords. If None, will be derived from DataFrame
|
|
277
|
-
|
|
278
|
-
Returns:
|
|
279
|
-
list[DataRecord]: List of DataRecord instances
|
|
280
|
-
"""
|
|
281
|
-
if df is None:
|
|
282
|
-
raise ValueError("DataFrame is None!")
|
|
283
|
-
|
|
284
|
-
# create schema if one isn't provided
|
|
285
|
-
if schema is None:
|
|
286
|
-
schema = create_schema_from_df(df)
|
|
287
|
-
|
|
288
|
-
# create an id for the dataset from the schema
|
|
289
|
-
dataset_id = hash_for_serialized_dict({
|
|
290
|
-
k: {"annotation": str(v.annotation), "default": str(v.default), "description": v.description}
|
|
291
|
-
for k, v in schema.model_fields.items()
|
|
292
|
-
})
|
|
293
|
-
|
|
294
|
-
# create records
|
|
295
|
-
records = []
|
|
296
|
-
for idx, row in df.iterrows():
|
|
297
|
-
row_dict = row.to_dict()
|
|
298
|
-
record = DataRecord(schema(**row_dict), source_indices=[f"{dataset_id}-{idx}"])
|
|
299
|
-
records.append(record)
|
|
300
|
-
|
|
301
|
-
return records
|
|
302
|
-
|
|
303
266
|
@staticmethod
|
|
304
267
|
def to_df(records: list[DataRecord], project_cols: list[str] | None = None) -> pd.DataFrame:
|
|
305
268
|
if len(records) == 0:
|
palimpzest/core/lib/schemas.py
CHANGED
|
@@ -80,6 +80,15 @@ def _create_pickleable_model(fields: dict[str, tuple[type, FieldInfo]]) -> type[
|
|
|
80
80
|
return new_model
|
|
81
81
|
|
|
82
82
|
|
|
83
|
+
def relax_schema(model: type[BaseModel]) -> type[BaseModel]:
|
|
84
|
+
"""Updates the type annotation for every field in the BaseModel to include typing.Any"""
|
|
85
|
+
fields = {}
|
|
86
|
+
for field_name, field in model.model_fields.items():
|
|
87
|
+
fields[field_name] = (field.annotation | Any, field)
|
|
88
|
+
|
|
89
|
+
return _create_pickleable_model(fields)
|
|
90
|
+
|
|
91
|
+
|
|
83
92
|
def project(model: type[BaseModel], project_fields: list[str]) -> type[BaseModel]:
|
|
84
93
|
"""Project a Pydantic model to only the specified columns."""
|
|
85
94
|
# make sure projection column names are shortened
|
|
@@ -314,6 +314,11 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
314
314
|
for future in as_completed(futures):
|
|
315
315
|
# update output record sets
|
|
316
316
|
record_set, operator, source_indices, input = future.result()
|
|
317
|
+
|
|
318
|
+
# if the operator is a join, get record_set from tuple output
|
|
319
|
+
if isinstance(operator, JoinOp):
|
|
320
|
+
record_set = record_set[0]
|
|
321
|
+
|
|
317
322
|
output_record_sets.append((record_set, operator, source_indices, input))
|
|
318
323
|
|
|
319
324
|
# update cache
|
|
@@ -11,7 +11,7 @@ from palimpzest.policy import Policy
|
|
|
11
11
|
from palimpzest.query.execution.execution_strategy import SentinelExecutionStrategy
|
|
12
12
|
from palimpzest.query.operators.aggregate import AggregateOp
|
|
13
13
|
from palimpzest.query.operators.convert import LLMConvert
|
|
14
|
-
from palimpzest.query.operators.filter import FilterOp, LLMFilter
|
|
14
|
+
from palimpzest.query.operators.filter import FilterOp, LLMFilter, NonLLMFilter
|
|
15
15
|
from palimpzest.query.operators.join import JoinOp
|
|
16
16
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
17
17
|
from palimpzest.query.operators.retrieve import RetrieveOp
|
|
@@ -351,7 +351,7 @@ class OpFrontier:
|
|
|
351
351
|
|
|
352
352
|
return op_inputs
|
|
353
353
|
|
|
354
|
-
def update_frontier(self, unique_logical_op_id: str, plan_stats: SentinelPlanStats) -> None:
|
|
354
|
+
def update_frontier(self, unique_logical_op_id: str, plan_stats: SentinelPlanStats, full_op_id_to_source_indices_processed: dict[str, set[list]]) -> None:
|
|
355
355
|
"""
|
|
356
356
|
Update the set of frontier operators, pulling in new ones from the reservoir as needed.
|
|
357
357
|
This function will:
|
|
@@ -383,22 +383,14 @@ class OpFrontier:
|
|
|
383
383
|
# compute final list of record op stats
|
|
384
384
|
full_op_id_to_record_op_stats[full_op_id] = list(record_id_to_max_quality_record_op_stats.values())
|
|
385
385
|
|
|
386
|
-
#
|
|
387
|
-
|
|
388
|
-
full_op_id_to_num_samples, total_num_samples = {}, 0
|
|
389
|
-
for full_op_id, record_op_stats_lst in full_op_id_to_record_op_stats.items():
|
|
386
|
+
# update the set of source indices processed by each physical operator
|
|
387
|
+
for full_op_id, source_indices_processed in full_op_id_to_source_indices_processed.items():
|
|
390
388
|
# update the set of source indices processed
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
source_indices = record_op_stats.record_source_indices
|
|
394
|
-
|
|
395
|
-
if len(source_indices) == 1:
|
|
396
|
-
source_indices = source_indices[0]
|
|
397
|
-
elif self.is_llm_join or self.is_aggregate_op:
|
|
398
|
-
source_indices = tuple(source_indices)
|
|
399
|
-
|
|
389
|
+
for source_indices in source_indices_processed:
|
|
390
|
+
source_indices = source_indices[0] if len(source_indices) == 1 else tuple(source_indices)
|
|
400
391
|
self.full_op_id_to_sources_processed[full_op_id].add(source_indices)
|
|
401
|
-
|
|
392
|
+
if source_indices in self.full_op_id_to_sources_not_processed[full_op_id]:
|
|
393
|
+
self.full_op_id_to_sources_not_processed[full_op_id].remove(source_indices)
|
|
402
394
|
|
|
403
395
|
# update the set of source indices not processed
|
|
404
396
|
self.full_op_id_to_sources_not_processed[full_op_id] = [
|
|
@@ -406,8 +398,11 @@ class OpFrontier:
|
|
|
406
398
|
if indices not in source_indices_processed
|
|
407
399
|
]
|
|
408
400
|
|
|
409
|
-
|
|
410
|
-
|
|
401
|
+
# compute mapping of physical op to num samples and total samples drawn
|
|
402
|
+
full_op_id_to_num_samples, total_num_samples = {}, 0
|
|
403
|
+
for full_op_id, record_op_stats_lst in full_op_id_to_record_op_stats.items():
|
|
404
|
+
# compute the number of samples as the length of the record_op_stats_lst
|
|
405
|
+
num_samples = len(record_op_stats_lst)
|
|
411
406
|
full_op_id_to_num_samples[full_op_id] = num_samples
|
|
412
407
|
total_num_samples += num_samples
|
|
413
408
|
|
|
@@ -620,6 +615,28 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
620
615
|
calls does not perfectly match the sample_budget. This may cause some minor discrepancies with
|
|
621
616
|
the progress manager as a result.
|
|
622
617
|
"""
|
|
618
|
+
def _remove_filtered_records_from_downstream_ops(self, topo_idx: int, plan: SentinelPlan, op_frontiers: dict[str, OpFrontier], source_indices_to_all_record_sets: dict[int, list[DataRecordSet]]) -> None:
|
|
619
|
+
"""Remove records which were filtered out by a NonLLMFilter from all downstream operators."""
|
|
620
|
+
filtered_source_indices = set()
|
|
621
|
+
|
|
622
|
+
# NonLLMFilter will have one record_set per source_indices with a single record
|
|
623
|
+
for source_indices, record_sets in source_indices_to_all_record_sets.items():
|
|
624
|
+
record: DataRecord = record_sets[0][0]
|
|
625
|
+
if not record._passed_operator:
|
|
626
|
+
filtered_source_indices.add(source_indices)
|
|
627
|
+
|
|
628
|
+
# remove filtered source indices from all downstream operators
|
|
629
|
+
if len(filtered_source_indices) > 0:
|
|
630
|
+
for downstream_topo_idx in range(topo_idx + 1, len(plan)):
|
|
631
|
+
downstream_logical_op_id = plan[downstream_topo_idx][0]
|
|
632
|
+
downstream_unique_logical_op_id = f"{downstream_topo_idx}-{downstream_logical_op_id}"
|
|
633
|
+
downstream_op_frontier = op_frontiers[downstream_unique_logical_op_id]
|
|
634
|
+
for full_op_id in downstream_op_frontier.full_op_id_to_sources_not_processed:
|
|
635
|
+
downstream_op_frontier.full_op_id_to_sources_not_processed[full_op_id] = [
|
|
636
|
+
indices for indices in downstream_op_frontier.full_op_id_to_sources_not_processed[full_op_id]
|
|
637
|
+
if indices not in filtered_source_indices
|
|
638
|
+
]
|
|
639
|
+
|
|
623
640
|
def _get_max_quality_op(self, unique_logical_op_id: str, op_frontiers: dict[str, OpFrontier], plan_stats: SentinelPlanStats) -> PhysicalOperator:
|
|
624
641
|
"""
|
|
625
642
|
Returns the operator in the frontier with the highest (estimated) quality.
|
|
@@ -639,7 +656,11 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
639
656
|
for op in frontier_ops:
|
|
640
657
|
op_quality_stats = []
|
|
641
658
|
if op.get_full_op_id() in full_op_id_to_record_op_stats:
|
|
642
|
-
op_quality_stats = [
|
|
659
|
+
op_quality_stats = [
|
|
660
|
+
record_op_stats.quality
|
|
661
|
+
for record_op_stats in full_op_id_to_record_op_stats[op.get_full_op_id()]
|
|
662
|
+
if record_op_stats.quality is not None
|
|
663
|
+
]
|
|
643
664
|
avg_op_quality = sum(op_quality_stats) / len(op_quality_stats) if len(op_quality_stats) > 0 else 0.0
|
|
644
665
|
if max_avg_quality is None or avg_op_quality > max_avg_quality:
|
|
645
666
|
max_quality_op = op
|
|
@@ -664,7 +685,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
664
685
|
source_indices_to_sample.update(source_indices)
|
|
665
686
|
|
|
666
687
|
# execute operator sets in sequence
|
|
667
|
-
for topo_idx, (logical_op_id,
|
|
688
|
+
for topo_idx, (logical_op_id, op_set) in enumerate(plan):
|
|
668
689
|
# compute unique logical op id within plan
|
|
669
690
|
unique_logical_op_id = f"{topo_idx}-{logical_op_id}"
|
|
670
691
|
|
|
@@ -672,8 +693,10 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
672
693
|
max_quality_op = self._get_max_quality_op(unique_logical_op_id, op_frontiers, plan_stats)
|
|
673
694
|
|
|
674
695
|
# get frontier ops and their next input
|
|
696
|
+
def is_filtered_out(tup: tuple) -> bool:
|
|
697
|
+
return tup[-1] is None or isinstance(tup[-1], list) and all([record is None for record in tup[-1]])
|
|
675
698
|
frontier_op_inputs = op_frontiers[unique_logical_op_id].get_frontier_op_inputs(source_indices_to_sample, max_quality_op)
|
|
676
|
-
frontier_op_inputs = list(filter(lambda tup:
|
|
699
|
+
frontier_op_inputs = list(filter(lambda tup: not is_filtered_out(tup), frontier_op_inputs))
|
|
677
700
|
|
|
678
701
|
# break out of the loop if frontier_op_inputs is empty, as this means all records have been filtered out
|
|
679
702
|
if len(frontier_op_inputs) == 0:
|
|
@@ -711,7 +734,18 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
711
734
|
op_frontiers[next_unique_logical_op_id].update_inputs(unique_logical_op_id, source_indices_to_all_record_sets)
|
|
712
735
|
|
|
713
736
|
# update the (pareto) frontier for each set of operators
|
|
714
|
-
|
|
737
|
+
full_op_id_to_source_indices_processed = {}
|
|
738
|
+
for source_indices, record_set_tuples in source_indices_to_record_set_tuples.items():
|
|
739
|
+
for _, op, _ in record_set_tuples:
|
|
740
|
+
if op.get_full_op_id() not in full_op_id_to_source_indices_processed:
|
|
741
|
+
full_op_id_to_source_indices_processed[op.get_full_op_id()] = set()
|
|
742
|
+
full_op_id_to_source_indices_processed[op.get_full_op_id()].add(source_indices)
|
|
743
|
+
op_frontiers[unique_logical_op_id].update_frontier(unique_logical_op_id, plan_stats, full_op_id_to_source_indices_processed)
|
|
744
|
+
|
|
745
|
+
# if the operator is a non-llm filter which has filtered out records, remove those records from
|
|
746
|
+
# all downstream operators' full_op_id_to_sources_not_processed
|
|
747
|
+
if isinstance(op_set[0], NonLLMFilter):
|
|
748
|
+
self._remove_filtered_records_from_downstream_ops(topo_idx, plan, op_frontiers, source_indices_to_all_record_sets)
|
|
715
749
|
|
|
716
750
|
# finalize plan stats
|
|
717
751
|
plan_stats.finish()
|
|
@@ -721,7 +755,6 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
721
755
|
|
|
722
756
|
def execute_sentinel_plan(self, plan: SentinelPlan, train_dataset: dict[str, Dataset], validator: Validator) -> SentinelPlanStats:
|
|
723
757
|
logger.info(f"Executing plan {plan.plan_id} with {self.max_workers} workers")
|
|
724
|
-
logger.info(f"Plan Details: {plan}")
|
|
725
758
|
|
|
726
759
|
# initialize plan stats
|
|
727
760
|
plan_stats = SentinelPlanStats.from_plan(plan)
|
|
@@ -12,6 +12,7 @@ from palimpzest.query.operators.distinct import DistinctOp as _DistinctOp
|
|
|
12
12
|
from palimpzest.query.operators.filter import FilterOp as _FilterOp
|
|
13
13
|
from palimpzest.query.operators.filter import LLMFilter as _LLMFilter
|
|
14
14
|
from palimpzest.query.operators.filter import NonLLMFilter as _NonLLMFilter
|
|
15
|
+
from palimpzest.query.operators.join import EmbeddingJoin as _EmbeddingJoin
|
|
15
16
|
from palimpzest.query.operators.join import JoinOp as _JoinOp
|
|
16
17
|
from palimpzest.query.operators.join import NestedLoopsJoin as _NestedLoopsJoin
|
|
17
18
|
from palimpzest.query.operators.limit import LimitScanOp as _LimitScanOp
|
|
@@ -88,7 +89,7 @@ PHYSICAL_OPERATORS = (
|
|
|
88
89
|
# filter
|
|
89
90
|
+ [_FilterOp, _NonLLMFilter, _LLMFilter]
|
|
90
91
|
# join
|
|
91
|
-
+ [_JoinOp, _NestedLoopsJoin]
|
|
92
|
+
+ [_EmbeddingJoin, _JoinOp, _NestedLoopsJoin]
|
|
92
93
|
# limit
|
|
93
94
|
+ [_LimitScanOp]
|
|
94
95
|
# mixture-of-agents
|
|
@@ -41,6 +41,7 @@ class JoinOp(PhysicalOperator, ABC):
|
|
|
41
41
|
prompt_strategy: PromptStrategy = PromptStrategy.JOIN,
|
|
42
42
|
join_parallelism: int = 64,
|
|
43
43
|
reasoning_effort: str | None = None,
|
|
44
|
+
retain_inputs: bool = True,
|
|
44
45
|
desc: str | None = None,
|
|
45
46
|
*args,
|
|
46
47
|
**kwargs,
|
|
@@ -52,6 +53,7 @@ class JoinOp(PhysicalOperator, ABC):
|
|
|
52
53
|
self.prompt_strategy = prompt_strategy
|
|
53
54
|
self.join_parallelism = join_parallelism
|
|
54
55
|
self.reasoning_effort = reasoning_effort
|
|
56
|
+
self.retain_inputs = retain_inputs
|
|
55
57
|
self.desc = desc
|
|
56
58
|
self.generator = Generator(model, prompt_strategy, reasoning_effort, self.api_base, Cardinality.ONE_TO_ONE, self.desc, self.verbose)
|
|
57
59
|
self.join_idx = 0
|
|
@@ -82,10 +84,11 @@ class JoinOp(PhysicalOperator, ABC):
|
|
|
82
84
|
op_params = super().get_op_params()
|
|
83
85
|
op_params = {
|
|
84
86
|
"condition": self.condition,
|
|
85
|
-
"model": self.model
|
|
86
|
-
"prompt_strategy": self.prompt_strategy
|
|
87
|
+
"model": self.model,
|
|
88
|
+
"prompt_strategy": self.prompt_strategy,
|
|
87
89
|
"join_parallelism": self.join_parallelism,
|
|
88
90
|
"reasoning_effort": self.reasoning_effort,
|
|
91
|
+
"retain_inputs": self.retain_inputs,
|
|
89
92
|
"desc": self.desc,
|
|
90
93
|
**op_params,
|
|
91
94
|
}
|
|
@@ -227,8 +230,9 @@ class NestedLoopsJoin(JoinOp):
|
|
|
227
230
|
num_inputs_processed = len(join_candidates)
|
|
228
231
|
|
|
229
232
|
# store input records to join with new records added later
|
|
230
|
-
self.
|
|
231
|
-
|
|
233
|
+
if self.retain_inputs:
|
|
234
|
+
self._left_input_records.extend(left_candidates)
|
|
235
|
+
self._right_input_records.extend(right_candidates)
|
|
232
236
|
|
|
233
237
|
# return empty DataRecordSet if no output records were produced
|
|
234
238
|
if len(output_records) == 0:
|
|
@@ -242,7 +246,7 @@ class EmbeddingJoin(JoinOp):
|
|
|
242
246
|
# specialized use cases (e.g., speech-to-text) with strict requirements on things like e.g. sample rate
|
|
243
247
|
def __init__(
|
|
244
248
|
self,
|
|
245
|
-
num_samples: int =
|
|
249
|
+
num_samples: int = 10,
|
|
246
250
|
*args,
|
|
247
251
|
**kwargs,
|
|
248
252
|
):
|
|
@@ -307,10 +311,7 @@ class EmbeddingJoin(JoinOp):
|
|
|
307
311
|
)
|
|
308
312
|
|
|
309
313
|
# get est. of conversion cost (in USD) per record from model card
|
|
310
|
-
model_conversion_usd_per_record =
|
|
311
|
-
MODEL_CARDS[self.embedding_model.value]["usd_per_input_token"] * est_num_input_tokens
|
|
312
|
-
+ MODEL_CARDS[self.embedding_model.value]["usd_per_output_token"] * est_num_output_tokens
|
|
313
|
-
)
|
|
314
|
+
model_conversion_usd_per_record = MODEL_CARDS[self.embedding_model.value]["usd_per_input_token"] * est_num_input_tokens
|
|
314
315
|
|
|
315
316
|
# estimate output cardinality using a constant assumption of the filter selectivity
|
|
316
317
|
selectivity = NAIVE_EST_JOIN_SELECTIVITY
|
|
@@ -521,8 +522,9 @@ class EmbeddingJoin(JoinOp):
|
|
|
521
522
|
record_op_stats.total_embedding_cost = amortized_embedding_cost
|
|
522
523
|
|
|
523
524
|
# store input records to join with new records added later
|
|
524
|
-
self.
|
|
525
|
-
|
|
525
|
+
if self.retain_inputs:
|
|
526
|
+
self._left_input_records.extend(zip(left_candidates, left_embeddings))
|
|
527
|
+
self._right_input_records.extend(zip(right_candidates, right_embeddings))
|
|
526
528
|
|
|
527
529
|
# return empty DataRecordSet if no output records were produced
|
|
528
530
|
if len(output_records) == 0:
|
|
@@ -8,6 +8,9 @@ from palimpzest.query.optimizer.rules import (
|
|
|
8
8
|
from palimpzest.query.optimizer.rules import (
|
|
9
9
|
CritiqueAndRefineRule as _CritiqueAndRefineRule,
|
|
10
10
|
)
|
|
11
|
+
from palimpzest.query.optimizer.rules import (
|
|
12
|
+
EmbeddingJoinRule as _EmbeddingJoinRule,
|
|
13
|
+
)
|
|
11
14
|
from palimpzest.query.optimizer.rules import (
|
|
12
15
|
ImplementationRule as _ImplementationRule,
|
|
13
16
|
)
|
|
@@ -18,10 +21,10 @@ from palimpzest.query.optimizer.rules import (
|
|
|
18
21
|
LLMFilterRule as _LLMFilterRule,
|
|
19
22
|
)
|
|
20
23
|
from palimpzest.query.optimizer.rules import (
|
|
21
|
-
|
|
24
|
+
MixtureOfAgentsRule as _MixtureOfAgentsRule,
|
|
22
25
|
)
|
|
23
26
|
from palimpzest.query.optimizer.rules import (
|
|
24
|
-
|
|
27
|
+
NestedLoopsJoinRule as _NestedLoopsJoinRule,
|
|
25
28
|
)
|
|
26
29
|
from palimpzest.query.optimizer.rules import (
|
|
27
30
|
NonLLMConvertRule as _NonLLMConvertRule,
|
|
@@ -56,10 +59,11 @@ ALL_RULES = [
|
|
|
56
59
|
_AggregateRule,
|
|
57
60
|
_BasicSubstitutionRule,
|
|
58
61
|
_CritiqueAndRefineRule,
|
|
62
|
+
_EmbeddingJoinRule,
|
|
59
63
|
_ImplementationRule,
|
|
60
64
|
_LLMConvertBondedRule,
|
|
61
65
|
_LLMFilterRule,
|
|
62
|
-
|
|
66
|
+
_NestedLoopsJoinRule,
|
|
63
67
|
_MixtureOfAgentsRule,
|
|
64
68
|
_NonLLMConvertRule,
|
|
65
69
|
_NonLLMFilterRule,
|
|
@@ -181,6 +181,7 @@ class Optimizer:
|
|
|
181
181
|
"join_parallelism": self.join_parallelism,
|
|
182
182
|
"reasoning_effort": self.reasoning_effort,
|
|
183
183
|
"api_base": self.api_base,
|
|
184
|
+
"is_validation": self.optimizer_strategy == OptimizationStrategyType.SENTINEL,
|
|
184
185
|
}
|
|
185
186
|
|
|
186
187
|
def deepcopy_clean(self):
|
|
@@ -204,10 +205,17 @@ class Optimizer:
|
|
|
204
205
|
return optimizer
|
|
205
206
|
|
|
206
207
|
def update_strategy(self, optimizer_strategy: OptimizationStrategyType):
|
|
208
|
+
# set the optimizer_strategy
|
|
207
209
|
self.optimizer_strategy = optimizer_strategy
|
|
210
|
+
|
|
211
|
+
# get the strategy class associated with the optimizer strategy
|
|
208
212
|
optimizer_strategy_cls = optimizer_strategy.value
|
|
209
213
|
self.strategy = optimizer_strategy_cls()
|
|
210
214
|
|
|
215
|
+
# remove transformation rules for optimization strategies which do not require them
|
|
216
|
+
if optimizer_strategy.no_transformation():
|
|
217
|
+
self.transformation_rules = []
|
|
218
|
+
|
|
211
219
|
def construct_group_tree(self, dataset: Dataset) -> tuple[int, dict[str, FieldInfo], dict[str, set[str]]]:
|
|
212
220
|
logger.debug(f"Constructing group tree for dataset: {dataset}")
|
|
213
221
|
### convert node --> Group ###
|
|
@@ -58,7 +58,6 @@ class GreedyStrategy(OptimizationStrategy):
|
|
|
58
58
|
def get_optimal_plans(self, groups: dict, final_group_id: int, policy: Policy, use_final_op_quality: bool) -> list[PhysicalPlan]:
|
|
59
59
|
logger.info(f"Getting greedy optimal plans for final group id: {final_group_id}")
|
|
60
60
|
plans = [self._get_greedy_physical_plan(groups, final_group_id)]
|
|
61
|
-
logger.info(f"Greedy optimal plans: {plans}")
|
|
62
61
|
logger.info(f"Done getting greedy optimal plans for final group id: {final_group_id}")
|
|
63
62
|
|
|
64
63
|
return plans
|
|
@@ -137,7 +136,6 @@ class ParetoStrategy(OptimizationStrategy):
|
|
|
137
136
|
optimal_plan = optimal_plan if policy.choose(optimal_plan.plan_cost, plan.plan_cost) else plan
|
|
138
137
|
|
|
139
138
|
plans = [optimal_plan]
|
|
140
|
-
logger.info(f"Pareto optimal plans: {plans}")
|
|
141
139
|
logger.info(f"Done getting pareto optimal plans for final group id: {final_group_id}")
|
|
142
140
|
return plans
|
|
143
141
|
|
|
@@ -174,7 +172,6 @@ class SentinelStrategy(OptimizationStrategy):
|
|
|
174
172
|
def get_optimal_plans(self, groups: dict, final_group_id: int, policy: Policy, use_final_op_quality: bool) -> list[SentinelPlan]:
|
|
175
173
|
logger.info(f"Getting sentinel optimal plans for final group id: {final_group_id}")
|
|
176
174
|
plans = [self._get_sentinel_plan(groups, final_group_id)]
|
|
177
|
-
logger.info(f"Sentinel optimal plans: {plans}")
|
|
178
175
|
logger.info(f"Done getting sentinel optimal plans for final group id: {final_group_id}")
|
|
179
176
|
return plans
|
|
180
177
|
|
|
@@ -330,12 +330,11 @@ class SentinelPlan(Plan):
|
|
|
330
330
|
|
|
331
331
|
def _get_str(self, idx: int = 0, indent: int = 0) -> str:
|
|
332
332
|
indent_str = " " * (indent * 2)
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
plan_str += subplan._get_str(idx=idx + 1, indent=indent + 1)
|
|
333
|
+
operator = self.operator_set[0]
|
|
334
|
+
inner_idx_str = "" if len(self.operator_set) == 1 else f"1 - {len(self.operator_set)}."
|
|
335
|
+
plan_str = f"{indent_str}{idx}.{inner_idx_str} {str(operator)}\n"
|
|
336
|
+
for subplan in self.subplans:
|
|
337
|
+
plan_str += subplan._get_str(idx=idx + 1, indent=indent + 1)
|
|
339
338
|
|
|
340
339
|
return plan_str
|
|
341
340
|
|
|
@@ -18,7 +18,7 @@ from palimpzest.query.operators.convert import LLMConvertBonded, NonLLMConvert
|
|
|
18
18
|
from palimpzest.query.operators.critique_and_refine import CritiqueAndRefineConvert, CritiqueAndRefineFilter
|
|
19
19
|
from palimpzest.query.operators.distinct import DistinctOp
|
|
20
20
|
from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter
|
|
21
|
-
from palimpzest.query.operators.join import NestedLoopsJoin
|
|
21
|
+
from palimpzest.query.operators.join import EmbeddingJoin, NestedLoopsJoin
|
|
22
22
|
from palimpzest.query.operators.limit import LimitScanOp
|
|
23
23
|
from palimpzest.query.operators.logical import (
|
|
24
24
|
Aggregate,
|
|
@@ -761,8 +761,8 @@ class SplitRule(ImplementationRule):
|
|
|
761
761
|
@classmethod
|
|
762
762
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
763
763
|
logical_op = logical_expression.operator
|
|
764
|
-
is_map_match = isinstance(logical_op, ConvertScan) and cls._is_text_only_operation() and logical_op.udf is None
|
|
765
|
-
is_filter_match = isinstance(logical_op, FilteredScan) and cls._is_text_only_operation() and logical_op.filter.filter_fn is None
|
|
764
|
+
is_map_match = isinstance(logical_op, ConvertScan) and cls._is_text_only_operation(logical_expression) and logical_op.udf is None
|
|
765
|
+
is_filter_match = isinstance(logical_op, FilteredScan) and cls._is_text_only_operation(logical_expression) and logical_op.filter.filter_fn is None
|
|
766
766
|
logger.debug(f"SplitRule matches_pattern: {is_map_match or is_filter_match} for {logical_expression}")
|
|
767
767
|
return is_map_match or is_filter_match
|
|
768
768
|
|
|
@@ -860,7 +860,7 @@ class LLMFilterRule(ImplementationRule):
|
|
|
860
860
|
return cls._perform_substitution(logical_expression, LLMFilter, runtime_kwargs, variable_op_kwargs)
|
|
861
861
|
|
|
862
862
|
|
|
863
|
-
class
|
|
863
|
+
class NestedLoopsJoinRule(ImplementationRule):
|
|
864
864
|
"""
|
|
865
865
|
Substitute a logical expression for a JoinOp with an (LLM) NestedLoopsJoin physical implementation.
|
|
866
866
|
"""
|
|
@@ -868,12 +868,12 @@ class LLMJoinRule(ImplementationRule):
|
|
|
868
868
|
@classmethod
|
|
869
869
|
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
870
870
|
is_match = isinstance(logical_expression.operator, JoinOp)
|
|
871
|
-
logger.debug(f"
|
|
871
|
+
logger.debug(f"NestedLoopsJoinRule matches_pattern: {is_match} for {logical_expression}")
|
|
872
872
|
return is_match
|
|
873
873
|
|
|
874
874
|
@classmethod
|
|
875
875
|
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
876
|
-
logger.debug(f"Substituting
|
|
876
|
+
logger.debug(f"Substituting NestedLoopsJoinRule for {logical_expression}")
|
|
877
877
|
|
|
878
878
|
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
879
879
|
models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
|
|
@@ -884,6 +884,7 @@ class LLMJoinRule(ImplementationRule):
|
|
|
884
884
|
"prompt_strategy": PromptStrategy.JOIN_NO_REASONING if model.is_reasoning_model() and no_reasoning else PromptStrategy.JOIN,
|
|
885
885
|
"join_parallelism": runtime_kwargs["join_parallelism"],
|
|
886
886
|
"reasoning_effort": runtime_kwargs["reasoning_effort"],
|
|
887
|
+
"retain_inputs": not runtime_kwargs["is_validation"],
|
|
887
888
|
}
|
|
888
889
|
for model in models
|
|
889
890
|
]
|
|
@@ -891,6 +892,39 @@ class LLMJoinRule(ImplementationRule):
|
|
|
891
892
|
return cls._perform_substitution(logical_expression, NestedLoopsJoin, runtime_kwargs, variable_op_kwargs)
|
|
892
893
|
|
|
893
894
|
|
|
895
|
+
class EmbeddingJoinRule(ImplementationRule):
|
|
896
|
+
"""
|
|
897
|
+
Substitute a logical expression for a JoinOp with an EmbeddingJoin physical implementation.
|
|
898
|
+
"""
|
|
899
|
+
|
|
900
|
+
@classmethod
|
|
901
|
+
def matches_pattern(cls, logical_expression: LogicalExpression) -> bool:
|
|
902
|
+
is_match = isinstance(logical_expression.operator, JoinOp) and not cls._is_audio_operation(logical_expression)
|
|
903
|
+
logger.debug(f"EmbeddingJoinRule matches_pattern: {is_match} for {logical_expression}")
|
|
904
|
+
return is_match
|
|
905
|
+
|
|
906
|
+
@classmethod
|
|
907
|
+
def substitute(cls, logical_expression: LogicalExpression, **runtime_kwargs) -> set[PhysicalExpression]:
|
|
908
|
+
logger.debug(f"Substituting EmbeddingJoinRule for {logical_expression}")
|
|
909
|
+
|
|
910
|
+
# create variable physical operator kwargs for each model which can implement this logical_expression
|
|
911
|
+
models = [model for model in runtime_kwargs["available_models"] if cls._model_matches_input(model, logical_expression)]
|
|
912
|
+
no_reasoning = runtime_kwargs["reasoning_effort"] in [None, "minimal", "low"]
|
|
913
|
+
variable_op_kwargs = [
|
|
914
|
+
{
|
|
915
|
+
"model": model,
|
|
916
|
+
"prompt_strategy": PromptStrategy.JOIN_NO_REASONING if model.is_reasoning_model() and no_reasoning else PromptStrategy.JOIN,
|
|
917
|
+
"join_parallelism": runtime_kwargs["join_parallelism"],
|
|
918
|
+
"reasoning_effort": runtime_kwargs["reasoning_effort"],
|
|
919
|
+
"retain_inputs": not runtime_kwargs["is_validation"],
|
|
920
|
+
"num_samples": 10, # TODO: iterate over different choices of num_samples
|
|
921
|
+
}
|
|
922
|
+
for model in models
|
|
923
|
+
]
|
|
924
|
+
|
|
925
|
+
return cls._perform_substitution(logical_expression, EmbeddingJoin, runtime_kwargs, variable_op_kwargs)
|
|
926
|
+
|
|
927
|
+
|
|
894
928
|
class AggregateRule(ImplementationRule):
|
|
895
929
|
"""
|
|
896
930
|
Substitute the logical expression for an aggregate with its physical counterpart.
|
|
@@ -247,8 +247,16 @@ class ApplyRule(Task):
|
|
|
247
247
|
# apply implementation rule
|
|
248
248
|
new_expressions = self.rule.substitute(self.logical_expression, **physical_op_params)
|
|
249
249
|
new_expressions = [expr for expr in new_expressions if expr.expr_id not in expressions]
|
|
250
|
+
|
|
251
|
+
# get the costed_full_op_ids from the context (if provided) and compute whether this
|
|
252
|
+
# logical expression has physical operators which have been costed
|
|
250
253
|
costed_full_op_ids = context['costed_full_op_ids']
|
|
251
|
-
|
|
254
|
+
logical_op_has_been_costed = costed_full_op_ids is not None and any([
|
|
255
|
+
op_id.split("-")[0] == self.logical_expression.operator.get_logical_op_id()
|
|
256
|
+
for op_id in costed_full_op_ids
|
|
257
|
+
])
|
|
258
|
+
|
|
259
|
+
if logical_op_has_been_costed:
|
|
252
260
|
new_expressions = [expr for expr in new_expressions if expr.operator.get_full_op_id() in costed_full_op_ids]
|
|
253
261
|
expressions.update({expr.expr_id: expr for expr in new_expressions})
|
|
254
262
|
group.physical_expressions.update(new_expressions)
|
|
@@ -16,6 +16,7 @@ class QueryProcessorConfig(BaseModel):
|
|
|
16
16
|
|
|
17
17
|
# general execution flags
|
|
18
18
|
policy: Policy = Field(default_factory=MaxQuality)
|
|
19
|
+
enforce_types: bool = Field(default=False)
|
|
19
20
|
scan_start_idx: int = Field(default=0)
|
|
20
21
|
num_samples: int = Field(default=None)
|
|
21
22
|
verbose: bool = Field(default=False)
|
|
@@ -149,6 +149,13 @@ class QueryProcessorFactory:
|
|
|
149
149
|
# apply any additional keyword arguments to the config and validate its contents
|
|
150
150
|
config, validator = cls._config_validation_and_normalization(config, train_dataset, validator)
|
|
151
151
|
|
|
152
|
+
# update the dataset's types if we're not enforcing types
|
|
153
|
+
if not config.enforce_types:
|
|
154
|
+
dataset.relax_types()
|
|
155
|
+
if train_dataset is not None:
|
|
156
|
+
for _, ds in train_dataset.items():
|
|
157
|
+
ds.relax_types()
|
|
158
|
+
|
|
152
159
|
# create the optimizer, execution strateg(ies), and processor
|
|
153
160
|
optimizer = cls._create_optimizer(config)
|
|
154
161
|
config.execution_strategy = cls._create_execution_strategy(dataset, config)
|
|
@@ -79,7 +79,7 @@ class Validator:
|
|
|
79
79
|
Compute the quality of the generated output for the given fields and input_record.
|
|
80
80
|
"""
|
|
81
81
|
# create prompt factory
|
|
82
|
-
factory = PromptFactory(PromptStrategy.MAP,
|
|
82
|
+
factory = PromptFactory(PromptStrategy.MAP, self.model, Cardinality.ONE_TO_ONE)
|
|
83
83
|
|
|
84
84
|
# get the input messages; strip out the system message(s)
|
|
85
85
|
msg_kwargs = {"output_schema": op.output_schema, "project_cols": op.get_input_fields()}
|
|
@@ -95,14 +95,14 @@ class Validator:
|
|
|
95
95
|
start_time = time.time()
|
|
96
96
|
validator_prompt = MAP_IMAGE_VALIDATOR_PROMPT if op.is_image_op() else MAP_VALIDATOR_PROMPT
|
|
97
97
|
val_messages = [{"role": "system", "content": validator_prompt}] + input_messages + [{"role": "user", "content": output_message}]
|
|
98
|
-
completion = litellm.completion(model=
|
|
98
|
+
completion = litellm.completion(model=self.model.value, messages=val_messages)
|
|
99
99
|
completion_text = completion.choices[0].message.content
|
|
100
100
|
gen_stats = self._get_gen_stats_from_completion(completion, start_time)
|
|
101
101
|
print(f"INPUT:\n{input_str}")
|
|
102
102
|
print(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
|
|
103
103
|
|
|
104
104
|
# parse the evaluation
|
|
105
|
-
eval_dict: dict = get_json_from_answer(completion_text,
|
|
105
|
+
eval_dict: dict = get_json_from_answer(completion_text, self.model, Cardinality.ONE_TO_ONE)
|
|
106
106
|
score = sum(eval_dict.values()) / len(eval_dict)
|
|
107
107
|
|
|
108
108
|
except Exception:
|
|
@@ -115,7 +115,7 @@ class Validator:
|
|
|
115
115
|
Compute the quality for each record_op_stats object in the given record_set.
|
|
116
116
|
"""
|
|
117
117
|
# create prompt factory
|
|
118
|
-
factory = PromptFactory(PromptStrategy.MAP,
|
|
118
|
+
factory = PromptFactory(PromptStrategy.MAP, self.model, Cardinality.ONE_TO_MANY)
|
|
119
119
|
|
|
120
120
|
# get the input messages; strip out the system message(s)
|
|
121
121
|
msg_kwargs = {"output_schema": op.output_schema, "project_cols": op.get_input_fields()}
|
|
@@ -138,7 +138,7 @@ class Validator:
|
|
|
138
138
|
# print(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
|
|
139
139
|
|
|
140
140
|
# parse the evaluation
|
|
141
|
-
eval_dicts: list[dict] = get_json_from_answer(completion_text,
|
|
141
|
+
eval_dicts: list[dict] = get_json_from_answer(completion_text, self.model, Cardinality.ONE_TO_MANY)
|
|
142
142
|
all_qualities = []
|
|
143
143
|
for record_eval_dict in eval_dicts:
|
|
144
144
|
all_qualities.extend(record_eval_dict.values())
|
|
@@ -158,12 +158,12 @@ class Validator:
|
|
|
158
158
|
label = self.filter_cache.get(filter_input_hash, None)
|
|
159
159
|
if label is None:
|
|
160
160
|
validator_op: LLMFilter = op.copy()
|
|
161
|
-
validator_op.model =
|
|
161
|
+
validator_op.model = self.model
|
|
162
162
|
try:
|
|
163
163
|
target_record_set = validator_op(input_record)
|
|
164
164
|
label = target_record_set[0]._passed_operator
|
|
165
165
|
self.filter_cache[filter_input_hash] = label
|
|
166
|
-
score = label == output
|
|
166
|
+
score = float(label == output)
|
|
167
167
|
record_op_stats = target_record_set.record_op_stats[0]
|
|
168
168
|
gen_stats = GenerationStats(
|
|
169
169
|
model_name=self.model.value,
|
|
@@ -181,7 +181,7 @@ class Validator:
|
|
|
181
181
|
pass
|
|
182
182
|
|
|
183
183
|
else:
|
|
184
|
-
score = label == output
|
|
184
|
+
score = float(label == output)
|
|
185
185
|
|
|
186
186
|
return score, gen_stats
|
|
187
187
|
|
|
@@ -191,12 +191,12 @@ class Validator:
|
|
|
191
191
|
label = self.join_cache.get(join_input_hash, None)
|
|
192
192
|
if label is None:
|
|
193
193
|
validator_op: JoinOp = op.copy()
|
|
194
|
-
validator_op.model =
|
|
194
|
+
validator_op.model = self.model
|
|
195
195
|
try:
|
|
196
|
-
target_record_set = validator_op([left_input_record], [right_input_record])
|
|
196
|
+
target_record_set, _ = validator_op([left_input_record], [right_input_record])
|
|
197
197
|
label = target_record_set[0]._passed_operator
|
|
198
198
|
self.join_cache[join_input_hash] = label
|
|
199
|
-
score = label == output
|
|
199
|
+
score = float(label == output)
|
|
200
200
|
record_op_stats = target_record_set.record_op_stats[0]
|
|
201
201
|
gen_stats = GenerationStats(
|
|
202
202
|
model_name=self.model.value,
|
|
@@ -214,7 +214,7 @@ class Validator:
|
|
|
214
214
|
pass
|
|
215
215
|
|
|
216
216
|
else:
|
|
217
|
-
score = label == output
|
|
217
|
+
score = float(label == output)
|
|
218
218
|
|
|
219
219
|
return score, gen_stats
|
|
220
220
|
|
|
@@ -225,7 +225,7 @@ class Validator:
|
|
|
225
225
|
# TODO: retrieve k=25; score each item based on relevance; compute F1
|
|
226
226
|
# TODO: support retrieval over images
|
|
227
227
|
# create prompt factory
|
|
228
|
-
factory = PromptFactory(PromptStrategy.MAP,
|
|
228
|
+
factory = PromptFactory(PromptStrategy.MAP, self.model, Cardinality.ONE_TO_ONE)
|
|
229
229
|
|
|
230
230
|
# get the input messages; strip out the system message(s)
|
|
231
231
|
msg_kwargs = {"output_schema": op.output_schema, "project_cols": op.get_input_fields()}
|
|
@@ -249,7 +249,7 @@ class Validator:
|
|
|
249
249
|
print(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)
|
|
250
250
|
|
|
251
251
|
# parse the evaluation
|
|
252
|
-
eval_dict: dict = get_json_from_answer(completion_text,
|
|
252
|
+
eval_dict: dict = get_json_from_answer(completion_text, self.model, Cardinality.ONE_TO_ONE)
|
|
253
253
|
score = sum(eval_dict.values()) / len(eval_dict)
|
|
254
254
|
|
|
255
255
|
except Exception:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: palimpzest
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.6
|
|
4
4
|
Summary: Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language
|
|
5
5
|
Author-email: MIT DSG Semantic Management Lab <michjc@csail.mit.edu>
|
|
6
6
|
Project-URL: homepage, https://palimpzest.org
|
|
@@ -9,15 +9,15 @@ palimpzest/core/models.py,sha256=VNi49i9xn_FxekyYrGPS1-_C_PaGXL8dz-dqjrIOk8g,424
|
|
|
9
9
|
palimpzest/core/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
palimpzest/core/data/context.py,sha256=x1xYyu9qW65dvtK_XayIfv_CgsCEPW6Qe0DTiSf9sjU,16207
|
|
11
11
|
palimpzest/core/data/context_manager.py,sha256=8hAKWD2jhFZgghTu7AYgjkvKDsJUPVxq8g4nG0HWvfo,6150
|
|
12
|
-
palimpzest/core/data/dataset.py,sha256=
|
|
12
|
+
palimpzest/core/data/dataset.py,sha256=0IMmV5_rheNb9ON8wZTy-h1VwWX9mRGkwgc93WGo73E,28881
|
|
13
13
|
palimpzest/core/data/index_dataset.py,sha256=adO67DgzHhA4lBME0-h4SjXfdz9UcNMSDGXTpUdKbgE,1929
|
|
14
14
|
palimpzest/core/data/iter_dataset.py,sha256=K47ajOXsCZV3WhOuDkw3xfiHzn8mXPU976uN3SjaP2U,20507
|
|
15
15
|
palimpzest/core/elements/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
16
|
palimpzest/core/elements/filters.py,sha256=fU2x0eWDwfP52_5fUmqJXTuhs4H0vvHtPZLdA3IIw8I,1642
|
|
17
17
|
palimpzest/core/elements/groupbysig.py,sha256=oFH5UkZzcR0msAgfQiRQOOvyJ3HaW4Dwr03h7tVOcrM,2324
|
|
18
|
-
palimpzest/core/elements/records.py,sha256=
|
|
18
|
+
palimpzest/core/elements/records.py,sha256=pqtuSgc-Jm5N57d6jtUXmQx0D-khqjOIQAFZjS1XmNM,17075
|
|
19
19
|
palimpzest/core/lib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
|
-
palimpzest/core/lib/schemas.py,sha256=
|
|
20
|
+
palimpzest/core/lib/schemas.py,sha256=2fzbTZBTssKTl9CFGDEQneXasOwo-PLP2lCqHZn2eng,9318
|
|
21
21
|
palimpzest/prompts/__init__.py,sha256=942kdENfPU5mFjIxYm-FusL0FD6LNhoj6cYoSGiUsCI,1628
|
|
22
22
|
palimpzest/prompts/agent_prompts.py,sha256=CUzBVLBiPSw8OShtKp4VTpQwtrNMtcMglo-IZHMvuDM,17459
|
|
23
23
|
palimpzest/prompts/context_search.py,sha256=s3pti4XNRiIyiWzjVNL_NqmqEc31jzSKMF2SlN0Aaf8,357
|
|
@@ -35,21 +35,21 @@ palimpzest/prompts/validator.py,sha256=pJTZjlt_OiFM3IFOgsJ0jQdayra8iRVrpqENlXI9t
|
|
|
35
35
|
palimpzest/query/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
36
|
palimpzest/query/execution/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
37
37
|
palimpzest/query/execution/all_sample_execution_strategy.py,sha256=8a8-eKsndo_edCwIamNgcISLQmTzVSv5vmD6Ogl8a6k,14367
|
|
38
|
-
palimpzest/query/execution/execution_strategy.py,sha256=
|
|
38
|
+
palimpzest/query/execution/execution_strategy.py,sha256=XoRVNlJSAgON-NWis9SecFr0B7DlJIm-25u1v5rjvu8,19085
|
|
39
39
|
palimpzest/query/execution/execution_strategy_type.py,sha256=vRQBPCQN5_aoyD3TLIeW3VPo15mqF-5RBvEXkENz9FE,987
|
|
40
|
-
palimpzest/query/execution/mab_execution_strategy.py,sha256=
|
|
40
|
+
palimpzest/query/execution/mab_execution_strategy.py,sha256=YjUZ2qBGvQMVUxi7rQCSU8JKP1RtqhG8Owik8hKB_UU,46292
|
|
41
41
|
palimpzest/query/execution/parallel_execution_strategy.py,sha256=roZZy7wLcmAwm_ecYvqSJanRaiox3OoNPuXxvRZ5TXg,15710
|
|
42
42
|
palimpzest/query/execution/single_threaded_execution_strategy.py,sha256=sESji79ytKxth9Tpm02c34Mltw0YiFn4GL5h0MI5Noo,16255
|
|
43
43
|
palimpzest/query/generators/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
44
44
|
palimpzest/query/generators/generators.py,sha256=UldCUEwaiBfpvQDieA-h7SiC8KM76gCABPj-mvqAJus,21500
|
|
45
|
-
palimpzest/query/operators/__init__.py,sha256=
|
|
45
|
+
palimpzest/query/operators/__init__.py,sha256=T-OFUqWfbL_xqW1n7nkXCWu0JLRePEkMwVCEAl3JNeM,4356
|
|
46
46
|
palimpzest/query/operators/aggregate.py,sha256=NZ_rmi0YrbNFivbBgKtonrCrK6fZw4h9Pm4lMMI5XVc,11376
|
|
47
47
|
palimpzest/query/operators/compute.py,sha256=X_pWN45smg8L4dV54nOae7dldQGL1nJVlVyJ3ULWSmI,8432
|
|
48
48
|
palimpzest/query/operators/convert.py,sha256=VfrWUFyuZC8fPf7LR7mMfpOjqSfxAuTLUxw-S-pn7hk,16123
|
|
49
49
|
palimpzest/query/operators/critique_and_refine.py,sha256=Q-NhasVoD9meX7g36RPrv3q4R48_8XEU4d3TE46hRJI,8979
|
|
50
50
|
palimpzest/query/operators/distinct.py,sha256=ZTXlIS7IaFRTsWv9RemzCo1JLz25vEma-TB42CV5fJQ,2614
|
|
51
51
|
palimpzest/query/operators/filter.py,sha256=ufREsO2-8CBk4u4fabDBYpEvb806E11EOyW-wuRs4vw,10356
|
|
52
|
-
palimpzest/query/operators/join.py,sha256=
|
|
52
|
+
palimpzest/query/operators/join.py,sha256=A0f7d4Nmi-MRp80HD3BrglYZPbFzp5X2vA-X-5XxaGE,25658
|
|
53
53
|
palimpzest/query/operators/limit.py,sha256=pdo7WfWY97SW3c-WqZ4SIPw7lHIVbaXPEWqHyK8qkF8,2130
|
|
54
54
|
palimpzest/query/operators/logical.py,sha256=K_dRlNKkda35kQ7gYGsrW9PoFuDPzexpjtDq_FYdhVw,20223
|
|
55
55
|
palimpzest/query/operators/mixture_of_agents.py,sha256=TWdg6XEg2u4TQM4d94gmbYqnK15wC7Q4Cyefp8SA4i8,11547
|
|
@@ -60,19 +60,19 @@ palimpzest/query/operators/retrieve.py,sha256=-OvEWmxwbepGz0w40FpHbqcOHZQ4Bp-MdX
|
|
|
60
60
|
palimpzest/query/operators/scan.py,sha256=OqCiPRTvTY7SbauNMyFvGT5nRVeRzVsGYSrkoN1Ib_w,7407
|
|
61
61
|
palimpzest/query/operators/search.py,sha256=cQin-Qc9FT7V0Gv3-pxMLbVMjqE6ALe99V0OrQhA6CI,22711
|
|
62
62
|
palimpzest/query/operators/split.py,sha256=oLzwnYb8TNf3XA9TMKEAIw7EIA12wHneaD42BNLIHiI,15043
|
|
63
|
-
palimpzest/query/optimizer/__init__.py,sha256=
|
|
63
|
+
palimpzest/query/optimizer/__init__.py,sha256=COn-okHtnYEyoNFRt3o3SA7jI5Wssx9BUEgfOfP4dOE,2560
|
|
64
64
|
palimpzest/query/optimizer/cost_model.py,sha256=OldPy-TJdfsQbYRoKlb3yWeKbi15jcldTIUS6BTi9T8,12678
|
|
65
|
-
palimpzest/query/optimizer/optimizer.py,sha256=
|
|
66
|
-
palimpzest/query/optimizer/optimizer_strategy.py,sha256=
|
|
65
|
+
palimpzest/query/optimizer/optimizer.py,sha256=BrhljITlFC5S5euA01pv4dzlqxrtKNEt_0DmhRtcMTk,19966
|
|
66
|
+
palimpzest/query/optimizer/optimizer_strategy.py,sha256=0foDaBHqQehK_zz6IlDEbNIw-44wxY6LO5H1anJi56Y,10042
|
|
67
67
|
palimpzest/query/optimizer/optimizer_strategy_type.py,sha256=V-MMHvJdnfZKoUX1xxxwh66q1RjN2FL35IsiT1C62c8,1084
|
|
68
|
-
palimpzest/query/optimizer/plan.py,sha256=
|
|
68
|
+
palimpzest/query/optimizer/plan.py,sha256=NoCUS_lyZ7LFj15_qpZ_cOFHVkCFMcIn8A7EsNeD57c,22849
|
|
69
69
|
palimpzest/query/optimizer/primitives.py,sha256=jMMVq37y1tWiPU1lSSKQP9OP-mzkpSxSmUeDajRYYOQ,5445
|
|
70
|
-
palimpzest/query/optimizer/rules.py,sha256=
|
|
71
|
-
palimpzest/query/optimizer/tasks.py,sha256=
|
|
70
|
+
palimpzest/query/optimizer/rules.py,sha256=er8K47L-qdRn0hCra-2PaqxhQEvDwJ7IVzNEszWHJ48,50452
|
|
71
|
+
palimpzest/query/optimizer/tasks.py,sha256=GCRA4rK6Q8dBGj2FnsRJUk3IdKthNQgiK5lFEu7v0mI,30439
|
|
72
72
|
palimpzest/query/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
73
|
-
palimpzest/query/processor/config.py,sha256=
|
|
73
|
+
palimpzest/query/processor/config.py,sha256=kr9UHQ947SJmI77wqomy310mSaKNIMPxh-5k9frMVII,2413
|
|
74
74
|
palimpzest/query/processor/query_processor.py,sha256=T4ffPbnOX23G8FDITzmM7Iw7DUEDWIHnwl8XLYllgjg,6240
|
|
75
|
-
palimpzest/query/processor/query_processor_factory.py,sha256=
|
|
75
|
+
palimpzest/query/processor/query_processor_factory.py,sha256=i9L9StqlUi7m1AqZMuYQWhunqOJi3nLK47skhxq9tIA,8317
|
|
76
76
|
palimpzest/schemabuilder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
77
77
|
palimpzest/schemabuilder/schema_builder.py,sha256=QraGp66dcD-ej6Y2mER40o86G9JqlBkL7swkJzjUAIY,7968
|
|
78
78
|
palimpzest/tools/README.md,sha256=56_6LPG80uc0CLVhTBP6I1wgIffNv9cyTr0TmVZqmrM,483
|
|
@@ -87,9 +87,9 @@ palimpzest/utils/model_helpers.py,sha256=X6SlMgD5I5Aj_cxaFaoGaaNvOOqTNZVmjj6zbfn
|
|
|
87
87
|
palimpzest/utils/progress.py,sha256=7gucyZr82udMDZitrrkAOSKHZVljE3R2wv9nf5gA5TM,20807
|
|
88
88
|
palimpzest/utils/udfs.py,sha256=LjHic54B1az-rKgNLur0wOpaz2ko_UodjLEJrazkxvY,1854
|
|
89
89
|
palimpzest/validator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
90
|
-
palimpzest/validator/validator.py,sha256=
|
|
91
|
-
palimpzest-0.8.
|
|
92
|
-
palimpzest-0.8.
|
|
93
|
-
palimpzest-0.8.
|
|
94
|
-
palimpzest-0.8.
|
|
95
|
-
palimpzest-0.8.
|
|
90
|
+
palimpzest/validator/validator.py,sha256=vasnvAzEv9tDNLGz2X7MpMJBpn8MqSNelQSXk3X6MBs,16002
|
|
91
|
+
palimpzest-0.8.6.dist-info/licenses/LICENSE,sha256=5GUlHy9lr-Py9kvV38FF1m3yy3NqM18fefuE9wkWumo,1079
|
|
92
|
+
palimpzest-0.8.6.dist-info/METADATA,sha256=NuqbbYGwNa5VlbFP3d59-1KdXA1LjrfChElaSTkmZBk,7048
|
|
93
|
+
palimpzest-0.8.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
94
|
+
palimpzest-0.8.6.dist-info/top_level.txt,sha256=raV06dJUgohefUn3ZyJS2uqp_Y76EOLA9Y2e_fxt8Ew,11
|
|
95
|
+
palimpzest-0.8.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|