palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +343 -209
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +639 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +62 -6
- palimpzest/prompts/filter_prompts.py +51 -6
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
- palimpzest/prompts/prompt_factory.py +375 -47
- palimpzest/prompts/split_proposer_prompts.py +1 -1
- palimpzest/prompts/util_phrases.py +5 -0
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +160 -331
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +33 -19
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +26 -16
- palimpzest/query/operators/join.py +403 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +205 -77
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +42 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +32 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
- palimpzest-0.8.1.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.21.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -4,17 +4,25 @@ import base64
|
|
|
4
4
|
import json
|
|
5
5
|
from string import Formatter
|
|
6
6
|
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
7
9
|
from palimpzest.constants import (
|
|
8
|
-
|
|
10
|
+
LLAMA_CONTEXT_TOKENS_LIMIT,
|
|
9
11
|
TOKENS_PER_CHARACTER,
|
|
10
12
|
Cardinality,
|
|
11
13
|
Model,
|
|
12
14
|
PromptStrategy,
|
|
13
15
|
)
|
|
14
16
|
from palimpzest.core.elements.records import DataRecord
|
|
15
|
-
from palimpzest.core.lib.
|
|
16
|
-
from palimpzest.core.lib.schemas import Schema
|
|
17
|
+
from palimpzest.core.lib.schemas import AudioBase64, AudioFilepath, ImageBase64, ImageFilepath, ImageURL
|
|
17
18
|
from palimpzest.prompts.convert_prompts import (
|
|
19
|
+
COT_QA_AUDIO_DISCLAIMER,
|
|
20
|
+
COT_QA_AUDIO_EXAMPLE_ANSWER,
|
|
21
|
+
COT_QA_AUDIO_EXAMPLE_CONTEXT,
|
|
22
|
+
COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
23
|
+
COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
|
|
24
|
+
COT_QA_AUDIO_EXAMPLE_REASONING,
|
|
25
|
+
COT_QA_AUDIO_JOB_INSTRUCTION,
|
|
18
26
|
COT_QA_BASE_SYSTEM_PROMPT,
|
|
19
27
|
COT_QA_BASE_USER_PROMPT,
|
|
20
28
|
COT_QA_EXAMPLE_ANSWER,
|
|
@@ -30,6 +38,8 @@ from palimpzest.prompts.convert_prompts import (
|
|
|
30
38
|
COT_QA_IMAGE_EXAMPLE_REASONING,
|
|
31
39
|
COT_QA_IMAGE_JOB_INSTRUCTION,
|
|
32
40
|
COT_QA_JOB_INSTRUCTION,
|
|
41
|
+
COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
42
|
+
COT_QA_NO_REASONING_BASE_USER_PROMPT,
|
|
33
43
|
)
|
|
34
44
|
from palimpzest.prompts.critique_and_refine_convert_prompts import (
|
|
35
45
|
BASE_CRITIQUE_PROMPT,
|
|
@@ -42,6 +52,12 @@ from palimpzest.prompts.critique_and_refine_convert_prompts import (
|
|
|
42
52
|
COT_QA_REFINEMENT_FINISH_INSTRUCTION,
|
|
43
53
|
)
|
|
44
54
|
from palimpzest.prompts.filter_prompts import (
|
|
55
|
+
COT_BOOL_AUDIO_DISCLAIMER,
|
|
56
|
+
COT_BOOL_AUDIO_EXAMPLE_CONTEXT,
|
|
57
|
+
COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
|
|
58
|
+
COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
59
|
+
COT_BOOL_AUDIO_EXAMPLE_REASONING,
|
|
60
|
+
COT_BOOL_AUDIO_JOB_INSTRUCTION,
|
|
45
61
|
COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
46
62
|
COT_BOOL_BASE_USER_PROMPT,
|
|
47
63
|
COT_BOOL_EXAMPLE_CONTEXT,
|
|
@@ -55,6 +71,39 @@ from palimpzest.prompts.filter_prompts import (
|
|
|
55
71
|
COT_BOOL_IMAGE_EXAMPLE_REASONING,
|
|
56
72
|
COT_BOOL_IMAGE_JOB_INSTRUCTION,
|
|
57
73
|
COT_BOOL_JOB_INSTRUCTION,
|
|
74
|
+
COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
75
|
+
COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
|
|
76
|
+
)
|
|
77
|
+
from palimpzest.prompts.join_prompts import (
|
|
78
|
+
COT_JOIN_AUDIO_DISCLAIMER,
|
|
79
|
+
COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
|
|
80
|
+
COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
81
|
+
COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
|
|
82
|
+
COT_JOIN_AUDIO_EXAMPLE_REASONING,
|
|
83
|
+
COT_JOIN_AUDIO_JOB_INSTRUCTION,
|
|
84
|
+
COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
|
|
85
|
+
COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
86
|
+
COT_JOIN_BASE_SYSTEM_PROMPT,
|
|
87
|
+
COT_JOIN_BASE_USER_PROMPT,
|
|
88
|
+
COT_JOIN_EXAMPLE_CONTEXT,
|
|
89
|
+
COT_JOIN_EXAMPLE_INPUT_FIELDS,
|
|
90
|
+
COT_JOIN_EXAMPLE_JOIN_CONDITION,
|
|
91
|
+
COT_JOIN_EXAMPLE_REASONING,
|
|
92
|
+
COT_JOIN_IMAGE_DISCLAIMER,
|
|
93
|
+
COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
|
|
94
|
+
COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
95
|
+
COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
|
|
96
|
+
COT_JOIN_IMAGE_EXAMPLE_REASONING,
|
|
97
|
+
COT_JOIN_IMAGE_JOB_INSTRUCTION,
|
|
98
|
+
COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
|
|
99
|
+
COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
100
|
+
COT_JOIN_JOB_INSTRUCTION,
|
|
101
|
+
COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
102
|
+
COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
103
|
+
COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
|
|
104
|
+
COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
|
|
105
|
+
COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
106
|
+
COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
|
|
58
107
|
)
|
|
59
108
|
from palimpzest.prompts.moa_aggregator_convert_prompts import (
|
|
60
109
|
COT_MOA_AGG_BASE_SYSTEM_PROMPT,
|
|
@@ -89,6 +138,7 @@ from palimpzest.prompts.split_proposer_prompts import (
|
|
|
89
138
|
SPLIT_PROPOSER_JOB_INSTRUCTION,
|
|
90
139
|
)
|
|
91
140
|
from palimpzest.prompts.util_phrases import (
|
|
141
|
+
DESC_SECTION,
|
|
92
142
|
ONE_TO_MANY_OUTPUT_FORMAT_INSTRUCTION,
|
|
93
143
|
ONE_TO_ONE_OUTPUT_FORMAT_INSTRUCTION,
|
|
94
144
|
)
|
|
@@ -99,11 +149,25 @@ class PromptFactory:
|
|
|
99
149
|
|
|
100
150
|
BASE_SYSTEM_PROMPT_MAP = {
|
|
101
151
|
PromptStrategy.COT_BOOL: COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
152
|
+
PromptStrategy.COT_BOOL_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
153
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
154
|
+
PromptStrategy.COT_BOOL_AUDIO_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
102
155
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
156
|
+
PromptStrategy.COT_BOOL_IMAGE_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
157
|
+
PromptStrategy.COT_JOIN: COT_JOIN_BASE_SYSTEM_PROMPT,
|
|
158
|
+
PromptStrategy.COT_JOIN_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
159
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_BASE_SYSTEM_PROMPT,
|
|
160
|
+
PromptStrategy.COT_JOIN_AUDIO_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
161
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_BASE_SYSTEM_PROMPT,
|
|
162
|
+
PromptStrategy.COT_JOIN_IMAGE_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
103
163
|
PromptStrategy.COT_QA: COT_QA_BASE_SYSTEM_PROMPT,
|
|
164
|
+
PromptStrategy.COT_QA_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
165
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_BASE_SYSTEM_PROMPT,
|
|
166
|
+
PromptStrategy.COT_QA_AUDIO_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
104
167
|
PromptStrategy.COT_QA_CRITIC: None,
|
|
105
168
|
PromptStrategy.COT_QA_REFINE: None,
|
|
106
169
|
PromptStrategy.COT_QA_IMAGE: COT_QA_BASE_SYSTEM_PROMPT,
|
|
170
|
+
PromptStrategy.COT_QA_IMAGE_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
107
171
|
PromptStrategy.COT_QA_IMAGE_CRITIC: None,
|
|
108
172
|
PromptStrategy.COT_QA_IMAGE_REFINE: None,
|
|
109
173
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
@@ -114,11 +178,25 @@ class PromptFactory:
|
|
|
114
178
|
}
|
|
115
179
|
BASE_USER_PROMPT_MAP = {
|
|
116
180
|
PromptStrategy.COT_BOOL: COT_BOOL_BASE_USER_PROMPT,
|
|
181
|
+
PromptStrategy.COT_BOOL_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
|
|
182
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_BASE_USER_PROMPT,
|
|
183
|
+
PromptStrategy.COT_BOOL_AUDIO_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
|
|
117
184
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_USER_PROMPT,
|
|
185
|
+
PromptStrategy.COT_BOOL_IMAGE_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
|
|
186
|
+
PromptStrategy.COT_JOIN: COT_JOIN_BASE_USER_PROMPT,
|
|
187
|
+
PromptStrategy.COT_JOIN_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
188
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_BASE_USER_PROMPT,
|
|
189
|
+
PromptStrategy.COT_JOIN_AUDIO_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
190
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_BASE_USER_PROMPT,
|
|
191
|
+
PromptStrategy.COT_JOIN_IMAGE_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
118
192
|
PromptStrategy.COT_QA: COT_QA_BASE_USER_PROMPT,
|
|
193
|
+
PromptStrategy.COT_QA_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
|
|
194
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_BASE_USER_PROMPT,
|
|
195
|
+
PromptStrategy.COT_QA_AUDIO_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
|
|
119
196
|
PromptStrategy.COT_QA_CRITIC: BASE_CRITIQUE_PROMPT,
|
|
120
197
|
PromptStrategy.COT_QA_REFINE: BASE_REFINEMENT_PROMPT,
|
|
121
198
|
PromptStrategy.COT_QA_IMAGE: COT_QA_BASE_USER_PROMPT,
|
|
199
|
+
PromptStrategy.COT_QA_IMAGE_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
|
|
122
200
|
PromptStrategy.COT_QA_IMAGE_CRITIC: BASE_CRITIQUE_PROMPT,
|
|
123
201
|
PromptStrategy.COT_QA_IMAGE_REFINE: BASE_REFINEMENT_PROMPT,
|
|
124
202
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_BASE_USER_PROMPT,
|
|
@@ -128,10 +206,11 @@ class PromptFactory:
|
|
|
128
206
|
PromptStrategy.SPLIT_MERGER: COT_SPLIT_MERGER_BASE_USER_PROMPT,
|
|
129
207
|
}
|
|
130
208
|
|
|
131
|
-
def __init__(self, prompt_strategy: PromptStrategy, model: Model, cardinality: Cardinality) -> None:
|
|
209
|
+
def __init__(self, prompt_strategy: PromptStrategy, model: Model, cardinality: Cardinality, desc: str | None = None) -> None:
|
|
132
210
|
self.prompt_strategy = prompt_strategy
|
|
133
211
|
self.model = model
|
|
134
212
|
self.cardinality = cardinality
|
|
213
|
+
self.desc = desc
|
|
135
214
|
|
|
136
215
|
def _get_context(self, candidate: DataRecord, input_fields: list[str]) -> str:
|
|
137
216
|
"""
|
|
@@ -144,8 +223,9 @@ class PromptFactory:
|
|
|
144
223
|
Returns:
|
|
145
224
|
str: The context.
|
|
146
225
|
"""
|
|
226
|
+
# TODO: remove mask_filepaths=True after SemBench evaluation
|
|
147
227
|
# get context from input record (project_cols will be None if not provided in kwargs)
|
|
148
|
-
context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields)
|
|
228
|
+
context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True)
|
|
149
229
|
|
|
150
230
|
# TODO: MOVE THIS LOGIC INTO A CHUNKING / CONTEXT MANAGEMENT CLASS
|
|
151
231
|
# - this class should be able to:
|
|
@@ -155,12 +235,12 @@ class PromptFactory:
|
|
|
155
235
|
# TODO: this does not work for image prompts
|
|
156
236
|
# TODO: this ignores the size of the `orignal_messages` in critique and refine prompts
|
|
157
237
|
# cut down on context based on window length
|
|
158
|
-
if self.model.is_llama_model()
|
|
238
|
+
if self.model.is_llama_model():
|
|
159
239
|
total_context_len = len(json.dumps(context, indent=2))
|
|
160
240
|
|
|
161
241
|
# sort fields by length and progressively strip from the longest field until it is short enough;
|
|
162
|
-
# NOTE:
|
|
163
|
-
while total_context_len * TOKENS_PER_CHARACTER >
|
|
242
|
+
# NOTE: LLAMA_CONTEXT_TOKENS_LIMIT is a rough estimate which leaves room for the rest of the prompt text
|
|
243
|
+
while total_context_len * TOKENS_PER_CHARACTER > LLAMA_CONTEXT_TOKENS_LIMIT:
|
|
164
244
|
# sort fields by length
|
|
165
245
|
field_lengths = [(field, len(value) if value is not None else 0) for field, value in context.items()]
|
|
166
246
|
sorted_fields = sorted(field_lengths, key=lambda item: item[1], reverse=True)
|
|
@@ -169,7 +249,7 @@ class PromptFactory:
|
|
|
169
249
|
longest_field_name, longest_field_length = sorted_fields[0]
|
|
170
250
|
|
|
171
251
|
# trim the field
|
|
172
|
-
context_factor =
|
|
252
|
+
context_factor = LLAMA_CONTEXT_TOKENS_LIMIT / (total_context_len * TOKENS_PER_CHARACTER)
|
|
173
253
|
keep_frac_idx = int(longest_field_length * context_factor)
|
|
174
254
|
context[longest_field_name] = context[longest_field_name][:keep_frac_idx]
|
|
175
255
|
|
|
@@ -191,7 +271,11 @@ class PromptFactory:
|
|
|
191
271
|
Returns:
|
|
192
272
|
list[str]: The list of input field names.
|
|
193
273
|
"""
|
|
194
|
-
|
|
274
|
+
# NOTE: joins will include left and right input fields in project_cols, so we have to check
|
|
275
|
+
# if the field is in the candidate record
|
|
276
|
+
input_fields = kwargs.get("project_cols", candidate.get_field_names())
|
|
277
|
+
input_fields = [field for field in input_fields if field in candidate.get_field_names()]
|
|
278
|
+
return input_fields
|
|
195
279
|
|
|
196
280
|
def _get_input_fields_desc(self, candidate: DataRecord, input_fields: list[str]) -> str:
|
|
197
281
|
"""
|
|
@@ -205,7 +289,7 @@ class PromptFactory:
|
|
|
205
289
|
"""
|
|
206
290
|
input_fields_desc = ""
|
|
207
291
|
for field_name in input_fields:
|
|
208
|
-
input_fields_desc += f"- {field_name}: {candidate.get_field_type(field_name).
|
|
292
|
+
input_fields_desc += f"- {field_name}: {candidate.get_field_type(field_name).description}\n"
|
|
209
293
|
|
|
210
294
|
return input_fields_desc[:-1]
|
|
211
295
|
|
|
@@ -221,13 +305,13 @@ class PromptFactory:
|
|
|
221
305
|
str: The output fields description.
|
|
222
306
|
"""
|
|
223
307
|
output_fields_desc = ""
|
|
224
|
-
output_schema:
|
|
308
|
+
output_schema: BaseModel = kwargs.get("output_schema")
|
|
225
309
|
if self.prompt_strategy.is_convert_prompt():
|
|
226
310
|
assert output_schema is not None, "Output schema must be provided for convert prompts."
|
|
227
311
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
output_fields_desc += f"- {field_name}: {
|
|
312
|
+
for field_name in sorted(output_fields):
|
|
313
|
+
desc = output_schema.model_fields[field_name].description
|
|
314
|
+
output_fields_desc += f"- {field_name}: {'no description available' if desc is None else desc}\n"
|
|
231
315
|
|
|
232
316
|
# strip the last newline characters from the field descriptions and return
|
|
233
317
|
return output_fields_desc[:-1]
|
|
@@ -245,6 +329,19 @@ class PromptFactory:
|
|
|
245
329
|
|
|
246
330
|
return filter_condition
|
|
247
331
|
|
|
332
|
+
def _get_join_condition(self, **kwargs) -> str | None:
|
|
333
|
+
"""
|
|
334
|
+
Returns the join condition for the join operation.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
str | None: The join condition (if applicable).
|
|
338
|
+
"""
|
|
339
|
+
join_condition = kwargs.get("join_condition")
|
|
340
|
+
if self.prompt_strategy.is_join_prompt():
|
|
341
|
+
assert join_condition is not None, "Join condition must be provided for join operations."
|
|
342
|
+
|
|
343
|
+
return join_condition
|
|
344
|
+
|
|
248
345
|
def _get_original_output(self, **kwargs) -> str | None:
|
|
249
346
|
"""
|
|
250
347
|
Returns the original output from a previous model generation for the critique and refinement operations.
|
|
@@ -337,8 +434,13 @@ class PromptFactory:
|
|
|
337
434
|
"""
|
|
338
435
|
prompt_strategy_to_job_instruction = {
|
|
339
436
|
PromptStrategy.COT_BOOL: COT_BOOL_JOB_INSTRUCTION,
|
|
437
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_JOB_INSTRUCTION,
|
|
340
438
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_JOB_INSTRUCTION,
|
|
439
|
+
PromptStrategy.COT_JOIN: COT_JOIN_JOB_INSTRUCTION,
|
|
440
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_JOB_INSTRUCTION,
|
|
441
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_JOB_INSTRUCTION,
|
|
341
442
|
PromptStrategy.COT_QA: COT_QA_JOB_INSTRUCTION,
|
|
443
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_JOB_INSTRUCTION,
|
|
342
444
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_JOB_INSTRUCTION,
|
|
343
445
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_JOB_INSTRUCTION,
|
|
344
446
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_JOB_INSTRUCTION,
|
|
@@ -346,6 +448,19 @@ class PromptFactory:
|
|
|
346
448
|
}
|
|
347
449
|
return prompt_strategy_to_job_instruction.get(self.prompt_strategy)
|
|
348
450
|
|
|
451
|
+
def _get_desc_section(self) -> str:
|
|
452
|
+
"""
|
|
453
|
+
Returns the description section for the prompt.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
str: The description section (if applicable).
|
|
457
|
+
"""
|
|
458
|
+
desc_section = ""
|
|
459
|
+
if self.desc is not None:
|
|
460
|
+
desc_section = DESC_SECTION.format(desc=self.desc)
|
|
461
|
+
|
|
462
|
+
return desc_section
|
|
463
|
+
|
|
349
464
|
def _get_critique_criteria(self) -> str | None:
|
|
350
465
|
"""
|
|
351
466
|
Returns the critique criteria for the critique operation.
|
|
@@ -402,8 +517,13 @@ class PromptFactory:
|
|
|
402
517
|
"""
|
|
403
518
|
prompt_strategy_to_example_input_fields = {
|
|
404
519
|
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_INPUT_FIELDS,
|
|
520
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
405
521
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
522
|
+
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_INPUT_FIELDS,
|
|
523
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
524
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
406
525
|
PromptStrategy.COT_QA: COT_QA_EXAMPLE_INPUT_FIELDS,
|
|
526
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
407
527
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
408
528
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_INPUT_FIELDS,
|
|
409
529
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
@@ -412,6 +532,21 @@ class PromptFactory:
|
|
|
412
532
|
|
|
413
533
|
return prompt_strategy_to_example_input_fields.get(self.prompt_strategy)
|
|
414
534
|
|
|
535
|
+
def _get_right_example_input_fields(self) -> str | None:
|
|
536
|
+
"""
|
|
537
|
+
Returns the example right input fields for the join prompt.
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
str | None: The example right input fields (if applicable).
|
|
541
|
+
"""
|
|
542
|
+
prompt_strategy_to_right_example_input_fields = {
|
|
543
|
+
PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
544
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
545
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
return prompt_strategy_to_right_example_input_fields.get(self.prompt_strategy)
|
|
549
|
+
|
|
415
550
|
def _get_example_output_fields(self) -> str | None:
|
|
416
551
|
"""
|
|
417
552
|
Returns the example output fields for the prompt.
|
|
@@ -421,6 +556,7 @@ class PromptFactory:
|
|
|
421
556
|
"""
|
|
422
557
|
prompt_strategy_to_example_output_fields = {
|
|
423
558
|
PromptStrategy.COT_QA: COT_QA_EXAMPLE_OUTPUT_FIELDS,
|
|
559
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
|
|
424
560
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
425
561
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
|
|
426
562
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
@@ -438,8 +574,13 @@ class PromptFactory:
|
|
|
438
574
|
"""
|
|
439
575
|
prompt_strategy_to_example_context = {
|
|
440
576
|
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_CONTEXT,
|
|
577
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_CONTEXT,
|
|
441
578
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_CONTEXT,
|
|
579
|
+
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_CONTEXT,
|
|
580
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
|
|
581
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
|
|
442
582
|
PromptStrategy.COT_QA: COT_QA_EXAMPLE_CONTEXT,
|
|
583
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_CONTEXT,
|
|
443
584
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_CONTEXT,
|
|
444
585
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_CONTEXT,
|
|
445
586
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_CONTEXT,
|
|
@@ -448,6 +589,21 @@ class PromptFactory:
|
|
|
448
589
|
|
|
449
590
|
return prompt_strategy_to_example_context.get(self.prompt_strategy)
|
|
450
591
|
|
|
592
|
+
def _get_right_example_context(self) -> str | None:
|
|
593
|
+
"""
|
|
594
|
+
Returns the right example context for the join prompt.
|
|
595
|
+
|
|
596
|
+
Returns:
|
|
597
|
+
str | None: The right example context (if applicable).
|
|
598
|
+
"""
|
|
599
|
+
prompt_strategy_to_right_example_context = {
|
|
600
|
+
PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
|
|
601
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
|
|
602
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
return prompt_strategy_to_right_example_context.get(self.prompt_strategy)
|
|
606
|
+
|
|
451
607
|
def _get_image_disclaimer(self) -> str:
|
|
452
608
|
"""
|
|
453
609
|
Returns the image disclaimer for the prompt. The disclaimer must be an empty string
|
|
@@ -458,12 +614,57 @@ class PromptFactory:
|
|
|
458
614
|
"""
|
|
459
615
|
prompt_strategy_to_image_disclaimer = {
|
|
460
616
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_DISCLAIMER,
|
|
617
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_DISCLAIMER,
|
|
461
618
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_DISCLAIMER,
|
|
462
619
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_DISCLAIMER,
|
|
463
620
|
}
|
|
464
621
|
|
|
465
622
|
return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
|
|
466
623
|
|
|
624
|
+
def _get_audio_disclaimer(self) -> str:
|
|
625
|
+
"""
|
|
626
|
+
Returns the audio disclaimer for the prompt. The disclaimer must be an empty string
|
|
627
|
+
for text prompts.
|
|
628
|
+
|
|
629
|
+
Returns:
|
|
630
|
+
str: The audio disclaimer. If this is a text prompt then it is an empty string.
|
|
631
|
+
"""
|
|
632
|
+
prompt_strategy_to_audio_disclaimer = {
|
|
633
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_DISCLAIMER,
|
|
634
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_DISCLAIMER,
|
|
635
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_DISCLAIMER,
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
|
|
639
|
+
|
|
640
|
+
def _get_right_image_disclaimer(self) -> str:
|
|
641
|
+
"""
|
|
642
|
+
Returns the right image disclaimer for the prompt. The disclaimer must be an empty string
|
|
643
|
+
for text prompts.
|
|
644
|
+
|
|
645
|
+
Returns:
|
|
646
|
+
str: The right image disclaimer. If this is a text prompt then it is an empty string.
|
|
647
|
+
"""
|
|
648
|
+
prompt_strategy_to_image_disclaimer = {
|
|
649
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
|
|
653
|
+
|
|
654
|
+
def _get_right_audio_disclaimer(self) -> str:
|
|
655
|
+
"""
|
|
656
|
+
Returns the right audio disclaimer for the prompt. The disclaimer must be an empty string
|
|
657
|
+
for text prompts.
|
|
658
|
+
|
|
659
|
+
Returns:
|
|
660
|
+
str: The right audio disclaimer. If this is a text prompt then it is an empty string.
|
|
661
|
+
"""
|
|
662
|
+
prompt_strategy_to_audio_disclaimer = {
|
|
663
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
|
|
664
|
+
}
|
|
665
|
+
|
|
666
|
+
return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
|
|
667
|
+
|
|
467
668
|
def _get_example_filter_condition(self) -> str | None:
|
|
468
669
|
"""
|
|
469
670
|
Returns the example filter condition for the prompt.
|
|
@@ -473,11 +674,27 @@ class PromptFactory:
|
|
|
473
674
|
"""
|
|
474
675
|
prompt_strategy_to_example_filter_condition = {
|
|
475
676
|
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_FILTER_CONDITION,
|
|
677
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
|
|
476
678
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_FILTER_CONDITION,
|
|
477
679
|
}
|
|
478
680
|
|
|
479
681
|
return prompt_strategy_to_example_filter_condition.get(self.prompt_strategy)
|
|
480
682
|
|
|
683
|
+
def _get_example_join_condition(self) -> str | None:
|
|
684
|
+
"""
|
|
685
|
+
Returns the example join condition for the prompt.
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
str | None: The example join condition (if applicable).
|
|
689
|
+
"""
|
|
690
|
+
prompt_strategy_to_example_join_condition = {
|
|
691
|
+
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_JOIN_CONDITION,
|
|
692
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
|
|
693
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
return prompt_strategy_to_example_join_condition.get(self.prompt_strategy)
|
|
697
|
+
|
|
481
698
|
def _get_example_reasoning(self) -> str | None:
|
|
482
699
|
"""
|
|
483
700
|
Returns the example reasoning for the prompt.
|
|
@@ -487,8 +704,13 @@ class PromptFactory:
|
|
|
487
704
|
"""
|
|
488
705
|
prompt_strategy_to_example_reasoning = {
|
|
489
706
|
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_REASONING,
|
|
707
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_REASONING,
|
|
490
708
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_REASONING,
|
|
709
|
+
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_REASONING,
|
|
710
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_REASONING,
|
|
711
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_REASONING,
|
|
491
712
|
PromptStrategy.COT_QA: COT_QA_EXAMPLE_REASONING,
|
|
713
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_REASONING,
|
|
492
714
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_REASONING,
|
|
493
715
|
}
|
|
494
716
|
|
|
@@ -503,6 +725,7 @@ class PromptFactory:
|
|
|
503
725
|
"""
|
|
504
726
|
prompt_strategy_to_example_answer = {
|
|
505
727
|
PromptStrategy.COT_QA: COT_QA_EXAMPLE_ANSWER,
|
|
728
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_ANSWER,
|
|
506
729
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_ANSWER,
|
|
507
730
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_ANSWER,
|
|
508
731
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_ANSWER,
|
|
@@ -512,7 +735,7 @@ class PromptFactory:
|
|
|
512
735
|
return prompt_strategy_to_example_answer.get(self.prompt_strategy)
|
|
513
736
|
|
|
514
737
|
def _get_all_format_kwargs(
|
|
515
|
-
self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], **kwargs
|
|
738
|
+
self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs
|
|
516
739
|
) -> dict:
|
|
517
740
|
"""
|
|
518
741
|
Returns a dictionary containing all the format kwargs for templating the prompts.
|
|
@@ -532,24 +755,39 @@ class PromptFactory:
|
|
|
532
755
|
"input_fields_desc": self._get_input_fields_desc(candidate, input_fields),
|
|
533
756
|
"output_fields_desc": self._get_output_fields_desc(output_fields, **kwargs),
|
|
534
757
|
"filter_condition": self._get_filter_condition(**kwargs),
|
|
758
|
+
"join_condition": self._get_join_condition(**kwargs),
|
|
535
759
|
"original_output": self._get_original_output(**kwargs),
|
|
536
760
|
"critique_output": self._get_critique_output(**kwargs),
|
|
537
761
|
"model_responses": self._get_model_responses(**kwargs),
|
|
538
762
|
"chunk_outputs": self._get_chunk_outputs(**kwargs),
|
|
539
763
|
}
|
|
540
764
|
|
|
765
|
+
# if a right candidate is provided, we also get the context and input field descriptions for the right candidate
|
|
766
|
+
if right_candidate is not None:
|
|
767
|
+
input_format_kwargs.update({
|
|
768
|
+
"right_context": self._get_context(right_candidate, right_input_fields),
|
|
769
|
+
"right_input_fields_desc": self._get_input_fields_desc(right_candidate, right_input_fields),
|
|
770
|
+
})
|
|
771
|
+
|
|
541
772
|
# get format kwargs which depend on the prompt strategy
|
|
542
773
|
prompt_strategy_format_kwargs = {
|
|
543
774
|
"output_format_instruction": self._get_output_format_instruction(),
|
|
544
775
|
"job_instruction": self._get_job_instruction(),
|
|
776
|
+
"desc_section": self._get_desc_section(),
|
|
545
777
|
"critique_criteria": self._get_critique_criteria(),
|
|
546
778
|
"refinement_criteria": self._get_refinement_criteria(),
|
|
547
779
|
"finish_instruction": self._get_finish_instruction(),
|
|
548
780
|
"example_input_fields": self._get_example_input_fields(),
|
|
781
|
+
"right_example_input_fields": self._get_right_example_input_fields(),
|
|
549
782
|
"example_output_fields": self._get_example_output_fields(),
|
|
550
783
|
"example_context": self._get_example_context(),
|
|
784
|
+
"right_example_context": self._get_right_example_context(),
|
|
551
785
|
"image_disclaimer": self._get_image_disclaimer(),
|
|
786
|
+
"audio_disclaimer": self._get_audio_disclaimer(),
|
|
787
|
+
"right_image_disclaimer": self._get_right_image_disclaimer(),
|
|
788
|
+
"right_audio_disclaimer": self._get_right_audio_disclaimer(),
|
|
552
789
|
"example_filter_condition": self._get_example_filter_condition(),
|
|
790
|
+
"example_join_condition": self._get_example_join_condition(),
|
|
553
791
|
"example_reasoning": self._get_example_reasoning(),
|
|
554
792
|
"example_answer": self._get_example_answer(),
|
|
555
793
|
}
|
|
@@ -557,6 +795,53 @@ class PromptFactory:
|
|
|
557
795
|
# return all format kwargs
|
|
558
796
|
return {**input_format_kwargs, **prompt_strategy_format_kwargs}
|
|
559
797
|
|
|
798
|
+
def _create_audio_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
|
|
799
|
+
"""
|
|
800
|
+
Parses the candidate record and returns the audio messages for the chat payload.
|
|
801
|
+
|
|
802
|
+
Args:
|
|
803
|
+
candidate (DataRecord): The input record.
|
|
804
|
+
input_fields (list[str]): The list of input fields.
|
|
805
|
+
|
|
806
|
+
Returns:
|
|
807
|
+
list[dict]: The audio messages for the chat payload.
|
|
808
|
+
"""
|
|
809
|
+
# create a message for each audio recording in an input field with an audio (or list of audio) type
|
|
810
|
+
audio_content = []
|
|
811
|
+
for field_name in input_fields:
|
|
812
|
+
field_value = candidate[field_name]
|
|
813
|
+
field_type = candidate.get_field_type(field_name)
|
|
814
|
+
|
|
815
|
+
# audio filepath (or list of audio filepaths)
|
|
816
|
+
if field_type.annotation in [AudioFilepath, AudioFilepath | None]:
|
|
817
|
+
with open(field_value, "rb") as f:
|
|
818
|
+
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
819
|
+
audio_content.append(
|
|
820
|
+
{"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None]:
|
|
824
|
+
for audio_filepath in field_value:
|
|
825
|
+
with open(audio_filepath, "rb") as f:
|
|
826
|
+
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
827
|
+
audio_content.append(
|
|
828
|
+
{"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
# pre-encoded images (or list of pre-encoded images)
|
|
832
|
+
elif field_type.annotation in [AudioBase64, AudioBase64 | None]:
|
|
833
|
+
audio_content.append(
|
|
834
|
+
{"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None]:
|
|
838
|
+
for base64_audio in field_value:
|
|
839
|
+
audio_content.append(
|
|
840
|
+
{"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
return [{"role": "user", "type": "input_audio", "content": audio_content}] if len(audio_content) > 0 else []
|
|
844
|
+
|
|
560
845
|
def _create_image_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
|
|
561
846
|
"""
|
|
562
847
|
Parses the candidate record and returns the image messages for the chat payload.
|
|
@@ -569,50 +854,48 @@ class PromptFactory:
|
|
|
569
854
|
list[dict]: The image messages for the chat payload.
|
|
570
855
|
"""
|
|
571
856
|
# create a message for each image in an input field with an image (or list of image) type
|
|
572
|
-
|
|
857
|
+
image_content = []
|
|
573
858
|
for field_name in input_fields:
|
|
574
859
|
field_value = candidate[field_name]
|
|
575
860
|
field_type = candidate.get_field_type(field_name)
|
|
576
861
|
|
|
577
862
|
# image filepath (or list of image filepaths)
|
|
578
|
-
if
|
|
863
|
+
if field_type.annotation in [ImageFilepath, ImageFilepath | None]:
|
|
579
864
|
with open(field_value, "rb") as f:
|
|
580
865
|
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
581
|
-
|
|
582
|
-
{"
|
|
866
|
+
image_content.append(
|
|
867
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
|
|
583
868
|
)
|
|
584
869
|
|
|
585
|
-
elif
|
|
870
|
+
elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None]:
|
|
586
871
|
for image_filepath in field_value:
|
|
587
872
|
with open(image_filepath, "rb") as f:
|
|
588
873
|
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
589
|
-
|
|
590
|
-
{"
|
|
874
|
+
image_content.append(
|
|
875
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
|
|
591
876
|
)
|
|
592
877
|
|
|
593
878
|
# image url (or list of image urls)
|
|
594
|
-
elif
|
|
595
|
-
|
|
879
|
+
elif field_type.annotation in [ImageURL, ImageURL | None]:
|
|
880
|
+
image_content.append({"type": "image_url", "image_url": {"url": field_value}})
|
|
596
881
|
|
|
597
|
-
elif
|
|
882
|
+
elif field_type.annotation in [list[ImageURL], list[ImageURL] | None]:
|
|
598
883
|
for image_url in field_value:
|
|
599
|
-
|
|
884
|
+
image_content.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
600
885
|
|
|
601
886
|
# pre-encoded images (or list of pre-encoded images)
|
|
602
|
-
elif
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
{"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
|
|
887
|
+
elif field_type.annotation in [ImageBase64, ImageBase64 | None]:
|
|
888
|
+
image_content.append(
|
|
889
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
|
|
606
890
|
)
|
|
607
891
|
|
|
608
|
-
elif
|
|
892
|
+
elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None]:
|
|
609
893
|
for base64_image in field_value:
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
{"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
|
|
894
|
+
image_content.append(
|
|
895
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
|
613
896
|
)
|
|
614
897
|
|
|
615
|
-
return
|
|
898
|
+
return [{"role": "user", "type": "image", "content": image_content}] if len(image_content) > 0 else []
|
|
616
899
|
|
|
617
900
|
def _get_system_prompt(self, **format_kwargs) -> str | None:
|
|
618
901
|
"""
|
|
@@ -631,7 +914,7 @@ class PromptFactory:
|
|
|
631
914
|
|
|
632
915
|
return base_prompt.format(**format_kwargs)
|
|
633
916
|
|
|
634
|
-
def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], **kwargs) -> str:
|
|
917
|
+
def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
|
|
635
918
|
"""
|
|
636
919
|
Returns a list of messages for the chat payload based on the prompt strategy.
|
|
637
920
|
|
|
@@ -648,10 +931,18 @@ class PromptFactory:
|
|
|
648
931
|
# get the base prompt template
|
|
649
932
|
base_prompt = self.BASE_USER_PROMPT_MAP.get(self.prompt_strategy)
|
|
650
933
|
|
|
651
|
-
# get any image messages for the chat payload (will be an empty list if
|
|
652
|
-
image_messages = (
|
|
653
|
-
|
|
654
|
-
)
|
|
934
|
+
# get any image messages for the chat payload (will be an empty list if no image fields exist)
|
|
935
|
+
image_messages = self._create_image_messages(candidate, input_fields)
|
|
936
|
+
|
|
937
|
+
# get any audio messages for the chat payload (will be an empty list if no audio fields exist)
|
|
938
|
+
audio_messages = self._create_audio_messages(candidate, input_fields)
|
|
939
|
+
|
|
940
|
+
# get any right image messages for the chat payload (will be an empty list if this is not a join image prompt)
|
|
941
|
+
right_image_messages, right_audio_messages = [], []
|
|
942
|
+
if self.prompt_strategy.is_join_prompt():
|
|
943
|
+
assert right_candidate is not None, "Right candidate must be provided for join prompts."
|
|
944
|
+
right_image_messages = self._create_image_messages(right_candidate, right_input_fields)
|
|
945
|
+
right_audio_messages = self._create_audio_messages(right_candidate, right_input_fields)
|
|
655
946
|
|
|
656
947
|
# get any original messages for critique and refinement operations
|
|
657
948
|
original_messages = kwargs.get("original_messages")
|
|
@@ -660,6 +951,8 @@ class PromptFactory:
|
|
|
660
951
|
"Original messages must be provided for critique and refinement operations."
|
|
661
952
|
)
|
|
662
953
|
|
|
954
|
+
# TODO: in the future if we support many modalities (e.g. images and audio) in the same prompt,
|
|
955
|
+
# then we will need to streamline this logic to handle the many different cases
|
|
663
956
|
# construct the user messages based on the prompt strategy
|
|
664
957
|
user_messages = []
|
|
665
958
|
if self.prompt_strategy.is_critic_prompt() or self.prompt_strategy.is_refine_prompt():
|
|
@@ -670,14 +963,47 @@ class PromptFactory:
|
|
|
670
963
|
user_messages.extend(original_messages)
|
|
671
964
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
672
965
|
|
|
673
|
-
|
|
674
|
-
|
|
966
|
+
# image not join
|
|
967
|
+
elif self.prompt_strategy.is_image_prompt() and not self.prompt_strategy.is_join_prompt():
|
|
968
|
+
base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
|
|
969
|
+
base_prompt_start, base_prompt_end = base_prompt.split("<<image-placeholder>>")
|
|
675
970
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
676
971
|
user_messages.extend(image_messages)
|
|
677
972
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
678
973
|
|
|
974
|
+
# image join
|
|
975
|
+
elif self.prompt_strategy.is_image_prompt() and self.prompt_strategy.is_join_prompt():
|
|
976
|
+
# for join image prompts, we may have two sets of images (one from the left candidate and one from the right candidate)
|
|
977
|
+
base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
|
|
978
|
+
base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<image-placeholder>>")
|
|
979
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
980
|
+
user_messages.extend(image_messages)
|
|
981
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
|
|
982
|
+
user_messages.extend(right_image_messages)
|
|
983
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
984
|
+
|
|
985
|
+
# audio not join
|
|
986
|
+
elif self.prompt_strategy.is_audio_prompt() and not self.prompt_strategy.is_join_prompt():
|
|
987
|
+
base_prompt = base_prompt.replace("<<image-placeholder>>", "")
|
|
988
|
+
base_prompt_start, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
|
|
989
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
990
|
+
user_messages.extend(audio_messages)
|
|
991
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
992
|
+
|
|
993
|
+
# audio join
|
|
994
|
+
elif self.prompt_strategy.is_audio_prompt() and self.prompt_strategy.is_join_prompt():
|
|
995
|
+
# for join image prompts, we may have two sets of images (one from the left candidate and one from the right candidate)
|
|
996
|
+
base_prompt = base_prompt.replace("<<image-placeholder>>", "")
|
|
997
|
+
base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
|
|
998
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
999
|
+
user_messages.extend(audio_messages)
|
|
1000
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
|
|
1001
|
+
user_messages.extend(right_audio_messages)
|
|
1002
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
1003
|
+
|
|
679
1004
|
else:
|
|
680
1005
|
base_prompt = base_prompt.replace("<<image-placeholder>>", "")
|
|
1006
|
+
base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
|
|
681
1007
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt.format(**kwargs)})
|
|
682
1008
|
|
|
683
1009
|
return user_messages
|
|
@@ -720,7 +1046,7 @@ class PromptFactory:
|
|
|
720
1046
|
# build set of format kwargs
|
|
721
1047
|
format_kwargs = {
|
|
722
1048
|
field_name: "<bytes>"
|
|
723
|
-
if
|
|
1049
|
+
if candidate.get_field_type(field_name).annotation in [bytes, bytes | None]
|
|
724
1050
|
else candidate[field_name]
|
|
725
1051
|
for field_name in input_fields
|
|
726
1052
|
}
|
|
@@ -740,7 +1066,7 @@ class PromptFactory:
|
|
|
740
1066
|
|
|
741
1067
|
return messages
|
|
742
1068
|
|
|
743
|
-
def create_messages(self, candidate: DataRecord, output_fields: list[str], **kwargs) -> list[dict]:
|
|
1069
|
+
def create_messages(self, candidate: DataRecord, output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
|
|
744
1070
|
"""
|
|
745
1071
|
Creates the messages for the chat payload based on the prompt strategy.
|
|
746
1072
|
|
|
@@ -754,6 +1080,7 @@ class PromptFactory:
|
|
|
754
1080
|
Args:
|
|
755
1081
|
candidate (DataRecord): The input record.
|
|
756
1082
|
output_fields (list[str]): The output fields.
|
|
1083
|
+
right_candidate (DataRecord | None): The other join input record (only provided for joins).
|
|
757
1084
|
kwargs: The keyword arguments provided by the user.
|
|
758
1085
|
|
|
759
1086
|
Returns:
|
|
@@ -761,6 +1088,7 @@ class PromptFactory:
|
|
|
761
1088
|
"""
|
|
762
1089
|
# compute the set of input fields
|
|
763
1090
|
input_fields = self._get_input_fields(candidate, **kwargs)
|
|
1091
|
+
right_input_fields = [] if right_candidate is None else self._get_input_fields(right_candidate, **kwargs)
|
|
764
1092
|
|
|
765
1093
|
# if the user provides a prompt, we process that prompt into messages and return them
|
|
766
1094
|
if "prompt" in kwargs:
|
|
@@ -774,7 +1102,7 @@ class PromptFactory:
|
|
|
774
1102
|
messages = []
|
|
775
1103
|
|
|
776
1104
|
# compute the full dictionary of format kwargs and add to kwargs
|
|
777
|
-
format_kwargs = self._get_all_format_kwargs(candidate, input_fields, output_fields, **kwargs)
|
|
1105
|
+
format_kwargs = self._get_all_format_kwargs(candidate, input_fields, output_fields, right_candidate, right_input_fields, **kwargs)
|
|
778
1106
|
kwargs = {**kwargs, **format_kwargs}
|
|
779
1107
|
|
|
780
1108
|
# generate system message (if applicable)
|
|
@@ -783,7 +1111,7 @@ class PromptFactory:
|
|
|
783
1111
|
messages.append({"role": "system", "type": "text", "content": system_prompt})
|
|
784
1112
|
|
|
785
1113
|
# generate user messages and add to messages
|
|
786
|
-
user_messages = self._get_user_messages(candidate, input_fields, **kwargs)
|
|
1114
|
+
user_messages = self._get_user_messages(candidate, input_fields, right_candidate, right_input_fields, **kwargs)
|
|
787
1115
|
messages.extend(user_messages)
|
|
788
1116
|
|
|
789
1117
|
return messages
|