plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,158 @@
1
+ from plancraft.models.utils import gold_search_recipe
2
+
3
+ VALID_ACTIONS = ["move", "smelt", "think", "search", "impossible"]
4
+
5
+ ACTIONS_DESCRIPTIONS = {
6
+ "move": {
7
+ "description": "Transfer a specific quantity of an item from one slot to another",
8
+ "format": "`move: from [Source] to [Target] with quantity N`",
9
+ },
10
+ "smelt": {
11
+ "description": "Smelt an item in a furnace and moves the output to a specific slot",
12
+ "format": "`smelt: from [Source] to [Target] with quantity N`",
13
+ },
14
+ "think": {
15
+ "description": "Generate thoughts to help you decide on the next action",
16
+ "format": "`think: <thought message>`",
17
+ },
18
+ "search": {
19
+ "description": "Search for a recipe to craft a specific item",
20
+ "format": "`search: <recipe name>`",
21
+ },
22
+ "impossible": {
23
+ "description": "Stop task if it is certain that it is impossible with given inventory",
24
+ "format": "`impossible: <reason>`",
25
+ },
26
+ }
27
+
28
+ BASE_SYSTEM_PROMPT = """You are crafting in Minecraft. You need to decide on the next action.
29
+
30
+ Crafting Grid: The crafting table is organized into a 3x3 grid. Each slot in the grid has a unique identifier:
31
+ - Top row: [A1] [A2] [A3]
32
+ - Middle row: [B1] [B2] [B3]
33
+ - Bottom row: [C1] [C2] [C3]
34
+
35
+ The output of the crafting process is placed in a designated output slot labeled [0] You cannot move or smelt items directly into slot [0]
36
+
37
+ Inventory Slots: The remaining inventory slots (outside of the crafting grid) are used for storing items. These slots are labeled as [I1] to [I36]"""
38
+
39
+ BASE_SYSTEM_PROMPT_EXAMPLE = """Example:
40
+ - `move: from [I2] to [A1] with quantity 3`
41
+ - `smelt: from [I5] to [I6] with quantity 1`
42
+
43
+ Constraints:
44
+ - You cannot move or smelt items into [0]
45
+ - If an item is not in slot [0] then the recipe is incorrect
46
+ - You need to move items from [0] to a free inventory slot to complete the crafting process"""
47
+
48
+
49
+ def get_system_prompt(actions: list[str]):
50
+ assert set(actions).issubset(VALID_ACTIONS), f"Invalid actions: {actions}"
51
+ assert "move" in actions, "move should be one of the actions"
52
+ assert "smelt" in actions, "smelt should be one of the actions"
53
+
54
+ descriptions = ""
55
+ for action in actions:
56
+ descriptions += f"\n\t- {action}: {ACTIONS_DESCRIPTIONS[action]['description']}"
57
+
58
+ output_format = ""
59
+ for action in actions:
60
+ output_format += f"\n\t- {ACTIONS_DESCRIPTIONS[action]['format']}"
61
+
62
+ return f"{BASE_SYSTEM_PROMPT}\n\nActions:{descriptions}\n\nFormat{output_format}\n\n{BASE_SYSTEM_PROMPT_EXAMPLE}"
63
+
64
+
65
+ CRAFTING_STEPS = [
66
+ "Craft an item of type: andesite\ninventory:\n - diorite [I18] quantity 1\n - cobblestone [I30] quantity 1",
67
+ "Craft an item of type: andesite\ninventory:\n - diorite [B1] quantity 1\n - cobblestone [I30] quantity 1",
68
+ "Craft an item of type: andesite\ninventory:\n - andesite [0] quantity 1\n - diorite [B1] quantity 1\n - cobblestone [B2] quantity 1",
69
+ "Craft an item of type: iron_ingot\ninventory:\n - iron_ore [I36] quantity 1\n - cobblestone [I30] quantity 1",
70
+ ]
71
+
72
+ BASE_ACTION_STEPS = [
73
+ "move: from [I18] to [B1] with quantity 1",
74
+ "move: from [I30] to [B2] with quantity 1",
75
+ "move: from [0] to [I6] with quantity 1",
76
+ "smelt: from [I36] to [I35] with quantity 1",
77
+ ]
78
+
79
+ THINK_STEPS = [
80
+ "think: To solve this task I need to craft andesite using 1 diorite and 1 cobblestone side by side.",
81
+ "think: Now I need to move the cobblestone into position [B2] to be right of the diorite.",
82
+ "think: Now I can craft the andesite by moving it from the craft slot [0] to a free inventory slot.",
83
+ "think: To craft an iron_ingot, I need to smelt iron_ore into an empty slot.",
84
+ ]
85
+
86
+ SEARCH_STEPS = [
87
+ "search: andesite",
88
+ None,
89
+ None,
90
+ "search: iron_ingot",
91
+ ]
92
+
93
+
94
+ def get_prompt_example(
95
+ actions: list[str],
96
+ use_text_inventory=True,
97
+ use_multimodal_content_format=False,
98
+ use_images=False,
99
+ ) -> list[dict]:
100
+ assert set(actions).issubset(VALID_ACTIONS), f"Invalid actions: {actions}"
101
+ assert "move" in actions, "move should be one of the actions"
102
+ assert "smelt" in actions, "smelt should be one of the actions"
103
+
104
+ if use_images:
105
+ assert (
106
+ use_multimodal_content_format
107
+ ), "use_images requires use_multimodal_content_format"
108
+
109
+ example_dialogue = []
110
+ for i, step in enumerate(CRAFTING_STEPS):
111
+ text = step
112
+ if not use_text_inventory:
113
+ text = text.split("\ninventory:\n")[0]
114
+
115
+ example_dialogue.append({"role": "user", "content": text})
116
+ if "search" in actions and SEARCH_STEPS[i]:
117
+ example_dialogue.append({"role": "assistant", "content": SEARCH_STEPS[i]})
118
+ search_target = text.split("seach: ")[-1].strip()
119
+ search_response = gold_search_recipe(search_target)
120
+ example_dialogue.append({"role": "user", "content": search_response})
121
+ if "think" in actions:
122
+ example_dialogue.append({"role": "assistant", "content": THINK_STEPS[i]})
123
+ example_dialogue.append({"role": "user", "content": "Ok"})
124
+ example_dialogue.append({"role": "assistant", "content": BASE_ACTION_STEPS[i]})
125
+
126
+ if not use_multimodal_content_format:
127
+ return example_dialogue
128
+
129
+ # convert to multimodal dialogue
130
+ multimodal_dialogue = []
131
+ for message in example_dialogue:
132
+ if "Craft an item" in message["content"]:
133
+ content_list = [
134
+ {
135
+ "type": "text",
136
+ "text": message["content"],
137
+ }
138
+ ]
139
+ if use_images:
140
+ content_list.append(
141
+ {
142
+ "type": "image",
143
+ }
144
+ )
145
+
146
+ multimodal_dialogue.append(
147
+ {"role": message["role"], "content": content_list}
148
+ )
149
+ else:
150
+ multimodal_dialogue.append(
151
+ {
152
+ "role": message["role"],
153
+ "content": [
154
+ {"type": "text", "text": message["content"]},
155
+ ],
156
+ }
157
+ )
158
+ return multimodal_dialogue
@@ -0,0 +1,93 @@
1
+ from dotenv import load_dotenv
2
+
3
+ from plancraft.config import EvalConfig
4
+ from plancraft.environments.actions import (
5
+ NoOp,
6
+ SymbolicAction,
7
+ )
8
+ from plancraft.models.act import ActModel
9
+ from plancraft.models.utils import (
10
+ convert_observation_to_message,
11
+ parse_content_response,
12
+ )
13
+
14
+ load_dotenv()
15
+
16
+
17
+ class ReactModel(ActModel):
18
+ """
19
+ Model that does action with interleaved thinking step
20
+ """
21
+
22
+ def __init__(self, cfg: EvalConfig):
23
+ super().__init__(cfg)
24
+ self.max_invalid_actions = 3
25
+
26
+ def step(self, observation: dict) -> SymbolicAction:
27
+ # override the step method in ActModel to force thinking step
28
+
29
+ self.history.add_observation_to_history(observation)
30
+ observation_message = convert_observation_to_message(
31
+ observation,
32
+ objective=self.history.objective,
33
+ bbox_model=self.bbox_model,
34
+ oam_model="oam" in self.llm.model_name,
35
+ use_text_inventory=self.use_text_inventory,
36
+ use_multimodal_content_format=self.use_multimodal_content_format,
37
+ use_images=self.use_images,
38
+ )
39
+ # add observation to history
40
+ self.history.add_message_to_history(content=observation_message, role="user")
41
+
42
+ i = 0
43
+ while i < self.max_invalid_actions:
44
+ message_window, image_window = self.llm.prepare_messages(
45
+ history=self.history,
46
+ max_messages_window=self.max_messages_window,
47
+ system_prompt=self.system_prompt,
48
+ prompt_images=self.prompt_images,
49
+ )
50
+ think_messages, think_token_used = self.llm.generate_unconstrained(
51
+ batch_messages=[message_window],
52
+ images=[image_window],
53
+ start_messages_generation="think:",
54
+ )
55
+ self.history.tokens_used += think_token_used
56
+ think_message = "think: " + think_messages[0].split("\n")[0].strip()
57
+ self.history.add_message_to_history(content=think_message, role="assistant")
58
+
59
+ # retrieve new message window (with thinking prompt)
60
+ message_window, image_window = self.llm.prepare_messages(
61
+ history=self.history,
62
+ max_messages_window=self.max_messages_window,
63
+ system_prompt=self.system_prompt,
64
+ prompt_images=self.prompt_images,
65
+ )
66
+ action_messages, action_token_used = self.llm.generate_unconstrained(
67
+ batch_messages=[message_window],
68
+ images=[image_window],
69
+ start_messages_generation="",
70
+ )
71
+ self.history.tokens_used += action_token_used
72
+
73
+ action_message = action_messages[0].split("\n")[0].strip()
74
+
75
+ self.history.add_message_to_history(
76
+ content=action_message, role="assistant"
77
+ )
78
+
79
+ response = parse_content_response(
80
+ action_message, valid_actions=self.valid_actions
81
+ )
82
+ if not isinstance(response, str):
83
+ # valid action
84
+ self.history.add_action_to_history(response)
85
+ return response
86
+
87
+ self.history.add_message_to_history(
88
+ content=response,
89
+ )
90
+ i += 1
91
+
92
+ # default move action
93
+ return NoOp()
@@ -0,0 +1,289 @@
1
+ import base64
2
+ import glob
3
+ import io
4
+ import pathlib
5
+ import re
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+
12
+ from plancraft.environments.actions import (
13
+ StopAction,
14
+ SymbolicAction,
15
+ SymbolicMoveAction,
16
+ SymbolicSmeltAction,
17
+ convert_from_slot_index,
18
+ )
19
+ from plancraft.environments.recipes import RECIPES
20
+
21
+
22
+ def numpy_to_base64(img_array: np.ndarray, image_format: str = "PNG") -> str:
23
+ """
24
+ Convert a NumPy array to a base64 encoded string.
25
+
26
+ Parameters:
27
+ - img_array: np.ndarray - Input image array.
28
+ - image_format: str - The format to save the image in (e.g., "PNG", "JPEG").
29
+
30
+ Returns:
31
+ - str - Base64 encoded string of the image.
32
+ """
33
+ # Convert NumPy array to image
34
+ image = Image.fromarray(img_array)
35
+
36
+ # Save the image to a bytes buffer
37
+ buffered = io.BytesIO()
38
+ image.save(buffered, format=image_format)
39
+
40
+ # Encode the bytes to a base64 string
41
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
42
+
43
+ return img_str
44
+
45
+
46
+ def get_downloaded_models() -> dict:
47
+ """
48
+ Get the list of downloaded models on the NFS partition (EIDF).
49
+ """
50
+ downloaded_models = {}
51
+ # known models on NFS partition
52
+ if pathlib.Path("/nfs").exists():
53
+ local_models = glob.glob("/nfs/public/hf/models/*/*")
54
+ downloaded_models = {
55
+ model.replace("/nfs/public/hf/models/", ""): model for model in local_models
56
+ }
57
+ return downloaded_models
58
+
59
+
60
+ class TrieNode:
61
+ def __init__(self):
62
+ self.children = {}
63
+ self.is_end_of_sequence = False
64
+
65
+
66
+ class Trie:
67
+ def __init__(self):
68
+ self.root = TrieNode()
69
+ self.longest_sequence_length = 0
70
+
71
+ def insert(self, sequence: list):
72
+ node = self.root
73
+ for num in sequence:
74
+ if num not in node.children:
75
+ node.children[num] = TrieNode()
76
+ node = node.children[num]
77
+ node.is_end_of_sequence = True
78
+
79
+ if len(sequence) > self.longest_sequence_length:
80
+ self.longest_sequence_length = len(sequence)
81
+
82
+ def starts_with(self, prefix: list) -> bool:
83
+ node = self.root
84
+ for num in prefix:
85
+ if num not in node.children:
86
+ return False
87
+ node = node.children[num]
88
+ return True
89
+
90
+ def get_next(self, prefix: list) -> list:
91
+ node = self.root
92
+ for num in prefix:
93
+ if num not in node.children:
94
+ return []
95
+ node = node.children[num]
96
+ return list(node.children.keys())
97
+
98
+
99
+ def tokenize(
100
+ model: AutoModelForCausalLM,
101
+ tokenizer: AutoTokenizer,
102
+ batch_messages: list[list[dict]],
103
+ start_messages_generation: list[str],
104
+ max_tokens=256,
105
+ images=None,
106
+ ) -> dict[str, torch.Tensor]:
107
+ """
108
+ Tokenize a list of messages and start the response message
109
+ """
110
+ assert len(start_messages_generation) == len(
111
+ batch_messages
112
+ ), "Length of start_messages_generation should be equal to batch_messages"
113
+
114
+ message_texts = tokenizer.apply_chat_template(
115
+ batch_messages,
116
+ add_generation_prompt=True,
117
+ tokenize=False,
118
+ )
119
+ # add the start of the response message for each message
120
+ message_texts = [
121
+ messages_text + new_message_start
122
+ for (messages_text, new_message_start) in zip(
123
+ message_texts, start_messages_generation
124
+ )
125
+ ]
126
+
127
+ max_prompt_length = None
128
+ # need to truncate if max_length is set
129
+ if model.generation_config.max_length > max_tokens:
130
+ max_prompt_length = model.generation_config.max_length - max_tokens
131
+
132
+ if images:
133
+ assert len(images) == len(
134
+ batch_messages
135
+ ), "Length of images should be equal to batch_messages"
136
+ tokenized_messages = tokenizer(
137
+ message_texts,
138
+ return_tensors="pt",
139
+ truncation=True,
140
+ max_length=max_prompt_length,
141
+ padding=True,
142
+ images=images,
143
+ )
144
+ else:
145
+ tokenized_messages = tokenizer(
146
+ message_texts,
147
+ return_tensors="pt",
148
+ truncation=True,
149
+ max_length=max_prompt_length,
150
+ padding=True,
151
+ )
152
+ return tokenized_messages
153
+
154
+
155
+ def objective_and_inventory_to_str(objective: str, inventory: list[dict]) -> str:
156
+ inventory_str = ""
157
+ for item in inventory:
158
+ if item["quantity"] > 0:
159
+ if "index" in item:
160
+ slot = item["index"]
161
+ else:
162
+ slot = item["slot"]
163
+
164
+ if isinstance(slot, int):
165
+ slot = convert_from_slot_index(slot)
166
+
167
+ inventory_str += f"\n - {item['type']} {slot} quantity {item['quantity']}"
168
+
169
+ return f"{objective}\ninventory:{inventory_str}"
170
+
171
+
172
+ def convert_observation_to_message(
173
+ observation: dict,
174
+ objective: str,
175
+ bbox_model=None,
176
+ oam_model=False,
177
+ use_text_inventory=True,
178
+ use_multimodal_content_format=False,
179
+ use_images=False,
180
+ ) -> str | dict:
181
+ """
182
+ Convert an observation to a message format
183
+
184
+ Parameters:
185
+ - observation: dict - The observation to convert.
186
+ - objective: str - The objective of the observation.
187
+ - bbox_model: Optional - The bounding box model to use.
188
+ - oam_model: bool - Whether to use the OAM model.
189
+ - use_text_inventory: bool - Whether to use text inventory.
190
+ - use_multimodal_content_format: bool - Whether to use multimodal content format.
191
+ - use_images: bool - Whether to append an image to the message content - must be used with use_multimodal_content_format.
192
+ """
193
+ if bbox_model is not None:
194
+ # convert to tensor
195
+ inventory = bbox_model.get_inventory(observation["pov"].copy())
196
+ text_content = objective_and_inventory_to_str(
197
+ objective, sorted(inventory, key=lambda x: x["slot"])
198
+ )
199
+ elif oam_model:
200
+ text_content = f"{objective}\ninventory:\n"
201
+ elif not use_text_inventory:
202
+ text_content = objective
203
+ else:
204
+ # if not multimodal, we only have text - we just dump a JSON of the inventory
205
+ inventory = []
206
+ for o in observation["inventory"]:
207
+ if o["quantity"] > 0:
208
+ inventory.append(
209
+ {
210
+ "type": o["type"],
211
+ "slot": convert_from_slot_index(o["index"]),
212
+ "quantity": o["quantity"],
213
+ }
214
+ )
215
+ text_content = objective_and_inventory_to_str(objective, inventory)
216
+
217
+ if not use_multimodal_content_format:
218
+ return text_content
219
+
220
+ content_list = [{"type": "text", "text": text_content}]
221
+ if use_images:
222
+ content_list.append({"type": "image"})
223
+ return {"content": content_list}
224
+
225
+
226
+ def gold_search_recipe(recipe_name: str) -> str:
227
+ """
228
+ Gold search recipe for the given observation and action
229
+ """
230
+ if recipe_name not in RECIPES:
231
+ return "Could not find a recipe by that name."
232
+
233
+ out_string = f"Recipes to craft {recipe_name}:\n"
234
+ for i, r in enumerate(RECIPES[recipe_name]):
235
+ if r.recipe_type != "smelting":
236
+ # sample a valid input grid (note that this is not guaranteed to be the only valid grid)
237
+ input_crafting_grid = r.sample_input_crafting_grid()
238
+ recipe_instructions = ""
239
+ for item in input_crafting_grid:
240
+ recipe_instructions += (
241
+ f"{item['type']} at {convert_from_slot_index(item['slot'])}\n"
242
+ )
243
+ else:
244
+ # smelting recipe
245
+ recipe_instructions = f"smelt {r.ingredient}\n"
246
+ out_string += f"recipe {i+1}:\n{recipe_instructions}"
247
+ return out_string
248
+
249
+
250
+ def parse_content_response(
251
+ content: str, valid_actions: list[str] = ["smelt", "move"]
252
+ ) -> str | SymbolicAction | StopAction:
253
+ """
254
+ Given a message and set of valid actions, parse the content to return the action
255
+ or a message if the action is not valid/requires message response
256
+ """
257
+
258
+ action_match = re.search(f"({'|'.join(valid_actions)}):", content)
259
+ if action_match:
260
+ action = action_match.group(1)
261
+ if action == "think":
262
+ return "Ok"
263
+ elif action == "impossible":
264
+ reason = re.search(r"impossible: (.*)", content).group(1)
265
+ return StopAction(reason=reason)
266
+ elif action == "search":
267
+ search_target = re.search(r"search: (\w+)", content).group(1)
268
+ return gold_search_recipe(search_target)
269
+ else:
270
+ try:
271
+ slot_from = re.search(r" from (\[[ABCI]?\d+\])", content).group(1)
272
+ slot_to = re.search(r" to (\[[ABCI]?\d+\])", content).group(1)
273
+ quantity = re.search(r"with quantity (\d+)", content).group(1)
274
+ if action == "move":
275
+ action = SymbolicMoveAction(
276
+ slot_from=slot_from,
277
+ slot_to=slot_to,
278
+ quantity=quantity,
279
+ )
280
+ else:
281
+ action = SymbolicSmeltAction(
282
+ slot_from=slot_from,
283
+ slot_to=slot_to,
284
+ quantity=quantity,
285
+ )
286
+ return action
287
+ except AttributeError as e:
288
+ return f"Format Error: {e}"
289
+ return f"Only select actions from the following: {', '.join(valid_actions)}"