palimpzest 0.6.1__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  """This file contains factory methods which return template prompts and return messages for chat payloads."""
2
+
2
3
  import base64
3
4
  import json
4
5
  from string import Formatter
@@ -82,6 +83,7 @@ from palimpzest.prompts.util_phrases import (
82
83
 
83
84
  class PromptFactory:
84
85
  """Factory class for generating prompts for the Generator given the input(s)."""
86
+
85
87
  BASE_SYSTEM_PROMPT_MAP = {
86
88
  PromptStrategy.COT_BOOL: COT_BOOL_BASE_SYSTEM_PROMPT,
87
89
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_SYSTEM_PROMPT,
@@ -148,13 +150,13 @@ class PromptFactory:
148
150
  longest_field_name, longest_field_length = sorted_fields[0]
149
151
 
150
152
  # trim the field
151
- context_factor = MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT / (total_context_len * TOKENS_PER_CHARACTER)
153
+ context_factor = MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT / (total_context_len * TOKENS_PER_CHARACTER)
152
154
  keep_frac_idx = int(longest_field_length * context_factor)
153
155
  context[longest_field_name] = context[longest_field_name][:keep_frac_idx]
154
156
 
155
157
  # update total context length
156
158
  total_context_len = len(json.dumps(context, indent=2))
157
-
159
+
158
160
  return json.dumps(context, indent=2)
159
161
 
160
162
  def _get_input_fields(self, candidate: DataRecord, **kwargs) -> list[str]:
@@ -201,7 +203,11 @@ class PromptFactory:
201
203
  """
202
204
  output_fields_desc = ""
203
205
  output_schema: Schema = kwargs.get("output_schema")
204
- if self.prompt_strategy.is_cot_qa_prompt():
206
+ if (
207
+ self.prompt_strategy.is_cot_qa_prompt()
208
+ or self.prompt_strategy.is_moa_proposer_prompt()
209
+ or self.prompt_strategy.is_moa_aggregator_prompt()
210
+ ):
205
211
  assert output_schema is not None, "Output schema must be provided for convert prompts."
206
212
 
207
213
  field_desc_map = output_schema.field_desc_map()
@@ -230,14 +236,16 @@ class PromptFactory:
230
236
 
231
237
  Args:
232
238
  kwargs: The keyword arguments provided by the user.
233
-
239
+
234
240
  Returns:
235
241
  str | None: The original output.
236
242
  """
237
243
  original_output = kwargs.get("original_output")
238
244
  if self.prompt_strategy.is_critic_prompt() or self.prompt_strategy.is_refine_prompt():
239
- assert original_output is not None, "Original output must be provided for critique and refinement operations."
240
-
245
+ assert original_output is not None, (
246
+ "Original output must be provided for critique and refinement operations."
247
+ )
248
+
241
249
  return original_output
242
250
 
243
251
  def _get_critique_output(self, **kwargs) -> str | None:
@@ -246,7 +254,7 @@ class PromptFactory:
246
254
 
247
255
  Args:
248
256
  kwargs: The keyword arguments provided by the user.
249
-
257
+
250
258
  Returns:
251
259
  str | None: The critique output.
252
260
  """
@@ -259,10 +267,10 @@ class PromptFactory:
259
267
  def _get_model_responses(self, **kwargs) -> str | None:
260
268
  """
261
269
  Returns the model responses for the mixture-of-agents aggregation operation.
262
-
270
+
263
271
  Args:
264
272
  kwargs: The keyword arguments provided by the user.
265
-
273
+
266
274
  Returns:
267
275
  str | None: The model responses.
268
276
  """
@@ -314,9 +322,7 @@ class PromptFactory:
314
322
  critique_criteria = None
315
323
  if self.prompt_strategy.is_critic_prompt():
316
324
  critique_criteria = (
317
- COT_QA_IMAGE_CRITIQUE_CRITERIA
318
- if self.prompt_strategy.is_image_prompt()
319
- else COT_QA_CRITIQUE_CRITERIA
325
+ COT_QA_IMAGE_CRITIQUE_CRITERIA if self.prompt_strategy.is_image_prompt() else COT_QA_CRITIQUE_CRITERIA
320
326
  )
321
327
 
322
328
  return critique_criteria
@@ -467,16 +473,18 @@ class PromptFactory:
467
473
 
468
474
  return prompt_strategy_to_example_answer.get(self.prompt_strategy)
469
475
 
470
- def _get_all_format_kwargs(self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], **kwargs) -> dict:
476
+ def _get_all_format_kwargs(
477
+ self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], **kwargs
478
+ ) -> dict:
471
479
  """
472
480
  Returns a dictionary containing all the format kwargs for templating the prompts.
473
-
481
+
474
482
  Args:
475
483
  candidate (DataRecord): The input record.
476
484
  input_fields (list[str]): The input fields.
477
485
  output_fields (list[str]): The output fields.
478
486
  kwargs: The keyword arguments provided by the user.
479
-
487
+
480
488
  Returns:
481
489
  dict: The dictionary containing all the format kwargs.
482
490
  """
@@ -517,7 +525,7 @@ class PromptFactory:
517
525
  Args:
518
526
  candidate (DataRecord): The input record.
519
527
  input_fields (list[str]): The list of input fields.
520
-
528
+
521
529
  Returns:
522
530
  list[dict]: The image messages for the chat payload.
523
531
  """
@@ -529,15 +537,19 @@ class PromptFactory:
529
537
 
530
538
  # image filepath (or list of image filepaths)
531
539
  if isinstance(field_type, ImageFilepathField):
532
- with open(field_value, 'rb') as f:
533
- base64_image_str = base64.b64encode(f.read()).decode('utf-8')
534
- image_messages.append({"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"})
540
+ with open(field_value, "rb") as f:
541
+ base64_image_str = base64.b64encode(f.read()).decode("utf-8")
542
+ image_messages.append(
543
+ {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
544
+ )
535
545
 
536
546
  elif hasattr(field_type, "element_type") and issubclass(field_type.element_type, ImageFilepathField):
537
547
  for image_filepath in field_value:
538
- with open(image_filepath, 'rb') as f:
539
- base64_image_str = base64.b64encode(f.read()).decode('utf-8')
540
- image_messages.append({"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"})
548
+ with open(image_filepath, "rb") as f:
549
+ base64_image_str = base64.b64encode(f.read()).decode("utf-8")
550
+ image_messages.append(
551
+ {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
552
+ )
541
553
 
542
554
  # image url (or list of image urls)
543
555
  elif isinstance(field_type, ImageURLField):
@@ -550,12 +562,16 @@ class PromptFactory:
550
562
  # pre-encoded images (or list of pre-encoded images)
551
563
  elif isinstance(field_type, ImageBase64Field):
552
564
  base64_image_str = field_value.decode("utf-8")
553
- image_messages.append({"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"})
565
+ image_messages.append(
566
+ {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
567
+ )
554
568
 
555
569
  elif hasattr(field_type, "element_type") and issubclass(field_type.element_type, ImageBase64Field):
556
570
  for base64_image in field_value:
557
571
  base64_image_str = base64_image.decode("utf-8")
558
- image_messages.append({"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"})
572
+ image_messages.append(
573
+ {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
574
+ )
559
575
 
560
576
  return image_messages
561
577
 
@@ -595,15 +611,15 @@ class PromptFactory:
595
611
 
596
612
  # get any image messages for the chat payload (will be an empty list if this is not an image prompt)
597
613
  image_messages = (
598
- self._create_image_messages(candidate, input_fields)
599
- if self.prompt_strategy.is_image_prompt()
600
- else []
614
+ self._create_image_messages(candidate, input_fields) if self.prompt_strategy.is_image_prompt() else []
601
615
  )
602
616
 
603
617
  # get any original messages for critique and refinement operations
604
618
  original_messages = kwargs.get("original_messages")
605
619
  if self.prompt_strategy.is_critic_prompt() or self.prompt_strategy.is_refine_prompt():
606
- assert original_messages is not None, "Original messages must be provided for critique and refinement operations."
620
+ assert original_messages is not None, (
621
+ "Original messages must be provided for critique and refinement operations."
622
+ )
607
623
 
608
624
  # construct the user messages based on the prompt strategy
609
625
  user_messages = []
@@ -661,17 +677,19 @@ class PromptFactory:
661
677
  f"Input fields: {input_fields}\n"
662
678
  )
663
679
  assert fields_check, err_msg
664
-
680
+
665
681
  # build set of format kwargs
666
682
  format_kwargs = {
667
- field_name: "<bytes>" if isinstance(candidate.get_field_type(field_name), BytesField) else candidate[field_name]
683
+ field_name: "<bytes>"
684
+ if isinstance(candidate.get_field_type(field_name), BytesField)
685
+ else candidate[field_name]
668
686
  for field_name in input_fields
669
687
  }
670
688
 
671
689
  # split prompt on <<image-placeholder>> if it exists
672
690
  if "<<image-placeholder>>" in user_prompt:
673
691
  raise NotImplementedError("Image prompts are not yet supported.")
674
-
692
+
675
693
  prompt_sections = user_prompt.split("<<image-placeholder>>")
676
694
  messages = [{"role": "user", "type": "text", "content": prompt_sections[0].format(**format_kwargs)}]
677
695
 
@@ -686,7 +704,7 @@ class PromptFactory:
686
704
  def create_messages(self, candidate: DataRecord, output_fields: list[str], **kwargs) -> list[dict]:
687
705
  """
688
706
  Creates the messages for the chat payload based on the prompt strategy.
689
-
707
+
690
708
  Each message will be a dictionary with the following format:
691
709
  {
692
710
  "role": "user" | "system",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: palimpzest
3
- Version: 0.6.1
3
+ Version: 0.6.2
4
4
  Summary: Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language
5
5
  Author-email: MIT DSG Semantic Management Lab <michjc@csail.mit.edu>
6
6
  Project-URL: homepage, https://palimpzest.org
@@ -20,7 +20,7 @@ palimpzest/prompts/critique_and_refine_convert_prompts.py,sha256=WoXExBxQ7twswd9
20
20
  palimpzest/prompts/filter_prompts.py,sha256=iQjn-39h3L0E5wng_UPgAXRHrP1ok329TXpOgZ6Wn1w,2372
21
21
  palimpzest/prompts/moa_aggregator_convert_prompts.py,sha256=BQRrtGdr53PTqvXzmFh8kfQ_w9KoKw-zTtmdo-8RFjo,2887
22
22
  palimpzest/prompts/moa_proposer_convert_prompts.py,sha256=d_hOh0-0m6HWBDAxUu7W3WyQtSTlUvqio3nzpnX2bxM,3642
23
- palimpzest/prompts/prompt_factory.py,sha256=VzZNH9kblFXYn4YKVKudJ21Y5Q-3tL6ZgFmNhBNTGjQ,31921
23
+ palimpzest/prompts/prompt_factory.py,sha256=Y1R3sRLoeQt6YUbw-4Tv5oj57Hu39IA9shKkgTtoMks,32184
24
24
  palimpzest/prompts/util_phrases.py,sha256=NWrcHfjJyiOY16Jyt7R50moVnlJDyvSBZ9kBqyX2WQo,751
25
25
  palimpzest/query/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  palimpzest/query/execution/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -80,8 +80,8 @@ palimpzest/utils/progress.py,sha256=GYmPUBdG7xmqbqj1UiSNP-pWZKmRMLX797MBgrOPugM,
80
80
  palimpzest/utils/sandbox.py,sha256=Ge96gmzqeOGlNkMCG9A95_PB8wRQbvTFua136of8FcA,6465
81
81
  palimpzest/utils/token_reduction_helpers.py,sha256=Ob95PcqCsbGLiBdQ-4YQsWGWRppb2hvQyt0gi1fzL-Y,3855
82
82
  palimpzest/utils/udfs.py,sha256=LjHic54B1az-rKgNLur0wOpaz2ko_UodjLEJrazkxvY,1854
83
- palimpzest-0.6.1.dist-info/LICENSE,sha256=5GUlHy9lr-Py9kvV38FF1m3yy3NqM18fefuE9wkWumo,1079
84
- palimpzest-0.6.1.dist-info/METADATA,sha256=VxPI4-vfq3Fm3l3PjxTpdHGbDclIQNHo1Ag1enfAyMU,7837
85
- palimpzest-0.6.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
86
- palimpzest-0.6.1.dist-info/top_level.txt,sha256=raV06dJUgohefUn3ZyJS2uqp_Y76EOLA9Y2e_fxt8Ew,11
87
- palimpzest-0.6.1.dist-info/RECORD,,
83
+ palimpzest-0.6.2.dist-info/LICENSE,sha256=5GUlHy9lr-Py9kvV38FF1m3yy3NqM18fefuE9wkWumo,1079
84
+ palimpzest-0.6.2.dist-info/METADATA,sha256=yBFEceRsylGwr8q8W0LqugpBmnqsey5SgKiQhPZhYgI,7837
85
+ palimpzest-0.6.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
86
+ palimpzest-0.6.2.dist-info/top_level.txt,sha256=raV06dJUgohefUn3ZyJS2uqp_Y76EOLA9Y2e_fxt8Ew,11
87
+ palimpzest-0.6.2.dist-info/RECORD,,