palimpzest 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/constants.py +38 -62
- palimpzest/core/data/iter_dataset.py +5 -5
- palimpzest/core/elements/groupbysig.py +1 -1
- palimpzest/core/elements/records.py +91 -109
- palimpzest/core/lib/schemas.py +23 -0
- palimpzest/core/models.py +3 -3
- palimpzest/prompts/__init__.py +2 -6
- palimpzest/prompts/convert_prompts.py +10 -66
- palimpzest/prompts/critique_and_refine_prompts.py +66 -0
- palimpzest/prompts/filter_prompts.py +8 -46
- palimpzest/prompts/join_prompts.py +12 -75
- palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
- palimpzest/prompts/moa_proposer_prompts.py +87 -0
- palimpzest/prompts/prompt_factory.py +351 -479
- palimpzest/prompts/split_merge_prompts.py +51 -2
- palimpzest/prompts/split_proposer_prompts.py +48 -16
- palimpzest/prompts/utils.py +109 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
- palimpzest/query/execution/execution_strategy.py +4 -4
- palimpzest/query/execution/mab_execution_strategy.py +1 -2
- palimpzest/query/execution/parallel_execution_strategy.py +3 -3
- palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
- palimpzest/query/generators/generators.py +31 -17
- palimpzest/query/operators/__init__.py +15 -2
- palimpzest/query/operators/aggregate.py +21 -19
- palimpzest/query/operators/compute.py +6 -8
- palimpzest/query/operators/convert.py +12 -37
- palimpzest/query/operators/critique_and_refine.py +194 -0
- palimpzest/query/operators/distinct.py +7 -7
- palimpzest/query/operators/filter.py +13 -25
- palimpzest/query/operators/join.py +321 -192
- palimpzest/query/operators/limit.py +4 -4
- palimpzest/query/operators/mixture_of_agents.py +246 -0
- palimpzest/query/operators/physical.py +25 -2
- palimpzest/query/operators/project.py +4 -4
- palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
- palimpzest/query/operators/retrieve.py +10 -9
- palimpzest/query/operators/scan.py +9 -10
- palimpzest/query/operators/search.py +18 -24
- palimpzest/query/operators/split.py +321 -0
- palimpzest/query/optimizer/__init__.py +12 -8
- palimpzest/query/optimizer/optimizer.py +12 -10
- palimpzest/query/optimizer/rules.py +201 -108
- palimpzest/query/optimizer/tasks.py +18 -6
- palimpzest/validator/validator.py +7 -9
- {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/METADATA +3 -8
- palimpzest-0.8.4.dist-info/RECORD +95 -0
- palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
- palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
- palimpzest/prompts/util_phrases.py +0 -19
- palimpzest/query/operators/critique_and_refine_convert.py +0 -113
- palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
- palimpzest/query/operators/split_convert.py +0 -170
- palimpzest-0.8.2.dist-info/RECORD +0 -95
- {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/WHEEL +0 -0
- {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
4
|
import json
|
|
5
|
-
from string import Formatter
|
|
6
5
|
|
|
7
6
|
from pydantic import BaseModel
|
|
8
7
|
|
|
@@ -10,137 +9,114 @@ from palimpzest.constants import (
|
|
|
10
9
|
LLAMA_CONTEXT_TOKENS_LIMIT,
|
|
11
10
|
TOKENS_PER_CHARACTER,
|
|
12
11
|
Cardinality,
|
|
12
|
+
Modality,
|
|
13
13
|
Model,
|
|
14
14
|
PromptStrategy,
|
|
15
15
|
)
|
|
16
16
|
from palimpzest.core.elements.records import DataRecord
|
|
17
|
-
from palimpzest.core.lib.schemas import
|
|
17
|
+
from palimpzest.core.lib.schemas import (
|
|
18
|
+
AUDIO_FIELD_TYPES,
|
|
19
|
+
IMAGE_FIELD_TYPES,
|
|
20
|
+
AudioBase64,
|
|
21
|
+
AudioFilepath,
|
|
22
|
+
ImageBase64,
|
|
23
|
+
ImageFilepath,
|
|
24
|
+
ImageURL,
|
|
25
|
+
)
|
|
18
26
|
from palimpzest.prompts.convert_prompts import (
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
|
|
24
|
-
COT_QA_AUDIO_EXAMPLE_REASONING,
|
|
25
|
-
COT_QA_AUDIO_JOB_INSTRUCTION,
|
|
26
|
-
COT_QA_BASE_SYSTEM_PROMPT,
|
|
27
|
-
COT_QA_BASE_USER_PROMPT,
|
|
28
|
-
COT_QA_EXAMPLE_ANSWER,
|
|
29
|
-
COT_QA_EXAMPLE_CONTEXT,
|
|
30
|
-
COT_QA_EXAMPLE_INPUT_FIELDS,
|
|
31
|
-
COT_QA_EXAMPLE_OUTPUT_FIELDS,
|
|
32
|
-
COT_QA_EXAMPLE_REASONING,
|
|
33
|
-
COT_QA_IMAGE_DISCLAIMER,
|
|
34
|
-
COT_QA_IMAGE_EXAMPLE_ANSWER,
|
|
35
|
-
COT_QA_IMAGE_EXAMPLE_CONTEXT,
|
|
36
|
-
COT_QA_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
37
|
-
COT_QA_IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
38
|
-
COT_QA_IMAGE_EXAMPLE_REASONING,
|
|
39
|
-
COT_QA_IMAGE_JOB_INSTRUCTION,
|
|
40
|
-
COT_QA_JOB_INSTRUCTION,
|
|
41
|
-
COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
42
|
-
COT_QA_NO_REASONING_BASE_USER_PROMPT,
|
|
27
|
+
MAP_BASE_SYSTEM_PROMPT,
|
|
28
|
+
MAP_BASE_USER_PROMPT,
|
|
29
|
+
MAP_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
30
|
+
MAP_NO_REASONING_BASE_USER_PROMPT,
|
|
43
31
|
)
|
|
44
|
-
from palimpzest.prompts.
|
|
32
|
+
from palimpzest.prompts.critique_and_refine_prompts import (
|
|
45
33
|
BASE_CRITIQUE_PROMPT,
|
|
46
34
|
BASE_REFINEMENT_PROMPT,
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
35
|
+
FILTER_CRITIQUE_CRITERIA,
|
|
36
|
+
FILTER_CRITIQUE_FINISH_INSTRUCTION,
|
|
37
|
+
FILTER_REFINEMENT_CRITERIA,
|
|
38
|
+
FILTER_REFINEMENT_FINISH_INSTRUCTION,
|
|
39
|
+
MAP_CRITIQUE_CRITERIA,
|
|
40
|
+
MAP_CRITIQUE_FINISH_INSTRUCTION,
|
|
41
|
+
MAP_REFINEMENT_CRITERIA,
|
|
42
|
+
MAP_REFINEMENT_FINISH_INSTRUCTION,
|
|
53
43
|
)
|
|
54
44
|
from palimpzest.prompts.filter_prompts import (
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
COT_BOOL_AUDIO_EXAMPLE_REASONING,
|
|
60
|
-
COT_BOOL_AUDIO_JOB_INSTRUCTION,
|
|
61
|
-
COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
62
|
-
COT_BOOL_BASE_USER_PROMPT,
|
|
63
|
-
COT_BOOL_EXAMPLE_CONTEXT,
|
|
64
|
-
COT_BOOL_EXAMPLE_FILTER_CONDITION,
|
|
65
|
-
COT_BOOL_EXAMPLE_INPUT_FIELDS,
|
|
66
|
-
COT_BOOL_EXAMPLE_REASONING,
|
|
67
|
-
COT_BOOL_IMAGE_DISCLAIMER,
|
|
68
|
-
COT_BOOL_IMAGE_EXAMPLE_CONTEXT,
|
|
69
|
-
COT_BOOL_IMAGE_EXAMPLE_FILTER_CONDITION,
|
|
70
|
-
COT_BOOL_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
71
|
-
COT_BOOL_IMAGE_EXAMPLE_REASONING,
|
|
72
|
-
COT_BOOL_IMAGE_JOB_INSTRUCTION,
|
|
73
|
-
COT_BOOL_JOB_INSTRUCTION,
|
|
74
|
-
COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
75
|
-
COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
|
|
45
|
+
FILTER_BASE_SYSTEM_PROMPT,
|
|
46
|
+
FILTER_BASE_USER_PROMPT,
|
|
47
|
+
FILTER_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
48
|
+
FILTER_NO_REASONING_BASE_USER_PROMPT,
|
|
76
49
|
)
|
|
77
50
|
from palimpzest.prompts.join_prompts import (
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
COT_JOIN_AUDIO_EXAMPLE_REASONING,
|
|
83
|
-
COT_JOIN_AUDIO_JOB_INSTRUCTION,
|
|
84
|
-
COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
|
|
85
|
-
COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
86
|
-
COT_JOIN_BASE_SYSTEM_PROMPT,
|
|
87
|
-
COT_JOIN_BASE_USER_PROMPT,
|
|
88
|
-
COT_JOIN_EXAMPLE_CONTEXT,
|
|
89
|
-
COT_JOIN_EXAMPLE_INPUT_FIELDS,
|
|
90
|
-
COT_JOIN_EXAMPLE_JOIN_CONDITION,
|
|
91
|
-
COT_JOIN_EXAMPLE_REASONING,
|
|
92
|
-
COT_JOIN_IMAGE_DISCLAIMER,
|
|
93
|
-
COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
|
|
94
|
-
COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
95
|
-
COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
|
|
96
|
-
COT_JOIN_IMAGE_EXAMPLE_REASONING,
|
|
97
|
-
COT_JOIN_IMAGE_JOB_INSTRUCTION,
|
|
98
|
-
COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
|
|
99
|
-
COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
100
|
-
COT_JOIN_JOB_INSTRUCTION,
|
|
101
|
-
COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
102
|
-
COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
103
|
-
COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
|
|
104
|
-
COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
|
|
105
|
-
COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
106
|
-
COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
|
|
51
|
+
JOIN_BASE_SYSTEM_PROMPT,
|
|
52
|
+
JOIN_BASE_USER_PROMPT,
|
|
53
|
+
JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
54
|
+
JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
107
55
|
)
|
|
108
|
-
from palimpzest.prompts.
|
|
109
|
-
|
|
110
|
-
|
|
56
|
+
from palimpzest.prompts.moa_aggregator_prompts import (
|
|
57
|
+
FILTER_MOA_AGG_BASE_SYSTEM_PROMPT,
|
|
58
|
+
FILTER_MOA_AGG_BASE_USER_PROMPT,
|
|
59
|
+
MAP_MOA_AGG_BASE_SYSTEM_PROMPT,
|
|
60
|
+
MAP_MOA_AGG_BASE_USER_PROMPT,
|
|
111
61
|
)
|
|
112
|
-
from palimpzest.prompts.
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
COT_MOA_PROPOSER_EXAMPLE_INPUT_FIELDS,
|
|
118
|
-
COT_MOA_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
|
|
119
|
-
COT_MOA_PROPOSER_IMAGE_DISCLAIMER,
|
|
120
|
-
COT_MOA_PROPOSER_IMAGE_EXAMPLE_ANSWER,
|
|
121
|
-
COT_MOA_PROPOSER_IMAGE_EXAMPLE_CONTEXT,
|
|
122
|
-
COT_MOA_PROPOSER_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
123
|
-
COT_MOA_PROPOSER_IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
124
|
-
COT_MOA_PROPOSER_IMAGE_JOB_INSTRUCTION,
|
|
125
|
-
COT_MOA_PROPOSER_JOB_INSTRUCTION,
|
|
62
|
+
from palimpzest.prompts.moa_proposer_prompts import (
|
|
63
|
+
FILTER_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
64
|
+
FILTER_MOA_PROPOSER_BASE_USER_PROMPT,
|
|
65
|
+
MAP_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
66
|
+
MAP_MOA_PROPOSER_BASE_USER_PROMPT,
|
|
126
67
|
)
|
|
127
68
|
from palimpzest.prompts.split_merge_prompts import (
|
|
128
|
-
|
|
129
|
-
|
|
69
|
+
FILTER_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
|
|
70
|
+
FILTER_SPLIT_MERGER_BASE_USER_PROMPT,
|
|
71
|
+
MAP_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
|
|
72
|
+
MAP_SPLIT_MERGER_BASE_USER_PROMPT,
|
|
130
73
|
)
|
|
131
74
|
from palimpzest.prompts.split_proposer_prompts import (
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
SPLIT_PROPOSER_EXAMPLE_INPUT_FIELDS,
|
|
137
|
-
SPLIT_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
|
|
138
|
-
SPLIT_PROPOSER_JOB_INSTRUCTION,
|
|
75
|
+
FILTER_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
76
|
+
FILTER_SPLIT_PROPOSER_BASE_USER_PROMPT,
|
|
77
|
+
MAP_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
78
|
+
MAP_SPLIT_PROPOSER_BASE_USER_PROMPT,
|
|
139
79
|
)
|
|
140
|
-
from palimpzest.prompts.
|
|
80
|
+
from palimpzest.prompts.utils import (
|
|
81
|
+
AUDIO_DISCLAIMER,
|
|
82
|
+
AUDIO_EXAMPLE_ANSWER,
|
|
83
|
+
AUDIO_EXAMPLE_CONTEXT,
|
|
84
|
+
AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
85
|
+
AUDIO_EXAMPLE_OUTPUT_FIELDS,
|
|
86
|
+
AUDIO_EXAMPLE_REASONING,
|
|
87
|
+
AUDIO_SENTENCE_EXAMPLE_ANSWER,
|
|
141
88
|
DESC_SECTION,
|
|
89
|
+
EXAMPLE_FILTER_CONDITION,
|
|
90
|
+
EXAMPLE_JOIN_CONDITION,
|
|
91
|
+
FILTER_EXAMPLE_REASONING,
|
|
92
|
+
FILTER_JOB_INSTRUCTION,
|
|
93
|
+
IMAGE_DISCLAIMER,
|
|
94
|
+
IMAGE_EXAMPLE_ANSWER,
|
|
95
|
+
IMAGE_EXAMPLE_CONTEXT,
|
|
96
|
+
IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
97
|
+
IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
98
|
+
IMAGE_EXAMPLE_REASONING,
|
|
99
|
+
IMAGE_SENTENCE_EXAMPLE_ANSWER,
|
|
100
|
+
JOIN_EXAMPLE_REASONING,
|
|
101
|
+
JOIN_JOB_INSTRUCTION,
|
|
102
|
+
MAP_JOB_INSTRUCTION,
|
|
142
103
|
ONE_TO_MANY_OUTPUT_FORMAT_INSTRUCTION,
|
|
143
104
|
ONE_TO_ONE_OUTPUT_FORMAT_INSTRUCTION,
|
|
105
|
+
PROPOSER_JOB_INSTRUCTION,
|
|
106
|
+
RIGHT_AUDIO_DISCLAIMER,
|
|
107
|
+
RIGHT_AUDIO_EXAMPLE_CONTEXT,
|
|
108
|
+
RIGHT_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
109
|
+
RIGHT_IMAGE_DISCLAIMER,
|
|
110
|
+
RIGHT_IMAGE_EXAMPLE_CONTEXT,
|
|
111
|
+
RIGHT_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
112
|
+
RIGHT_TEXT_EXAMPLE_CONTEXT,
|
|
113
|
+
RIGHT_TEXT_EXAMPLE_INPUT_FIELDS,
|
|
114
|
+
TEXT_EXAMPLE_ANSWER,
|
|
115
|
+
TEXT_EXAMPLE_CONTEXT,
|
|
116
|
+
TEXT_EXAMPLE_INPUT_FIELDS,
|
|
117
|
+
TEXT_EXAMPLE_OUTPUT_FIELDS,
|
|
118
|
+
TEXT_EXAMPLE_REASONING,
|
|
119
|
+
TEXT_SENTENCE_EXAMPLE_ANSWER,
|
|
144
120
|
)
|
|
145
121
|
|
|
146
122
|
|
|
@@ -148,62 +124,54 @@ class PromptFactory:
|
|
|
148
124
|
"""Factory class for generating prompts for the Generator given the input(s)."""
|
|
149
125
|
|
|
150
126
|
BASE_SYSTEM_PROMPT_MAP = {
|
|
151
|
-
|
|
152
|
-
PromptStrategy.
|
|
153
|
-
PromptStrategy.
|
|
154
|
-
PromptStrategy.
|
|
155
|
-
PromptStrategy.
|
|
156
|
-
PromptStrategy.
|
|
157
|
-
PromptStrategy.
|
|
158
|
-
PromptStrategy.
|
|
159
|
-
PromptStrategy.
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
PromptStrategy.
|
|
163
|
-
PromptStrategy.
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
PromptStrategy.
|
|
167
|
-
PromptStrategy.
|
|
168
|
-
PromptStrategy.
|
|
169
|
-
PromptStrategy.
|
|
170
|
-
PromptStrategy.
|
|
171
|
-
PromptStrategy.
|
|
172
|
-
PromptStrategy.
|
|
173
|
-
PromptStrategy.
|
|
174
|
-
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
175
|
-
PromptStrategy.COT_MOA_AGG: COT_MOA_AGG_BASE_SYSTEM_PROMPT,
|
|
176
|
-
PromptStrategy.SPLIT_PROPOSER: COT_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
177
|
-
PromptStrategy.SPLIT_MERGER: COT_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
|
|
127
|
+
# filter system prompts
|
|
128
|
+
PromptStrategy.FILTER: FILTER_BASE_SYSTEM_PROMPT,
|
|
129
|
+
PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
130
|
+
PromptStrategy.FILTER_CRITIC: None,
|
|
131
|
+
PromptStrategy.FILTER_REFINE: None,
|
|
132
|
+
PromptStrategy.FILTER_MOA_PROPOSER: FILTER_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
133
|
+
PromptStrategy.FILTER_MOA_AGG: FILTER_MOA_AGG_BASE_SYSTEM_PROMPT,
|
|
134
|
+
PromptStrategy.FILTER_SPLIT_PROPOSER: FILTER_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
135
|
+
PromptStrategy.FILTER_SPLIT_MERGER: FILTER_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
|
|
136
|
+
|
|
137
|
+
# join system prompts
|
|
138
|
+
PromptStrategy.JOIN: JOIN_BASE_SYSTEM_PROMPT,
|
|
139
|
+
PromptStrategy.JOIN_NO_REASONING: JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
140
|
+
|
|
141
|
+
# map system prompts
|
|
142
|
+
PromptStrategy.MAP: MAP_BASE_SYSTEM_PROMPT,
|
|
143
|
+
PromptStrategy.MAP_NO_REASONING: MAP_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
144
|
+
PromptStrategy.MAP_CRITIC: None,
|
|
145
|
+
PromptStrategy.MAP_REFINE: None,
|
|
146
|
+
PromptStrategy.MAP_MOA_PROPOSER: MAP_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
147
|
+
PromptStrategy.MAP_MOA_AGG: MAP_MOA_AGG_BASE_SYSTEM_PROMPT,
|
|
148
|
+
PromptStrategy.MAP_SPLIT_PROPOSER: MAP_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
149
|
+
PromptStrategy.MAP_SPLIT_MERGER: MAP_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
|
|
178
150
|
}
|
|
179
151
|
BASE_USER_PROMPT_MAP = {
|
|
180
|
-
|
|
181
|
-
PromptStrategy.
|
|
182
|
-
PromptStrategy.
|
|
183
|
-
PromptStrategy.
|
|
184
|
-
PromptStrategy.
|
|
185
|
-
PromptStrategy.
|
|
186
|
-
PromptStrategy.
|
|
187
|
-
PromptStrategy.
|
|
188
|
-
PromptStrategy.
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
PromptStrategy.
|
|
192
|
-
PromptStrategy.
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
PromptStrategy.
|
|
196
|
-
PromptStrategy.
|
|
197
|
-
PromptStrategy.
|
|
198
|
-
PromptStrategy.
|
|
199
|
-
PromptStrategy.
|
|
200
|
-
PromptStrategy.
|
|
201
|
-
PromptStrategy.
|
|
202
|
-
PromptStrategy.
|
|
203
|
-
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_BASE_USER_PROMPT,
|
|
204
|
-
PromptStrategy.COT_MOA_AGG: COT_MOA_AGG_BASE_USER_PROMPT,
|
|
205
|
-
PromptStrategy.SPLIT_PROPOSER: COT_SPLIT_PROPOSER_BASE_USER_PROMPT,
|
|
206
|
-
PromptStrategy.SPLIT_MERGER: COT_SPLIT_MERGER_BASE_USER_PROMPT,
|
|
152
|
+
# filter user prompts
|
|
153
|
+
PromptStrategy.FILTER: FILTER_BASE_USER_PROMPT,
|
|
154
|
+
PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_USER_PROMPT,
|
|
155
|
+
PromptStrategy.FILTER_CRITIC: BASE_CRITIQUE_PROMPT,
|
|
156
|
+
PromptStrategy.FILTER_REFINE: BASE_REFINEMENT_PROMPT,
|
|
157
|
+
PromptStrategy.FILTER_MOA_PROPOSER: FILTER_MOA_PROPOSER_BASE_USER_PROMPT,
|
|
158
|
+
PromptStrategy.FILTER_MOA_AGG: FILTER_MOA_AGG_BASE_USER_PROMPT,
|
|
159
|
+
PromptStrategy.FILTER_SPLIT_PROPOSER: FILTER_SPLIT_PROPOSER_BASE_USER_PROMPT,
|
|
160
|
+
PromptStrategy.FILTER_SPLIT_MERGER: FILTER_SPLIT_MERGER_BASE_USER_PROMPT,
|
|
161
|
+
|
|
162
|
+
# join user prompts
|
|
163
|
+
PromptStrategy.JOIN: JOIN_BASE_USER_PROMPT,
|
|
164
|
+
PromptStrategy.JOIN_NO_REASONING: JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
165
|
+
|
|
166
|
+
# map user prompts
|
|
167
|
+
PromptStrategy.MAP: MAP_BASE_USER_PROMPT,
|
|
168
|
+
PromptStrategy.MAP_NO_REASONING: MAP_NO_REASONING_BASE_USER_PROMPT,
|
|
169
|
+
PromptStrategy.MAP_CRITIC: BASE_CRITIQUE_PROMPT,
|
|
170
|
+
PromptStrategy.MAP_REFINE: BASE_REFINEMENT_PROMPT,
|
|
171
|
+
PromptStrategy.MAP_MOA_PROPOSER: MAP_MOA_PROPOSER_BASE_USER_PROMPT,
|
|
172
|
+
PromptStrategy.MAP_MOA_AGG: MAP_MOA_AGG_BASE_USER_PROMPT,
|
|
173
|
+
PromptStrategy.MAP_SPLIT_PROPOSER: MAP_SPLIT_PROPOSER_BASE_USER_PROMPT,
|
|
174
|
+
PromptStrategy.MAP_SPLIT_MERGER: MAP_SPLIT_MERGER_BASE_USER_PROMPT,
|
|
207
175
|
}
|
|
208
176
|
|
|
209
177
|
def __init__(self, prompt_strategy: PromptStrategy, model: Model, cardinality: Cardinality, desc: str | None = None) -> None:
|
|
@@ -277,6 +245,54 @@ class PromptFactory:
|
|
|
277
245
|
input_fields = [field for field in input_fields if field in candidate.get_field_names()]
|
|
278
246
|
return input_fields
|
|
279
247
|
|
|
248
|
+
def _get_input_modalities(self, candidate: DataRecord, input_fields: list[str]) -> set[Modality]:
|
|
249
|
+
"""
|
|
250
|
+
The list of input modalities for the given input fields.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
candidate (DataRecord): The input record.
|
|
254
|
+
input_fields (list[str]): The input fields.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
set[Modality]: The list of input modalities.
|
|
258
|
+
"""
|
|
259
|
+
input_modalities = []
|
|
260
|
+
for field_name in input_fields:
|
|
261
|
+
field_type = candidate.get_field_type(field_name)
|
|
262
|
+
if field_type.annotation in IMAGE_FIELD_TYPES:
|
|
263
|
+
input_modalities.append(Modality.IMAGE)
|
|
264
|
+
elif field_type.annotation in AUDIO_FIELD_TYPES:
|
|
265
|
+
input_modalities.append(Modality.AUDIO)
|
|
266
|
+
else:
|
|
267
|
+
input_modalities.append(Modality.TEXT)
|
|
268
|
+
|
|
269
|
+
return set(input_modalities)
|
|
270
|
+
|
|
271
|
+
def _get_modalities_str(self, input_modalities: set[Modality]) -> str:
|
|
272
|
+
"""
|
|
273
|
+
Returns a format string to reflect the input modalities.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
input_modalities (set[Modality]): The input modalities.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
str: The string to reflect the input modalities.
|
|
280
|
+
"""
|
|
281
|
+
if input_modalities == {Modality.TEXT}:
|
|
282
|
+
return "text"
|
|
283
|
+
elif input_modalities == {Modality.IMAGE}:
|
|
284
|
+
return "image(s)"
|
|
285
|
+
elif input_modalities == {Modality.AUDIO}:
|
|
286
|
+
return "audio"
|
|
287
|
+
elif input_modalities == {Modality.TEXT, Modality.IMAGE}:
|
|
288
|
+
return "text and/or image(s)"
|
|
289
|
+
elif input_modalities == {Modality.TEXT, Modality.AUDIO}:
|
|
290
|
+
return "text and/or audio"
|
|
291
|
+
elif input_modalities == {Modality.IMAGE, Modality.AUDIO}:
|
|
292
|
+
return "image(s) and/or audio"
|
|
293
|
+
elif input_modalities == {Modality.TEXT, Modality.IMAGE, Modality.AUDIO}:
|
|
294
|
+
return "text, image(s), and/or audio"
|
|
295
|
+
|
|
280
296
|
def _get_input_fields_desc(self, candidate: DataRecord, input_fields: list[str]) -> str:
|
|
281
297
|
"""
|
|
282
298
|
Returns a multi-line description of each input field for the prompt.
|
|
@@ -305,8 +321,8 @@ class PromptFactory:
|
|
|
305
321
|
str: The output fields description.
|
|
306
322
|
"""
|
|
307
323
|
output_fields_desc = ""
|
|
308
|
-
output_schema: BaseModel = kwargs.get("output_schema")
|
|
309
|
-
if self.prompt_strategy.
|
|
324
|
+
output_schema: type[BaseModel] = kwargs.get("output_schema")
|
|
325
|
+
if self.prompt_strategy.is_map_prompt():
|
|
310
326
|
assert output_schema is not None, "Output schema must be provided for convert prompts."
|
|
311
327
|
|
|
312
328
|
for field_name in sorted(output_fields):
|
|
@@ -324,7 +340,7 @@ class PromptFactory:
|
|
|
324
340
|
str | None: The filter condition (if applicable).
|
|
325
341
|
"""
|
|
326
342
|
filter_condition = kwargs.get("filter_condition")
|
|
327
|
-
if self.prompt_strategy.
|
|
343
|
+
if self.prompt_strategy.is_filter_prompt():
|
|
328
344
|
assert filter_condition is not None, "Filter condition must be provided for filter operations."
|
|
329
345
|
|
|
330
346
|
return filter_condition
|
|
@@ -390,7 +406,8 @@ class PromptFactory:
|
|
|
390
406
|
if self.prompt_strategy.is_moa_aggregator_prompt():
|
|
391
407
|
model_responses = ""
|
|
392
408
|
for idx, model_response in enumerate(kwargs.get("model_responses")):
|
|
393
|
-
model_responses += f"MODEL RESPONSE {idx + 1}: {model_response}\n"
|
|
409
|
+
model_responses += f"MODEL RESPONSE {idx + 1}: {model_response.rstrip()}\n\n"
|
|
410
|
+
model_responses = model_responses.rstrip() if model_responses is not None else None
|
|
394
411
|
|
|
395
412
|
return model_responses
|
|
396
413
|
|
|
@@ -408,7 +425,8 @@ class PromptFactory:
|
|
|
408
425
|
if self.prompt_strategy.is_split_merger_prompt():
|
|
409
426
|
chunk_outputs = ""
|
|
410
427
|
for idx, chunk_output in enumerate(kwargs.get("chunk_outputs")):
|
|
411
|
-
chunk_outputs += f"CHUNK OUTPUT {idx + 1}: {chunk_output}\n"
|
|
428
|
+
chunk_outputs += f"CHUNK OUTPUT {idx + 1}: {chunk_output.rstrip()}\n\n"
|
|
429
|
+
chunk_outputs = chunk_outputs.rstrip() if chunk_outputs is not None else None
|
|
412
430
|
|
|
413
431
|
return chunk_outputs
|
|
414
432
|
|
|
@@ -425,28 +443,33 @@ class PromptFactory:
|
|
|
425
443
|
else ONE_TO_MANY_OUTPUT_FORMAT_INSTRUCTION
|
|
426
444
|
)
|
|
427
445
|
|
|
428
|
-
def _get_job_instruction(self) -> str | None:
|
|
446
|
+
def _get_job_instruction(self, input_modalities: set[Modality]) -> str | None:
|
|
429
447
|
"""
|
|
430
448
|
Returns the job instruction based on the prompt strategy.
|
|
431
449
|
|
|
450
|
+
Args:
|
|
451
|
+
input_modalities (set[Modality]): The modalities of the input fields.
|
|
452
|
+
|
|
432
453
|
Returns:
|
|
433
|
-
str | None: The job instruction
|
|
454
|
+
str | None: The job instruction.
|
|
434
455
|
"""
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
456
|
+
# get the job instruction based on the prompt strategy
|
|
457
|
+
job_instruction = None
|
|
458
|
+
if self.prompt_strategy.is_moa_proposer_prompt() or self.prompt_strategy.is_split_proposer_prompt():
|
|
459
|
+
job_instruction = PROPOSER_JOB_INSTRUCTION
|
|
460
|
+
elif self.prompt_strategy.is_map_prompt():
|
|
461
|
+
job_instruction = MAP_JOB_INSTRUCTION
|
|
462
|
+
elif self.prompt_strategy.is_filter_prompt():
|
|
463
|
+
job_instruction = FILTER_JOB_INSTRUCTION
|
|
464
|
+
elif self.prompt_strategy.is_join_prompt():
|
|
465
|
+
job_instruction = JOIN_JOB_INSTRUCTION
|
|
466
|
+
|
|
467
|
+
# format the job instruction based on the input modalities
|
|
468
|
+
modalities = self._get_modalities_str(input_modalities)
|
|
469
|
+
if job_instruction is not None:
|
|
470
|
+
job_instruction = job_instruction.format(modalities=modalities)
|
|
471
|
+
|
|
472
|
+
return job_instruction
|
|
450
473
|
|
|
451
474
|
def _get_desc_section(self) -> str:
|
|
452
475
|
"""
|
|
@@ -470,9 +493,7 @@ class PromptFactory:
|
|
|
470
493
|
"""
|
|
471
494
|
critique_criteria = None
|
|
472
495
|
if self.prompt_strategy.is_critic_prompt():
|
|
473
|
-
critique_criteria = (
|
|
474
|
-
COT_QA_IMAGE_CRITIQUE_CRITERIA if self.prompt_strategy.is_image_prompt() else COT_QA_CRITIQUE_CRITERIA
|
|
475
|
-
)
|
|
496
|
+
critique_criteria = MAP_CRITIQUE_CRITERIA if self.prompt_strategy.is_map_prompt() else FILTER_CRITIQUE_CRITERIA
|
|
476
497
|
|
|
477
498
|
return critique_criteria
|
|
478
499
|
|
|
@@ -485,11 +506,7 @@ class PromptFactory:
|
|
|
485
506
|
"""
|
|
486
507
|
refinement_criteria = None
|
|
487
508
|
if self.prompt_strategy.is_refine_prompt():
|
|
488
|
-
refinement_criteria = (
|
|
489
|
-
COT_QA_IMAGE_REFINEMENT_CRITERIA
|
|
490
|
-
if self.prompt_strategy.is_image_prompt()
|
|
491
|
-
else COT_QA_REFINEMENT_CRITERIA
|
|
492
|
-
)
|
|
509
|
+
refinement_criteria = MAP_REFINEMENT_CRITERIA if self.prompt_strategy.is_map_prompt() else FILTER_REFINEMENT_CRITERIA
|
|
493
510
|
|
|
494
511
|
return refinement_criteria
|
|
495
512
|
|
|
@@ -502,240 +519,156 @@ class PromptFactory:
|
|
|
502
519
|
"""
|
|
503
520
|
finish_instruction = None
|
|
504
521
|
if self.prompt_strategy.is_critic_prompt():
|
|
505
|
-
finish_instruction =
|
|
522
|
+
finish_instruction = MAP_CRITIQUE_FINISH_INSTRUCTION if self.prompt_strategy.is_map_prompt() else FILTER_CRITIQUE_FINISH_INSTRUCTION
|
|
506
523
|
elif self.prompt_strategy.is_refine_prompt():
|
|
507
|
-
finish_instruction =
|
|
524
|
+
finish_instruction = MAP_REFINEMENT_FINISH_INSTRUCTION if self.prompt_strategy.is_map_prompt() else FILTER_REFINEMENT_FINISH_INSTRUCTION
|
|
508
525
|
|
|
509
526
|
return finish_instruction
|
|
510
527
|
|
|
511
|
-
def _get_example_input_fields(self) -> str
|
|
528
|
+
def _get_example_input_fields(self, input_modalities: set[Modality], right: bool = False) -> str:
|
|
512
529
|
"""
|
|
513
530
|
Returns the example input fields for the prompt.
|
|
514
531
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
prompt_strategy_to_example_input_fields = {
|
|
519
|
-
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_INPUT_FIELDS,
|
|
520
|
-
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
521
|
-
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
522
|
-
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_INPUT_FIELDS,
|
|
523
|
-
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
524
|
-
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
525
|
-
PromptStrategy.COT_QA: COT_QA_EXAMPLE_INPUT_FIELDS,
|
|
526
|
-
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
527
|
-
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
528
|
-
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_INPUT_FIELDS,
|
|
529
|
-
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
530
|
-
PromptStrategy.SPLIT_PROPOSER: SPLIT_PROPOSER_EXAMPLE_INPUT_FIELDS,
|
|
531
|
-
}
|
|
532
|
-
|
|
533
|
-
return prompt_strategy_to_example_input_fields.get(self.prompt_strategy)
|
|
534
|
-
|
|
535
|
-
def _get_right_example_input_fields(self) -> str | None:
|
|
536
|
-
"""
|
|
537
|
-
Returns the example right input fields for the join prompt.
|
|
532
|
+
Args:
|
|
533
|
+
input_modalities (set[Modality]): The modalities of the input fields.
|
|
534
|
+
right (bool): Whether to return the right input fields for the join prompt.
|
|
538
535
|
|
|
539
536
|
Returns:
|
|
540
|
-
str
|
|
537
|
+
str: The example input fields.
|
|
541
538
|
"""
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
539
|
+
input_modality_to_example_input_fields = {
|
|
540
|
+
Modality.TEXT: RIGHT_TEXT_EXAMPLE_INPUT_FIELDS if right else TEXT_EXAMPLE_INPUT_FIELDS,
|
|
541
|
+
Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_INPUT_FIELDS if right else IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
542
|
+
Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_INPUT_FIELDS if right else AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
546
543
|
}
|
|
547
544
|
|
|
548
|
-
|
|
545
|
+
example_input_fields = ""
|
|
546
|
+
for input_modality in input_modalities:
|
|
547
|
+
example_input_fields += input_modality_to_example_input_fields[input_modality].rstrip()
|
|
548
|
+
example_input_fields = example_input_fields.lstrip() + "\n"
|
|
549
|
+
|
|
550
|
+
return example_input_fields
|
|
549
551
|
|
|
550
|
-
def _get_example_output_fields(self) -> str
|
|
552
|
+
def _get_example_output_fields(self, input_modalities: set[Modality]) -> str:
|
|
551
553
|
"""
|
|
552
554
|
Returns the example output fields for the prompt.
|
|
553
555
|
|
|
554
556
|
Returns:
|
|
555
|
-
str
|
|
557
|
+
str: The example output fields.
|
|
556
558
|
"""
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
|
|
562
|
-
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
563
|
-
PromptStrategy.SPLIT_PROPOSER: SPLIT_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
|
|
559
|
+
input_modality_to_example_output_fields = {
|
|
560
|
+
Modality.TEXT: TEXT_EXAMPLE_OUTPUT_FIELDS,
|
|
561
|
+
Modality.IMAGE: IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
562
|
+
Modality.AUDIO: AUDIO_EXAMPLE_OUTPUT_FIELDS,
|
|
564
563
|
}
|
|
565
564
|
|
|
566
|
-
|
|
565
|
+
example_output_fields = ""
|
|
566
|
+
for input_modality in input_modalities:
|
|
567
|
+
example_output_fields += input_modality_to_example_output_fields[input_modality].rstrip()
|
|
568
|
+
example_output_fields = example_output_fields.lstrip() + "\n"
|
|
567
569
|
|
|
568
|
-
|
|
570
|
+
return example_output_fields
|
|
571
|
+
|
|
572
|
+
def _get_example_context(self, input_modalities: set[Modality], right: bool = False) -> str:
|
|
569
573
|
"""
|
|
570
574
|
Returns the example context for the prompt.
|
|
571
575
|
|
|
572
576
|
Returns:
|
|
573
|
-
str
|
|
577
|
+
str: The example context.
|
|
574
578
|
"""
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_CONTEXT,
|
|
580
|
-
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
|
|
581
|
-
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
|
|
582
|
-
PromptStrategy.COT_QA: COT_QA_EXAMPLE_CONTEXT,
|
|
583
|
-
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_CONTEXT,
|
|
584
|
-
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_CONTEXT,
|
|
585
|
-
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_CONTEXT,
|
|
586
|
-
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_CONTEXT,
|
|
587
|
-
PromptStrategy.SPLIT_PROPOSER: SPLIT_PROPOSER_EXAMPLE_CONTEXT,
|
|
579
|
+
input_modality_to_example_context = {
|
|
580
|
+
Modality.TEXT: RIGHT_TEXT_EXAMPLE_CONTEXT if right else TEXT_EXAMPLE_CONTEXT,
|
|
581
|
+
Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_CONTEXT if right else IMAGE_EXAMPLE_CONTEXT,
|
|
582
|
+
Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_CONTEXT if right else AUDIO_EXAMPLE_CONTEXT,
|
|
588
583
|
}
|
|
589
584
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
""
|
|
594
|
-
Returns the right example context for the join prompt.
|
|
595
|
-
|
|
596
|
-
Returns:
|
|
597
|
-
str | None: The right example context (if applicable).
|
|
598
|
-
"""
|
|
599
|
-
prompt_strategy_to_right_example_context = {
|
|
600
|
-
PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
|
|
601
|
-
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
|
|
602
|
-
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
|
|
603
|
-
}
|
|
585
|
+
example_context = ""
|
|
586
|
+
for input_modality in input_modalities:
|
|
587
|
+
example_context += input_modality_to_example_context[input_modality].rstrip() + ","
|
|
588
|
+
example_context = example_context[:-1] + "\n"
|
|
604
589
|
|
|
605
|
-
return
|
|
590
|
+
return example_context
|
|
606
591
|
|
|
607
|
-
def _get_image_disclaimer(self) -> str:
|
|
592
|
+
def _get_image_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
|
|
608
593
|
"""
|
|
609
594
|
Returns the image disclaimer for the prompt. The disclaimer must be an empty string
|
|
610
|
-
for
|
|
595
|
+
for non-image prompts.
|
|
611
596
|
|
|
612
597
|
Returns:
|
|
613
598
|
str: The image disclaimer. If this is a text prompt then it is an empty string.
|
|
614
599
|
"""
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_DISCLAIMER,
|
|
618
|
-
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_DISCLAIMER,
|
|
619
|
-
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_DISCLAIMER,
|
|
620
|
-
}
|
|
621
|
-
|
|
622
|
-
return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
|
|
600
|
+
image_disclaimer = RIGHT_IMAGE_DISCLAIMER if right else IMAGE_DISCLAIMER
|
|
601
|
+
return image_disclaimer if Modality.IMAGE in input_modalities else ""
|
|
623
602
|
|
|
624
|
-
def _get_audio_disclaimer(self) -> str:
|
|
603
|
+
def _get_audio_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
|
|
625
604
|
"""
|
|
626
605
|
Returns the audio disclaimer for the prompt. The disclaimer must be an empty string
|
|
627
|
-
for
|
|
606
|
+
for non-audio prompts.
|
|
628
607
|
|
|
629
608
|
Returns:
|
|
630
609
|
str: The audio disclaimer. If this is a text prompt then it is an empty string.
|
|
631
610
|
"""
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_DISCLAIMER,
|
|
635
|
-
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_DISCLAIMER,
|
|
636
|
-
}
|
|
637
|
-
|
|
638
|
-
return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
|
|
639
|
-
|
|
640
|
-
def _get_right_image_disclaimer(self) -> str:
|
|
641
|
-
"""
|
|
642
|
-
Returns the right image disclaimer for the prompt. The disclaimer must be an empty string
|
|
643
|
-
for text prompts.
|
|
644
|
-
|
|
645
|
-
Returns:
|
|
646
|
-
str: The right image disclaimer. If this is a text prompt then it is an empty string.
|
|
647
|
-
"""
|
|
648
|
-
prompt_strategy_to_image_disclaimer = {
|
|
649
|
-
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
|
|
650
|
-
}
|
|
651
|
-
|
|
652
|
-
return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
|
|
653
|
-
|
|
654
|
-
def _get_right_audio_disclaimer(self) -> str:
|
|
655
|
-
"""
|
|
656
|
-
Returns the right audio disclaimer for the prompt. The disclaimer must be an empty string
|
|
657
|
-
for text prompts.
|
|
658
|
-
|
|
659
|
-
Returns:
|
|
660
|
-
str: The right audio disclaimer. If this is a text prompt then it is an empty string.
|
|
661
|
-
"""
|
|
662
|
-
prompt_strategy_to_audio_disclaimer = {
|
|
663
|
-
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
|
|
664
|
-
}
|
|
611
|
+
audio_disclaimer = RIGHT_AUDIO_DISCLAIMER if right else AUDIO_DISCLAIMER
|
|
612
|
+
return audio_disclaimer if Modality.AUDIO in input_modalities else ""
|
|
665
613
|
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
def _get_example_filter_condition(self) -> str | None:
|
|
614
|
+
def _get_example_reasoning(self, input_modalities: set[Modality]) -> str:
|
|
669
615
|
"""
|
|
670
|
-
Returns the example
|
|
616
|
+
Returns the example reasoning for the prompt.
|
|
671
617
|
|
|
672
618
|
Returns:
|
|
673
|
-
str
|
|
674
|
-
"""
|
|
675
|
-
prompt_strategy_to_example_filter_condition = {
|
|
676
|
-
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_FILTER_CONDITION,
|
|
677
|
-
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
|
|
678
|
-
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_FILTER_CONDITION,
|
|
679
|
-
}
|
|
680
|
-
|
|
681
|
-
return prompt_strategy_to_example_filter_condition.get(self.prompt_strategy)
|
|
682
|
-
|
|
683
|
-
def _get_example_join_condition(self) -> str | None:
|
|
619
|
+
str: The example reasoning.
|
|
684
620
|
"""
|
|
685
|
-
|
|
621
|
+
if self.prompt_strategy.is_filter_prompt():
|
|
622
|
+
return FILTER_EXAMPLE_REASONING
|
|
623
|
+
elif self.prompt_strategy.is_join_prompt():
|
|
624
|
+
return JOIN_EXAMPLE_REASONING
|
|
686
625
|
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_JOIN_CONDITION,
|
|
692
|
-
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
|
|
693
|
-
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
|
|
626
|
+
input_modality_to_example_reasoning = {
|
|
627
|
+
Modality.TEXT: TEXT_EXAMPLE_REASONING,
|
|
628
|
+
Modality.IMAGE: IMAGE_EXAMPLE_REASONING,
|
|
629
|
+
Modality.AUDIO: AUDIO_EXAMPLE_REASONING,
|
|
694
630
|
}
|
|
695
631
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
Returns the example reasoning for the prompt.
|
|
701
|
-
|
|
702
|
-
Returns:
|
|
703
|
-
str | None: The example reasoning (if applicable).
|
|
704
|
-
"""
|
|
705
|
-
prompt_strategy_to_example_reasoning = {
|
|
706
|
-
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_REASONING,
|
|
707
|
-
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_REASONING,
|
|
708
|
-
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_REASONING,
|
|
709
|
-
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_REASONING,
|
|
710
|
-
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_REASONING,
|
|
711
|
-
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_REASONING,
|
|
712
|
-
PromptStrategy.COT_QA: COT_QA_EXAMPLE_REASONING,
|
|
713
|
-
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_REASONING,
|
|
714
|
-
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_REASONING,
|
|
715
|
-
}
|
|
632
|
+
example_reasoning = ""
|
|
633
|
+
for input_modality in input_modalities:
|
|
634
|
+
example_reasoning += input_modality_to_example_reasoning[input_modality] + " "
|
|
635
|
+
example_reasoning = example_reasoning.rstrip()
|
|
716
636
|
|
|
717
|
-
return
|
|
637
|
+
return example_reasoning
|
|
718
638
|
|
|
719
|
-
def _get_example_answer(self) -> str
|
|
639
|
+
def _get_example_answer(self, input_modalities: set[Modality]) -> str:
|
|
720
640
|
"""
|
|
721
641
|
Returns the example answer for the prompt.
|
|
722
642
|
|
|
723
643
|
Returns:
|
|
724
|
-
str
|
|
644
|
+
str: The example answer.
|
|
725
645
|
"""
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_ANSWER,
|
|
732
|
-
PromptStrategy.SPLIT_PROPOSER: SPLIT_PROPOSER_EXAMPLE_ANSWER,
|
|
646
|
+
use_sentence_answers = self.prompt_strategy.is_split_proposer_prompt() or self.prompt_strategy.is_moa_proposer_prompt()
|
|
647
|
+
input_modality_to_example_answer = {
|
|
648
|
+
Modality.TEXT: TEXT_SENTENCE_EXAMPLE_ANSWER if use_sentence_answers else TEXT_EXAMPLE_ANSWER,
|
|
649
|
+
Modality.IMAGE: IMAGE_SENTENCE_EXAMPLE_ANSWER if use_sentence_answers else IMAGE_EXAMPLE_ANSWER,
|
|
650
|
+
Modality.AUDIO: AUDIO_SENTENCE_EXAMPLE_ANSWER if use_sentence_answers else AUDIO_EXAMPLE_ANSWER,
|
|
733
651
|
}
|
|
734
652
|
|
|
735
|
-
|
|
653
|
+
example_answer = ""
|
|
654
|
+
for input_modality in input_modalities:
|
|
655
|
+
example_answer += input_modality_to_example_answer[input_modality].rstrip()
|
|
656
|
+
if use_sentence_answers:
|
|
657
|
+
example_answer += " "
|
|
658
|
+
example_answer = example_answer + "\n"
|
|
659
|
+
|
|
660
|
+
return example_answer
|
|
736
661
|
|
|
737
662
|
def _get_all_format_kwargs(
|
|
738
|
-
self,
|
|
663
|
+
self,
|
|
664
|
+
candidate: DataRecord,
|
|
665
|
+
input_fields: list[str],
|
|
666
|
+
input_modalities: set[Modality],
|
|
667
|
+
output_fields: list[str],
|
|
668
|
+
right_candidate: DataRecord | None,
|
|
669
|
+
right_input_fields: list[str],
|
|
670
|
+
right_input_modalities: set[Modality],
|
|
671
|
+
**kwargs,
|
|
739
672
|
) -> dict:
|
|
740
673
|
"""
|
|
741
674
|
Returns a dictionary containing all the format kwargs for templating the prompts.
|
|
@@ -770,26 +703,27 @@ class PromptFactory:
|
|
|
770
703
|
})
|
|
771
704
|
|
|
772
705
|
# get format kwargs which depend on the prompt strategy
|
|
706
|
+
full_input_modalities = input_modalities.union(right_input_modalities)
|
|
773
707
|
prompt_strategy_format_kwargs = {
|
|
774
708
|
"output_format_instruction": self._get_output_format_instruction(),
|
|
775
|
-
"job_instruction": self._get_job_instruction(),
|
|
709
|
+
"job_instruction": self._get_job_instruction(full_input_modalities),
|
|
776
710
|
"desc_section": self._get_desc_section(),
|
|
777
711
|
"critique_criteria": self._get_critique_criteria(),
|
|
778
712
|
"refinement_criteria": self._get_refinement_criteria(),
|
|
779
713
|
"finish_instruction": self._get_finish_instruction(),
|
|
780
|
-
"example_input_fields": self._get_example_input_fields(),
|
|
781
|
-
"right_example_input_fields": self.
|
|
782
|
-
"example_output_fields": self._get_example_output_fields(),
|
|
783
|
-
"example_context": self._get_example_context(),
|
|
784
|
-
"right_example_context": self.
|
|
785
|
-
"image_disclaimer": self._get_image_disclaimer(),
|
|
786
|
-
"audio_disclaimer": self._get_audio_disclaimer(),
|
|
787
|
-
"right_image_disclaimer": self.
|
|
788
|
-
"right_audio_disclaimer": self.
|
|
789
|
-
"example_filter_condition":
|
|
790
|
-
"example_join_condition":
|
|
791
|
-
"example_reasoning": self._get_example_reasoning(),
|
|
792
|
-
"example_answer": self._get_example_answer(),
|
|
714
|
+
"example_input_fields": self._get_example_input_fields(input_modalities),
|
|
715
|
+
"right_example_input_fields": self._get_example_input_fields(right_input_modalities, right=True),
|
|
716
|
+
"example_output_fields": self._get_example_output_fields(input_modalities),
|
|
717
|
+
"example_context": self._get_example_context(input_modalities),
|
|
718
|
+
"right_example_context": self._get_example_context(right_input_modalities, right=True),
|
|
719
|
+
"image_disclaimer": self._get_image_disclaimer(input_modalities),
|
|
720
|
+
"audio_disclaimer": self._get_audio_disclaimer(input_modalities),
|
|
721
|
+
"right_image_disclaimer": self._get_image_disclaimer(right_input_modalities, right=True),
|
|
722
|
+
"right_audio_disclaimer": self._get_audio_disclaimer(right_input_modalities, right=True),
|
|
723
|
+
"example_filter_condition": EXAMPLE_FILTER_CONDITION,
|
|
724
|
+
"example_join_condition": EXAMPLE_JOIN_CONDITION,
|
|
725
|
+
"example_reasoning": self._get_example_reasoning(input_modalities),
|
|
726
|
+
"example_answer": self._get_example_answer(input_modalities),
|
|
793
727
|
}
|
|
794
728
|
|
|
795
729
|
# return all format kwargs
|
|
@@ -937,7 +871,7 @@ class PromptFactory:
|
|
|
937
871
|
# get any audio messages for the chat payload (will be an empty list if no audio fields exist)
|
|
938
872
|
audio_messages = self._create_audio_messages(candidate, input_fields)
|
|
939
873
|
|
|
940
|
-
# get any right image messages for the chat payload (will be an empty list if
|
|
874
|
+
# get any right image / audio messages for the chat payload (will be an empty list if image / audio not present)
|
|
941
875
|
right_image_messages, right_audio_messages = [], []
|
|
942
876
|
if self.prompt_strategy.is_join_prompt():
|
|
943
877
|
assert right_candidate is not None, "Right candidate must be provided for join prompts."
|
|
@@ -951,121 +885,63 @@ class PromptFactory:
|
|
|
951
885
|
"Original messages must be provided for critique and refinement operations."
|
|
952
886
|
)
|
|
953
887
|
|
|
954
|
-
#
|
|
955
|
-
|
|
888
|
+
# combine image and audio messages
|
|
889
|
+
image_audio_messages = image_messages + audio_messages
|
|
890
|
+
right_image_audio_messages = right_image_messages + right_audio_messages
|
|
891
|
+
has_image_audio = len(image_audio_messages) > 0
|
|
892
|
+
has_right_image_audio = len(right_image_audio_messages) > 0
|
|
893
|
+
|
|
956
894
|
# construct the user messages based on the prompt strategy
|
|
957
895
|
user_messages = []
|
|
958
896
|
if self.prompt_strategy.is_critic_prompt() or self.prompt_strategy.is_refine_prompt():
|
|
959
|
-
# NOTE: if this critic / refinement prompt is processing images, those images
|
|
960
|
-
#
|
|
897
|
+
# NOTE: if this critic / refinement prompt is processing images / audio, those images / audio
|
|
898
|
+
# will be part of the `original_messages` and will show up in the final chat payload
|
|
961
899
|
base_prompt_start, base_prompt_end = base_prompt.split("<<original-prompt-placeholder>>\n")
|
|
962
900
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
963
901
|
user_messages.extend(original_messages)
|
|
964
902
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
965
903
|
|
|
966
|
-
#
|
|
967
|
-
elif self.prompt_strategy.
|
|
968
|
-
|
|
969
|
-
|
|
904
|
+
# handle joins with left and right images / audio
|
|
905
|
+
elif self.prompt_strategy.is_join_prompt() and has_image_audio and has_right_image_audio:
|
|
906
|
+
base_prompt_start, base_prompt_rest = base_prompt.split("<<image-audio-placeholder>>")
|
|
907
|
+
base_prompt_mid, base_prompt_end = base_prompt_rest.split("<<right-image-audio-placeholder>>")
|
|
970
908
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
971
|
-
user_messages.extend(
|
|
909
|
+
user_messages.extend(image_audio_messages)
|
|
910
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
|
|
911
|
+
user_messages.extend(right_image_audio_messages)
|
|
972
912
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
973
913
|
|
|
974
|
-
#
|
|
975
|
-
elif self.prompt_strategy.
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<image-placeholder>>")
|
|
914
|
+
# handle joins with only left images / audio
|
|
915
|
+
elif self.prompt_strategy.is_join_prompt() and has_image_audio and not has_right_image_audio:
|
|
916
|
+
base_prompt = base_prompt.replace("<<right-image-audio-placeholder>>", "")
|
|
917
|
+
base_prompt_start, base_prompt_end = base_prompt.split("<<image-audio-placeholder>>")
|
|
979
918
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
980
|
-
user_messages.extend(
|
|
981
|
-
user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
|
|
982
|
-
user_messages.extend(right_image_messages)
|
|
919
|
+
user_messages.extend(image_audio_messages)
|
|
983
920
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
984
921
|
|
|
985
|
-
#
|
|
986
|
-
elif self.prompt_strategy.
|
|
987
|
-
base_prompt = base_prompt.replace("<<image-placeholder>>", "")
|
|
988
|
-
base_prompt_start, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
|
|
922
|
+
# handle joins with only right images / audio
|
|
923
|
+
elif self.prompt_strategy.is_join_prompt() and not has_image_audio and has_right_image_audio:
|
|
924
|
+
base_prompt = base_prompt.replace("<<image-audio-placeholder>>", "")
|
|
925
|
+
base_prompt_start, base_prompt_end = base_prompt.split("<<right-image-audio-placeholder>>")
|
|
989
926
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
990
|
-
user_messages.extend(
|
|
927
|
+
user_messages.extend(right_image_audio_messages)
|
|
991
928
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
992
929
|
|
|
993
|
-
# audio
|
|
994
|
-
elif self.prompt_strategy.
|
|
995
|
-
|
|
996
|
-
base_prompt = base_prompt.replace("<<image-placeholder>>", "")
|
|
997
|
-
base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
|
|
930
|
+
# handle non-joins with images / audio
|
|
931
|
+
elif not self.prompt_strategy.is_join_prompt() and has_image_audio and not self.prompt_strategy.is_moa_aggregator_prompt():
|
|
932
|
+
base_prompt_start, base_prompt_end = base_prompt.split("<<image-audio-placeholder>>")
|
|
998
933
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
999
|
-
user_messages.extend(
|
|
1000
|
-
user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
|
|
1001
|
-
user_messages.extend(right_audio_messages)
|
|
934
|
+
user_messages.extend(image_audio_messages)
|
|
1002
935
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
1003
936
|
|
|
937
|
+
# handle prompts w/no images or audio
|
|
1004
938
|
else:
|
|
1005
|
-
base_prompt = base_prompt.replace("<<image-placeholder>>", "")
|
|
1006
|
-
base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
|
|
939
|
+
base_prompt = base_prompt.replace("<<image-audio-placeholder>>", "")
|
|
940
|
+
base_prompt = base_prompt.replace("<<right-image-audio-placeholder>>", "")
|
|
1007
941
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt.format(**kwargs)})
|
|
1008
942
|
|
|
1009
943
|
return user_messages
|
|
1010
944
|
|
|
1011
|
-
def _process_custom_user_prompt(self, candidate: DataRecord, input_fields: list[str], **kwargs) -> list[dict]:
|
|
1012
|
-
"""
|
|
1013
|
-
Processes a custom user prompt provided by the user.
|
|
1014
|
-
|
|
1015
|
-
Args:
|
|
1016
|
-
candidate (DataRecord): The input record.
|
|
1017
|
-
kwargs: The keyword arguments provided by the user.
|
|
1018
|
-
|
|
1019
|
-
Returns:
|
|
1020
|
-
list[dict]: The messages for the chat payload.
|
|
1021
|
-
"""
|
|
1022
|
-
# get the user prompt
|
|
1023
|
-
user_prompt: str = kwargs["prompt"]
|
|
1024
|
-
|
|
1025
|
-
# sanity check that we have all the inputs for the user's prompt template
|
|
1026
|
-
prompt_field_names = [fname for _, fname, _, _ in Formatter().parse(user_prompt) if fname]
|
|
1027
|
-
fields_check = all([field in input_fields for field in prompt_field_names])
|
|
1028
|
-
if not fields_check:
|
|
1029
|
-
if sorted(candidate.get_field_names()) != (input_fields):
|
|
1030
|
-
err_msg = (
|
|
1031
|
-
f"Prompt string has fields which are not in input fields.\n"
|
|
1032
|
-
f"Prompt fields: {prompt_field_names}\n"
|
|
1033
|
-
f"Computed fields: {candidate.get_field_names()}\n"
|
|
1034
|
-
f"Input fields: {input_fields}\n"
|
|
1035
|
-
f"Be careful that you are not projecting out computed fields. "
|
|
1036
|
-
f"If you use `depends_on` in your program, make sure it includes the fields you need."
|
|
1037
|
-
)
|
|
1038
|
-
else:
|
|
1039
|
-
err_msg = (
|
|
1040
|
-
f"Prompt string has fields which are not in input fields.\n"
|
|
1041
|
-
f"Prompt fields: {prompt_field_names}\n"
|
|
1042
|
-
f"Input fields: {input_fields}\n"
|
|
1043
|
-
)
|
|
1044
|
-
assert fields_check, err_msg
|
|
1045
|
-
|
|
1046
|
-
# build set of format kwargs
|
|
1047
|
-
format_kwargs = {
|
|
1048
|
-
field_name: "<bytes>"
|
|
1049
|
-
if candidate.get_field_type(field_name).annotation in [bytes, bytes | None]
|
|
1050
|
-
else candidate[field_name]
|
|
1051
|
-
for field_name in input_fields
|
|
1052
|
-
}
|
|
1053
|
-
|
|
1054
|
-
# split prompt on <<image-placeholder>> if it exists
|
|
1055
|
-
if "<<image-placeholder>>" in user_prompt:
|
|
1056
|
-
raise NotImplementedError("Image prompts are not yet supported.")
|
|
1057
|
-
|
|
1058
|
-
prompt_sections = user_prompt.split("<<image-placeholder>>")
|
|
1059
|
-
messages = [{"role": "user", "type": "text", "content": prompt_sections[0].format(**format_kwargs)}]
|
|
1060
|
-
|
|
1061
|
-
# NOTE: this currently assumes that the user can only provide a single <<image-placeholder>>
|
|
1062
|
-
if len(prompt_sections) > 1:
|
|
1063
|
-
image_messages = self._create_image_messages(candidate, input_fields)
|
|
1064
|
-
messages.extend(image_messages)
|
|
1065
|
-
messages.append({"role": "user", "type": "text", "content": prompt_sections[1].format(**format_kwargs)})
|
|
1066
|
-
|
|
1067
|
-
return messages
|
|
1068
|
-
|
|
1069
945
|
def create_messages(self, candidate: DataRecord, output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
|
|
1070
946
|
"""
|
|
1071
947
|
Creates the messages for the chat payload based on the prompt strategy.
|
|
@@ -1090,19 +966,15 @@ class PromptFactory:
|
|
|
1090
966
|
input_fields = self._get_input_fields(candidate, **kwargs)
|
|
1091
967
|
right_input_fields = [] if right_candidate is None else self._get_input_fields(right_candidate, **kwargs)
|
|
1092
968
|
|
|
1093
|
-
#
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
if "system_prompt" in kwargs:
|
|
1097
|
-
messages.append({"role": "system", "type": "text", "content": kwargs["system_prompt"]})
|
|
1098
|
-
messages.extend(self._process_custom_user_prompt(candidate, input_fields, **kwargs))
|
|
1099
|
-
return messages
|
|
969
|
+
# use input fields to determine the left / right input modalities
|
|
970
|
+
input_modalities = self._get_input_modalities(candidate, input_fields)
|
|
971
|
+
right_input_modalities = set() if right_candidate is None else self._get_input_modalities(right_candidate, right_input_fields)
|
|
1100
972
|
|
|
1101
973
|
# initialize messages
|
|
1102
974
|
messages = []
|
|
1103
975
|
|
|
1104
976
|
# compute the full dictionary of format kwargs and add to kwargs
|
|
1105
|
-
format_kwargs = self._get_all_format_kwargs(candidate, input_fields, output_fields, right_candidate, right_input_fields, **kwargs)
|
|
977
|
+
format_kwargs = self._get_all_format_kwargs(candidate, input_fields, input_modalities, output_fields, right_candidate, right_input_fields, right_input_modalities, **kwargs)
|
|
1106
978
|
kwargs = {**kwargs, **format_kwargs}
|
|
1107
979
|
|
|
1108
980
|
# generate system message (if applicable)
|