palimpzest 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -830,7 +830,7 @@ class PromptFactory:
830
830
  field_type = dr.get_field_type(field_name)
831
831
 
832
832
  # audio filepath (or list of audio filepaths)
833
- if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any]:
833
+ if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any] and field_value is not None:
834
834
  with open(field_value, "rb") as f:
835
835
  base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
836
836
  audio_content.append(
@@ -839,6 +839,8 @@ class PromptFactory:
839
839
 
840
840
  elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None, list[AudioFilepath] | Any]:
841
841
  for audio_filepath in field_value:
842
+ if audio_filepath is None:
843
+ continue
842
844
  with open(audio_filepath, "rb") as f:
843
845
  base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
844
846
  audio_content.append(
@@ -846,13 +848,15 @@ class PromptFactory:
846
848
  )
847
849
 
848
850
  # pre-encoded images (or list of pre-encoded images)
849
- elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any]:
851
+ elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any] and field_value is not None:
850
852
  audio_content.append(
851
853
  {"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
852
854
  )
853
855
 
854
856
  elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None, list[AudioBase64] | Any]:
855
857
  for base64_audio in field_value:
858
+ if base64_audio is None:
859
+ continue
856
860
  audio_content.append(
857
861
  {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
858
862
  )
@@ -882,7 +886,7 @@ class PromptFactory:
882
886
  field_type = dr.get_field_type(field_name)
883
887
 
884
888
  # image filepath (or list of image filepaths)
885
- if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any]:
889
+ if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any] and field_value is not None:
886
890
  with open(field_value, "rb") as f:
887
891
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
888
892
  image_content.append(
@@ -891,6 +895,8 @@ class PromptFactory:
891
895
 
892
896
  elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]:
893
897
  for image_filepath in field_value:
898
+ if image_filepath is None:
899
+ continue
894
900
  with open(image_filepath, "rb") as f:
895
901
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
896
902
  image_content.append(
@@ -898,21 +904,25 @@ class PromptFactory:
898
904
  )
899
905
 
900
906
  # image url (or list of image urls)
901
- elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any]:
907
+ elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any] and field_value is not None:
902
908
  image_content.append({"type": "image_url", "image_url": {"url": field_value}})
903
909
 
904
910
  elif field_type.annotation in [list[ImageURL], list[ImageURL] | None, list[ImageURL] | Any]:
905
911
  for image_url in field_value:
912
+ if image_url is None:
913
+ continue
906
914
  image_content.append({"type": "image_url", "image_url": {"url": image_url}})
907
915
 
908
916
  # pre-encoded images (or list of pre-encoded images)
909
- elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any]:
917
+ elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any] and field_value is not None:
910
918
  image_content.append(
911
919
  {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
912
920
  )
913
921
 
914
922
  elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None, list[ImageBase64] | Any]:
915
923
  for base64_image in field_value:
924
+ if base64_image is None:
925
+ continue
916
926
  image_content.append(
917
927
  {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
918
928
  )
@@ -91,6 +91,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
91
91
  use_final_op_quality: bool = False,
92
92
  seed: int = 42,
93
93
  exp_name: str | None = None,
94
+ dont_use_priors: bool = False,
94
95
  *args,
95
96
  **kwargs,
96
97
  ):
@@ -105,6 +106,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
105
106
  self.seed = seed
106
107
  self.rng = np.random.default_rng(seed=seed)
107
108
  self.exp_name = exp_name
109
+ self.dont_use_priors = dont_use_priors
108
110
 
109
111
  # general cache which maps hash(logical_op_id, phys_op_id, hash(input)) --> record_set
110
112
  self.cache: dict[int, DataRecordSet] = {}
@@ -44,6 +44,7 @@ class OpFrontier:
44
44
  seed: int,
45
45
  policy: Policy,
46
46
  priors: dict | None = None,
47
+ dont_use_priors: bool = False,
47
48
  ):
48
49
  # set k and j, which are the initial number of operators in the frontier and the
49
50
  # initial number of records to sample for each frontier operator
@@ -51,6 +52,7 @@ class OpFrontier:
51
52
  self.j = j
52
53
  self.source_indices = source_indices
53
54
  self.root_dataset_ids = root_dataset_ids
