palimpzest 0.8.6__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@
2
2
 
3
3
  import base64
4
4
  import json
5
+ from typing import Any
5
6
 
6
7
  from pydantic import BaseModel
7
8
 
@@ -23,6 +24,12 @@ from palimpzest.core.lib.schemas import (
23
24
  ImageFilepath,
24
25
  ImageURL,
25
26
  )
27
+ from palimpzest.prompts.aggregate_prompts import (
28
+ AGG_BASE_SYSTEM_PROMPT,
29
+ AGG_BASE_USER_PROMPT,
30
+ AGG_NO_REASONING_BASE_SYSTEM_PROMPT,
31
+ AGG_NO_REASONING_BASE_USER_PROMPT,
32
+ )
26
33
  from palimpzest.prompts.convert_prompts import (
27
34
  MAP_BASE_SYSTEM_PROMPT,
28
35
  MAP_BASE_USER_PROMPT,
@@ -78,6 +85,12 @@ from palimpzest.prompts.split_proposer_prompts import (
78
85
  MAP_SPLIT_PROPOSER_BASE_USER_PROMPT,
79
86
  )
80
87
  from palimpzest.prompts.utils import (
88
+ AGG_AUDIO_DISCLAIMER,
89
+ AGG_EXAMPLE_ANSWER,
90
+ AGG_EXAMPLE_OUTPUT_FIELDS,
91
+ AGG_EXAMPLE_REASONING,
92
+ AGG_IMAGE_DISCLAIMER,
93
+ AGG_JOB_INSTRUCTION,
81
94
  AUDIO_DISCLAIMER,
82
95
  AUDIO_EXAMPLE_ANSWER,
83
96
  AUDIO_EXAMPLE_CONTEXT,
@@ -86,6 +99,7 @@ from palimpzest.prompts.utils import (
86
99
  AUDIO_EXAMPLE_REASONING,
87
100
  AUDIO_SENTENCE_EXAMPLE_ANSWER,
88
101
  DESC_SECTION,
102
+ EXAMPLE_AGG_INSTRUCTION,
89
103
  EXAMPLE_FILTER_CONDITION,
90
104
  EXAMPLE_JOIN_CONDITION,
91
105
  FILTER_EXAMPLE_REASONING,
@@ -111,12 +125,18 @@ from palimpzest.prompts.utils import (
111
125
  RIGHT_IMAGE_EXAMPLE_INPUT_FIELDS,
112
126
  RIGHT_TEXT_EXAMPLE_CONTEXT,
113
127
  RIGHT_TEXT_EXAMPLE_INPUT_FIELDS,
128
+ SECOND_AUDIO_EXAMPLE_CONTEXT,
129
+ SECOND_IMAGE_EXAMPLE_CONTEXT,
130
+ SECOND_TEXT_EXAMPLE_CONTEXT,
114
131
  TEXT_EXAMPLE_ANSWER,
115
132
  TEXT_EXAMPLE_CONTEXT,
116
133
  TEXT_EXAMPLE_INPUT_FIELDS,
117
134
  TEXT_EXAMPLE_OUTPUT_FIELDS,
118
135
  TEXT_EXAMPLE_REASONING,
119
136
  TEXT_SENTENCE_EXAMPLE_ANSWER,
137
+ THIRD_AUDIO_EXAMPLE_CONTEXT,
138
+ THIRD_IMAGE_EXAMPLE_CONTEXT,
139
+ THIRD_TEXT_EXAMPLE_CONTEXT,
120
140
  )
121
141
 
122
142
 
@@ -124,6 +144,10 @@ class PromptFactory:
124
144
  """Factory class for generating prompts for the Generator given the input(s)."""
125
145
 
126
146
  BASE_SYSTEM_PROMPT_MAP = {
147
+ # agg user prompts
148
+ PromptStrategy.AGG: AGG_BASE_SYSTEM_PROMPT,
149
+ PromptStrategy.AGG_NO_REASONING: AGG_NO_REASONING_BASE_SYSTEM_PROMPT,
150
+
127
151
  # filter system prompts
128
152
  PromptStrategy.FILTER: FILTER_BASE_SYSTEM_PROMPT,
129
153
  PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_SYSTEM_PROMPT,
@@ -149,6 +173,10 @@ class PromptFactory:
149
173
  PromptStrategy.MAP_SPLIT_MERGER: MAP_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
150
174
  }
151
175
  BASE_USER_PROMPT_MAP = {
176
+ # agg user prompts
177
+ PromptStrategy.AGG: AGG_BASE_USER_PROMPT,
178
+ PromptStrategy.AGG_NO_REASONING: AGG_NO_REASONING_BASE_USER_PROMPT,
179
+
152
180
  # filter user prompts
153
181
  PromptStrategy.FILTER: FILTER_BASE_USER_PROMPT,
154
182
  PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_USER_PROMPT,
@@ -180,7 +208,7 @@ class PromptFactory:
180
208
  self.cardinality = cardinality
181
209
  self.desc = desc
182
210
 
183
- def _get_context(self, candidate: DataRecord, input_fields: list[str]) -> str:
211
+ def _get_context(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> str:
184
212
  """
185
213
  Returns the context for the prompt.
186
214
 
@@ -193,7 +221,10 @@ class PromptFactory:
193
221
  """
194
222
  # TODO: remove mask_filepaths=True after SemBench evaluation
195
223
  # get context from input record (project_cols will be None if not provided in kwargs)
196
- context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True)
224
+ if isinstance(candidate, list):
225
+ context: list[dict] = [record.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True) for record in candidate]
226
+ else:
227
+ context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True)
197
228
 
198
229
  # TODO: MOVE THIS LOGIC INTO A CHUNKING / CONTEXT MANAGEMENT CLASS
199
230
  # - this class should be able to:
@@ -202,8 +233,10 @@ class PromptFactory:
202
233
  # - handle the issue with `original_messages` (ask Matt if this is not clear)
203
234
  # TODO: this does not work for image prompts
204
235
  # TODO: this ignores the size of the `orignal_messages` in critique and refine prompts
236
+ # NOTE: llama models are disallowed for aggregation so we can assume context is a dict here
205
237
  # cut down on context based on window length
206
238
  if self.model.is_llama_model():
239
+ assert isinstance(context, dict), "Llama models are not allowed for aggregation operations."
207
240
  total_context_len = len(json.dumps(context, indent=2))
208
241
 
209
242
  # sort fields by length and progressively strip from the longest field until it is short enough;
@@ -322,7 +355,7 @@ class PromptFactory:
322
355
  """
323
356
  output_fields_desc = ""
324
357
  output_schema: type[BaseModel] = kwargs.get("output_schema")
325
- if self.prompt_strategy.is_map_prompt():
358
+ if self.prompt_strategy.is_map_prompt() or self.prompt_strategy.is_agg_prompt():
326
359
  assert output_schema is not None, "Output schema must be provided for convert prompts."
327
360
 
328
361
  for field_name in sorted(output_fields):
@@ -332,6 +365,19 @@ class PromptFactory:
332
365
  # strip the last newline characters from the field descriptions and return
333
366
  return output_fields_desc[:-1]
334
367
 
368
+ def _get_agg_instruction(self, **kwargs) -> str | None:
369
+ """
370
+ Returns the aggregation instruction for the aggregation operation.
371
+
372
+ Returns:
373
+ str | None: The aggregation instruction (if applicable).
374
+ """
375
+ agg_instruction = kwargs.get("agg_instruction")
376
+ if self.prompt_strategy.is_agg_prompt():
377
+ assert agg_instruction is not None, "Aggregation instruction must be provided for aggregation operations."
378
+
379
+ return agg_instruction
380
+
335
381
  def _get_filter_condition(self, **kwargs) -> str | None:
336
382
  """
337
383
  Returns the filter condition for the filter operation.
@@ -463,6 +509,8 @@ class PromptFactory:
463
509
  job_instruction = FILTER_JOB_INSTRUCTION
464
510
  elif self.prompt_strategy.is_join_prompt():
465
511
  job_instruction = JOIN_JOB_INSTRUCTION
512
+ elif self.prompt_strategy.is_agg_prompt():
513
+ job_instruction = AGG_JOB_INSTRUCTION
466
514
 
467
515
  # format the job instruction based on the input modalities
468
516
  modalities = self._get_modalities_str(input_modalities)
@@ -556,6 +604,9 @@ class PromptFactory:
556
604
  Returns:
557
605
  str: The example output fields.
558
606
  """
607
+ if self.prompt_strategy.is_agg_prompt():
608
+ return AGG_EXAMPLE_OUTPUT_FIELDS
609
+
559
610
  input_modality_to_example_output_fields = {
560
611
  Modality.TEXT: TEXT_EXAMPLE_OUTPUT_FIELDS,
561
612
  Modality.IMAGE: IMAGE_EXAMPLE_OUTPUT_FIELDS,
@@ -569,17 +620,31 @@ class PromptFactory:
569
620
 
570
621
  return example_output_fields
571
622
 
572
- def _get_example_context(self, input_modalities: set[Modality], right: bool = False) -> str:
623
+ def _get_example_context(self, input_modalities: set[Modality], right: bool = False, second: bool = False, third: bool = False) -> str:
573
624
  """
574
625
  Returns the example context for the prompt.
575
626
 
576
627
  Returns:
577
628
  str: The example context.
578
629
  """
630
+ assert not (second and third), "Cannot have both second and third example contexts."
631
+ assert not (right and (second or third)), "Right context is only used for joins; second and third contexts only use for aggregations."
632
+ text_example_context = TEXT_EXAMPLE_CONTEXT
633
+ image_example_context = IMAGE_EXAMPLE_CONTEXT
634
+ audio_example_context = AUDIO_EXAMPLE_CONTEXT
635
+ if second:
636
+ text_example_context = SECOND_TEXT_EXAMPLE_CONTEXT
637
+ image_example_context = SECOND_IMAGE_EXAMPLE_CONTEXT
638
+ audio_example_context = SECOND_AUDIO_EXAMPLE_CONTEXT
639
+ elif third:
640
+ text_example_context = THIRD_TEXT_EXAMPLE_CONTEXT
641
+ image_example_context = THIRD_IMAGE_EXAMPLE_CONTEXT
642
+ audio_example_context = THIRD_AUDIO_EXAMPLE_CONTEXT
643
+
579
644
  input_modality_to_example_context = {
580
- Modality.TEXT: RIGHT_TEXT_EXAMPLE_CONTEXT if right else TEXT_EXAMPLE_CONTEXT,
581
- Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_CONTEXT if right else IMAGE_EXAMPLE_CONTEXT,
582
- Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_CONTEXT if right else AUDIO_EXAMPLE_CONTEXT,
645
+ Modality.TEXT: RIGHT_TEXT_EXAMPLE_CONTEXT if right else text_example_context,
646
+ Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_CONTEXT if right else image_example_context,
647
+ Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_CONTEXT if right else audio_example_context,
583
648
  }
584
649
 
585
650
  example_context = ""
@@ -589,7 +654,7 @@ class PromptFactory:
589
654
 
590
655
  return example_context
591
656
 
592
- def _get_image_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
657
+ def _get_image_disclaimer(self, input_modalities: set[Modality], right: bool = False, agg: bool = False) -> str:
593
658
  """
594
659
  Returns the image disclaimer for the prompt. The disclaimer must be an empty string
595
660
  for non-image prompts.
@@ -597,10 +662,12 @@ class PromptFactory:
597
662
  Returns:
598
663
  str: The image disclaimer. If this is a text prompt then it is an empty string.
599
664
  """
600
- image_disclaimer = RIGHT_IMAGE_DISCLAIMER if right else IMAGE_DISCLAIMER
665
+ assert not (right and agg), "Right image disclaimer is only used for joins; agg image disclaimer only used for aggregations."
666
+ image_disclaimer = AGG_IMAGE_DISCLAIMER if agg else IMAGE_DISCLAIMER
667
+ image_disclaimer = RIGHT_IMAGE_DISCLAIMER if right else image_disclaimer
601
668
  return image_disclaimer if Modality.IMAGE in input_modalities else ""
602
669
 
603
- def _get_audio_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
670
+ def _get_audio_disclaimer(self, input_modalities: set[Modality], right: bool = False, agg: bool = False) -> str:
604
671
  """
605
672
  Returns the audio disclaimer for the prompt. The disclaimer must be an empty string
606
673
  for non-audio prompts.
@@ -608,7 +675,9 @@ class PromptFactory:
608
675
  Returns:
609
676
  str: The audio disclaimer. If this is a text prompt then it is an empty string.
610
677
  """
611
- audio_disclaimer = RIGHT_AUDIO_DISCLAIMER if right else AUDIO_DISCLAIMER
678
+ assert not (right and agg), "Right audio disclaimer is only used for joins; agg audio disclaimer only used for aggregations."
679
+ audio_disclaimer = AGG_AUDIO_DISCLAIMER if agg else AUDIO_DISCLAIMER
680
+ audio_disclaimer = RIGHT_AUDIO_DISCLAIMER if right else audio_disclaimer
612
681
  return audio_disclaimer if Modality.AUDIO in input_modalities else ""
613
682
 
614
683
  def _get_example_reasoning(self, input_modalities: set[Modality]) -> str:
@@ -622,6 +691,8 @@ class PromptFactory:
622
691
  return FILTER_EXAMPLE_REASONING
623
692
  elif self.prompt_strategy.is_join_prompt():
624
693
  return JOIN_EXAMPLE_REASONING
694
+ elif self.prompt_strategy.is_agg_prompt():
695
+ return AGG_EXAMPLE_REASONING
625
696
 
626
697
  input_modality_to_example_reasoning = {
627
698
  Modality.TEXT: TEXT_EXAMPLE_REASONING,
@@ -643,6 +714,9 @@ class PromptFactory:
643
714
  Returns:
644
715
  str: The example answer.
645
716
  """
717
+ if self.prompt_strategy.is_agg_prompt():
718
+ return AGG_EXAMPLE_ANSWER
719
+
646
720
  use_sentence_answers = self.prompt_strategy.is_split_proposer_prompt() or self.prompt_strategy.is_moa_proposer_prompt()
647
721
  input_modality_to_example_answer = {
648
722
  Modality.TEXT: TEXT_SENTENCE_EXAMPLE_ANSWER if use_sentence_answers else TEXT_EXAMPLE_ANSWER,
@@ -661,7 +735,7 @@ class PromptFactory:
661
735
 
662
736
  def _get_all_format_kwargs(
663
737
  self,
664
- candidate: DataRecord,
738
+ candidate: DataRecord | list[DataRecord],
665
739
  input_fields: list[str],
666
740
  input_modalities: set[Modality],
667
741
  output_fields: list[str],
@@ -685,8 +759,9 @@ class PromptFactory:
685
759
  # get format kwargs which depend on the input data
686
760
  input_format_kwargs = {
687
761
  "context": self._get_context(candidate, input_fields),
688
- "input_fields_desc": self._get_input_fields_desc(candidate, input_fields),
762
+ "input_fields_desc": self._get_input_fields_desc(candidate[0] if isinstance(candidate, list) else candidate, input_fields),
689
763
  "output_fields_desc": self._get_output_fields_desc(output_fields, **kwargs),
764
+ "agg_instruction": self._get_agg_instruction(**kwargs),
690
765
  "filter_condition": self._get_filter_condition(**kwargs),
691
766
  "join_condition": self._get_join_condition(**kwargs),
692
767
  "original_output": self._get_original_output(**kwargs),
@@ -715,11 +790,14 @@ class PromptFactory:
715
790
  "right_example_input_fields": self._get_example_input_fields(right_input_modalities, right=True),
716
791
  "example_output_fields": self._get_example_output_fields(input_modalities),
717
792
  "example_context": self._get_example_context(input_modalities),
793
+ "second_example_context": self._get_example_context(input_modalities, second=True) if self.prompt_strategy.is_agg_prompt() else "",
794
+ "third_example_context": self._get_example_context(input_modalities, third=True) if self.prompt_strategy.is_agg_prompt() else "",
718
795
  "right_example_context": self._get_example_context(right_input_modalities, right=True),
719
- "image_disclaimer": self._get_image_disclaimer(input_modalities),
720
- "audio_disclaimer": self._get_audio_disclaimer(input_modalities),
796
+ "image_disclaimer": self._get_image_disclaimer(input_modalities, agg=self.prompt_strategy.is_agg_prompt()),
797
+ "audio_disclaimer": self._get_audio_disclaimer(input_modalities, agg=self.prompt_strategy.is_agg_prompt()),
721
798
  "right_image_disclaimer": self._get_image_disclaimer(right_input_modalities, right=True),
722
799
  "right_audio_disclaimer": self._get_audio_disclaimer(right_input_modalities, right=True),
800
+ "example_agg_instruction": EXAMPLE_AGG_INSTRUCTION,
723
801
  "example_filter_condition": EXAMPLE_FILTER_CONDITION,
724
802
  "example_join_condition": EXAMPLE_JOIN_CONDITION,
725
803
  "example_reasoning": self._get_example_reasoning(input_modalities),
@@ -729,106 +807,116 @@ class PromptFactory:
729
807
  # return all format kwargs
730
808
  return {**input_format_kwargs, **prompt_strategy_format_kwargs}
731
809
 
732
- def _create_audio_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
810
+ def _create_audio_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> list[dict]:
733
811
  """
734
- Parses the candidate record and returns the audio messages for the chat payload.
812
+ Parses the candidate record(s) and returns the audio messages for the chat payload.
735
813
 
736
814
  Args:
737
- candidate (DataRecord): The input record.
815
+ candidate (DataRecord | list[DataRecord]): The input record(s).
738
816
  input_fields (list[str]): The list of input fields.
739
817
 
740
818
  Returns:
741
819
  list[dict]: The audio messages for the chat payload.
742
820
  """
821
+ # normalize type to be list[DataRecord]
822
+ if isinstance(candidate, DataRecord):
823
+ candidate = [candidate]
824
+
743
825
  # create a message for each audio recording in an input field with an audio (or list of audio) type
744
826
  audio_content = []
745
827
  for field_name in input_fields:
746
- field_value = candidate[field_name]
747
- field_type = candidate.get_field_type(field_name)
828
+ for dr in candidate:
829
+ field_value = dr[field_name]
830
+ field_type = dr.get_field_type(field_name)
748
831
 
749
- # audio filepath (or list of audio filepaths)
750
- if field_type.annotation in [AudioFilepath, AudioFilepath | None]:
751
- with open(field_value, "rb") as f:
752
- base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
753
- audio_content.append(
754
- {"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
755
- )
756
-
757
- elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None]:
758
- for audio_filepath in field_value:
759
- with open(audio_filepath, "rb") as f:
832
+ # audio filepath (or list of audio filepaths)
833
+ if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any]:
834
+ with open(field_value, "rb") as f:
760
835
  base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
761
836
  audio_content.append(
762
837
  {"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
763
838
  )
764
839
 
765
- # pre-encoded images (or list of pre-encoded images)
766
- elif field_type.annotation in [AudioBase64, AudioBase64 | None]:
767
- audio_content.append(
768
- {"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
769
- )
840
+ elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None, list[AudioFilepath] | Any]:
841
+ for audio_filepath in field_value:
842
+ with open(audio_filepath, "rb") as f:
843
+ base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
844
+ audio_content.append(
845
+ {"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
846
+ )
770
847
 
771
- elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None]:
772
- for base64_audio in field_value:
848
+ # pre-encoded images (or list of pre-encoded images)
849
+ elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any]:
773
850
  audio_content.append(
774
- {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
851
+ {"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
775
852
  )
776
853
 
854
+ elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None, list[AudioBase64] | Any]:
855
+ for base64_audio in field_value:
856
+ audio_content.append(
857
+ {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
858
+ )
859
+
777
860
  return [{"role": "user", "type": "input_audio", "content": audio_content}] if len(audio_content) > 0 else []
778
861
 
779
- def _create_image_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
862
+ def _create_image_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> list[dict]:
780
863
  """
781
- Parses the candidate record and returns the image messages for the chat payload.
864
+ Parses the candidate record(s) and returns the image messages for the chat payload.
782
865
 
783
866
  Args:
784
- candidate (DataRecord): The input record.
867
+ candidate (DataRecord | list[DataRecord]): The input record(s).
785
868
  input_fields (list[str]): The list of input fields.
786
869
 
787
870
  Returns:
788
871
  list[dict]: The image messages for the chat payload.
789
872
  """
873
+ # normalize type to be list[DataRecord]
874
+ if isinstance(candidate, DataRecord):
875
+ candidate = [candidate]
876
+
790
877
  # create a message for each image in an input field with an image (or list of image) type
791
878
  image_content = []
792
879
  for field_name in input_fields:
793
- field_value = candidate[field_name]
794
- field_type = candidate.get_field_type(field_name)
880
+ for dr in candidate:
881
+ field_value = dr[field_name]
882
+ field_type = dr.get_field_type(field_name)
795
883
 
796
- # image filepath (or list of image filepaths)
797
- if field_type.annotation in [ImageFilepath, ImageFilepath | None]:
798
- with open(field_value, "rb") as f:
799
- base64_image_str = base64.b64encode(f.read()).decode("utf-8")
800
- image_content.append(
801
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
802
- )
803
-
804
- elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None]:
805
- for image_filepath in field_value:
806
- with open(image_filepath, "rb") as f:
884
+ # image filepath (or list of image filepaths)
885
+ if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any]:
886
+ with open(field_value, "rb") as f:
807
887
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
808
888
  image_content.append(
809
889
  {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
810
890
  )
811
891
 
812
- # image url (or list of image urls)
813
- elif field_type.annotation in [ImageURL, ImageURL | None]:
814
- image_content.append({"type": "image_url", "image_url": {"url": field_value}})
892
+ elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]:
893
+ for image_filepath in field_value:
894
+ with open(image_filepath, "rb") as f:
895
+ base64_image_str = base64.b64encode(f.read()).decode("utf-8")
896
+ image_content.append(
897
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
898
+ )
815
899
 
816
- elif field_type.annotation in [list[ImageURL], list[ImageURL] | None]:
817
- for image_url in field_value:
818
- image_content.append({"type": "image_url", "image_url": {"url": image_url}})
900
+ # image url (or list of image urls)
901
+ elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any]:
902
+ image_content.append({"type": "image_url", "image_url": {"url": field_value}})
819
903
 
820
- # pre-encoded images (or list of pre-encoded images)
821
- elif field_type.annotation in [ImageBase64, ImageBase64 | None]:
822
- image_content.append(
823
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
824
- )
904
+ elif field_type.annotation in [list[ImageURL], list[ImageURL] | None, list[ImageURL] | Any]:
905
+ for image_url in field_value:
906
+ image_content.append({"type": "image_url", "image_url": {"url": image_url}})
825
907
 
826
- elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None]:
827
- for base64_image in field_value:
908
+ # pre-encoded images (or list of pre-encoded images)
909
+ elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any]:
828
910
  image_content.append(
829
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
911
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
830
912
  )
831
913
 
914
+ elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None, list[ImageBase64] | Any]:
915
+ for base64_image in field_value:
916
+ image_content.append(
917
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
918
+ )
919
+
832
920
  return [{"role": "user", "type": "image", "content": image_content}] if len(image_content) > 0 else []
