palimpzest 0.8.7__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,12 @@ from palimpzest.core.lib.schemas import (
24
24
  ImageFilepath,
25
25
  ImageURL,
26
26
  )
27
+ from palimpzest.prompts.aggregate_prompts import (
28
+ AGG_BASE_SYSTEM_PROMPT,
29
+ AGG_BASE_USER_PROMPT,
30
+ AGG_NO_REASONING_BASE_SYSTEM_PROMPT,
31
+ AGG_NO_REASONING_BASE_USER_PROMPT,
32
+ )
27
33
  from palimpzest.prompts.convert_prompts import (
28
34
  MAP_BASE_SYSTEM_PROMPT,
29
35
  MAP_BASE_USER_PROMPT,
@@ -79,6 +85,12 @@ from palimpzest.prompts.split_proposer_prompts import (
79
85
  MAP_SPLIT_PROPOSER_BASE_USER_PROMPT,
80
86
  )
81
87
  from palimpzest.prompts.utils import (
88
+ AGG_AUDIO_DISCLAIMER,
89
+ AGG_EXAMPLE_ANSWER,
90
+ AGG_EXAMPLE_OUTPUT_FIELDS,
91
+ AGG_EXAMPLE_REASONING,
92
+ AGG_IMAGE_DISCLAIMER,
93
+ AGG_JOB_INSTRUCTION,
82
94
  AUDIO_DISCLAIMER,
83
95
  AUDIO_EXAMPLE_ANSWER,
84
96
  AUDIO_EXAMPLE_CONTEXT,
@@ -87,6 +99,7 @@ from palimpzest.prompts.utils import (
87
99
  AUDIO_EXAMPLE_REASONING,
88
100
  AUDIO_SENTENCE_EXAMPLE_ANSWER,
89
101
  DESC_SECTION,
102
+ EXAMPLE_AGG_INSTRUCTION,
90
103
  EXAMPLE_FILTER_CONDITION,
91
104
  EXAMPLE_JOIN_CONDITION,
92
105
  FILTER_EXAMPLE_REASONING,
@@ -112,12 +125,18 @@ from palimpzest.prompts.utils import (
112
125
  RIGHT_IMAGE_EXAMPLE_INPUT_FIELDS,
113
126
  RIGHT_TEXT_EXAMPLE_CONTEXT,
114
127
  RIGHT_TEXT_EXAMPLE_INPUT_FIELDS,
128
+ SECOND_AUDIO_EXAMPLE_CONTEXT,
129
+ SECOND_IMAGE_EXAMPLE_CONTEXT,
130
+ SECOND_TEXT_EXAMPLE_CONTEXT,
115
131
  TEXT_EXAMPLE_ANSWER,
116
132
  TEXT_EXAMPLE_CONTEXT,
117
133
  TEXT_EXAMPLE_INPUT_FIELDS,
118
134
  TEXT_EXAMPLE_OUTPUT_FIELDS,
119
135
  TEXT_EXAMPLE_REASONING,
120
136
  TEXT_SENTENCE_EXAMPLE_ANSWER,
137
+ THIRD_AUDIO_EXAMPLE_CONTEXT,
138
+ THIRD_IMAGE_EXAMPLE_CONTEXT,
139
+ THIRD_TEXT_EXAMPLE_CONTEXT,
121
140
  )
122
141
 
123
142
 
@@ -125,6 +144,10 @@ class PromptFactory:
125
144
  """Factory class for generating prompts for the Generator given the input(s)."""
126
145
 
127
146
  BASE_SYSTEM_PROMPT_MAP = {
147
+ # agg user prompts
148
+ PromptStrategy.AGG: AGG_BASE_SYSTEM_PROMPT,
149
+ PromptStrategy.AGG_NO_REASONING: AGG_NO_REASONING_BASE_SYSTEM_PROMPT,
150
+
128
151
  # filter system prompts
129
152
  PromptStrategy.FILTER: FILTER_BASE_SYSTEM_PROMPT,
130
153
  PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_SYSTEM_PROMPT,
@@ -150,6 +173,10 @@ class PromptFactory:
150
173
  PromptStrategy.MAP_SPLIT_MERGER: MAP_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
151
174
  }
152
175
  BASE_USER_PROMPT_MAP = {
176
+ # agg user prompts
177
+ PromptStrategy.AGG: AGG_BASE_USER_PROMPT,
178
+ PromptStrategy.AGG_NO_REASONING: AGG_NO_REASONING_BASE_USER_PROMPT,
179
+
153
180
  # filter user prompts
154
181
  PromptStrategy.FILTER: FILTER_BASE_USER_PROMPT,
155
182
  PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_USER_PROMPT,
@@ -181,7 +208,7 @@ class PromptFactory:
181
208
  self.cardinality = cardinality
182
209
  self.desc = desc
183
210
 
184
- def _get_context(self, candidate: DataRecord, input_fields: list[str]) -> str:
211
+ def _get_context(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> str:
185
212
  """
186
213
  Returns the context for the prompt.
187
214
 
@@ -194,7 +221,10 @@ class PromptFactory:
194
221
  """
195
222
  # TODO: remove mask_filepaths=True after SemBench evaluation
196
223
  # get context from input record (project_cols will be None if not provided in kwargs)
197
- context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True)
224
+ if isinstance(candidate, list):
225
+ context: list[dict] = [record.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True) for record in candidate]
226
+ else:
227
+ context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True)
198
228
 
199
229
  # TODO: MOVE THIS LOGIC INTO A CHUNKING / CONTEXT MANAGEMENT CLASS
200
230
  # - this class should be able to:
@@ -203,8 +233,10 @@ class PromptFactory:
203
233
  # - handle the issue with `original_messages` (ask Matt if this is not clear)
204
234
  # TODO: this does not work for image prompts
205
235
  # TODO: this ignores the size of the `orignal_messages` in critique and refine prompts
236
+ # NOTE: llama models are disallowed for aggregation so we can assume context is a dict here
206
237
  # cut down on context based on window length
207
238
  if self.model.is_llama_model():
239
+ assert isinstance(context, dict), "Llama models are not allowed for aggregation operations."
208
240
  total_context_len = len(json.dumps(context, indent=2))
209
241
 
210
242
  # sort fields by length and progressively strip from the longest field until it is short enough;
@@ -323,7 +355,7 @@ class PromptFactory:
323
355
  """
324
356
  output_fields_desc = ""
325
357
  output_schema: type[BaseModel] = kwargs.get("output_schema")
326
- if self.prompt_strategy.is_map_prompt():
358
+ if self.prompt_strategy.is_map_prompt() or self.prompt_strategy.is_agg_prompt():
327
359
  assert output_schema is not None, "Output schema must be provided for convert prompts."
328
360
 
329
361
  for field_name in sorted(output_fields):
@@ -333,6 +365,19 @@ class PromptFactory:
333
365
  # strip the last newline characters from the field descriptions and return
334
366
  return output_fields_desc[:-1]
335
367
 
368
+ def _get_agg_instruction(self, **kwargs) -> str | None:
369
+ """
370
+ Returns the aggregation instruction for the aggregation operation.
371
+
372
+ Returns:
373
+ str | None: The aggregation instruction (if applicable).
374
+ """
375
+ agg_instruction = kwargs.get("agg_instruction")
376
+ if self.prompt_strategy.is_agg_prompt():
377
+ assert agg_instruction is not None, "Aggregation instruction must be provided for aggregation operations."
378
+
379
+ return agg_instruction
380
+
336
381
  def _get_filter_condition(self, **kwargs) -> str | None:
337
382
  """
338
383
  Returns the filter condition for the filter operation.
@@ -464,6 +509,8 @@ class PromptFactory:
464
509
  job_instruction = FILTER_JOB_INSTRUCTION
465
510
  elif self.prompt_strategy.is_join_prompt():
466
511
  job_instruction = JOIN_JOB_INSTRUCTION
512
+ elif self.prompt_strategy.is_agg_prompt():
513
+ job_instruction = AGG_JOB_INSTRUCTION
467
514
 
468
515
  # format the job instruction based on the input modalities
469
516
  modalities = self._get_modalities_str(input_modalities)
@@ -557,6 +604,9 @@ class PromptFactory:
557
604
  Returns:
558
605
  str: The example output fields.
559
606
  """
607
+ if self.prompt_strategy.is_agg_prompt():
608
+ return AGG_EXAMPLE_OUTPUT_FIELDS
609
+
560
610
  input_modality_to_example_output_fields = {
561
611
  Modality.TEXT: TEXT_EXAMPLE_OUTPUT_FIELDS,
562
612
  Modality.IMAGE: IMAGE_EXAMPLE_OUTPUT_FIELDS,
@@ -570,17 +620,31 @@ class PromptFactory:
570
620
 
571
621
  return example_output_fields
572
622
 
573
- def _get_example_context(self, input_modalities: set[Modality], right: bool = False) -> str:
623
+ def _get_example_context(self, input_modalities: set[Modality], right: bool = False, second: bool = False, third: bool = False) -> str:
574
624
  """
575
625
  Returns the example context for the prompt.
576
626
 
577
627
  Returns:
578
628
  str: The example context.
579
629
  """
630
+ assert not (second and third), "Cannot have both second and third example contexts."
631
+ assert not (right and (second or third)), "Right context is only used for joins; second and third contexts only use for aggregations."
632
+ text_example_context = TEXT_EXAMPLE_CONTEXT
633
+ image_example_context = IMAGE_EXAMPLE_CONTEXT
634
+ audio_example_context = AUDIO_EXAMPLE_CONTEXT
635
+ if second:
636
+ text_example_context = SECOND_TEXT_EXAMPLE_CONTEXT
637
+ image_example_context = SECOND_IMAGE_EXAMPLE_CONTEXT
638
+ audio_example_context = SECOND_AUDIO_EXAMPLE_CONTEXT
639
+ elif third:
640
+ text_example_context = THIRD_TEXT_EXAMPLE_CONTEXT
641
+ image_example_context = THIRD_IMAGE_EXAMPLE_CONTEXT
642
+ audio_example_context = THIRD_AUDIO_EXAMPLE_CONTEXT
643
+
580
644
  input_modality_to_example_context = {
581
- Modality.TEXT: RIGHT_TEXT_EXAMPLE_CONTEXT if right else TEXT_EXAMPLE_CONTEXT,
582
- Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_CONTEXT if right else IMAGE_EXAMPLE_CONTEXT,
583
- Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_CONTEXT if right else AUDIO_EXAMPLE_CONTEXT,
645
+ Modality.TEXT: RIGHT_TEXT_EXAMPLE_CONTEXT if right else text_example_context,
646
+ Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_CONTEXT if right else image_example_context,
647
+ Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_CONTEXT if right else audio_example_context,
584
648
  }
585
649
 
586
650
  example_context = ""
@@ -590,7 +654,7 @@ class PromptFactory:
590
654
 
591
655
  return example_context
592
656
 
593
- def _get_image_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
657
+ def _get_image_disclaimer(self, input_modalities: set[Modality], right: bool = False, agg: bool = False) -> str:
594
658
  """
595
659
  Returns the image disclaimer for the prompt. The disclaimer must be an empty string
596
660
  for non-image prompts.
@@ -598,10 +662,12 @@ class PromptFactory:
598
662
  Returns:
599
663
  str: The image disclaimer. If this is a text prompt then it is an empty string.
600
664
  """
601
- image_disclaimer = RIGHT_IMAGE_DISCLAIMER if right else IMAGE_DISCLAIMER
665
+ assert not (right and agg), "Right image disclaimer is only used for joins; agg image disclaimer only used for aggregations."
666
+ image_disclaimer = AGG_IMAGE_DISCLAIMER if agg else IMAGE_DISCLAIMER
667
+ image_disclaimer = RIGHT_IMAGE_DISCLAIMER if right else image_disclaimer
602
668
  return image_disclaimer if Modality.IMAGE in input_modalities else ""
603
669
 
604
- def _get_audio_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
670
+ def _get_audio_disclaimer(self, input_modalities: set[Modality], right: bool = False, agg: bool = False) -> str:
605
671
  """
606
672
  Returns the audio disclaimer for the prompt. The disclaimer must be an empty string
607
673
  for non-audio prompts.
@@ -609,7 +675,9 @@ class PromptFactory:
609
675
  Returns:
610
676
  str: The audio disclaimer. If this is a text prompt then it is an empty string.
611
677
  """
612
- audio_disclaimer = RIGHT_AUDIO_DISCLAIMER if right else AUDIO_DISCLAIMER
678
+ assert not (right and agg), "Right audio disclaimer is only used for joins; agg audio disclaimer only used for aggregations."
679
+ audio_disclaimer = AGG_AUDIO_DISCLAIMER if agg else AUDIO_DISCLAIMER
680
+ audio_disclaimer = RIGHT_AUDIO_DISCLAIMER if right else audio_disclaimer
613
681
  return audio_disclaimer if Modality.AUDIO in input_modalities else ""
614
682
 
615
683
  def _get_example_reasoning(self, input_modalities: set[Modality]) -> str:
@@ -623,6 +691,8 @@ class PromptFactory:
623
691
  return FILTER_EXAMPLE_REASONING
624
692
  elif self.prompt_strategy.is_join_prompt():
625
693
  return JOIN_EXAMPLE_REASONING
694
+ elif self.prompt_strategy.is_agg_prompt():
695
+ return AGG_EXAMPLE_REASONING
626
696
 
627
697
  input_modality_to_example_reasoning = {
628
698
  Modality.TEXT: TEXT_EXAMPLE_REASONING,
@@ -644,6 +714,9 @@ class PromptFactory:
644
714
  Returns:
645
715
  str: The example answer.
646
716
  """
717
+ if self.prompt_strategy.is_agg_prompt():
718
+ return AGG_EXAMPLE_ANSWER
719
+
647
720
  use_sentence_answers = self.prompt_strategy.is_split_proposer_prompt() or self.prompt_strategy.is_moa_proposer_prompt()
648
721
  input_modality_to_example_answer = {
649
722
  Modality.TEXT: TEXT_SENTENCE_EXAMPLE_ANSWER if use_sentence_answers else TEXT_EXAMPLE_ANSWER,
@@ -662,7 +735,7 @@ class PromptFactory:
662
735
 
663
736
  def _get_all_format_kwargs(
664
737
  self,
665
- candidate: DataRecord,
738
+ candidate: DataRecord | list[DataRecord],
666
739
  input_fields: list[str],
667
740
  input_modalities: set[Modality],
668
741
  output_fields: list[str],
@@ -686,8 +759,9 @@ class PromptFactory:
686
759
  # get format kwargs which depend on the input data
687
760
  input_format_kwargs = {
688
761
  "context": self._get_context(candidate, input_fields),
689
- "input_fields_desc": self._get_input_fields_desc(candidate, input_fields),
762
+ "input_fields_desc": self._get_input_fields_desc(candidate[0] if isinstance(candidate, list) else candidate, input_fields),
690
763
  "output_fields_desc": self._get_output_fields_desc(output_fields, **kwargs),
764
+ "agg_instruction": self._get_agg_instruction(**kwargs),
691
765
  "filter_condition": self._get_filter_condition(**kwargs),
692
766
  "join_condition": self._get_join_condition(**kwargs),
693
767
  "original_output": self._get_original_output(**kwargs),
@@ -716,11 +790,14 @@ class PromptFactory:
716
790
  "right_example_input_fields": self._get_example_input_fields(right_input_modalities, right=True),
717
791
  "example_output_fields": self._get_example_output_fields(input_modalities),
718
792
  "example_context": self._get_example_context(input_modalities),
793
+ "second_example_context": self._get_example_context(input_modalities, second=True) if self.prompt_strategy.is_agg_prompt() else "",
794
+ "third_example_context": self._get_example_context(input_modalities, third=True) if self.prompt_strategy.is_agg_prompt() else "",
719
795
  "right_example_context": self._get_example_context(right_input_modalities, right=True),
720
- "image_disclaimer": self._get_image_disclaimer(input_modalities),
721
- "audio_disclaimer": self._get_audio_disclaimer(input_modalities),
796
+ "image_disclaimer": self._get_image_disclaimer(input_modalities, agg=self.prompt_strategy.is_agg_prompt()),
797
+ "audio_disclaimer": self._get_audio_disclaimer(input_modalities, agg=self.prompt_strategy.is_agg_prompt()),
722
798
  "right_image_disclaimer": self._get_image_disclaimer(right_input_modalities, right=True),
723
799
  "right_audio_disclaimer": self._get_audio_disclaimer(right_input_modalities, right=True),
800
+ "example_agg_instruction": EXAMPLE_AGG_INSTRUCTION,
724
801
  "example_filter_condition": EXAMPLE_FILTER_CONDITION,
725
802
  "example_join_condition": EXAMPLE_JOIN_CONDITION,
726
803
  "example_reasoning": self._get_example_reasoning(input_modalities),
@@ -730,106 +807,116 @@ class PromptFactory:
730
807
  # return all format kwargs
731
808
  return {**input_format_kwargs, **prompt_strategy_format_kwargs}
732
809
 
733
- def _create_audio_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
810
+ def _create_audio_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> list[dict]:
734
811
  """
735
- Parses the candidate record and returns the audio messages for the chat payload.
812
+ Parses the candidate record(s) and returns the audio messages for the chat payload.
736
813
 
737
814
  Args:
738
- candidate (DataRecord): The input record.
815
+ candidate (DataRecord | list[DataRecord]): The input record(s).
739
816
  input_fields (list[str]): The list of input fields.
740
817
 
741
818
  Returns:
742
819
  list[dict]: The audio messages for the chat payload.
743
820
  """
821
+ # normalize type to be list[DataRecord]
822
+ if isinstance(candidate, DataRecord):
823
+ candidate = [candidate]
824
+
744
825
  # create a message for each audio recording in an input field with an audio (or list of audio) type
745
826
  audio_content = []
746
827
  for field_name in input_fields:
747
- field_value = candidate[field_name]
748
- field_type = candidate.get_field_type(field_name)
828
+ for dr in candidate:
829
+ field_value = dr[field_name]
830
+ field_type = dr.get_field_type(field_name)
749
831
 
750
- # audio filepath (or list of audio filepaths)
751
- if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any]:
752
- with open(field_value, "rb") as f:
753
- base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
754
- audio_content.append(
755
- {"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
756
- )
757
-
758
- elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None, list[AudioFilepath] | Any]:
759
- for audio_filepath in field_value:
760
- with open(audio_filepath, "rb") as f:
832
+ # audio filepath (or list of audio filepaths)
833
+ if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any]:
834
+ with open(field_value, "rb") as f:
761
835
  base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
762
836
  audio_content.append(
763
837
  {"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
764
838
  )
765
839
 
766
- # pre-encoded images (or list of pre-encoded images)
767
- elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any]:
768
- audio_content.append(
769
- {"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
770
- )
840
+ elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None, list[AudioFilepath] | Any]:
841
+ for audio_filepath in field_value:
842
+ with open(audio_filepath, "rb") as f:
843
+ base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
844
+ audio_content.append(
845
+ {"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
846
+ )
771
847
 
772
- elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None, list[AudioBase64] | Any]:
773
- for base64_audio in field_value:
848
+ # pre-encoded images (or list of pre-encoded images)
849
+ elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any]:
774
850
  audio_content.append(
775
- {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
851
+ {"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
776
852
  )
777
853
 
854
+ elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None, list[AudioBase64] | Any]:
855
+ for base64_audio in field_value:
856
+ audio_content.append(
857
+ {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
858
+ )
859
+
778
860
  return [{"role": "user", "type": "input_audio", "content": audio_content}] if len(audio_content) > 0 else []
779
861
 
780
- def _create_image_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
862
+ def _create_image_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> list[dict]:
781
863
  """
782
- Parses the candidate record and returns the image messages for the chat payload.
864
+ Parses the candidate record(s) and returns the image messages for the chat payload.
783
865
 
784
866
  Args:
785
- candidate (DataRecord): The input record.
867
+ candidate (DataRecord | list[DataRecord]): The input record(s).
786
868
  input_fields (list[str]): The list of input fields.
787
869
 
788
870
  Returns:
789
871
  list[dict]: The image messages for the chat payload.
790
872
  """
873
+ # normalize type to be list[DataRecord]
874
+ if isinstance(candidate, DataRecord):
875
+ candidate = [candidate]
876
+
791
877
  # create a message for each image in an input field with an image (or list of image) type
792
878
  image_content = []
793
879
  for field_name in input_fields:
794
- field_value = candidate[field_name]
795
- field_type = candidate.get_field_type(field_name)
880
+ for dr in candidate:
881
+ field_value = dr[field_name]
882
+ field_type = dr.get_field_type(field_name)
796
883
 
797
- # image filepath (or list of image filepaths)
798
- if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any]:
799
- with open(field_value, "rb") as f:
800
- base64_image_str = base64.b64encode(f.read()).decode("utf-8")
801
- image_content.append(
802
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
803
- )
804
-
805
- elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]:
806
- for image_filepath in field_value:
807
- with open(image_filepath, "rb") as f:
884
+ # image filepath (or list of image filepaths)
885
+ if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any]:
886
+ with open(field_value, "rb") as f:
808
887
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
809
888
  image_content.append(
810
889
  {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
811
890
  )
812
891
 
813
- # image url (or list of image urls)
814
- elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any]:
815
- image_content.append({"type": "image_url", "image_url": {"url": field_value}})
892
+ elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]:
893
+ for image_filepath in field_value:
894
+ with open(image_filepath, "rb") as f:
895
+ base64_image_str = base64.b64encode(f.read()).decode("utf-8")
896
+ image_content.append(
897
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
898
+ )
816
899
 