55
+ self.dont_use_priors = dont_use_priors
54
56
 
55
57
  # store the policy that we are optimizing under
56
58
  self.policy = policy
@@ -68,6 +70,7 @@ class OpFrontier:
68
70
  is_llm_filter = isinstance(sample_op, LLMFilter)
69
71
  is_llm_topk = isinstance(sample_op, TopKOp) and isinstance(sample_op.index, Collection)
70
72
  self.is_llm_op = is_llm_convert or is_llm_filter or is_llm_topk or self.is_llm_join
73
+ self.is_llm_convert = is_llm_convert
71
74
 
72
75
  # get order in which we will sample physical operators for this logical operator
73
76
  sample_op_indices = self._get_op_index_order(op_set, seed)
@@ -190,7 +193,9 @@ class OpFrontier:
190
193
  Returns a list of indices for the operators in the op_set.
191
194
  """
192
195
  # if this is not an llm-operator, we simply return the indices in random order
193
- if not self.is_llm_op:
196
+ if not self.is_llm_op or self.dont_use_priors:
197
+ if self.is_llm_convert:
198
+ print("Using NO PRIORS for operator sampling order")
194
199
  rng = np.random.default_rng(seed=seed)
195
200
  op_indices = np.arange(len(op_set))
196
201
  rng.shuffle(op_indices)
@@ -198,6 +203,8 @@ class OpFrontier:
198
203
 
199
204
  # if this is an llm-operator, but we do not have priors, we first compute naive priors
200
205
  if self.priors is None or any([op_id not in self.priors for op_id in map(lambda op: op.get_op_id(), op_set)]):
206
+ if self.is_llm_convert:
207
+ print("Using NAIVE PRIORS for operator sampling order")
201
208
  self.priors = self._compute_naive_priors(op_set)
202
209
 
203
210
  # NOTE: self.priors is a dictionary with format:
@@ -805,7 +812,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
805
812
  assert len(root_dataset_ids) == 1, f"Scan for {sample_op} has {len(root_dataset_ids)} > 1 root dataset ids"
806
813
  root_dataset_id = root_dataset_ids[0]
807
814
  source_indices = dataset_id_to_shuffled_source_indices[root_dataset_id]
808
- op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
815
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
809
816
  elif isinstance(sample_op, JoinOp):
810
817
  assert len(source_unique_logical_op_ids) == 2, f"Join for {sample_op} has {len(source_unique_logical_op_ids)} != 2 source logical operators"
811
818
  left_source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
@@ -814,10 +821,10 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
814
821
  for left_source_idx in left_source_indices:
815
822
  for right_source_idx in right_source_indices:
816
823
  source_indices.append((left_source_idx, right_source_idx))
817
- op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
824
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
818
825
  else:
819
826
  source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
820
- op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
827
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
821
828
 
822
829
  # initialize and start the progress manager
823
830
  self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, sample_cost_budget=self.sample_cost_budget, progress=self.progress)
@@ -27,6 +27,25 @@ from palimpzest.query.generators.generators import Generator
27
27
  from palimpzest.query.operators.physical import PhysicalOperator
28
28
 
29
29
 
30
+ class Singleton:
31
+ def __new__(cls, *args, **kw):
32
+ if not hasattr(cls, '_instance'):
33
+ orig = super(Singleton, cls) # noqa: UP008
34
+ cls._instance = orig.__new__(cls, *args, **kw)
35
+ return cls._instance
36
+
37
+ class Locks(Singleton):
38
+ model = None
39
+ clip_lock = threading.Lock()
40
+ exec_lock = threading.Lock()
41
+
42
+ @classmethod
43
+ def get_model(cls, model_name: str):
44
+ with cls.clip_lock:
45
+ if cls.model is None:
46
+ cls.model = SentenceTransformer(model_name)
47
+ return cls.model
48
+
30
49
  def compute_similarity(left_embedding: list[float], right_embedding: list[float]) -> float:
31
50
  """
32
51
  Compute the similarity between two embeddings using cosine similarity.
@@ -487,8 +506,7 @@ class EmbeddingJoin(LLMJoin):
487
506
  if field_name.split(".")[-1] in self.get_input_fields()
488
507
  ])
