palimpzest 0.6.1__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- palimpzest/prompts/prompt_factory.py +50 -32
- {palimpzest-0.6.1.dist-info → palimpzest-0.6.2.dist-info}/METADATA +1 -1
- {palimpzest-0.6.1.dist-info → palimpzest-0.6.2.dist-info}/RECORD +6 -6
- {palimpzest-0.6.1.dist-info → palimpzest-0.6.2.dist-info}/LICENSE +0 -0
- {palimpzest-0.6.1.dist-info → palimpzest-0.6.2.dist-info}/WHEEL +0 -0
- {palimpzest-0.6.1.dist-info → palimpzest-0.6.2.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""This file contains factory methods which return template prompts and return messages for chat payloads."""
|
|
2
|
+
|
|
2
3
|
import base64
|
|
3
4
|
import json
|
|
4
5
|
from string import Formatter
|
|
@@ -82,6 +83,7 @@ from palimpzest.prompts.util_phrases import (
|
|
|
82
83
|
|
|
83
84
|
class PromptFactory:
|
|
84
85
|
"""Factory class for generating prompts for the Generator given the input(s)."""
|
|
86
|
+
|
|
85
87
|
BASE_SYSTEM_PROMPT_MAP = {
|
|
86
88
|
PromptStrategy.COT_BOOL: COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
87
89
|
PromptStrategy.COT_BOOL_IMAGE: COT_BOOL_BASE_SYSTEM_PROMPT,
|
|
@@ -148,13 +150,13 @@ class PromptFactory:
|
|
|
148
150
|
longest_field_name, longest_field_length = sorted_fields[0]
|
|
149
151
|
|
|
150
152
|
# trim the field
|
|
151
|
-
context_factor =
|
|
153
|
+
context_factor = MIXTRAL_LLAMA_CONTEXT_TOKENS_LIMIT / (total_context_len * TOKENS_PER_CHARACTER)
|
|
152
154
|
keep_frac_idx = int(longest_field_length * context_factor)
|
|
153
155
|
context[longest_field_name] = context[longest_field_name][:keep_frac_idx]
|
|
154
156
|
|
|
155
157
|
# update total context length
|
|
156
158
|
total_context_len = len(json.dumps(context, indent=2))
|
|
157
|
-
|
|
159
|
+
|
|
158
160
|
return json.dumps(context, indent=2)
|
|
159
161
|
|
|
160
162
|
def _get_input_fields(self, candidate: DataRecord, **kwargs) -> list[str]:
|
|
@@ -201,7 +203,11 @@ class PromptFactory:
|
|
|
201
203
|
"""
|
|
202
204
|
output_fields_desc = ""
|
|
203
205
|
output_schema: Schema = kwargs.get("output_schema")
|
|
204
|
-
if
|
|
206
|
+
if (
|
|
207
|
+
self.prompt_strategy.is_cot_qa_prompt()
|
|
208
|
+
or self.prompt_strategy.is_moa_proposer_prompt()
|
|
209
|
+
or self.prompt_strategy.is_moa_aggregator_prompt()
|
|
210
|
+
):
|
|
205
211
|
assert output_schema is not None, "Output schema must be provided for convert prompts."
|
|
206
212
|
|
|
207
213
|
field_desc_map = output_schema.field_desc_map()
|
|
@@ -230,14 +236,16 @@ class PromptFactory:
|
|
|
230
236
|
|
|
231
237
|
Args:
|
|
232
238
|
kwargs: The keyword arguments provided by the user.
|
|
233
|
-
|
|
239
|
+
|
|
234
240
|
Returns:
|
|
235
241
|
str | None: The original output.
|
|
236
242
|
"""
|
|
237
243
|
original_output = kwargs.get("original_output")
|
|
238
244
|
if self.prompt_strategy.is_critic_prompt() or self.prompt_strategy.is_refine_prompt():
|
|
239
|
-
assert original_output is not None,
|
|
240
|
-
|
|
245
|
+
assert original_output is not None, (
|
|
246
|
+
"Original output must be provided for critique and refinement operations."
|
|
247
|
+
)
|
|
248
|
+
|
|
241
249
|
return original_output
|
|
242
250
|
|
|
243
251
|
def _get_critique_output(self, **kwargs) -> str | None:
|
|
@@ -246,7 +254,7 @@ class PromptFactory:
|
|
|
246
254
|
|
|
247
255
|
Args:
|
|
248
256
|
kwargs: The keyword arguments provided by the user.
|
|
249
|
-
|
|
257
|
+
|
|
250
258
|
Returns:
|
|
251
259
|
str | None: The critique output.
|
|
252
260
|
"""
|
|
@@ -259,10 +267,10 @@ class PromptFactory:
|
|
|
259
267
|
def _get_model_responses(self, **kwargs) -> str | None:
|
|
260
268
|
"""
|
|
261
269
|
Returns the model responses for the mixture-of-agents aggregation operation.
|
|
262
|
-
|
|
270
|
+
|
|
263
271
|
Args:
|
|
264
272
|
kwargs: The keyword arguments provided by the user.
|
|
265
|
-
|
|
273
|
+
|
|
266
274
|
Returns:
|
|
267
275
|
str | None: The model responses.
|
|
268
276
|
"""
|
|
@@ -314,9 +322,7 @@ class PromptFactory:
|
|
|
314
322
|
critique_criteria = None
|
|
315
323
|
if self.prompt_strategy.is_critic_prompt():
|
|
316
324
|
critique_criteria = (
|
|
317
|
-
COT_QA_IMAGE_CRITIQUE_CRITERIA
|
|
318
|
-
if self.prompt_strategy.is_image_prompt()
|
|
319
|
-
else COT_QA_CRITIQUE_CRITERIA
|
|
325
|
+
COT_QA_IMAGE_CRITIQUE_CRITERIA if self.prompt_strategy.is_image_prompt() else COT_QA_CRITIQUE_CRITERIA
|
|
320
326
|
)
|
|
321
327
|
|
|
322
328
|
return critique_criteria
|
|
@@ -467,16 +473,18 @@ class PromptFactory:
|
|
|
467
473
|
|
|
468
474
|
return prompt_strategy_to_example_answer.get(self.prompt_strategy)
|
|
469
475
|
|
|
470
|
-
def _get_all_format_kwargs(
|
|
476
|
+
def _get_all_format_kwargs(
|
|
477
|
+
self, candidate: DataRecord, input_fields: list[str], output_fields: list[str], **kwargs
|
|
478
|
+
) -> dict:
|
|
471
479
|
"""
|
|
472
480
|
Returns a dictionary containing all the format kwargs for templating the prompts.
|
|
473
|
-
|
|
481
|
+
|
|
474
482
|
Args:
|
|
475
483
|
candidate (DataRecord): The input record.
|
|
476
484
|
input_fields (list[str]): The input fields.
|
|
477
485
|
output_fields (list[str]): The output fields.
|
|
478
486
|
kwargs: The keyword arguments provided by the user.
|
|
479
|
-
|
|
487
|
+
|
|
480
488
|
Returns:
|
|
481
489
|
dict: The dictionary containing all the format kwargs.
|
|
482
490
|
"""
|
|
@@ -517,7 +525,7 @@ class PromptFactory:
|
|
|
517
525
|
Args:
|
|
518
526
|
candidate (DataRecord): The input record.
|
|
519
527
|
input_fields (list[str]): The list of input fields.
|
|
520
|
-
|
|
528
|
+
|
|
521
529
|
Returns:
|
|
522
530
|
list[dict]: The image messages for the chat payload.
|
|
523
531
|
"""
|
|
@@ -529,15 +537,19 @@ class PromptFactory:
|
|
|
529
537
|
|
|
530
538
|
# image filepath (or list of image filepaths)
|
|
531
539
|
if isinstance(field_type, ImageFilepathField):
|
|
532
|
-
with open(field_value,
|
|
533
|
-
base64_image_str = base64.b64encode(f.read()).decode(
|
|
534
|
-
image_messages.append(
|
|
540
|
+
with open(field_value, "rb") as f:
|
|
541
|
+
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
542
|
+
image_messages.append(
|
|
543
|
+
{"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
|
|
544
|
+
)
|
|
535
545
|
|
|
536
546
|
elif hasattr(field_type, "element_type") and issubclass(field_type.element_type, ImageFilepathField):
|
|
537
547
|
for image_filepath in field_value:
|
|
538
|
-
with open(image_filepath,
|
|
539
|
-
base64_image_str = base64.b64encode(f.read()).decode(
|
|
540
|
-
image_messages.append(
|
|
548
|
+
with open(image_filepath, "rb") as f:
|
|
549
|
+
base64_image_str = base64.b64encode(f.read()).decode("utf-8")
|
|
550
|
+
image_messages.append(
|
|
551
|
+
{"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
|
|
552
|
+
)
|
|
541
553
|
|
|
542
554
|
# image url (or list of image urls)
|
|
543
555
|
elif isinstance(field_type, ImageURLField):
|
|
@@ -550,12 +562,16 @@ class PromptFactory:
|
|
|
550
562
|
# pre-encoded images (or list of pre-encoded images)
|
|
551
563
|
elif isinstance(field_type, ImageBase64Field):
|
|
552
564
|
base64_image_str = field_value.decode("utf-8")
|
|
553
|
-
image_messages.append(
|
|
565
|
+
image_messages.append(
|
|
566
|
+
{"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
|
|
567
|
+
)
|
|
554
568
|
|
|
555
569
|
elif hasattr(field_type, "element_type") and issubclass(field_type.element_type, ImageBase64Field):
|
|
556
570
|
for base64_image in field_value:
|
|
557
571
|
base64_image_str = base64_image.decode("utf-8")
|
|
558
|
-
image_messages.append(
|
|
572
|
+
image_messages.append(
|
|
573
|
+
{"role": "user", "type": "image", "content": f"data:image/jpeg;base64,{base64_image_str}"}
|
|
574
|
+
)
|
|
559
575
|
|
|
560
576
|
return image_messages
|
|
561
577
|
|
|
@@ -595,15 +611,15 @@ class PromptFactory:
|
|
|
595
611
|
|
|
596
612
|
# get any image messages for the chat payload (will be an empty list if this is not an image prompt)
|
|
597
613
|
image_messages = (
|
|
598
|
-
self._create_image_messages(candidate, input_fields)
|
|
599
|
-
if self.prompt_strategy.is_image_prompt()
|
|
600
|
-
else []
|
|
614
|
+
self._create_image_messages(candidate, input_fields) if self.prompt_strategy.is_image_prompt() else []
|
|
601
615
|
)
|
|
602
616
|
|
|
603
617
|
# get any original messages for critique and refinement operations
|
|
604
618
|
original_messages = kwargs.get("original_messages")
|
|
605
619
|
if self.prompt_strategy.is_critic_prompt() or self.prompt_strategy.is_refine_prompt():
|
|
606
|
-
assert original_messages is not None,
|
|
620
|
+
assert original_messages is not None, (
|
|
621
|
+
"Original messages must be provided for critique and refinement operations."
|
|
622
|
+
)
|
|
607
623
|
|
|
608
624
|
# construct the user messages based on the prompt strategy
|
|
609
625
|
user_messages = []
|
|
@@ -661,17 +677,19 @@ class PromptFactory:
|
|
|
661
677
|
f"Input fields: {input_fields}\n"
|
|
662
678
|
)
|
|
663
679
|
assert fields_check, err_msg
|
|
664
|
-
|
|
680
|
+
|
|
665
681
|
# build set of format kwargs
|
|
666
682
|
format_kwargs = {
|
|
667
|
-
field_name: "<bytes>"
|
|
683
|
+
field_name: "<bytes>"
|
|
684
|
+
if isinstance(candidate.get_field_type(field_name), BytesField)
|
|
685
|
+
else candidate[field_name]
|
|
668
686
|
for field_name in input_fields
|
|
669
687
|
}
|
|
670
688
|
|
|
671
689
|
# split prompt on <<image-placeholder>> if it exists
|
|
672
690
|
if "<<image-placeholder>>" in user_prompt:
|
|
673
691
|
raise NotImplementedError("Image prompts are not yet supported.")
|
|
674
|
-
|
|
692
|
+
|
|
675
693
|
prompt_sections = user_prompt.split("<<image-placeholder>>")
|
|
676
694
|
messages = [{"role": "user", "type": "text", "content": prompt_sections[0].format(**format_kwargs)}]
|
|
677
695
|
|
|
@@ -686,7 +704,7 @@ class PromptFactory:
|
|
|
686
704
|
def create_messages(self, candidate: DataRecord, output_fields: list[str], **kwargs) -> list[dict]:
|
|
687
705
|
"""
|
|
688
706
|
Creates the messages for the chat payload based on the prompt strategy.
|
|
689
|
-
|
|
707
|
+
|
|
690
708
|
Each message will be a dictionary with the following format:
|
|
691
709
|
{
|
|
692
710
|
"role": "user" | "system",
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: palimpzest
|
|
3
|
-
Version: 0.6.
|
|
3
|
+
Version: 0.6.2
|
|
4
4
|
Summary: Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language
|
|
5
5
|
Author-email: MIT DSG Semantic Management Lab <michjc@csail.mit.edu>
|
|
6
6
|
Project-URL: homepage, https://palimpzest.org
|
|
@@ -20,7 +20,7 @@ palimpzest/prompts/critique_and_refine_convert_prompts.py,sha256=WoXExBxQ7twswd9
|
|
|
20
20
|
palimpzest/prompts/filter_prompts.py,sha256=iQjn-39h3L0E5wng_UPgAXRHrP1ok329TXpOgZ6Wn1w,2372
|
|
21
21
|
palimpzest/prompts/moa_aggregator_convert_prompts.py,sha256=BQRrtGdr53PTqvXzmFh8kfQ_w9KoKw-zTtmdo-8RFjo,2887
|
|
22
22
|
palimpzest/prompts/moa_proposer_convert_prompts.py,sha256=d_hOh0-0m6HWBDAxUu7W3WyQtSTlUvqio3nzpnX2bxM,3642
|
|
23
|
-
palimpzest/prompts/prompt_factory.py,sha256=
|
|
23
|
+
palimpzest/prompts/prompt_factory.py,sha256=Y1R3sRLoeQt6YUbw-4Tv5oj57Hu39IA9shKkgTtoMks,32184
|
|
24
24
|
palimpzest/prompts/util_phrases.py,sha256=NWrcHfjJyiOY16Jyt7R50moVnlJDyvSBZ9kBqyX2WQo,751
|
|
25
25
|
palimpzest/query/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
26
|
palimpzest/query/execution/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -80,8 +80,8 @@ palimpzest/utils/progress.py,sha256=GYmPUBdG7xmqbqj1UiSNP-pWZKmRMLX797MBgrOPugM,
|
|
|
80
80
|
palimpzest/utils/sandbox.py,sha256=Ge96gmzqeOGlNkMCG9A95_PB8wRQbvTFua136of8FcA,6465
|
|
81
81
|
palimpzest/utils/token_reduction_helpers.py,sha256=Ob95PcqCsbGLiBdQ-4YQsWGWRppb2hvQyt0gi1fzL-Y,3855
|
|
82
82
|
palimpzest/utils/udfs.py,sha256=LjHic54B1az-rKgNLur0wOpaz2ko_UodjLEJrazkxvY,1854
|
|
83
|
-
palimpzest-0.6.
|
|
84
|
-
palimpzest-0.6.
|
|
85
|
-
palimpzest-0.6.
|
|
86
|
-
palimpzest-0.6.
|
|
87
|
-
palimpzest-0.6.
|
|
83
|
+
palimpzest-0.6.2.dist-info/LICENSE,sha256=5GUlHy9lr-Py9kvV38FF1m3yy3NqM18fefuE9wkWumo,1079
|
|
84
|
+
palimpzest-0.6.2.dist-info/METADATA,sha256=yBFEceRsylGwr8q8W0LqugpBmnqsey5SgKiQhPZhYgI,7837
|
|
85
|
+
palimpzest-0.6.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
86
|
+
palimpzest-0.6.2.dist-info/top_level.txt,sha256=raV06dJUgohefUn3ZyJS2uqp_Y76EOLA9Y2e_fxt8Ew,11
|
|
87
|
+
palimpzest-0.6.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|