palimpzest 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
palimpzest/core/models.py CHANGED
@@ -35,12 +35,18 @@ class GenerationStats(BaseModel):
35
35
  # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records
36
36
  total_output_tokens: float = 0.0
37
37
 
38
+ # the total number of input tokens processed by embedding models
39
+ total_embedding_input_tokens: float = 0.0
40
+
38
41
  # the total cost of processing the input tokens; None if this operation did not use an LLM
39
42
  total_input_cost: float = 0.0
40
43
 
41
44
  # the total cost of processing the output tokens; None if this operation did not use an LLM
42
45
  total_output_cost: float = 0.0
43
46
 
47
+ # the total cost of processing input tokens for embedding models
48
+ total_embedding_cost: float = 0.0
49
+
44
50
  # the total cost of processing the input and output tokens; None if this operation did not use an LLM
45
51
  cost_per_record: float = 0.0
46
52
 
@@ -68,6 +74,9 @@ class GenerationStats(BaseModel):
68
74
  "fn_call_duration_secs",
69
75
  "total_llm_calls",
70
76
  "total_embedding_llm_calls",
77
+ "total_embedding_input_tokens",
78
+ "total_embedding_cost"
79
+
71
80
  ]:
72
81
  setattr(self, model_field, getattr(self, model_field) + getattr(other, model_field))
73
82
  return self
@@ -85,6 +94,8 @@ class GenerationStats(BaseModel):
85
94
  "cost_per_record",
86
95
  "total_llm_calls",
87
96
  "total_embedding_llm_calls",
97
+ "total_embedding_input_tokens",
98
+ "total_embedding_cost"
88
99
  ]
89
100
  }
90
101
  # dct['raw_answers'] = self.raw_answers + other.raw_answers
@@ -107,6 +118,8 @@ class GenerationStats(BaseModel):
107
118
  "fn_call_duration_secs",
108
119
  "total_llm_calls",
109
120
  "total_embedding_llm_calls",
121
+ "total_embedding_input_tokens",
122
+ "total_embedding_cost"
110
123
  ]:
111
124
  setattr(self, model_field, getattr(self, model_field) / quotient)
112
125
  return self
@@ -128,6 +141,8 @@ class GenerationStats(BaseModel):
128
141
  "total_llm_calls",
129
142
  "total_embedding_llm_calls",
130
143
  "cost_per_record",
144
+ "total_embedding_input_tokens",
145
+ "total_embedding_cost"
131
146
  ]
132
147
  }
133
148
  dct["model_name"] = self.model_name
@@ -217,6 +232,10 @@ class RecordOpStats(BaseModel):
217
232
  # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records
218
233
  total_output_tokens: float = 0.0
219
234
 
235
+ # the total number of input tokens processed by embedding models
236
+ # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records
237
+ total_embedding_input_tokens: float = 0.0
238
+
220
239
  # the total cost of processing the input tokens; None if this operation did not use an LLM
221
240
  total_input_cost: float = 0.0
222
241
 
@@ -278,6 +297,9 @@ class OperatorStats(BaseModel):
278
297
  # the total output tokens processed by this operation
279
298
  total_output_tokens: int = 0
280
299
 
300
+ #the total embedding input tokens processed by this operation
301
+ total_embedding_input_tokens: int = 0
302
+
281
303
  # a list of RecordOpStats processed by the operation
282
304
  record_op_stats_lst: list[RecordOpStats] = Field(default_factory=list)
283
305
 
@@ -309,6 +331,7 @@ class OperatorStats(BaseModel):
309
331
  self.total_op_cost += stats.total_op_cost
310
332
  self.total_input_tokens += stats.total_input_tokens
311
333
  self.total_output_tokens += stats.total_output_tokens
334
+ self.total_embedding_input_tokens += stats.total_embedding_input_tokens
312
335
  self.record_op_stats_lst.extend(stats.record_op_stats_lst)
313
336
 
314
337
  elif isinstance(stats, RecordOpStats):
@@ -319,6 +342,7 @@ class OperatorStats(BaseModel):
319
342
  self.total_op_cost += stats.cost_per_record
320
343
  self.total_input_tokens += stats.total_input_tokens
321
344
  self.total_output_tokens += stats.total_output_tokens
345
+ self.total_embedding_input_tokens += stats.total_embedding_input_tokens
322
346
 
323
347
  else:
324
348
  raise TypeError(f"Cannot add {type(stats)} to OperatorStats")
@@ -370,6 +394,9 @@ class BasePlanStats(BaseModel):
370
394
  # total output tokens processed by this plan
371
395
  total_output_tokens: int = 0
372
396
 
397
+ # total embedding input tokens processed by this plan
398
+ total_embedding_input_tokens: int = 0
399
+
373
400
  # start time for the plan execution; should be set by calling PlanStats.start()
374
401
  start_time: float | None = None
375
402
 
@@ -385,6 +412,7 @@ class BasePlanStats(BaseModel):
385
412
  self.total_plan_cost = self.sum_op_costs() + self.sum_validation_costs()
386
413
  self.total_input_tokens = self.sum_input_tokens() + self.sum_validation_input_tokens()
387
414
  self.total_output_tokens = self.sum_output_tokens() + self.sum_validation_output_tokens()
415
+ self.total_embedding_input_tokens = self.sum_embedding_input_tokens() + self.sum_validation_embedding_input_tokens()
388
416
 
389
417
  @staticmethod
390
418
  @abstractmethod
@@ -415,6 +443,13 @@ class BasePlanStats(BaseModel):
415
443
  """
416
444
  pass
417
445
 
446
+ @abstractmethod
447
+ def sum_embedding_input_tokens(self) -> int:
448
+ """
449
+ Sum the input embedding tokens processed by all operators in this plan.
450
+ """
451
+ pass
452
+
418
453
  @abstractmethod
419
454
  def add_record_op_stats(self, unique_full_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
420
455
  """
@@ -453,6 +488,12 @@ class BasePlanStats(BaseModel):
453
488
  Sum the output tokens processed by all validation generations in this plan.
454
489
  """
455
490
  return sum([gen_stats.total_output_tokens for _, gen_stats in self.validation_gen_stats.items()])
491
+
492
+ def sum_validation_embedding_input_tokens(self) -> int:
493
+ """
494
+ Sum the input embedding tokens processed by all validation generations in this plan.
495
+ """
496
+ return sum([gen_stats.total_embedding_input_tokens for _, gen_stats in self.validation_gen_stats.items()])
456
497
 
457
498
  def get_total_cost_so_far(self) -> float:
458
499
  """
@@ -501,6 +542,12 @@ class PlanStats(BasePlanStats):
501
542
  Sum the output tokens processed by all operators in this plan.
502
543
  """
503
544
  return sum([op_stats.total_output_tokens for _, op_stats in self.operator_stats.items()])
545
+
546
+ def sum_embedding_input_tokens(self) -> int:
547
+ """
548
+ Sum the input embedding tokens processed by all operators in this plan.
549
+ """
550
+ return sum([op_stats.total_embedding_input_tokens for _, op_stats in self.operator_stats.items()])
504
551
 
505
552
  def add_record_op_stats(self, unique_full_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
506
553
  """
@@ -528,6 +575,7 @@ class PlanStats(BasePlanStats):
528
575
  self.total_plan_cost += plan_stats.total_plan_cost
529
576
  self.total_input_tokens += plan_stats.total_input_tokens
530
577
  self.total_output_tokens += plan_stats.total_output_tokens
578
+ self.total_embedding_input_tokens += plan_stats.total_embedding_input_tokens
531
579
  for unique_full_op_id, op_stats in plan_stats.operator_stats.items():