833
921
 
834
922
  def _get_system_prompt(self, **format_kwargs) -> str | None:
@@ -848,12 +936,12 @@ class PromptFactory:
848
936
 
849
937
  return base_prompt.format(**format_kwargs)
850
938
 
851
- def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
939
+ def _get_user_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
852
940
  """
853
941
  Returns a list of messages for the chat payload based on the prompt strategy.
854
942
 
855
943
  Args:
856
- candidate (DataRecord): The input record.
944
+ candidate (DataRecord | list[DataRecord]): The input record(s).
857
945
  input_fields (list[str]): The input fields.
858
946
  output_fields (list[str]): The output fields.
859
947
  kwargs: The formatting kwargs and some keyword arguments provided by the user.
@@ -942,7 +1030,7 @@ class PromptFactory:
942
1030
 
943
1031
  return user_messages
944
1032
 
945
- def create_messages(self, candidate: DataRecord, output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
1033
+ def create_messages(self, candidate: DataRecord | list[DataRecord], output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
946
1034
  """
947
1035
  Creates the messages for the chat payload based on the prompt strategy.
948
1036
 
@@ -954,7 +1042,7 @@ class PromptFactory:
954
1042
  }
955
1043
 
956
1044
  Args:
957
- candidate (DataRecord): The input record.
1045
+ candidate (DataRecord | list[DataRecord]): The input record(s).
958
1046
  output_fields (list[str]): The output fields.
