palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/__init__.py +37 -6
- palimpzest/agents/__init__.py +0 -0
- palimpzest/agents/compute_agents.py +0 -0
- palimpzest/agents/search_agents.py +637 -0
- palimpzest/constants.py +259 -197
- palimpzest/core/data/context.py +393 -0
- palimpzest/core/data/context_manager.py +163 -0
- palimpzest/core/data/dataset.py +634 -0
- palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
- palimpzest/core/elements/groupbysig.py +16 -13
- palimpzest/core/elements/records.py +166 -75
- palimpzest/core/lib/schemas.py +152 -390
- palimpzest/core/{data/dataclasses.py → models.py} +306 -170
- palimpzest/policy.py +2 -27
- palimpzest/prompts/__init__.py +35 -5
- palimpzest/prompts/agent_prompts.py +357 -0
- palimpzest/prompts/context_search.py +9 -0
- palimpzest/prompts/convert_prompts.py +61 -5
- palimpzest/prompts/filter_prompts.py +50 -5
- palimpzest/prompts/join_prompts.py +163 -0
- palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
- palimpzest/prompts/prompt_factory.py +358 -46
- palimpzest/prompts/validator.py +239 -0
- palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
- palimpzest/query/execution/execution_strategy.py +210 -317
- palimpzest/query/execution/execution_strategy_type.py +5 -7
- palimpzest/query/execution/mab_execution_strategy.py +249 -136
- palimpzest/query/execution/parallel_execution_strategy.py +153 -244
- palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
- palimpzest/query/generators/generators.py +157 -330
- palimpzest/query/operators/__init__.py +15 -5
- palimpzest/query/operators/aggregate.py +50 -33
- palimpzest/query/operators/compute.py +201 -0
- palimpzest/query/operators/convert.py +27 -21
- palimpzest/query/operators/critique_and_refine_convert.py +7 -5
- palimpzest/query/operators/distinct.py +62 -0
- palimpzest/query/operators/filter.py +22 -13
- palimpzest/query/operators/join.py +402 -0
- palimpzest/query/operators/limit.py +3 -3
- palimpzest/query/operators/logical.py +198 -80
- palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
- palimpzest/query/operators/physical.py +27 -21
- palimpzest/query/operators/project.py +3 -3
- palimpzest/query/operators/rag_convert.py +7 -7
- palimpzest/query/operators/retrieve.py +9 -9
- palimpzest/query/operators/scan.py +81 -42
- palimpzest/query/operators/search.py +524 -0
- palimpzest/query/operators/split_convert.py +10 -8
- palimpzest/query/optimizer/__init__.py +7 -9
- palimpzest/query/optimizer/cost_model.py +108 -441
- palimpzest/query/optimizer/optimizer.py +123 -181
- palimpzest/query/optimizer/optimizer_strategy.py +66 -61
- palimpzest/query/optimizer/plan.py +352 -67
- palimpzest/query/optimizer/primitives.py +43 -19
- palimpzest/query/optimizer/rules.py +484 -646
- palimpzest/query/optimizer/tasks.py +127 -58
- palimpzest/query/processor/config.py +41 -76
- palimpzest/query/processor/query_processor.py +73 -18
- palimpzest/query/processor/query_processor_factory.py +46 -38
- palimpzest/schemabuilder/schema_builder.py +15 -28
- palimpzest/utils/model_helpers.py +27 -77
- palimpzest/utils/progress.py +114 -102
- palimpzest/validator/__init__.py +0 -0
- palimpzest/validator/validator.py +306 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
- palimpzest-0.8.0.dist-info/RECORD +95 -0
- palimpzest/core/lib/fields.py +0 -141
- palimpzest/prompts/code_synthesis_prompts.py +0 -28
- palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
- palimpzest/query/generators/api_client_factory.py +0 -30
- palimpzest/query/operators/code_synthesis_convert.py +0 -488
- palimpzest/query/operators/map.py +0 -130
- palimpzest/query/processor/nosentinel_processor.py +0 -33
- palimpzest/query/processor/processing_strategy_type.py +0 -28
- palimpzest/query/processor/sentinel_processor.py +0 -88
- palimpzest/query/processor/streaming_processor.py +0 -149
- palimpzest/sets.py +0 -405
- palimpzest/utils/datareader_helpers.py +0 -61
- palimpzest/utils/demo_helpers.py +0 -75
- palimpzest/utils/field_helpers.py +0 -69
- palimpzest/utils/generation_helpers.py +0 -69
- palimpzest/utils/sandbox.py +0 -183
- palimpzest-0.7.20.dist-info/RECORD +0 -95
- /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
- {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -4,17 +4,25 @@ import base64
|
|
|
4
4
|
import json
|
|
5
5
|
from string import Formatter
|
|
6
6
|
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
7
9
|
from palimpzest.constants import (
|
|
8
|
-
|
|
10
|
+
LLAMA_CONTEXT_TOKENS_LIMIT,
|
|
9
11
|
TOKENS_PER_CHARACTER,
|
|
10
12
|
Cardinality,
|
|
11
13
|
Model,
|
|
12
14
|
PromptStrategy,
|
|
13
15
|
)
|
|
14
16
|
from palimpzest.core.elements.records import DataRecord
|
|
15
|
-
from palimpzest.core.lib.
|
|
16
|
-
from palimpzest.core.lib.schemas import Schema
|
|
17
|
+
from palimpzest.core.lib.schemas import AudioBase64, AudioFilepath, ImageBase64, ImageFilepath, ImageURL
|
|
17
18
|
from palimpzest.prompts.convert_prompts import (
|
|
19
|
+
COT_QA_AUDIO_DISCLAIMER,
|
|
20
|
+
COT_QA_AUDIO_EXAMPLE_ANSWER,
|
|
21
|
+
COT_QA_AUDIO_EXAMPLE_CONTEXT,
|
|
22
|
+
COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
23
|
+
COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
|
|
24
|
+
COT_QA_AUDIO_EXAMPLE_REASONING,
|
|
25
|
+
COT_QA_AUDIO_JOB_INSTRUCTION,
|
|
18
26
|
COT_QA_BASE_SYSTEM_PROMPT,
|
|
19
27
|
COT_QA_BASE_USER_PROMPT,
|
|
20
28
|
COT_QA_EXAMPLE_ANSWER,
|
|
@@ -30,6 +38,8 @@ from palimpzest.prompts.convert_prompts import (
|
|
|
30
38
|
COT_QA_IMAGE_EXAMPLE_REASONING,
|
|
31
39
|
COT_QA_IMAGE_JOB_INSTRUCTION,
|
|
32
40
|
COT_QA_JOB_INSTRUCTION,
|
|
41
|
+
COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
42
|
+
COT_QA_NO_REASONING_BASE_USER_PROMPT,
|
|
33
43
|
)
|
|
34
44
|
from palimpzest.prompts.critique_and_refine_convert_prompts import (
|
|
35
45
|
BASE_CRITIQUE_PROMPT,
|
|
@@ -42,6 +52,12 @@ from palimpzest.prompts.critique_and_refine_convert_prompts import (
|
|
|
42
52
|
COT_QA_REFINEMENT_FINISH_INSTRUCTION,
|
|
43
53
|
)
|
|
44
54
|
from palimpzest.prompts.filter_prompts import (
|
|
55
|
+
COT_BOOL_AUDIO_DISCLAIMER,
|
|
56
|
+
COT_BOOL_AUDIO_EXAMPLE_CONTEXT,
|
|
57
|
+
COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
|
|
58
|
+
COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
59
|
+
COT_BOOL_AUDIO_EXAMPLE_REASONING,
|
|
60
|
+
COT_BOOL_AUDIO_JOB_INSTRUCTION,
|
|
45
61
|
COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
46
62
|
COT_BOOL_BASE_USER_PROMPT,
|
|
47
63
|
COT_BOOL_EXAMPLE_CONTEXT,
|
|
@@ -55,6 +71,39 @@ from palimpzest.prompts.filter_prompts import (
|
|
|
55
71
|
COT_BOOL_IMAGE_EXAMPLE_REASONING,
|
|
56
72
|
COT_BOOL_IMAGE_JOB_INSTRUCTION,
|
|
57
73
|
COT_BOOL_JOB_INSTRUCTION,
|
|
74
|
+
COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
75
|
+
COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
|
|
76
|
+
)
|
|
77
|
+
from palimpzest.prompts.join_prompts import (
|
|
78
|
+
COT_JOIN_AUDIO_DISCLAIMER,
|
|
79
|
+
COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
|
|
80
|
+
COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
81
|
+
COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
|
|
82
|
+
COT_JOIN_AUDIO_EXAMPLE_REASONING,
|
|
83
|
+
COT_JOIN_AUDIO_JOB_INSTRUCTION,
|
|
84
|
+
COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
|
|
85
|
+
COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
86
|
+
COT_JOIN_BASE_SYSTEM_PROMPT,
|
|
87
|
+
COT_JOIN_BASE_USER_PROMPT,
|
|
88
|
+
COT_JOIN_EXAMPLE_CONTEXT,
|
|
89
|
+
COT_JOIN_EXAMPLE_INPUT_FIELDS,
|
|
90
|
+
COT_JOIN_EXAMPLE_JOIN_CONDITION,
|
|
91
|
+
COT_JOIN_EXAMPLE_REASONING,
|
|
92
|
+
COT_JOIN_IMAGE_DISCLAIMER,
|
|
93
|
+
COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
|
|
94
|
+
COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
95
|
+
COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
|
|
96
|
+
COT_JOIN_IMAGE_EXAMPLE_REASONING,
|
|
97
|
+
COT_JOIN_IMAGE_JOB_INSTRUCTION,
|
|
98
|
+
COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
|
|
99
|
+
COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
100
|
+
COT_JOIN_JOB_INSTRUCTION,
|
|
101
|
+
COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
102
|
+
COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
103
|
+
COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
|
|
104
|
+
COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
|
|
105
|
+
COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
106
|
+
COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
|
|
58
107
|
)
|
|
59
108
|
from palimpzest.prompts.moa_aggregator_convert_prompts import (
|
|
60
109
|
COT_MOA_AGG_BASE_SYSTEM_PROMPT,
|
|
@@ -99,11 +148,25 @@ class PromptFactory:
|
|
|
99
148
|
|
|
100
149
|
BASE_SYSTEM_PROMPT_MAP = {
|
|
101
150
|
PromptStrategy.COT_BOOL: COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
151
|
+
PromptStrategy.COT_BOOL_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
152
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
153
|
+
PromptStrategy.COT_BOOL_AUDIO_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
102
154
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
155
|
+
PromptStrategy.COT_BOOL_IMAGE_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
156
|
+
PromptStrategy.COT_JOIN: COT_JOIN_BASE_SYSTEM_PROMPT,
|
|
157
|
+
PromptStrategy.COT_JOIN_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
158
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_BASE_SYSTEM_PROMPT,
|
|
159
|
+
PromptStrategy.COT_JOIN_AUDIO_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
160
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_BASE_SYSTEM_PROMPT,
|
|
161
|
+
PromptStrategy.COT_JOIN_IMAGE_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
103
162
|
PromptStrategy.COT_QA: COT_QA_BASE_SYSTEM_PROMPT,
|
|
163
|
+
PromptStrategy.COT_QA_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
164
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_BASE_SYSTEM_PROMPT,
|
|
165
|
+
PromptStrategy.COT_QA_AUDIO_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
104
166
|
PromptStrategy.COT_QA_CRITIC: None,
|
|
105
167
|
PromptStrategy.COT_QA_REFINE: None,
|
|
106
168
|
PromptStrategy.COT_QA_IMAGE: COT_QA_BASE_SYSTEM_PROMPT,
|
|
169
|
+
PromptStrategy.COT_QA_IMAGE_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
|
|
107
170
|
PromptStrategy.COT_QA_IMAGE_CRITIC: None,
|
|
108
171
|
PromptStrategy.COT_QA_IMAGE_REFINE: None,
|
|
109
172
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
|
|
@@ -114,11 +177,25 @@ class PromptFactory:
|
|
|
114
177
|
}
|
|
115
178
|
BASE_USER_PROMPT_MAP = {
|
|
116
179
|
PromptStrategy.COT_BOOL: COT_BOOL_BASE_USER_PROMPT,
|
|
180
|
+
PromptStrategy.COT_BOOL_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
|
|
181
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_BASE_USER_PROMPT,
|
|
182
|
+
PromptStrategy.COT_BOOL_AUDIO_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
|
|
117
183
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_USER_PROMPT,
|
|
184
|
+
PromptStrategy.COT_BOOL_IMAGE_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
|
|
185
|
+
PromptStrategy.COT_JOIN: COT_JOIN_BASE_USER_PROMPT,
|
|
186
|
+
PromptStrategy.COT_JOIN_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
187
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_BASE_USER_PROMPT,
|
|
188
|
+
PromptStrategy.COT_JOIN_AUDIO_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
189
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_BASE_USER_PROMPT,
|
|
190
|
+
PromptStrategy.COT_JOIN_IMAGE_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
|
|
118
191
|
PromptStrategy.COT_QA: COT_QA_BASE_USER_PROMPT,
|
|
192
|
+
PromptStrategy.COT_QA_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
|
|
193
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_BASE_USER_PROMPT,
|
|
194
|
+
PromptStrategy.COT_QA_AUDIO_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
|
|
119
195
|
PromptStrategy.COT_QA_CRITIC: BASE_CRITIQUE_PROMPT,
|
|
120
196
|
PromptStrategy.COT_QA_REFINE: BASE_REFINEMENT_PROMPT,
|
|
121
197
|
PromptStrategy.COT_QA_IMAGE: COT_QA_BASE_USER_PROMPT,
|
|
198
|
+
PromptStrategy.COT_QA_IMAGE_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
|
|
122
199
|
PromptStrategy.COT_QA_IMAGE_CRITIC: BASE_CRITIQUE_PROMPT,
|
|
123
200
|
PromptStrategy.COT_QA_IMAGE_REFINE: BASE_REFINEMENT_PROMPT,
|
|
124
201
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_BASE_USER_PROMPT,
|
|
@@ -144,8 +221,9 @@ class PromptFactory:
|
|
|
144
221
|
Returns:
|
|
145
222
|
str: The context.
|
|
146
223
|
"""
|
|
224
|
+
# TODO: remove mask_filepaths=True after SemBench evaluation
|
|
147
225
|
# get context from input record (project_cols will be None if not provided in kwargs)
|
|
148
|
-
context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields)
|
|
226
|
+
context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True)
|
|
149
227
|
|
|
150
228
|
# TODO: MOVE THIS LOGIC INTO A CHUNKING / CONTEXT MANAGEMENT CLASS
|
|
151
229
|
# - this class should be able to:
|
|
@@ -155,12 +233,12 @@ class PromptFactory:
|
|
|
155
233
|
# TODO: this does not work for image prompts
|
|
156
234
|
# TODO: this ignores the size of the `orignal_messages` in critique and refine prompts
|
|
157
235
|
# cut down on context based on window length
|
|
158
|
-
if self.model.is_llama_model()
|
|
236
|
+
if self.model.is_llama_model():
|
|
159
237
|
total_context_len = len(json.dumps(context, indent=2))
|
|
160
238
|
|
|
161
239
|
# sort fields by length and progressively strip from the longest field until it is short enough;
|
|
162
|
-
# NOTE:
|
|
163
|
-
while total_context_len * TOKENS_PER_CHARACTER >
|
|
240
|
+
# NOTE: LLAMA_CONTEXT_TOKENS_LIMIT is a rough estimate which leaves room for the rest of the prompt text
|
|
241
|
+
while total_context_len * TOKENS_PER_CHARACTER > LLAMA_CONTEXT_TOKENS_LIMIT:
|
|
164
242
|
# sort fields by length
|
|
165
243
|
field_lengths = [(field, len(value) if value is not None else 0) for field, value in context.items()]
|
|
166
244
|
sorted_fields = sorted(field_lengths, key=lambda item: item[1], reverse=True)
|
|
@@ -169,7 +247,7 @@ class PromptFactory:
|
|
|
169
247
|
longest_field_name, longest_field_length = sorted_fields[0]
|
|
170
248
|
|
|
171
249
|
# trim the field
|
|
172
|
-
context_factor =
|
|
250
|
+
context_factor = LLAMA_CONTEXT_TOKENS_LIMIT / (total_context_len * TOKENS_PER_CHARACTER)
|
|
173
251
|
keep_frac_idx = int(longest_field_length * context_factor)
|
|
174
252
|
context[longest_field_name] = context[longest_field_name][:keep_frac_idx]
|
|
175
253
|
|
|
@@ -191,7 +269,11 @@ class PromptFactory:
|
|
|
191
269
|
Returns:
|
|
192
270
|
list[str]: The list of input field names.
|
|
193
271
|
"""
|
|
194
|
-
|
|
272
|
+
# NOTE: joins will include left and right input fields in project_cols, so we have to check
|
|
273
|
+
# if the field is in the candidate record
|
|
274
|
+
input_fields = kwargs.get("project_cols", candidate.get_field_names())
|
|
275
|
+
input_fields = [field for field in input_fields if field in candidate.get_field_names()]
|
|
276
|
+
return input_fields
|
|
195
277
|
|
|
196
278
|
def _get_input_fields_desc(self, candidate: DataRecord, input_fields: list[str]) -> str:
|
|
197
279
|
"""
|
|
@@ -205,7 +287,7 @@ class PromptFactory:
|
|
|
205
287
|
"""
|
|
206
288
|
input_fields_desc = ""
|
|
207
289
|
for field_name in input_fields:
|
|
208
|
-
input_fields_desc += f"- {field_name}: {candidate.get_field_type(field_name).
|
|
290
|
+
input_fields_desc += f"- {field_name}: {candidate.get_field_type(field_name).description}\n"
|
|
209
291
|
|
|
210
292
|
return input_fields_desc[:-1]
|
|
211
293
|
|
|
@@ -221,13 +303,13 @@ class PromptFactory:
|
|
|
221
303
|
str: The output fields description.
|
|
222
304
|
"""
|
|
223
305
|
output_fields_desc = ""
|
|
224
|
-
output_schema:
|
|
306
|
+
output_schema: BaseModel = kwargs.get("output_schema")
|
|
225
307
|
if self.prompt_strategy.is_convert_prompt():
|
|
226
308
|
assert output_schema is not None, "Output schema must be provided for convert prompts."
|
|
227
309
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
output_fields_desc += f"- {field_name}: {
|
|
310
|
+
for field_name in sorted(output_fields):
|
|
311
|
+
desc = output_schema.model_fields[field_name].description
|
|
312
|
+
output_fields_desc += f"- {field_name}: {'no description available' if desc is None else desc}\n"
|
|
231
313
|
|
|
232
314
|
# strip the last newline characters from the field descriptions and return
|
|
233
315
|
return output_fields_desc[:-1]
|
|
@@ -245,6 +327,19 @@ class PromptFactory:
|
|
|
245
327
|
|
|
246
328
|
return filter_condition
|
|
247
329
|
|
|
330
|
+
def _get_join_condition(self, **kwargs) -> str | None:
|
|
331
|
+
"""
|
|
332
|
+
Returns the join condition for the join operation.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
str | None: The join condition (if applicable).
|
|
336
|
+
"""
|
|
337
|
+
join_condition = kwargs.get("join_condition")
|
|
338
|
+
if self.prompt_strategy.is_join_prompt():
|
|
339
|
+
assert join_condition is not None, "Join condition must be provided for join operations."
|
|
340
|
+
|
|
341
|
+
return join_condition
|
|
342
|
+
|
|
248
343
|
def _get_original_output(self, **kwargs) -> str | None:
|
|
249
344
|
"""
|
|
250
345
|
Returns the original output from a previous model generation for the critique and refinement operations.
|
|
@@ -337,8 +432,13 @@ class PromptFactory:
|
|
|
337
432
|
"""
|
|
338
433
|
prompt_strategy_to_job_instruction = {
|
|
339
434
|
PromptStrategy.COT_BOOL: COT_BOOL_JOB_INSTRUCTION,
|
|
435
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_JOB_INSTRUCTION,
|
|
340
436
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_JOB_INSTRUCTION,
|
|
437
|
+
PromptStrategy.COT_JOIN: COT_JOIN_JOB_INSTRUCTION,
|
|
438
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_JOB_INSTRUCTION,
|
|
439
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_JOB_INSTRUCTION,
|
|
341
440
|
PromptStrategy.COT_QA: COT_QA_JOB_INSTRUCTION,
|
|
441
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_JOB_INSTRUCTION,
|
|
342
442
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_JOB_INSTRUCTION,
|
|
343
443
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_JOB_INSTRUCTION,
|
|
344
444
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_JOB_INSTRUCTION,
|
|
@@ -402,8 +502,13 @@ class PromptFactory:
|
|
|
402
502
|
"""
|
|
403
503
|
prompt_strategy_to_example_input_fields = {
|
|
404
504
|
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_INPUT_FIELDS,
|
|
505
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
405
506
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
507
|
+
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_INPUT_FIELDS,
|
|
508
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
509
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
406
510
|
PromptStrategy.COT_QA: COT_QA_EXAMPLE_INPUT_FIELDS,
|
|
511
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
|
|
407
512
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
408
513
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_INPUT_FIELDS,
|
|
409
514
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_INPUT_FIELDS,
|
|
@@ -412,6 +517,21 @@ class PromptFactory:
|
|
|
412
517
|
|
|
413
518
|
return prompt_strategy_to_example_input_fields.get(self.prompt_strategy)
|
|
414
519
|
|
|
520
|
+
def _get_right_example_input_fields(self) -> str | None:
|
|
521
|
+
"""
|
|
522
|
+
Returns the example right input fields for the join prompt.
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
str | None: The example right input fields (if applicable).
|
|
526
|
+
"""
|
|
527
|
+
prompt_strategy_to_right_example_input_fields = {
|
|
528
|
+
PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
529
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
530
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
return prompt_strategy_to_right_example_input_fields.get(self.prompt_strategy)
|
|
534
|
+
|
|
415
535
|
def _get_example_output_fields(self) -> str | None:
|
|
416
536
|
"""
|
|
417
537
|
Returns the example output fields for the prompt.
|
|
@@ -421,6 +541,7 @@ class PromptFactory:
|
|
|
421
541
|
"""
|
|
422
542
|
prompt_strategy_to_example_output_fields = {
|
|
423
543
|
PromptStrategy.COT_QA: COT_QA_EXAMPLE_OUTPUT_FIELDS,
|
|
544
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
|
|
424
545
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
425
546
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
|
|
426
547
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_OUTPUT_FIELDS,
|
|
@@ -438,8 +559,13 @@ class PromptFactory:
|
|
|
438
559
|
"""
|
|
439
560
|
prompt_strategy_to_example_context = {
|
|
440
561
|
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_CONTEXT,
|
|
562
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_CONTEXT,
|
|
441
563
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_CONTEXT,
|
|
564
|
+
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_CONTEXT,
|
|
565
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
|
|
566
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
|
|
442
567
|
PromptStrategy.COT_QA: COT_QA_EXAMPLE_CONTEXT,
|
|
568
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_CONTEXT,
|
|
443
569
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_CONTEXT,
|
|
444
570
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_CONTEXT,
|
|
445
571
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_CONTEXT,
|
|
@@ -448,6 +574,21 @@ class PromptFactory:
|
|
|
448
574
|
|
|
449
575
|
return prompt_strategy_to_example_context.get(self.prompt_strategy)
|
|
450
576
|
|
|
577
|
+
def _get_right_example_context(self) -> str | None:
|
|
578
|
+
"""
|
|
579
|
+
Returns the right example context for the join prompt.
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
str | None: The right example context (if applicable).
|
|
583
|
+
"""
|
|
584
|
+
prompt_strategy_to_right_example_context = {
|
|
585
|
+
PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
|
|
586
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
|
|
587
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
return prompt_strategy_to_right_example_context.get(self.prompt_strategy)
|
|
591
|
+
|
|
451
592
|
def _get_image_disclaimer(self) -> str:
|
|
452
593
|
"""
|
|
453
594
|
Returns the image disclaimer for the prompt. The disclaimer must be an empty string
|
|
@@ -458,12 +599,57 @@ class PromptFactory:
|
|
|
458
599
|
"""
|
|
459
600
|
prompt_strategy_to_image_disclaimer = {
|
|
460
601
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_DISCLAIMER,
|
|
602
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_DISCLAIMER,
|
|
461
603
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_DISCLAIMER,
|
|
462
604
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_DISCLAIMER,
|
|
463
605
|
}
|
|
464
606
|
|
|
465
607
|
return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
|
|
466
608
|
|
|
609
|
+
def _get_audio_disclaimer(self) -> str:
|
|
610
|
+
"""
|
|
611
|
+
Returns the audio disclaimer for the prompt. The disclaimer must be an empty string
|
|
612
|
+
for text prompts.
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
str: The audio disclaimer. If this is a text prompt then it is an empty string.
|
|
616
|
+
"""
|
|
617
|
+
prompt_strategy_to_audio_disclaimer = {
|
|
618
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_DISCLAIMER,
|
|
619
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_DISCLAIMER,
|
|
620
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_DISCLAIMER,
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
|
|
624
|
+
|
|
625
|
+
def _get_right_image_disclaimer(self) -> str:
|
|
626
|
+
"""
|
|
627
|
+
Returns the right image disclaimer for the prompt. The disclaimer must be an empty string
|
|
628
|
+
for text prompts.
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
str: The right image disclaimer. If this is a text prompt then it is an empty string.
|
|
632
|
+
"""
|
|
633
|
+
prompt_strategy_to_image_disclaimer = {
|
|
634
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
|
|
638
|
+
|
|
639
|
+
def _get_right_audio_disclaimer(self) -> str:
|
|
640
|
+
"""
|
|
641
|
+
Returns the right audio disclaimer for the prompt. The disclaimer must be an empty string
|
|
642
|
+
for text prompts.
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
str: The right audio disclaimer. If this is a text prompt then it is an empty string.
|
|
646
|
+
"""
|
|
647
|
+
prompt_strategy_to_audio_disclaimer = {
|
|
648
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
|
|
652
|
+
|
|
467
653
|
def _get_example_filter_condition(self) -> str | None:
|
|
468
654
|
"""
|
|
469
655
|
Returns the example filter condition for the prompt.
|
|
@@ -473,11 +659,27 @@ class PromptFactory:
|
|
|
473
659
|
"""
|
|
474
660
|
prompt_strategy_to_example_filter_condition = {
|
|
475
661
|
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_FILTER_CONDITION,
|
|
662
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
|
|
476
663
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_FILTER_CONDITION,
|
|
477
664
|
}
|
|
478
665
|
|
|
479
666
|
return prompt_strategy_to_example_filter_condition.get(self.prompt_strategy)
|
|
480
667
|
|
|
668
|
+
def _get_example_join_condition(self) -> str | None:
|
|
669
|
+
"""
|
|
670
|
+
Returns the example join condition for the prompt.
|
|
671
|
+
|
|
672
|
+
Returns:
|
|
673
|
+
str | None: The example join condition (if applicable).
|
|
674
|
+
"""
|
|
675
|
+
prompt_strategy_to_example_join_condition = {
|
|
676
|
+
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_JOIN_CONDITION,
|
|
677
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
|
|
678
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
return prompt_strategy_to_example_join_condition.get(self.prompt_strategy)
|
|
682
|
+
|
|
481
683
|
def _get_example_reasoning(self) -> str | None:
|
|
482
684
|
"""
|
|
483
685
|
Returns the example reasoning for the prompt.
|
|
@@ -487,8 +689,13 @@ class PromptFactory:
|
|
|
487
689
|
"""
|
|
488
690
|
prompt_strategy_to_example_reasoning = {
|
|
489
691
|
PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_REASONING,
|
|
692
|
+
PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_REASONING,
|
|
490
693
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_REASONING,
|
|
694
|
+
PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_REASONING,
|
|
695
|
+
PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_REASONING,
|
|
696
|
+
PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_REASONING,
|
|
491
697
|
PromptStrategy.COT_QA: COT_QA_EXAMPLE_REASONING,
|
|
698
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_REASONING,
|
|
492
699
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_REASONING,
|
|
493
700
|
}
|
|
494
701
|
|
|
@@ -503,6 +710,7 @@ class PromptFactory:
|
|
|
503
710
|
"""
|
|
504
711
|
prompt_strategy_to_example_answer = {
|
|
505
712
|
PromptStrategy.COT_QA: COT_QA_EXAMPLE_ANSWER,
|
|
713
|
+
PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_ANSWER,
|
|
506
714
|
PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_ANSWER,
|
|
507
715
|
PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_ANSWER,
|
|
508
716
|
PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_ANSWER,
|
|
@@ -512,7 +720,7 @@ class PromptFactory:
|
|
|
512
720
|
return prompt_strategy_to_example_answer.get(self.prompt_strategy)
|
|
513
721
|
|
|
514
722
|
def _get_all_format_kwargs(
|
|
515
|
-
self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], **kwargs
|
|
723
|
+
self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs
|
|
516
724
|
) -> dict:
|
|
517
725
|
"""
|
|
518
726
|
Returns a dictionary containing all the format kwargs for templating the prompts.
|
|
@@ -532,12 +740,20 @@ class PromptFactory:
|
|
|
532
740
|
"input_fields_desc": self._get_input_fields_desc(candidate, input_fields),
|
|
533
741
|
"output_fields_desc": self._get_output_fields_desc(output_fields, **kwargs),
|
|
534
742
|
"filter_condition": self._get_filter_condition(**kwargs),
|
|
743
|
+
"join_condition": self._get_join_condition(**kwargs),
|
|
535
744
|
"original_output": self._get_original_output(**kwargs),
|
|
536
745
|
"critique_output": self._get_critique_output(**kwargs),
|
|
537
746
|
"model_responses": self._get_model_responses(**kwargs),
|
|
538
747
|
"chunk_outputs": self._get_chunk_outputs(**kwargs),
|
|
539
748
|
}
|
|
540
749
|
|
|
750
|
+
# if a right candidate is provided, we also get the context and input field descriptions for the right candidate
|
|
751
|
+
if right_candidate is not None:
|
|
752
|
+
input_format_kwargs.update({
|
|
753
|
+
"right_context": self._get_context(right_candidate, right_input_fields),
|
|
754
|
+
"right_input_fields_desc": self._get_input_fields_desc(right_candidate, right_input_fields),
|
|
755
|
+
})
|
|
756
|
+
|
|
541
757
|
# get format kwargs which depend on the prompt strategy
|
|
542
758
|
prompt_strategy_format_kwargs = {
|
|
543
759
|
"output_format_instruction": self._get_output_format_instruction(),
|
|
@@ -546,10 +762,16 @@ class PromptFactory:
|
|
|
546
762
|
"refinement_criteria": self._get_refinement_criteria(),
|
|
547
763
|
"finish_instruction": self._get_finish_instruction(),
|
|
548
764
|
"example_input_fields": self._get_example_input_fields(),
|
|
765
|
+
"right_example_input_fields": self._get_right_example_input_fields(),
|
|
549
766
|
"example_output_fields": self._get_example_output_fields(),
|
|
550
767
|
"example_context": self._get_example_context(),
|
|
768
|
+
"right_example_context": self._get_right_example_context(),
|
|
551
769
|
"image_disclaimer": self._get_image_disclaimer(),
|
|
770
|
+
"audio_disclaimer": self._get_audio_disclaimer(),
|
|
771
|
+
"right_image_disclaimer": self._get_right_image_disclaimer(),
|
|
772
|
+
"right_audio_disclaimer": self._get_right_audio_disclaimer(),
|
|
552
773
|
"example_filter_condition": self._get_example_filter_condition(),
|
|
774
|
+
"example_join_condition": self._get_example_join_condition(),
|
|
553
775
|
"example_reasoning": self._get_example_reasoning(),
|
|
554
776
|
"example_answer": self._get_example_answer(),
|
|
555
777
|
}
|
|
@@ -557,6 +779,53 @@ class PromptFactory:
|
|
|
557
779
|
# return all format kwargs
|
|
558
780
|
return {**input_format_kwargs, **prompt_strategy_format_kwargs}
|
|
559
781
|
|
|
782
|
+
def _create_audio_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
|
|
783
|
+
"""
|
|
784
|
+
Parses the candidate record and returns the audio messages for the chat payload.
|
|
785
|
+
|
|
786
|
+
Args:
|
|
787
|
+
candidate (DataRecord): The input record.
|
|
788
|
+
input_fields (list[str]): The list of input fields.
|
|
789
|
+
|
|
790
|
+
Returns:
|
|
791
|
+
list[dict]: The audio messages for the chat payload.
|
|
792
|
+
"""
|
|
793
|
+
# create a message for each audio recording in an input field with an audio (or list of audio) type
|
|
794
|
+
audio_content = []
|
|
795
|
+
for field_name in input_fields:
|
|
796
|
+
field_value = candidate[field_name]
|
|
797
|
+
field_type = candidate.get_field_type(field_name)
|
|
798
|
+
|
|
799
|
+
# audio filepath (or list of audio filepaths)
|
|
800
|
+
if field_type.annotation in [AudioFilepath, AudioFilepath | None]:
|
|
801
|
+
with open(field_value, "rb") as f:
|
|
802
|
+
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
803
|
+
audio_content.append(
|
|
804
|
+
{"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None]:
|
|
808
|
+
for audio_filepath in field_value:
|
|
809
|
+
with open(audio_filepath, "rb") as f:
|
|
810
|
+
base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
|
|
811
|
+
audio_content.append(
|
|
812
|
+
{"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
# pre-encoded images (or list of pre-encoded images)
|
|
816
|
+
elif field_type.annotation in [AudioBase64, AudioBase64 | None]:
|
|
817
|
+
audio_content.append(
|
|
818
|
+
{"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None]:
|
|
822
|
+
for base64_audio in field_value:
|
|
823
|
+
audio_content.append(
|
|
824
|
+
{"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
return [{"role": "user", "type": "input_audio", "content": audio_content}] if len(audio_content) > 0 else []
|
|
828
|
+
|
|
560
829
|
def _create_image_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
|
|
561
830
|
"""
|
|
562
831
|
Parses the candidate record and returns the image messages for the chat payload.
|
|
@@ -569,50 +838,48 @@ class PromptFactory:
|
|
|
569
838
|
list[dict]: The image messages for the chat payload.
|
|
570
839
|
"""
|
|
571
840
|
# create a message for each image in an input field with an image (or list of image) type
|
|
572
|
-
|
|
841
|
+
image_content = []
|
|
573
842
|
for field_name in input_fields:
|
|
574
843
|
field_value = candidate[field_name]
|
|
575
844
|
field_type = candidate.get_field_type(field_name)
|
|
576
845
|
|
|
577
846
|
# image filepath (or list of image filepaths)
|
|
578
|
-
if
|
|
847
|
+
if field_type.annotation in [ImageFilepath, ImageFilepath | None]:
|
|
579
848
|
with open(field_value, "rb") as f:
|
|
580
849
|
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
581
|
-
|
|
582
|
-
{"
|
|
850
|
+
image_content.append(
|
|
851
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
|
|
583
852
|
)
|
|
584
853
|
|
|
585
|
-
elif
|
|
854
|
+
elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None]:
|
|
586
855
|
for image_filepath in field_value:
|
|
587
856
|
with open(image_filepath, "rb") as f:
|
|
588
857
|
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
589
|
-
|
|
590
|
-
{"
|
|
858
|
+
image_content.append(
|
|
859
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
|
|
591
860
|
)
|
|
592
861
|
|
|
593
862
|
# image url (or list of image urls)
|
|
594
|
-
elif
|
|
595
|
-
|
|
863
|
+
elif field_type.annotation in [ImageURL, ImageURL | None]:
|
|
864
|
+
image_content.append({"type": "image_url", "image_url": {"url": field_value}})
|
|
596
865
|
|
|
597
|
-
elif
|
|
866
|
+
elif field_type.annotation in [list[ImageURL], list[ImageURL] | None]:
|
|
598
867
|
for image_url in field_value:
|
|
599
|
-
|
|
868
|
+
image_content.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
600
869
|
|
|
601
870
|
# pre-encoded images (or list of pre-encoded images)
|
|
602
|
-
elif
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
{"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
|
|
871
|
+
elif field_type.annotation in [ImageBase64, ImageBase64 | None]:
|
|
872
|
+
image_content.append(
|
|
873
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
|
|
606
874
|
)
|
|
607
875
|
|
|
608
|
-
elif
|
|
876
|
+
elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None]:
|
|
609
877
|
for base64_image in field_value:
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
{"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
|
|
878
|
+
image_content.append(
|
|
879
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
|
613
880
|
)
|
|
614
881
|
|
|
615
|
-
return
|
|
882
|
+
return [{"role": "user", "type": "image", "content": image_content}] if len(image_content) > 0 else []
|
|
616
883
|
|
|
617
884
|
def _get_system_prompt(self, **format_kwargs) -> str | None:
|
|
618
885
|
"""
|
|
@@ -631,7 +898,7 @@ class PromptFactory:
|
|
|
631
898
|
|
|
632
899
|
return base_prompt.format(**format_kwargs)
|
|
633
900
|
|
|
634
|
-
def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], **kwargs) -> str:
|
|
901
|
+
def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
|
|
635
902
|
"""
|
|
636
903
|
Returns a list of messages for the chat payload based on the prompt strategy.
|
|
637
904
|
|
|
@@ -648,10 +915,18 @@ class PromptFactory:
|
|
|
648
915
|
# get the base prompt template
|
|
649
916
|
base_prompt = self.BASE_USER_PROMPT_MAP.get(self.prompt_strategy)
|
|
650
917
|
|
|
651
|
-
# get any image messages for the chat payload (will be an empty list if
|
|
652
|
-
image_messages = (
|
|
653
|
-
|
|
654
|
-
)
|
|
918
|
+
# get any image messages for the chat payload (will be an empty list if no image fields exist)
|
|
919
|
+
image_messages = self._create_image_messages(candidate, input_fields)
|
|
920
|
+
|
|
921
|
+
# get any audio messages for the chat payload (will be an empty list if no audio fields exist)
|
|
922
|
+
audio_messages = self._create_audio_messages(candidate, input_fields)
|
|
923
|
+
|
|
924
|
+
# get any right image messages for the chat payload (will be an empty list if this is not a join image prompt)
|
|
925
|
+
right_image_messages, right_audio_messages = [], []
|
|
926
|
+
if self.prompt_strategy.is_join_prompt():
|
|
927
|
+
assert right_candidate is not None, "Right candidate must be provided for join prompts."
|
|
928
|
+
right_image_messages = self._create_image_messages(right_candidate, right_input_fields)
|
|
929
|
+
right_audio_messages = self._create_audio_messages(right_candidate, right_input_fields)
|
|
655
930
|
|
|
656
931
|
# get any original messages for critique and refinement operations
|
|
657
932
|
original_messages = kwargs.get("original_messages")
|
|
@@ -660,6 +935,8 @@ class PromptFactory:
|
|
|
660
935
|
"Original messages must be provided for critique and refinement operations."
|
|
661
936
|
)
|
|
662
937
|
|
|
938
|
+
# TODO: in the future if we support many modalities (e.g. images and audio) in the same prompt,
|
|
939
|
+
# then we will need to streamline this logic to handle the many different cases
|
|
663
940
|
# construct the user messages based on the prompt strategy
|
|
664
941
|
user_messages = []
|
|
665
942
|
if self.prompt_strategy.is_critic_prompt() or self.prompt_strategy.is_refine_prompt():
|
|
@@ -670,14 +947,47 @@ class PromptFactory:
|
|
|
670
947
|
user_messages.extend(original_messages)
|
|
671
948
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
672
949
|
|
|
673
|
-
|
|
674
|
-
|
|
950
|
+
# image not join
|
|
951
|
+
elif self.prompt_strategy.is_image_prompt() and not self.prompt_strategy.is_join_prompt():
|
|
952
|
+
base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
|
|
953
|
+
base_prompt_start, base_prompt_end = base_prompt.split("<<image-placeholder>>")
|
|
675
954
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
676
955
|
user_messages.extend(image_messages)
|
|
677
956
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
678
957
|
|
|
958
|
+
# image join
|
|
959
|
+
elif self.prompt_strategy.is_image_prompt() and self.prompt_strategy.is_join_prompt():
|
|
960
|
+
# for join image prompts, we may have two sets of images (one from the left candidate and one from the right candidate)
|
|
961
|
+
base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
|
|
962
|
+
base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<image-placeholder>>")
|
|
963
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
964
|
+
user_messages.extend(image_messages)
|
|
965
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
|
|
966
|
+
user_messages.extend(right_image_messages)
|
|
967
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
968
|
+
|
|
969
|
+
# audio not join
|
|
970
|
+
elif self.prompt_strategy.is_audio_prompt() and not self.prompt_strategy.is_join_prompt():
|
|
971
|
+
base_prompt = base_prompt.replace("<<image-placeholder>>", "")
|
|
972
|
+
base_prompt_start, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
|
|
973
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
974
|
+
user_messages.extend(audio_messages)
|
|
975
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
976
|
+
|
|
977
|
+
# audio join
|
|
978
|
+
elif self.prompt_strategy.is_audio_prompt() and self.prompt_strategy.is_join_prompt():
|
|
979
|
+
# for join image prompts, we may have two sets of images (one from the left candidate and one from the right candidate)
|
|
980
|
+
base_prompt = base_prompt.replace("<<image-placeholder>>", "")
|
|
981
|
+
base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
|
|
982
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
|
|
983
|
+
user_messages.extend(audio_messages)
|
|
984
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
|
|
985
|
+
user_messages.extend(right_audio_messages)
|
|
986
|
+
user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
|
|
987
|
+
|
|
679
988
|
else:
|
|
680
989
|
base_prompt = base_prompt.replace("<<image-placeholder>>", "")
|
|
990
|
+
base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
|
|
681
991
|
user_messages.append({"role": "user", "type": "text", "content": base_prompt.format(**kwargs)})
|
|
682
992
|
|
|
683
993
|
return user_messages
|
|
@@ -720,7 +1030,7 @@ class PromptFactory:
|
|
|
720
1030
|
# build set of format kwargs
|
|
721
1031
|
format_kwargs = {
|
|
722
1032
|
field_name: "<bytes>"
|
|
723
|
-
if
|
|
1033
|
+
if candidate.get_field_type(field_name).annotation in [bytes, bytes | None]
|
|
724
1034
|
else candidate[field_name]
|
|
725
1035
|
for field_name in input_fields
|
|
726
1036
|
}
|
|
@@ -740,7 +1050,7 @@ class PromptFactory:
|
|
|
740
1050
|
|
|
741
1051
|
return messages
|
|
742
1052
|
|
|
743
|
-
def create_messages(self, candidate: DataRecord, output_fields: list[str], **kwargs) -> list[dict]:
|
|
1053
|
+
def create_messages(self, candidate: DataRecord, output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
|
|
744
1054
|
"""
|
|
745
1055
|
Creates the messages for the chat payload based on the prompt strategy.
|
|
746
1056
|
|
|
@@ -754,6 +1064,7 @@ class PromptFactory:
|
|
|
754
1064
|
Args:
|
|
755
1065
|
candidate (DataRecord): The input record.
|
|
756
1066
|
output_fields (list[str]): The output fields.
|
|
1067
|
+
right_candidate (DataRecord | None): The other join input record (only provided for joins).
|
|
757
1068
|
kwargs: The keyword arguments provided by the user.
|
|
758
1069
|
|
|
759
1070
|
Returns:
|
|
@@ -761,6 +1072,7 @@ class PromptFactory:
|
|
|
761
1072
|
"""
|
|
762
1073
|
# compute the set of input fields
|
|
763
1074
|
input_fields = self._get_input_fields(candidate, **kwargs)
|
|
1075
|
+
right_input_fields = [] if right_candidate is None else self._get_input_fields(right_candidate, **kwargs)
|
|
764
1076
|
|
|
765
1077
|
# if the user provides a prompt, we process that prompt into messages and return them
|
|
766
1078
|
if "prompt" in kwargs:
|
|
@@ -774,7 +1086,7 @@ class PromptFactory:
|
|
|
774
1086
|
messages = []
|
|
775
1087
|
|
|
776
1088
|
# compute the full dictionary of format kwargs and add to kwargs
|
|
777
|
-
format_kwargs = self._get_all_format_kwargs(candidate, input_fields, output_fields, **kwargs)
|
|
1089
|
+
format_kwargs = self._get_all_format_kwargs(candidate, input_fields, output_fields, right_candidate, right_input_fields, **kwargs)
|
|
778
1090
|
kwargs = {**kwargs, **format_kwargs}
|
|
779
1091
|
|
|
780
1092
|
# generate system message (if applicable)
|
|
@@ -783,7 +1095,7 @@ class PromptFactory:
|
|
|
783
1095
|
messages.append({"role": "system", "type": "text", "content": system_prompt})
|
|
784
1096
|
|
|
785
1097
|
# generate user messages and add to messages
|
|
786
|
-
user_messages = self._get_user_messages(candidate, input_fields, **kwargs)
|
|
1098
|
+
user_messages = self._get_user_messages(candidate, input_fields, right_candidate, right_input_fields, **kwargs)
|
|
787
1099
|
messages.extend(user_messages)
|
|
788
1100
|
|
|
789
1101
|
return messages
|