532
580
  if unique_full_op_id in self.operator_stats:
533
581
  self.operator_stats[unique_full_op_id] += op_stats
@@ -539,6 +587,7 @@ class PlanStats(BasePlanStats):
539
587
  stats += f"total_plan_cost={self.total_plan_cost} \n"
540
588
  stats += f"total_input_tokens={self.total_input_tokens} \n"
541
589
  stats += f"total_output_tokens={self.total_output_tokens} \n"
590
+ stats += f"total_embedding_input_tokens={self.total_embedding_input_tokens} \n"
542
591
  for idx, op_stats in enumerate(self.operator_stats.values()):
543
592
  stats += f"{idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n"
544
593
  return stats
@@ -586,6 +635,12 @@ class SentinelPlanStats(BasePlanStats):
586
635
  Sum the output tokens processed by all operators in this plan.
587
636
  """
588
637
  return sum(sum([op_stats.total_output_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
638
+
639
+ def sum_embedding_input_tokens(self) -> int:
640
+ """
641
+ Sum the output tokens processed by all operators in this plan.
642
+ """
643
+ return sum(sum([op_stats.total_embedding_input_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items())
589
644
 
590
645
  def add_record_op_stats(self, unique_logical_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None:
591
646
  """
@@ -627,6 +682,7 @@ class SentinelPlanStats(BasePlanStats):
627
682
  self.total_plan_cost += plan_stats.total_plan_cost
628
683
  self.total_input_tokens += plan_stats.total_input_tokens
629
684
  self.total_output_tokens += plan_stats.total_output_tokens
685
+ self.total_embedding_input_tokens += plan_stats.total_embedding_input_tokens
630
686
  for unique_logical_op_id, physical_op_stats in plan_stats.operator_stats.items():
631
687
  for full_op_id, op_stats in physical_op_stats.items():
632
688
  if unique_logical_op_id in self.operator_stats:
@@ -648,6 +704,7 @@ class SentinelPlanStats(BasePlanStats):
648
704
  stats += f"total_plan_cost={self.total_plan_cost} \n"
649
705
  stats += f"total_input_tokens={self.total_input_tokens} \n"
650
706
  stats += f"total_output_tokens={self.total_output_tokens} \n"
707
+ stats += f"total_embedding_input_tokens={self.total_embedding_input_tokens} \n"
651
708
  for outer_idx, physical_op_stats in enumerate(self.operator_stats.values()):
652
709
  total_time = sum([op_stats.total_op_time for op_stats in physical_op_stats.values()])
653
710
  total_cost = sum([op_stats.total_op_cost for op_stats in physical_op_stats.values()])
@@ -695,6 +752,9 @@ class ExecutionStats(BaseModel):
695
752
  # total number of output tokens processed
696
753
  total_output_tokens: int = 0
697
754
 
755
+ # total number of embedding input tokens processed
756
+ total_embedding_input_tokens: int = 0
757
+
698
758
  # total number of tokens processed
699
759
  total_tokens: int = 0
700
760
 
@@ -748,7 +808,8 @@ class ExecutionStats(BaseModel):
748
808
  # compute the tokens for total execution
749
809
  self.total_input_tokens = self.sum_input_tokens()
750
810
  self.total_output_tokens = self.sum_output_tokens()
751
- self.total_tokens = self.total_input_tokens + self.total_output_tokens
811
+ self.total_embedding_input_tokens = self.sum_embedding_input_tokens()
812
+ self.total_tokens = self.total_input_tokens + self.total_output_tokens + self.total_embedding_input_tokens
752
813
 
753
814
  # compute plan_strs
754
815
  self.plan_strs = {plan_id: plan_stats.plan_str for plan_id, plan_stats in self.plan_stats.items()}
@@ -780,6 +841,15 @@ class ExecutionStats(BaseModel):
780
841
  sentinel_plan_output_tokens = sum([plan_stats.sum_output_tokens() for _, plan_stats in self.sentinel_plan_stats.items()])
781
842
  plan_output_tokens = sum([plan_stats.sum_output_tokens() for _, plan_stats in self.plan_stats.items()])
782
843
  return plan_output_tokens + sentinel_plan_output_tokens
844
+
845
+
846
+ def sum_embedding_input_tokens(self) -> int:
847
+ """
848
+ Sum the embedding input tokens processed in this execution
849
+ """
850
+ sentinel_plan_embedding_input_tokens = sum([plan_stats.sum_embedding_input_tokens() for _, plan_stats in self.sentinel_plan_stats.items()])
851
+ plan_embedding_input_tokens = sum([plan_stats.sum_embedding_input_tokens() for _, plan_stats in self.plan_stats.items()])
852
+ return plan_embedding_input_tokens + sentinel_plan_embedding_input_tokens
783
853
 
784
854
  def add_plan_stats(self, plan_stats: PlanStats | SentinelPlanStats | list[PlanStats] | list[SentinelPlanStats]) -> None:
785
855
  """
@@ -830,7 +830,7 @@ class PromptFactory:
830
830
  field_type = dr.get_field_type(field_name)
831
831
 
832
832
  # audio filepath (or list of audio filepaths)
833
- if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any]:
833
+ if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any] and field_value is not None:
834
834
  with open(field_value, "rb") as f:
835
835
  base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
836
836
  audio_content.append(
@@ -839,6 +839,8 @@ class PromptFactory:
839
839
 
840
840
  elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None, list[AudioFilepath] | Any]:
841
841
  for audio_filepath in field_value:
842
+ if audio_filepath is None:
843
+ continue
842
844
  with open(audio_filepath, "rb") as f:
843
845
  base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
844
846
  audio_content.append(
@@ -846,13 +848,15 @@ class PromptFactory:
846
848
  )
847
849
 
848
850
  # pre-encoded images (or list of pre-encoded images)
849
- elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any]:
851
+ elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any] and field_value is not None:
850
852
  audio_content.append(
851
853
  {"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
852
854
  )
853
855
 
854
856
  elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None, list[AudioBase64] | Any]:
855
857
  for base64_audio in field_value:
858
+ if base64_audio is None:
859
+ continue
856
860
  audio_content.append(
857
861
  {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
858
862
  )
@@ -882,7 +886,7 @@ class PromptFactory:
882
886
  field_type = dr.get_field_type(field_name)
883
887
 
884
888
  # image filepath (or list of image filepaths)
885
- if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any]:
889
+ if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any] and field_value is not None:
886
890
  with open(field_value, "rb") as f:
887
891
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
888
892
  image_content.append(
@@ -891,6 +895,8 @@ class PromptFactory:
891
895
 
892
896
  elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]:
893
897
  for image_filepath in field_value:
898
+ if image_filepath is None:
899
+ continue
894
900
  with open(image_filepath, "rb") as f:
895
901
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
896
902
  image_content.append(
@@ -898,21 +904,25 @@ class PromptFactory:
898
904
  )
899
905
 
900
906
  # image url (or list of image urls)
901
- elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any]:
907
+ elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any] and field_value is not None:
902
908
  image_content.append({"type": "image_url", "image_url": {"url": field_value}})
903
909
 
904
910
  elif field_type.annotation in [list[ImageURL], list[ImageURL] | None, list[ImageURL] | Any]:
905
911
  for image_url in field_value:
912
+ if image_url is None:
913
+ continue
906
914
  image_content.append({"type": "image_url", "image_url": {"url": image_url}})
907
915
 
908
916
  # pre-encoded images (or list of pre-encoded images)
909
- elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any]:
917
+ elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any] and field_value is not None:
910
918
  image_content.append(
911
919
  {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
912
920
  )
913
921
 
914
922
  elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None, list[ImageBase64] | Any]:
915
923
  for base64_image in field_value:
924
+ if base64_image is None:
925
+ continue
916
926
  image_content.append(
917
927
  {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
918
928
  )
@@ -91,6 +91,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
91
91
  use_final_op_quality: bool = False,
92
92
  seed: int = 42,
93
93
  exp_name: str | None = None,
94
+ dont_use_priors: bool = False,
94
95
  *args,
95
96
  **kwargs,
96
97
  ):
@@ -105,6 +106,7 @@ class SentinelExecutionStrategy(BaseExecutionStrategy, ABC):
105
106
  self.seed = seed
106
107
  self.rng = np.random.default_rng(seed=seed)
107
108
  self.exp_name = exp_name
109
+ self.dont_use_priors = dont_use_priors
108
110
 
109
111
  # general cache which maps hash(logical_op_id, phys_op_id, hash(input)) --> record_set
110
112
  self.cache: dict[int, DataRecordSet] = {}
@@ -44,6 +44,7 @@ class OpFrontier:
44
44
  seed: int,
45
45
  policy: Policy,
46
46
  priors: dict | None = None,
47
+ dont_use_priors: bool = False,
47
48
  ):
48
49
  # set k and j, which are the initial number of operators in the frontier and the
49
50
  # initial number of records to sample for each frontier operator
@@ -51,6 +52,7 @@ class OpFrontier:
51
52
  self.j = j
52
53
  self.source_indices = source_indices
53
54
  self.root_dataset_ids = root_dataset_ids
55
+ self.dont_use_priors = dont_use_priors
54
56
 
55
57
  # store the policy that we are optimizing under
56
58
  self.policy = policy
@@ -68,6 +70,7 @@ class OpFrontier:
68
70
  is_llm_filter = isinstance(sample_op, LLMFilter)
69
71
  is_llm_topk = isinstance(sample_op, TopKOp) and isinstance(sample_op.index, Collection)
70
72
  self.is_llm_op = is_llm_convert or is_llm_filter or is_llm_topk or self.is_llm_join
73
+ self.is_llm_convert = is_llm_convert
71
74
 
72
75
  # get order in which we will sample physical operators for this logical operator
73
76
  sample_op_indices = self._get_op_index_order(op_set, seed)
@@ -190,7 +193,9 @@ class OpFrontier:
190
193
  Returns a list of indices for the operators in the op_set.
191
194
  """