489
508
  self.embedding_model = Model.TEXT_EMBEDDING_3_SMALL if self.text_only else Model.CLIP_VIT_B_32
490
- self.clip_model = None
491
- self._lock = threading.Lock()
509
+ self.locks = Locks()
492
510
 
493
511
  # keep track of embedding costs that could not be amortized if no output records were produced
494
512
  self.residual_embedding_cost = 0.0
@@ -560,12 +578,6 @@ class EmbeddingJoin(LLMJoin):
560
578
  quality=quality,
561
579
  )
562
580
 
563
- def _get_clip_model(self):
564
- with self._lock:
565
- if self.clip_model is None:
566
- self.clip_model = SentenceTransformer(self.embedding_model.value)
567
- return self.clip_model
568
-
569
581
  def _compute_embeddings(self, candidates: list[DataRecord], input_fields: list[str]) -> tuple[np.ndarray, GenerationStats]:
570
582
  # return empty array and empty stats if no candidates
571
583
  if len(candidates) == 0:
@@ -581,7 +593,7 @@ class EmbeddingJoin(LLMJoin):
581
593
  total_input_tokens = response.usage.total_tokens
582
594
  embeddings = np.array([item.embedding for item in response.data])
583
595
  else:
584
- model = self._get_clip_model()
596
+ model = self.locks.get_model(self.embedding_model.value)
585
597
  embeddings = np.zeros((len(candidates), 512)) # CLIP embeddings are 512-dimensional
586
598
  num_input_fields_present = 0
587
599
  for field in input_fields:
@@ -623,7 +635,7 @@ class EmbeddingJoin(LLMJoin):
623
635
  output_record, output_record_op_stats = super()._process_join_candidate_pair(left_candidate, right_candidate, gen_kwargs)
624
636
  return output_record, output_record_op_stats, embedding_sim
625
637
 
626
- def _process_join_candidate_with_sim(self, left_candidate: DataRecord, right_candidate: DataRecord, passed_operator: bool) -> tuple[DataRecord, RecordOpStats]:
638
+ def _process_join_candidate_with_sim(self, left_candidate: DataRecord, right_candidate: DataRecord, embedding_sim: float, passed_operator: bool) -> tuple[DataRecord, RecordOpStats]:
627
639
  # compute output record and add to output_records
628
640
  join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
629
641
  join_dr._passed_operator = passed_operator
@@ -656,7 +668,7 @@ class EmbeddingJoin(LLMJoin):
656
668
  op_details={k: str(v) for k, v in self.get_id_params().items()},
657
669
  )
658
670
 
659
- return join_dr, record_op_stats
671
+ return join_dr, record_op_stats, embedding_sim
660
672
 
661
673
  def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
662
674
  # get the set of input fields from both records in the join
@@ -690,36 +702,50 @@ class EmbeddingJoin(LLMJoin):
690
702
  output_records, output_record_op_stats, num_inputs_processed = [], [], 0
691
703
 
692
704
  # draw samples until num_samples is reached
