palimpzest 1.0.0__tar.gz → 1.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {palimpzest-1.0.0/src/palimpzest.egg-info → palimpzest-1.1.1}/PKG-INFO +1 -1
- {palimpzest-1.0.0 → palimpzest-1.1.1}/pyproject.toml +1 -1
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/elements/groupbysig.py +5 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/models.py +6 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/prompt_factory.py +15 -5
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/execution/execution_strategy.py +7 -3
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/execution/mab_execution_strategy.py +21 -7
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/generators/generators.py +1 -1
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/join.py +94 -63
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/rag.py +6 -5
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/topk.py +24 -5
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/processor/config.py +2 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/utils/progress.py +32 -6
- {palimpzest-1.0.0 → palimpzest-1.1.1/src/palimpzest.egg-info}/PKG-INFO +1 -1
- {palimpzest-1.0.0 → palimpzest-1.1.1}/LICENSE +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/README.md +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/setup.cfg +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/agents/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/agents/compute_agents.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/agents/search_agents.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/constants.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/data/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/data/context.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/data/context_manager.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/data/dataset.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/data/index_dataset.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/data/iter_dataset.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/elements/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/elements/filters.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/elements/records.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/lib/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/core/lib/schemas.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/policy.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/agent_prompts.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/aggregate_prompts.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/context_search.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/convert_prompts.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/critique_and_refine_prompts.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/filter_prompts.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/join_prompts.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/moa_aggregator_prompts.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/moa_proposer_prompts.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/split_merge_prompts.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/split_proposer_prompts.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/utils.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/prompts/validator.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/execution/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/execution/all_sample_execution_strategy.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/execution/execution_strategy_type.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/execution/parallel_execution_strategy.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/execution/single_threaded_execution_strategy.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/generators/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/aggregate.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/compute.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/convert.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/critique_and_refine.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/distinct.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/filter.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/limit.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/logical.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/mixture_of_agents.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/physical.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/project.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/scan.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/search.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/operators/split.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/optimizer/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/optimizer/cost_model.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/optimizer/optimizer.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/optimizer/optimizer_strategy.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/optimizer/optimizer_strategy_type.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/optimizer/plan.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/optimizer/primitives.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/optimizer/rules.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/optimizer/tasks.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/processor/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/processor/query_processor.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/processor/query_processor_factory.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/schemabuilder/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/schemabuilder/schema_builder.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/tools/README.md +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/tools/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/tools/allenpdf.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/tools/pdfparser.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/tools/skema_tools.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/utils/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/utils/env_helpers.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/utils/hash_helpers.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/utils/model_helpers.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/utils/udfs.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/validator/__init__.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/validator/validator.py +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest.egg-info/SOURCES.txt +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest.egg-info/dependency_links.txt +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest.egg-info/requires.txt +0 -0
- {palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: palimpzest
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.1
|
|
4
4
|
Summary: Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language
|
|
5
5
|
Author-email: MIT DSG Semantic Management Lab <michjc@csail.mit.edu>
|
|
6
6
|
Project-URL: homepage, https://palimpzest.org
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "palimpzest"
|
|
3
|
-
version = "1.
|
|
3
|
+
version = "1.1.1"
|
|
4
4
|
description = "Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.12"
|
|
@@ -11,6 +11,11 @@ from palimpzest.core.lib.schemas import create_schema_from_fields
|
|
|
11
11
|
# - construct the correct output schema using the input schema and the group by and aggregation fields
|
|
12
12
|
# - remove/update all other references to GroupBySig in the codebase
|
|
13
13
|
|
|
14
|
+
# TODO:
|
|
15
|
+
# - move the arguments for group_by_fields, agg_funcs, and agg_fields into the Dataset.groupby() operator
|
|
16
|
+
# - construct the correct output schema using the input schema and the group by and aggregation fields
|
|
17
|
+
# - remove/update all other references to GroupBySig in the codebase
|
|
18
|
+
|
|
14
19
|
# signature for a group by aggregate that applies
|
|
15
20
|
# group and aggregation to an input tuple
|
|
16
21
|
class GroupBySig:
|
|
@@ -454,6 +454,12 @@ class BasePlanStats(BaseModel):
|
|
|
454
454
|
"""
|
|
455
455
|
return sum([gen_stats.total_output_tokens for _, gen_stats in self.validation_gen_stats.items()])
|
|
456
456
|
|
|
457
|
+
def get_total_cost_so_far(self) -> float:
|
|
458
|
+
"""
|
|
459
|
+
Get the total cost incurred so far in this plan execution.
|
|
460
|
+
"""
|
|
461
|
+
return self.sum_op_costs() + self.sum_validation_costs()
|
|
462
|
+
|
|
457
463
|
|
|
458
464
|
class PlanStats(BasePlanStats):
|
|
459
465
|
"""
|
|
@@ -830,7 +830,7 @@ class PromptFactory:
|
|
|
830
830
|
field_type = dr.get_field_type(field_name)
|
|
831
831
|
|
|
832
832
|
# audio filepath (or list of audio filepaths)
|
|
833
|
-
if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any]:
|
|
833
|
+
if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any] and field_value is not None:
|
|
834
834
|
with open(field_value, "rb") as f:
|
|
835
835
|
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
836
836
|
audio_content.append(
|
|
@@ -839,6 +839,8 @@ class PromptFactory:
|
|
|
839
839
|
|
|
840
840
|
elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None, list[AudioFilepath] | Any]:
|
|
841
841
|
for audio_filepath in field_value:
|
|
842
|
+
if audio_filepath is None:
|
|
843
|
+
continue
|
|
842
844
|
with open(audio_filepath, "rb") as f:
|
|
843
845
|
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
844
846
|
audio_content.append(
|
|
@@ -846,13 +848,15 @@ class PromptFactory:
|
|
|
846
848
|
)
|
|
847
849
|
|
|
848
850
|
# pre-encoded images (or list of pre-encoded images)
|
|
849
|
-
elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any]:
|
|
851
|
+
elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any] and field_value is not None:
|
|
850
852
|
audio_content.append(
|
|
851
853
|
{"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
|
|
852
854
|
)
|
|
853
855
|
|
|
854
856
|
elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None, list[AudioBase64] | Any]:
|
|
855
857
|
for base64_audio in field_value:
|
|
858
|
+
if base64_audio is None:
|
|
859
|
+
continue
|
|
856
860
|
audio_content.append(
|
|
857
861
|
{"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
|
|
858
862
|
)
|
|
@@ -882,7 +886,7 @@ class PromptFactory:
|
|
|
882
886
|
field_type = dr.get_field_type(field_name)
|
|
883
887
|
|
|
884
888
|
# image filepath (or list of image filepaths)
|
|
885
|
-
if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any]:
|
|
889
|
+
if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any] and field_value is not None:
|
|
886
890
|
with open(field_value, "rb") as f:
|
|
887
891
|
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
888
892
|
image_content.append(
|
|
@@ -891,6 +895,8 @@ class PromptFactory:
|
|
|
891
895
|
|
|
892
896
|
elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]:
|
|
893
897
|
for image_filepath in field_value:
|
|
898
|
+
if image_filepath is None:
|
|
899
|
+
continue
|
|
894
900
|
with open(image_filepath, "rb") as f:
|
|
895
901
|
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
896
902
|
image_content.append(
|
|
@@ -898,21 +904,25 @@ class PromptFactory:
|
|
|
898
904
|
)
|
|
899
905
|
|
|
900
906
|
# image url (or list of image urls)
|
|
901
|
-
elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any]:
|
|
907
|
+
elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any] and field_value is not None:
|
|
902
908
|
image_content.append({"type": "image_url", "image_url": {"url": field_value}})
|
|
903
909
|
|
|
904
910
|
elif field_type.annotation in [list[ImageURL], list[ImageURL] | None, list[ImageURL] | Any]:
|
|
905
911
|
for image_url in field_value:
|
|
912
|
+
if image_url is None:
|
|
913
|
+
continue
|
|
906
914
|
image_content.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
907
915
|
|
|
908
916
|
# pre-encoded images (or list of pre-encoded images)
|
|
909
|
-
elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any]:
|
|
917
|
+
elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any] and field_value is not None:
|
|
910
918
|
image_content.append(
|
|
911
919
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
|
|
912
920
|
)
|
|
913
921
|
|
|
914
922
|
elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None, list[ImageBase64] | Any]:
|
|
915
923
|
for base64_image in field_value:
|
|
924
|
+
if base64_image is None:
|
|
925
|
+
continue
|
|
916
926
|
image_content.append(
|
|
917
927
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
|
918
928
|
)
|
|
@@ -82,14 +82,16 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
82
82
|
"""
|
|
83
83
|
def __init__(
|
|
84
84
|
self,
|
|
85
|
-
k: int,
|
|
86
|
-
j: int,
|
|
87
|
-
sample_budget: int,
|
|
88
85
|
policy: Policy,
|
|
86
|
+
k: int = 6,
|
|
87
|
+
j: int = 4,
|
|
88
|
+
sample_budget: int = 100,
|
|
89
|
+
sample_cost_budget: float | None = None,
|
|
89
90
|
priors: dict | None = None,
|
|
90
91
|
use_final_op_quality: bool = False,
|
|
91
92
|
seed: int = 42,
|
|
92
93
|
exp_name: str | None = None,
|
|
94
|
+
dont_use_priors: bool = False,
|
|
93
95
|
*args,
|
|
94
96
|
**kwargs,
|
|
95
97
|
):
|
|
@@ -97,12 +99,14 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
|
|
|
97
99
|
self.k = k
|
|
98
100
|
self.j = j
|
|
99
101
|
self.sample_budget = sample_budget
|
|
102
|
+
self.sample_cost_budget = sample_cost_budget
|
|
100
103
|
self.policy = policy
|
|
101
104
|
self.priors = priors
|
|
102
105
|
self.use_final_op_quality = use_final_op_quality
|
|
103
106
|
self.seed = seed
|
|
104
107
|
self.rng = np.random.default_rng(seed=seed)
|
|
105
108
|
self.exp_name = exp_name
|
|
109
|
+
self.dont_use_priors = dont_use_priors
|
|
106
110
|
|
|
107
111
|
# general cache which maps hash(logical_op_id, phys_op_id, hash(input)) --> record_set
|
|
108
112
|
self.cache: dict[int, DataRecordSet] = {}
|
{palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/execution/mab_execution_strategy.py
RENAMED
|
@@ -44,6 +44,7 @@ class OpFrontier:
|
|
|
44
44
|
seed: int,
|
|
45
45
|
policy: Policy,
|
|
46
46
|
priors: dict | None = None,
|
|
47
|
+
dont_use_priors: bool = False,
|
|
47
48
|
):
|
|
48
49
|
# set k and j, which are the initial number of operators in the frontier and the
|
|
49
50
|
# initial number of records to sample for each frontier operator
|
|
@@ -51,6 +52,7 @@ class OpFrontier:
|
|
|
51
52
|
self.j = j
|
|
52
53
|
self.source_indices = source_indices
|
|
53
54
|
self.root_dataset_ids = root_dataset_ids
|
|
55
|
+
self.dont_use_priors = dont_use_priors
|
|
54
56
|
|
|
55
57
|
# store the policy that we are optimizing under
|
|
56
58
|
self.policy = policy
|
|
@@ -68,6 +70,7 @@ class OpFrontier:
|
|
|
68
70
|
is_llm_filter = isinstance(sample_op, LLMFilter)
|
|
69
71
|
is_llm_topk = isinstance(sample_op, TopKOp) and isinstance(sample_op.index, Collection)
|
|
70
72
|
self.is_llm_op = is_llm_convert or is_llm_filter or is_llm_topk or self.is_llm_join
|
|
73
|
+
self.is_llm_convert = is_llm_convert
|
|
71
74
|
|
|
72
75
|
# get order in which we will sample physical operators for this logical operator
|
|
73
76
|
sample_op_indices = self._get_op_index_order(op_set, seed)
|
|
@@ -190,7 +193,9 @@ class OpFrontier:
|
|
|
190
193
|
Returns a list of indices for the operators in the op_set.
|
|
191
194
|
"""
|
|
192
195
|
# if this is not an llm-operator, we simply return the indices in random order
|
|
193
|
-
if not self.is_llm_op:
|
|
196
|
+
if not self.is_llm_op or self.dont_use_priors:
|
|
197
|
+
if self.is_llm_convert:
|
|
198
|
+
print("Using NO PRIORS for operator sampling order")
|
|
194
199
|
rng = np.random.default_rng(seed=seed)
|
|
195
200
|
op_indices = np.arange(len(op_set))
|
|
196
201
|
rng.shuffle(op_indices)
|
|
@@ -198,6 +203,8 @@ class OpFrontier:
|
|
|
198
203
|
|
|
199
204
|
# if this is an llm-operator, but we do not have priors, we first compute naive priors
|
|
200
205
|
if self.priors is None or any([op_id not in self.priors for op_id in map(lambda op: op.get_op_id(), op_set)]):
|
|
206
|
+
if self.is_llm_convert:
|
|
207
|
+
print("Using NAIVE PRIORS for operator sampling order")
|
|
201
208
|
self.priors = self._compute_naive_priors(op_set)
|
|
202
209
|
|
|
203
210
|
# NOTE: self.priors is a dictionary with format:
|
|
@@ -680,6 +687,9 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
680
687
|
|
|
681
688
|
return max_quality_op
|
|
682
689
|
|
|
690
|
+
def _compute_termination_condition(self, samples_drawn: int, sampling_cost: float) -> bool:
|
|
691
|
+
return (samples_drawn >= self.sample_budget) if self.sample_cost_budget is None else (sampling_cost >= self.sample_cost_budget)
|
|
692
|
+
|
|
683
693
|
def _execute_sentinel_plan(
|
|
684
694
|
self,
|
|
685
695
|
plan: SentinelPlan,
|
|
@@ -688,8 +698,8 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
688
698
|
plan_stats: SentinelPlanStats,
|
|
689
699
|
) -> SentinelPlanStats:
|
|
690
700
|
# sample records and operators and update the frontiers
|
|
691
|
-
samples_drawn = 0
|
|
692
|
-
while
|
|
701
|
+
samples_drawn, sampling_cost = 0, 0.0
|
|
702
|
+
while not self._compute_termination_condition(samples_drawn, sampling_cost):
|
|
693
703
|
# pre-compute the set of source indices which will need to be sampled
|
|
694
704
|
source_indices_to_sample = set()
|
|
695
705
|
for op_frontier in op_frontiers.values():
|
|
@@ -732,6 +742,9 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
732
742
|
}
|
|
733
743
|
source_indices_to_all_record_sets, val_gen_stats = self._score_quality(validator, source_indices_to_all_record_sets)
|
|
734
744
|
|
|
745
|
+
# update the progress manager with validation cost
|
|
746
|
+
self.progress_manager.incr_overall_progress_cost(val_gen_stats.cost_per_record)
|
|
747
|
+
|
|
735
748
|
# remove records that were read from the execution cache before adding to record op stats
|
|
736
749
|
new_record_op_stats = []
|
|
737
750
|
for _, record_set_tuples in source_indices_to_record_set_tuples.items():
|
|
@@ -742,6 +755,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
742
755
|
# update plan stats
|
|
743
756
|
plan_stats.add_record_op_stats(unique_logical_op_id, new_record_op_stats)
|
|
744
757
|
plan_stats.add_validation_gen_stats(unique_logical_op_id, val_gen_stats)
|
|
758
|
+
sampling_cost = plan_stats.get_total_cost_so_far()
|
|
745
759
|
|
|
746
760
|
# provide the best record sets as inputs to the next logical operator
|
|
747
761
|
next_unique_logical_op_id = plan.get_next_unique_logical_op_id(unique_logical_op_id)
|
|
@@ -798,7 +812,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
798
812
|
assert len(root_dataset_ids) == 1, f"Scan for {sample_op} has {len(root_dataset_ids)} > 1 root dataset ids"
|
|
799
813
|
root_dataset_id = root_dataset_ids[0]
|
|
800
814
|
source_indices = dataset_id_to_shuffled_source_indices[root_dataset_id]
|
|
801
|
-
op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
|
|
815
|
+
op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
|
|
802
816
|
elif isinstance(sample_op, JoinOp):
|
|
803
817
|
assert len(source_unique_logical_op_ids) == 2, f"Join for {sample_op} has {len(source_unique_logical_op_ids)} != 2 source logical operators"
|
|
804
818
|
left_source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
|
|
@@ -807,13 +821,13 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
|
|
|
807
821
|
for left_source_idx in left_source_indices:
|
|
808
822
|
for right_source_idx in right_source_indices:
|
|
809
823
|
source_indices.append((left_source_idx, right_source_idx))
|
|
810
|
-
op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
|
|
824
|
+
op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
|
|
811
825
|
else:
|
|
812
826
|
source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
|
|
813
|
-
op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
|
|
827
|
+
op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
|
|
814
828
|
|
|
815
829
|
# initialize and start the progress manager
|
|
816
|
-
self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, progress=self.progress)
|
|
830
|
+
self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, sample_cost_budget=self.sample_cost_budget, progress=self.progress)
|
|
817
831
|
self.progress_manager.start()
|
|
818
832
|
|
|
819
833
|
# NOTE: we must handle progress manager outside of _execute_sentinel_plan to ensure that it is shut down correctly;
|
|
@@ -338,7 +338,7 @@ class Generator(Generic[ContextType, InputType]):
|
|
|
338
338
|
reasoning_effort = "minimal" if self.reasoning_effort is None else self.reasoning_effort
|
|
339
339
|
completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs}
|
|
340
340
|
if self.model.is_vllm_model():
|
|
341
|
-
completion_kwargs = {"api_base": self.api_base, "api_key": os.environ.get("VLLM_API_KEY", "fake-api-key") **completion_kwargs}
|
|
341
|
+
completion_kwargs = {"api_base": self.api_base, "api_key": os.environ.get("VLLM_API_KEY", "fake-api-key"), **completion_kwargs}
|
|
342
342
|
completion = litellm.completion(model=self.model_name, messages=messages, **completion_kwargs)
|
|
343
343
|
end_time = time.time()
|
|
344
344
|
logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds")
|
|
@@ -27,6 +27,25 @@ from palimpzest.query.generators.generators import Generator
|
|
|
27
27
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
class Singleton:
|
|
31
|
+
def __new__(cls, *args, **kw):
|
|
32
|
+
if not hasattr(cls, '_instance'):
|
|
33
|
+
orig = super(Singleton, cls) # noqa: UP008
|
|
34
|
+
cls._instance = orig.__new__(cls, *args, **kw)
|
|
35
|
+
return cls._instance
|
|
36
|
+
|
|
37
|
+
class Locks(Singleton):
|
|
38
|
+
model = None
|
|
39
|
+
clip_lock = threading.Lock()
|
|
40
|
+
exec_lock = threading.Lock()
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def get_model(cls, model_name: str):
|
|
44
|
+
with cls.clip_lock:
|
|
45
|
+
if cls.model is None:
|
|
46
|
+
cls.model = SentenceTransformer(model_name)
|
|
47
|
+
return cls.model
|
|
48
|
+
|
|
30
49
|
def compute_similarity(left_embedding: list[float], right_embedding: list[float]) -> float:
|
|
31
50
|
"""
|
|
32
51
|
Compute the similarity between two embeddings using cosine similarity.
|
|
@@ -487,8 +506,7 @@ class EmbeddingJoin(LLMJoin):
|
|
|
487
506
|
if field_name.split(".")[-1] in self.get_input_fields()
|
|
488
507
|
])
|
|
489
508
|
self.embedding_model = Model.TEXT_EMBEDDING_3_SMALL if self.text_only else Model.CLIP_VIT_B_32
|
|
490
|
-
self.
|
|
491
|
-
self._lock = threading.Lock()
|
|
509
|
+
self.locks = Locks()
|
|
492
510
|
|
|
493
511
|
# keep track of embedding costs that could not be amortized if no output records were produced
|
|
494
512
|
self.residual_embedding_cost = 0.0
|
|
@@ -560,12 +578,6 @@ class EmbeddingJoin(LLMJoin):
|
|
|
560
578
|
quality=quality,
|
|
561
579
|
)
|
|
562
580
|
|
|
563
|
-
def _get_clip_model(self):
|
|
564
|
-
with self._lock:
|
|
565
|
-
if self.clip_model is None:
|
|
566
|
-
self.clip_model = SentenceTransformer(self.embedding_model.value)
|
|
567
|
-
return self.clip_model
|
|
568
|
-
|
|
569
581
|
def _compute_embeddings(self, candidates: list[DataRecord], input_fields: list[str]) -> tuple[np.ndarray, GenerationStats]:
|
|
570
582
|
# return empty array and empty stats if no candidates
|
|
571
583
|
if len(candidates) == 0:
|
|
@@ -581,7 +593,7 @@ class EmbeddingJoin(LLMJoin):
|
|
|
581
593
|
total_input_tokens = response.usage.total_tokens
|
|
582
594
|
embeddings = np.array([item.embedding for item in response.data])
|
|
583
595
|
else:
|
|
584
|
-
model = self.
|
|
596
|
+
model = self.locks.get_model(self.embedding_model.value)
|
|
585
597
|
embeddings = np.zeros((len(candidates), 512)) # CLIP embeddings are 512-dimensional
|
|
586
598
|
num_input_fields_present = 0
|
|
587
599
|
for field in input_fields:
|
|
@@ -623,7 +635,7 @@ class EmbeddingJoin(LLMJoin):
|
|
|
623
635
|
output_record, output_record_op_stats = super()._process_join_candidate_pair(left_candidate, right_candidate, gen_kwargs)
|
|
624
636
|
return output_record, output_record_op_stats, embedding_sim
|
|
625
637
|
|
|
626
|
-
def _process_join_candidate_with_sim(self, left_candidate: DataRecord, right_candidate: DataRecord, passed_operator: bool) -> tuple[DataRecord, RecordOpStats]:
|
|
638
|
+
def _process_join_candidate_with_sim(self, left_candidate: DataRecord, right_candidate: DataRecord, embedding_sim: float, passed_operator: bool) -> tuple[DataRecord, RecordOpStats]:
|
|
627
639
|
# compute output record and add to output_records
|
|
628
640
|
join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
|
|
629
641
|
join_dr._passed_operator = passed_operator
|
|
@@ -656,7 +668,7 @@ class EmbeddingJoin(LLMJoin):
|
|
|
656
668
|
op_details={k: str(v) for k, v in self.get_id_params().items()},
|
|
657
669
|
)
|
|
658
670
|
|
|
659
|
-
return join_dr, record_op_stats
|
|
671
|
+
return join_dr, record_op_stats, embedding_sim
|
|
660
672
|
|
|
661
673
|
def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
|
|
662
674
|
# get the set of input fields from both records in the join
|
|
@@ -690,36 +702,50 @@ class EmbeddingJoin(LLMJoin):
|
|
|
690
702
|
output_records, output_record_op_stats, num_inputs_processed = [], [], 0
|
|
691
703
|
|
|
692
704
|
# draw samples until num_samples is reached
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
705
|
+
with self.locks.exec_lock:
|
|
706
|
+
if self.samples_drawn < self.num_samples:
|
|
707
|
+
samples_to_draw = min(self.num_samples - self.samples_drawn, len(join_candidates))
|
|
708
|
+
join_candidate_samples = join_candidates[:samples_to_draw]
|
|
709
|
+
join_candidates = join_candidates[samples_to_draw:]
|
|
710
|
+
|
|
711
|
+
# apply the generator to each pair of candidates
|
|
712
|
+
with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
|
|
713
|
+
futures = [
|
|
714
|
+
executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim)
|
|
715
|
+
for left_candidate, right_candidate, embedding_sim in join_candidate_samples
|
|
716
|
+
]
|
|
717
|
+
|
|
718
|
+
# collect results as they complete
|
|
719
|
+
similarities, joined = [], []
|
|
720
|
+
for future in as_completed(futures):
|
|
721
|
+
self.join_idx += 1
|
|
722
|
+
join_output_record, join_output_record_op_stats, embedding_sim = future.result()
|
|
723
|
+
output_records.append(join_output_record)
|
|
724
|
+
output_record_op_stats.append(join_output_record_op_stats)
|
|
725
|
+
similarities.append(embedding_sim)
|
|
726
|
+
joined.append(join_output_record._passed_operator)
|
|
727
|
+
print(f"{self.join_idx} JOINED")
|
|
728
|
+
|
|
729
|
+
# sort join results by embedding similarity
|
|
730
|
+
sorted_sim_join_tuples = sorted(zip(similarities, joined), key=lambda x: x[0])
|
|
731
|
+
|
|
732
|
+
# compute threshold below which no records joined
|
|
733
|
+
for embedding_sim, records_joined in sorted_sim_join_tuples:
|
|
734
|
+
if records_joined:
|
|
735
|
+
break
|
|
736
|
+
if not records_joined and embedding_sim > self.max_non_matching_sim:
|
|
737
|
+
self.max_non_matching_sim = embedding_sim
|
|
738
|
+
|
|
739
|
+
# compute threshold above which all records joined
|
|
740
|
+
for embedding_sim, records_joined in reversed(sorted_sim_join_tuples):
|
|
741
|
+
if not records_joined:
|
|
742
|
+
break
|
|
743
|
+
if records_joined and embedding_sim < self.min_matching_sim:
|
|
744
|
+
self.min_matching_sim = embedding_sim
|
|
745
|
+
|
|
746
|
+
# update samples drawn and num_inputs_processed
|
|
747
|
+
self.samples_drawn += samples_to_draw
|
|
748
|
+
num_inputs_processed += samples_to_draw
|
|
723
749
|
|
|
724
750
|
# process remaining candidates based on embedding similarity
|
|
725
751
|
if len(join_candidates) > 0:
|
|
@@ -727,43 +753,48 @@ class EmbeddingJoin(LLMJoin):
|
|
|
727
753
|
with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
|
|
728
754
|
futures = []
|
|
729
755
|
for left_candidate, right_candidate, embedding_sim in join_candidates:
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
)
|
|
756
|
+
# if the embedding similarity is lower than the threshold below which no records joined,
|
|
757
|
+
# then we can skip the LLM call and mark the records as not joined
|
|
758
|
+
if embedding_sim < self.max_non_matching_sim:
|
|
759
|
+
futures.append(executor.submit(self._process_join_candidate_with_sim, left_candidate, right_candidate, embedding_sim, passed_operator=False))
|
|
735
760
|
|
|
736
|
-
if
|
|
737
|
-
|
|
761
|
+
# if the embedding similarity is higher than the threshold above which all records joined,
|
|
762
|
+
# then we can skip the LLM call and mark the records as joined
|
|
763
|
+
elif embedding_sim > self.min_matching_sim:
|
|
764
|
+
futures.append(executor.submit(self._process_join_candidate_with_sim, left_candidate, right_candidate, embedding_sim, passed_operator=True))
|
|
738
765
|
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
output_records.append(output_record)
|
|
743
|
-
output_record_op_stats.append(record_op_stats)
|
|
744
|
-
print(f"{self.join_idx} SKIPPED (low sim: {embedding_sim:.4f} < {self.min_matching_sim:.4f})")
|
|
745
|
-
|
|
746
|
-
elif embedding_sim > self.max_non_matching_sim:
|
|
747
|
-
self.join_idx += 1
|
|
748
|
-
output_record, record_op_stats = self._process_join_candidate_with_sim(left_candidate, right_candidate, passed_operator=True)
|
|
749
|
-
output_records.append(output_record)
|
|
750
|
-
output_record_op_stats.append(record_op_stats)
|
|
751
|
-
print(f"{self.join_idx} JOINED (high sim: {embedding_sim:.4f} > {self.max_non_matching_sim:.4f})")
|
|
766
|
+
# otherwise, we will process the LLM call
|
|
767
|
+
else:
|
|
768
|
+
futures.append(executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim))
|
|
752
769
|
|
|
753
770
|
num_inputs_processed += 1
|
|
754
771
|
|
|
755
772
|
# collect results as they complete
|
|
773
|
+
similarities, joined = [], []
|
|
756
774
|
for future in as_completed(futures):
|
|
757
775
|
self.join_idx += 1
|
|
758
776
|
join_output_record, join_output_record_op_stats, embedding_sim = future.result()
|
|
759
777
|
output_records.append(join_output_record)
|
|
760
778
|
output_record_op_stats.append(join_output_record_op_stats)
|
|
779
|
+
similarities.append(embedding_sim)
|
|
780
|
+
joined.append(join_output_record._passed_operator)
|
|
761
781
|
print(f"{self.join_idx} JOINED")
|
|
762
782
|
|
|
763
|
-
|
|
764
|
-
|
|
783
|
+
### update thresholds if there are llm calls which incrementally squeeze the boundaries ###
|
|
784
|
+
# sort join results by embedding similarity
|
|
785
|
+
sorted_sim_join_tuples = sorted(zip(similarities, joined), key=lambda x: x[0])
|
|
786
|
+
|
|
787
|
+
# potentially update threshold below which no records joined
|
|
788
|
+
for embedding_sim, records_joined in sorted_sim_join_tuples:
|
|
789
|
+
if records_joined:
|
|
790
|
+
break
|
|
765
791
|
if not records_joined and embedding_sim > self.max_non_matching_sim:
|
|
766
792
|
self.max_non_matching_sim = embedding_sim
|
|
793
|
+
|
|
794
|
+
# potentially update threshold above which all records joined
|
|
795
|
+
for embedding_sim, records_joined in reversed(sorted_sim_join_tuples):
|
|
796
|
+
if not records_joined:
|
|
797
|
+
break
|
|
767
798
|
if records_joined and embedding_sim < self.min_matching_sim:
|
|
768
799
|
self.min_matching_sim = embedding_sim
|
|
769
800
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import time
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
6
|
from numpy import dot
|
|
6
7
|
from numpy.linalg import norm
|
|
@@ -153,9 +154,9 @@ class RAGConvert(LLMConvert):
|
|
|
153
154
|
field = candidate.get_field_type(field_name)
|
|
154
155
|
|
|
155
156
|
# skip this field if it is not a string or a list of strings
|
|
156
|
-
is_string_field = field.annotation in [str, str | None]
|
|
157
|
-
is_list_string_field = field.annotation in [list[str], list[str] | None]
|
|
158
|
-
if not (is_string_field or is_list_string_field):
|
|
157
|
+
is_string_field = field.annotation in [str, str | None, str | Any]
|
|
158
|
+
is_list_string_field = field.annotation in [list[str], list[str] | None, list[str] | Any]
|
|
159
|
+
if not (is_string_field or is_list_string_field) or candidate[field_name] is None:
|
|
159
160
|
continue
|
|
160
161
|
|
|
161
162
|
# if this is a list of strings, join the strings
|
|
@@ -358,8 +359,8 @@ class RAGFilter(LLMFilter):
|
|
|
358
359
|
field = candidate.get_field_type(field_name)
|
|
359
360
|
|
|
360
361
|
# skip this field if it is not a string or a list of strings
|
|
361
|
-
is_string_field = field.annotation in [str, str | None]
|
|
362
|
-
is_list_string_field = field.annotation in [list[str], list[str] | None]
|
|
362
|
+
is_string_field = field.annotation in [str, str | None, str | Any]
|
|
363
|
+
is_list_string_field = field.annotation in [list[str], list[str] | None, list[str] | Any]
|
|
363
364
|
if not (is_string_field or is_list_string_field):
|
|
364
365
|
continue
|
|
365
366
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
+
import threading
|
|
4
5
|
import time
|
|
5
6
|
from typing import Callable
|
|
6
7
|
|
|
@@ -17,6 +18,24 @@ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, Recor
|
|
|
17
18
|
from palimpzest.query.operators.physical import PhysicalOperator
|
|
18
19
|
|
|
19
20
|
|
|
21
|
+
class Singleton:
|
|
22
|
+
def __new__(cls, *args, **kw):
|
|
23
|
+
if not hasattr(cls, '_instance'):
|
|
24
|
+
orig = super(Singleton, cls) # noqa: UP008
|
|
25
|
+
cls._instance = orig.__new__(cls, *args, **kw)
|
|
26
|
+
return cls._instance
|
|
27
|
+
|
|
28
|
+
class ClipModel(Singleton):
|
|
29
|
+
model = None
|
|
30
|
+
lock = threading.Lock()
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def get_model(cls, model_name: str):
|
|
34
|
+
with cls.lock:
|
|
35
|
+
if cls.model is None:
|
|
36
|
+
cls.model = SentenceTransformer(model_name)
|
|
37
|
+
return cls.model
|
|
38
|
+
|
|
20
39
|
class TopKOp(PhysicalOperator):
|
|
21
40
|
def __init__(
|
|
22
41
|
self,
|
|
@@ -56,6 +75,7 @@ class TopKOp(PhysicalOperator):
|
|
|
56
75
|
self.output_attrs = output_attrs
|
|
57
76
|
self.search_func = search_func if search_func is not None else self.default_search_func
|
|
58
77
|
self.k = k
|
|
78
|
+
self.clip_model = ClipModel()
|
|
59
79
|
|
|
60
80
|
def __str__(self):
|
|
61
81
|
op = super().__str__()
|
|
@@ -185,7 +205,6 @@ class TopKOp(PhysicalOperator):
|
|
|
185
205
|
# construct and return the record set
|
|
186
206
|
return DataRecordSet(drs, record_op_stats_lst)
|
|
187
207
|
|
|
188
|
-
|
|
189
208
|
def __call__(self, candidate: DataRecord) -> DataRecordSet:
|
|
190
209
|
start_time = time.time()
|
|
191
210
|
|
|
@@ -209,9 +228,9 @@ class TopKOp(PhysicalOperator):
|
|
|
209
228
|
inputs, gen_stats = None, GenerationStats()
|
|
210
229
|
if isinstance(self.index, Collection):
|
|
211
230
|
uses_openai_embedding_fcn = isinstance(self.index._embedding_function, OpenAIEmbeddingFunction)
|
|
212
|
-
|
|
231
|
+
uses_clip_model = isinstance(self.index._embedding_function, SentenceTransformerEmbeddingFunction)
|
|
213
232
|
error_msg = "ChromaDB index must use OpenAI or SentenceTransformer embedding function; see: https://docs.trychroma.com/integrations/embedding-models/openai"
|
|
214
|
-
assert uses_openai_embedding_fcn or
|
|
233
|
+
assert uses_openai_embedding_fcn or uses_clip_model, error_msg
|
|
215
234
|
|
|
216
235
|
model_name = self.index._embedding_function.model_name if uses_openai_embedding_fcn else "clip-ViT-B-32"
|
|
217
236
|
err_msg = f"For Chromadb, we currently only support `text-embedding-3-small` and `clip-ViT-B-32`; your index uses: {model_name}"
|
|
@@ -228,8 +247,8 @@ class TopKOp(PhysicalOperator):
|
|
|
228
247
|
total_input_tokens = response.usage.total_tokens
|
|
229
248
|
inputs = [item.embedding for item in response.data]
|
|
230
249
|
|
|
231
|
-
elif
|
|
232
|
-
model =
|
|
250
|
+
elif uses_clip_model:
|
|
251
|
+
model = self.clip_model.get_model(model_name)
|
|
233
252
|
inputs = model.encode(query)
|
|
234
253
|
|
|
235
254
|
embed_total_time = time.time() - embed_start_time
|
|
@@ -44,9 +44,11 @@ class QueryProcessorConfig(BaseModel):
|
|
|
44
44
|
k: int = Field(default=6)
|
|
45
45
|
j: int = Field(default=4)
|
|
46
46
|
sample_budget: int = Field(default=100)
|
|
47
|
+
sample_cost_budget: float | None = Field(default=None)
|
|
47
48
|
seed: int = Field(default=42)
|
|
48
49
|
exp_name: str | None = Field(default=None)
|
|
49
50
|
priors: dict | None = Field(default=None)
|
|
51
|
+
dont_use_priors: bool = Field(default=False)
|
|
50
52
|
|
|
51
53
|
def to_dict(self) -> dict:
|
|
52
54
|
"""Convert the config to a dict representation."""
|
|
@@ -283,7 +283,7 @@ class PZProgressManager(ProgressManager):
|
|
|
283
283
|
self.unique_full_op_id_to_stats[unique_full_op_id].memory_usage_mb = get_memory_usage()
|
|
284
284
|
|
|
285
285
|
class PZSentinelProgressManager(ProgressManager):
|
|
286
|
-
def __init__(self, plan: SentinelPlan, sample_budget: int):
|
|
286
|
+
def __init__(self, plan: SentinelPlan, sample_budget: int | None, sample_cost_budget: float | None):
|
|
287
287
|
# overall progress bar
|
|
288
288
|
self.overall_progress = RichProgress(
|
|
289
289
|
SpinnerColumn(),
|
|
@@ -298,7 +298,9 @@ class PZSentinelProgressManager(ProgressManager):
|
|
|
298
298
|
refresh_per_second=10,
|
|
299
299
|
expand=True, # Use full width
|
|
300
300
|
)
|
|
301
|
-
self.
|
|
301
|
+
self.use_cost_budget = sample_cost_budget is not None
|
|
302
|
+
total = sample_cost_budget if self.use_cost_budget else sample_budget
|
|
303
|
+
self.overall_task_id = self.overall_progress.add_task("", total=total, cost=0.0, recent="")
|
|
302
304
|
|
|
303
305
|
# logical operator progress bars
|
|
304
306
|
self.op_progress = RichProgress(
|
|
@@ -334,6 +336,9 @@ class PZSentinelProgressManager(ProgressManager):
|
|
|
334
336
|
# initialize start time
|
|
335
337
|
self.start_time = None
|
|
336
338
|
|
|
339
|
+
# initialize validation cost
|
|
340
|
+
self.validation_cost = 0.0
|
|
341
|
+
|
|
337
342
|
# add a task to the progress manager for each operator in the plan
|
|
338
343
|
for topo_idx, (logical_op_id, op_set) in enumerate(plan):
|
|
339
344
|
unique_logical_op_id = f"{topo_idx}-{logical_op_id}"
|
|
@@ -387,15 +392,34 @@ class PZSentinelProgressManager(ProgressManager):
|
|
|
387
392
|
# start progress bars
|
|
388
393
|
self.live_display.start()
|
|
389
394
|
|
|
395
|
+
def incr_overall_progress_cost(self, cost_delta: float):
|
|
396
|
+
"""Advance the overall progress bar by the given cost delta"""
|
|
397
|
+
self.validation_cost += cost_delta
|
|
398
|
+
self.overall_progress.update(
|
|
399
|
+
self.overall_task_id,
|
|
400
|
+
advance=cost_delta,
|
|
401
|
+
cost=sum(stats.total_cost for _, stats in self.unique_logical_op_id_to_stats.items()) + self.validation_cost,
|
|
402
|
+
refresh=True,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# force the live display to refresh
|
|
406
|
+
self.live_display.refresh()
|
|
407
|
+
|
|
390
408
|
def incr(self, unique_logical_op_id: str, num_samples: int, display_text: str | None = None, **kwargs):
|
|
391
409
|
# TODO: (above) organize progress bars into a Live / Table / Panel or something
|
|
392
410
|
# get the task for the given operation
|
|
393
411
|
task = self.unique_logical_op_id_to_task.get(unique_logical_op_id)
|
|
394
412
|
|
|
413
|
+
# store the cost before updating stats
|
|
414
|
+
previous_total_cost = self.unique_logical_op_id_to_stats[unique_logical_op_id].total_cost
|
|
415
|
+
|
|
395
416
|
# update statistics with any additional keyword arguments
|
|
396
417
|
if kwargs != {}:
|
|
397
418
|
self.update_stats(unique_logical_op_id, **kwargs)
|
|
398
419
|
|
|
420
|
+
# compute the cost delta
|
|
421
|
+
cost_delta = self.unique_logical_op_id_to_stats[unique_logical_op_id].total_cost - previous_total_cost
|
|
422
|
+
|
|
399
423
|
# update progress bar and recent text in one update
|
|
400
424
|
if display_text is not None:
|
|
401
425
|
self.unique_logical_op_id_to_stats[unique_logical_op_id].recent_text = display_text
|
|
@@ -414,10 +438,11 @@ class PZSentinelProgressManager(ProgressManager):
|
|
|
414
438
|
)
|
|
415
439
|
|
|
416
440
|
# advance the overall progress bar
|
|
441
|
+
advance = cost_delta if self.use_cost_budget else num_samples
|
|
417
442
|
self.overall_progress.update(
|
|
418
443
|
self.overall_task_id,
|
|
419
|
-
advance=
|
|
420
|
-
cost=sum(stats.total_cost for _, stats in self.unique_logical_op_id_to_stats.items()),
|
|
444
|
+
advance=advance,
|
|
445
|
+
cost=sum(stats.total_cost for _, stats in self.unique_logical_op_id_to_stats.items()) + self.validation_cost,
|
|
421
446
|
refresh=True,
|
|
422
447
|
)
|
|
423
448
|
|
|
@@ -451,6 +476,7 @@ def create_progress_manager(
|
|
|
451
476
|
plan: PhysicalPlan | SentinelPlan,
|
|
452
477
|
num_samples: int | None = None,
|
|
453
478
|
sample_budget: int | None = None,
|
|
479
|
+
sample_cost_budget: float | None = None,
|
|
454
480
|
progress: bool = True,
|
|
455
481
|
) -> ProgressManager:
|
|
456
482
|
"""Factory function to create appropriate progress manager based on environment"""
|
|
@@ -458,7 +484,7 @@ def create_progress_manager(
|
|
|
458
484
|
return MockProgressManager(plan, num_samples)
|
|
459
485
|
|
|
460
486
|
if isinstance(plan, SentinelPlan):
|
|
461
|
-
assert sample_budget is not None, "Sample budget must be specified for SentinelPlan progress manager"
|
|
462
|
-
return PZSentinelProgressManager(plan, sample_budget)
|
|
487
|
+
assert sample_budget is not None or sample_cost_budget is not None, "Sample budget must be specified for SentinelPlan progress manager"
|
|
488
|
+
return PZSentinelProgressManager(plan, sample_budget, sample_cost_budget)
|
|
463
489
|
|
|
464
490
|
return PZProgressManager(plan, num_samples)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: palimpzest
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.1.1
|
|
4
4
|
Summary: Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language
|
|
5
5
|
Author-email: MIT DSG Semantic Management Lab <michjc@csail.mit.edu>
|
|
6
6
|
Project-URL: homepage, https://palimpzest.org
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/execution/execution_strategy_type.py
RENAMED
|
File without changes
|
{palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/execution/parallel_execution_strategy.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/optimizer/optimizer_strategy_type.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{palimpzest-1.0.0 → palimpzest-1.1.1}/src/palimpzest/query/processor/query_processor_factory.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|