192
195
  # if this is not an llm-operator, we simply return the indices in random order
193
- if not self.is_llm_op:
196
+ if not self.is_llm_op or self.dont_use_priors:
197
+ if self.is_llm_convert:
198
+ print("Using NO PRIORS for operator sampling order")
194
199
  rng = np.random.default_rng(seed=seed)
195
200
  op_indices = np.arange(len(op_set))
196
201
  rng.shuffle(op_indices)
@@ -198,6 +203,8 @@ class OpFrontier:
198
203
 
199
204
  # if this is an llm-operator, but we do not have priors, we first compute naive priors
200
205
  if self.priors is None or any([op_id not in self.priors for op_id in map(lambda op: op.get_op_id(), op_set)]):
206
+ if self.is_llm_convert:
207
+ print("Using NAIVE PRIORS for operator sampling order")
201
208
  self.priors = self._compute_naive_priors(op_set)
202
209
 
203
210
  # NOTE: self.priors is a dictionary with format:
@@ -770,7 +777,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
770
777
 
771
778
  # if the operator is a non-llm filter which has filtered out records, remove those records from
772
779
  # all downstream operators' full_op_id_to_sources_not_processed
773
- if isinstance(op_set[0], NonLLMFilter):
780
+ if isinstance(op_set[0], NonLLMFilter) and next_unique_logical_op_id is not None:
774
781
  self._remove_filtered_records_from_downstream_ops(topo_idx, plan, op_frontiers, source_indices_to_all_record_sets)
775
782
 
776
783
  # finalize plan stats
@@ -805,7 +812,7 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
805
812
  assert len(root_dataset_ids) == 1, f"Scan for {sample_op} has {len(root_dataset_ids)} > 1 root dataset ids"
806
813
  root_dataset_id = root_dataset_ids[0]
807
814
  source_indices = dataset_id_to_shuffled_source_indices[root_dataset_id]
808
- op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
815
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
809
816
  elif isinstance(sample_op, JoinOp):
810
817
  assert len(source_unique_logical_op_ids) == 2, f"Join for {sample_op} has {len(source_unique_logical_op_ids)} != 2 source logical operators"
811
818
  left_source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
@@ -814,10 +821,10 @@ class MABExecutionStrategy(SentinelExecutionStrategy):
814
821
  for left_source_idx in left_source_indices:
815
822
  for right_source_idx in right_source_indices:
816
823
  source_indices.append((left_source_idx, right_source_idx))
817
- op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
824
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
818
825
  else:
819
826
  source_indices = op_frontiers[source_unique_logical_op_ids[0]].source_indices
820
- op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors)
827
+ op_frontiers[unique_logical_op_id] = OpFrontier(op_set, source_unique_logical_op_ids, root_dataset_ids, source_indices, self.k, self.j, self.seed, self.policy, self.priors, self.dont_use_priors)
821
828
 
822
829
  # initialize and start the progress manager
823
830
  self.progress_manager = create_progress_manager(plan, sample_budget=self.sample_budget, sample_cost_budget=self.sample_cost_budget, progress=self.progress)
@@ -121,8 +121,10 @@ class ConvertOp(PhysicalOperator, ABC):
121
121
  generated_fields=field_names,
122
122
  total_input_tokens=per_record_stats.total_input_tokens,
123
123
  total_output_tokens=per_record_stats.total_output_tokens,
124
+ total_embedding_input_tokens=per_record_stats.total_embedding_input_tokens,
124
125
  total_input_cost=per_record_stats.total_input_cost,
125
126
  total_output_cost=per_record_stats.total_output_cost,
127
+ total_embedding_cost=per_record_stats.total_embedding_cost,
126
128
  llm_call_duration_secs=per_record_stats.llm_call_duration_secs,
127
129
  fn_call_duration_secs=per_record_stats.fn_call_duration_secs,
128
130
  total_llm_calls=per_record_stats.total_llm_calls,
@@ -89,8 +89,10 @@ class FilterOp(PhysicalOperator, ABC):
89
89
  filter_str=self.filter_obj.get_filter_str(),
90
90
  total_input_tokens=generation_stats.total_input_tokens,
91
91
  total_output_tokens=generation_stats.total_output_tokens,
92
+ total_embedding_input_tokens=generation_stats.total_embedding_input_tokens,
92
93
  total_input_cost=generation_stats.total_input_cost,
93
94
  total_output_cost=generation_stats.total_output_cost,
95
+ total_embedding_cost=generation_stats.total_embedding_cost,
94
96
  llm_call_duration_secs=generation_stats.llm_call_duration_secs,
95
97
  fn_call_duration_secs=generation_stats.fn_call_duration_secs,
96
98
  total_llm_calls=generation_stats.total_llm_calls,
@@ -27,6 +27,25 @@ from palimpzest.query.generators.generators import Generator
27
27
  from palimpzest.query.operators.physical import PhysicalOperator
28
28
 
29
29
 