959
1047
  right_candidate (DataRecord | None): The other join input record (only provided for joins).
960
1048
  kwargs: The keyword arguments provided by the user.
@@ -963,11 +1051,11 @@ class PromptFactory:
963
1051
  list[dict]: The messages for the chat payload.
964
1052
  """
965
1053
  # compute the set of input fields
966
- input_fields = self._get_input_fields(candidate, **kwargs)
1054
+ input_fields = self._get_input_fields(candidate[0] if isinstance(candidate, list) else candidate, **kwargs)
967
1055
  right_input_fields = [] if right_candidate is None else self._get_input_fields(right_candidate, **kwargs)
968
1056
 
969
1057
  # use input fields to determine the left / right input modalities
970
- input_modalities = self._get_input_modalities(candidate, input_fields)
1058
+ input_modalities = self._get_input_modalities(candidate[0] if isinstance(candidate, list) else candidate, input_fields)
971
1059
  right_input_modalities = set() if right_candidate is None else self._get_input_modalities(right_candidate, right_input_fields)
972
1060
 
973
1061
  # initialize messages
@@ -11,12 +11,14 @@ The user has additionally provided you with this description of the task you nee
11
11
  """
12
12
 
13
13
  ### JOB INSTRUCTIONS ###
14
+ AGG_JOB_INSTRUCTION = """analyze input {modalities} in order to perform an aggregation and generate a JSON object"""
14
15
  MAP_JOB_INSTRUCTION = """analyze input {modalities} in order to produce a JSON object"""
