palimpzest 0.7.21__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +343 -209
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +639 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +62 -6
  19. palimpzest/prompts/filter_prompts.py +51 -6
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +6 -6
  22. palimpzest/prompts/prompt_factory.py +375 -47
  23. palimpzest/prompts/split_proposer_prompts.py +1 -1
  24. palimpzest/prompts/util_phrases.py +5 -0
  25. palimpzest/prompts/validator.py +239 -0
  26. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  27. palimpzest/query/execution/execution_strategy.py +210 -317
  28. palimpzest/query/execution/execution_strategy_type.py +5 -7
  29. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  30. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  31. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  32. palimpzest/query/generators/generators.py +160 -331
  33. palimpzest/query/operators/__init__.py +15 -5
  34. palimpzest/query/operators/aggregate.py +50 -33
  35. palimpzest/query/operators/compute.py +201 -0
  36. palimpzest/query/operators/convert.py +33 -19
  37. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  38. palimpzest/query/operators/distinct.py +62 -0
  39. palimpzest/query/operators/filter.py +26 -16
  40. palimpzest/query/operators/join.py +403 -0
  41. palimpzest/query/operators/limit.py +3 -3
  42. palimpzest/query/operators/logical.py +205 -77
  43. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  44. palimpzest/query/operators/physical.py +27 -21
  45. palimpzest/query/operators/project.py +3 -3
  46. palimpzest/query/operators/rag_convert.py +7 -7
  47. palimpzest/query/operators/retrieve.py +9 -9
  48. palimpzest/query/operators/scan.py +81 -42
  49. palimpzest/query/operators/search.py +524 -0
  50. palimpzest/query/operators/split_convert.py +10 -8
  51. palimpzest/query/optimizer/__init__.py +7 -9
  52. palimpzest/query/optimizer/cost_model.py +108 -441
  53. palimpzest/query/optimizer/optimizer.py +123 -181
  54. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  55. palimpzest/query/optimizer/plan.py +352 -67
  56. palimpzest/query/optimizer/primitives.py +43 -19
  57. palimpzest/query/optimizer/rules.py +484 -646
  58. palimpzest/query/optimizer/tasks.py +127 -58
  59. palimpzest/query/processor/config.py +42 -76
  60. palimpzest/query/processor/query_processor.py +73 -18
  61. palimpzest/query/processor/query_processor_factory.py +46 -38
  62. palimpzest/schemabuilder/schema_builder.py +15 -28
  63. palimpzest/utils/model_helpers.py +32 -77
  64. palimpzest/utils/progress.py +114 -102
  65. palimpzest/validator/__init__.py +0 -0
  66. palimpzest/validator/validator.py +306 -0
  67. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/METADATA +6 -1
  68. palimpzest-0.8.1.dist-info/RECORD +95 -0
  69. palimpzest/core/lib/fields.py +0 -141
  70. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  71. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  72. palimpzest/query/generators/api_client_factory.py +0 -30
  73. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  74. palimpzest/query/operators/map.py +0 -130
  75. palimpzest/query/processor/nosentinel_processor.py +0 -33
  76. palimpzest/query/processor/processing_strategy_type.py +0 -28
  77. palimpzest/query/processor/sentinel_processor.py +0 -88
  78. palimpzest/query/processor/streaming_processor.py +0 -149
  79. palimpzest/sets.py +0 -405
  80. palimpzest/utils/datareader_helpers.py +0 -61
  81. palimpzest/utils/demo_helpers.py +0 -75
  82. palimpzest/utils/field_helpers.py +0 -69
  83. palimpzest/utils/generation_helpers.py +0 -69
  84. palimpzest/utils/sandbox.py +0 -183
  85. palimpzest-0.7.21.dist-info/RECORD +0 -95
  86. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  87. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/WHEEL +0 -0
  88. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/licenses/LICENSE +0 -0
  89. {palimpzest-0.7.21.dist-info → palimpzest-0.8.1.dist-info}/top_level.txt +0 -0
@@ -4,17 +4,25 @@ import base64
4
4
  import json
5
5
  from string import Formatter
6
6
 
7
+ from pydantic import BaseModel
8
+
7
9
  from palimpzest.constants import (
8
- MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT,
10
+ LLAMA_CONTEXT_TOKENS_LIMIT,
9
11
  TOKENS_PER_CHARACTER,
10
12
  Cardinality,
11
13
  Model,
12
14
  PromptStrategy,
13
15
  )
14
16
  from palimpzest.core.elements.records import DataRecord