817
- elif field_type.annotation in [list[ImageURL], list[ImageURL] | None, list[ImageURL] | Any]:
818
- for image_url in field_value:
819
- image_content.append({"type": "image_url", "image_url": {"url": image_url}})
900
+ # image url (or list of image urls)
901
+ elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any]:
902
+ image_content.append({"type": "image_url", "image_url": {"url": field_value}})
820
903
 
821
- # pre-encoded images (or list of pre-encoded images)
822
- elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any]:
823
- image_content.append(
824
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
825
- )
904
+ elif field_type.annotation in [list[ImageURL], list[ImageURL] | None, list[ImageURL] | Any]:
905
+ for image_url in field_value:
906
+ image_content.append({"type": "image_url", "image_url": {"url": image_url}})
826
907
 
827
- elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None, list[ImageBase64] | Any]:
828
- for base64_image in field_value:
908
+ # pre-encoded images (or list of pre-encoded images)
909
+ elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any]:
829
910
  image_content.append(
830
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
911
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
831
912
  )
832
913
 
914
+ elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None, list[ImageBase64] | Any]:
915
+ for base64_image in field_value:
916
+ image_content.append(
917
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
918
+ )
919
+
833
920
  return [{"role": "user", "type": "image", "content": image_content}] if len(image_content) > 0 else []
