palimpzest 0.8.6__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +12 -4
- palimpzest/core/data/dataset.py +42 -0
- palimpzest/core/elements/records.py +5 -1
- palimpzest/core/lib/schemas.py +13 -0
- palimpzest/prompts/aggregate_prompts.py +99 -0
- palimpzest/prompts/prompt_factory.py +163 -75
- palimpzest/prompts/utils.py +38 -1
- palimpzest/prompts/validator.py +24 -24
- palimpzest/query/generators/generators.py +9 -7
- palimpzest/query/operators/__init__.py +4 -1
- palimpzest/query/operators/aggregate.py +285 -6
- palimpzest/query/operators/logical.py +17 -4
- palimpzest/query/optimizer/__init__.py +4 -0
- palimpzest/query/optimizer/rules.py +42 -2
- palimpzest/validator/validator.py +7 -7
- {palimpzest-0.8.6.dist-info → palimpzest-0.9.0.dist-info}/METADATA +1 -1
- {palimpzest-0.8.6.dist-info → palimpzest-0.9.0.dist-info}/RECORD +20 -19
- {palimpzest-0.8.6.dist-info → palimpzest-0.9.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.8.6.dist-info → palimpzest-0.9.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.8.6.dist-info → palimpzest-0.9.0.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
4
|
import json
|
|
5
|
+
from typing import Any
|
|
5
6
|
|
|
6
7
|
from pydantic import BaseModel
|
|
7
8
|
|
|
@@ -23,6 +24,12 @@ from palimpzest.core.lib.schemas import (
|
|
|
23
24
|
ImageFilepath,
|
|
24
25
|
ImageURL,
|
|
25
26
|
)
|
|
27
|
+
from palimpzest.prompts.aggregate_prompts import (
|
|
28
|
+
AGG_BASE_SYSTEM_PROMPT,
|
|
29
|
+
AGG_BASE_USER_PROMPT,
|
|
30
|
+
AGG_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
31
|
+
AGG_NO_REASONING_BASE_USER_PROMPT,
|
|
32
|
+
)
|
|
26
33
|
from palimpzest.prompts.convert_prompts import (
|
|
27
34
|
MAP_BASE_SYSTEM_PROMPT,
|
|
28
35
|
MAP_BASE_USER_PROMPT,
|
|
@@ -78,6 +85,12 @@ from palimpzest.prompts.split_proposer_prompts import (
|
|
|
78
85
|
MAP_SPLIT_PROPOSER_BASE_USER_PROMPT,
|
|
79
86
|
)
|
|
80
87
|
from palimpzest.prompts.utils import (
|
|
88
|
+
AGG_AUDIO_DISCLAIMER,
|
|
89
|
+
AGG_EXAMPLE_ANSWER,
|
|
90
|
+
AGG_EXAMPLE_OUTPUT_FIELDS,
|
|
91
|
+
AGG_EXAMPLE_REASONING,
|
|
92
|
+
AGG_IMAGE_DISCLAIMER,
|
|
93
|
+
AGG_JOB_INSTRUCTION,
|
|
81
94
|
AUDIO_DISCLAIMER,
|
|
82
95
|
AUDIO_EXAMPLE_ANSWER,
|
|
83
96
|
AUDIO_EXAMPLE_CONTEXT,
|
|
@@ -86,6 +99,7 @@ from palimpzest.prompts.utils import (
|
|
|
86
99
|
AUDIO_EXAMPLE_REASONING,
|
|
87
100
|
AUDIO_SENTENCE_EXAMPLE_ANSWER,
|
|
88
101
|
DESC_SECTION,
|
|
102
|
+
EXAMPLE_AGG_INSTRUCTION,
|
|
89
103
|
EXAMPLE_FILTER_CONDITION,
|
|
90
104
|
EXAMPLE_JOIN_CONDITION,
|
|
91
105
|
FILTER_EXAMPLE_REASONING,
|
|
@@ -111,12 +125,18 @@ from palimpzest.prompts.utils import (
|
|
|
111
125
|
RIGHT_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
112
126
|
RIGHT_TEXT_EXAMPLE_CONTEXT,
|
|
113
127
|
RIGHT_TEXT_EXAMPLE_INPUT_FIELDS,
|
|
128
|
+
SECOND_AUDIO_EXAMPLE_CONTEXT,
|
|
129
|
+
SECOND_IMAGE_EXAMPLE_CONTEXT,
|
|
130
|
+
SECOND_TEXT_EXAMPLE_CONTEXT,
|
|
114
131
|
TEXT_EXAMPLE_ANSWER,
|
|
115
132
|
TEXT_EXAMPLE_CONTEXT,
|
|
116
133
|
TEXT_EXAMPLE_INPUT_FIELDS,
|
|
117
134
|
TEXT_EXAMPLE_OUTPUT_FIELDS,
|
|
118
135
|
TEXT_EXAMPLE_REASONING,
|
|
119
136
|
TEXT_SENTENCE_EXAMPLE_ANSWER,
|
|
137
|
+
THIRD_AUDIO_EXAMPLE_CONTEXT,
|
|
138
|
+
THIRD_IMAGE_EXAMPLE_CONTEXT,
|
|
139
|
+
THIRD_TEXT_EXAMPLE_CONTEXT,
|
|
120
140
|
)
|
|
121
141
|
|
|
122
142
|
|
|
@@ -124,6 +144,10 @@ class PromptFactory:
|
|
|
124
144
|
"""Factory class for generating prompts for the Generator given the input(s)."""
|
|
125
145
|
|
|
126
146
|
BASE_SYSTEM_PROMPT_MAP = {
|
|
147
|
+
# agg user prompts
|
|
148
|
+
PromptStrategy.AGG: AGG_BASE_SYSTEM_PROMPT,
|
|
149
|
+
PromptStrategy.AGG_NO_REASONING: AGG_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
150
|
+
|
|
127
151
|
# filter system prompts
|
|
128
152
|
PromptStrategy.FILTER: FILTER_BASE_SYSTEM_PROMPT,
|
|
129
153
|
PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
@@ -149,6 +173,10 @@ class PromptFactory:
|
|
|
149
173
|
PromptStrategy.MAP_SPLIT_MERGER: MAP_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
|
|
150
174
|
}
|
|
151
175
|
BASE_USER_PROMPT_MAP = {
|
|
176
|
+
# agg user prompts
|
|
177
|
+
PromptStrategy.AGG: AGG_BASE_USER_PROMPT,
|
|
178
|
+
PromptStrategy.AGG_NO_REASONING: AGG_NO_REASONING_BASE_USER_PROMPT,
|
|
179
|
+
|
|
152
180
|
# filter user prompts
|
|
153
181
|
PromptStrategy.FILTER: FILTER_BASE_USER_PROMPT,
|
|
154
182
|
PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_USER_PROMPT,
|
|
@@ -180,7 +208,7 @@ class PromptFactory:
|
|
|
180
208
|
self.cardinality = cardinality
|
|
181
209
|
self.desc = desc
|
|
182
210
|
|
|
183
|
-
def _get_context(self, candidate: DataRecord, input_fields: list[str]) -> str:
|
|
211
|
+
def _get_context(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> str:
|
|
184
212
|
"""
|
|
185
213
|
Returns the context for the prompt.
|
|
186
214
|
|
|
@@ -193,7 +221,10 @@ class PromptFactory:
|
|
|
193
221
|
"""
|
|
194
222
|
# TODO: remove mask_filepaths=True after SemBench evaluation
|
|
195
223
|
# get context from input record (project_cols will be None if not provided in kwargs)
|
|
196
|
-
|
|
224
|
+
if isinstance(candidate, list):
|
|
225
|
+
context: list[dict] = [record.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True) for record in candidate]
|
|
226
|
+
else:
|
|
227
|
+
context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True)
|
|
197
228
|
|
|
198
229
|
# TODO: MOVE THIS LOGIC INTO A CHUNKING / CONTEXT MANAGEMENT CLASS
|
|
199
230
|
# - this class should be able to:
|
|
@@ -202,8 +233,10 @@ class PromptFactory:
|
|
|
202
233
|
# - handle the issue with `original_messages` (ask Matt if this is not clear)
|
|
203
234
|
# TODO: this does not work for image prompts
|
|
204
235
|
# TODO: this ignores the size of the `orignal_messages` in critique and refine prompts
|
|
236
|
+
# NOTE: llama models are disallowed for aggregation so we can assume context is a dict here
|
|
205
237
|
# cut down on context based on window length
|
|
206
238
|
if self.model.is_llama_model():
|
|
239
|
+
assert isinstance(context, dict), "Llama models are not allowed for aggregation operations."
|
|
207
240
|
total_context_len = len(json.dumps(context, indent=2))
|
|
208
241
|
|
|
209
242
|
# sort fields by length and progressively strip from the longest field until it is short enough;
|
|
@@ -322,7 +355,7 @@ class PromptFactory:
|
|
|
322
355
|
"""
|
|
323
356
|
output_fields_desc = ""
|
|
324
357
|
output_schema: type[BaseModel] = kwargs.get("output_schema")
|
|
325
|
-
if self.prompt_strategy.is_map_prompt():
|
|
358
|
+
if self.prompt_strategy.is_map_prompt() or self.prompt_strategy.is_agg_prompt():
|
|
326
359
|
assert output_schema is not None, "Output schema must be provided for convert prompts."
|
|
327
360
|
|
|
328
361
|
for field_name in sorted(output_fields):
|
|
@@ -332,6 +365,19 @@ class PromptFactory:
|
|
|
332
365
|
# strip the last newline characters from the field descriptions and return
|
|
333
366
|
return output_fields_desc[:-1]
|
|
334
367
|
|
|
368
|
+
def _get_agg_instruction(self, **kwargs) -> str | None:
|
|
369
|
+
"""
|
|
370
|
+
Returns the aggregation instruction for the aggregation operation.
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
str | None: The aggregation instruction (if applicable).
|
|
374
|
+
"""
|
|
375
|
+
agg_instruction = kwargs.get("agg_instruction")
|
|
376
|
+
if self.prompt_strategy.is_agg_prompt():
|
|
377
|
+
assert agg_instruction is not None, "Aggregation instruction must be provided for aggregation operations."
|
|
378
|
+
|
|
379
|
+
return agg_instruction
|
|
380
|
+
|
|
335
381
|
def _get_filter_condition(self, **kwargs) -> str | None:
|
|
336
382
|
"""
|
|
337
383
|
Returns the filter condition for the filter operation.
|
|
@@ -463,6 +509,8 @@ class PromptFactory:
|
|
|
463
509
|
job_instruction = FILTER_JOB_INSTRUCTION
|
|
464
510
|
elif self.prompt_strategy.is_join_prompt():
|
|
465
511
|
job_instruction = JOIN_JOB_INSTRUCTION
|
|
512
|
+
elif self.prompt_strategy.is_agg_prompt():
|
|
513
|
+
job_instruction = AGG_JOB_INSTRUCTION
|
|
466
514
|
|
|
467
515
|
# format the job instruction based on the input modalities
|
|
468
516
|
modalities = self._get_modalities_str(input_modalities)
|
|
@@ -556,6 +604,9 @@ class PromptFactory:
|
|
|
556
604
|
Returns:
|
|
557
605
|
str: The example output fields.
|
|
558
606
|
"""
|
|
607
|
+
if self.prompt_strategy.is_agg_prompt():
|
|
608
|
+
return AGG_EXAMPLE_OUTPUT_FIELDS
|
|
609
|
+
|
|
559
610
|
input_modality_to_example_output_fields = {
|
|
560
611
|
Modality.TEXT: TEXT_EXAMPLE_OUTPUT_FIELDS,
|
|
561
612
|
Modality.IMAGE: IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
@@ -569,17 +620,31 @@ class PromptFactory:
|
|
|
569
620
|
|
|
570
621
|
return example_output_fields
|
|
571
622
|
|
|
572
|
-
def _get_example_context(self, input_modalities: set[Modality], right: bool = False) -> str:
|
|
623
|
+
def _get_example_context(self, input_modalities: set[Modality], right: bool = False, second: bool = False, third: bool = False) -> str:
|
|
573
624
|
"""
|
|
574
625
|
Returns the example context for the prompt.
|
|
575
626
|
|
|
576
627
|
Returns:
|
|
577
628
|
str: The example context.
|
|
578
629
|
"""
|
|
630
|
+
assert not (second and third), "Cannot have both second and third example contexts."
|
|
631
|
+
assert not (right and (second or third)), "Right context is only used for joins; second and third contexts only use for aggregations."
|
|
632
|
+
text_example_context = TEXT_EXAMPLE_CONTEXT
|
|
633
|
+
image_example_context = IMAGE_EXAMPLE_CONTEXT
|
|
634
|
+
audio_example_context = AUDIO_EXAMPLE_CONTEXT
|
|
635
|
+
if second:
|
|
636
|
+
text_example_context = SECOND_TEXT_EXAMPLE_CONTEXT
|
|
637
|
+
image_example_context = SECOND_IMAGE_EXAMPLE_CONTEXT
|
|
638
|
+
audio_example_context = SECOND_AUDIO_EXAMPLE_CONTEXT
|
|
639
|
+
elif third:
|
|
640
|
+
text_example_context = THIRD_TEXT_EXAMPLE_CONTEXT
|
|
641
|
+
image_example_context = THIRD_IMAGE_EXAMPLE_CONTEXT
|
|
642
|
+
audio_example_context = THIRD_AUDIO_EXAMPLE_CONTEXT
|
|
643
|
+
|
|
579
644
|
input_modality_to_example_context = {
|
|
580
|
-
Modality.TEXT: RIGHT_TEXT_EXAMPLE_CONTEXT if right else
|
|
581
|
-
Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_CONTEXT if right else
|
|
582
|
-
Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_CONTEXT if right else
|
|
645
|
+
Modality.TEXT: RIGHT_TEXT_EXAMPLE_CONTEXT if right else text_example_context,
|
|
646
|
+
Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_CONTEXT if right else image_example_context,
|
|
647
|
+
Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_CONTEXT if right else audio_example_context,
|
|
583
648
|
}
|
|
584
649
|
|
|
585
650
|
example_context = ""
|
|
@@ -589,7 +654,7 @@ class PromptFactory:
|
|
|
589
654
|
|
|
590
655
|
return example_context
|
|
591
656
|
|
|
592
|
-
def _get_image_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
|
|
657
|
+
def _get_image_disclaimer(self, input_modalities: set[Modality], right: bool = False, agg: bool = False) -> str:
|
|
593
658
|
"""
|
|
594
659
|
Returns the image disclaimer for the prompt. The disclaimer must be an empty string
|
|
595
660
|
for non-image prompts.
|
|
@@ -597,10 +662,12 @@ class PromptFactory:
|
|
|
597
662
|
Returns:
|
|
598
663
|
str: The image disclaimer. If this is a text prompt then it is an empty string.
|
|
599
664
|
"""
|
|
600
|
-
|
|
665
|
+
assert not (right and agg), "Right image disclaimer is only used for joins; agg image disclaimer only used for aggregations."
|
|
666
|
+
image_disclaimer = AGG_IMAGE_DISCLAIMER if agg else IMAGE_DISCLAIMER
|
|
667
|
+
image_disclaimer = RIGHT_IMAGE_DISCLAIMER if right else image_disclaimer
|
|
601
668
|
return image_disclaimer if Modality.IMAGE in input_modalities else ""
|
|
602
669
|
|
|
603
|
-
def _get_audio_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
|
|
670
|
+
def _get_audio_disclaimer(self, input_modalities: set[Modality], right: bool = False, agg: bool = False) -> str:
|
|
604
671
|
"""
|
|
605
672
|
Returns the audio disclaimer for the prompt. The disclaimer must be an empty string
|
|
606
673
|
for non-audio prompts.
|
|
@@ -608,7 +675,9 @@ class PromptFactory:
|
|
|
608
675
|
Returns:
|
|
609
676
|
str: The audio disclaimer. If this is a text prompt then it is an empty string.
|
|
610
677
|
"""
|
|
611
|
-
|
|
678
|
+
assert not (right and agg), "Right audio disclaimer is only used for joins; agg audio disclaimer only used for aggregations."
|
|
679
|
+
audio_disclaimer = AGG_AUDIO_DISCLAIMER if agg else AUDIO_DISCLAIMER
|
|
680
|
+
audio_disclaimer = RIGHT_AUDIO_DISCLAIMER if right else audio_disclaimer
|
|
612
681
|
return audio_disclaimer if Modality.AUDIO in input_modalities else ""
|
|
613
682
|
|
|
614
683
|
def _get_example_reasoning(self, input_modalities: set[Modality]) -> str:
|
|
@@ -622,6 +691,8 @@ class PromptFactory:
|
|
|
622
691
|
return FILTER_EXAMPLE_REASONING
|
|
623
692
|
elif self.prompt_strategy.is_join_prompt():
|
|
624
693
|
return JOIN_EXAMPLE_REASONING
|
|
694
|
+
elif self.prompt_strategy.is_agg_prompt():
|
|
695
|
+
return AGG_EXAMPLE_REASONING
|
|
625
696
|
|
|
626
697
|
input_modality_to_example_reasoning = {
|
|
627
698
|
Modality.TEXT: TEXT_EXAMPLE_REASONING,
|
|
@@ -643,6 +714,9 @@ class PromptFactory:
|
|
|
643
714
|
Returns:
|
|
644
715
|
str: The example answer.
|
|
645
716
|
"""
|
|
717
|
+
if self.prompt_strategy.is_agg_prompt():
|
|
718
|
+
return AGG_EXAMPLE_ANSWER
|
|
719
|
+
|
|
646
720
|
use_sentence_answers = self.prompt_strategy.is_split_proposer_prompt() or self.prompt_strategy.is_moa_proposer_prompt()
|
|
647
721
|
input_modality_to_example_answer = {
|
|
648
722
|
Modality.TEXT: TEXT_SENTENCE_EXAMPLE_ANSWER if use_sentence_answers else TEXT_EXAMPLE_ANSWER,
|
|
@@ -661,7 +735,7 @@ class PromptFactory:
|
|
|
661
735
|
|
|
662
736
|
def _get_all_format_kwargs(
|
|
663
737
|
self,
|
|
664
|
-
candidate: DataRecord,
|
|
738
|
+
candidate: DataRecord | list[DataRecord],
|
|
665
739
|
input_fields: list[str],
|
|
666
740
|
input_modalities: set[Modality],
|
|
667
741
|
output_fields: list[str],
|
|
@@ -685,8 +759,9 @@ class PromptFactory:
|
|
|
685
759
|
# get format kwargs which depend on the input data
|
|
686
760
|
input_format_kwargs = {
|
|
687
761
|
"context": self._get_context(candidate, input_fields),
|
|
688
|
-
"input_fields_desc": self._get_input_fields_desc(candidate, input_fields),
|
|
762
|
+
"input_fields_desc": self._get_input_fields_desc(candidate[0] if isinstance(candidate, list) else candidate, input_fields),
|
|
689
763
|
"output_fields_desc": self._get_output_fields_desc(output_fields, **kwargs),
|
|
764
|
+
"agg_instruction": self._get_agg_instruction(**kwargs),
|
|
690
765
|
"filter_condition": self._get_filter_condition(**kwargs),
|
|
691
766
|
"join_condition": self._get_join_condition(**kwargs),
|
|
692
767
|
"original_output": self._get_original_output(**kwargs),
|
|
@@ -715,11 +790,14 @@ class PromptFactory:
|
|
|
715
790
|
"right_example_input_fields": self._get_example_input_fields(right_input_modalities, right=True),
|
|
716
791
|
"example_output_fields": self._get_example_output_fields(input_modalities),
|
|
717
792
|
"example_context": self._get_example_context(input_modalities),
|
|
793
|
+
"second_example_context": self._get_example_context(input_modalities, second=True) if self.prompt_strategy.is_agg_prompt() else "",
|
|
794
|
+
"third_example_context": self._get_example_context(input_modalities, third=True) if self.prompt_strategy.is_agg_prompt() else "",
|
|
718
795
|
"right_example_context": self._get_example_context(right_input_modalities, right=True),
|
|
719
|
-
"image_disclaimer": self._get_image_disclaimer(input_modalities),
|
|
720
|
-
"audio_disclaimer": self._get_audio_disclaimer(input_modalities),
|
|
796
|
+
"image_disclaimer": self._get_image_disclaimer(input_modalities, agg=self.prompt_strategy.is_agg_prompt()),
|
|
797
|
+
"audio_disclaimer": self._get_audio_disclaimer(input_modalities, agg=self.prompt_strategy.is_agg_prompt()),
|
|
721
798
|
"right_image_disclaimer": self._get_image_disclaimer(right_input_modalities, right=True),
|
|
722
799
|
"right_audio_disclaimer": self._get_audio_disclaimer(right_input_modalities, right=True),
|
|
800
|
+
"example_agg_instruction": EXAMPLE_AGG_INSTRUCTION,
|
|
723
801
|
"example_filter_condition": EXAMPLE_FILTER_CONDITION,
|
|
724
802
|
"example_join_condition": EXAMPLE_JOIN_CONDITION,
|
|
725
803
|
"example_reasoning": self._get_example_reasoning(input_modalities),
|
|
@@ -729,106 +807,116 @@ class PromptFactory:
|
|
|
729
807
|
# return all format kwargs
|
|
730
808
|
return {**input_format_kwargs, **prompt_strategy_format_kwargs}
|
|
731
809
|
|
|
732
|
-
def _create_audio_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
|
|
810
|
+
def _create_audio_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> list[dict]:
|
|
733
811
|
"""
|
|
734
|
-
Parses the candidate record and returns the audio messages for the chat payload.
|
|
812
|
+
Parses the candidate record(s) and returns the audio messages for the chat payload.
|
|
735
813
|
|
|
736
814
|
Args:
|
|
737
|
-
candidate (DataRecord): The input record.
|
|
815
|
+
candidate (DataRecord | list[DataRecord]): The input record(s).
|
|
738
816
|
input_fields (list[str]): The list of input fields.
|
|
739
817
|
|
|
740
818
|
Returns:
|
|
741
819
|
list[dict]: The audio messages for the chat payload.
|
|
742
820
|
"""
|
|
821
|
+
# normalize type to be list[DataRecord]
|
|
822
|
+
if isinstance(candidate, DataRecord):
|
|
823
|
+
candidate = [candidate]
|
|
824
|
+
|
|
743
825
|
# create a message for each audio recording in an input field with an audio (or list of audio) type
|
|
744
826
|
audio_content = []
|
|
745
827
|
for field_name in input_fields:
|
|
746
|
-
|
|
747
|
-
|
|
828
|
+
for dr in candidate:
|
|
829
|
+
field_value = dr[field_name]
|
|
830
|
+
field_type = dr.get_field_type(field_name)
|
|
748
831
|
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
753
|
-
audio_content.append(
|
|
754
|
-
{"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
|
|
755
|
-
)
|
|
756
|
-
|
|
757
|
-
elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None]:
|
|
758
|
-
for audio_filepath in field_value:
|
|
759
|
-
with open(audio_filepath, "rb") as f:
|
|
832
|
+
# audio filepath (or list of audio filepaths)
|
|
833
|
+
if field_type.annotation in [AudioFilepath, AudioFilepath | None, AudioFilepath | Any]:
|
|
834
|
+
with open(field_value, "rb") as f:
|
|
760
835
|
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
761
836
|
audio_content.append(
|
|
762
837
|
{"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
|
|
763
838
|
)
|
|
764
839
|
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
840
|
+
elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None, list[AudioFilepath] | Any]:
|
|
841
|
+
for audio_filepath in field_value:
|
|
842
|
+
with open(audio_filepath, "rb") as f:
|
|
843
|
+
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
844
|
+
audio_content.append(
|
|
845
|
+
{"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
|
|
846
|
+
)
|
|
770
847
|
|
|
771
|
-
|
|
772
|
-
|
|
848
|
+
# pre-encoded images (or list of pre-encoded images)
|
|
849
|
+
elif field_type.annotation in [AudioBase64, AudioBase64 | None, AudioBase64 | Any]:
|
|
773
850
|
audio_content.append(
|
|
774
|
-
{"type": "input_audio", "input_audio": {"data":
|
|
851
|
+
{"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
|
|
775
852
|
)
|
|
776
853
|
|
|
854
|
+
elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None, list[AudioBase64] | Any]:
|
|
855
|
+
for base64_audio in field_value:
|
|
856
|
+
audio_content.append(
|
|
857
|
+
{"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
|
|
858
|
+
)
|
|
859
|
+
|
|
777
860
|
return [{"role": "user", "type": "input_audio", "content": audio_content}] if len(audio_content) > 0 else []
|
|
778
861
|
|
|
779
|
-
def _create_image_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
|
|
862
|
+
def _create_image_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str]) -> list[dict]:
|
|
780
863
|
"""
|
|
781
|
-
Parses the candidate record and returns the image messages for the chat payload.
|
|
864
|
+
Parses the candidate record(s) and returns the image messages for the chat payload.
|
|
782
865
|
|
|
783
866
|
Args:
|
|
784
|
-
candidate (DataRecord): The input record.
|
|
867
|
+
candidate (DataRecord | list[DataRecord]): The input record(s).
|
|
785
868
|
input_fields (list[str]): The list of input fields.
|
|
786
869
|
|
|
787
870
|
Returns:
|
|
788
871
|
list[dict]: The image messages for the chat payload.
|
|
789
872
|
"""
|
|
873
|
+
# normalize type to be list[DataRecord]
|
|
874
|
+
if isinstance(candidate, DataRecord):
|
|
875
|
+
candidate = [candidate]
|
|
876
|
+
|
|
790
877
|
# create a message for each image in an input field with an image (or list of image) type
|
|
791
878
|
image_content = []
|
|
792
879
|
for field_name in input_fields:
|
|
793
|
-
|
|
794
|
-
|
|
880
|
+
for dr in candidate:
|
|
881
|
+
field_value = dr[field_name]
|
|
882
|
+
field_type = dr.get_field_type(field_name)
|
|
795
883
|
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
800
|
-
image_content.append(
|
|
801
|
-
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
|
|
802
|
-
)
|
|
803
|
-
|
|
804
|
-
elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None]:
|
|
805
|
-
for image_filepath in field_value:
|
|
806
|
-
with open(image_filepath, "rb") as f:
|
|
884
|
+
# image filepath (or list of image filepaths)
|
|
885
|
+
if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any]:
|
|
886
|
+
with open(field_value, "rb") as f:
|
|
807
887
|
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
808
888
|
image_content.append(
|
|
809
889
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
|
|
810
890
|
)
|
|
811
891
|
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
892
|
+
elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]:
|
|
893
|
+
for image_filepath in field_value:
|
|
894
|
+
with open(image_filepath, "rb") as f:
|
|
895
|
+
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
896
|
+
image_content.append(
|
|
897
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
|
|
898
|
+
)
|
|
815
899
|
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
image_content.append({"type": "image_url", "image_url": {"url":
|
|
900
|
+
# image url (or list of image urls)
|
|
901
|
+
elif field_type.annotation in [ImageURL, ImageURL | None, ImageURL | Any]:
|
|
902
|
+
image_content.append({"type": "image_url", "image_url": {"url": field_value}})
|
|
819
903
|
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
|
|
824
|
-
)
|
|
904
|
+
elif field_type.annotation in [list[ImageURL], list[ImageURL] | None, list[ImageURL] | Any]:
|
|
905
|
+
for image_url in field_value:
|
|
906
|
+
image_content.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
825
907
|
|
|
826
|
-
|
|
827
|
-
|
|
908
|
+
# pre-encoded images (or list of pre-encoded images)
|
|
909
|
+
elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any]:
|
|
828
910
|
image_content.append(
|
|
829
|
-
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{
|
|
911
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
|
|
830
912
|
)
|
|
831
913
|
|
|
914
|
+
elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None, list[ImageBase64] | Any]:
|
|
915
|
+
for base64_image in field_value:
|
|
916
|
+
image_content.append(
|
|
917
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
|
918
|
+
)
|
|
919
|
+
|
|
832
920
|
return [{"role": "user", "type": "image", "content": image_content}] if len(image_content) > 0 else []
|
|
833
921
|
|
|
834
922
|
def _get_system_prompt(self, **format_kwargs) -> str | None:
|
|
@@ -848,12 +936,12 @@ class PromptFactory:
|
|
|
848
936
|
|
|
849
937
|
return base_prompt.format(**format_kwargs)
|
|
850
938
|
|
|
851
|
-
def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
|
|
939
|
+
def _get_user_messages(self, candidate: DataRecord | list[DataRecord], input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
|
|
852
940
|
"""
|
|
853
941
|
Returns a list of messages for the chat payload based on the prompt strategy.
|
|
854
942
|
|
|
855
943
|
Args:
|
|
856
|
-
candidate (DataRecord): The input record.
|
|
944
|
+
candidate (DataRecord | list[DataRecord]): The input record(s).
|
|
857
945
|
input_fields (list[str]): The input fields.
|
|
858
946
|
output_fields (list[str]): The output fields.
|
|
859
947
|
kwargs: The formatting kwargs and some keyword arguments provided by the user.
|
|
@@ -942,7 +1030,7 @@ class PromptFactory:
|
|
|
942
1030
|
|
|
943
1031
|
return user_messages
|
|
944
1032
|
|
|
945
|
-
def create_messages(self, candidate: DataRecord, output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
|
|
1033
|
+
def create_messages(self, candidate: DataRecord | list[DataRecord], output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
|
|
946
1034
|
"""
|
|
947
1035
|
Creates the messages for the chat payload based on the prompt strategy.
|
|
948
1036
|
|
|
@@ -954,7 +1042,7 @@ class PromptFactory:
|
|
|
954
1042
|
}
|
|
955
1043
|
|
|
956
1044
|
Args:
|
|
957
|
-
candidate (DataRecord): The input record.
|
|
1045
|
+
candidate (DataRecord | list[DataRecord]): The input record(s).
|
|
958
1046
|
output_fields (list[str]): The output fields.
|
|
959
1047
|
right_candidate (DataRecord | None): The other join input record (only provided for joins).
|
|
960
1048
|
kwargs: The keyword arguments provided by the user.
|
|
@@ -963,11 +1051,11 @@ class PromptFactory:
|
|
|
963
1051
|
list[dict]: The messages for the chat payload.
|
|
964
1052
|
"""
|
|
965
1053
|
# compute the set of input fields
|
|
966
|
-
input_fields = self._get_input_fields(candidate, **kwargs)
|
|
1054
|
+
input_fields = self._get_input_fields(candidate[0] if isinstance(candidate, list) else candidate, **kwargs)
|
|
967
1055
|
right_input_fields = [] if right_candidate is None else self._get_input_fields(right_candidate, **kwargs)
|
|
968
1056
|
|
|
969
1057
|
# use input fields to determine the left / right input modalities
|
|
970
|
-
input_modalities = self._get_input_modalities(candidate, input_fields)
|
|
1058
|
+
input_modalities = self._get_input_modalities(candidate[0] if isinstance(candidate, list) else candidate, input_fields)
|
|
971
1059
|
right_input_modalities = set() if right_candidate is None else self._get_input_modalities(right_candidate, right_input_fields)
|
|
972
1060
|
|
|
973
1061
|
# initialize messages
|
palimpzest/prompts/utils.py
CHANGED
|
@@ -11,12 +11,14 @@ The user has additionally provided you with this description of the task you nee
|
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
13
|
### JOB INSTRUCTIONS ###
|
|
14
|
+
AGG_JOB_INSTRUCTION = """analyze input {modalities} in order to perform an aggregation and generate a JSON object"""
|
|
14
15
|
MAP_JOB_INSTRUCTION = """analyze input {modalities} in order to produce a JSON object"""
|
|
15
16
|
FILTER_JOB_INSTRUCTION = """analyze input {modalities} in order to answer a TRUE / FALSE question"""
|
|
16
17
|
JOIN_JOB_INSTRUCTION = """analyze input {modalities} in order to determine whether two data records satisfy a join condition"""
|
|
17
18
|
PROPOSER_JOB_INSTRUCTION = """analyze input {modalities} in order to produce an answer to a question"""
|
|
18
19
|
|
|
19
|
-
### FILTER / JOIN CONDITIONS ###
|
|
20
|
+
### AGG / FILTER / JOIN CONDITIONS ###
|
|
21
|
+
EXAMPLE_AGG_INSTRUCTION = "Count the distinct number of scientists in the input."
|
|
20
22
|
EXAMPLE_FILTER_CONDITION = "The subject of the input is a foundational computer scientist."
|
|
21
23
|
EXAMPLE_JOIN_CONDITION = "The two inputs are scientists in the same academic field."
|
|
22
24
|
|
|
@@ -48,6 +50,7 @@ TEXT_EXAMPLE_OUTPUT_FIELDS = """- name: the name of the scientist
|
|
|
48
50
|
- birth_year: the year the scientist was born"""
|
|
49
51
|
IMAGE_EXAMPLE_OUTPUT_FIELDS = """- is_bald: true if the scientist is bald and false otherwise"""
|
|
50
52
|
AUDIO_EXAMPLE_OUTPUT_FIELDS = """- birthplace: the city where the scientist was born"""
|
|
53
|
+
AGG_EXAMPLE_OUTPUT_FIELDS = """- num_distinct_scientists: the number of distinct scientists mentioned in the input"""
|
|
51
54
|
|
|
52
55
|
### EXAMPLE CONTEXTS ###
|
|
53
56
|
TEXT_EXAMPLE_CONTEXT = """
|
|
@@ -71,6 +74,30 @@ RIGHT_IMAGE_EXAMPLE_CONTEXT = """
|
|
|
71
74
|
RIGHT_AUDIO_EXAMPLE_CONTEXT = """
|
|
72
75
|
"podcast": <bytes>
|
|
73
76
|
"""
|
|
77
|
+
SECOND_TEXT_EXAMPLE_CONTEXT = """
|
|
78
|
+
"text": "Alan Turing was a pioneering computer scientist and mathematician. He is widely considered to be the father of theoretical computer science and artificial intelligence.",
|
|
79
|
+
"birthday": "June 23, 1912"
|
|
80
|
+
"""
|
|
81
|
+
SECOND_IMAGE_EXAMPLE_CONTEXT = """
|
|
82
|
+
"image": <bytes>,
|
|
83
|
+
"photographer": "PhotoPro42"
|
|
84
|
+
"""
|
|
85
|
+
SECOND_AUDIO_EXAMPLE_CONTEXT = """
|
|
86
|
+
"recording": <bytes>,
|
|
87
|
+
"speaker": "Barbara Walters"
|
|
88
|
+
"""
|
|
89
|
+
THIRD_TEXT_EXAMPLE_CONTEXT = """
|
|
90
|
+
"text": "Ada Lovelace is a historically significant computer scientist.",
|
|
91
|
+
"birthday": "December 10, 1815"
|
|
92
|
+
"""
|
|
93
|
+
THIRD_IMAGE_EXAMPLE_CONTEXT = """
|
|
94
|
+
"image": <bytes>,
|
|
95
|
+
"photographer": "PicturePerfect"
|
|
96
|
+
"""
|
|
97
|
+
THIRD_AUDIO_EXAMPLE_CONTEXT = """
|
|
98
|
+
"recording": <bytes>,
|
|
99
|
+
"speaker": "Anderson Cooper"
|
|
100
|
+
"""
|
|
74
101
|
|
|
75
102
|
### DISCLAIMERS ###
|
|
76
103
|
IMAGE_DISCLAIMER = """
|
|
@@ -85,15 +112,25 @@ RIGHT_IMAGE_DISCLAIMER = """
|
|
|
85
112
|
RIGHT_AUDIO_DISCLAIMER = """
|
|
86
113
|
\n<audio content provided here; assume in this example the podcast is discussing Alan Turing's work on the Enigma code>
|
|
87
114
|
"""
|
|
115
|
+
AGG_IMAGE_DISCLAIMER = """
|
|
116
|
+
\n<image content provided here; assume in this example the first image shows Ada Lovelace, the second image shows Alan Turing, and the third image shows Ada Lovelace again>
|
|
117
|
+
"""
|
|
118
|
+
AGG_AUDIO_DISCLAIMER = """
|
|
119
|
+
\n<audio content provided here; assume in this example the first recording is about Ada Lovelace, the second recording is about Alan Turing, and the third recording is about Ada Lovelace again>
|
|
120
|
+
"""
|
|
88
121
|
|
|
89
122
|
### EXAMPLE REASONINGS ###
|
|
90
123
|
TEXT_EXAMPLE_REASONING = """The text passage mentions the scientist's name as "Augusta Ada King, Countess of Lovelace, also known as Ada Lovelace" and the scientist's birthday as "December 10, 1815". Therefore, the name of the scientist is "Augusta Ada King" and the birth year is 1815."""
|
|
91
124
|
IMAGE_EXAMPLE_REASONING = """The image shows hair on top of the scientist's head, so the is_bald field should be false."""
|
|
92
125
|
AUDIO_EXAMPLE_REASONING = """The newscast recording discusses Ada Lovelace's upbringing in London, so the birthplace field should be "London"."""
|
|
126
|
+
AGG_EXAMPLE_REASONING = """The input contains two distinct scientists: "Augusta Ada King" and "Alan Turing". Although "Ada Lovelace" is mentioned twice, she should only be counted once. Therefore, the number of distinct scientists mentioned in the input is 2."""
|
|
93
127
|
FILTER_EXAMPLE_REASONING = """Ada Lovelace is a foundational computer scientist, therefore the answer is TRUE."""
|
|
94
128
|
JOIN_EXAMPLE_REASONING = """The subject of the left record is Ada Lovelace and the subject of the right record is Alan Turing. Since both inputs are about computer scientists, they satisfy the join condition. Therefore, the answer is TRUE."""
|
|
95
129
|
|
|
96
130
|
### EXAMPLE ANSWERS ###
|
|
131
|
+
AGG_EXAMPLE_ANSWER = """
|
|
132
|
+
"num_distinct_scientists": 2
|
|
133
|
+
"""
|
|
97
134
|
TEXT_EXAMPLE_ANSWER = """
|
|
98
135
|
"name": "Augusta Ada King",
|
|
99
136
|
"birth_year": 1815
|