693
- if self.samples_drawn < self.num_samples:
694
- samples_to_draw = min(self.num_samples - self.samples_drawn, len(join_candidates))
695
- join_candidate_samples = join_candidates[:samples_to_draw]
696
- join_candidates = join_candidates[samples_to_draw:]
697
-
698
- # apply the generator to each pair of candidates
699
- with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
700
- futures = [
701
- executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim)
702
- for left_candidate, right_candidate, embedding_sim in join_candidate_samples
703
- ]
704
-
705
- # collect results as they complete
706
- for future in as_completed(futures):
707
- self.join_idx += 1
708
- join_output_record, join_output_record_op_stats, embedding_sim = future.result()
709
- output_records.append(join_output_record)
710
- output_record_op_stats.append(join_output_record_op_stats)
711
- print(f"{self.join_idx} JOINED")
712
-
713
- # update similarity thresholds
714
- records_joined = join_output_record._passed_operator
715
- if not records_joined and embedding_sim > self.max_non_matching_sim:
716
- self.max_non_matching_sim = embedding_sim
717
- if records_joined and embedding_sim < self.min_matching_sim:
718
- self.min_matching_sim = embedding_sim
719
-
720
- # update samples drawn and num_inputs_processed
721
- self.samples_drawn += samples_to_draw
722
- num_inputs_processed += samples_to_draw
705
+ with self.locks.exec_lock:
706
+ if self.samples_drawn < self.num_samples:
707
+ samples_to_draw = min(self.num_samples - self.samples_drawn, len(join_candidates))
708
+ join_candidate_samples = join_candidates[:samples_to_draw]
709
+ join_candidates = join_candidates[samples_to_draw:]
710
+
711
+ # apply the generator to each pair of candidates
712
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
713
+ futures = [
714
+ executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim)
715
+ for left_candidate, right_candidate, embedding_sim in join_candidate_samples
716
+ ]
717
+
718
+ # collect results as they complete
719
+ similarities, joined = [], []
720
+ for future in as_completed(futures):
721
+ self.join_idx += 1
722
+ join_output_record, join_output_record_op_stats, embedding_sim = future.result()
723
+ output_records.append(join_output_record)
724
+ output_record_op_stats.append(join_output_record_op_stats)
725
+ similarities.append(embedding_sim)
726
+ joined.append(join_output_record._passed_operator)
727
+ print(f"{self.join_idx} JOINED")
728
+
729
+ # sort join results by embedding similarity
730
+ sorted_sim_join_tuples = sorted(zip(similarities, joined), key=lambda x: x[0])
731
+
732
+ # compute threshold below which no records joined
733
+ for embedding_sim, records_joined in sorted_sim_join_tuples:
734
+ if records_joined:
735
+ break
736
+ if not records_joined and embedding_sim > self.max_non_matching_sim:
737
+ self.max_non_matching_sim = embedding_sim
738
+
739
+ # compute threshold above which all records joined
740
+ for embedding_sim, records_joined in reversed(sorted_sim_join_tuples):
741
+ if not records_joined:
742
+ break
743
+ if records_joined and embedding_sim < self.min_matching_sim:
744
+ self.min_matching_sim = embedding_sim
745
+
746
+ # update samples drawn and num_inputs_processed
747
+ self.samples_drawn += samples_to_draw
748
+ num_inputs_processed += samples_to_draw
723
749
 
724
750
  # process remaining candidates based on embedding similarity
725
751
  if len(join_candidates) > 0:
@@ -727,43 +753,48 @@ class EmbeddingJoin(LLMJoin):
727
753
  with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
728
754
  futures = []
729
755
  for left_candidate, right_candidate, embedding_sim in join_candidates:
730
- llm_call_needed = (
731
- self.min_matching_sim == float("inf")
732
- or self.max_non_matching_sim == float("-inf")
733
- or self.min_matching_sim <= embedding_sim <= self.max_non_matching_sim
734
- )
756
+ # if the embedding similarity is lower than the threshold below which no records joined,
757
+ # then we can skip the LLM call and mark the records as not joined
758
+ if embedding_sim < self.max_non_matching_sim:
759
+ futures.append(executor.submit(self._process_join_candidate_with_sim, left_candidate, right_candidate, embedding_sim, passed_operator=False))
735
760
 
736
- if llm_call_needed:
737
- futures.append(executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim))
761
+ # if the embedding similarity is higher than the threshold above which all records joined,
762
+ # then we can skip the LLM call and mark the records as joined
763
+ elif embedding_sim > self.min_matching_sim:
764
+ futures.append(executor.submit(self._process_join_candidate_with_sim, left_candidate, right_candidate, embedding_sim, passed_operator=True))
738
765
 
739
- elif embedding_sim < self.min_matching_sim:
740
- self.join_idx += 1
741
- output_record, record_op_stats = self._process_join_candidate_with_sim(left_candidate, right_candidate, passed_operator=False)
742
- output_records.append(output_record)
743
- output_record_op_stats.append(record_op_stats)
744
- print(f"{self.join_idx} SKIPPED (low sim: {embedding_sim:.4f} < {self.min_matching_sim:.4f})")
745
-
746
- elif embedding_sim > self.max_non_matching_sim:
747
- self.join_idx += 1
748
- output_record, record_op_stats = self._process_join_candidate_with_sim(left_candidate, right_candidate, passed_operator=True)
749
- output_records.append(output_record)
750
- output_record_op_stats.append(record_op_stats)
751
- print(f"{self.join_idx} JOINED (high sim: {embedding_sim:.4f} > {self.max_non_matching_sim:.4f})")
766
+ # otherwise, we will process the LLM call
767
+ else:
768
+ futures.append(executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim))
752
769
 