834
921
 
835
922
  def _get_system_prompt(self, **format_kwargs) -> str | None:
@@ -849,12 +936,12 @@ class PromptFactory:
849
936
 
850
937
  return base_prompt.format(**format_kwargs)
851
938
 
852
- def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
939
+ def _get_user_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
853
940
  """
854
941
  Returns a list of messages for the chat payload based on the prompt strategy.
855
942
 
856
943
  Args:
857
- candidate (DataRecord): The input record.
944
+ candidate (DataRecord | list[DataRecord]): The input record(s).
858
945
  input_fields (list[str]): The input fields.
859
946
  output_fields (list[str]): The output fields.
860
947
  kwargs: The formatting kwargs and some keyword arguments provided by the user.
@@ -943,7 +1030,7 @@ class PromptFactory:
943
1030
 
944
1031
  return user_messages
945
1032
 
946
- def create_messages(self, candidate: DataRecord, output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
1033
+ def create_messages(self, candidate: DataRecord | list[DataRecord], output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
947
1034
  """
948
1035
  Creates the messages for the chat payload based on the prompt strategy.
949
1036
 
@@ -955,7 +1042,7 @@ class PromptFactory:
955
1042
  }
956
1043
 
957
1044
  Args:
958
- candidate (DataRecord): The input record.
1045
+ candidate (DataRecord | list[DataRecord]): The input record(s).
959
1046
  output_fields (list[str]): The output fields.