15
16
  FILTER_JOB_INSTRUCTION = """analyze input {modalities} in order to answer a TRUE / FALSE question"""
16
17
  JOIN_JOB_INSTRUCTION = """analyze input {modalities} in order to determine whether two data records satisfy a join condition"""
17
18
  PROPOSER_JOB_INSTRUCTION = """analyze input {modalities} in order to produce an answer to a question"""
18
19
 
19
- ### FILTER / JOIN CONDITIONS ###
20
+ ### AGG / FILTER / JOIN CONDITIONS ###
21
+ EXAMPLE_AGG_INSTRUCTION = "Count the distinct number of scientists in the input."
20
22
  EXAMPLE_FILTER_CONDITION = "The subject of the input is a foundational computer scientist."
21
23
  EXAMPLE_JOIN_CONDITION = "The two inputs are scientists in the same academic field."
22
24
 
@@ -48,6 +50,7 @@ TEXT_EXAMPLE_OUTPUT_FIELDS = """- name: the name of the scientist
48
50
  - birth_year: the year the scientist was born"""
49
51
  IMAGE_EXAMPLE_OUTPUT_FIELDS = """- is_bald: true if the scientist is bald and false otherwise"""
50
52
  AUDIO_EXAMPLE_OUTPUT_FIELDS = """- birthplace: the city where the scientist was born"""
