plancraft 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plancraft/environment/prompts.py +13 -2
- plancraft/evaluator.py +11 -2
- plancraft/utils.py +21 -18
- {plancraft-0.3.3.dist-info → plancraft-0.3.4.dist-info}/METADATA +1 -1
- {plancraft-0.3.3.dist-info → plancraft-0.3.4.dist-info}/RECORD +7 -7
- {plancraft-0.3.3.dist-info → plancraft-0.3.4.dist-info}/WHEEL +0 -0
- {plancraft-0.3.3.dist-info → plancraft-0.3.4.dist-info}/licenses/LICENSE +0 -0
plancraft/environment/prompts.py
CHANGED
@@ -59,7 +59,8 @@ SEARCH_STEPS = [
|
|
59
59
|
|
60
60
|
def get_system_prompt(
|
61
61
|
handlers: list[ActionHandlerBase] = [MoveActionHandler(), SmeltActionHandler()],
|
62
|
-
|
62
|
+
use_multimodal_content_format=False,
|
63
|
+
) -> dict:
|
63
64
|
action_names = [handler.action_name for handler in handlers]
|
64
65
|
assert "move" in action_names, "MoveActionHandler should be one of the handlers"
|
65
66
|
assert "smelt" in action_names, "SmeltActionHandler should be one of the handlers"
|
@@ -72,7 +73,17 @@ def get_system_prompt(
|
|
72
73
|
for handler in handlers:
|
73
74
|
output_format += f"\n\t- {handler.prompt_format_example}"
|
74
75
|
|
75
|
-
|
76
|
+
system_prompt_text = f"{BASE_SYSTEM_PROMPT}\n\nActions:{descriptions}\n\nFormat{output_format}\n\n{BASE_SYSTEM_PROMPT_EXAMPLE}"
|
77
|
+
|
78
|
+
if use_multimodal_content_format:
|
79
|
+
return {
|
80
|
+
"role": "system",
|
81
|
+
"content": [{"text": system_prompt_text, "type": "text"}],
|
82
|
+
}
|
83
|
+
return {
|
84
|
+
"role": "system",
|
85
|
+
"content": system_prompt_text,
|
86
|
+
}
|
76
87
|
|
77
88
|
|
78
89
|
def get_prompt_example(
|
plancraft/evaluator.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
+
from typing import Optional
|
3
4
|
|
4
5
|
import imageio
|
5
6
|
from loguru import logger
|
@@ -8,18 +9,18 @@ from tqdm import tqdm
|
|
8
9
|
import wandb
|
9
10
|
from plancraft.config import PlancraftExample
|
10
11
|
from plancraft.environment.actions import (
|
11
|
-
StopAction,
|
12
12
|
ActionHandlerBase,
|
13
13
|
MoveActionHandler,
|
14
14
|
SmeltActionHandler,
|
15
|
+
StopAction,
|
15
16
|
)
|
16
17
|
from plancraft.environment.env import (
|
17
18
|
PlancraftEnvironment,
|
18
19
|
get_objective_str,
|
19
20
|
target_and_inventory_to_text_obs,
|
20
21
|
)
|
21
|
-
from plancraft.utils import History
|
22
22
|
from plancraft.models.base import PlancraftBaseModel
|
23
|
+
from plancraft.utils import History
|
23
24
|
|
24
25
|
|
25
26
|
class Evaluator:
|
@@ -48,6 +49,10 @@ class Evaluator:
|
|
48
49
|
use_images: bool = False,
|
49
50
|
use_text_inventory: bool = False,
|
50
51
|
use_fasterrcnn: bool = False,
|
52
|
+
system_prompt: Optional[dict] = None,
|
53
|
+
prompt_examples: list[dict] = [],
|
54
|
+
prompt_images: list[str] = [],
|
55
|
+
few_shot: bool = True,
|
51
56
|
):
|
52
57
|
self.run_name = run_name
|
53
58
|
self.use_multimodal_content_format = use_multimodal_content_format
|
@@ -77,6 +82,10 @@ class Evaluator:
|
|
77
82
|
use_images=use_images,
|
78
83
|
use_text_inventory=use_text_inventory,
|
79
84
|
resolution=resolution,
|
85
|
+
few_shot=few_shot,
|
86
|
+
system_prompt=system_prompt,
|
87
|
+
prompt_examples=prompt_examples,
|
88
|
+
prompt_images=prompt_images,
|
80
89
|
)
|
81
90
|
|
82
91
|
# load model
|
plancraft/utils.py
CHANGED
@@ -2,6 +2,7 @@ import glob
|
|
2
2
|
import pathlib
|
3
3
|
from collections import Counter
|
4
4
|
from copy import copy
|
5
|
+
from typing import Optional
|
5
6
|
|
6
7
|
import torch
|
7
8
|
from loguru import logger
|
@@ -12,8 +13,8 @@ from plancraft.environment.actions import (
|
|
12
13
|
SmeltAction,
|
13
14
|
)
|
14
15
|
from plancraft.environment.prompts import (
|
15
|
-
get_system_prompt,
|
16
16
|
get_prompt_example,
|
17
|
+
get_system_prompt,
|
17
18
|
load_prompt_images,
|
18
19
|
)
|
19
20
|
|
@@ -35,6 +36,9 @@ class History:
|
|
35
36
|
use_images=False,
|
36
37
|
use_text_inventory=False,
|
37
38
|
resolution="high",
|
39
|
+
system_prompt: Optional[dict] = None,
|
40
|
+
prompt_examples: list[dict] = [],
|
41
|
+
prompt_images: list[str] = [],
|
38
42
|
):
|
39
43
|
self.action_handlers = actions
|
40
44
|
self.use_multimodal_content_format = use_multimodal_content_format
|
@@ -49,31 +53,30 @@ class History:
|
|
49
53
|
|
50
54
|
self.tokens_used = 0
|
51
55
|
|
56
|
+
# use system prompt if provided
|
57
|
+
if system_prompt:
|
58
|
+
self.system_prompt_dialogue = system_prompt
|
59
|
+
else:
|
60
|
+
# generate system prompt
|
61
|
+
self.system_prompt_dialogue = get_system_prompt(
|
62
|
+
handlers=self.action_handlers,
|
63
|
+
use_multimodal_content_format=self.use_multimodal_content_format,
|
64
|
+
)
|
65
|
+
|
52
66
|
# set up dialogue history with few-shot prompt
|
67
|
+
self.prompt_examples = prompt_examples
|
68
|
+
self.prompt_images = prompt_images
|
53
69
|
self.set_up_few_shot_prompt()
|
54
|
-
self.system_prompt_dialogue = self.system_prompt()
|
55
70
|
|
56
71
|
self.dialogue_history = copy(self.prompt_examples)
|
57
72
|
self.images = copy(self.prompt_images)
|
58
73
|
self.initial_dialogue_length = len(self.dialogue_history)
|
59
74
|
|
60
|
-
def system_prompt(self):
|
61
|
-
# kept separate from dialogue history because certain models deal with system prompt differently
|
62
|
-
system_prompt_text = get_system_prompt(handlers=self.action_handlers)
|
63
|
-
if self.use_multimodal_content_format:
|
64
|
-
return {
|
65
|
-
"role": "system",
|
66
|
-
"content": [{"text": system_prompt_text, "type": "text"}],
|
67
|
-
}
|
68
|
-
return {
|
69
|
-
"role": "system",
|
70
|
-
"content": system_prompt_text,
|
71
|
-
}
|
72
|
-
|
73
75
|
def set_up_few_shot_prompt(self):
|
74
|
-
|
75
|
-
self.
|
76
|
-
|
76
|
+
# if either prompt_examples or prompt_images are provided, skip
|
77
|
+
if self.prompt_examples or self.prompt_images:
|
78
|
+
return
|
79
|
+
# if few-shot is not enabled, skip
|
77
80
|
if self.few_shot:
|
78
81
|
self.prompt_examples = get_prompt_example(
|
79
82
|
self.action_handlers,
|
@@ -1,8 +1,8 @@
|
|
1
1
|
plancraft/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
plancraft/config.py,sha256=Ppkps-E8xDNYEP9prOVxW2zEG9MpWVzcLJi4tmGLjuQ,4285
|
3
|
-
plancraft/evaluator.py,sha256=
|
3
|
+
plancraft/evaluator.py,sha256=adGmrn3GMQd5KSfFGQZxHjisQbvoxvEv1W1CPxZnFi8,11061
|
4
4
|
plancraft/generate_dataset.py,sha256=DlrU-PmvWqSNJD1g1-8Lpb8n3N-Ogw3rje1nrRzjGKs,2382
|
5
|
-
plancraft/utils.py,sha256=
|
5
|
+
plancraft/utils.py,sha256=rYiqLUaEqjdUG-nqeHmeVG3PaExAlYiBGXH5qzLZPhs,7224
|
6
6
|
plancraft/data/test.json,sha256=4jWfYMAVuZCFmGB4iZJAjlh9_8jXECdaGp8xn7_tAM4,1317131
|
7
7
|
plancraft/data/test.small.easy.json,sha256=5NZEJ2PqIgmHQecJOIVQyM1D6GFKyJq7GVmgRudaqQk,189304
|
8
8
|
plancraft/data/test.small.json,sha256=eULAG1rdolRMXPrecV-7YoDIheKGyIT5MVpWdISV0wg,270089
|
@@ -15,7 +15,7 @@ plancraft/environment/actions.py,sha256=D9QqBW7yWsbWCjxNyWp61Xtb0c6EtyXk3PZ1I8SR
|
|
15
15
|
plancraft/environment/env.py,sha256=F5xo1eAJ9MeuoE2IpG_LtbaE0BGd66URPB_rehAWIiU,16372
|
16
16
|
plancraft/environment/items.py,sha256=Z9rhSyVDEoHF1pxRvhyiT94tyQJaWHi3wUHVcamz82o,221
|
17
17
|
plancraft/environment/planner.py,sha256=eJExz3OxSzurIEdH9LOtMwFH9ApqMQ3CokVhmbV6Px0,3953
|
18
|
-
plancraft/environment/prompts.py,sha256=
|
18
|
+
plancraft/environment/prompts.py,sha256=8QXclX0ygpL02uZichE1AVkbdn_0HGteD5bzo0FZGOU,6947
|
19
19
|
plancraft/environment/recipes.py,sha256=0vwzOU86eZmGN2EpZVSIvzxpx0AOBWNPxTtAOFBN2A0,19570
|
20
20
|
plancraft/environment/sampler.py,sha256=IZT-XjmWSZrs0zDyRTMjYytXxewdwYf5YGGdKsR5ll4,7643
|
21
21
|
plancraft/environment/search.py,sha256=uFHpLvW40rMKOxDabcyWrpOrhKLDZqAJOF_jew4_WXk,1837
|
@@ -1920,7 +1920,7 @@ plancraft/models/generators.py,sha256=F76_iPiqxUjDIrQwF58tzM0bLM91OkZJ0sBqBuki5w
|
|
1920
1920
|
plancraft/models/oracle.py,sha256=jDCE6zVFvbwFpDzQZTkHIlRwMud1yMJ4LVIdfpt5ddU,8449
|
1921
1921
|
plancraft/models/utils.py,sha256=E-sZohvolWgGbpHQKgAgkgIfUJoVnT5pMt6JP8xLHKg,4034
|
1922
1922
|
plancraft/train/dataset.py,sha256=oFqEd4LG9oEQ-71teh0Wf7-jJbtybT2ZibfM2bBdBkM,5474
|
1923
|
-
plancraft-0.3.
|
1924
|
-
plancraft-0.3.
|
1925
|
-
plancraft-0.3.
|
1926
|
-
plancraft-0.3.
|
1923
|
+
plancraft-0.3.4.dist-info/METADATA,sha256=W14g4fJ1y6zALGre8NKFRZXu9cVCrQS9i-24akOIWSw,11306
|
1924
|
+
plancraft-0.3.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
1925
|
+
plancraft-0.3.4.dist-info/licenses/LICENSE,sha256=YGR8ehDB4t-T-lOQKMfKNR-2zsOU7E3E5NA8t25HKE0,1070
|
1926
|
+
plancraft-0.3.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|