960
1047
  right_candidate (DataRecord | None): The other join input record (only provided for joins).
961
1048
  kwargs: The keyword arguments provided by the user.
@@ -964,11 +1051,11 @@ class PromptFactory:
964
1051
  list[dict]: The messages for the chat payload.
965
1052
  """
966
1053
  # compute the set of input fields
967
- input_fields = self._get_input_fields(candidate, **kwargs)
1054
+ input_fields = self._get_input_fields(candidate[0] if isinstance(candidate, list) else candidate, **kwargs)
968
1055
  right_input_fields = [] if right_candidate is None else self._get_input_fields(right_candidate, **kwargs)
969
1056
 
970
1057
  # use input fields to determine the left / right input modalities
971
- input_modalities = self._get_input_modalities(candidate, input_fields)
1058
+ input_modalities = self._get_input_modalities(candidate[0] if isinstance(candidate, list) else candidate, input_fields)
972
1059
  right_input_modalities = set() if right_candidate is None else self._get_input_modalities(right_candidate, right_input_fields)
973
1060
 
974
1061
  # initialize messages
@@ -11,12 +11,14 @@ The user has additionally provided you with this description of the task you nee
11
11
  """
12
12
 
13
13
  ### JOB INSTRUCTIONS ###
14
+ AGG_JOB_INSTRUCTION = """analyze input {modalities} in order to perform an aggregation and generate a JSON object"""
14
15
  MAP_JOB_INSTRUCTION = """analyze input {modalities} in order to produce a JSON object"""
