plancraft 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,283 @@
1
+ from typing import Optional
2
+
3
+ from loguru import logger
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms.v2 as v2
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ PretrainedConfig,
13
+ PreTrainedModel,
14
+ )
15
+
16
+ from plancraft.models.bbox_model import IntegratedBoundingBoxModel
17
+
18
+
19
+ class PlancraftOAMConfig(PretrainedConfig):
20
+ model_type = "plancraft-aom"
21
+ is_composition = True
22
+
23
+ def __init__(
24
+ self,
25
+ from_llama=False,
26
+ **kwargs,
27
+ ):
28
+ self.from_llama = from_llama
29
+ super().__init__(**kwargs)
30
+
31
+
32
+ class PlancraftOAM(PreTrainedModel):
33
+ config_class = PlancraftOAMConfig
34
+
35
+ def __init__(self, config: PlancraftOAMConfig):
36
+ super().__init__(config)
37
+
38
+ self.config = config
39
+ # load text model
40
+ if self.config.from_llama:
41
+ self.text_model = AutoModelForCausalLM.from_pretrained(
42
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
43
+ )
44
+ else:
45
+ text_model_config = AutoConfig.from_pretrained(
46
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
47
+ )
48
+ self.text_model = AutoModelForCausalLM.from_config(text_model_config)
49
+
50
+ # load vision model
51
+ self.vision_model = IntegratedBoundingBoxModel.from_pretrained(
52
+ "gautierdag/plancraft-maskrcnn"
53
+ )
54
+ self.vision_model.eval()
55
+
56
+ # convert vision features to text embedding
57
+ self.vision_to_text_embedding = nn.Linear(
58
+ 1024, self.text_model.config.hidden_size
59
+ )
60
+ self.tokenizer = AutoTokenizer.from_pretrained(
61
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
62
+ trust_remote=True,
63
+ )
64
+ # add special tokens
65
+ self.tokenizer.add_special_tokens(
66
+ {
67
+ "additional_special_tokens": [
68
+ "<|inventory|>",
69
+ ]
70
+ }
71
+ )
72
+ self.tokenizer.pad_token = self.tokenizer.eos_token
73
+ self.inventory_idx = self.tokenizer.convert_tokens_to_ids("<|inventory|>")
74
+
75
+ # resize token embeddings
76
+ self.text_model.resize_token_embeddings(len(self.tokenizer))
77
+ # image transforms
78
+ self.transforms = v2.Compose(
79
+ [v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
80
+ )
81
+
82
+ @torch.no_grad()
83
+ def extract_bboxes(self, images: list) -> list[dict]:
84
+ if len(images) == 0:
85
+ return []
86
+ img_tensors = torch.stack([self.transforms(img) for img in images])
87
+ img_tensors = img_tensors.cuda()
88
+ # disable gradients
89
+ self.vision_model.freeze()
90
+ # get bounding box predictions
91
+ bbox_preds = self.vision_model(img_tensors)
92
+ return bbox_preds
93
+
94
+ def prepare_messages(self, messages: list, bboxes: list[dict]) -> str:
95
+ # no bounding boxes
96
+ if len(bboxes) == 0:
97
+ text = self.tokenizer.apply_chat_template(
98
+ messages, add_generation_prompt=not self.training, tokenize=False
99
+ )
100
+ text = text.replace("<|begin_of_text|>", "")
101
+ return text
102
+
103
+ # expand <|inventory|> tokens into N tokens (N = number of bounding boxes)
104
+ new_messages = []
105
+ i_pred = 0
106
+ for m in messages:
107
+ new_message = m.copy()
108
+ if new_message["role"] == "user" and new_message["content"].endswith(
109
+ "<|inventory|>"
110
+ ):
111
+ # add inventory tokens for each bounding box
112
+ new_message["content"] = new_message["content"].replace(
113
+ "<|inventory|>",
114
+ "<|inventory|>" * (bboxes[i_pred]["features"].shape[0]),
115
+ )
116
+ i_pred += 1
117
+ new_messages.append(new_message)
118
+
119
+ assert i_pred == len(
120
+ bboxes
121
+ ), "Number of inventory tokens does not match number of images"
122
+ # add special tokens
123
+
124
+ text = self.tokenizer.apply_chat_template(
125
+ new_messages, add_generation_prompt=not self.training, tokenize=False
126
+ )
127
+ text = text.replace("<|begin_of_text|>", "")
128
+ return text
129
+
130
+ def inputs_merger(
131
+ self,
132
+ input_ids: torch.LongTensor,
133
+ inputs_embeds: Optional[torch.Tensor],
134
+ image_hidden_states: Optional[torch.Tensor],
135
+ ):
136
+ # along batch dimension
137
+ for i in range(len(image_hidden_states)):
138
+ if len(image_hidden_states[i]) == 0:
139
+ assert (
140
+ input_ids[i] == self.inventory_idx
141
+ ).sum() == 0, "No images but inventory token is still present"
142
+ continue
143
+
144
+ # count the number of inventory tokens
145
+ n_inventory_tokens = (input_ids[i] == self.inventory_idx).sum()
146
+ if n_inventory_tokens != image_hidden_states[i].shape[0]:
147
+ logger.warning(
148
+ f"Number of inventory tokens ({n_inventory_tokens}) does not match number of bounding boxes ({image_hidden_states[i].shape[0]}). Possible truncation."
149
+ )
150
+ # truncated from the start
151
+ image_hidden_states[i] = image_hidden_states[i][-n_inventory_tokens:]
152
+
153
+ # replace inventory tokens with bbox features
154
+ inputs_embeds[i, input_ids[i] == self.inventory_idx] = image_hidden_states[
155
+ i
156
+ ]
157
+ return inputs_embeds
158
+
159
+ def process_inputs(
160
+ self,
161
+ batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
162
+ batch_images: list[list] = [], # list of list of images (unprocessed)
163
+ ) -> tuple[dict[str, torch.FloatTensor], list[torch.FloatTensor], int]:
164
+ """
165
+ Converts raw images and messages into model inputs
166
+ """
167
+ assert len(batch_images) == len(
168
+ batch_messages
169
+ ), "Number of images and messages should match in the batch dim"
170
+ # initial forward pass
171
+ texts_batch = []
172
+ image_hidden_states = []
173
+ total_boxes = 0
174
+ for images, messages in zip(batch_images, batch_messages):
175
+ # process images
176
+ bboxes = self.extract_bboxes(images)
177
+ if len(bboxes) > 0:
178
+ # get bbox features
179
+ features = torch.concat([p["features"] for p in bboxes], dim=0)
180
+ # upscale to text embedding size
181
+ features_embeds = self.vision_to_text_embedding(features)
182
+ image_hidden_states.append(features_embeds)
183
+ # count bboxes total
184
+ total_boxes += features.shape[0]
185
+ else:
186
+ image_hidden_states.append([])
187
+
188
+ # process messages
189
+ text = self.prepare_messages(messages, bboxes)
190
+ texts_batch.append(text)
191
+
192
+ # tokenize text
193
+ # @NOTE: truncation could cause issues with inventory tokens not matching number of boxes
194
+ # in that case, we will truncate the boxes from the end, and issue a warning
195
+ batch = self.tokenizer(
196
+ texts_batch,
197
+ truncation=True,
198
+ padding=True,
199
+ max_length=16384,
200
+ return_tensors="pt",
201
+ )
202
+ return batch, image_hidden_states, total_boxes
203
+
204
+ def forward(
205
+ self,
206
+ batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
207
+ batch_images: list[list] = [], # list of list of images (unprocessed)
208
+ **kwargs,
209
+ ):
210
+ labels = None
211
+ batch, image_hidden_states, total_boxes = self.process_inputs(
212
+ batch_messages, batch_images
213
+ )
214
+ # move to cuda
215
+ batch = {k: v.cuda() for k, v in batch.items()}
216
+ attention_mask = batch["attention_mask"]
217
+ input_ids = batch["input_ids"]
218
+
219
+ labels = input_ids.clone()
220
+ # remove inventory tokens from labels
221
+ labels[labels == self.inventory_idx] = -100
222
+ # sanity check: should have same number of boxes as inventory tokens
223
+ assert (labels == -100).sum() == total_boxes
224
+
225
+ # get text embeddings
226
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
227
+ inputs_embeds = self.inputs_merger(
228
+ input_ids, inputs_embeds, image_hidden_states
229
+ )
230
+ # forward pass
231
+ return self.text_model(
232
+ inputs_embeds=inputs_embeds,
233
+ attention_mask=attention_mask,
234
+ labels=labels,
235
+ return_dict=True,
236
+ )
237
+
238
+ @torch.no_grad()
239
+ def generate(
240
+ self,
241
+ batch_messages: list[list[dict]],
242
+ batch_images: list[list],
243
+ do_sample=True,
244
+ temperature=0.6,
245
+ max_new_tokens=32,
246
+ ):
247
+ self.tokenizer.padding_side = "left"
248
+
249
+ batch, image_hidden_states, _ = self.process_inputs(
250
+ batch_messages, batch_images
251
+ )
252
+ batch = {k: v.cuda() for k, v in batch.items()}
253
+ attention_mask = batch["attention_mask"]
254
+ input_ids = batch["input_ids"]
255
+
256
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
257
+ inputs_embeds = self.inputs_merger(
258
+ input_ids, inputs_embeds, image_hidden_states
259
+ )
260
+
261
+ generated_sequences = self.text_model.generate(
262
+ inputs_embeds=inputs_embeds,
263
+ attention_mask=attention_mask,
264
+ do_sample=do_sample,
265
+ temperature=temperature,
266
+ max_new_tokens=max_new_tokens,
267
+ pad_token_id=self.tokenizer.eos_token_id,
268
+ )
269
+
270
+ # Decode the output
271
+ text_responses = self.tokenizer.batch_decode(
272
+ generated_sequences,
273
+ # generated_sequences[:, prompt_tokens:],
274
+ skip_special_tokens=False,
275
+ )
276
+
277
+ # remove <|eot_id|> tokens
278
+ text_responses = [
279
+ text_response.replace("<|eot_id|>", "") for text_response in text_responses
280
+ ]
281
+
282
+ _, total_tokens_used = generated_sequences.shape
283
+ return text_responses, total_tokens_used
@@ -0,0 +1,265 @@
1
+ import copy
2
+ from collections import Counter
3
+
4
+ from plancraft.config import EvalConfig
5
+ from plancraft.environments.actions import (
6
+ RealActionInteraction,
7
+ StopAction,
8
+ SymbolicMoveAction,
9
+ SymbolicSmeltAction,
10
+ )
11
+ from plancraft.environments.planner import optimal_planner
12
+ from plancraft.environments.recipes import (
13
+ ShapedRecipe,
14
+ ShapelessRecipe,
15
+ SmeltingRecipe,
16
+ id_to_item,
17
+ )
18
+ from plancraft.environments.sampler import MAX_STACK_SIZE
19
+ from plancraft.models.base import ABCModel, History
20
+
21
+
22
+ def item_set_id_to_type(item_set_ids: set[int]):
23
+ return set(id_to_item(i) for i in item_set_ids)
24
+
25
+
26
+ def find_free_inventory_slot(inventory: list[dict], from_slot: int) -> int:
27
+ # find a free slot in the inventory for the item in from_slot
28
+ from_item_type, from_item_quantity = None, None
29
+
30
+ type_to_slot = {}
31
+ slot_to_quantity = {}
32
+ for item in inventory:
33
+ if ("slot" in item and item["slot"] == from_slot) or (
34
+ "index" in item and item["index"] == from_slot
35
+ ):
36
+ from_item_quantity = item["quantity"]
37
+ from_item_type = item["type"]
38
+ # break
39
+ item_type = item["type"]
40
+ quantity = item["quantity"]
41
+ if quantity == 0:
42
+ item_type = "air"
43
+
44
+ if "slot" in item:
45
+ item_slot = item["slot"]
46
+ else:
47
+ item_slot = item["index"]
48
+
49
+ if item_type not in type_to_slot:
50
+ type_to_slot[item_type] = [item_slot]
51
+ else:
52
+ type_to_slot[item_type].append(item_slot)
53
+
54
+ if item_slot not in slot_to_quantity:
55
+ slot_to_quantity[item_slot] = quantity
56
+ else:
57
+ slot_to_quantity[item_slot] += quantity
58
+
59
+ assert from_item_type is not None, f"Item not found in slot {from_slot}"
60
+
61
+ # if there is a free slot with the same item type
62
+ if from_item_type in type_to_slot:
63
+ for slot in type_to_slot[from_item_type]:
64
+ if (
65
+ slot != from_slot
66
+ and slot_to_quantity[slot] + from_item_quantity
67
+ <= MAX_STACK_SIZE[from_item_type]
68
+ ):
69
+ return slot
70
+
71
+ # if there is a free slot with air
72
+ for slot in type_to_slot["air"]:
73
+ if slot != from_slot and slot > 10:
74
+ return slot
75
+
76
+ raise ValueError("No free slot found")
77
+
78
+
79
+ def find_item_in_inventory(target: str, inventory: list[dict]) -> int:
80
+ for item in inventory:
81
+ if item["type"] == target and item["quantity"] > 0:
82
+ if "slot" in item:
83
+ return item["slot"]
84
+ elif "index" in item:
85
+ return item["index"]
86
+ raise ValueError("Neither slot or index is set")
87
+
88
+
89
+ def get_inventory_counter(inventory: list[dict]) -> Counter:
90
+ counter = Counter()
91
+ for item in inventory:
92
+ if "slot" in item and item["slot"] == 0:
93
+ continue
94
+ if "index" in item and item["index"] == 0:
95
+ continue
96
+ if item["type"] == "air":
97
+ continue
98
+ counter[item["type"]] += item["quantity"]
99
+ return counter
100
+
101
+
102
+ def get_crafting_slot_item(inventory: list[dict]) -> dict:
103
+ for item in inventory:
104
+ if "slot" in item and item["slot"] == 0 and item["quantity"] > 0:
105
+ return item
106
+ if "index" in item and item["index"] == 0 and item["quantity"] > 0:
107
+ return item
108
+ return None
109
+
110
+
111
+ def update_inventory(
112
+ inventory: list[dict], slot_from: int, slot_to: int, quantity: int
113
+ ) -> list[dict]:
114
+ """
115
+ decrements quantity of item in slot_from
116
+ NOTE: we don't care about incrementing the items in slot_to
117
+
118
+ """
119
+ new_inventory = []
120
+ for item in inventory:
121
+ if "slot" in item and item["slot"] == slot_from:
122
+ item["quantity"] -= quantity
123
+ elif "index" in item and item["index"] == slot_from:
124
+ item["quantity"] -= quantity
125
+ new_inventory.append(item)
126
+ return new_inventory
127
+
128
+
129
+ class OracleModel(ABCModel):
130
+ """
131
+ Oracle model returns actions that solve the task optimally
132
+ """
133
+
134
+ def __init__(self, cfg: EvalConfig):
135
+ assert (
136
+ cfg.plancraft.environment.symbolic_action_space
137
+ ), "Only symbolic actions are supported for oracle"
138
+ self.history = History(objective="")
139
+ self.plans = []
140
+ self.subplans = []
141
+
142
+ def reset_history(self, objective: str = ""):
143
+ self.history.reset(objective=objective)
144
+ self.plans = []
145
+ self.subplans = []
146
+
147
+ def get_plan(self, observation: dict):
148
+ # objective="Craft an item of type: ...."
149
+ # this simply recovering the target item to craft
150
+ target = self.history.objective.split(": ")[-1]
151
+ inventory_counter = get_inventory_counter(observation["inventory"])
152
+ self.plans = optimal_planner(target=target, inventory=inventory_counter)
153
+
154
+ def get_next_action(
155
+ self, observation: dict
156
+ ) -> SymbolicMoveAction | SymbolicSmeltAction:
157
+ if len(self.subplans) > 0:
158
+ return self.subplans.pop(0)
159
+ if len(self.plans) == 0:
160
+ raise ValueError("No more steps in plan")
161
+
162
+ observed_inventory = copy.deepcopy(observation["inventory"])
163
+
164
+ # take item from crafting slot
165
+ if slot_item := get_crafting_slot_item(observed_inventory):
166
+ # move item from crafting slot to inventory
167
+ free_slot = find_free_inventory_slot(observed_inventory, from_slot=0)
168
+ return SymbolicMoveAction(
169
+ slot_from=0, slot_to=free_slot, quantity=slot_item["quantity"]
170
+ )
171
+
172
+ plan_recipe, new_inventory = self.plans.pop(0)
173
+ self.subplans = []
174
+ new_inventory_counter = Counter(new_inventory)
175
+ current_inventory = observed_inventory
176
+ current_inventory_counter = get_inventory_counter(current_inventory)
177
+ items_to_use_counter = current_inventory_counter - new_inventory_counter
178
+ new_items = new_inventory_counter - current_inventory_counter
179
+ assert len(new_items) == 1
180
+
181
+ if isinstance(plan_recipe, ShapelessRecipe):
182
+ crafting_slot = 1
183
+
184
+ # add each item to crafting slots
185
+ for item, quantity in items_to_use_counter.items():
186
+ n = 0
187
+ while n < quantity:
188
+ from_slot = find_item_in_inventory(item, current_inventory)
189
+
190
+ # skip if from_slot is the crafting slot
191
+ if from_slot == crafting_slot:
192
+ crafting_slot += 1
193
+ n += 1
194
+ continue
195
+
196
+ # low_level_plan.append(("move", item, from_slot, crafting_slot, 1))
197
+ action = SymbolicMoveAction(
198
+ slot_from=from_slot, slot_to=crafting_slot, quantity=1
199
+ )
200
+ # update state of inventory
201
+ current_inventory = update_inventory(
202
+ current_inventory, from_slot, crafting_slot, 1
203
+ )
204
+ self.subplans.append(action)
205
+
206
+ crafting_slot += 1
207
+ n += 1
208
+
209
+ # if plan_recipe is a smelting recipe
210
+ elif isinstance(plan_recipe, SmeltingRecipe):
211
+ assert len(items_to_use_counter) == 1, "smelting only supports one item"
212
+ for item, quantity in items_to_use_counter.items():
213
+ from_slot = find_item_in_inventory(item, current_inventory)
214
+ free_slot = find_free_inventory_slot(
215
+ current_inventory, from_slot=from_slot
216
+ )
217
+ action = SymbolicSmeltAction(
218
+ slot_from=from_slot, slot_to=free_slot, quantity=quantity
219
+ )
220
+ self.subplans.append(action)
221
+
222
+ # if plan_recipe is a shaped recipe
223
+ elif isinstance(plan_recipe, ShapedRecipe):
224
+ for i, row in enumerate(plan_recipe.kernel):
225
+ for j, item_set in enumerate(row):
226
+ inventory_position = (i * 3) + j + 1
227
+ valid_items = item_set_id_to_type(item_set)
228
+ for item in valid_items:
229
+ if items_to_use_counter[item] > 0:
230
+ from_slot = find_item_in_inventory(item, current_inventory)
231
+ action = SymbolicMoveAction(
232
+ slot_from=from_slot,
233
+ slot_to=inventory_position,
234
+ quantity=1,
235
+ )
236
+ items_to_use_counter[item] -= 1
237
+ # update state of inventory
238
+ current_inventory = update_inventory(
239
+ current_inventory, from_slot, inventory_position, 1
240
+ )
241
+ self.subplans.append(action)
242
+ break
243
+ else:
244
+ raise NotImplementedError(f"Recipe type {type(plan_recipe)} not supported")
245
+
246
+ return self.subplans.pop(0)
247
+
248
+ def step(
249
+ self, observation: dict
250
+ ) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
251
+ # add observation to history
252
+ self.history.add_observation_to_history(observation)
253
+
254
+ # get action
255
+ if len(self.plans) == 0:
256
+ self.get_plan(observation)
257
+ if self.plans is None:
258
+ self.plans = []
259
+ return StopAction()
260
+
261
+ action = self.get_next_action(observation)
262
+
263
+ # add action to history
264
+ self.history.add_action_to_history(action)
265
+ return action