753
770
  num_inputs_processed += 1
754
771
 
755
772
  # collect results as they complete
773
+ similarities, joined = [], []
756
774
  for future in as_completed(futures):
757
775
  self.join_idx += 1
758
776
  join_output_record, join_output_record_op_stats, embedding_sim = future.result()
759
777
  output_records.append(join_output_record)
760
778
  output_record_op_stats.append(join_output_record_op_stats)
779
+ similarities.append(embedding_sim)
780
+ joined.append(join_output_record._passed_operator)
761
781
  print(f"{self.join_idx} JOINED")
762
782
 
763
- # update similarity thresholds
764
- records_joined = join_output_record._passed_operator
783
+ ### update thresholds if there are llm calls which incrementally squeeze the boundaries ###
784
+ # sort join results by embedding similarity
785
+ sorted_sim_join_tuples = sorted(zip(similarities, joined), key=lambda x: x[0])
786
+
787
+ # potentially update threshold below which no records joined
788
+ for embedding_sim, records_joined in sorted_sim_join_tuples:
789
+ if records_joined:
790
+ break
765
791
  if not records_joined and embedding_sim > self.max_non_matching_sim:
766
792
  self.max_non_matching_sim = embedding_sim
793
+
794
+ # potentially update threshold above which all records joined
795
+ for embedding_sim, records_joined in reversed(sorted_sim_join_tuples):
796
+ if not records_joined:
797
+ break
767
798
  if records_joined and embedding_sim < self.min_matching_sim:
768
799
  self.min_matching_sim = embedding_sim
769
800
 
@@ -156,7 +156,7 @@ class RAGConvert(LLMConvert):
156
156
  # skip this field if it is not a string or a list of strings
157
157
  is_string_field = field.annotation in [str, str | None, str | Any]
158
158
  is_list_string_field = field.annotation in [list[str], list[str] | None, list[str] | Any]
159
- if not (is_string_field or is_list_string_field):
159
+ if not (is_string_field or is_list_string_field) or candidate[field_name] is None:
160
160
  continue
161
161
 
162
162
  # if this is a list of strings, join the strings
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import os
4
+ import threading
4
5
  import time
5
6
  from typing import Callable
6
7
 
@@ -17,6 +18,24 @@ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, Recor
17
18
  from palimpzest.query.operators.physical import PhysicalOperator
18
19
 
19
20
 
21
+ class Singleton:
22
+ def __new__(cls, *args, **kw):
23
+ if not hasattr(cls, '_instance'):
24
+ orig = super(Singleton, cls) # noqa: UP008
25
+ cls._instance = orig.__new__(cls, *args, **kw)
26
+ return cls._instance
27
+
28
+ class ClipModel(Singleton):
29
+ model = None
30
+ lock = threading.Lock()
31
+
32
+ @classmethod
33
+ def get_model(cls, model_name: str):
34
+ with cls.lock:
35
+ if cls.model is None:
36
+ cls.model = SentenceTransformer(model_name)
37
+ return cls.model
38
+
20
39
  class TopKOp(PhysicalOperator):