15
16
  FILTER_JOB_INSTRUCTION = """analyze input {modalities} in order to answer a TRUE / FALSE question"""
16
17
  JOIN_JOB_INSTRUCTION = """analyze input {modalities} in order to determine whether two data records satisfy a join condition"""
17
18
  PROPOSER_JOB_INSTRUCTION = """analyze input {modalities} in order to produce an answer to a question"""
18
19
 
19
- ### FILTER / JOIN CONDITIONS ###
20
+ ### AGG / FILTER / JOIN CONDITIONS ###
21
+ EXAMPLE_AGG_INSTRUCTION = "Count the distinct number of scientists in the input."
20
22
  EXAMPLE_FILTER_CONDITION = "The subject of the input is a foundational computer scientist."
21
23
  EXAMPLE_JOIN_CONDITION = "The two inputs are scientists in the same academic field."
22
24
 
@@ -48,6 +50,7 @@ TEXT_EXAMPLE_OUTPUT_FIELDS = """- name: the name of the scientist
48
50
  - birth_year: the year the scientist was born"""
49
51
  IMAGE_EXAMPLE_OUTPUT_FIELDS = """- is_bald: true if the scientist is bald and false otherwise"""
50
52
  AUDIO_EXAMPLE_OUTPUT_FIELDS = """- birthplace: the city where the scientist was born"""
