palimpzest 0.7.20__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. palimpzest/__init__.py +37 -6
  2. palimpzest/agents/__init__.py +0 -0
  3. palimpzest/agents/compute_agents.py +0 -0
  4. palimpzest/agents/search_agents.py +637 -0
  5. palimpzest/constants.py +259 -197
  6. palimpzest/core/data/context.py +393 -0
  7. palimpzest/core/data/context_manager.py +163 -0
  8. palimpzest/core/data/dataset.py +634 -0
  9. palimpzest/core/data/{datareaders.py → iter_dataset.py} +202 -126
  10. palimpzest/core/elements/groupbysig.py +16 -13
  11. palimpzest/core/elements/records.py +166 -75
  12. palimpzest/core/lib/schemas.py +152 -390
  13. palimpzest/core/{data/dataclasses.py → models.py} +306 -170
  14. palimpzest/policy.py +2 -27
  15. palimpzest/prompts/__init__.py +35 -5
  16. palimpzest/prompts/agent_prompts.py +357 -0
  17. palimpzest/prompts/context_search.py +9 -0
  18. palimpzest/prompts/convert_prompts.py +61 -5
  19. palimpzest/prompts/filter_prompts.py +50 -5
  20. palimpzest/prompts/join_prompts.py +163 -0
  21. palimpzest/prompts/moa_proposer_convert_prompts.py +5 -5
  22. palimpzest/prompts/prompt_factory.py +358 -46
  23. palimpzest/prompts/validator.py +239 -0
  24. palimpzest/query/execution/all_sample_execution_strategy.py +134 -76
  25. palimpzest/query/execution/execution_strategy.py +210 -317
  26. palimpzest/query/execution/execution_strategy_type.py +5 -7
  27. palimpzest/query/execution/mab_execution_strategy.py +249 -136
  28. palimpzest/query/execution/parallel_execution_strategy.py +153 -244
  29. palimpzest/query/execution/single_threaded_execution_strategy.py +107 -64
  30. palimpzest/query/generators/generators.py +157 -330
  31. palimpzest/query/operators/__init__.py +15 -5
  32. palimpzest/query/operators/aggregate.py +50 -33
  33. palimpzest/query/operators/compute.py +201 -0
  34. palimpzest/query/operators/convert.py +27 -21
  35. palimpzest/query/operators/critique_and_refine_convert.py +7 -5
  36. palimpzest/query/operators/distinct.py +62 -0
  37. palimpzest/query/operators/filter.py +22 -13
  38. palimpzest/query/operators/join.py +402 -0
  39. palimpzest/query/operators/limit.py +3 -3
  40. palimpzest/query/operators/logical.py +198 -80
  41. palimpzest/query/operators/mixture_of_agents_convert.py +10 -8
  42. palimpzest/query/operators/physical.py +27 -21
  43. palimpzest/query/operators/project.py +3 -3
  44. palimpzest/query/operators/rag_convert.py +7 -7
  45. palimpzest/query/operators/retrieve.py +9 -9
  46. palimpzest/query/operators/scan.py +81 -42
  47. palimpzest/query/operators/search.py +524 -0
  48. palimpzest/query/operators/split_convert.py +10 -8
  49. palimpzest/query/optimizer/__init__.py +7 -9
  50. palimpzest/query/optimizer/cost_model.py +108 -441
  51. palimpzest/query/optimizer/optimizer.py +123 -181
  52. palimpzest/query/optimizer/optimizer_strategy.py +66 -61
  53. palimpzest/query/optimizer/plan.py +352 -67
  54. palimpzest/query/optimizer/primitives.py +43 -19
  55. palimpzest/query/optimizer/rules.py +484 -646
  56. palimpzest/query/optimizer/tasks.py +127 -58
  57. palimpzest/query/processor/config.py +41 -76
  58. palimpzest/query/processor/query_processor.py +73 -18
  59. palimpzest/query/processor/query_processor_factory.py +46 -38
  60. palimpzest/schemabuilder/schema_builder.py +15 -28
  61. palimpzest/utils/model_helpers.py +27 -77
  62. palimpzest/utils/progress.py +114 -102
  63. palimpzest/validator/__init__.py +0 -0
  64. palimpzest/validator/validator.py +306 -0
  65. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/METADATA +6 -1
  66. palimpzest-0.8.0.dist-info/RECORD +95 -0
  67. palimpzest/core/lib/fields.py +0 -141
  68. palimpzest/prompts/code_synthesis_prompts.py +0 -28
  69. palimpzest/query/execution/random_sampling_execution_strategy.py +0 -240
  70. palimpzest/query/generators/api_client_factory.py +0 -30
  71. palimpzest/query/operators/code_synthesis_convert.py +0 -488
  72. palimpzest/query/operators/map.py +0 -130
  73. palimpzest/query/processor/nosentinel_processor.py +0 -33
  74. palimpzest/query/processor/processing_strategy_type.py +0 -28
  75. palimpzest/query/processor/sentinel_processor.py +0 -88
  76. palimpzest/query/processor/streaming_processor.py +0 -149
  77. palimpzest/sets.py +0 -405
  78. palimpzest/utils/datareader_helpers.py +0 -61
  79. palimpzest/utils/demo_helpers.py +0 -75
  80. palimpzest/utils/field_helpers.py +0 -69
  81. palimpzest/utils/generation_helpers.py +0 -69
  82. palimpzest/utils/sandbox.py +0 -183
  83. palimpzest-0.7.20.dist-info/RECORD +0 -95
  84. /palimpzest/core/{elements/index.py → data/index_dataset.py} +0 -0
  85. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/WHEEL +0 -0
  86. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/licenses/LICENSE +0 -0
  87. {palimpzest-0.7.20.dist-info → palimpzest-0.8.0.dist-info}/top_level.txt +0 -0
