palimpzest 0.8.7__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +12 -4
- palimpzest/core/data/dataset.py +42 -0
- palimpzest/core/lib/schemas.py +6 -0
- palimpzest/prompts/aggregate_prompts.py +99 -0
- palimpzest/prompts/prompt_factory.py +162 -75
- palimpzest/prompts/utils.py +38 -1
- palimpzest/prompts/validator.py +24 -24
- palimpzest/query/generators/generators.py +9 -7
- palimpzest/query/operators/__init__.py +4 -1
- palimpzest/query/operators/aggregate.py +285 -6
- palimpzest/query/operators/logical.py +17 -4
- palimpzest/query/optimizer/__init__.py +4 -0
- palimpzest/query/optimizer/rules.py +42 -2
- {palimpzest-0.8.7.dist-info → palimpzest-0.9.0.dist-info}/METADATA +1 -1
- {palimpzest-0.8.7.dist-info → palimpzest-0.9.0.dist-info}/RECORD +18 -17
- {palimpzest-0.8.7.dist-info → palimpzest-0.9.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.8.7.dist-info → palimpzest-0.9.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.8.7.dist-info → palimpzest-0.9.0.dist-info}/top_level.txt +0 -0
|
@@ -24,6 +24,12 @@ from palimpzest.core.lib.schemas import (
|
|
|
24
24
|
ImageFilepath,
|
|
25
25
|
ImageURL,
|
|
26
26
|
)
|
|
27
|
+
from palimpzest.prompts.aggregate_prompts import (
|
|
28
|
+
AGG_BASE_SYSTEM_PROMPT,
|
|
29
|
+
AGG_BASE_USER_PROMPT,
|
|
30
|
+
AGG_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
31
|
+
AGG_NO_REASONING_BASE_USER_PROMPT,
|
|
32
|
+
)
|
|
27
33
|
from palimpzest.prompts.convert_prompts import (
|
|
28
34
|
MAP_BASE_SYSTEM_PROMPT,
|
|
29
35
|
MAP_BASE_USER_PROMPT,
|
|
@@ -79,6 +85,12 @@ from palimpzest.prompts.split_proposer_prompts import (
|
|
|
79
85
|
MAP_SPLIT_PROPOSER_BASE_USER_PROMPT,
|
|
80
86
|
)
|
|
81
87
|
from palimpzest.prompts.utils import (
|
|
88
|
+
AGG_AUDIO_DISCLAIMER,
|
|
89
|
+
AGG_EXAMPLE_ANSWER,
|
|
90
|
+
AGG_EXAMPLE_OUTPUT_FIELDS,
|
|
91
|
+
AGG_EXAMPLE_REASONING,
|
|
92
|
+
AGG_IMAGE_DISCLAIMER,
|
|
93
|
+
AGG_JOB_INSTRUCTION,
|
|
82
94
|
AUDIO_DISCLAIMER,
|
|
83
95
|
AUDIO_EXAMPLE_ANSWER,
|
|
84
96
|
AUDIO_EXAMPLE_CONTEXT,
|
|
@@ -87,6 +99,7 @@ from palimpzest.prompts.utils import (
|
|
|
87
99
|
AUDIO_EXAMPLE_REASONING,
|
|
88
100
|
AUDIO_SENTENCE_EXAMPLE_ANSWER,
|
|
89
101
|
DESC_SECTION,
|
|
102
|
+
EXAMPLE_AGG_INSTRUCTION,
|
|
90
103
|
EXAMPLE_FILTER_CONDITION,
|
|
91
104
|
EXAMPLE_JOIN_CONDITION,
|
|
92
105
|
FILTER_EXAMPLE_REASONING,
|
|
@@ -112,12 +125,18 @@ from palimpzest.prompts.utils import (
|
|
|
112
125
|
RIGHT_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
113
126
|
RIGHT_TEXT_EXAMPLE_CONTEXT,
|
|
114
127
|
RIGHT_TEXT_EXAMPLE_INPUT_FIELDS,
|
|
128
|
+
SECOND_AUDIO_EXAMPLE_CONTEXT,
|
|
129
|
+
SECOND_IMAGE_EXAMPLE_CONTEXT,
|
|
130
|
+
SECOND_TEXT_EXAMPLE_CONTEXT,
|
|
115
131
|
TEXT_EXAMPLE_ANSWER,
|
|
116
132
|
TEXT_EXAMPLE_CONTEXT,
|
|
117
133
|
TEXT_EXAMPLE_INPUT_FIELDS,
|
|
118
134
|
TEXT_EXAMPLE_OUTPUT_FIELDS,
|
|
119
135
|
TEXT_EXAMPLE_REASONING,
|
|
120
136
|
TEXT_SENTENCE_EXAMPLE_ANSWER,
|
|
137
|
+
THIRD_AUDIO_EXAMPLE_CONTEXT,
|
|
138
|
+
THIRD_IMAGE_EXAMPLE_CONTEXT,
|
|
139
|
+
THIRD_TEXT_EXAMPLE_CONTEXT,
|
|
121
140
|
)
|
|
122
141
|
|
|
123
142
|
|
|
@@ -125,6 +144,10 @@ class PromptFactory:
|
|
|
125
144
|
"""Factory class for generating prompts for the Generator given the input(s)."""
|
|
126
145
|
|
|
127
146
|
BASE_SYSTEM_PROMPT_MAP = {
|
|
147
|
+
# agg user prompts
|
|
148
|
+
PromptStrategy.AGG: AGG_BASE_SYSTEM_PROMPT,
|
|
149
|
+
PromptStrategy.AGG_NO_REASONING: AGG_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
150
|
+
|
|
128
151
|
# filter system prompts
|
|
129
152
|
PromptStrategy.FILTER: FILTER_BASE_SYSTEM_PROMPT,
|
|
130
153
|
PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
@@ -150,6 +173,10 @@ class PromptFactory:
|
|
|
150
173
|
PromptStrategy.MAP_SPLIT_MERGER: MAP_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
|
|
151
174
|
}
|
|
152
175
|
BASE_USER_PROMPT_MAP = {
|
|
176
|
+
# agg user prompts
|
|
177
|
+
PromptStrategy.AGG: AGG_BASE_USER_PROMPT,
|
|
178
|
+
PromptStrategy.AGG_NO_REASONING: AGG_NO_REASONING_BASE_USER_PROMPT,
|
|
179
|
+
|
|
153
180
|
# filter user prompts
|
|
154
181
|
PromptStrategy.FILTER: FILTER_BASE_USER_PROMPT,
|
|
155
182
|
PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_USER_PROMPT,
|
|
@@ -181,7 +208,7 @@ class PromptFactory:
|
|
|
181
208
|
self.cardinality = cardinality
|
|
182
209
|
self.desc = desc
|
|
183
210
|
|
|
184
|
-
def _get_context(self, candidate: DataRecord, input_fields: list[str]) -> str:
|
|
211
|
+
def _get_context(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> str:
|
|
185
212
|
"""
|
|
186
213
|
Returns the context for the prompt.
|
|
187
214
|
|
|
@@ -194,7 +221,10 @@ class PromptFactory:
|
|
|
194
221
|
"""
|
|
195
222
|
# TODO: remove mask_filepaths=True after SemBench evaluation
|
|
196
223
|
# get context from input record (project_cols will be None if not provided in kwargs)
|
|
197
|
-
|
|
224
|
+
if isinstance(candidate, list):
|
|
225
|
+
context: list[dict] = [record.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True) for record in candidate]
|
|
226
|
+
else:
|
|
227
|
+
context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True)
|
|
198
228
|
|
|
199
229
|
# TODO: MOVE THIS LOGIC INTO A CHUNKING / CONTEXT MANAGEMENT CLASS
|
|
200
230
|
# - this class should be able to:
|
|
@@ -203,8 +233,10 @@ class PromptFactory:
|
|
|
203
233
|
# - handle the issue with `original_messages` (ask Matt if this is not clear)
|
|
204
234
|
# TODO: this does not work for image prompts
|
|
205
235
|
# TODO: this ignores the size of the `orignal_messages` in critique and refine prompts
|
|
236
|
+
# NOTE: llama models are disallowed for aggregation so we can assume context is a dict here
|
|
206
237
|
# cut down on context based on window length
|
|
207
238
|
if self.model.is_llama_model():
|
|
239
|
+
assert isinstance(context, dict), "Llama models are not allowed for aggregation operations."
|
|
208
240
|
total_context_len = len(json.dumps(context, indent=2))
|
|
209
241
|
|
|
210
242
|
# sort fields by length and progressively strip from the longest field until it is short enough;
|
|
@@ -323,7 +355,7 @@ class PromptFactory:
|
|
|
323
355
|
"""
|
|
324
356
|
output_fields_desc = ""
|
|
325
357
|
output_schema: type[BaseModel] = kwargs.get("output_schema")
|
|
326
|
-
if self.prompt_strategy.is_map_prompt():
|
|
358
|
+
if self.prompt_strategy.is_map_prompt() or self.prompt_strategy.is_agg_prompt():
|
|
327
359
|
assert output_schema is not None, "Output schema must be provided for convert prompts."
|
|
328
360
|
|
|
329
361
|
for field_name in sorted(output_fields):
|
|
@@ -333,6 +365,19 @@ class PromptFactory:
|
|
|
333
365
|
# strip the last newline characters from the field descriptions and return
|
|
334
366
|
return output_fields_desc[:-1]
|
|
335
367
|
|
|
368
|
+
def _get_agg_instruction(self, **kwargs) -> str | None:
|
|
369
|
+
"""
|
|
370
|
+
Returns the aggregation instruction for the aggregation operation.
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
str | None: The aggregation instruction (if applicable).
|
|
374
|
+
"""
|
|
375
|
+
agg_instruction = kwargs.get("agg_instruction")
|
|
376
|
+
if self.prompt_strategy.is_agg_prompt():
|
|
377
|
+
assert agg_instruction is not None, "Aggregation instruction must be provided for aggregation operations."
|
|
378
|
+
|
|
379
|
+
return agg_instruction
|
|
380
|
+
|
|
336
381
|
def _get_filter_condition(self, **kwargs) -> str | None:
|
|
337
382
|
"""
|
|
338
383
|
Returns the filter condition for the filter operation.
|
|
@@ -464,6 +509,8 @@ class PromptFactory:
|
|
|
464
509
|
job_instruction = FILTER_JOB_INSTRUCTION
|
|
465
510
|
elif self.prompt_strategy.is_join_prompt():
|
|
466
511
|
job_instruction = JOIN_JOB_INSTRUCTION
|
|
512
|
+
elif self.prompt_strategy.is_agg_prompt():
|
|
513
|
+
job_instruction = AGG_JOB_INSTRUCTION
|
|
467
514
|
|
|
468
515
|
# format the job instruction based on the input modalities
|
|
469
516
|
modalities = self._get_modalities_str(input_modalities)
|
|
@@ -557,6 +604,9 @@ class PromptFactory:
|
|
|
557
604
|
Returns:
|
|
558
605
|
str: The example output fields.
|
|
559
606
|
"""
|
|
607
|
+
if self.prompt_strategy.is_agg_prompt():
|
|
608
|
+
return AGG_EXAMPLE_OUTPUT_FIELDS
|
|
609
|
+
|
|
560
610
|
input_modality_to_example_output_fields = {
|
|
561
611
|
Modality.TEXT: TEXT_EXAMPLE_OUTPUT_FIELDS,
|
|
562
612
|
Modality.IMAGE: IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
@@ -570,17 +620,31 @@ class PromptFactory:
|
|
|
570
620
|
|
|
571
621
|
return example_output_fields
|
|
572
622
|
|
|
573
|
-
def _get_example_context(self, input_modalities: set[Modality], right: bool = False) -> str:
|
|
623
|
+
def _get_example_context(self, input_modalities: set[Modality], right: bool = False, second: bool = False, third: bool = False) -> str:
|
|
574
624
|
"""
|
|
575
625
|
Returns the example context for the prompt.
|
|
576
626
|
|
|
577
627
|
Returns:
|
|
578
628
|
str: The example context.
|
|
579
629
|
"""
|
|
630
|
+
assert not (second and third), "Cannot have both second and third example contexts."
|
|
631
|
+
assert not (right and (second or third)), "Right context is only used for joins; second and third contexts only use for aggregations."
|
|
632
|
+
text_example_context = TEXT_EXAMPLE_CONTEXT
|
|
633
|
+
image_example_context = IMAGE_EXAMPLE_CONTEXT
|
|
634
|
+
audio_example_context = AUDIO_EXAMPLE_CONTEXT
|
|
635
|
+
if second:
|
|
636
|
+
text_example_context = SECOND_TEXT_EXAMPLE_CONTEXT
|
|
637
|
+
image_example_context = SECOND_IMAGE_EXAMPLE_CONTEXT
|
|
638
|
+
audio_example_context = SECOND_AUDIO_EXAMPLE_CONTEXT
|
|
639
|
+
elif third:
|
|
640
|
+
text_example_context = THIRD_TEXT_EXAMPLE_CONTEXT
|
|
641
|
+
image_example_context = THIRD_IMAGE_EXAMPLE_CONTEXT
|
|
642
|
+
audio_example_context = THIRD_AUDIO_EXAMPLE_CONTEXT
|
|
643
|
+
|
|
580
644
|
input_modality_to_example_context = {
|
|
581
|
-
Modality.TEXT: RIGHT_TEXT_EXAMPLE_CONTEXT if right else
|
|
582
|
-
Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_CONTEXT if right else
|
|
583
|
-
Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_CONTEXT if right else
|
|
645
|
+
Modality.TEXT: RIGHT_TEXT_EXAMPLE_CONTEXT if right else text_example_context,
|
|
646
|
+
Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_CONTEXT if right else image_example_context,
|
|
647
|
+
Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_CONTEXT if right else audio_example_context,
|
|
584
648
|
}
|
|
585
649
|
|
|
586
650
|
example_context = ""
|
|
@@ -590,7 +654,7 @@ class PromptFactory:
|
|
|
590
654
|
|
|
591
655
|
return example_context
|
|
592
656
|
|
|
593
|
-
def _get_image_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
|
|
657
|
+
def _get_image_disclaimer(self, input_modalities: set[Modality], right: bool = False, agg: bool = False) -> str:
|
|
594
658
|
"""
|
|
595
659
|
Returns the image disclaimer for the prompt. The disclaimer must be an empty string
|
|
596
660
|
for non-image prompts.
|
|
@@ -598,10 +662,12 @@ class PromptFactory:
|
|
|
598
662
|
Returns:
|
|
599
663
|
str: The image disclaimer. If this is a text prompt then it is an empty string.
|
|
600
664
|
"""
|
|
601
|
-
|
|
665
|
+
assert not (right and agg), "Right image disclaimer is only used for joins; agg image disclaimer only used for aggregations."
|
|
666
|
+
image_disclaimer = AGG_IMAGE_DISCLAIMER if agg else IMAGE_DISCLAIMER
|
|
667
|
+
image_disclaimer = RIGHT_IMAGE_DISCLAIMER if right else image_disclaimer
|
|
602
668
|
return image_disclaimer if Modality.IMAGE in input_modalities else ""
|
|
603
669
|
|
|
604
|
-
def _get_audio_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
|
|
670
|
+
def _get_audio_disclaimer(self, input_modalities: set[Modality], right: bool = False, agg: bool = False) -> str:
|
|
605
671
|
"""
|
|
606
672
|
Returns the audio disclaimer for the prompt. The disclaimer must be an empty string
|
|
607
673
|
for non-audio prompts.
|
|
@@ -609,7 +675,9 @@ class PromptFactory:
|
|
|
609
675
|
Returns:
|
|
610
676
|
str: The audio disclaimer. If this is a text prompt then it is an empty string.
|
|
611
677
|
"""
|
|
612
|
-
|
|
678
|
+
assert not (right and agg), "Right audio disclaimer is only used for joins; agg audio disclaimer only used for aggregations."
|
|
679
|
+
audio_disclaimer = AGG_AUDIO_DISCLAIMER if agg else AUDIO_DISCLAIMER
|
|
680
|
+
audio_disclaimer = RIGHT_AUDIO_DISCLAIMER if right else audio_disclaimer
|
|
613
681
|
return audio_disclaimer if Modality.AUDIO in input_modalities else ""
|
|
614
682
|
|
|
615
683
|
def _get_example_reasoning(self, input_modalities: set[Modality]) -> str:
|
|
@@ -623,6 +691,8 @@ class PromptFactory:
|
|
|
623
691
|
return FILTER_EXAMPLE_REASONING
|
|
624
692
|
elif self.prompt_strategy.is_join_prompt():
|
|
625
693
|
return JOIN_EXAMPLE_REASONING
|
|
694
|
+
elif self.prompt_strategy.is_agg_prompt():
|
|
695
|
+
return AGG_EXAMPLE_REASONING
|
|
626
696
|
|
|
627
697
|
input_modality_to_example_reasoning = {
|
|
628
698
|
Modality.TEXT: TEXT_EXAMPLE_REASONING,
|
|
@@ -644,6 +714,9 @@ class PromptFactory:
|
|
|
644
714
|
Returns:
|
|
645
715
|
str: The example answer.
|
|
646
716
|
"""
|
|
717
|
+
if self.prompt_strategy.is_agg_prompt():
|
|
718
|
+
return AGG_EXAMPLE_ANSWER
|
|
719
|
+
|
|
647
720
|
use_sentence_answers = self.prompt_strategy.is_split_proposer_prompt() or self.prompt_strategy.is_moa_proposer_prompt()
|
|
648
721
|
input_modality_to_example_answer = {
|
|
649
722
|
Modality.TEXT: TEXT_SENTENCE_EXAMPLE_ANSWER if use_sentence_answers else TEXT_EXAMPLE_ANSWER,
|
|
@@ -662,7 +735,7 @@ class PromptFactory:
|
|
|
662
735
|
|
|
663
736
|
def _get_all_format_kwargs(
|
|
664
737
|
self,
|
|
665
|
-
candidate: DataRecord,
|
|
738
|
+
candidate: DataRecord | list[DataRecord],
|
|
666
739
|
input_fields: list[str],
|
|
667
740
|
input_modalities: set[Modality],
|
|
668
741
|
output_fields: list[str],
|
|
@@ -686,8 +759,9 @@ class PromptFactory:
|
|
|
686
759
|
# get format kwargs which depend on the input data
|
|
687
760
|
input_format_kwargs = {
|
|
688
761
|
"context": self._get_context(candidate, input_fields),
|
|
689
|
-
"input_fields_desc": self._get_input_fields_desc(candidate, input_fields),
|
|
762
|
+
"input_fields_desc": self._get_input_fields_desc(candidate[0] if isinstance(candidate, list) else candidate, input_fields),
|
|
690
763
|
"output_fields_desc": self._get_output_fields_desc(output_fields, **kwargs),
|
|
764
|
+
"agg_instruction": self._get_agg_instruction(**kwargs),
|
|
691
765
|
"filter_condition": self._get_filter_condition(**kwargs),
|
|
692
766
|
"join_condition": self._get_join_condition(**kwargs),
|
|
693
767
|
"original_output": self._get_original_output(**kwargs),
|
|
@@ -716,11 +790,14 @@ class PromptFactory:
|
|
|
716
790
|
"right_example_input_fields": self._get_example_input_fields(right_input_modalities, right=True),
|
|
717
791
|
"example_output_fields": self._get_example_output_fields(input_modalities),
|
|
718
792
|
"example_context": self._get_example_context(input_modalities),
|
|
793
|
+
"second_example_context": self._get_example_context(input_modalities, second=True) if self.prompt_strategy.is_agg_prompt() else "",
|
|
794
|
+
"third_example_context": self._get_example_context(input_modalities, third=True) if self.prompt_strategy.is_agg_prompt() else "",
|
|
719
795
|
"right_example_context": self._get_example_context(right_input_modalities, right=True),
|
|
720
|
-
"image_disclaimer": self._get_image_disclaimer(input_modalities),
|
|
721
|
-
"audio_disclaimer": self._get_audio_disclaimer(input_modalities),
|
|
796
|
+
"image_disclaimer": self._get_image_disclaimer(input_modalities, agg=self.prompt_strategy.is_agg_prompt()),
|
|
797
|
+
"audio_disclaimer": self._get_audio_disclaimer(input_modalities, agg=self.prompt_strategy.is_agg_prompt()),
|
|
722
798
|
"right_image_disclaimer": self._get_image_disclaimer(right_input_modalities, right=True),
|
|
723
799
|
"right_audio_disclaimer": self._get_audio_disclaimer(right_input_modalities, right=True),
|
|
800
|
+
"example_agg_instruction": EXAMPLE_AGG_INSTRUCTION,
|
|
724
801
|
"example_filter_condition": EXAMPLE_FILTER_CONDITION,
|
|
725
802
|
"example_join_condition": EXAMPLE_JOIN_CONDITION,
|
|
726
803
|
"example_reasoning": self._get_example_reasoning(input_modalities),
|
|
@@ -730,106 +807,116 @@ class PromptFactory:
|
|
|
730
807
|
# return all format kwargs
|
|
731
808
|
return {**input_format_kwargs, **prompt_strategy_format_kwargs}
|
|
732
809
|
|
|
733
|
-
def _create_audio_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
|
|
810
|
+
def _create_audio_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> list[dict]:
|
|
734
811
|
"""
|
|
735
|
-
Parses the candidate record and returns the audio messages for the chat payload.
|
|
812
|
+
Parses the candidate record(s) and returns the audio messages for the chat payload.
|
|
736
813
|
|
|
737
814
|
Args:
|
|
738
|
-
candidate (DataRecord): The input record.
|
|
815
|
+
candidate (DataRecord | list[DataRecord]): The input record(s).
|
|
739
816
|
input_fields (list[str]): The list of input fields.
|
|
740
817
|
|
|
741
818
|
Returns:
|
|
742
819
|
list[dict]: The audio messages for the chat payload.
|
|
743
820
|
"""
|
|
821
|
+
# normalize type to be list[DataRecord]
|
|
822
|
+
if isinstance(candidate, DataRecord):
|
|
823
|
+
candidate = [candidate]
|
|
824
|
+
|
|
744
825
|
# create a message for each audio recording in an input field with an audio (or list of audio) type
|
|
745
826
|
audio_content = []
|
|
746
827
|
for field_name in input_fields:
|
|
747
|
-
|
|
748
|
-
|
|
828
|
+
for dr in candidate:
|
|
829
|
+
field_value = dr[field_name]
|
|
830
|
+
field_type = dr.get_field_type(field_name)
|
|
749
831
|
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
754
|
-
audio_content.append(
|
|
755
|
-
{"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
|
|
756
|
-
)
|
|
757
|
-
|
|
758
|
-
elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None, list[AudioFilepath] | Any]:
|
|
759
|
-
for audio_filepath in field_value:
|
|
760
|
-
with open(audio_filepath, "rb") as f:
|
|
832
|
+
# audio filepath (or list of audio filepaths)
|
|
833
|
+
if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any]:
|
|
834
|
+
with open(field_value, "rb") as f:
|
|
761
835
|
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
762
836
|
audio_content.append(
|
|
763
837
|
{"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
|
|
764
838
|
)
|
|
765
839
|
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
840
|
+
elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None, list[AudioFilepath] | Any]:
|
|
841
|
+
for audio_filepath in field_value:
|
|
842
|
+
with open(audio_filepath, "rb") as f:
|
|
843
|
+
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
844
|
+
audio_content.append(
|
|
845
|
+
{"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
|
|
846
|
+
)
|
|
771
847
|
|
|
772
|
-
|
|
773
|
-
|
|
848
|
+
# pre-encoded images (or list of pre-encoded images)
|
|
849
|
+
elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any]:
|
|
774
850
|
audio_content.append(
|
|
775
|
-
{"type": "input_audio", "input_audio": {"data":
|
|
851
|
+
{"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
|
|
776
852
|
)
|
|
777
853
|
|
|
854
|
+
elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None, list[AudioBase64] | Any]:
|
|
855
|
+
for base64_audio in field_value:
|
|
856
|
+
audio_content.append(
|
|
857
|
+
{"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
|
|
858
|
+
)
|
|
859
|
+
|
|
778
860
|
return [{"role": "user", "type": "input_audio", "content": audio_content}] if len(audio_content) > 0 else []
|
|
779
861
|
|
|
780
|
-
def _create_image_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
|
|
862
|
+
def _create_image_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> list[dict]:
|
|
781
863
|
"""
|
|
782
|
-
Parses the candidate record and returns the image messages for the chat payload.
|
|
864
|
+
Parses the candidate record(s) and returns the image messages for the chat payload.
|
|
783
865
|
|
|
784
866
|
Args:
|
|
785
|
-
candidate (DataRecord): The input record.
|
|
867
|
+
candidate (DataRecord | list[DataRecord]): The input record(s).
|
|
786
868
|
input_fields (list[str]): The list of input fields.
|
|
787
869
|
|
|
788
870
|
Returns:
|
|
789
871
|
list[dict]: The image messages for the chat payload.
|
|
790
872
|
"""
|
|
873
|
+
# normalize type to be list[DataRecord]
|
|
874
|
+
if isinstance(candidate, DataRecord):
|
|
875
|
+
candidate = [candidate]
|
|
876
|
+
|
|
791
877
|
# create a message for each image in an input field with an image (or list of image) type
|
|
792
878
|
image_content = []
|
|
793
879
|
for field_name in input_fields:
|
|
794
|
-
|
|
795
|
-
|
|
880
|
+
for dr in candidate:
|
|
881
|
+
field_value = dr[field_name]
|
|
882
|
+
field_type = dr.get_field_type(field_name)
|
|
796
883
|
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
801
|
-
image_content.append(
|
|
802
|
-
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
|
|
803
|
-
)
|
|
804
|
-
|
|
805
|
-
elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]:
|
|
806
|
-
for image_filepath in field_value:
|
|
807
|
-
with open(image_filepath, "rb") as f:
|
|
884
|
+
# image filepath (or list of image filepaths)
|
|
885
|
+
if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any]:
|
|
886
|
+
with open(field_value, "rb") as f:
|
|
808
887
|
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
809
888
|
image_content.append(
|
|
810
889
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
|
|
811
890
|
)
|
|
812
891
|
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
892
|
+
elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]:
|
|
893
|
+
for image_filepath in field_value:
|
|
894
|
+
with open(image_filepath, "rb") as f:
|
|
895
|
+
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
896
|
+
image_content.append(
|
|
897
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
|
|
898
|
+
)
|
|
816
899
|
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
image_content.append({"type": "image_url", "image_url": {"url":
|
|
900
|
+
# image url (or list of image urls)
|
|
901
|
+
elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any]:
|
|
902
|
+
image_content.append({"type": "image_url", "image_url": {"url": field_value}})
|
|
820
903
|
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
|
|
825
|
-
)
|
|
904
|
+
elif field_type.annotation in [list[ImageURL], list[ImageURL] | None, list[ImageURL] | Any]:
|
|
905
|
+
for image_url in field_value:
|
|
906
|
+
image_content.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
826
907
|
|
|
827
|
-
|
|
828
|
-
|
|
908
|
+
# pre-encoded images (or list of pre-encoded images)
|
|
909
|
+
elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any]:
|
|
829
910
|
image_content.append(
|
|
830
|
-
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{
|
|
911
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
|
|
831
912
|
)
|
|
832
913
|
|
|
914
|
+
elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None, list[ImageBase64] | Any]:
|
|
915
|
+
for base64_image in field_value:
|
|
916
|
+
image_content.append(
|
|
917
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
|
918
|
+
)
|
|
919
|
+
|
|
833
920
|
return [{"role": "user", "type": "image", "content": image_content}] if len(image_content) > 0 else []
|
|
834
921
|
|
|
835
922
|
def _get_system_prompt(self, **format_kwargs) -> str | None:
|
|
@@ -849,12 +936,12 @@ class PromptFactory:
|
|
|
849
936
|
|
|
850
937
|
return base_prompt.format(**format_kwargs)
|
|
851
938
|
|
|
852
|
-
def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
|
|
939
|
+
def _get_user_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
|
|
853
940
|
"""
|
|
854
941
|
Returns a list of messages for the chat payload based on the prompt strategy.
|
|
855
942
|
|
|
856
943
|
Args:
|
|
857
|
-
candidate (DataRecord): The input record.
|
|
944
|
+
candidate (DataRecord | list[DataRecord]): The input record(s).
|
|
858
945
|
input_fields (list[str]): The input fields.
|
|
859
946
|
output_fields (list[str]): The output fields.
|
|
860
947
|
kwargs: The formatting kwargs and some keyword arguments provided by the user.
|
|
@@ -943,7 +1030,7 @@ class PromptFactory:
|
|
|
943
1030
|
|
|
944
1031
|
return user_messages
|
|
945
1032
|
|
|
946
|
-
def create_messages(self, candidate: DataRecord, output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
|
|
1033
|
+
def create_messages(self, candidate: DataRecord | list[DataRecord], output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
|
|
947
1034
|
"""
|
|
948
1035
|
Creates the messages for the chat payload based on the prompt strategy.
|
|
949
1036
|
|
|
@@ -955,7 +1042,7 @@ class PromptFactory:
|
|
|
955
1042
|
}
|
|
956
1043
|
|
|
957
1044
|
Args:
|
|
958
|
-
candidate (DataRecord): The input record.
|
|
1045
|
+
candidate (DataRecord | list[DataRecord]): The input record(s).
|
|
959
1046
|
output_fields (list[str]): The output fields.
|
|
960
1047
|
right_candidate (DataRecord | None): The other join input record (only provided for joins).
|
|
961
1048
|
kwargs: The keyword arguments provided by the user.
|
|
@@ -964,11 +1051,11 @@ class PromptFactory:
|
|
|
964
1051
|
list[dict]: The messages for the chat payload.
|
|
965
1052
|
"""
|
|
966
1053
|
# compute the set of input fields
|
|
967
|
-
input_fields = self._get_input_fields(candidate, **kwargs)
|
|
1054
|
+
input_fields = self._get_input_fields(candidate[0] if isinstance(candidate, list) else candidate, **kwargs)
|
|
968
1055
|
right_input_fields = [] if right_candidate is None else self._get_input_fields(right_candidate, **kwargs)
|
|
969
1056
|
|
|
970
1057
|
# use input fields to determine the left / right input modalities
|
|
971
|
-
input_modalities = self._get_input_modalities(candidate, input_fields)
|
|
1058
|
+
input_modalities = self._get_input_modalities(candidate[0] if isinstance(candidate, list) else candidate, input_fields)
|
|
972
1059
|
right_input_modalities = set() if right_candidate is None else self._get_input_modalities(right_candidate, right_input_fields)
|
|
973
1060
|
|
|
974
1061
|
# initialize messages
|
palimpzest/prompts/utils.py
CHANGED
|
@@ -11,12 +11,14 @@ The user has additionally provided you with this description of the task you nee
|
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
13
|
### JOB INSTRUCTIONS ###
|
|
14
|
+
AGG_JOB_INSTRUCTION = """analyze input {modalities} in order to perform an aggregation and generate a JSON object"""
|
|
14
15
|
MAP_JOB_INSTRUCTION = """analyze input {modalities} in order to produce a JSON object"""
|
|
15
16
|
FILTER_JOB_INSTRUCTION = """analyze input {modalities} in order to answer a TRUE / FALSE question"""
|
|
16
17
|
JOIN_JOB_INSTRUCTION = """analyze input {modalities} in order to determine whether two data records satisfy a join condition"""
|
|
17
18
|
PROPOSER_JOB_INSTRUCTION = """analyze input {modalities} in order to produce an answer to a question"""
|
|
18
19
|
|
|
19
|
-
### FILTER / JOIN CONDITIONS ###
|
|
20
|
+
### AGG / FILTER / JOIN CONDITIONS ###
|
|
21
|
+
EXAMPLE_AGG_INSTRUCTION = "Count the distinct number of scientists in the input."
|
|
20
22
|
EXAMPLE_FILTER_CONDITION = "The subject of the input is a foundational computer scientist."
|
|
21
23
|
EXAMPLE_JOIN_CONDITION = "The two inputs are scientists in the same academic field."
|
|
22
24
|
|
|
@@ -48,6 +50,7 @@ TEXT_EXAMPLE_OUTPUT_FIELDS = """- name: the name of the scientist
|
|
|
48
50
|
- birth_year: the year the scientist was born"""
|
|
49
51
|
IMAGE_EXAMPLE_OUTPUT_FIELDS = """- is_bald: true if the scientist is bald and false otherwise"""
|
|
50
52
|
AUDIO_EXAMPLE_OUTPUT_FIELDS = """- birthplace: the city where the scientist was born"""
|
|
53
|
+
AGG_EXAMPLE_OUTPUT_FIELDS = """- num_distinct_scientists: the number of distinct scientists mentioned in the input"""
|
|
51
54
|
|
|
52
55
|
### EXAMPLE CONTEXTS ###
|
|
53
56
|
TEXT_EXAMPLE_CONTEXT = """
|
|
@@ -71,6 +74,30 @@ RIGHT_IMAGE_EXAMPLE_CONTEXT = """
|
|
|
71
74
|
RIGHT_AUDIO_EXAMPLE_CONTEXT = """
|
|
72
75
|
"podcast": <bytes>
|
|
73
76
|
"""
|
|
77
|
+
SECOND_TEXT_EXAMPLE_CONTEXT = """
|
|
78
|
+
"text": "Alan Turing was a pioneering computer scientist and mathematician. He is widely considered to be the father of theoretical computer science and artificial intelligence.",
|
|
79
|
+
"birthday": "June 23, 1912"
|
|
80
|
+
"""
|
|
81
|
+
SECOND_IMAGE_EXAMPLE_CONTEXT = """
|
|
82
|
+
"image": <bytes>,
|
|
83
|
+
"photographer": "PhotoPro42"
|
|
84
|
+
"""
|
|
85
|
+
SECOND_AUDIO_EXAMPLE_CONTEXT = """
|
|
86
|
+
"recording": <bytes>,
|
|
87
|
+
"speaker": "Barbara Walters"
|
|
88
|
+
"""
|
|
89
|
+
THIRD_TEXT_EXAMPLE_CONTEXT = """
|
|
90
|
+
"text": "Ada Lovelace is a historically significant computer scientist.",
|
|
91
|
+
"birthday": "December 10, 1815"
|
|
92
|
+
"""
|
|
93
|
+
THIRD_IMAGE_EXAMPLE_CONTEXT = """
|
|
94
|
+
"image": <bytes>,
|
|
95
|
+
"photographer": "PicturePerfect"
|
|
96
|
+
"""
|
|
97
|
+
THIRD_AUDIO_EXAMPLE_CONTEXT = """
|
|
98
|
+
"recording": <bytes>,
|
|
99
|
+
"speaker": "Anderson Cooper"
|
|
100
|
+
"""
|
|
74
101
|
|
|
75
102
|
### DISCLAIMERS ###
|
|
76
103
|
IMAGE_DISCLAIMER = """
|
|
@@ -85,15 +112,25 @@ RIGHT_IMAGE_DISCLAIMER = """
|
|
|
85
112
|
RIGHT_AUDIO_DISCLAIMER = """
|
|
86
113
|
\n<audio content provided here; assume in this example the podcast is discussing Alan Turing's work on the Enigma code>
|
|
87
114
|
"""
|
|
115
|
+
AGG_IMAGE_DISCLAIMER = """
|
|
116
|
+
\n<image content provided here; assume in this example the first image shows Ada Lovelace, the second image shows Alan Turing, and the third image shows Ada Lovelace again>
|
|
117
|
+
"""
|
|
118
|
+
AGG_AUDIO_DISCLAIMER = """
|
|
119
|
+
\n<audio content provided here; assume in this example the first recording is about Ada Lovelace, the second recording is about Alan Turing, and the third recording is about Ada Lovelace again>
|
|
120
|
+
"""
|
|
88
121
|
|
|
89
122
|
### EXAMPLE REASONINGS ###
|
|
90
123
|
TEXT_EXAMPLE_REASONING = """The text passage mentions the scientist's name as "Augusta Ada King, Countess of Lovelace, also known as Ada Lovelace" and the scientist's birthday as "December 10, 1815". Therefore, the name of the scientist is "Augusta Ada King" and the birth year is 1815."""
|
|
91
124
|
IMAGE_EXAMPLE_REASONING = """The image shows hair on top of the scientist's head, so the is_bald field should be false."""
|
|
92
125
|
AUDIO_EXAMPLE_REASONING = """The newscast recording discusses Ada Lovelace's upbringing in London, so the birthplace field should be "London"."""
|
|
126
|
+
AGG_EXAMPLE_REASONING = """The input contains two distinct scientists: "Augusta Ada King" and "Alan Turing". Although "Ada Lovelace" is mentioned twice, she should only be counted once. Therefore, the number of distinct scientists mentioned in the input is 2."""
|
|
93
127
|
FILTER_EXAMPLE_REASONING = """Ada Lovelace is a foundational computer scientist, therefore the answer is TRUE."""
|
|
94
128
|
JOIN_EXAMPLE_REASONING = """The subject of the left record is Ada Lovelace and the subject of the right record is Alan Turing. Since both inputs are about computer scientists, they satisfy the join condition. Therefore, the answer is TRUE."""
|
|
95
129
|
|
|
96
130
|
### EXAMPLE ANSWERS ###
|
|
131
|
+
AGG_EXAMPLE_ANSWER = """
|
|
132
|
+
"num_distinct_scientists": 2
|
|
133
|
+
"""
|
|
97
134
|
TEXT_EXAMPLE_ANSWER = """
|
|
98
135
|
"name": "Augusta Ada King",
|
|
99
136
|
"birth_year": 1815
|