53
+ AGG_EXAMPLE_OUTPUT_FIELDS = """- num_distinct_scientists: the number of distinct scientists mentioned in the input"""
51
54
 
52
55
  ### EXAMPLE CONTEXTS ###
53
56
  TEXT_EXAMPLE_CONTEXT = """
@@ -71,6 +74,30 @@ RIGHT_IMAGE_EXAMPLE_CONTEXT = """
71
74
  RIGHT_AUDIO_EXAMPLE_CONTEXT = """
72
75
  "podcast": <bytes>
73
76
  """
77
+ SECOND_TEXT_EXAMPLE_CONTEXT = """
78
+ "text": "Alan Turing was a pioneering computer scientist and mathematician. He is widely considered to be the father of theoretical computer science and artificial intelligence.",
79
+ "birthday": "June 23, 1912"
80
+ """
81
+ SECOND_IMAGE_EXAMPLE_CONTEXT = """
82
+ "image": <bytes>,
83
+ "photographer": "PhotoPro42"
84
+ """
85
+ SECOND_AUDIO_EXAMPLE_CONTEXT = """
86
+ "recording": <bytes>,
87
+ "speaker": "Barbara Walters"
88
+ """
89
+ THIRD_TEXT_EXAMPLE_CONTEXT = """
90
+ "text": "Ada Lovelace is a historically significant computer scientist.",
91
+ "birthday": "December 10, 1815"
92
+ """
93
+ THIRD_IMAGE_EXAMPLE_CONTEXT = """
94
+ "image": <bytes>,
95
+ "photographer": "PicturePerfect"
96
+ """
97
+ THIRD_AUDIO_EXAMPLE_CONTEXT = """
98
+ "recording": <bytes>,
99
+ "speaker": "Anderson Cooper"
100
+ """
74
101
 