15
- from palimpzest.core.lib.fields import BytesField, ImageBase64Field, ImageFilepathField, ImageURLField
16
- from palimpzest.core.lib.schemas import Schema
17
+ from palimpzest.core.lib.schemas import AudioBase64, AudioFilepath, ImageBase64, ImageFilepath, ImageURL
17
18
  from palimpzest.prompts.convert_prompts import (
19
+ COT_QA_AUDIO_DISCLAIMER,
20
+ COT_QA_AUDIO_EXAMPLE_ANSWER,
21
+ COT_QA_AUDIO_EXAMPLE_CONTEXT,
22
+ COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
23
+ COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
24
+ COT_QA_AUDIO_EXAMPLE_REASONING,
25
+ COT_QA_AUDIO_JOB_INSTRUCTION,
18
26
  COT_QA_BASE_SYSTEM_PROMPT,
19
27
  COT_QA_BASE_USER_PROMPT,
20
28
  COT_QA_EXAMPLE_ANSWER,
@@ -30,6 +38,8 @@ from palimpzest.prompts.convert_prompts import (
30
38
  COT_QA_IMAGE_EXAMPLE_REASONING,
31
39
  COT_QA_IMAGE_JOB_INSTRUCTION,
32
40
  COT_QA_JOB_INSTRUCTION,
41
+ COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
42
+ COT_QA_NO_REASONING_BASE_USER_PROMPT,
33
43
  )
34
44
  from palimpzest.prompts.critique_and_refine_convert_prompts import (
35
45
  BASE_CRITIQUE_PROMPT,
@@ -42,6 +52,12 @@ from palimpzest.prompts.critique_and_refine_convert_prompts import (
42
52
  COT_QA_REFINEMENT_FINISH_INSTRUCTION,
43
53
  )
44
54
  from palimpzest.prompts.filter_prompts import (
55
+ COT_BOOL_AUDIO_DISCLAIMER,
56
+ COT_BOOL_AUDIO_EXAMPLE_CONTEXT,
57
+ COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
58
+ COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
59
+ COT_BOOL_AUDIO_EXAMPLE_REASONING,
60
+ COT_BOOL_AUDIO_JOB_INSTRUCTION,
45
61
  COT_BOOL_BASE_SYSTEM_PROMPT,
46
62
  COT_BOOL_BASE_USER_PROMPT,
47
63
  COT_BOOL_EXAMPLE_CONTEXT,
@@ -55,6 +71,39 @@ from palimpzest.prompts.filter_prompts import (
55
71
  COT_BOOL_IMAGE_EXAMPLE_REASONING,
56
72
  COT_BOOL_IMAGE_JOB_INSTRUCTION,
57
73
  COT_BOOL_JOB_INSTRUCTION,
74
+ COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
75
+ COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
76
+ )
77
+ from palimpzest.prompts.join_prompts import (
78
+ COT_JOIN_AUDIO_DISCLAIMER,
79
+ COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
80
+ COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
81
+ COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
82
+ COT_JOIN_AUDIO_EXAMPLE_REASONING,
83
+ COT_JOIN_AUDIO_JOB_INSTRUCTION,
84
+ COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
85
+ COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
86
+ COT_JOIN_BASE_SYSTEM_PROMPT,
87
+ COT_JOIN_BASE_USER_PROMPT,
88
+ COT_JOIN_EXAMPLE_CONTEXT,
89
+ COT_JOIN_EXAMPLE_INPUT_FIELDS,
90
+ COT_JOIN_EXAMPLE_JOIN_CONDITION,
91
+ COT_JOIN_EXAMPLE_REASONING,
92
+ COT_JOIN_IMAGE_DISCLAIMER,
93
+ COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
94
+ COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
95
+ COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
96
+ COT_JOIN_IMAGE_EXAMPLE_REASONING,
97
+ COT_JOIN_IMAGE_JOB_INSTRUCTION,
98
+ COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
99
+ COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
100
+ COT_JOIN_JOB_INSTRUCTION,
101
+ COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
102
+ COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
103
+ COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
104
+ COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
105
+ COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
106
+ COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
58
107
  )
59
108
  from palimpzest.prompts.moa_aggregator_convert_prompts import (
60
109
  COT_MOA_AGG_BASE_SYSTEM_PROMPT,
@@ -89,6 +138,7 @@ from palimpzest.prompts.split_proposer_prompts import (
89
138
  SPLIT_PROPOSER_JOB_INSTRUCTION,
90
139
  )
91
140
  from palimpzest.prompts.util_phrases import (
141
+ DESC_SECTION,
92
142
  ONE_TO_MANY_OUTPUT_FORMAT_INSTRUCTION,
93
143
  ONE_TO_ONE_OUTPUT_FORMAT_INSTRUCTION,
94
144
  )
@@ -99,11 +149,25 @@ class PromptFactory:
99
149
 
100
150
  BASE_SYSTEM_PROMPT_MAP = {
101
151
  PromptStrategy.COT_BOOL: COT_BOOL_BASE_SYSTEM_PROMPT,
152
+ PromptStrategy.COT_BOOL_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
153
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_BASE_SYSTEM_PROMPT,
154
+ PromptStrategy.COT_BOOL_AUDIO_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
102
155
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_SYSTEM_PROMPT,
156
+ PromptStrategy.COT_BOOL_IMAGE_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
157
+ PromptStrategy.COT_JOIN: COT_JOIN_BASE_SYSTEM_PROMPT,
158
+ PromptStrategy.COT_JOIN_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
159
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_BASE_SYSTEM_PROMPT,
160
+ PromptStrategy.COT_JOIN_AUDIO_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
161
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_BASE_SYSTEM_PROMPT,
162
+ PromptStrategy.COT_JOIN_IMAGE_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
103
163
  PromptStrategy.COT_QA: COT_QA_BASE_SYSTEM_PROMPT,
164
+ PromptStrategy.COT_QA_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
165
+ PromptStrategy.COT_QA_AUDIO: COT_QA_BASE_SYSTEM_PROMPT,
166
+ PromptStrategy.COT_QA_AUDIO_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
104
167
  PromptStrategy.COT_QA_CRITIC: None,
105
168
  PromptStrategy.COT_QA_REFINE: None,
106
169
  PromptStrategy.COT_QA_IMAGE: COT_QA_BASE_SYSTEM_PROMPT,
170
+ PromptStrategy.COT_QA_IMAGE_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
107
171
  PromptStrategy.COT_QA_IMAGE_CRITIC: None,
108
172
  PromptStrategy.COT_QA_IMAGE_REFINE: None,
109
173
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
@@ -114,11 +178,25 @@ class PromptFactory:
114
178
  }
115
179
  BASE_USER_PROMPT_MAP = {
116
180
  PromptStrategy.COT_BOOL: COT_BOOL_BASE_USER_PROMPT,
181
+ PromptStrategy.COT_BOOL_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
182
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_BASE_USER_PROMPT,
183
+ PromptStrategy.COT_BOOL_AUDIO_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
117
184
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_USER_PROMPT,
185
+ PromptStrategy.COT_BOOL_IMAGE_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
186
+ PromptStrategy.COT_JOIN: COT_JOIN_BASE_USER_PROMPT,
187
+ PromptStrategy.COT_JOIN_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
188
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_BASE_USER_PROMPT,
189
+ PromptStrategy.COT_JOIN_AUDIO_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
190
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_BASE_USER_PROMPT,
191
+ PromptStrategy.COT_JOIN_IMAGE_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
118
192
  PromptStrategy.COT_QA: COT_QA_BASE_USER_PROMPT,
193
+ PromptStrategy.COT_QA_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
194
+ PromptStrategy.COT_QA_AUDIO: COT_QA_BASE_USER_PROMPT,
195
+ PromptStrategy.COT_QA_AUDIO_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
119
196
  PromptStrategy.COT_QA_CRITIC: BASE_CRITIQUE_PROMPT,
120
197
  PromptStrategy.COT_QA_REFINE: BASE_REFINEMENT_PROMPT,
121
198
  PromptStrategy.COT_QA_IMAGE: COT_QA_BASE_USER_PROMPT,
199
+ PromptStrategy.COT_QA_IMAGE_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
122
200
  PromptStrategy.COT_QA_IMAGE_CRITIC: BASE_CRITIQUE_PROMPT,
123
201
  PromptStrategy.COT_QA_IMAGE_REFINE: BASE_REFINEMENT_PROMPT,
124
202
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_BASE_USER_PROMPT,
@@ -128,10 +206,11 @@ class PromptFactory:
128
206
  PromptStrategy.SPLIT_MERGER: COT_SPLIT_MERGER_BASE_USER_PROMPT,
129
207
  }
130
208
 
131
- def __init__(self, prompt_strategy: PromptStrategy, model: Model, cardinality: Cardinality) -> None:
209
+ def __init__(self, prompt_strategy: PromptStrategy, model: Model, cardinality: Cardinality, desc: str | None = None) -> None:
132
210
  self.prompt_strategy = prompt_strategy
133
211
  self.model = model
134
212
  self.cardinality = cardinality
213
+ self.desc = desc
135
214
 
136
215
  def _get_context(self, candidate: DataRecord, input_fields: list[str]) -> str:
137
216
  """
@@ -144,8 +223,9 @@ class PromptFactory:
144
223
  Returns:
145
224
  str: The context.
146
225
  """
226
+ # TODO: remove mask_filepaths=True after SemBench evaluation
147
227
  # get context from input record (project_cols will be None if not provided in kwargs)
148
- context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields)
228
+ context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True)
149
229
 
150
230
  # TODO: MOVE THIS LOGIC INTO A CHUNKING / CONTEXT MANAGEMENT CLASS
151
231
  # - this class should be able to:
@@ -155,12 +235,12 @@ class PromptFactory:
155
235
  # TODO: this does not work for image prompts
156
236
  # TODO: this ignores the size of the `orignal_messages` in critique and refine prompts
157
237
  # cut down on context based on window length
158
- if self.model.is_llama_model() or self.model.is_mixtral_model():
238
+ if self.model.is_llama_model():
159
239
  total_context_len = len(json.dumps(context, indent=2))
160
240
 
161
241
  # sort fields by length and progressively strip from the longest field until it is short enough;
162
- # NOTE: MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT is a rough estimate which leaves room for the rest of the prompt text
163
- while total_context_len * TOKENS_PER_CHARACTER > MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT:
242
+ # NOTE: LLAMA_CONTEXT_TOKENS_LIMIT is a rough estimate which leaves room for the rest of the prompt text
243
+ while total_context_len * TOKENS_PER_CHARACTER > LLAMA_CONTEXT_TOKENS_LIMIT:
164
244
  # sort fields by length
165
245
  field_lengths = [(field, len(value) if value is not None else 0) for field, value in context.items()]
166
246
  sorted_fields = sorted(field_lengths, key=lambda item: item[1], reverse=True)
@@ -169,7 +249,7 @@ class PromptFactory:
169
249
  longest_field_name, longest_field_length = sorted_fields[0]
170
250
 
171
251
  # trim the field
172
- context_factor = MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT / (total_context_len * TOKENS_PER_CHARACTER)
252
+ context_factor = LLAMA_CONTEXT_TOKENS_LIMIT / (total_context_len * TOKENS_PER_CHARACTER)
173
253
  keep_frac_idx = int(longest_field_length * context_factor)
174
254
  context[longest_field_name] = context[longest_field_name][:keep_frac_idx]
175
255
 
@@ -191,7 +271,11 @@ class PromptFactory:
191
271
  Returns:
192
272
  list[str]: The list of input field names.
193
273
  """
194
- return kwargs.get("project_cols", candidate.get_field_names())
274
+ # NOTE: joins will include left and right input fields in project_cols, so we have to check
275
+ # if the field is in the candidate record
276
+ input_fields = kwargs.get("project_cols", candidate.get_field_names())
277
+ input_fields = [field for field in input_fields if field in candidate.get_field_names()]
278
+ return input_fields
195
279
 
196
280
  def _get_input_fields_desc(self, candidate: DataRecord, input_fields: list[str]) -> str:
197
281
  """
@@ -205,7 +289,7 @@ class PromptFactory:
205
289
  """
206
290
  input_fields_desc = ""
207
291
  for field_name in input_fields:
208
- input_fields_desc += f"- {field_name}: {candidate.get_field_type(field_name)._desc}\n"
292
+ input_fields_desc += f"- {field_name}: {candidate.get_field_type(field_name).description}\n"
209
293
 
210
294
  return input_fields_desc[:-1]
211
295
 
@@ -221,13 +305,13 @@ class PromptFactory:
221
305
  str: The output fields description.
222
306
  """
223
307
  output_fields_desc = ""
224
- output_schema: Schema = kwargs.get("output_schema")
308
+ output_schema: BaseModel = kwargs.get("output_schema")
225
309
  if self.prompt_strategy.is_convert_prompt():
226
310
  assert output_schema is not None, "Output schema must be provided for convert prompts."
227
311
 
228
- field_desc_map = output_schema.field_desc_map()
229
- for field_name in output_fields:
230
- output_fields_desc += f"- {field_name}: {field_desc_map[field_name]}\n"
312
+ for field_name in sorted(output_fields):
313
+ desc = output_schema.model_fields[field_name].description
314
+ output_fields_desc += f"- {field_name}: {'no description available' if desc is None else desc}\n"
231
315
 
232
316
  # strip the last newline characters from the field descriptions and return
233
317
  return output_fields_desc[:-1]
@@ -245,6 +329,19 @@ class PromptFactory:
245
329
 
246
330
  return filter_condition
247
331
 
332
+ def _get_join_condition(self, **kwargs) -> str | None:
333
+ """
334
+ Returns the join condition for the join operation.
335
+
336
+ Returns:
337
+ str | None: The join condition (if applicable).
338
+ """
339
+ join_condition = kwargs.get("join_condition")
340
+ if self.prompt_strategy.is_join_prompt():
341
+ assert join_condition is not None, "Join condition must be provided for join operations."
342
+
343
+ return join_condition
344
+
248
345
  def _get_original_output(self, **kwargs) -> str | None:
249
346
  """
250
347
  Returns the original output from a previous model generation for the critique and refinement operations.
@@ -337,8 +434,13 @@ class PromptFactory:
337
434
  """
338
435
  prompt_strategy_to_job_instruction = {
339
436
  PromptStrategy.COT_BOOL: COT_BOOL_JOB_INSTRUCTION,
437
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_JOB_INSTRUCTION,
340
438
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_JOB_INSTRUCTION,
439
+ PromptStrategy.COT_JOIN: COT_JOIN_JOB_INSTRUCTION,
440
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_JOB_INSTRUCTION,
441
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_JOB_INSTRUCTION,
341
442
  PromptStrategy.COT_QA: COT_QA_JOB_INSTRUCTION,
443
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_JOB_INSTRUCTION,
342
444
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_JOB_INSTRUCTION,
343
445
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_JOB_INSTRUCTION,
344
446
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_JOB_INSTRUCTION,
@@ -346,6 +448,19 @@ class PromptFactory:
346
448
  }
347
449
  return prompt_strategy_to_job_instruction.get(self.prompt_strategy)
348
450
 
451
+ def _get_desc_section(self) -> str:
452
+ """
453
+ Returns the description section for the prompt.
454
+
455
+ Returns:
456
+ str: The description section (if applicable).
457
+ """
458
+ desc_section = ""
459
+ if self.desc is not None:
460
+ desc_section = DESC_SECTION.format(desc=self.desc)
461
+
462
+ return desc_section
463
+
349
464
  def _get_critique_criteria(self) -> str | None:
350
465
  """
351
466
  Returns the critique criteria for the critique operation.
@@ -402,8 +517,13 @@ class PromptFactory:
402
517
  """
403
518
  prompt_strategy_to_example_input_fields = {
404
519
  PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_INPUT_FIELDS,
520
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
405
521
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_INPUT_FIELDS,
522
+ PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_INPUT_FIELDS,
523
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
524
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
406
525
  PromptStrategy.COT_QA: COT_QA_EXAMPLE_INPUT_FIELDS,
526
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
407
527
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_INPUT_FIELDS,
408
528
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_INPUT_FIELDS,
409
529
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_INPUT_FIELDS,
@@ -412,6 +532,21 @@ class PromptFactory:
412
532
 
413
533
  return prompt_strategy_to_example_input_fields.get(self.prompt_strategy)
414
534
 
535
+ def _get_right_example_input_fields(self) -> str | None:
536
+ """
537
+ Returns the example right input fields for the join prompt.
538
+
539
+ Returns:
540
+ str | None: The example right input fields (if applicable).
541
+ """
542
+ prompt_strategy_to_right_example_input_fields = {
543
+ PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
544
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
545
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
546
+ }
547
+
548
+ return prompt_strategy_to_right_example_input_fields.get(self.prompt_strategy)
549
+
415
550
  def _get_example_output_fields(self) -> str | None:
416
551
  """
417
552
  Returns the example output fields for the prompt.
@@ -421,6 +556,7 @@ class PromptFactory:
421
556
  """
422
557
  prompt_strategy_to_example_output_fields = {
423
558
  PromptStrategy.COT_QA: COT_QA_EXAMPLE_OUTPUT_FIELDS,
559
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
424
560
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_OUTPUT_FIELDS,
425
561
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
426
562
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_OUTPUT_FIELDS,
@@ -438,8 +574,13 @@ class PromptFactory:
438
574
  """
439
575
  prompt_strategy_to_example_context = {
440
576
  PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_CONTEXT,
577
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_CONTEXT,
441
578
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_CONTEXT,
579
+ PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_CONTEXT,
580
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
581
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
442
582
  PromptStrategy.COT_QA: COT_QA_EXAMPLE_CONTEXT,
583
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_CONTEXT,
443
584
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_CONTEXT,
444
585
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_CONTEXT,
445
586
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_CONTEXT,
@@ -448,6 +589,21 @@ class PromptFactory:
448
589
 
449
590
  return prompt_strategy_to_example_context.get(self.prompt_strategy)
450
591
 
592
+ def _get_right_example_context(self) -> str | None:
593
+ """
594
+ Returns the right example context for the join prompt.
595
+
596
+ Returns:
597
+ str | None: The right example context (if applicable).
598
+ """
599
+ prompt_strategy_to_right_example_context = {
600
+ PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
601
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
602
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
603
+ }
604
+
605
+ return prompt_strategy_to_right_example_context.get(self.prompt_strategy)
606
+
451
607
  def _get_image_disclaimer(self) -> str:
452
608
  """
453
609
  Returns the image disclaimer for the prompt. The disclaimer must be an empty string
@@ -458,12 +614,57 @@ class PromptFactory:
458
614
  """
459
615
  prompt_strategy_to_image_disclaimer = {
460
616
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_DISCLAIMER,
617
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_DISCLAIMER,
461
618
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_DISCLAIMER,
462
619
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_DISCLAIMER,
463
620
  }
464
621
 
465
622
  return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
466
623
 
624
+ def _get_audio_disclaimer(self) -> str:
625
+ """
626
+ Returns the audio disclaimer for the prompt. The disclaimer must be an empty string
627
+ for text prompts.
628
+
629
+ Returns:
630
+ str: The audio disclaimer. If this is a text prompt then it is an empty string.
631
+ """
632
+ prompt_strategy_to_audio_disclaimer = {
633
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_DISCLAIMER,
634
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_DISCLAIMER,
635
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_DISCLAIMER,
636
+ }
637
+
638
+ return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
639
+
640
+ def _get_right_image_disclaimer(self) -> str:
641
+ """
642
+ Returns the right image disclaimer for the prompt. The disclaimer must be an empty string
643
+ for text prompts.
644
+
645
+ Returns:
646
+ str: The right image disclaimer. If this is a text prompt then it is an empty string.
647
+ """
648
+ prompt_strategy_to_image_disclaimer = {
649
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
650
+ }
651
+
652
+ return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
653
+
654
+ def _get_right_audio_disclaimer(self) -> str:
655
+ """
656
+ Returns the right audio disclaimer for the prompt. The disclaimer must be an empty string
657
+ for text prompts.
658
+
659
+ Returns:
660
+ str: The right audio disclaimer. If this is a text prompt then it is an empty string.
661
+ """
662
+ prompt_strategy_to_audio_disclaimer = {
663
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
664
+ }
665
+
666
+ return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
667
+
467
668
  def _get_example_filter_condition(self) -> str | None:
468
669
  """
469
670
  Returns the example filter condition for the prompt.
@@ -473,11 +674,27 @@ class PromptFactory:
473
674
  """
474
675
  prompt_strategy_to_example_filter_condition = {
475
676
  PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_FILTER_CONDITION,
677
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
476
678
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_FILTER_CONDITION,
477
679
  }
478
680
 
479
681
  return prompt_strategy_to_example_filter_condition.get(self.prompt_strategy)
480
682
 
683
+ def _get_example_join_condition(self) -> str | None:
684
+ """
685
+ Returns the example join condition for the prompt.
686
+
687
+ Returns:
688
+ str | None: The example join condition (if applicable).
689
+ """
690
+ prompt_strategy_to_example_join_condition = {
691
+ PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_JOIN_CONDITION,
692
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
693
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
694
+ }
695
+
696
+ return prompt_strategy_to_example_join_condition.get(self.prompt_strategy)
697
+
481
698
  def _get_example_reasoning(self) -> str | None:
482
699
  """
483
700
  Returns the example reasoning for the prompt.
@@ -487,8 +704,13 @@ class PromptFactory:
487
704
  """
488
705
  prompt_strategy_to_example_reasoning = {
489
706
  PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_REASONING,
707
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_REASONING,
490
708
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_REASONING,
709
+ PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_REASONING,
710
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_REASONING,
711
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_REASONING,
491
712
  PromptStrategy.COT_QA: COT_QA_EXAMPLE_REASONING,
713
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_REASONING,
492
714
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_REASONING,
493
715
  }
494
716
 
@@ -503,6 +725,7 @@ class PromptFactory:
503
725
  """
504
726
  prompt_strategy_to_example_answer = {
505
727
  PromptStrategy.COT_QA: COT_QA_EXAMPLE_ANSWER,
728
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_ANSWER,
506
729
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_ANSWER,
507
730
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_ANSWER,
508
731
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_ANSWER,
@@ -512,7 +735,7 @@ class PromptFactory:
512
735
  return prompt_strategy_to_example_answer.get(self.prompt_strategy)
513
736
 
514
737
  def _get_all_format_kwargs(
515
- self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], **kwargs
738
+ self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs
516
739
  ) -> dict:
517
740
  """
518
741
  Returns a dictionary containing all the format kwargs for templating the prompts.
@@ -532,24 +755,39 @@ class PromptFactory:
532
755
  "input_fields_desc": self._get_input_fields_desc(candidate, input_fields),
533
756
  "output_fields_desc": self._get_output_fields_desc(output_fields, **kwargs),
534
757
  "filter_condition": self._get_filter_condition(**kwargs),
758
+ "join_condition": self._get_join_condition(**kwargs),
535
759
  "original_output": self._get_original_output(**kwargs),
536
760
  "critique_output": self._get_critique_output(**kwargs),
537
761
  "model_responses": self._get_model_responses(**kwargs),
538
762
  "chunk_outputs": self._get_chunk_outputs(**kwargs),
539
763
  }
540
764
 
765
+ # if a right candidate is provided, we also get the context and input field descriptions for the right candidate
766
+ if right_candidate is not None:
767
+ input_format_kwargs.update({
768
+ "right_context": self._get_context(right_candidate, right_input_fields),
769
+ "right_input_fields_desc": self._get_input_fields_desc(right_candidate, right_input_fields),
770
+ })
771
+
541
772
  # get format kwargs which depend on the prompt strategy
542
773
  prompt_strategy_format_kwargs = {
543
774
  "output_format_instruction": self._get_output_format_instruction(),
544
775
  "job_instruction": self._get_job_instruction(),
776
+ "desc_section": self._get_desc_section(),
545
777
  "critique_criteria": self._get_critique_criteria(),
546
778
  "refinement_criteria": self._get_refinement_criteria(),
547
779
  "finish_instruction": self._get_finish_instruction(),
548
780
  "example_input_fields": self._get_example_input_fields(),
781
+ "right_example_input_fields": self._get_right_example_input_fields(),
549
782
  "example_output_fields": self._get_example_output_fields(),
550
783
  "example_context": self._get_example_context(),
784
+ "right_example_context": self._get_right_example_context(),
551
785
  "image_disclaimer": self._get_image_disclaimer(),
786
+ "audio_disclaimer": self._get_audio_disclaimer(),
787
+ "right_image_disclaimer": self._get_right_image_disclaimer(),
788
+ "right_audio_disclaimer": self._get_right_audio_disclaimer(),
552
789
  "example_filter_condition": self._get_example_filter_condition(),
790
+ "example_join_condition": self._get_example_join_condition(),
553
791
  "example_reasoning": self._get_example_reasoning(),
554
792
  "example_answer": self._get_example_answer(),
555
793
  }
@@ -557,6 +795,53 @@ class PromptFactory:
557
795
  # return all format kwargs
558
796
  return {**input_format_kwargs, **prompt_strategy_format_kwargs}
559
797
 
798
+ def _create_audio_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
799
+ """
800
+ Parses the candidate record and returns the audio messages for the chat payload.
801
+
802
+ Args:
803
+ candidate (DataRecord): The input record.
804
+ input_fields (list[str]): The list of input fields.
805
+
806
+ Returns:
807
+ list[dict]: The audio messages for the chat payload.
808
+ """
809
+ # create a message for each audio recording in an input field with an audio (or list of audio) type
810
+ audio_content = []
811
+ for field_name in input_fields:
812
+ field_value = candidate[field_name]
813
+ field_type = candidate.get_field_type(field_name)
814
+
815
+ # audio filepath (or list of audio filepaths)
816
+ if field_type.annotation in [AudioFilepath, AudioFilepath | None]:
817
+ with open(field_value, "rb") as f:
818
+ base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
819
+ audio_content.append(
820
+ {"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
821
+ )
822
+
823
+ elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None]:
824
+ for audio_filepath in field_value:
825
+ with open(audio_filepath, "rb") as f:
826
+ base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
827
+ audio_content.append(
828
+ {"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
829
+ )
830
+
831
+ # pre-encoded images (or list of pre-encoded images)
832
+ elif field_type.annotation in [AudioBase64, AudioBase64 | None]:
833
+ audio_content.append(
834
+ {"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
835
+ )
836
+
837
+ elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None]:
838
+ for base64_audio in field_value:
839
+ audio_content.append(
840
+ {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
841
+ )
842
+
843
+ return [{"role": "user", "type": "input_audio", "content": audio_content}] if len(audio_content) > 0 else []
844
+
560
845
  def _create_image_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
561
846
  """
562
847
  Parses the candidate record and returns the image messages for the chat payload.
@@ -569,50 +854,48 @@ class PromptFactory:
569
854
  list[dict]: The image messages for the chat payload.
570
855
  """
571
856
  # create a message for each image in an input field with an image (or list of image) type
572
- image_messages = []
857
+ image_content = []
573
858
  for field_name in input_fields:
574
859
  field_value = candidate[field_name]
575
860
  field_type = candidate.get_field_type(field_name)
576
861
 
577
862
  # image filepath (or list of image filepaths)
578
- if isinstance(field_type, ImageFilepathField):
863
+ if field_type.annotation in [ImageFilepath, ImageFilepath | None]:
579
864
  with open(field_value, "rb") as f:
580
865
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
581
- image_messages.append(
582
- {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
866
+ image_content.append(
867
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
583
868
  )
584
869
 
585
- elif hasattr(field_type, "element_type") and issubclass(field_type.element_type, ImageFilepathField):
870
+ elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None]:
586
871
  for image_filepath in field_value:
587
872
  with open(image_filepath, "rb") as f:
588
873
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
589
- image_messages.append(
590
- {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
874
+ image_content.append(
875
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
591
876
  )
592
877
 
593
878
  # image url (or list of image urls)
594
- elif isinstance(field_type, ImageURLField):
595
- image_messages.append({"role": "user", "type": "image", "content": field_value})
879
+ elif field_type.annotation in [ImageURL, ImageURL | None]:
880
+ image_content.append({"type": "image_url", "image_url": {"url": field_value}})
596
881
 
597
- elif hasattr(field_type, "element_type") and issubclass(field_type.element_type, ImageURLField):
882
+ elif field_type.annotation in [list[ImageURL], list[ImageURL] | None]:
598
883
  for image_url in field_value:
599
- image_messages.append({"role": "user", "type": "image", "content": image_url})
884
+ image_content.append({"type": "image_url", "image_url": {"url": image_url}})
600
885
 
601
886
  # pre-encoded images (or list of pre-encoded images)
602
- elif isinstance(field_type, ImageBase64Field):
603
- base64_image_str = field_value.decode("utf-8")
604
- image_messages.append(
605
- {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
887
+ elif field_type.annotation in [ImageBase64, ImageBase64 | None]:
888
+ image_content.append(
889
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
606
890
  )
607
891
 
608
- elif hasattr(field_type, "element_type") and issubclass(field_type.element_type, ImageBase64Field):
892
+ elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None]:
609
893
  for base64_image in field_value:
610
- base64_image_str = base64_image.decode("utf-8")
611
- image_messages.append(
612
- {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
894
+ image_content.append(
895
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
613
896
  )
614
897
 
615
- return image_messages
898
+ return [{"role": "user", "type": "image", "content": image_content}] if len(image_content) > 0 else []
616
899
 
617
900
  def _get_system_prompt(self, **format_kwargs) -> str | None:
618
901
  """
@@ -631,7 +914,7 @@ class PromptFactory:
631
914
 
632
915
  return base_prompt.format(**format_kwargs)
633
916
 
634
- def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], **kwargs) -> str:
917
+ def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
635
918
  """
636
919
  Returns a list of messages for the chat payload based on the prompt strategy.
637
920
 
@@ -648,10 +931,18 @@ class PromptFactory:
648
931
  # get the base prompt template
649
932
  base_prompt = self.BASE_USER_PROMPT_MAP.get(self.prompt_strategy)
650
933
 
651
- # get any image messages for the chat payload (will be an empty list if this is not an image prompt)
652
- image_messages = (
653
- self._create_image_messages(candidate, input_fields) if self.prompt_strategy.is_image_prompt() else []
654
- )
934
+ # get any image messages for the chat payload (will be an empty list if no image fields exist)
935
+ image_messages = self._create_image_messages(candidate, input_fields)
936
+
937
+ # get any audio messages for the chat payload (will be an empty list if no audio fields exist)
938
+ audio_messages = self._create_audio_messages(candidate, input_fields)
939
+
940
+ # get any right image messages for the chat payload (will be an empty list if this is not a join image prompt)
941
+ right_image_messages, right_audio_messages = [], []
942
+ if self.prompt_strategy.is_join_prompt():
943
+ assert right_candidate is not None, "Right candidate must be provided for join prompts."
944
+ right_image_messages = self._create_image_messages(right_candidate, right_input_fields)
945
+ right_audio_messages = self._create_audio_messages(right_candidate, right_input_fields)
655
946
 
656
947
  # get any original messages for critique and refinement operations
657
948
  original_messages = kwargs.get("original_messages")
@@ -660,6 +951,8 @@ class PromptFactory:
660
951
  "Original messages must be provided for critique and refinement operations."
661
952
  )
662
953
 
954
+ # TODO: in the future if we support many modalities (e.g. images and audio) in the same prompt,
955
+ # then we will need to streamline this logic to handle the many different cases
663
956
  # construct the user messages based on the prompt strategy
664
957
  user_messages = []
665
958
  if self.prompt_strategy.is_critic_prompt() or self.prompt_strategy.is_refine_prompt():
@@ -670,14 +963,47 @@ class PromptFactory:
670
963
  user_messages.extend(original_messages)
671
964
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
672
965
 
673
- elif self.prompt_strategy.is_image_prompt():
674
- base_prompt_start, base_prompt_end = base_prompt.split("<<image-placeholder>>\n")
966
+ # image not join
967
+ elif self.prompt_strategy.is_image_prompt() and not self.prompt_strategy.is_join_prompt():
968
+ base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
969
+ base_prompt_start, base_prompt_end = base_prompt.split("<<image-placeholder>>")
675
970
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
676
971
  user_messages.extend(image_messages)
677
972
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
678
973
 
974
+ # image join
975
+ elif self.prompt_strategy.is_image_prompt() and self.prompt_strategy.is_join_prompt():
976
+ # for join image prompts, we may have two sets of images (one from the left candidate and one from the right candidate)
977
+ base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
978
+ base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<image-placeholder>>")
979
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
980
+ user_messages.extend(image_messages)
981
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
982
+ user_messages.extend(right_image_messages)
983
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
984
+
985
+ # audio not join
986
+ elif self.prompt_strategy.is_audio_prompt() and not self.prompt_strategy.is_join_prompt():
987
+ base_prompt = base_prompt.replace("<<image-placeholder>>", "")
988
+ base_prompt_start, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
989
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
990
+ user_messages.extend(audio_messages)
991
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
992
+
993
+ # audio join
994
+ elif self.prompt_strategy.is_audio_prompt() and self.prompt_strategy.is_join_prompt():
995
+ # for join image prompts, we may have two sets of images (one from the left candidate and one from the right candidate)
996
+ base_prompt = base_prompt.replace("<<image-placeholder>>", "")
997
+ base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
998
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
999
+ user_messages.extend(audio_messages)
1000
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
1001
+ user_messages.extend(right_audio_messages)
1002
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
1003
+
679
1004
  else:
680
1005
  base_prompt = base_prompt.replace("<<image-placeholder>>", "")
1006
+ base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
681
1007
  user_messages.append({"role": "user", "type": "text", "content": base_prompt.format(**kwargs)})
682
1008
 
683
1009
  return user_messages
@@ -720,7 +1046,7 @@ class PromptFactory:
720
1046
  # build set of format kwargs
721
1047
  format_kwargs = {
722
1048
  field_name: "<bytes>"
723
- if isinstance(candidate.get_field_type(field_name), BytesField)
1049
+ if candidate.get_field_type(field_name).annotation in [bytes, bytes | None]
724
1050
  else candidate[field_name]
725
1051
  for field_name in input_fields
726
1052
  }
@@ -740,7 +1066,7 @@ class PromptFactory:
740
1066
 
741
1067
  return messages
742
1068
 
743
- def create_messages(self, candidate: DataRecord, output_fields: list[str], **kwargs) -> list[dict]:
1069
+ def create_messages(self, candidate: DataRecord, output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
744
1070
  """
745
1071
  Creates the messages for the chat payload based on the prompt strategy.
746
1072
 
@@ -754,6 +1080,7 @@ class PromptFactory:
754
1080
  Args:
755
1081
  candidate (DataRecord): The input record.
756
1082
  output_fields (list[str]): The output fields.
1083
+ right_candidate (DataRecord | None): The other join input record (only provided for joins).
757
1084
  kwargs: The keyword arguments provided by the user.
758
1085
 
759
1086
  Returns:
@@ -761,6 +1088,7 @@ class PromptFactory:
761
1088
  """
762
1089
  # compute the set of input fields
763
1090
  input_fields = self._get_input_fields(candidate, **kwargs)
1091
+ right_input_fields = [] if right_candidate is None else self._get_input_fields(right_candidate, **kwargs)
764
1092
 
765
1093
  # if the user provides a prompt, we process that prompt into messages and return them
766
1094
  if "prompt" in kwargs:
@@ -774,7 +1102,7 @@ class PromptFactory:
774
1102
  messages = []
775
1103
 
776
1104
  # compute the full dictionary of format kwargs and add to kwargs
777
- format_kwargs = self._get_all_format_kwargs(candidate, input_fields, output_fields, **kwargs)
1105
+ format_kwargs = self._get_all_format_kwargs(candidate, input_fields, output_fields, right_candidate, right_input_fields, **kwargs)
778
1106
  kwargs = {**kwargs, **format_kwargs}
779
1107
 
780
1108
  # generate system message (if applicable)
@@ -783,7 +1111,7 @@ class PromptFactory:
783
1111
  messages.append({"role": "system", "type": "text", "content": system_prompt})
784
1112
 
785
1113
  # generate user messages and add to messages
786
- user_messages = self._get_user_messages(candidate, input_fields, **kwargs)
1114
+ user_messages = self._get_user_messages(candidate, input_fields, right_candidate, right_input_fields, **kwargs)
787
1115
  messages.extend(user_messages)
788
1116
 
789
1117
  return messages