53
+ AGG_EXAMPLE_OUTPUT_FIELDS = """- num_distinct_scientists: the number of distinct scientists mentioned in the input"""
51
54
 
52
55
  ### EXAMPLE CONTEXTS ###
53
56
  TEXT_EXAMPLE_CONTEXT = """
@@ -71,6 +74,30 @@ RIGHT_IMAGE_EXAMPLE_CONTEXT = """
71
74
  RIGHT_AUDIO_EXAMPLE_CONTEXT = """
72
75
  "podcast": <bytes>
73
76
  """
77
+ SECOND_TEXT_EXAMPLE_CONTEXT = """
78
+ "text": "Alan Turing was a pioneering computer scientist and mathematician. He is widely considered to be the father of theoretical computer science and artificial intelligence.",
79
+ "birthday": "June 23, 1912"
80
+ """
81
+ SECOND_IMAGE_EXAMPLE_CONTEXT = """
82
+ "image": <bytes>,
83
+ "photographer": "PhotoPro42"
84
+ """
85
+ SECOND_AUDIO_EXAMPLE_CONTEXT = """
86
+ "recording": <bytes>,
87
+ "speaker": "Barbara Walters"
88
+ """
89
+ THIRD_TEXT_EXAMPLE_CONTEXT = """
90
+ "text": "Ada Lovelace is a historically significant computer scientist.",
91
+ "birthday": "December 10, 1815"
92
+ """
93
+ THIRD_IMAGE_EXAMPLE_CONTEXT = """
94
+ "image": <bytes>,
95
+ "photographer": "PicturePerfect"
96
+ """
97
+ THIRD_AUDIO_EXAMPLE_CONTEXT = """
98
+ "recording": <bytes>,
99
+ "speaker": "Anderson Cooper"
100
+ """
74
101
 