75
102
  ### DISCLAIMERS ###
76
103
  IMAGE_DISCLAIMER = """
@@ -85,15 +112,25 @@ RIGHT_IMAGE_DISCLAIMER = """
85
112
  RIGHT_AUDIO_DISCLAIMER = """
86
113
  \n<audio content provided here; assume in this example the podcast is discussing Alan Turing's work on the Enigma code>
87
114
  """
115
+ AGG_IMAGE_DISCLAIMER = """
116
+ \n<image content provided here; assume in this example the first image shows Ada Lovelace, the second image shows Alan Turing, and the third image shows Ada Lovelace again>
117
+ """
118
+ AGG_AUDIO_DISCLAIMER = """
119
+ \n<audio content provided here; assume in this example the first recording is about Ada Lovelace, the second recording is about Alan Turing, and the third recording is about Ada Lovelace again>
120
+ """
88
121
 
89
122
  ### EXAMPLE REASONINGS ###
90
123
  TEXT_EXAMPLE_REASONING = """The text passage mentions the scientist's name as "Augusta Ada King, Countess of Lovelace, also known as Ada Lovelace" and the scientist's birthday as "December 10, 1815". Therefore, the name of the scientist is "Augusta Ada King" and the birth year is 1815."""
91
124
  IMAGE_EXAMPLE_REASONING = """The image shows hair on top of the scientist's head, so the is_bald field should be false."""
92
125
  AUDIO_EXAMPLE_REASONING = """The newscast recording discusses Ada Lovelace's upbringing in London, so the birthplace field should be "London"."""
126
+ AGG_EXAMPLE_REASONING = """The input contains two distinct scientists: "Augusta Ada King" and "Alan Turing". Although "Ada Lovelace" is mentioned twice, she should only be counted once. Therefore, the number of distinct scientists mentioned in the input is 2."""
93
127
  FILTER_EXAMPLE_REASONING = """Ada Lovelace is a foundational computer scientist, therefore the answer is TRUE."""
94
128
  JOIN_EXAMPLE_REASONING = """The subject of the left record is Ada Lovelace and the subject of the right record is Alan Turing. Since both inputs are about computer scientists, they satisfy the join condition. Therefore, the answer is TRUE."""
95
129
 
96
130
  ### EXAMPLE ANSWERS ###
131
+ AGG_EXAMPLE_ANSWER = """
132
+ "num_distinct_scientists": 2
133
+ """
97
134
  TEXT_EXAMPLE_ANSWER = """
98
135
  "name": "Augusta Ada King",
99
136
  "birth_year": 1815