21
40
  def __init__(
22
41
  self,
@@ -56,6 +75,7 @@ class TopKOp(PhysicalOperator):
56
75
  self.output_attrs = output_attrs
57
76
  self.search_func = search_func if search_func is not None else self.default_search_func
58
77
  self.k = k
78
+ self.clip_model = ClipModel()
59
79
 
60
80
  def __str__(self):
61
81
  op = super().__str__()
@@ -185,7 +205,6 @@ class TopKOp(PhysicalOperator):
185
205
  # construct and return the record set
186
206
  return DataRecordSet(drs, record_op_stats_lst)
187
207
 
188
-
189
208
  def __call__(self, candidate: DataRecord) -> DataRecordSet:
190
209
  start_time = time.time()
191
210
 
@@ -209,9 +228,9 @@ class TopKOp(PhysicalOperator):
209
228
  inputs, gen_stats = None, GenerationStats()
210
229
  if isinstance(self.index, Collection):
211
230
  uses_openai_embedding_fcn = isinstance(self.index._embedding_function, OpenAIEmbeddingFunction)
212
- uses_sentence_transformer_embedding_fcn = isinstance(self.index._embedding_function, SentenceTransformerEmbeddingFunction)
231
+ uses_clip_model = isinstance(self.index._embedding_function, SentenceTransformerEmbeddingFunction)
213
232
  error_msg = "ChromaDB index must use OpenAI or SentenceTransformer embedding function; see: https://docs.trychroma.com/integrations/embedding-models/openai"
214
- assert uses_openai_embedding_fcn or uses_sentence_transformer_embedding_fcn, error_msg
233
+ assert uses_openai_embedding_fcn or uses_clip_model, error_msg
215
234
 
216
235
  model_name = self.index._embedding_function.model_name if uses_openai_embedding_fcn else "clip-ViT-B-32"
217
236
  err_msg = f"For Chromadb, we currently only support `text-embedding-3-small` and `clip-ViT-B-32`; your index uses: {model_name}"
@@ -228,8 +247,8 @@ class TopKOp(PhysicalOperator):
228
247
  total_input_tokens = response.usage.total_tokens
229
248
  inputs = [item.embedding for item in response.data]
230
249
 
231
- elif uses_sentence_transformer_embedding_fcn:
232
- model = SentenceTransformer(model_name)
250
+ elif uses_clip_model:
251
+ model = self.clip_model.get_model(model_name)
233
252
  inputs = model.encode(query)
234
253
 
235
254
  embed_total_time = time.time() - embed_start_time
@@ -48,6 +48,7 @@ class QueryProcessorConfig(BaseModel):
48
48
  seed: int = Field(default=42)
49
49
  exp_name: str | None = Field(default=None)
50
50
  priors: dict | None = Field(default=None)
51
+ dont_use_priors: bool = Field(default=False)
51
52
 
52
53
  def to_dict(self) -> dict:
53
54
  """Convert the config to a dict representation."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: palimpzest
3
- Version: 1.1.0
3
+ Version: 1.1.1
4
4
  Summary: Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language
5
5
  Author-email: MIT DSG Semantic Management Lab <michjc@csail.mit.edu>
6
6
  Project-URL: homepage, https://palimpzest.org
@@ -28,7 +28,7 @@ palimpzest/prompts/filter_prompts.py,sha256=D-aY3-th1GzEHrVGbKORVN2R7x7coYGjp8Fr
28
28
  palimpzest/prompts/join_prompts.py,sha256=z-y4L1cw1O3I_F9DW6MvqeztdQoKDQawX6nK6vQAkdM,2916
29
29
  palimpzest/prompts/moa_aggregator_prompts.py,sha256=b5cz4G2oF86LlHOy8vmtxoMcZ9zaZoppKrURHgzCzNU,5248
30
30
  palimpzest/prompts/moa_proposer_prompts.py,sha256=yfZYwmCg-Tg9h0H7PJMEuDYPR45EbYnORmVX6cY2vRQ,3125
31
- palimpzest/prompts/prompt_factory.py,sha256=0xj3glD5Y7R7MUsmKxJCOa4q9VeIILDO2IVWz_4huYw,49355
31
+ palimpzest/prompts/prompt_factory.py,sha256=txtCvDI0sv_LEar0iK_E1_mlRMvuwZseM-6BSC9ugUs,49926
32
32
  palimpzest/prompts/split_merge_prompts.py,sha256=hX-MThmW4VU7rjgm7gb-bpniEMdj25mtp0o8qBeWvIQ,5573
33
33
  palimpzest/prompts/split_proposer_prompts.py,sha256=Ucqwfn4FqFk-b9E024EK4e_3_QndTJjggwiwa1x5CQs,3115
34
34
  palimpzest/prompts/utils.py,sha256=Eure2pqm8Ftme9lQlHwFL9EqK3yjH14WQHofnQINce4,7497
@@ -36,9 +36,9 @@ palimpzest/prompts/validator.py,sha256=OxebGjvXNBy0Cq79XI3aPRbongzOdtHH6mQctpbWc
36
36
  palimpzest/query/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  palimpzest/query/execution/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
38
  palimpzest/query/execution/all_sample_execution_strategy.py,sha256=d2MO_AmXF_HbV4rUNkFqHsuoBCofU98zQ-D3Q06BXCc,14369
39
- palimpzest/query/execution/execution_strategy.py,sha256=PkXtO8wZpv6HHLlpxXwcc0t5pPCGafp0L_iVkG9bmXM,19162
39
+ palimpzest/query/execution/execution_strategy.py,sha256=TnSInUlcGZHn2GUpLiIFVgfPpmsNfIKKgElnRt6a6ss,19248
40
40
  palimpzest/query/execution/execution_strategy_type.py,sha256=vRQBPCQN5_aoyD3TLIeW3VPo15mqF-5RBvEXkENz9FE,987
41
- palimpzest/query/execution/mab_execution_strategy.py,sha256=i03LYRhaG2VLia-XSiYbKdlu3hLQZul75xMcRGm065M,47767
41
+ palimpzest/query/execution/mab_execution_strategy.py,sha256=eKWEEFjOSpkaqkwuoFbg5yA5jdkki4FfRtCD7P4VaeI,48205
42
42
  palimpzest/query/execution/parallel_execution_strategy.py,sha256=Di-8d7waE0bev4kNDXEJJqQ0wwQ87_sPV-t5qFtAlPQ,17589
43
43
  palimpzest/query/execution/single_threaded_execution_strategy.py,sha256=1rjMel0-AI6KUi_SMNgPPXxMgG5-t9lenLKoYEClgjk,17464
44
44
  palimpzest/query/generators/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -50,17 +50,17 @@ palimpzest/query/operators/convert.py,sha256=cjUPrSgvZBZXBbrbepIxZMBXjbWWPLuTX4J
50
50
  palimpzest/query/operators/critique_and_refine.py,sha256=Q-NhasVoD9meX7g36RPrv3q4R48_8XEU4d3TE46hRJI,8979
51
51
  palimpzest/query/operators/distinct.py,sha256=ZTXlIS7IaFRTsWv9RemzCo1JLz25vEma-TB42CV5fJQ,2614
52
52
  palimpzest/query/operators/filter.py,sha256=ufREsO2-8CBk4u4fabDBYpEvb806E11EOyW-wuRs4vw,10356
53
- palimpzest/query/operators/join.py,sha256=17BGzrxf_fkqhnEzhq-5b0qv2qQTw7z6job5YkBUrZE,36993
53
+ palimpzest/query/operators/join.py,sha256=CPqqIQrxWNjslWkxnZrbW-DKFkJ8iiDsWLJeZmUs5gI,38549
54
54
  palimpzest/query/operators/limit.py,sha256=pdo7WfWY97SW3c-WqZ4SIPw7lHIVbaXPEWqHyK8qkF8,2130
55
55
  palimpzest/query/operators/logical.py,sha256=OtB82L1X19ibtLx1GIfeXXyO7YfjkFmh3puIUgqKQRE,21160
56
56
  palimpzest/query/operators/mixture_of_agents.py,sha256=KC-ZpjtGY28sfwlk2TpduLC_fADj_UASFCaicaKqSFc,11671
57
57
  palimpzest/query/operators/physical.py,sha256=0_BfFX9nKuN__440eAfEfApWAoGOunVSCZIQxS4HO2Y,9773
58
58
  palimpzest/query/operators/project.py,sha256=gxbjsHEACCP9wxATH-mw6wOSUi5s13JyWsLqqhAYWXQ,2111
59
- palimpzest/query/operators/rag.py,sha256=CJm83pBapA8HEGfhRnWjqt_ESS6hJAPvPJksRTOGL7M,20124
59
+ palimpzest/query/operators/rag.py,sha256=jJ09uCkjCr5KuiFMPhWcZAQNBR3_5eNb0ALkn_IoXbU,20157
60
60
  palimpzest/query/operators/scan.py,sha256=OqCiPRTvTY7SbauNMyFvGT5nRVeRzVsGYSrkoN1Ib_w,7407
61
61
  palimpzest/query/operators/search.py,sha256=cQin-Qc9FT7V0Gv3-pxMLbVMjqE6ALe99V0OrQhA6CI,22711
62
62
  palimpzest/query/operators/split.py,sha256=oLzwnYb8TNf3XA9TMKEAIw7EIA12wHneaD42BNLIHiI,15043
63
- palimpzest/query/operators/topk.py,sha256=92Bu98xc8CMlS9bf1xc0FxcfVuhv6j4x_303Aq1v-U0,13053
63
+ palimpzest/query/operators/topk.py,sha256=MZl83Cu43QmN4skjlfpR8EVFFCgA7sR6PbGgBGWC0tg,13564
64
64
  palimpzest/query/optimizer/__init__.py,sha256=v9fSBOL2p3sQew4LrN2DQUPe0WezO328Hr54qBTqrAs,2799
65
65
  palimpzest/query/optimizer/cost_model.py,sha256=p7AsR6f4VYdGjrUKPGN_VTErY36GjY90Bsvsys4le2M,12655
66
66
  palimpzest/query/optimizer/optimizer.py,sha256=ksLkzQ2sVgJFbkxGF3ncF74EsAHZFos8G19xlHQrtJo,20063
@@ -71,7 +71,7 @@ palimpzest/query/optimizer/primitives.py,sha256=jMMVq37y1tWiPU1lSSKQP9OP-mzkpSxS
71
71
  palimpzest/query/optimizer/rules.py,sha256=awhe76trskv5Tq5E2QHpUN_YV6jH8INywa0Ige8IIhY,53341
72
72
  palimpzest/query/optimizer/tasks.py,sha256=DNJjY2QldfKFWj6INHElMh88dYc36Z5m3wHwbs4jyF4,30455
73
73
  palimpzest/query/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
74
- palimpzest/query/processor/config.py,sha256=MkZ776VUk9tIOCZdVyH__H4Z0gO4c8fpehX2Gqywvks,2472
74
+ palimpzest/query/processor/config.py,sha256=8-MpPYHv2SI4dub4MP_gOYSRxO80_ALLuWRxD-F2YOg,2521
75
75
  palimpzest/query/processor/query_processor.py,sha256=T4ffPbnOX23G8FDITzmM7Iw7DUEDWIHnwl8XLYllgjg,6240
76
76
  palimpzest/query/processor/query_processor_factory.py,sha256=i9L9StqlUi7m1AqZMuYQWhunqOJi3nLK47skhxq9tIA,8317
77
77
  palimpzest/schemabuilder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -89,8 +89,8 @@ palimpzest/utils/progress.py,sha256=eHXrTPTCRHjMdK0EjYRUzSxcV6N1lK8TS3Ju_ZlQLhY,
89
89
  palimpzest/utils/udfs.py,sha256=LjHic54B1az-rKgNLur0wOpaz2ko_UodjLEJrazkxvY,1854
90
90
  palimpzest/validator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
91
91
  palimpzest/validator/validator.py,sha256=SvjK09zCpGtK0yM0OasvQlSzyq3loy32DyOOKRmYXC0,15977
92
- palimpzest-1.1.0.dist-info/licenses/LICENSE,sha256=5GUlHy9lr-Py9kvV38FF1m3yy3NqM18fefuE9wkWumo,1079
93
- palimpzest-1.1.0.dist-info/METADATA,sha256=0AZq33WMFrxkarQADVPv2OFQu7ko38fzhBOtTQjc3Fw,5359
94
- palimpzest-1.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
95
- palimpzest-1.1.0.dist-info/top_level.txt,sha256=raV06dJUgohefUn3ZyJS2uqp_Y76EOLA9Y2e_fxt8Ew,11
96
- palimpzest-1.1.0.dist-info/RECORD,,
92
+ palimpzest-1.1.1.dist-info/licenses/LICENSE,sha256=5GUlHy9lr-Py9kvV38FF1m3yy3NqM18fefuE9wkWumo,1079
93
+ palimpzest-1.1.1.dist-info/METADATA,sha256=lZssP9vUbcnMNl_6vwpDvPyiI1IH65nDwggeSndSOcw,5359
94
+ palimpzest-1.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
95
+ palimpzest-1.1.1.dist-info/top_level.txt,sha256=raV06dJUgohefUn3ZyJS2uqp_Y76EOLA9Y2e_fxt8Ew,11
96
+ palimpzest-1.1.1.dist-info/RECORD,,