30
+ class Singleton:
31
+ def __new__(cls, *args, **kw):
32
+ if not hasattr(cls, '_instance'):
33
+ orig = super(Singleton, cls) # noqa: UP008
34
+ cls._instance = orig.__new__(cls, *args, **kw)
35
+ return cls._instance
36
+
37
+ class Locks(Singleton):
38
+ model = None
39
+ clip_lock = threading.Lock()
40
+ exec_lock = threading.Lock()
41
+
42
+ @classmethod
43
+ def get_model(cls, model_name: str):
44
+ with cls.clip_lock:
45
+ if cls.model is None:
46
+ cls.model = SentenceTransformer(model_name)
47
+ return cls.model
48
+
30
49
  def compute_similarity(left_embedding: list[float], right_embedding: list[float]) -> float:
31
50
  """
32
51
  Compute the similarity between two embeddings using cosine similarity.
@@ -357,8 +376,10 @@ class LLMJoin(JoinOp):
357
376
  join_condition=self.condition,
358
377
  total_input_tokens=generation_stats.total_input_tokens,
359
378
  total_output_tokens=generation_stats.total_output_tokens,
379
+ total_embedding_input_tokens=generation_stats.total_embedding_input_tokens,
360
380
  total_input_cost=generation_stats.total_input_cost,
361
381
  total_output_cost=generation_stats.total_output_cost,
382
+ total_embedding_cost=generation_stats.total_embedding_cost,
362
383
  llm_call_duration_secs=generation_stats.llm_call_duration_secs,
363
384
  fn_call_duration_secs=generation_stats.fn_call_duration_secs,
364
385
  total_llm_calls=generation_stats.total_llm_calls,
@@ -487,8 +508,7 @@ class EmbeddingJoin(LLMJoin):
487
508
  if field_name.split(".")[-1] in self.get_input_fields()
488
509
  ])
489
510
  self.embedding_model = Model.TEXT_EMBEDDING_3_SMALL if self.text_only else Model.CLIP_VIT_B_32
490
- self.clip_model = None
491
- self._lock = threading.Lock()
511
+ self.locks = Locks()
492
512
 
493
513
  # keep track of embedding costs that could not be amortized if no output records were produced
494
514
  self.residual_embedding_cost = 0.0
@@ -560,28 +580,22 @@ class EmbeddingJoin(LLMJoin):
560
580
  quality=quality,
561
581
  )
562
582
 
563
- def _get_clip_model(self):
564
- with self._lock:
565
- if self.clip_model is None:
566
- self.clip_model = SentenceTransformer(self.embedding_model.value)
567
- return self.clip_model
568
-
569
583
  def _compute_embeddings(self, candidates: list[DataRecord], input_fields: list[str]) -> tuple[np.ndarray, GenerationStats]:
570
584
  # return empty array and empty stats if no candidates
571
585
  if len(candidates) == 0:
572
586
  return np.zeros((0, 512)), GenerationStats()
573
587
 
574
588
  start_time = time.time()
575
- total_input_tokens = 0
589
+ total_embedding_input_tokens = 0
576
590
  embeddings = None
577
591
  if self.text_only:
578
592
  client = OpenAI()
579
593
  inputs = [dr.to_json_str(bytes_to_str=True, project_cols=input_fields, sorted=True) for dr in candidates]
580
594
  response = client.embeddings.create(input=inputs, model=self.embedding_model.value)
581
- total_input_tokens = response.usage.total_tokens
595
+ total_embedding_input_tokens = response.usage.total_tokens
582
596
  embeddings = np.array([item.embedding for item in response.data])
583
597
  else:
584
- model = self._get_clip_model()
598
+ model = self.locks.get_model(self.embedding_model.value)
585
599
  embeddings = np.zeros((len(candidates), 512)) # CLIP embeddings are 512-dimensional
586
600
  num_input_fields_present = 0
587
601
  for field in input_fields:
@@ -604,14 +618,16 @@ class EmbeddingJoin(LLMJoin):
604
618
 
605
619
  # compute cost of embedding(s)
606
620
  model_card = MODEL_CARDS[self.embedding_model.value]
607
- total_input_cost = model_card["usd_per_input_token"] * total_input_tokens
621
+ total_embedding_cost = model_card["usd_per_input_token"] * total_embedding_input_tokens
608
622
  embedding_gen_stats = GenerationStats(
609
623
  model_name=self.embedding_model.value,
610
- total_input_tokens=total_input_tokens,
624
+ total_input_tokens=0.0,
611
625
  total_output_tokens=0.0,
612
- total_input_cost=total_input_cost,
626
+ total_embedding_input_tokens=total_embedding_input_tokens,
627
+ total_input_cost=0.0,
613
628
  total_output_cost=0.0,
614
- cost_per_record=total_input_cost,
629
+ total_embedding_cost=total_embedding_cost,
630
+ cost_per_record=total_embedding_cost,
615
631
  llm_call_duration_secs=time.time() - start_time,
616
632
  total_llm_calls=1,
617
633
  total_embedding_llm_calls=len(candidates),
@@ -623,7 +639,7 @@ class EmbeddingJoin(LLMJoin):
623
639
  output_record, output_record_op_stats = super()._process_join_candidate_pair(left_candidate, right_candidate, gen_kwargs)
624
640
  return output_record, output_record_op_stats, embedding_sim
625
641
 
626
- def _process_join_candidate_with_sim(self, left_candidate: DataRecord, right_candidate: DataRecord, passed_operator: bool) -> tuple[DataRecord, RecordOpStats]:
642
+ def _process_join_candidate_with_sim(self, left_candidate: DataRecord, right_candidate: DataRecord, embedding_sim: float, passed_operator: bool) -> tuple[DataRecord, RecordOpStats]:
627
643
  # compute output record and add to output_records
628
644
  join_dr = DataRecord.from_join_parents(self.output_schema, left_candidate, right_candidate)
629
645
  join_dr._passed_operator = passed_operator
@@ -656,7 +672,7 @@ class EmbeddingJoin(LLMJoin):
656
672
  op_details={k: str(v) for k, v in self.get_id_params().items()},
657
673
  )
658
674
 
659
- return join_dr, record_op_stats
675
+ return join_dr, record_op_stats, embedding_sim
660
676
 
661
677
  def __call__(self, left_candidates: list[DataRecord], right_candidates: list[DataRecord], final: bool = False) -> tuple[DataRecordSet, int]:
662
678
  # get the set of input fields from both records in the join
@@ -690,36 +706,50 @@ class EmbeddingJoin(LLMJoin):
690
706
  output_records, output_record_op_stats, num_inputs_processed = [], [], 0
691
707
 
692
708
  # draw samples until num_samples is reached
693
- if self.samples_drawn < self.num_samples:
694
- samples_to_draw = min(self.num_samples - self.samples_drawn, len(join_candidates))
695
- join_candidate_samples = join_candidates[:samples_to_draw]
696
- join_candidates = join_candidates[samples_to_draw:]
697
-
698
- # apply the generator to each pair of candidates
699
- with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
700
- futures = [
701
- executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim)
702
- for left_candidate, right_candidate, embedding_sim in join_candidate_samples
703
- ]
704
-
705
- # collect results as they complete
706
- for future in as_completed(futures):
707
- self.join_idx += 1
708
- join_output_record, join_output_record_op_stats, embedding_sim = future.result()
709
- output_records.append(join_output_record)
710
- output_record_op_stats.append(join_output_record_op_stats)
711
- print(f"{self.join_idx} JOINED")
712
-
713
- # update similarity thresholds
714
- records_joined = join_output_record._passed_operator
715
- if not records_joined and embedding_sim > self.max_non_matching_sim:
716
- self.max_non_matching_sim = embedding_sim
717
- if records_joined and embedding_sim < self.min_matching_sim:
718
- self.min_matching_sim = embedding_sim
719
-
720
- # update samples drawn and num_inputs_processed
721
- self.samples_drawn += samples_to_draw
722
- num_inputs_processed += samples_to_draw
709
+ with self.locks.exec_lock:
710
+ if self.samples_drawn < self.num_samples:
711
+ samples_to_draw = min(self.num_samples - self.samples_drawn, len(join_candidates))
712
+ join_candidate_samples = join_candidates[:samples_to_draw]
713
+ join_candidates = join_candidates[samples_to_draw:]
714
+
715
+ # apply the generator to each pair of candidates
716
+ with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
717
+ futures = [
718
+ executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim)
719
+ for left_candidate, right_candidate, embedding_sim in join_candidate_samples
720
+ ]
721
+
722
+ # collect results as they complete
723
+ similarities, joined = [], []
724
+ for future in as_completed(futures):
725
+ self.join_idx += 1
726
+ join_output_record, join_output_record_op_stats, embedding_sim = future.result()
727
+ output_records.append(join_output_record)
728
+ output_record_op_stats.append(join_output_record_op_stats)
729
+ similarities.append(embedding_sim)
730
+ joined.append(join_output_record._passed_operator)
731
+ print(f"{self.join_idx} JOINED")
732
+
733
+ # sort join results by embedding similarity
734
+ sorted_sim_join_tuples = sorted(zip(similarities, joined), key=lambda x: x[0])
735
+
736
+ # compute threshold below which no records joined
737
+ for embedding_sim, records_joined in sorted_sim_join_tuples:
738
+ if records_joined:
739
+ break
740
+ if not records_joined and embedding_sim > self.max_non_matching_sim:
741
+ self.max_non_matching_sim = embedding_sim
742
+
743
+ # compute threshold above which all records joined
744
+ for embedding_sim, records_joined in reversed(sorted_sim_join_tuples):
745
+ if not records_joined:
746
+ break
747
+ if records_joined and embedding_sim < self.min_matching_sim:
748
+ self.min_matching_sim = embedding_sim
749
+
750
+ # update samples drawn and num_inputs_processed
751
+ self.samples_drawn += samples_to_draw
752
+ num_inputs_processed += samples_to_draw
723
753
 
724
754
  # process remaining candidates based on embedding similarity
725
755
  if len(join_candidates) > 0:
@@ -727,43 +757,48 @@ class EmbeddingJoin(LLMJoin):
727
757
  with ThreadPoolExecutor(max_workers=self.join_parallelism) as executor:
728
758
  futures = []
729
759
  for left_candidate, right_candidate, embedding_sim in join_candidates:
730
- llm_call_needed = (
731
- self.min_matching_sim == float("inf")
732
- or self.max_non_matching_sim == float("-inf")
733
- or self.min_matching_sim <= embedding_sim <= self.max_non_matching_sim
734
- )
760
+ # if the embedding similarity is lower than the threshold below which no records joined,
761
+ # then we can skip the LLM call and mark the records as not joined
762
+ if embedding_sim < self.max_non_matching_sim:
763
+ futures.append(executor.submit(self._process_join_candidate_with_sim, left_candidate, right_candidate, embedding_sim, passed_operator=False))
735
764
 
736
- if llm_call_needed:
737
- futures.append(executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim))
765
+ # if the embedding similarity is higher than the threshold above which all records joined,
766
+ # then we can skip the LLM call and mark the records as joined
767
+ elif embedding_sim > self.min_matching_sim:
768
+ futures.append(executor.submit(self._process_join_candidate_with_sim, left_candidate, right_candidate, embedding_sim, passed_operator=True))
738
769
 
739
- elif embedding_sim < self.min_matching_sim:
740
- self.join_idx += 1
741
- output_record, record_op_stats = self._process_join_candidate_with_sim(left_candidate, right_candidate, passed_operator=False)
742
- output_records.append(output_record)
743
- output_record_op_stats.append(record_op_stats)
744
- print(f"{self.join_idx} SKIPPED (low sim: {embedding_sim:.4f} < {self.min_matching_sim:.4f})")
745
-
746
- elif embedding_sim > self.max_non_matching_sim:
747
- self.join_idx += 1
748
- output_record, record_op_stats = self._process_join_candidate_with_sim(left_candidate, right_candidate, passed_operator=True)
749
- output_records.append(output_record)
750
- output_record_op_stats.append(record_op_stats)
751
- print(f"{self.join_idx} JOINED (high sim: {embedding_sim:.4f} > {self.max_non_matching_sim:.4f})")
770
+ # otherwise, we will process the LLM call
771
+ else:
772
+ futures.append(executor.submit(self._process_join_candidate_pair, left_candidate, right_candidate, gen_kwargs, embedding_sim))
752
773
 
753
774
  num_inputs_processed += 1
754
775
 
755
776
  # collect results as they complete
777
+ similarities, joined = [], []
756
778
  for future in as_completed(futures):
757
779
  self.join_idx += 1
758
780
  join_output_record, join_output_record_op_stats, embedding_sim = future.result()
759
781
  output_records.append(join_output_record)
760
782
  output_record_op_stats.append(join_output_record_op_stats)
783
+ similarities.append(embedding_sim)
784
+ joined.append(join_output_record._passed_operator)
761
785
  print(f"{self.join_idx} JOINED")
762
786
 
763
- # update similarity thresholds
764
- records_joined = join_output_record._passed_operator
787
+ ### update thresholds if there are llm calls which incrementally squeeze the boundaries ###
788
+ # sort join results by embedding similarity
789
+ sorted_sim_join_tuples = sorted(zip(similarities, joined), key=lambda x: x[0])
790
+
791
+ # potentially update threshold below which no records joined
792
+ for embedding_sim, records_joined in sorted_sim_join_tuples:
793
+ if records_joined:
794
+ break
765
795
  if not records_joined and embedding_sim > self.max_non_matching_sim:
766
796
  self.max_non_matching_sim = embedding_sim
797
+
798
+ # potentially update threshold above which all records joined
799
+ for embedding_sim, records_joined in reversed(sorted_sim_join_tuples):
800
+ if not records_joined:
801
+ break
767
802
  if records_joined and embedding_sim < self.min_matching_sim:
768
803
  self.min_matching_sim = embedding_sim
769
804
 
@@ -109,15 +109,17 @@ class RAGConvert(LLMConvert):
109
109
 
110
110
  # compute the generation stats object
111
111
  model_card = MODEL_CARDS[model_name]
112
- total_input_tokens = response.usage.total_tokens
113
- total_input_cost = model_card["usd_per_input_token"] * total_input_tokens
112
+ total_embedding_input_tokens = response.usage.total_tokens
113
+ total_embedding_cost = model_card["usd_per_input_token"] * total_embedding_input_tokens
114
114
  embed_stats = GenerationStats(
115
115
  model_name=model_name, # NOTE: this should be overwritten by generation model in convert()
116
- total_input_tokens=total_input_tokens,
116
+ total_input_tokens=0.0,
117
117
  total_output_tokens=0.0,
118
- total_input_cost=total_input_cost,
118
+ total_embedding_input_tokens=total_embedding_input_tokens,
119
+ total_input_cost=0.0,
119
120
  total_output_cost=0.0,
120
- cost_per_record=total_input_cost,
121
+ total_embedding_cost=total_embedding_cost,
122
+ cost_per_record=total_embedding_cost,
121
123
  llm_call_duration_secs=total_time,
122
124
  total_llm_calls=1,
123
125
  total_embedding_llm_calls=1,
@@ -156,7 +158,7 @@ class RAGConvert(LLMConvert):
156
158
  # skip this field if it is not a string or a list of strings
157
159
  is_string_field = field.annotation in [str, str | None, str | Any]
158
160
  is_list_string_field = field.annotation in [list[str], list[str] | None, list[str] | Any]
159
- if not (is_string_field or is_list_string_field):
161
+ if not (is_string_field or is_list_string_field) or candidate[field_name] is None:
160
162
  continue
161
163
 
162
164
  # if this is a list of strings, join the strings
@@ -318,15 +320,17 @@ class RAGFilter(LLMFilter):
318
320
 
319
321
  # compute the generation stats object
320
322
  model_card = MODEL_CARDS[model_name]
321
- total_input_tokens = response.usage.total_tokens
322
- total_input_cost = model_card["usd_per_input_token"] * total_input_tokens
323
+ total_embedding_input_tokens = response.usage.total_tokens
324
+ total_embedding_cost = model_card["usd_per_input_token"] * total_embedding_input_tokens
323
325
  embed_stats = GenerationStats(
324
326
  model_name=model_name, # NOTE: this should be overwritten by generation model in filter()
325
- total_input_tokens=total_input_tokens,
327
+ total_input_tokens=0.0,
326
328
  total_output_tokens=0.0,
327
- total_input_cost=total_input_cost,
329
+ total_embedding_input_tokens=total_embedding_input_tokens,
330
+ total_input_cost=0.0,
328
331
  total_output_cost=0.0,
329
- cost_per_record=total_input_cost,
332
+ total_embedding_cost=total_embedding_cost,
333
+ cost_per_record=total_embedding_cost,
330
334
  llm_call_duration_secs=total_time,
331
335
  total_llm_calls=1,
332
336
  total_embedding_llm_calls=1,
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import os
4
+ import threading
4
5
  import time
5
6
  from typing import Callable
6
7
 
@@ -17,6 +18,24 @@ from palimpzest.core.models import GenerationStats, OperatorCostEstimates, Recor
17
18
  from palimpzest.query.operators.physical import PhysicalOperator
18
19
 
19
20
 
21
+ class Singleton:
22
+ def __new__(cls, *args, **kw):
23
+ if not hasattr(cls, '_instance'):
24
+ orig = super(Singleton, cls) # noqa: UP008
25
+ cls._instance = orig.__new__(cls, *args, **kw)
26
+ return cls._instance
27
+
28
+ class ClipModel(Singleton):
29
+ model = None
30
+ lock = threading.Lock()
31
+
32
+ @classmethod
33
+ def get_model(cls, model_name: str):
34
+ with cls.lock:
35
+ if cls.model is None:
36
+ cls.model = SentenceTransformer(model_name)
37
+ return cls.model
38
+
20
39
  class TopKOp(PhysicalOperator):
21
40
  def __init__(
22
41
  self,
@@ -56,6 +75,7 @@ class TopKOp(PhysicalOperator):
56
75
  self.output_attrs = output_attrs
57
76
  self.search_func = search_func if search_func is not None else self.default_search_func
58
77
  self.k = k
78
+ self.clip_model = ClipModel()
59
79
 
60
80
  def __str__(self):
61
81
  op = super().__str__()
@@ -185,7 +205,6 @@ class TopKOp(PhysicalOperator):
185
205
  # construct and return the record set
186
206
  return DataRecordSet(drs, record_op_stats_lst)
187
207
 
188
-
189
208
  def __call__(self, candidate: DataRecord) -> DataRecordSet:
190
209
  start_time = time.time()
191
210
 
@@ -209,9 +228,9 @@ class TopKOp(PhysicalOperator):
209
228
  inputs, gen_stats = None, GenerationStats()
210
229
  if isinstance(self.index, Collection):
211
230
  uses_openai_embedding_fcn = isinstance(self.index._embedding_function, OpenAIEmbeddingFunction)
212
- uses_sentence_transformer_embedding_fcn = isinstance(self.index._embedding_function, SentenceTransformerEmbeddingFunction)
231
+ uses_clip_model = isinstance(self.index._embedding_function, SentenceTransformerEmbeddingFunction)
213
232
  error_msg = "ChromaDB index must use OpenAI or SentenceTransformer embedding function; see: https://docs.trychroma.com/integrations/embedding-models/openai"
214
- assert uses_openai_embedding_fcn or uses_sentence_transformer_embedding_fcn, error_msg
233
+ assert uses_openai_embedding_fcn or uses_clip_model, error_msg
215
234
 
216
235
  model_name = self.index._embedding_function.model_name if uses_openai_embedding_fcn else "clip-ViT-B-32"
217
236
  err_msg = f"For Chromadb, we currently only support `text-embedding-3-small` and `clip-ViT-B-32`; your index uses: {model_name}"
@@ -228,8 +247,8 @@ class TopKOp(PhysicalOperator):
228
247
  total_input_tokens = response.usage.total_tokens
229
248
  inputs = [item.embedding for item in response.data]
230
249
 
231
- elif uses_sentence_transformer_embedding_fcn:
232
- model = SentenceTransformer(model_name)
250
+ elif uses_clip_model:
251
+ model = self.clip_model.get_model(model_name)
233
252
  inputs = model.encode(query)
234
253
 
235
254
  embed_total_time = time.time() - embed_start_time
@@ -105,9 +105,10 @@ class SampleBasedCostModel:
105
105
  "time_per_record": record_op_stats.time_per_record,
106
106
  "quality": record_op_stats.quality,
107
107
  "passed_operator": record_op_stats.passed_operator,
108
- "source_indices": record_op_stats.record_source_indices, # TODO: remove
109
- "op_details": record_op_stats.op_details, # TODO: remove
110
- "answer": record_op_stats.answer, # TODO: remove
108
+ "source_indices": record_op_stats.record_source_indices,
109
+ "op_details": record_op_stats.op_details,
110
+ "answer": record_op_stats.answer,
111
+ "op_name": record_op_stats.op_name,
111
112
  }
112
113
  execution_record_op_stats.append(record_op_stats_dict)
113
114
 
@@ -128,8 +129,12 @@ class SampleBasedCostModel:
128
129
  else physical_op_df.source_indices.apply(tuple).nunique()
129
130
  )
130
131
 
131
- # compute selectivity
132
+ # compute selectivity; for filters this may be 1.0 on smalle samples;
133
+ # always put something slightly less than 1.0 to ensure that filters are pushed down when possible
132
134
  selectivity = physical_op_df.passed_operator.sum() / num_source_records
135
+ op_name = physical_op_df.op_name.iloc[0].lower()
136
+ if selectivity == 1.0 and "filter" in op_name:
137
+ selectivity -= 1e-3
133
138
 
134
139
  # compute quality; if all qualities are None then this will be NaN
135
140
  quality = physical_op_df.quality.mean()
@@ -48,6 +48,7 @@ class QueryProcessorConfig(BaseModel):
48
48
  seed: int = Field(default=42)
49
49
  exp_name: str | None = Field(default=None)
50
50
  priors: dict | None = Field(default=None)
51
+ dont_use_priors: bool = Field(default=False)
51
52
 
52
53
  def to_dict(self) -> dict:
53
54
  """Convert the config to a dict representation."""
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import os
2
3
  from enum import Enum
3
4
 
4
5
  from palimpzest.core.data.dataset import Dataset
@@ -91,6 +92,27 @@ class QueryProcessorFactory:
91
92
  # set the final set of available models in the config
92
93
  config.available_models = available_models
93
94
 
95
+ if len(config.available_models) == 0:
96
+ raise ValueError("No available models found.")
97
+
98
+ openai_key = os.getenv("OPENAI_API_KEY")
99
+ anthropic_key = os.getenv("ANTHROPIC_API_KEY")
100
+ together_key = os.getenv("TOGETHER_API_KEY")
101
+ gemini_key = os.getenv("GEMINI_API_KEY")
102
+ google_key = os.getenv("GOOGLE_API_KEY")
103
+
104
+ for model in config.available_models:
105
+ if model.is_openai_model() and not openai_key:
106
+ raise ValueError("OPENAI_API_KEY must be set to use OpenAI models.")
107
+ if model.is_anthropic_model() and not anthropic_key:
108
+ raise ValueError("ANTHROPIC_API_KEY must be set to use Anthropic models.")
109
+ if model.is_together_model() and not together_key:
110
+ raise ValueError("TOGETHER_API_KEY must be set to use Together models.")
111
+ if model.is_google_model() and not (gemini_key or google_key or config.gemini_credentials_path):
112
+ raise ValueError("GEMINI_API_KEY, GOOGLE_API_KEY, or gemini_credentials path must be set to use Google Gemini models.")
113
+ if model.is_vllm_model() and config.api_base is None:
114
+ raise ValueError("api_base must be set to use vLLM models.")
115
+
94
116
  return config, validator
95
117
 
96
118
  @classmethod
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: palimpzest
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language
5
5
  Author-email: MIT DSG Semantic Management Lab <michjc@csail.mit.edu>
6
6
  Project-URL: homepage, https://palimpzest.org
@@ -31,7 +31,7 @@ Requires-Dist: pillow>=11.3.0
31
31
  Requires-Dist: prettytable>=3.9.0
32
32
  Requires-Dist: psutil==5.9.5
33
33
  Requires-Dist: PyLD>=2.0.4
34
- Requires-Dist: pyarrow==20.0.0
34
+ Requires-Dist: pyarrow>=20.0.0
35
35
  Requires-Dist: pypdf>=5.1.0
36
36
  Requires-Dist: pytest-mock>=3.14.0
37
37
  Requires-Dist: pyyaml>=6.0.1
@@ -5,7 +5,7 @@ palimpzest/agents/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
5
5
  palimpzest/agents/compute_agents.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  palimpzest/agents/search_agents.py,sha256=t2QMreB5Ph71aoNk5bBtV-0l8im79z-pMAR3JDAySDw,29418
7
7
  palimpzest/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- palimpzest/core/models.py,sha256=FKyKW9PqmqpnDGWOINNT6XgBj0raaAskxtdNdFZ4Zyw,42688
8
+ palimpzest/core/models.py,sha256=t4zHPA-Nrz2Mmq2EZfJWU_CsSbzu4LFv6_Wob10MZnc,46110
9
9
  palimpzest/core/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  palimpzest/core/data/context.py,sha256=x1xYyu9qW65dvtK_XayIfv_CgsCEPW6Qe0DTiSf9sjU,16207
11
11
  palimpzest/core/data/context_manager.py,sha256=8hAKWD2jhFZgghTu7AYgjkvKDsJUPVxq8g4nG0HWvfo,6150
@@ -28,7 +28,7 @@ palimpzest/prompts/filter_prompts.py,sha256=D-aY3-th1GzEHrVGbKORVN2R7x7coYGjp8Fr
28
28
  palimpzest/prompts/join_prompts.py,sha256=z-y4L1cw1O3I_F9DW6MvqeztdQoKDQawX6nK6vQAkdM,2916
29
29
  palimpzest/prompts/moa_aggregator_prompts.py,sha256=b5cz4G2oF86LlHOy8vmtxoMcZ9zaZoppKrURHgzCzNU,5248
30
30
  palimpzest/prompts/moa_proposer_prompts.py,sha256=yfZYwmCg-Tg9h0H7PJMEuDYPR45EbYnORmVX6cY2vRQ,3125
31
- palimpzest/prompts/prompt_factory.py,sha256=0xj3glD5Y7R7MUsmKxJCOa4q9VeIILDO2IVWz_4huYw,49355
31
+ palimpzest/prompts/prompt_factory.py,sha256=txtCvDI0sv_LEar0iK_E1_mlRMvuwZseM-6BSC9ugUs,49926
32
32
  palimpzest/prompts/split_merge_prompts.py,sha256=hX-MThmW4VU7rjgm7gb-bpniEMdj25mtp0o8qBeWvIQ,5573
33
33
  palimpzest/prompts/split_proposer_prompts.py,sha256=Ucqwfn4FqFk-b9E024EK4e_3_QndTJjggwiwa1x5CQs,3115
34
34
  palimpzest/prompts/utils.py,sha256=Eure2pqm8Ftme9lQlHwFL9EqK3yjH14WQHofnQINce4,7497
@@ -36,9 +36,9 @@ palimpzest/prompts/validator.py,sha256=OxebGjvXNBy0Cq79XI3aPRbongzOdtHH6mQctpbWc
36
36
  palimpzest/query/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  palimpzest/query/execution/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
38
  palimpzest/query/execution/all_sample_execution_strategy.py,sha256=d2MO_AmXF_HbV4rUNkFqHsuoBCofU98zQ-D3Q06BXCc,14369
39
- palimpzest/query/execution/execution_strategy.py,sha256=PkXtO8wZpv6HHLlpxXwcc0t5pPCGafp0L_iVkG9bmXM,19162
39
+ palimpzest/query/execution/execution_strategy.py,sha256=TnSInUlcGZHn2GUpLiIFVgfPpmsNfIKKgElnRt6a6ss,19248
40
40
  palimpzest/query/execution/execution_strategy_type.py,sha256=vRQBPCQN5_aoyD3TLIeW3VPo15mqF-5RBvEXkENz9FE,987
41
- palimpzest/query/execution/mab_execution_strategy.py,sha256=i03LYRhaG2VLia-XSiYbKdlu3hLQZul75xMcRGm065M,47767
41
+ palimpzest/query/execution/mab_execution_strategy.py,sha256=BLRTSQXPeWBlJ_-8GAFHj2fbIY_eoPhuWeDcIdOokcg,48247
42
42
  palimpzest/query/execution/parallel_execution_strategy.py,sha256=Di-8d7waE0bev4kNDXEJJqQ0wwQ87_sPV-t5qFtAlPQ,17589
43
43
  palimpzest/query/execution/single_threaded_execution_strategy.py,sha256=1rjMel0-AI6KUi_SMNgPPXxMgG5-t9lenLKoYEClgjk,17464
44
44
  palimpzest/query/generators/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -46,23 +46,23 @@ palimpzest/query/generators/generators.py,sha256=OV2HBZvCQtvhj6rOwti_8dpQX_bqTZ-
46
46
  palimpzest/query/operators/__init__.py,sha256=B9zr_VmUs6YRep4fjbj7e0aTM6T9-GrqbY7tKWxEdkc,4734
47
47
  palimpzest/query/operators/aggregate.py,sha256=nQ6Zh1DYeqDoIDwkPQDw8QCwW0y52sGC-No6uPSRc8A,27367
48
48
  palimpzest/query/operators/compute.py,sha256=X_pWN45smg8L4dV54nOae7dldQGL1nJVlVyJ3ULWSmI,8432
49
- palimpzest/query/operators/convert.py,sha256=cjUPrSgvZBZXBbrbepIxZMBXjbWWPLuTX4JwLyvVg2U,16050
49
+ palimpzest/query/operators/convert.py,sha256=beJLS-vnVc_VLnngoDKTj-k_Ul7GmDA-yMKM6-mX5Ho,16218
50
50
  palimpzest/query/operators/critique_and_refine.py,sha256=Q-NhasVoD9meX7g36RPrv3q4R48_8XEU4d3TE46hRJI,8979
51
51
  palimpzest/query/operators/distinct.py,sha256=ZTXlIS7IaFRTsWv9RemzCo1JLz25vEma-TB42CV5fJQ,2614
52
- palimpzest/query/operators/filter.py,sha256=ufREsO2-8CBk4u4fabDBYpEvb806E11EOyW-wuRs4vw,10356
53
- palimpzest/query/operators/join.py,sha256=17BGzrxf_fkqhnEzhq-5b0qv2qQTw7z6job5YkBUrZE,36993
52
+ palimpzest/query/operators/filter.py,sha256=h559CweLdcWw_-LPyR2h04pKsm-jVM_Kazif-BXYBFo,10516
53
+ palimpzest/query/operators/join.py,sha256=hxIbSSDMy_5bch7kpYQ-iQS5XEth_T19cTSlXd76gdg,38845
54
54
  palimpzest/query/operators/limit.py,sha256=pdo7WfWY97SW3c-WqZ4SIPw7lHIVbaXPEWqHyK8qkF8,2130
55
55
  palimpzest/query/operators/logical.py,sha256=OtB82L1X19ibtLx1GIfeXXyO7YfjkFmh3puIUgqKQRE,21160
56
56
  palimpzest/query/operators/mixture_of_agents.py,sha256=KC-ZpjtGY28sfwlk2TpduLC_fADj_UASFCaicaKqSFc,11671
57
57
  palimpzest/query/operators/physical.py,sha256=0_BfFX9nKuN__440eAfEfApWAoGOunVSCZIQxS4HO2Y,9773
58
58
  palimpzest/query/operators/project.py,sha256=gxbjsHEACCP9wxATH-mw6wOSUi5s13JyWsLqqhAYWXQ,2111
59
- palimpzest/query/operators/rag.py,sha256=CJm83pBapA8HEGfhRnWjqt_ESS6hJAPvPJksRTOGL7M,20124
59
+ palimpzest/query/operators/rag.py,sha256=ZDloc4nC8foI3rTSHQxqduAVPj5LV8xMu_ng0EjDOA0,20409
60
60
  palimpzest/query/operators/scan.py,sha256=OqCiPRTvTY7SbauNMyFvGT5nRVeRzVsGYSrkoN1Ib_w,7407
61
61
  palimpzest/query/operators/search.py,sha256=cQin-Qc9FT7V0Gv3-pxMLbVMjqE6ALe99V0OrQhA6CI,22711
62
62
  palimpzest/query/operators/split.py,sha256=oLzwnYb8TNf3XA9TMKEAIw7EIA12wHneaD42BNLIHiI,15043
63
- palimpzest/query/operators/topk.py,sha256=92Bu98xc8CMlS9bf1xc0FxcfVuhv6j4x_303Aq1v-U0,13053
63
+ palimpzest/query/operators/topk.py,sha256=MZl83Cu43QmN4skjlfpR8EVFFCgA7sR6PbGgBGWC0tg,13564
64
64
  palimpzest/query/optimizer/__init__.py,sha256=v9fSBOL2p3sQew4LrN2DQUPe0WezO328Hr54qBTqrAs,2799
65
- palimpzest/query/optimizer/cost_model.py,sha256=p7AsR6f4VYdGjrUKPGN_VTErY36GjY90Bsvsys4le2M,12655
65
+ palimpzest/query/optimizer/cost_model.py,sha256=JaxdLuUZuq52BJ52YdW4ChfWptwXsh7Rk7oaPCn_gWc,12956
66
66
  palimpzest/query/optimizer/optimizer.py,sha256=ksLkzQ2sVgJFbkxGF3ncF74EsAHZFos8G19xlHQrtJo,20063
67
67
  palimpzest/query/optimizer/optimizer_strategy.py,sha256=0foDaBHqQehK_zz6IlDEbNIw-44wxY6LO5H1anJi56Y,10042
68
68
  palimpzest/query/optimizer/optimizer_strategy_type.py,sha256=V-MMHvJdnfZKoUX1xxxwh66q1RjN2FL35IsiT1C62c8,1084
@@ -71,9 +71,9 @@ palimpzest/query/optimizer/primitives.py,sha256=jMMVq37y1tWiPU1lSSKQP9OP-mzkpSxS
71
71
  palimpzest/query/optimizer/rules.py,sha256=awhe76trskv5Tq5E2QHpUN_YV6jH8INywa0Ige8IIhY,53341
72
72
  palimpzest/query/optimizer/tasks.py,sha256=DNJjY2QldfKFWj6INHElMh88dYc36Z5m3wHwbs4jyF4,30455
73
73
  palimpzest/query/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
74
- palimpzest/query/processor/config.py,sha256=MkZ776VUk9tIOCZdVyH__H4Z0gO4c8fpehX2Gqywvks,2472
74
+ palimpzest/query/processor/config.py,sha256=8-MpPYHv2SI4dub4MP_gOYSRxO80_ALLuWRxD-F2YOg,2521
75
75
  palimpzest/query/processor/query_processor.py,sha256=T4ffPbnOX23G8FDITzmM7Iw7DUEDWIHnwl8XLYllgjg,6240
76
- palimpzest/query/processor/query_processor_factory.py,sha256=i9L9StqlUi7m1AqZMuYQWhunqOJi3nLK47skhxq9tIA,8317
76
+ palimpzest/query/processor/query_processor_factory.py,sha256=l9f0C0lngOihZDzH0TK9WdKR9CwwgB6IbNZftonSFR0,9576
77
77
  palimpzest/schemabuilder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
78
  palimpzest/schemabuilder/schema_builder.py,sha256=QraGp66dcD-ej6Y2mER40o86G9JqlBkL7swkJzjUAIY,7968
79
79
  palimpzest/tools/README.md,sha256=56_6LPG80uc0CLVhTBP6I1wgIffNv9cyTr0TmVZqmrM,483
@@ -89,8 +89,8 @@ palimpzest/utils/progress.py,sha256=eHXrTPTCRHjMdK0EjYRUzSxcV6N1lK8TS3Ju_ZlQLhY,
89
89
  palimpzest/utils/udfs.py,sha256=LjHic54B1az-rKgNLur0wOpaz2ko_UodjLEJrazkxvY,1854
90
90
  palimpzest/validator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
91
91
  palimpzest/validator/validator.py,sha256=SvjK09zCpGtK0yM0OasvQlSzyq3loy32DyOOKRmYXC0,15977
92
- palimpzest-1.1.0.dist-info/licenses/LICENSE,sha256=5GUlHy9lr-Py9kvV38FF1m3yy3NqM18fefuE9wkWumo,1079
93
- palimpzest-1.1.0.dist-info/METADATA,sha256=0AZq33WMFrxkarQADVPv2OFQu7ko38fzhBOtTQjc3Fw,5359
94
- palimpzest-1.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
95
- palimpzest-1.1.0.dist-info/top_level.txt,sha256=raV06dJUgohefUn3ZyJS2uqp_Y76EOLA9Y2e_fxt8Ew,11
96
- palimpzest-1.1.0.dist-info/RECORD,,
92
+ palimpzest-1.2.0.dist-info/licenses/LICENSE,sha256=5GUlHy9lr-Py9kvV38FF1m3yy3NqM18fefuE9wkWumo,1079
93
+ palimpzest-1.2.0.dist-info/METADATA,sha256=IKxg8RllEvn6dgboEJVnxdnd5RwYmXFhIL2FHvoYpWw,5359
94
+ palimpzest-1.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
95
+ palimpzest-1.2.0.dist-info/top_level.txt,sha256=raV06dJUgohefUn3ZyJS2uqp_Y76EOLA9Y2e_fxt8Ew,11
96
+ palimpzest-1.2.0.dist-info/RECORD,,