75
102
  ### DISCLAIMERS ###
76
103
  IMAGE_DISCLAIMER = """
@@ -85,15 +112,25 @@ RIGHT_IMAGE_DISCLAIMER = """
85
112
  RIGHT_AUDIO_DISCLAIMER = """
86
113
  \n<audio content provided here; assume in this example the podcast is discussing Alan Turing's work on the Enigma code>
87
114
  """
115
+ AGG_IMAGE_DISCLAIMER = """
116
+ \n<image content provided here; assume in this example the first image shows Ada Lovelace, the second image shows Alan Turing, and the third image shows Ada Lovelace again>
117
+ """
118
+ AGG_AUDIO_DISCLAIMER = """
119
+ \n<audio content provided here; assume in this example the first recording is about Ada Lovelace, the second recording is about Alan Turing, and the third recording is about Ada Lovelace again>
120
+ """
88
121
 
89
122
  ### EXAMPLE REASONINGS ###
90
123
  TEXT_EXAMPLE_REASONING = """The text passage mentions the scientist's name as "Augusta Ada King, Countess of Lovelace, also known as Ada Lovelace" and the scientist's birthday as "December 10, 1815". Therefore, the name of the scientist is "Augusta Ada King" and the birth year is 1815."""
91
124
  IMAGE_EXAMPLE_REASONING = """The image shows hair on top of the scientist's head, so the is_bald field should be false."""
92
125
  AUDIO_EXAMPLE_REASONING = """The newscast recording discusses Ada Lovelace's upbringing in London, so the birthplace field should be "London"."""
126
+ AGG_EXAMPLE_REASONING = """The input contains two distinct scientists: "Augusta Ada King" and "Alan Turing". Although "Ada Lovelace" is mentioned twice, she should only be counted once. Therefore, the number of distinct scientists mentioned in the input is 2."""
93
127
  FILTER_EXAMPLE_REASONING = """Ada Lovelace is a foundational computer scientist, therefore the answer is TRUE."""
94
128
  JOIN_EXAMPLE_REASONING = """The subject of the left record is Ada Lovelace and the subject of the right record is Alan Turing. Since both inputs are about computer scientists, they satisfy the join condition. Therefore, the answer is TRUE."""
95
129
 
96
130
  ### EXAMPLE ANSWERS ###
131
+ AGG_EXAMPLE_ANSWER = """
132
+ "num_distinct_scientists": 2
133
+ """
97
134
  TEXT_EXAMPLE_ANSWER = """
98
135
  "name": "Augusta Ada King",
99
136
  "birth_year": 1815