palimpzest 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. palimpzest/constants.py +38 -62
  2. palimpzest/core/data/iter_dataset.py +5 -5
  3. palimpzest/core/elements/groupbysig.py +1 -1
  4. palimpzest/core/elements/records.py +91 -109
  5. palimpzest/core/lib/schemas.py +23 -0
  6. palimpzest/core/models.py +3 -3
  7. palimpzest/prompts/__init__.py +2 -6
  8. palimpzest/prompts/convert_prompts.py +10 -66
  9. palimpzest/prompts/critique_and_refine_prompts.py +66 -0
  10. palimpzest/prompts/filter_prompts.py +8 -46
  11. palimpzest/prompts/join_prompts.py +12 -75
  12. palimpzest/prompts/{moa_aggregator_convert_prompts.py → moa_aggregator_prompts.py} +51 -2
  13. palimpzest/prompts/moa_proposer_prompts.py +87 -0
  14. palimpzest/prompts/prompt_factory.py +351 -479
  15. palimpzest/prompts/split_merge_prompts.py +51 -2
  16. palimpzest/prompts/split_proposer_prompts.py +48 -16
  17. palimpzest/prompts/utils.py +109 -0
  18. palimpzest/query/execution/all_sample_execution_strategy.py +1 -1
  19. palimpzest/query/execution/execution_strategy.py +4 -4
  20. palimpzest/query/execution/mab_execution_strategy.py +1 -2
  21. palimpzest/query/execution/parallel_execution_strategy.py +3 -3
  22. palimpzest/query/execution/single_threaded_execution_strategy.py +8 -8
  23. palimpzest/query/generators/generators.py +31 -17
  24. palimpzest/query/operators/__init__.py +15 -2
  25. palimpzest/query/operators/aggregate.py +21 -19
  26. palimpzest/query/operators/compute.py +6 -8
  27. palimpzest/query/operators/convert.py +12 -37
  28. palimpzest/query/operators/critique_and_refine.py +194 -0
  29. palimpzest/query/operators/distinct.py +7 -7
  30. palimpzest/query/operators/filter.py +13 -25
  31. palimpzest/query/operators/join.py +321 -192
  32. palimpzest/query/operators/limit.py +4 -4
  33. palimpzest/query/operators/mixture_of_agents.py +246 -0
  34. palimpzest/query/operators/physical.py +25 -2
  35. palimpzest/query/operators/project.py +4 -4
  36. palimpzest/query/operators/{rag_convert.py → rag.py} +202 -5
  37. palimpzest/query/operators/retrieve.py +10 -9
  38. palimpzest/query/operators/scan.py +9 -10
  39. palimpzest/query/operators/search.py +18 -24
  40. palimpzest/query/operators/split.py +321 -0
  41. palimpzest/query/optimizer/__init__.py +12 -8
  42. palimpzest/query/optimizer/optimizer.py +12 -10
  43. palimpzest/query/optimizer/rules.py +201 -108
  44. palimpzest/query/optimizer/tasks.py +18 -6
  45. palimpzest/validator/validator.py +7 -9
  46. {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/METADATA +3 -8
  47. palimpzest-0.8.4.dist-info/RECORD +95 -0
  48. palimpzest/prompts/critique_and_refine_convert_prompts.py +0 -216
  49. palimpzest/prompts/moa_proposer_convert_prompts.py +0 -75
  50. palimpzest/prompts/util_phrases.py +0 -19
  51. palimpzest/query/operators/critique_and_refine_convert.py +0 -113
  52. palimpzest/query/operators/mixture_of_agents_convert.py +0 -140
  53. palimpzest/query/operators/split_convert.py +0 -170
  54. palimpzest-0.8.2.dist-info/RECORD +0 -95
  55. {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/WHEEL +0 -0
  56. {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/licenses/LICENSE +0 -0
  57. {palimpzest-0.8.2.dist-info → palimpzest-0.8.4.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@
2
2
 
3
3
  import base64
4
4
  import json
5
- from string import Formatter
6
5
 
7
6
  from pydantic import BaseModel
8
7
 
@@ -10,137 +9,114 @@ from palimpzest.constants import (
10
9
  LLAMA_CONTEXT_TOKENS_LIMIT,
11
10
  TOKENS_PER_CHARACTER,
12
11
  Cardinality,
12
+ Modality,
13
13
  Model,
14
14
  PromptStrategy,
15
15
  )
16
16
  from palimpzest.core.elements.records import DataRecord
17
- from palimpzest.core.lib.schemas import AudioBase64, AudioFilepath, ImageBase64, ImageFilepath, ImageURL
17
+ from palimpzest.core.lib.schemas import (
18
+ AUDIO_FIELD_TYPES,
19
+ IMAGE_FIELD_TYPES,
20
+ AudioBase64,
21
+ AudioFilepath,
22
+ ImageBase64,
23
+ ImageFilepath,
24
+ ImageURL,
25
+ )
18
26
  from palimpzest.prompts.convert_prompts import (
19
- COT_QA_AUDIO_DISCLAIMER,
20
- COT_QA_AUDIO_EXAMPLE_ANSWER,
21
- COT_QA_AUDIO_EXAMPLE_CONTEXT,
22
- COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
23
- COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
24
- COT_QA_AUDIO_EXAMPLE_REASONING,
25
- COT_QA_AUDIO_JOB_INSTRUCTION,
26
- COT_QA_BASE_SYSTEM_PROMPT,
27
- COT_QA_BASE_USER_PROMPT,
28
- COT_QA_EXAMPLE_ANSWER,
29
- COT_QA_EXAMPLE_CONTEXT,
30
- COT_QA_EXAMPLE_INPUT_FIELDS,
31
- COT_QA_EXAMPLE_OUTPUT_FIELDS,
32
- COT_QA_EXAMPLE_REASONING,
33
- COT_QA_IMAGE_DISCLAIMER,
34
- COT_QA_IMAGE_EXAMPLE_ANSWER,
35
- COT_QA_IMAGE_EXAMPLE_CONTEXT,
36
- COT_QA_IMAGE_EXAMPLE_INPUT_FIELDS,
37
- COT_QA_IMAGE_EXAMPLE_OUTPUT_FIELDS,
38
- COT_QA_IMAGE_EXAMPLE_REASONING,
39
- COT_QA_IMAGE_JOB_INSTRUCTION,
40
- COT_QA_JOB_INSTRUCTION,
41
- COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
42
- COT_QA_NO_REASONING_BASE_USER_PROMPT,
27
+ MAP_BASE_SYSTEM_PROMPT,
28
+ MAP_BASE_USER_PROMPT,
29
+ MAP_NO_REASONING_BASE_SYSTEM_PROMPT,
30
+ MAP_NO_REASONING_BASE_USER_PROMPT,
43
31
  )
44
- from palimpzest.prompts.critique_and_refine_convert_prompts import (
32
+ from palimpzest.prompts.critique_and_refine_prompts import (
45
33
  BASE_CRITIQUE_PROMPT,
46
34
  BASE_REFINEMENT_PROMPT,
47
- COT_QA_CRITIQUE_CRITERIA,
48
- COT_QA_CRITIQUE_FINISH_INSTRUCTION,
49
- COT_QA_IMAGE_CRITIQUE_CRITERIA,
50
- COT_QA_IMAGE_REFINEMENT_CRITERIA,
51
- COT_QA_REFINEMENT_CRITERIA,
52
- COT_QA_REFINEMENT_FINISH_INSTRUCTION,
35
+ FILTER_CRITIQUE_CRITERIA,
36
+ FILTER_CRITIQUE_FINISH_INSTRUCTION,
37
+ FILTER_REFINEMENT_CRITERIA,
38
+ FILTER_REFINEMENT_FINISH_INSTRUCTION,
39
+ MAP_CRITIQUE_CRITERIA,
40
+ MAP_CRITIQUE_FINISH_INSTRUCTION,
41
+ MAP_REFINEMENT_CRITERIA,
42
+ MAP_REFINEMENT_FINISH_INSTRUCTION,
53
43
  )
54
44
  from palimpzest.prompts.filter_prompts import (
55
- COT_BOOL_AUDIO_DISCLAIMER,
56
- COT_BOOL_AUDIO_EXAMPLE_CONTEXT,
57
- COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
58
- COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
59
- COT_BOOL_AUDIO_EXAMPLE_REASONING,
60
- COT_BOOL_AUDIO_JOB_INSTRUCTION,
61
- COT_BOOL_BASE_SYSTEM_PROMPT,
62
- COT_BOOL_BASE_USER_PROMPT,
63
- COT_BOOL_EXAMPLE_CONTEXT,
64
- COT_BOOL_EXAMPLE_FILTER_CONDITION,
65
- COT_BOOL_EXAMPLE_INPUT_FIELDS,
66
- COT_BOOL_EXAMPLE_REASONING,
67
- COT_BOOL_IMAGE_DISCLAIMER,
68
- COT_BOOL_IMAGE_EXAMPLE_CONTEXT,
69
- COT_BOOL_IMAGE_EXAMPLE_FILTER_CONDITION,
70
- COT_BOOL_IMAGE_EXAMPLE_INPUT_FIELDS,
71
- COT_BOOL_IMAGE_EXAMPLE_REASONING,
72
- COT_BOOL_IMAGE_JOB_INSTRUCTION,
73
- COT_BOOL_JOB_INSTRUCTION,
74
- COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
75
- COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
45
+ FILTER_BASE_SYSTEM_PROMPT,
46
+ FILTER_BASE_USER_PROMPT,
47
+ FILTER_NO_REASONING_BASE_SYSTEM_PROMPT,
48
+ FILTER_NO_REASONING_BASE_USER_PROMPT,
76
49
  )
77
50
  from palimpzest.prompts.join_prompts import (
78
- COT_JOIN_AUDIO_DISCLAIMER,
79
- COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
80
- COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
81
- COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
82
- COT_JOIN_AUDIO_EXAMPLE_REASONING,
83
- COT_JOIN_AUDIO_JOB_INSTRUCTION,
84
- COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
85
- COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
86
- COT_JOIN_BASE_SYSTEM_PROMPT,
87
- COT_JOIN_BASE_USER_PROMPT,
88
- COT_JOIN_EXAMPLE_CONTEXT,
89
- COT_JOIN_EXAMPLE_INPUT_FIELDS,
90
- COT_JOIN_EXAMPLE_JOIN_CONDITION,
91
- COT_JOIN_EXAMPLE_REASONING,
92
- COT_JOIN_IMAGE_DISCLAIMER,
93
- COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
94
- COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
95
- COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
96
- COT_JOIN_IMAGE_EXAMPLE_REASONING,
97
- COT_JOIN_IMAGE_JOB_INSTRUCTION,
98
- COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
99
- COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
100
- COT_JOIN_JOB_INSTRUCTION,
101
- COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
102
- COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
103
- COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
104
- COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
105
- COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
106
- COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
51
+ JOIN_BASE_SYSTEM_PROMPT,
52
+ JOIN_BASE_USER_PROMPT,
53
+ JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
54
+ JOIN_NO_REASONING_BASE_USER_PROMPT,
107
55
  )
108
- from palimpzest.prompts.moa_aggregator_convert_prompts import (
109
- COT_MOA_AGG_BASE_SYSTEM_PROMPT,
110
- COT_MOA_AGG_BASE_USER_PROMPT,
56
+ from palimpzest.prompts.moa_aggregator_prompts import (
57
+ FILTER_MOA_AGG_BASE_SYSTEM_PROMPT,
58
+ FILTER_MOA_AGG_BASE_USER_PROMPT,
59
+ MAP_MOA_AGG_BASE_SYSTEM_PROMPT,
60
+ MAP_MOA_AGG_BASE_USER_PROMPT,
111
61
  )
112
- from palimpzest.prompts.moa_proposer_convert_prompts import (
113
- COT_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
114
- COT_MOA_PROPOSER_BASE_USER_PROMPT,
115
- COT_MOA_PROPOSER_EXAMPLE_ANSWER,
116
- COT_MOA_PROPOSER_EXAMPLE_CONTEXT,
117
- COT_MOA_PROPOSER_EXAMPLE_INPUT_FIELDS,
118
- COT_MOA_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
119
- COT_MOA_PROPOSER_IMAGE_DISCLAIMER,
120
- COT_MOA_PROPOSER_IMAGE_EXAMPLE_ANSWER,
121
- COT_MOA_PROPOSER_IMAGE_EXAMPLE_CONTEXT,
122
- COT_MOA_PROPOSER_IMAGE_EXAMPLE_INPUT_FIELDS,
123
- COT_MOA_PROPOSER_IMAGE_EXAMPLE_OUTPUT_FIELDS,
124
- COT_MOA_PROPOSER_IMAGE_JOB_INSTRUCTION,
125
- COT_MOA_PROPOSER_JOB_INSTRUCTION,
62
+ from palimpzest.prompts.moa_proposer_prompts import (
63
+ FILTER_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
64
+ FILTER_MOA_PROPOSER_BASE_USER_PROMPT,
65
+ MAP_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
66
+ MAP_MOA_PROPOSER_BASE_USER_PROMPT,
126
67
  )
127
68
  from palimpzest.prompts.split_merge_prompts import (
128
- COT_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
129
- COT_SPLIT_MERGER_BASE_USER_PROMPT,
69
+ FILTER_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
70
+ FILTER_SPLIT_MERGER_BASE_USER_PROMPT,
71
+ MAP_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
72
+ MAP_SPLIT_MERGER_BASE_USER_PROMPT,
130
73
  )
131
74
  from palimpzest.prompts.split_proposer_prompts import (
132
- COT_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
133
- COT_SPLIT_PROPOSER_BASE_USER_PROMPT,
134
- SPLIT_PROPOSER_EXAMPLE_ANSWER,
135
- SPLIT_PROPOSER_EXAMPLE_CONTEXT,
136
- SPLIT_PROPOSER_EXAMPLE_INPUT_FIELDS,
137
- SPLIT_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
138
- SPLIT_PROPOSER_JOB_INSTRUCTION,
75
+ FILTER_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
76
+ FILTER_SPLIT_PROPOSER_BASE_USER_PROMPT,
77
+ MAP_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
78
+ MAP_SPLIT_PROPOSER_BASE_USER_PROMPT,
139
79
  )
140
- from palimpzest.prompts.util_phrases import (
80
+ from palimpzest.prompts.utils import (
81
+ AUDIO_DISCLAIMER,
82
+ AUDIO_EXAMPLE_ANSWER,
83
+ AUDIO_EXAMPLE_CONTEXT,
84
+ AUDIO_EXAMPLE_INPUT_FIELDS,
85
+ AUDIO_EXAMPLE_OUTPUT_FIELDS,
86
+ AUDIO_EXAMPLE_REASONING,
87
+ AUDIO_SENTENCE_EXAMPLE_ANSWER,
141
88
  DESC_SECTION,
89
+ EXAMPLE_FILTER_CONDITION,
90
+ EXAMPLE_JOIN_CONDITION,
91
+ FILTER_EXAMPLE_REASONING,
92
+ FILTER_JOB_INSTRUCTION,
93
+ IMAGE_DISCLAIMER,
94
+ IMAGE_EXAMPLE_ANSWER,
95
+ IMAGE_EXAMPLE_CONTEXT,
96
+ IMAGE_EXAMPLE_INPUT_FIELDS,
97
+ IMAGE_EXAMPLE_OUTPUT_FIELDS,
98
+ IMAGE_EXAMPLE_REASONING,
99
+ IMAGE_SENTENCE_EXAMPLE_ANSWER,
100
+ JOIN_EXAMPLE_REASONING,
101
+ JOIN_JOB_INSTRUCTION,
102
+ MAP_JOB_INSTRUCTION,
142
103
  ONE_TO_MANY_OUTPUT_FORMAT_INSTRUCTION,
143
104
  ONE_TO_ONE_OUTPUT_FORMAT_INSTRUCTION,
105
+ PROPOSER_JOB_INSTRUCTION,
106
+ RIGHT_AUDIO_DISCLAIMER,
107
+ RIGHT_AUDIO_EXAMPLE_CONTEXT,
108
+ RIGHT_AUDIO_EXAMPLE_INPUT_FIELDS,
109
+ RIGHT_IMAGE_DISCLAIMER,
110
+ RIGHT_IMAGE_EXAMPLE_CONTEXT,
111
+ RIGHT_IMAGE_EXAMPLE_INPUT_FIELDS,
112
+ RIGHT_TEXT_EXAMPLE_CONTEXT,
113
+ RIGHT_TEXT_EXAMPLE_INPUT_FIELDS,
114
+ TEXT_EXAMPLE_ANSWER,
115
+ TEXT_EXAMPLE_CONTEXT,
116
+ TEXT_EXAMPLE_INPUT_FIELDS,
117
+ TEXT_EXAMPLE_OUTPUT_FIELDS,
118
+ TEXT_EXAMPLE_REASONING,
119
+ TEXT_SENTENCE_EXAMPLE_ANSWER,
144
120
  )
145
121
 
146
122
 
@@ -148,62 +124,54 @@ class PromptFactory:
148
124
  """Factory class for generating prompts for the Generator given the input(s)."""
149
125
 
150
126
  BASE_SYSTEM_PROMPT_MAP = {
151
- PromptStrategy.COT_BOOL: COT_BOOL_BASE_SYSTEM_PROMPT,
152
- PromptStrategy.COT_BOOL_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
153
- PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_BASE_SYSTEM_PROMPT,
154
- PromptStrategy.COT_BOOL_AUDIO_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
155
- PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_SYSTEM_PROMPT,
156
- PromptStrategy.COT_BOOL_IMAGE_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
157
- PromptStrategy.COT_JOIN: COT_JOIN_BASE_SYSTEM_PROMPT,
158
- PromptStrategy.COT_JOIN_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
159
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_BASE_SYSTEM_PROMPT,
160
- PromptStrategy.COT_JOIN_AUDIO_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
161
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_BASE_SYSTEM_PROMPT,
162
- PromptStrategy.COT_JOIN_IMAGE_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
163
- PromptStrategy.COT_QA: COT_QA_BASE_SYSTEM_PROMPT,
164
- PromptStrategy.COT_QA_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
165
- PromptStrategy.COT_QA_AUDIO: COT_QA_BASE_SYSTEM_PROMPT,
166
- PromptStrategy.COT_QA_AUDIO_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
167
- PromptStrategy.COT_QA_CRITIC: None,
168
- PromptStrategy.COT_QA_REFINE: None,
169
- PromptStrategy.COT_QA_IMAGE: COT_QA_BASE_SYSTEM_PROMPT,
170
- PromptStrategy.COT_QA_IMAGE_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
171
- PromptStrategy.COT_QA_IMAGE_CRITIC: None,
172
- PromptStrategy.COT_QA_IMAGE_REFINE: None,
173
- PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
174
- PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
175
- PromptStrategy.COT_MOA_AGG: COT_MOA_AGG_BASE_SYSTEM_PROMPT,
176
- PromptStrategy.SPLIT_PROPOSER: COT_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
177
- PromptStrategy.SPLIT_MERGER: COT_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
127
+ # filter system prompts
128
+ PromptStrategy.FILTER: FILTER_BASE_SYSTEM_PROMPT,
129
+ PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_SYSTEM_PROMPT,
130
+ PromptStrategy.FILTER_CRITIC: None,
131
+ PromptStrategy.FILTER_REFINE: None,
132
+ PromptStrategy.FILTER_MOA_PROPOSER: FILTER_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
133
+ PromptStrategy.FILTER_MOA_AGG: FILTER_MOA_AGG_BASE_SYSTEM_PROMPT,
134
+ PromptStrategy.FILTER_SPLIT_PROPOSER: FILTER_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
135
+ PromptStrategy.FILTER_SPLIT_MERGER: FILTER_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
136
+
137
+ # join system prompts
138
+ PromptStrategy.JOIN: JOIN_BASE_SYSTEM_PROMPT,
139
+ PromptStrategy.JOIN_NO_REASONING: JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
140
+
141
+ # map system prompts
142
+ PromptStrategy.MAP: MAP_BASE_SYSTEM_PROMPT,
143
+ PromptStrategy.MAP_NO_REASONING: MAP_NO_REASONING_BASE_SYSTEM_PROMPT,
144
+ PromptStrategy.MAP_CRITIC: None,
145
+ PromptStrategy.MAP_REFINE: None,
146
+ PromptStrategy.MAP_MOA_PROPOSER: MAP_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
147
+ PromptStrategy.MAP_MOA_AGG: MAP_MOA_AGG_BASE_SYSTEM_PROMPT,
148
+ PromptStrategy.MAP_SPLIT_PROPOSER: MAP_SPLIT_PROPOSER_BASE_SYSTEM_PROMPT,
149
+ PromptStrategy.MAP_SPLIT_MERGER: MAP_SPLIT_MERGER_BASE_SYSTEM_PROMPT,
178
150
  }
179
151
  BASE_USER_PROMPT_MAP = {
180
- PromptStrategy.COT_BOOL: COT_BOOL_BASE_USER_PROMPT,
181
- PromptStrategy.COT_BOOL_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
182
- PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_BASE_USER_PROMPT,
183
- PromptStrategy.COT_BOOL_AUDIO_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
184
- PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_USER_PROMPT,
185
- PromptStrategy.COT_BOOL_IMAGE_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
186
- PromptStrategy.COT_JOIN: COT_JOIN_BASE_USER_PROMPT,
187
- PromptStrategy.COT_JOIN_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
188
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_BASE_USER_PROMPT,
189
- PromptStrategy.COT_JOIN_AUDIO_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
190
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_BASE_USER_PROMPT,
191
- PromptStrategy.COT_JOIN_IMAGE_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
192
- PromptStrategy.COT_QA: COT_QA_BASE_USER_PROMPT,
193
- PromptStrategy.COT_QA_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
194
- PromptStrategy.COT_QA_AUDIO: COT_QA_BASE_USER_PROMPT,
195
- PromptStrategy.COT_QA_AUDIO_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
196
- PromptStrategy.COT_QA_CRITIC: BASE_CRITIQUE_PROMPT,
197
- PromptStrategy.COT_QA_REFINE: BASE_REFINEMENT_PROMPT,
198
- PromptStrategy.COT_QA_IMAGE: COT_QA_BASE_USER_PROMPT,
199
- PromptStrategy.COT_QA_IMAGE_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
200
- PromptStrategy.COT_QA_IMAGE_CRITIC: BASE_CRITIQUE_PROMPT,
201
- PromptStrategy.COT_QA_IMAGE_REFINE: BASE_REFINEMENT_PROMPT,
202
- PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_BASE_USER_PROMPT,
203
- PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_BASE_USER_PROMPT,
204
- PromptStrategy.COT_MOA_AGG: COT_MOA_AGG_BASE_USER_PROMPT,
205
- PromptStrategy.SPLIT_PROPOSER: COT_SPLIT_PROPOSER_BASE_USER_PROMPT,
206
- PromptStrategy.SPLIT_MERGER: COT_SPLIT_MERGER_BASE_USER_PROMPT,
152
+ # filter user prompts
153
+ PromptStrategy.FILTER: FILTER_BASE_USER_PROMPT,
154
+ PromptStrategy.FILTER_NO_REASONING: FILTER_NO_REASONING_BASE_USER_PROMPT,
155
+ PromptStrategy.FILTER_CRITIC: BASE_CRITIQUE_PROMPT,
156
+ PromptStrategy.FILTER_REFINE: BASE_REFINEMENT_PROMPT,
157
+ PromptStrategy.FILTER_MOA_PROPOSER: FILTER_MOA_PROPOSER_BASE_USER_PROMPT,
158
+ PromptStrategy.FILTER_MOA_AGG: FILTER_MOA_AGG_BASE_USER_PROMPT,
159
+ PromptStrategy.FILTER_SPLIT_PROPOSER: FILTER_SPLIT_PROPOSER_BASE_USER_PROMPT,
160
+ PromptStrategy.FILTER_SPLIT_MERGER: FILTER_SPLIT_MERGER_BASE_USER_PROMPT,
161
+
162
+ # join user prompts
163
+ PromptStrategy.JOIN: JOIN_BASE_USER_PROMPT,
164
+ PromptStrategy.JOIN_NO_REASONING: JOIN_NO_REASONING_BASE_USER_PROMPT,
165
+
166
+ # map user prompts
167
+ PromptStrategy.MAP: MAP_BASE_USER_PROMPT,
168
+ PromptStrategy.MAP_NO_REASONING: MAP_NO_REASONING_BASE_USER_PROMPT,
169
+ PromptStrategy.MAP_CRITIC: BASE_CRITIQUE_PROMPT,
170
+ PromptStrategy.MAP_REFINE: BASE_REFINEMENT_PROMPT,
171
+ PromptStrategy.MAP_MOA_PROPOSER: MAP_MOA_PROPOSER_BASE_USER_PROMPT,
172
+ PromptStrategy.MAP_MOA_AGG: MAP_MOA_AGG_BASE_USER_PROMPT,
173
+ PromptStrategy.MAP_SPLIT_PROPOSER: MAP_SPLIT_PROPOSER_BASE_USER_PROMPT,
174
+ PromptStrategy.MAP_SPLIT_MERGER: MAP_SPLIT_MERGER_BASE_USER_PROMPT,
207
175
  }
208
176
 
209
177
  def __init__(self, prompt_strategy: PromptStrategy, model: Model, cardinality: Cardinality, desc: str | None = None) -> None:
@@ -277,6 +245,54 @@ class PromptFactory:
277
245
  input_fields = [field for field in input_fields if field in candidate.get_field_names()]
278
246
  return input_fields
279
247
 
248
+ def _get_input_modalities(self, candidate: DataRecord, input_fields: list[str]) -> set[Modality]:
249
+ """
250
+ The list of input modalities for the given input fields.
251
+
252
+ Args:
253
+ candidate (DataRecord): The input record.
254
+ input_fields (list[str]): The input fields.
255
+
256
+ Returns:
257
+ set[Modality]: The list of input modalities.
258
+ """
259
+ input_modalities = []
260
+ for field_name in input_fields:
261
+ field_type = candidate.get_field_type(field_name)
262
+ if field_type.annotation in IMAGE_FIELD_TYPES:
263
+ input_modalities.append(Modality.IMAGE)
264
+ elif field_type.annotation in AUDIO_FIELD_TYPES:
265
+ input_modalities.append(Modality.AUDIO)
266
+ else:
267
+ input_modalities.append(Modality.TEXT)
268
+
269
+ return set(input_modalities)
270
+
271
+ def _get_modalities_str(self, input_modalities: set[Modality]) -> str:
272
+ """
273
+ Returns a format string to reflect the input modalities.
274
+
275
+ Args:
276
+ input_modalities (set[Modality]): The input modalities.
277
+
278
+ Returns:
279
+ str: The string to reflect the input modalities.
280
+ """
281
+ if input_modalities == {Modality.TEXT}:
282
+ return "text"
283
+ elif input_modalities == {Modality.IMAGE}:
284
+ return "image(s)"
285
+ elif input_modalities == {Modality.AUDIO}:
286
+ return "audio"
287
+ elif input_modalities == {Modality.TEXT, Modality.IMAGE}:
288
+ return "text and/or image(s)"
289
+ elif input_modalities == {Modality.TEXT, Modality.AUDIO}:
290
+ return "text and/or audio"
291
+ elif input_modalities == {Modality.IMAGE, Modality.AUDIO}:
292
+ return "image(s) and/or audio"
293
+ elif input_modalities == {Modality.TEXT, Modality.IMAGE, Modality.AUDIO}:
294
+ return "text, image(s), and/or audio"
295
+
280
296
  def _get_input_fields_desc(self, candidate: DataRecord, input_fields: list[str]) -> str:
281
297
  """
282
298
  Returns a multi-line description of each input field for the prompt.
@@ -305,8 +321,8 @@ class PromptFactory:
305
321
  str: The output fields description.
306
322
  """
307
323
  output_fields_desc = ""
308
- output_schema: BaseModel = kwargs.get("output_schema")
309
- if self.prompt_strategy.is_convert_prompt():
324
+ output_schema: type[BaseModel] = kwargs.get("output_schema")
325
+ if self.prompt_strategy.is_map_prompt():
310
326
  assert output_schema is not None, "Output schema must be provided for convert prompts."
311
327
 
312
328
  for field_name in sorted(output_fields):
@@ -324,7 +340,7 @@ class PromptFactory:
324
340
  str | None: The filter condition (if applicable).
325
341
  """
326
342
  filter_condition = kwargs.get("filter_condition")
327
- if self.prompt_strategy.is_bool_prompt():
343
+ if self.prompt_strategy.is_filter_prompt():
328
344
  assert filter_condition is not None, "Filter condition must be provided for filter operations."
329
345
 
330
346
  return filter_condition
@@ -390,7 +406,8 @@ class PromptFactory:
390
406
  if self.prompt_strategy.is_moa_aggregator_prompt():
391
407
  model_responses = ""
392
408
  for idx, model_response in enumerate(kwargs.get("model_responses")):
393
- model_responses += f"MODEL RESPONSE {idx + 1}: {model_response}\n"
409
+ model_responses += f"MODEL RESPONSE {idx + 1}: {model_response.rstrip()}\n\n"
410
+ model_responses = model_responses.rstrip() if model_responses is not None else None
394
411
 
395
412
  return model_responses
396
413
 
@@ -408,7 +425,8 @@ class PromptFactory:
408
425
  if self.prompt_strategy.is_split_merger_prompt():
409
426
  chunk_outputs = ""
410
427
  for idx, chunk_output in enumerate(kwargs.get("chunk_outputs")):
411
- chunk_outputs += f"CHUNK OUTPUT {idx + 1}: {chunk_output}\n"
428
+ chunk_outputs += f"CHUNK OUTPUT {idx + 1}: {chunk_output.rstrip()}\n\n"
429
+ chunk_outputs = chunk_outputs.rstrip() if chunk_outputs is not None else None
412
430
 
413
431
  return chunk_outputs
414
432
 
@@ -425,28 +443,33 @@ class PromptFactory:
425
443
  else ONE_TO_MANY_OUTPUT_FORMAT_INSTRUCTION
426
444
  )
427
445
 
428
- def _get_job_instruction(self) -> str | None:
446
+ def _get_job_instruction(self, input_modalities: set[Modality]) -> str | None:
429
447
  """
430
448
  Returns the job instruction based on the prompt strategy.
431
449
 
450
+ Args:
451
+ input_modalities (set[Modality]): The modalities of the input fields.
452
+
432
453
  Returns:
433
- str | None: The job instruction (if applicable).
454
+ str | None: The job instruction.
434
455
  """
435
- prompt_strategy_to_job_instruction = {
436
- PromptStrategy.COT_BOOL: COT_BOOL_JOB_INSTRUCTION,
437
- PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_JOB_INSTRUCTION,
438
- PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_JOB_INSTRUCTION,
439
- PromptStrategy.COT_JOIN: COT_JOIN_JOB_INSTRUCTION,
440
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_JOB_INSTRUCTION,
441
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_JOB_INSTRUCTION,
442
- PromptStrategy.COT_QA: COT_QA_JOB_INSTRUCTION,
443
- PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_JOB_INSTRUCTION,
444
- PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_JOB_INSTRUCTION,
445
- PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_JOB_INSTRUCTION,
446
- PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_JOB_INSTRUCTION,
447
- PromptStrategy.SPLIT_PROPOSER: SPLIT_PROPOSER_JOB_INSTRUCTION,
448
- }
449
- return prompt_strategy_to_job_instruction.get(self.prompt_strategy)
456
+ # get the job instruction based on the prompt strategy
457
+ job_instruction = None
458
+ if self.prompt_strategy.is_moa_proposer_prompt() or self.prompt_strategy.is_split_proposer_prompt():
459
+ job_instruction = PROPOSER_JOB_INSTRUCTION
460
+ elif self.prompt_strategy.is_map_prompt():
461
+ job_instruction = MAP_JOB_INSTRUCTION
462
+ elif self.prompt_strategy.is_filter_prompt():
463
+ job_instruction = FILTER_JOB_INSTRUCTION
464
+ elif self.prompt_strategy.is_join_prompt():
465
+ job_instruction = JOIN_JOB_INSTRUCTION
466
+
467
+ # format the job instruction based on the input modalities
468
+ modalities = self._get_modalities_str(input_modalities)
469
+ if job_instruction is not None:
470
+ job_instruction = job_instruction.format(modalities=modalities)
471
+
472
+ return job_instruction
450
473
 
451
474
  def _get_desc_section(self) -> str:
452
475
  """
@@ -470,9 +493,7 @@ class PromptFactory:
470
493
  """
471
494
  critique_criteria = None
472
495
  if self.prompt_strategy.is_critic_prompt():
473
- critique_criteria = (
474
- COT_QA_IMAGE_CRITIQUE_CRITERIA if self.prompt_strategy.is_image_prompt() else COT_QA_CRITIQUE_CRITERIA
475
- )
496
+ critique_criteria = MAP_CRITIQUE_CRITERIA if self.prompt_strategy.is_map_prompt() else FILTER_CRITIQUE_CRITERIA
476
497
 
477
498
  return critique_criteria
478
499
 
@@ -485,11 +506,7 @@ class PromptFactory:
485
506
  """
486
507
  refinement_criteria = None
487
508
  if self.prompt_strategy.is_refine_prompt():
488
- refinement_criteria = (
489
- COT_QA_IMAGE_REFINEMENT_CRITERIA
490
- if self.prompt_strategy.is_image_prompt()
491
- else COT_QA_REFINEMENT_CRITERIA
492
- )
509
+ refinement_criteria = MAP_REFINEMENT_CRITERIA if self.prompt_strategy.is_map_prompt() else FILTER_REFINEMENT_CRITERIA
493
510
 
494
511
  return refinement_criteria
495
512
 
@@ -502,240 +519,156 @@ class PromptFactory:
502
519
  """
503
520
  finish_instruction = None
504
521
  if self.prompt_strategy.is_critic_prompt():
505
- finish_instruction = COT_QA_CRITIQUE_FINISH_INSTRUCTION
522
+ finish_instruction = MAP_CRITIQUE_FINISH_INSTRUCTION if self.prompt_strategy.is_map_prompt() else FILTER_CRITIQUE_FINISH_INSTRUCTION
506
523
  elif self.prompt_strategy.is_refine_prompt():
507
- finish_instruction = COT_QA_REFINEMENT_FINISH_INSTRUCTION
524
+ finish_instruction = MAP_REFINEMENT_FINISH_INSTRUCTION if self.prompt_strategy.is_map_prompt() else FILTER_REFINEMENT_FINISH_INSTRUCTION
508
525
 
509
526
  return finish_instruction
510
527
 
511
- def _get_example_input_fields(self) -> str | None:
528
+ def _get_example_input_fields(self, input_modalities: set[Modality], right: bool = False) -> str:
512
529
  """
513
530
  Returns the example input fields for the prompt.
514
531
 
515
- Returns:
516
- str | None: The example input fields (if applicable).
517
- """
518
- prompt_strategy_to_example_input_fields = {
519
- PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_INPUT_FIELDS,
520
- PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
521
- PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_INPUT_FIELDS,
522
- PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_INPUT_FIELDS,
523
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
524
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
525
- PromptStrategy.COT_QA: COT_QA_EXAMPLE_INPUT_FIELDS,
526
- PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
527
- PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_INPUT_FIELDS,
528
- PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_INPUT_FIELDS,
529
- PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_INPUT_FIELDS,
530
- PromptStrategy.SPLIT_PROPOSER: SPLIT_PROPOSER_EXAMPLE_INPUT_FIELDS,
531
- }
532
-
533
- return prompt_strategy_to_example_input_fields.get(self.prompt_strategy)
534
-
535
- def _get_right_example_input_fields(self) -> str | None:
536
- """
537
- Returns the example right input fields for the join prompt.
532
+ Args:
533
+ input_modalities (set[Modality]): The modalities of the input fields.
534
+ right (bool): Whether to return the right input fields for the join prompt.
538
535
 
539
536
  Returns:
540
- str | None: The example right input fields (if applicable).
537
+ str: The example input fields.
541
538
  """
542
- prompt_strategy_to_right_example_input_fields = {
543
- PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
544
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
545
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
539
+ input_modality_to_example_input_fields = {
540
+ Modality.TEXT: RIGHT_TEXT_EXAMPLE_INPUT_FIELDS if right else TEXT_EXAMPLE_INPUT_FIELDS,
541
+ Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_INPUT_FIELDS if right else IMAGE_EXAMPLE_INPUT_FIELDS,
542
+ Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_INPUT_FIELDS if right else AUDIO_EXAMPLE_INPUT_FIELDS,
546
543
  }
547
544
 
548
- return prompt_strategy_to_right_example_input_fields.get(self.prompt_strategy)
545
+ example_input_fields = ""
546
+ for input_modality in input_modalities:
547
+ example_input_fields += input_modality_to_example_input_fields[input_modality].rstrip()
548
+ example_input_fields = example_input_fields.lstrip() + "\n"
549
+
550
+ return example_input_fields
549
551
 
550
- def _get_example_output_fields(self) -> str | None:
552
+ def _get_example_output_fields(self, input_modalities: set[Modality]) -> str:
551
553
  """
552
554
  Returns the example output fields for the prompt.
553
555
 
554
556
  Returns:
555
- str | None: The example output fields (if applicable).
557
+ str: The example output fields.
556
558
  """
557
- prompt_strategy_to_example_output_fields = {
558
- PromptStrategy.COT_QA: COT_QA_EXAMPLE_OUTPUT_FIELDS,
559
- PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
560
- PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_OUTPUT_FIELDS,
561
- PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
562
- PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_OUTPUT_FIELDS,
563
- PromptStrategy.SPLIT_PROPOSER: SPLIT_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
559
+ input_modality_to_example_output_fields = {
560
+ Modality.TEXT: TEXT_EXAMPLE_OUTPUT_FIELDS,
561
+ Modality.IMAGE: IMAGE_EXAMPLE_OUTPUT_FIELDS,
562
+ Modality.AUDIO: AUDIO_EXAMPLE_OUTPUT_FIELDS,
564
563
  }
565
564
 
566
- return prompt_strategy_to_example_output_fields.get(self.prompt_strategy)
565
+ example_output_fields = ""
566
+ for input_modality in input_modalities:
567
+ example_output_fields += input_modality_to_example_output_fields[input_modality].rstrip()
568
+ example_output_fields = example_output_fields.lstrip() + "\n"
567
569
 
568
- def _get_example_context(self) -> str | None:
570
+ return example_output_fields
571
+
572
+ def _get_example_context(self, input_modalities: set[Modality], right: bool = False) -> str:
569
573
  """
570
574
  Returns the example context for the prompt.
571
575
 
572
576
  Returns:
573
- str | None: The example context (if applicable).
577
+ str: The example context.
574
578
  """
575
- prompt_strategy_to_example_context = {
576
- PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_CONTEXT,
577
- PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_CONTEXT,
578
- PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_CONTEXT,
579
- PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_CONTEXT,
580
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
581
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
582
- PromptStrategy.COT_QA: COT_QA_EXAMPLE_CONTEXT,
583
- PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_CONTEXT,
584
- PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_CONTEXT,
585
- PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_CONTEXT,
586
- PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_CONTEXT,
587
- PromptStrategy.SPLIT_PROPOSER: SPLIT_PROPOSER_EXAMPLE_CONTEXT,
579
+ input_modality_to_example_context = {
580
+ Modality.TEXT: RIGHT_TEXT_EXAMPLE_CONTEXT if right else TEXT_EXAMPLE_CONTEXT,
581
+ Modality.IMAGE: RIGHT_IMAGE_EXAMPLE_CONTEXT if right else IMAGE_EXAMPLE_CONTEXT,
582
+ Modality.AUDIO: RIGHT_AUDIO_EXAMPLE_CONTEXT if right else AUDIO_EXAMPLE_CONTEXT,
588
583
  }
589
584
 
590
- return prompt_strategy_to_example_context.get(self.prompt_strategy)
591
-
592
- def _get_right_example_context(self) -> str | None:
593
- """
594
- Returns the right example context for the join prompt.
595
-
596
- Returns:
597
- str | None: The right example context (if applicable).
598
- """
599
- prompt_strategy_to_right_example_context = {
600
- PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
601
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
602
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
603
- }
585
+ example_context = ""
586
+ for input_modality in input_modalities:
587
+ example_context += input_modality_to_example_context[input_modality].rstrip() + ","
588
+ example_context = example_context[:-1] + "\n"
604
589
 
605
- return prompt_strategy_to_right_example_context.get(self.prompt_strategy)
590
+ return example_context
606
591
 
607
- def _get_image_disclaimer(self) -> str:
592
+ def _get_image_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
608
593
  """
609
594
  Returns the image disclaimer for the prompt. The disclaimer must be an empty string
610
- for text prompts.
595
+ for non-image prompts.
611
596
 
612
597
  Returns:
613
598
  str: The image disclaimer. If this is a text prompt then it is an empty string.
614
599
  """
615
- prompt_strategy_to_image_disclaimer = {
616
- PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_DISCLAIMER,
617
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_DISCLAIMER,
618
- PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_DISCLAIMER,
619
- PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_DISCLAIMER,
620
- }
621
-
622
- return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
600
+ image_disclaimer = RIGHT_IMAGE_DISCLAIMER if right else IMAGE_DISCLAIMER
601
+ return image_disclaimer if Modality.IMAGE in input_modalities else ""
623
602
 
624
- def _get_audio_disclaimer(self) -> str:
603
+ def _get_audio_disclaimer(self, input_modalities: set[Modality], right: bool = False) -> str:
625
604
  """
626
605
  Returns the audio disclaimer for the prompt. The disclaimer must be an empty string
627
- for text prompts.
606
+ for non-audio prompts.
628
607
 
629
608
  Returns:
630
609
  str: The audio disclaimer. If this is a text prompt then it is an empty string.
631
610
  """
632
- prompt_strategy_to_audio_disclaimer = {
633
- PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_DISCLAIMER,
634
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_DISCLAIMER,
635
- PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_DISCLAIMER,
636
- }
637
-
638
- return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
639
-
640
- def _get_right_image_disclaimer(self) -> str:
641
- """
642
- Returns the right image disclaimer for the prompt. The disclaimer must be an empty string
643
- for text prompts.
644
-
645
- Returns:
646
- str: The right image disclaimer. If this is a text prompt then it is an empty string.
647
- """
648
- prompt_strategy_to_image_disclaimer = {
649
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
650
- }
651
-
652
- return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
653
-
654
- def _get_right_audio_disclaimer(self) -> str:
655
- """
656
- Returns the right audio disclaimer for the prompt. The disclaimer must be an empty string
657
- for text prompts.
658
-
659
- Returns:
660
- str: The right audio disclaimer. If this is a text prompt then it is an empty string.
661
- """
662
- prompt_strategy_to_audio_disclaimer = {
663
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
664
- }
611
+ audio_disclaimer = RIGHT_AUDIO_DISCLAIMER if right else AUDIO_DISCLAIMER
612
+ return audio_disclaimer if Modality.AUDIO in input_modalities else ""
665
613
 
666
- return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
667
-
668
- def _get_example_filter_condition(self) -> str | None:
614
+ def _get_example_reasoning(self, input_modalities: set[Modality]) -> str:
669
615
  """
670
- Returns the example filter condition for the prompt.
616
+ Returns the example reasoning for the prompt.
671
617
 
672
618
  Returns:
673
- str | None: The example filter condition (if applicable).
674
- """
675
- prompt_strategy_to_example_filter_condition = {
676
- PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_FILTER_CONDITION,
677
- PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
678
- PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_FILTER_CONDITION,
679
- }
680
-
681
- return prompt_strategy_to_example_filter_condition.get(self.prompt_strategy)
682
-
683
- def _get_example_join_condition(self) -> str | None:
619
+ str: The example reasoning.
684
620
  """
685
- Returns the example join condition for the prompt.
621
+ if self.prompt_strategy.is_filter_prompt():
622
+ return FILTER_EXAMPLE_REASONING
623
+ elif self.prompt_strategy.is_join_prompt():
624
+ return JOIN_EXAMPLE_REASONING
686
625
 
687
- Returns:
688
- str | None: The example join condition (if applicable).
689
- """
690
- prompt_strategy_to_example_join_condition = {
691
- PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_JOIN_CONDITION,
692
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
693
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
626
+ input_modality_to_example_reasoning = {
627
+ Modality.TEXT: TEXT_EXAMPLE_REASONING,
628
+ Modality.IMAGE: IMAGE_EXAMPLE_REASONING,
629
+ Modality.AUDIO: AUDIO_EXAMPLE_REASONING,
694
630
  }
695
631
 
696
- return prompt_strategy_to_example_join_condition.get(self.prompt_strategy)
697
-
698
- def _get_example_reasoning(self) -> str | None:
699
- """
700
- Returns the example reasoning for the prompt.
701
-
702
- Returns:
703
- str | None: The example reasoning (if applicable).
704
- """
705
- prompt_strategy_to_example_reasoning = {
706
- PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_REASONING,
707
- PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_REASONING,
708
- PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_REASONING,
709
- PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_REASONING,
710
- PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_REASONING,
711
- PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_REASONING,
712
- PromptStrategy.COT_QA: COT_QA_EXAMPLE_REASONING,
713
- PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_REASONING,
714
- PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_REASONING,
715
- }
632
+ example_reasoning = ""
633
+ for input_modality in input_modalities:
634
+ example_reasoning += input_modality_to_example_reasoning[input_modality] + " "
635
+ example_reasoning = example_reasoning.rstrip()
716
636
 
717
- return prompt_strategy_to_example_reasoning.get(self.prompt_strategy)
637
+ return example_reasoning
718
638
 
719
- def _get_example_answer(self) -> str | None:
639
+ def _get_example_answer(self, input_modalities: set[Modality]) -> str:
720
640
  """
721
641
  Returns the example answer for the prompt.
722
642
 
723
643
  Returns:
724
- str | None: The example answer (if applicable).
644
+ str: The example answer.
725
645
  """
726
- prompt_strategy_to_example_answer = {
727
- PromptStrategy.COT_QA: COT_QA_EXAMPLE_ANSWER,
728
- PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_ANSWER,
729
- PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_ANSWER,
730
- PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_ANSWER,
731
- PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_ANSWER,
732
- PromptStrategy.SPLIT_PROPOSER: SPLIT_PROPOSER_EXAMPLE_ANSWER,
646
+ use_sentence_answers = self.prompt_strategy.is_split_proposer_prompt() or self.prompt_strategy.is_moa_proposer_prompt()
647
+ input_modality_to_example_answer = {
648
+ Modality.TEXT: TEXT_SENTENCE_EXAMPLE_ANSWER if use_sentence_answers else TEXT_EXAMPLE_ANSWER,
649
+ Modality.IMAGE: IMAGE_SENTENCE_EXAMPLE_ANSWER if use_sentence_answers else IMAGE_EXAMPLE_ANSWER,
650
+ Modality.AUDIO: AUDIO_SENTENCE_EXAMPLE_ANSWER if use_sentence_answers else AUDIO_EXAMPLE_ANSWER,
733
651
  }
734
652
 
735
- return prompt_strategy_to_example_answer.get(self.prompt_strategy)
653
+ example_answer = ""
654
+ for input_modality in input_modalities:
655
+ example_answer += input_modality_to_example_answer[input_modality].rstrip()
656
+ if use_sentence_answers:
657
+ example_answer += " "
658
+ example_answer = example_answer + "\n"
659
+
660
+ return example_answer
736
661
 
737
662
  def _get_all_format_kwargs(
738
- self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs
663
+ self,
664
+ candidate: DataRecord,
665
+ input_fields: list[str],
666
+ input_modalities: set[Modality],
667
+ output_fields: list[str],
668
+ right_candidate: DataRecord | None,
669
+ right_input_fields: list[str],
670
+ right_input_modalities: set[Modality],
671
+ **kwargs,
739
672
  ) -> dict:
740
673
  """
741
674
  Returns a dictionary containing all the format kwargs for templating the prompts.
@@ -770,26 +703,27 @@ class PromptFactory:
770
703
  })
771
704
 
772
705
  # get format kwargs which depend on the prompt strategy
706
+ full_input_modalities = input_modalities.union(right_input_modalities)
773
707
  prompt_strategy_format_kwargs = {
774
708
  "output_format_instruction": self._get_output_format_instruction(),
775
- "job_instruction": self._get_job_instruction(),
709
+ "job_instruction": self._get_job_instruction(full_input_modalities),
776
710
  "desc_section": self._get_desc_section(),
777
711
  "critique_criteria": self._get_critique_criteria(),
778
712
  "refinement_criteria": self._get_refinement_criteria(),
779
713
  "finish_instruction": self._get_finish_instruction(),
780
- "example_input_fields": self._get_example_input_fields(),
781
- "right_example_input_fields": self._get_right_example_input_fields(),
782
- "example_output_fields": self._get_example_output_fields(),
783
- "example_context": self._get_example_context(),
784
- "right_example_context": self._get_right_example_context(),
785
- "image_disclaimer": self._get_image_disclaimer(),
786
- "audio_disclaimer": self._get_audio_disclaimer(),
787
- "right_image_disclaimer": self._get_right_image_disclaimer(),
788
- "right_audio_disclaimer": self._get_right_audio_disclaimer(),
789
- "example_filter_condition": self._get_example_filter_condition(),
790
- "example_join_condition": self._get_example_join_condition(),
791
- "example_reasoning": self._get_example_reasoning(),
792
- "example_answer": self._get_example_answer(),
714
+ "example_input_fields": self._get_example_input_fields(input_modalities),
715
+ "right_example_input_fields": self._get_example_input_fields(right_input_modalities, right=True),
716
+ "example_output_fields": self._get_example_output_fields(input_modalities),
717
+ "example_context": self._get_example_context(input_modalities),
718
+ "right_example_context": self._get_example_context(right_input_modalities, right=True),
719
+ "image_disclaimer": self._get_image_disclaimer(input_modalities),
720
+ "audio_disclaimer": self._get_audio_disclaimer(input_modalities),
721
+ "right_image_disclaimer": self._get_image_disclaimer(right_input_modalities, right=True),
722
+ "right_audio_disclaimer": self._get_audio_disclaimer(right_input_modalities, right=True),
723
+ "example_filter_condition": EXAMPLE_FILTER_CONDITION,
724
+ "example_join_condition": EXAMPLE_JOIN_CONDITION,
725
+ "example_reasoning": self._get_example_reasoning(input_modalities),
726
+ "example_answer": self._get_example_answer(input_modalities),
793
727
  }
794
728
 
795
729
  # return all format kwargs
@@ -937,7 +871,7 @@ class PromptFactory:
937
871
  # get any audio messages for the chat payload (will be an empty list if no audio fields exist)
938
872
  audio_messages = self._create_audio_messages(candidate, input_fields)
939
873
 
940
- # get any right image messages for the chat payload (will be an empty list if this is not a join image prompt)
874
+ # get any right image / audio messages for the chat payload (will be an empty list if image / audio not present)
941
875
  right_image_messages, right_audio_messages = [], []
942
876
  if self.prompt_strategy.is_join_prompt():
943
877
  assert right_candidate is not None, "Right candidate must be provided for join prompts."
@@ -951,121 +885,63 @@ class PromptFactory:
951
885
  "Original messages must be provided for critique and refinement operations."
952
886
  )
953
887
 
954
- # TODO: in the future if we support many modalities (e.g. images and audio) in the same prompt,
955
- # then we will need to streamline this logic to handle the many different cases
888
+ # combine image and audio messages
889
+ image_audio_messages = image_messages + audio_messages
890
+ right_image_audio_messages = right_image_messages + right_audio_messages
891
+ has_image_audio = len(image_audio_messages) > 0
892
+ has_right_image_audio = len(right_image_audio_messages) > 0
893
+
956
894
  # construct the user messages based on the prompt strategy
957
895
  user_messages = []
958
896
  if self.prompt_strategy.is_critic_prompt() or self.prompt_strategy.is_refine_prompt():
959
- # NOTE: if this critic / refinement prompt is processing images, those images will
960
- # be part of the `original_messages` and will show up in the final chat payload
897
+ # NOTE: if this critic / refinement prompt is processing images / audio, those images / audio
898
+ # will be part of the `original_messages` and will show up in the final chat payload
961
899
  base_prompt_start, base_prompt_end = base_prompt.split("<<original-prompt-placeholder>>\n")
962
900
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
963
901
  user_messages.extend(original_messages)
964
902
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
965
903
 
966
- # image not join
967
- elif self.prompt_strategy.is_image_prompt() and not self.prompt_strategy.is_join_prompt():
968
- base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
969
- base_prompt_start, base_prompt_end = base_prompt.split("<<image-placeholder>>")
904
+ # handle joins with left and right images / audio
905
+ elif self.prompt_strategy.is_join_prompt() and has_image_audio and has_right_image_audio:
906
+ base_prompt_start, base_prompt_rest = base_prompt.split("<<image-audio-placeholder>>")
907
+ base_prompt_mid, base_prompt_end = base_prompt_rest.split("<<right-image-audio-placeholder>>")
970
908
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
971
- user_messages.extend(image_messages)
909
+ user_messages.extend(image_audio_messages)
910
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
911
+ user_messages.extend(right_image_audio_messages)
972
912
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
973
913
 
974
- # image join
975
- elif self.prompt_strategy.is_image_prompt() and self.prompt_strategy.is_join_prompt():
976
- # for join image prompts, we may have two sets of images (one from the left candidate and one from the right candidate)
977
- base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
978
- base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<image-placeholder>>")
914
+ # handle joins with only left images / audio
915
+ elif self.prompt_strategy.is_join_prompt() and has_image_audio and not has_right_image_audio:
916
+ base_prompt = base_prompt.replace("<<right-image-audio-placeholder>>", "")
917
+ base_prompt_start, base_prompt_end = base_prompt.split("<<image-audio-placeholder>>")
979
918
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
980
- user_messages.extend(image_messages)
981
- user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
982
- user_messages.extend(right_image_messages)
919
+ user_messages.extend(image_audio_messages)
983
920
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
984
921
 
985
- # audio not join
986
- elif self.prompt_strategy.is_audio_prompt() and not self.prompt_strategy.is_join_prompt():
987
- base_prompt = base_prompt.replace("<<image-placeholder>>", "")
988
- base_prompt_start, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
922
+ # handle joins with only right images / audio
923
+ elif self.prompt_strategy.is_join_prompt() and not has_image_audio and has_right_image_audio:
924
+ base_prompt = base_prompt.replace("<<image-audio-placeholder>>", "")
925
+ base_prompt_start, base_prompt_end = base_prompt.split("<<right-image-audio-placeholder>>")
989
926
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
990
- user_messages.extend(audio_messages)
927
+ user_messages.extend(right_image_audio_messages)
991
928
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
992
929
 
993
- # audio join
994
- elif self.prompt_strategy.is_audio_prompt() and self.prompt_strategy.is_join_prompt():
995
- # for join image prompts, we may have two sets of images (one from the left candidate and one from the right candidate)
996
- base_prompt = base_prompt.replace("<<image-placeholder>>", "")
997
- base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
930
+ # handle non-joins with images / audio
931
+ elif not self.prompt_strategy.is_join_prompt() and has_image_audio and not self.prompt_strategy.is_moa_aggregator_prompt():
932
+ base_prompt_start, base_prompt_end = base_prompt.split("<<image-audio-placeholder>>")
998
933
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
999
- user_messages.extend(audio_messages)
1000
- user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
1001
- user_messages.extend(right_audio_messages)
934
+ user_messages.extend(image_audio_messages)
1002
935
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
1003
936
 
937
+ # handle prompts w/no images or audio
1004
938
  else:
1005
- base_prompt = base_prompt.replace("<<image-placeholder>>", "")
1006
- base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
939
+ base_prompt = base_prompt.replace("<<image-audio-placeholder>>", "")
940
+ base_prompt = base_prompt.replace("<<right-image-audio-placeholder>>", "")
1007
941
  user_messages.append({"role": "user", "type": "text", "content": base_prompt.format(**kwargs)})
1008
942
 
1009
943
  return user_messages
1010
944
 
1011
- def _process_custom_user_prompt(self, candidate: DataRecord, input_fields: list[str], **kwargs) -> list[dict]:
1012
- """
1013
- Processes a custom user prompt provided by the user.
1014
-
1015
- Args:
1016
- candidate (DataRecord): The input record.
1017
- kwargs: The keyword arguments provided by the user.
1018
-
1019
- Returns:
1020
- list[dict]: The messages for the chat payload.
1021
- """
1022
- # get the user prompt
1023
- user_prompt: str = kwargs["prompt"]
1024
-
1025
- # sanity check that we have all the inputs for the user's prompt template
1026
- prompt_field_names = [fname for _, fname, _, _ in Formatter().parse(user_prompt) if fname]
1027
- fields_check = all([field in input_fields for field in prompt_field_names])
1028
- if not fields_check:
1029
- if sorted(candidate.get_field_names()) != (input_fields):
1030
- err_msg = (
1031
- f"Prompt string has fields which are not in input fields.\n"
1032
- f"Prompt fields: {prompt_field_names}\n"
1033
- f"Computed fields: {candidate.get_field_names()}\n"
1034
- f"Input fields: {input_fields}\n"
1035
- f"Be careful that you are not projecting out computed fields. "
1036
- f"If you use `depends_on` in your program, make sure it includes the fields you need."
1037
- )
1038
- else:
1039
- err_msg = (
1040
- f"Prompt string has fields which are not in input fields.\n"
1041
- f"Prompt fields: {prompt_field_names}\n"
1042
- f"Input fields: {input_fields}\n"
1043
- )
1044
- assert fields_check, err_msg
1045
-
1046
- # build set of format kwargs
1047
- format_kwargs = {
1048
- field_name: "<bytes>"
1049
- if candidate.get_field_type(field_name).annotation in [bytes, bytes | None]
1050
- else candidate[field_name]
1051
- for field_name in input_fields
1052
- }
1053
-
1054
- # split prompt on <<image-placeholder>> if it exists
1055
- if "<<image-placeholder>>" in user_prompt:
1056
- raise NotImplementedError("Image prompts are not yet supported.")
1057
-
1058
- prompt_sections = user_prompt.split("<<image-placeholder>>")
1059
- messages = [{"role": "user", "type": "text", "content": prompt_sections[0].format(**format_kwargs)}]
1060
-
1061
- # NOTE: this currently assumes that the user can only provide a single <<image-placeholder>>
1062
- if len(prompt_sections) > 1:
1063
- image_messages = self._create_image_messages(candidate, input_fields)
1064
- messages.extend(image_messages)
1065
- messages.append({"role": "user", "type": "text", "content": prompt_sections[1].format(**format_kwargs)})
1066
-
1067
- return messages
1068
-
1069
945
  def create_messages(self, candidate: DataRecord, output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
1070
946
  """
1071
947
  Creates the messages for the chat payload based on the prompt strategy.
@@ -1090,19 +966,15 @@ class PromptFactory:
1090
966
  input_fields = self._get_input_fields(candidate, **kwargs)
1091
967
  right_input_fields = [] if right_candidate is None else self._get_input_fields(right_candidate, **kwargs)
1092
968
 
1093
- # if the user provides a prompt, we process that prompt into messages and return them
1094
- if "prompt" in kwargs:
1095
- messages = []
1096
- if "system_prompt" in kwargs:
1097
- messages.append({"role": "system", "type": "text", "content": kwargs["system_prompt"]})
1098
- messages.extend(self._process_custom_user_prompt(candidate, input_fields, **kwargs))
1099
- return messages
969
+ # use input fields to determine the left / right input modalities
970
+ input_modalities = self._get_input_modalities(candidate, input_fields)
971
+ right_input_modalities = set() if right_candidate is None else self._get_input_modalities(right_candidate, right_input_fields)
1100
972
 
1101
973
  # initialize messages
1102
974
  messages = []
1103
975
 
1104
976
  # compute the full dictionary of format kwargs and add to kwargs
1105
- format_kwargs = self._get_all_format_kwargs(candidate, input_fields, output_fields, right_candidate, right_input_fields, **kwargs)
977
+ format_kwargs = self._get_all_format_kwargs(candidate, input_fields, input_modalities, output_fields, right_candidate, right_input_fields, right_input_modalities, **kwargs)
1106
978
  kwargs = {**kwargs, **format_kwargs}
1107
979
 
1108
980
  # generate system message (if applicable)