@@ -4,17 +4,25 @@ import base64
4
4
  import json
5
5
  from string import Formatter
6
6
 
7
+ from pydantic import BaseModel
8
+
7
9
  from palimpzest.constants import (
8
- MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT,
10
+ LLAMA_CONTEXT_TOKENS_LIMIT,
9
11
  TOKENS_PER_CHARACTER,
10
12
  Cardinality,
11
13
  Model,
12
14
  PromptStrategy,
13
15
  )
14
16
  from palimpzest.core.elements.records import DataRecord
15
- from palimpzest.core.lib.fields import BytesField, ImageBase64Field, ImageFilepathField, ImageURLField
16
- from palimpzest.core.lib.schemas import Schema
17
+ from palimpzest.core.lib.schemas import AudioBase64, AudioFilepath, ImageBase64, ImageFilepath, ImageURL
17
18
  from palimpzest.prompts.convert_prompts import (
19
+ COT_QA_AUDIO_DISCLAIMER,
20
+ COT_QA_AUDIO_EXAMPLE_ANSWER,
21
+ COT_QA_AUDIO_EXAMPLE_CONTEXT,
22
+ COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
23
+ COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
24
+ COT_QA_AUDIO_EXAMPLE_REASONING,
25
+ COT_QA_AUDIO_JOB_INSTRUCTION,
18
26
  COT_QA_BASE_SYSTEM_PROMPT,
19
27
  COT_QA_BASE_USER_PROMPT,
20
28
  COT_QA_EXAMPLE_ANSWER,
@@ -30,6 +38,8 @@ from palimpzest.prompts.convert_prompts import (
30
38
  COT_QA_IMAGE_EXAMPLE_REASONING,
31
39
  COT_QA_IMAGE_JOB_INSTRUCTION,
32
40
  COT_QA_JOB_INSTRUCTION,
41
+ COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
42
+ COT_QA_NO_REASONING_BASE_USER_PROMPT,
33
43
  )
34
44
  from palimpzest.prompts.critique_and_refine_convert_prompts import (
35
45
  BASE_CRITIQUE_PROMPT,
@@ -42,6 +52,12 @@ from palimpzest.prompts.critique_and_refine_convert_prompts import (
42
52
  COT_QA_REFINEMENT_FINISH_INSTRUCTION,
43
53
  )
44
54
  from palimpzest.prompts.filter_prompts import (
55
+ COT_BOOL_AUDIO_DISCLAIMER,
56
+ COT_BOOL_AUDIO_EXAMPLE_CONTEXT,
57
+ COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
58
+ COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
59
+ COT_BOOL_AUDIO_EXAMPLE_REASONING,
60
+ COT_BOOL_AUDIO_JOB_INSTRUCTION,
45
61
  COT_BOOL_BASE_SYSTEM_PROMPT,
46
62
  COT_BOOL_BASE_USER_PROMPT,
47
63
  COT_BOOL_EXAMPLE_CONTEXT,
@@ -55,6 +71,39 @@ from palimpzest.prompts.filter_prompts import (
55
71
  COT_BOOL_IMAGE_EXAMPLE_REASONING,
56
72
  COT_BOOL_IMAGE_JOB_INSTRUCTION,
57
73
  COT_BOOL_JOB_INSTRUCTION,
74
+ COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
75
+ COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
76
+ )
77
+ from palimpzest.prompts.join_prompts import (
78
+ COT_JOIN_AUDIO_DISCLAIMER,
79
+ COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
80
+ COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
81
+ COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
82
+ COT_JOIN_AUDIO_EXAMPLE_REASONING,
83
+ COT_JOIN_AUDIO_JOB_INSTRUCTION,
84
+ COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
85
+ COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
86
+ COT_JOIN_BASE_SYSTEM_PROMPT,
87
+ COT_JOIN_BASE_USER_PROMPT,
88
+ COT_JOIN_EXAMPLE_CONTEXT,
89
+ COT_JOIN_EXAMPLE_INPUT_FIELDS,
90
+ COT_JOIN_EXAMPLE_JOIN_CONDITION,
91
+ COT_JOIN_EXAMPLE_REASONING,
92
+ COT_JOIN_IMAGE_DISCLAIMER,
93
+ COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
94
+ COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
95
+ COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
96
+ COT_JOIN_IMAGE_EXAMPLE_REASONING,
97
+ COT_JOIN_IMAGE_JOB_INSTRUCTION,
98
+ COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
99
+ COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
100
+ COT_JOIN_JOB_INSTRUCTION,
101
+ COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
102
+ COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
103
+ COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
104
+ COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
105
+ COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
106
+ COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
58
107
  )
59
108
  from palimpzest.prompts.moa_aggregator_convert_prompts import (
60
109
  COT_MOA_AGG_BASE_SYSTEM_PROMPT,
@@ -99,11 +148,25 @@ class PromptFactory:
99
148
 
100
149
  BASE_SYSTEM_PROMPT_MAP = {
101
150
  PromptStrategy.COT_BOOL: COT_BOOL_BASE_SYSTEM_PROMPT,
151
+ PromptStrategy.COT_BOOL_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
152
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_BASE_SYSTEM_PROMPT,
153
+ PromptStrategy.COT_BOOL_AUDIO_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
102
154
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_SYSTEM_PROMPT,
155
+ PromptStrategy.COT_BOOL_IMAGE_NO_REASONING: COT_BOOL_NO_REASONING_BASE_SYSTEM_PROMPT,
156
+ PromptStrategy.COT_JOIN: COT_JOIN_BASE_SYSTEM_PROMPT,
157
+ PromptStrategy.COT_JOIN_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
158
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_BASE_SYSTEM_PROMPT,
159
+ PromptStrategy.COT_JOIN_AUDIO_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
160
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_BASE_SYSTEM_PROMPT,
161
+ PromptStrategy.COT_JOIN_IMAGE_NO_REASONING: COT_JOIN_NO_REASONING_BASE_SYSTEM_PROMPT,
103
162
  PromptStrategy.COT_QA: COT_QA_BASE_SYSTEM_PROMPT,
163
+ PromptStrategy.COT_QA_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
164
+ PromptStrategy.COT_QA_AUDIO: COT_QA_BASE_SYSTEM_PROMPT,
165
+ PromptStrategy.COT_QA_AUDIO_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
104
166
  PromptStrategy.COT_QA_CRITIC: None,
105
167
  PromptStrategy.COT_QA_REFINE: None,
106
168
  PromptStrategy.COT_QA_IMAGE: COT_QA_BASE_SYSTEM_PROMPT,
169
+ PromptStrategy.COT_QA_IMAGE_NO_REASONING: COT_QA_NO_REASONING_BASE_SYSTEM_PROMPT,
107
170
  PromptStrategy.COT_QA_IMAGE_CRITIC: None,
108
171
  PromptStrategy.COT_QA_IMAGE_REFINE: None,
109
172
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_BASE_SYSTEM_PROMPT,
@@ -114,11 +177,25 @@ class PromptFactory:
114
177
  }
115
178
  BASE_USER_PROMPT_MAP = {
116
179
  PromptStrategy.COT_BOOL: COT_BOOL_BASE_USER_PROMPT,
180
+ PromptStrategy.COT_BOOL_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
181
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_BASE_USER_PROMPT,
182
+ PromptStrategy.COT_BOOL_AUDIO_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
117
183
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_USER_PROMPT,
184
+ PromptStrategy.COT_BOOL_IMAGE_NO_REASONING: COT_BOOL_NO_REASONING_BASE_USER_PROMPT,
185
+ PromptStrategy.COT_JOIN: COT_JOIN_BASE_USER_PROMPT,
186
+ PromptStrategy.COT_JOIN_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
187
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_BASE_USER_PROMPT,
188
+ PromptStrategy.COT_JOIN_AUDIO_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
189
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_BASE_USER_PROMPT,
190
+ PromptStrategy.COT_JOIN_IMAGE_NO_REASONING: COT_JOIN_NO_REASONING_BASE_USER_PROMPT,
118
191
  PromptStrategy.COT_QA: COT_QA_BASE_USER_PROMPT,
192
+ PromptStrategy.COT_QA_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
193
+ PromptStrategy.COT_QA_AUDIO: COT_QA_BASE_USER_PROMPT,
194
+ PromptStrategy.COT_QA_AUDIO_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
119
195
  PromptStrategy.COT_QA_CRITIC: BASE_CRITIQUE_PROMPT,
120
196
  PromptStrategy.COT_QA_REFINE: BASE_REFINEMENT_PROMPT,
121
197
  PromptStrategy.COT_QA_IMAGE: COT_QA_BASE_USER_PROMPT,
198
+ PromptStrategy.COT_QA_IMAGE_NO_REASONING: COT_QA_NO_REASONING_BASE_USER_PROMPT,
122
199
  PromptStrategy.COT_QA_IMAGE_CRITIC: BASE_CRITIQUE_PROMPT,
123
200
  PromptStrategy.COT_QA_IMAGE_REFINE: BASE_REFINEMENT_PROMPT,
124
201
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_BASE_USER_PROMPT,
@@ -144,8 +221,9 @@ class PromptFactory:
144
221
  Returns:
145
222
  str: The context.
146
223
  """
224
+ # TODO: remove mask_filepaths=True after SemBench evaluation
147
225
  # get context from input record (project_cols will be None if not provided in kwargs)
148
- context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields)
226
+ context: dict = candidate.to_dict(include_bytes=False, project_cols=input_fields, mask_filepaths=True)
149
227
 
150
228
  # TODO: MOVE THIS LOGIC INTO A CHUNKING / CONTEXT MANAGEMENT CLASS
151
229
  # - this class should be able to:
@@ -155,12 +233,12 @@ class PromptFactory:
155
233
  # TODO: this does not work for image prompts
156
234
  # TODO: this ignores the size of the `orignal_messages` in critique and refine prompts
157
235
  # cut down on context based on window length
158
- if self.model.is_llama_model() or self.model.is_mixtral_model():
236
+ if self.model.is_llama_model():
159
237
  total_context_len = len(json.dumps(context, indent=2))
160
238
 
161
239
  # sort fields by length and progressively strip from the longest field until it is short enough;
162
- # NOTE: MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT is a rough estimate which leaves room for the rest of the prompt text
163
- while total_context_len * TOKENS_PER_CHARACTER > MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT:
240
+ # NOTE: LLAMA_CONTEXT_TOKENS_LIMIT is a rough estimate which leaves room for the rest of the prompt text
241
+ while total_context_len * TOKENS_PER_CHARACTER > LLAMA_CONTEXT_TOKENS_LIMIT:
164
242
  # sort fields by length
165
243
  field_lengths = [(field, len(value) if value is not None else 0) for field, value in context.items()]
166
244
  sorted_fields = sorted(field_lengths, key=lambda item: item[1], reverse=True)
@@ -169,7 +247,7 @@ class PromptFactory:
169
247
  longest_field_name, longest_field_length = sorted_fields[0]
170
248
 
171
249
  # trim the field
172
- context_factor = MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT / (total_context_len * TOKENS_PER_CHARACTER)
250
+ context_factor = LLAMA_CONTEXT_TOKENS_LIMIT / (total_context_len * TOKENS_PER_CHARACTER)
173
251
  keep_frac_idx = int(longest_field_length * context_factor)
174
252
  context[longest_field_name] = context[longest_field_name][:keep_frac_idx]
175
253
 
@@ -191,7 +269,11 @@ class PromptFactory:
191
269
  Returns:
192
270
  list[str]: The list of input field names.
193
271
  """
194
- return kwargs.get("project_cols", candidate.get_field_names())
272
+ # NOTE: joins will include left and right input fields in project_cols, so we have to check
273
+ # if the field is in the candidate record
274
+ input_fields = kwargs.get("project_cols", candidate.get_field_names())
275
+ input_fields = [field for field in input_fields if field in candidate.get_field_names()]
276
+ return input_fields
195
277
 
196
278
  def _get_input_fields_desc(self, candidate: DataRecord, input_fields: list[str]) -> str:
197
279
  """
@@ -205,7 +287,7 @@ class PromptFactory:
205
287
  """
206
288
  input_fields_desc = ""
207
289
  for field_name in input_fields:
208
- input_fields_desc += f"- {field_name}: {candidate.get_field_type(field_name)._desc}\n"
290
+ input_fields_desc += f"- {field_name}: {candidate.get_field_type(field_name).description}\n"
209
291
 
210
292
  return input_fields_desc[:-1]
211
293
 
@@ -221,13 +303,13 @@ class PromptFactory:
221
303
  str: The output fields description.
222
304
  """
223
305
  output_fields_desc = ""
224
- output_schema: Schema = kwargs.get("output_schema")
306
+ output_schema: BaseModel = kwargs.get("output_schema")
225
307
  if self.prompt_strategy.is_convert_prompt():
226
308
  assert output_schema is not None, "Output schema must be provided for convert prompts."
227
309
 
228
- field_desc_map = output_schema.field_desc_map()
229
- for field_name in output_fields:
230
- output_fields_desc += f"- {field_name}: {field_desc_map[field_name]}\n"
310
+ for field_name in sorted(output_fields):
311
+ desc = output_schema.model_fields[field_name].description
312
+ output_fields_desc += f"- {field_name}: {'no description available' if desc is None else desc}\n"
231
313
 
232
314
  # strip the last newline characters from the field descriptions and return
233
315
  return output_fields_desc[:-1]
@@ -245,6 +327,19 @@ class PromptFactory:
245
327
 
246
328
  return filter_condition
247
329
 
330
+ def _get_join_condition(self, **kwargs) -> str | None:
331
+ """
332
+ Returns the join condition for the join operation.
333
+
334
+ Returns:
335
+ str | None: The join condition (if applicable).
336
+ """
337
+ join_condition = kwargs.get("join_condition")
338
+ if self.prompt_strategy.is_join_prompt():
339
+ assert join_condition is not None, "Join condition must be provided for join operations."
340
+
341
+ return join_condition
342
+
248
343
  def _get_original_output(self, **kwargs) -> str | None:
249
344
  """
250
345
  Returns the original output from a previous model generation for the critique and refinement operations.
@@ -337,8 +432,13 @@ class PromptFactory:
337
432
  """
338
433
  prompt_strategy_to_job_instruction = {
339
434
  PromptStrategy.COT_BOOL: COT_BOOL_JOB_INSTRUCTION,
435
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_JOB_INSTRUCTION,
340
436
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_JOB_INSTRUCTION,
437
+ PromptStrategy.COT_JOIN: COT_JOIN_JOB_INSTRUCTION,
438
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_JOB_INSTRUCTION,
439
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_JOB_INSTRUCTION,
341
440
  PromptStrategy.COT_QA: COT_QA_JOB_INSTRUCTION,
441
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_JOB_INSTRUCTION,
342
442
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_JOB_INSTRUCTION,
343
443
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_JOB_INSTRUCTION,
344
444
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_JOB_INSTRUCTION,
@@ -402,8 +502,13 @@ class PromptFactory:
402
502
  """
403
503
  prompt_strategy_to_example_input_fields = {
404
504
  PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_INPUT_FIELDS,
505
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_INPUT_FIELDS,
405
506
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_INPUT_FIELDS,
507
+ PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_INPUT_FIELDS,
508
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_INPUT_FIELDS,
509
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_INPUT_FIELDS,
406
510
  PromptStrategy.COT_QA: COT_QA_EXAMPLE_INPUT_FIELDS,
511
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_INPUT_FIELDS,
407
512
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_INPUT_FIELDS,
408
513
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_INPUT_FIELDS,
409
514
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_INPUT_FIELDS,
@@ -412,6 +517,21 @@ class PromptFactory:
412
517
 
413
518
  return prompt_strategy_to_example_input_fields.get(self.prompt_strategy)
414
519
 
520
+ def _get_right_example_input_fields(self) -> str | None:
521
+ """
522
+ Returns the example right input fields for the join prompt.
523
+
524
+ Returns:
525
+ str | None: The example right input fields (if applicable).
526
+ """
527
+ prompt_strategy_to_right_example_input_fields = {
528
+ PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_INPUT_FIELDS,
529
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_INPUT_FIELDS,
530
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_INPUT_FIELDS,
531
+ }
532
+
533
+ return prompt_strategy_to_right_example_input_fields.get(self.prompt_strategy)
534
+
415
535
  def _get_example_output_fields(self) -> str | None:
416
536
  """
417
537
  Returns the example output fields for the prompt.
@@ -421,6 +541,7 @@ class PromptFactory:
421
541
  """
422
542
  prompt_strategy_to_example_output_fields = {
423
543
  PromptStrategy.COT_QA: COT_QA_EXAMPLE_OUTPUT_FIELDS,
544
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_OUTPUT_FIELDS,
424
545
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_OUTPUT_FIELDS,
425
546
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_OUTPUT_FIELDS,
426
547
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_OUTPUT_FIELDS,
@@ -438,8 +559,13 @@ class PromptFactory:
438
559
  """
439
560
  prompt_strategy_to_example_context = {
440
561
  PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_CONTEXT,
562
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_CONTEXT,
441
563
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_CONTEXT,
564
+ PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_CONTEXT,
565
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_CONTEXT,
566
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_CONTEXT,
442
567
  PromptStrategy.COT_QA: COT_QA_EXAMPLE_CONTEXT,
568
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_CONTEXT,
443
569
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_CONTEXT,
444
570
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_CONTEXT,
445
571
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_CONTEXT,
@@ -448,6 +574,21 @@ class PromptFactory:
448
574
 
449
575
  return prompt_strategy_to_example_context.get(self.prompt_strategy)
450
576
 
577
+ def _get_right_example_context(self) -> str | None:
578
+ """
579
+ Returns the right example context for the join prompt.
580
+
581
+ Returns:
582
+ str | None: The right example context (if applicable).
583
+ """
584
+ prompt_strategy_to_right_example_context = {
585
+ PromptStrategy.COT_JOIN: COT_JOIN_RIGHT_EXAMPLE_CONTEXT,
586
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_RIGHT_EXAMPLE_CONTEXT,
587
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_RIGHT_EXAMPLE_CONTEXT,
588
+ }
589
+
590
+ return prompt_strategy_to_right_example_context.get(self.prompt_strategy)
591
+
451
592
  def _get_image_disclaimer(self) -> str:
452
593
  """
453
594
  Returns the image disclaimer for the prompt. The disclaimer must be an empty string
@@ -458,12 +599,57 @@ class PromptFactory:
458
599
  """
459
600
  prompt_strategy_to_image_disclaimer = {
460
601
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_DISCLAIMER,
602
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_DISCLAIMER,
461
603
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_DISCLAIMER,
462
604
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_DISCLAIMER,
463
605
  }
464
606
 
465
607
  return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
466
608
 
609
+ def _get_audio_disclaimer(self) -> str:
610
+ """
611
+ Returns the audio disclaimer for the prompt. The disclaimer must be an empty string
612
+ for text prompts.
613
+
614
+ Returns:
615
+ str: The audio disclaimer. If this is a text prompt then it is an empty string.
616
+ """
617
+ prompt_strategy_to_audio_disclaimer = {
618
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_DISCLAIMER,
619
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_DISCLAIMER,
620
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_DISCLAIMER,
621
+ }
622
+
623
+ return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
624
+
625
+ def _get_right_image_disclaimer(self) -> str:
626
+ """
627
+ Returns the right image disclaimer for the prompt. The disclaimer must be an empty string
628
+ for text prompts.
629
+
630
+ Returns:
631
+ str: The right image disclaimer. If this is a text prompt then it is an empty string.
632
+ """
633
+ prompt_strategy_to_image_disclaimer = {
634
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_RIGHT_IMAGE_DISCLAIMER,
635
+ }
636
+
637
+ return prompt_strategy_to_image_disclaimer.get(self.prompt_strategy, "")
638
+
639
+ def _get_right_audio_disclaimer(self) -> str:
640
+ """
641
+ Returns the right audio disclaimer for the prompt. The disclaimer must be an empty string
642
+ for text prompts.
643
+
644
+ Returns:
645
+ str: The right audio disclaimer. If this is a text prompt then it is an empty string.
646
+ """
647
+ prompt_strategy_to_audio_disclaimer = {
648
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_RIGHT_AUDIO_DISCLAIMER,
649
+ }
650
+
651
+ return prompt_strategy_to_audio_disclaimer.get(self.prompt_strategy, "")
652
+
467
653
  def _get_example_filter_condition(self) -> str | None:
468
654
  """
469
655
  Returns the example filter condition for the prompt.
@@ -473,11 +659,27 @@ class PromptFactory:
473
659
  """
474
660
  prompt_strategy_to_example_filter_condition = {
475
661
  PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_FILTER_CONDITION,
662
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_FILTER_CONDITION,
476
663
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_FILTER_CONDITION,
477
664
  }
478
665
 
479
666
  return prompt_strategy_to_example_filter_condition.get(self.prompt_strategy)
480
667
 
668
+ def _get_example_join_condition(self) -> str | None:
669
+ """
670
+ Returns the example join condition for the prompt.
671
+
672
+ Returns:
673
+ str | None: The example join condition (if applicable).
674
+ """
675
+ prompt_strategy_to_example_join_condition = {
676
+ PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_JOIN_CONDITION,
677
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_JOIN_CONDITION,
678
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_JOIN_CONDITION,
679
+ }
680
+
681
+ return prompt_strategy_to_example_join_condition.get(self.prompt_strategy)
682
+
481
683
  def _get_example_reasoning(self) -> str | None:
482
684
  """
483
685
  Returns the example reasoning for the prompt.
@@ -487,8 +689,13 @@ class PromptFactory:
487
689
  """
488
690
  prompt_strategy_to_example_reasoning = {
489
691
  PromptStrategy.COT_BOOL: COT_BOOL_EXAMPLE_REASONING,
692
+ PromptStrategy.COT_BOOL_AUDIO: COT_BOOL_AUDIO_EXAMPLE_REASONING,
490
693
  PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_IMAGE_EXAMPLE_REASONING,
694
+ PromptStrategy.COT_JOIN: COT_JOIN_EXAMPLE_REASONING,
695
+ PromptStrategy.COT_JOIN_AUDIO: COT_JOIN_AUDIO_EXAMPLE_REASONING,
696
+ PromptStrategy.COT_JOIN_IMAGE: COT_JOIN_IMAGE_EXAMPLE_REASONING,
491
697
  PromptStrategy.COT_QA: COT_QA_EXAMPLE_REASONING,
698
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_REASONING,
492
699
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_REASONING,
493
700
  }
494
701
 
@@ -503,6 +710,7 @@ class PromptFactory:
503
710
  """
504
711
  prompt_strategy_to_example_answer = {
505
712
  PromptStrategy.COT_QA: COT_QA_EXAMPLE_ANSWER,
713
+ PromptStrategy.COT_QA_AUDIO: COT_QA_AUDIO_EXAMPLE_ANSWER,
506
714
  PromptStrategy.COT_QA_IMAGE: COT_QA_IMAGE_EXAMPLE_ANSWER,
507
715
  PromptStrategy.COT_MOA_PROPOSER: COT_MOA_PROPOSER_EXAMPLE_ANSWER,
508
716
  PromptStrategy.COT_MOA_PROPOSER_IMAGE: COT_MOA_PROPOSER_IMAGE_EXAMPLE_ANSWER,
@@ -512,7 +720,7 @@ class PromptFactory:
512
720
  return prompt_strategy_to_example_answer.get(self.prompt_strategy)
513
721
 
514
722
  def _get_all_format_kwargs(
515
- self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], **kwargs
723
+ self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs
516
724
  ) -> dict:
517
725
  """
518
726
  Returns a dictionary containing all the format kwargs for templating the prompts.
@@ -532,12 +740,20 @@ class PromptFactory:
532
740
  "input_fields_desc": self._get_input_fields_desc(candidate, input_fields),
533
741
  "output_fields_desc": self._get_output_fields_desc(output_fields, **kwargs),
534
742
  "filter_condition": self._get_filter_condition(**kwargs),
743
+ "join_condition": self._get_join_condition(**kwargs),
535
744
  "original_output": self._get_original_output(**kwargs),
536
745
  "critique_output": self._get_critique_output(**kwargs),
537
746
  "model_responses": self._get_model_responses(**kwargs),
538
747
  "chunk_outputs": self._get_chunk_outputs(**kwargs),
539
748
  }
540
749
 
750
+ # if a right candidate is provided, we also get the context and input field descriptions for the right candidate
751
+ if right_candidate is not None:
752
+ input_format_kwargs.update({
753
+ "right_context": self._get_context(right_candidate, right_input_fields),
754
+ "right_input_fields_desc": self._get_input_fields_desc(right_candidate, right_input_fields),
755
+ })
756
+
541
757
  # get format kwargs which depend on the prompt strategy
542
758
  prompt_strategy_format_kwargs = {
543
759
  "output_format_instruction": self._get_output_format_instruction(),
@@ -546,10 +762,16 @@ class PromptFactory:
546
762
  "refinement_criteria": self._get_refinement_criteria(),
547
763
  "finish_instruction": self._get_finish_instruction(),
548
764
  "example_input_fields": self._get_example_input_fields(),
765
+ "right_example_input_fields": self._get_right_example_input_fields(),
549
766
  "example_output_fields": self._get_example_output_fields(),
550
767
  "example_context": self._get_example_context(),
768
+ "right_example_context": self._get_right_example_context(),
551
769
  "image_disclaimer": self._get_image_disclaimer(),
770
+ "audio_disclaimer": self._get_audio_disclaimer(),
771
+ "right_image_disclaimer": self._get_right_image_disclaimer(),
772
+ "right_audio_disclaimer": self._get_right_audio_disclaimer(),
552
773
  "example_filter_condition": self._get_example_filter_condition(),
774
+ "example_join_condition": self._get_example_join_condition(),
553
775
  "example_reasoning": self._get_example_reasoning(),
554
776
  "example_answer": self._get_example_answer(),
555
777
  }
@@ -557,6 +779,53 @@ class PromptFactory:
557
779
  # return all format kwargs
558
780
  return {**input_format_kwargs, **prompt_strategy_format_kwargs}
559
781
 
782
+ def _create_audio_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
783
+ """
784
+ Parses the candidate record and returns the audio messages for the chat payload.
785
+
786
+ Args:
787
+ candidate (DataRecord): The input record.
788
+ input_fields (list[str]): The list of input fields.
789
+
790
+ Returns:
791
+ list[dict]: The audio messages for the chat payload.
792
+ """
793
+ # create a message for each audio recording in an input field with an audio (or list of audio) type
794
+ audio_content = []
795
+ for field_name in input_fields:
796
+ field_value = candidate[field_name]
797
+ field_type = candidate.get_field_type(field_name)
798
+
799
+ # audio filepath (or list of audio filepaths)
800
+ if field_type.annotation in [AudioFilepath, AudioFilepath | None]:
801
+ with open(field_value, "rb") as f:
802
+ base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
803
+ audio_content.append(
804
+ {"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
805
+ )
806
+
807
+ elif field_type.annotation in [list[AudioFilepath], list[AudioFilepath] | None]:
808
+ for audio_filepath in field_value:
809
+ with open(audio_filepath, "rb") as f:
810
+ base64_audio_str = base64.b64encode(f.read()).decode("utf-8")
811
+ audio_content.append(
812
+ {"type": "input_audio", "input_audio": {"data": base64_audio_str, "format": "wav"}}
813
+ )
814
+
815
+ # pre-encoded images (or list of pre-encoded images)
816
+ elif field_type.annotation in [AudioBase64, AudioBase64 | None]:
817
+ audio_content.append(
818
+ {"type": "input_audio", "input_audio": {"data": field_value, "format": "wav"}}
819
+ )
820
+
821
+ elif field_type.annotation in [list[AudioBase64], list[AudioBase64] | None]:
822
+ for base64_audio in field_value:
823
+ audio_content.append(
824
+ {"type": "input_audio", "input_audio": {"data": base64_audio, "format": "wav"}}
825
+ )
826
+
827
+ return [{"role": "user", "type": "input_audio", "content": audio_content}] if len(audio_content) > 0 else []
828
+
560
829
  def _create_image_messages(self, candidate: DataRecord, input_fields: list[str]) -> list[dict]:
561
830
  """
562
831
  Parses the candidate record and returns the image messages for the chat payload.
@@ -569,50 +838,48 @@ class PromptFactory:
569
838
  list[dict]: The image messages for the chat payload.
570
839
  """
571
840
  # create a message for each image in an input field with an image (or list of image) type
572
- image_messages = []
841
+ image_content = []
573
842
  for field_name in input_fields:
574
843
  field_value = candidate[field_name]
575
844
  field_type = candidate.get_field_type(field_name)
576
845
 
577
846
  # image filepath (or list of image filepaths)
578
- if isinstance(field_type, ImageFilepathField):
847
+ if field_type.annotation in [ImageFilepath, ImageFilepath | None]:
579
848
  with open(field_value, "rb") as f:
580
849
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
581
- image_messages.append(
582
- {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
850
+ image_content.append(
851
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
583
852
  )
584
853
 
585
- elif hasattr(field_type, "element_type") and issubclass(field_type.element_type, ImageFilepathField):
854
+ elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None]:
586
855
  for image_filepath in field_value:
587
856
  with open(image_filepath, "rb") as f:
588
857
  base64_image_str = base64.b64encode(f.read()).decode("utf-8")
589
- image_messages.append(
590
- {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
858
+ image_content.append(
859
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}}
591
860
  )
592
861
 
593
862
  # image url (or list of image urls)
594
- elif isinstance(field_type, ImageURLField):
595
- image_messages.append({"role": "user", "type": "image", "content": field_value})
863
+ elif field_type.annotation in [ImageURL, ImageURL | None]:
864
+ image_content.append({"type": "image_url", "image_url": {"url": field_value}})
596
865
 
597
- elif hasattr(field_type, "element_type") and issubclass(field_type.element_type, ImageURLField):
866
+ elif field_type.annotation in [list[ImageURL], list[ImageURL] | None]:
598
867
  for image_url in field_value:
599
- image_messages.append({"role": "user", "type": "image", "content": image_url})
868
+ image_content.append({"type": "image_url", "image_url": {"url": image_url}})
600
869
 
601
870
  # pre-encoded images (or list of pre-encoded images)
602
- elif isinstance(field_type, ImageBase64Field):
603
- base64_image_str = field_value.decode("utf-8")
604
- image_messages.append(
605
- {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
871
+ elif field_type.annotation in [ImageBase64, ImageBase64 | None]:
872
+ image_content.append(
873
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}}
606
874
  )
607
875
 
608
- elif hasattr(field_type, "element_type") and issubclass(field_type.element_type, ImageBase64Field):
876
+ elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None]:
609
877
  for base64_image in field_value:
610
- base64_image_str = base64_image.decode("utf-8")
611
- image_messages.append(
612
- {"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
878
+ image_content.append(
879
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
613
880
  )
614
881
 
615
- return image_messages
882
+ return [{"role": "user", "type": "image", "content": image_content}] if len(image_content) > 0 else []
616
883
 
617
884
  def _get_system_prompt(self, **format_kwargs) -> str | None:
618
885
  """
@@ -631,7 +898,7 @@ class PromptFactory:
631
898
 
632
899
  return base_prompt.format(**format_kwargs)
633
900
 
634
- def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], **kwargs) -> str:
901
+ def _get_user_messages(self, candidate: DataRecord, input_fields: list[str], right_candidate: DataRecord | None, right_input_fields: list[str], **kwargs) -> str:
635
902
  """
636
903
  Returns a list of messages for the chat payload based on the prompt strategy.
637
904
 
@@ -648,10 +915,18 @@ class PromptFactory:
648
915
  # get the base prompt template
649
916
  base_prompt = self.BASE_USER_PROMPT_MAP.get(self.prompt_strategy)
650
917
 
651
- # get any image messages for the chat payload (will be an empty list if this is not an image prompt)
652
- image_messages = (
653
- self._create_image_messages(candidate, input_fields) if self.prompt_strategy.is_image_prompt() else []
654
- )
918
+ # get any image messages for the chat payload (will be an empty list if no image fields exist)
919
+ image_messages = self._create_image_messages(candidate, input_fields)
920
+
921
+ # get any audio messages for the chat payload (will be an empty list if no audio fields exist)
922
+ audio_messages = self._create_audio_messages(candidate, input_fields)
923
+
924
+ # get any right image messages for the chat payload (will be an empty list if this is not a join image prompt)
925
+ right_image_messages, right_audio_messages = [], []
926
+ if self.prompt_strategy.is_join_prompt():
927
+ assert right_candidate is not None, "Right candidate must be provided for join prompts."
928
+ right_image_messages = self._create_image_messages(right_candidate, right_input_fields)
929
+ right_audio_messages = self._create_audio_messages(right_candidate, right_input_fields)
655
930
 
656
931
  # get any original messages for critique and refinement operations
657
932
  original_messages = kwargs.get("original_messages")
@@ -660,6 +935,8 @@ class PromptFactory:
660
935
  "Original messages must be provided for critique and refinement operations."
661
936
  )
662
937
 
938
+ # TODO: in the future if we support many modalities (e.g. images and audio) in the same prompt,
939
+ # then we will need to streamline this logic to handle the many different cases
663
940
  # construct the user messages based on the prompt strategy
664
941
  user_messages = []
665
942
  if self.prompt_strategy.is_critic_prompt() or self.prompt_strategy.is_refine_prompt():
@@ -670,14 +947,47 @@ class PromptFactory:
670
947
  user_messages.extend(original_messages)
671
948
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
672
949
 
673
- elif self.prompt_strategy.is_image_prompt():
674
- base_prompt_start, base_prompt_end = base_prompt.split("<<image-placeholder>>\n")
950
+ # image not join
951
+ elif self.prompt_strategy.is_image_prompt() and not self.prompt_strategy.is_join_prompt():
952
+ base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
953
+ base_prompt_start, base_prompt_end = base_prompt.split("<<image-placeholder>>")
675
954
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
676
955
  user_messages.extend(image_messages)
677
956
  user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
678
957
 
958
+ # image join
959
+ elif self.prompt_strategy.is_image_prompt() and self.prompt_strategy.is_join_prompt():
960
+ # for join image prompts, we may have two sets of images (one from the left candidate and one from the right candidate)
961
+ base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
962
+ base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<image-placeholder>>")
963
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
964
+ user_messages.extend(image_messages)
965
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
966
+ user_messages.extend(right_image_messages)
967
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
968
+
969
+ # audio not join
970
+ elif self.prompt_strategy.is_audio_prompt() and not self.prompt_strategy.is_join_prompt():
971
+ base_prompt = base_prompt.replace("<<image-placeholder>>", "")
972
+ base_prompt_start, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
973
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
974
+ user_messages.extend(audio_messages)
975
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
976
+
977
+ # audio join
978
+ elif self.prompt_strategy.is_audio_prompt() and self.prompt_strategy.is_join_prompt():
979
+ # for join image prompts, we may have two sets of images (one from the left candidate and one from the right candidate)
980
+ base_prompt = base_prompt.replace("<<image-placeholder>>", "")
981
+ base_prompt_start, base_prompt_mid, base_prompt_end = base_prompt.split("<<audio-placeholder>>")
982
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_start.format(**kwargs)})
983
+ user_messages.extend(audio_messages)
984
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_mid.format(**kwargs)})
985
+ user_messages.extend(right_audio_messages)
986
+ user_messages.append({"role": "user", "type": "text", "content": base_prompt_end.format(**kwargs)})
987
+
679
988
  else:
680
989
  base_prompt = base_prompt.replace("<<image-placeholder>>", "")
990
+ base_prompt = base_prompt.replace("<<audio-placeholder>>", "")
681
991
  user_messages.append({"role": "user", "type": "text", "content": base_prompt.format(**kwargs)})
682
992
 
683
993
  return user_messages
@@ -720,7 +1030,7 @@ class PromptFactory:
720
1030
  # build set of format kwargs
721
1031
  format_kwargs = {
722
1032
  field_name: "<bytes>"
723
- if isinstance(candidate.get_field_type(field_name), BytesField)
1033
+ if candidate.get_field_type(field_name).annotation in [bytes, bytes | None]
724
1034
  else candidate[field_name]
725
1035
  for field_name in input_fields
726
1036
  }
@@ -740,7 +1050,7 @@ class PromptFactory:
740
1050
 
741
1051
  return messages
742
1052
 
743
- def create_messages(self, candidate: DataRecord, output_fields: list[str], **kwargs) -> list[dict]:
1053
+ def create_messages(self, candidate: DataRecord, output_fields: list[str], right_candidate: DataRecord | None = None, **kwargs) -> list[dict]:
744
1054
  """
745
1055
  Creates the messages for the chat payload based on the prompt strategy.
746
1056
 
@@ -754,6 +1064,7 @@ class PromptFactory:
754
1064
  Args:
755
1065
  candidate (DataRecord): The input record.
756
1066
  output_fields (list[str]): The output fields.
1067
+ right_candidate (DataRecord | None): The other join input record (only provided for joins).
757
1068
  kwargs: The keyword arguments provided by the user.
758
1069
 
759
1070
  Returns:
@@ -761,6 +1072,7 @@ class PromptFactory:
761
1072
  """
762
1073
  # compute the set of input fields
763
1074
  input_fields = self._get_input_fields(candidate, **kwargs)
1075
+ right_input_fields = [] if right_candidate is None else self._get_input_fields(right_candidate, **kwargs)
764
1076
 
765
1077
  # if the user provides a prompt, we process that prompt into messages and return them
766
1078
  if "prompt" in kwargs:
@@ -774,7 +1086,7 @@ class PromptFactory:
774
1086
  messages = []
775
1087
 
776
1088
  # compute the full dictionary of format kwargs and add to kwargs
777
- format_kwargs = self._get_all_format_kwargs(candidate, input_fields, output_fields, **kwargs)
1089
+ format_kwargs = self._get_all_format_kwargs(candidate, input_fields, output_fields, right_candidate, right_input_fields, **kwargs)
778
1090
  kwargs = {**kwargs, **format_kwargs}
779
1091
 
780
1092
  # generate system message (if applicable)
@@ -783,7 +1095,7 @@ class PromptFactory:
783
1095
  messages.append({"role": "system", "type": "text", "content": system_prompt})
784
1096
 
785
1097
  # generate user messages and add to messages
786
- user_messages = self._get_user_messages(candidate, input_fields, **kwargs)
1098
+ user_messages = self._get_user_messages(candidate, input_fields, right_candidate, right_input_fields, **kwargs)
787
1099
  messages.extend(user_messages)
788
1100
 
789
1101
  return messages