plancraft 0.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
models/oam.py ADDED
@@ -0,0 +1,284 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision.transforms.v2 as v2
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ PretrainedConfig,
12
+ PreTrainedModel,
13
+ )
14
+
15
+ from plancraft.models.bbox_model import IntegratedBoundingBoxModel
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class PlancraftOAMConfig(PretrainedConfig):
21
+ model_type = "plancraft-aom"
22
+ is_composition = True
23
+
24
+ def __init__(
25
+ self,
26
+ from_llama=False,
27
+ **kwargs,
28
+ ):
29
+ self.from_llama = from_llama
30
+ super().__init__(**kwargs)
31
+
32
+
33
+ class PlancraftOAM(PreTrainedModel):
34
+ config_class = PlancraftOAMConfig
35
+
36
+ def __init__(self, config: PlancraftOAMConfig):
37
+ super().__init__(config)
38
+
39
+ self.config = config
40
+ # load text model
41
+ if self.config.from_llama:
42
+ self.text_model = AutoModelForCausalLM.from_pretrained(
43
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
44
+ )
45
+ else:
46
+ text_model_config = AutoConfig.from_pretrained(
47
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
48
+ )
49
+ self.text_model = AutoModelForCausalLM.from_config(text_model_config)
50
+
51
+ # load vision model
52
+ self.vision_model = IntegratedBoundingBoxModel.from_pretrained(
53
+ "gautierdag/plancraft-maskrcnn"
54
+ )
55
+ self.vision_model.eval()
56
+
57
+ # convert vision features to text embedding
58
+ self.vision_to_text_embedding = nn.Linear(
59
+ 1024, self.text_model.config.hidden_size
60
+ )
61
+ self.tokenizer = AutoTokenizer.from_pretrained(
62
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
63
+ trust_remote=True,
64
+ )
65
+ # add special tokens
66
+ self.tokenizer.add_special_tokens(
67
+ {
68
+ "additional_special_tokens": [
69
+ "<|inventory|>",
70
+ ]
71
+ }
72
+ )
73
+ self.tokenizer.pad_token = self.tokenizer.eos_token
74
+ self.inventory_idx = self.tokenizer.convert_tokens_to_ids("<|inventory|>")
75
+
76
+ # resize token embeddings
77
+ self.text_model.resize_token_embeddings(len(self.tokenizer))
78
+ # image transforms
79
+ self.transforms = v2.Compose(
80
+ [v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
81
+ )
82
+
83
+ @torch.no_grad()
84
+ def extract_bboxes(self, images: list) -> list[dict]:
85
+ if len(images) == 0:
86
+ return []
87
+ img_tensors = torch.stack([self.transforms(img) for img in images])
88
+ img_tensors = img_tensors.cuda()
89
+ # disable gradients
90
+ self.vision_model.freeze()
91
+ # get bounding box predictions
92
+ bbox_preds = self.vision_model(img_tensors)
93
+ return bbox_preds
94
+
95
+ def prepare_messages(self, messages: list, bboxes: list[dict]) -> str:
96
+ # no bounding boxes
97
+ if len(bboxes) == 0:
98
+ text = self.tokenizer.apply_chat_template(
99
+ messages, add_generation_prompt=not self.training, tokenize=False
100
+ )
101
+ text = text.replace("<|begin_of_text|>", "")
102
+ return text
103
+
104
+ # expand <|inventory|> tokens into N tokens (N = number of bounding boxes)
105
+ new_messages = []
106
+ i_pred = 0
107
+ for m in messages:
108
+ new_message = m.copy()
109
+ if new_message["role"] == "user" and new_message["content"].endswith(
110
+ "<|inventory|>"
111
+ ):
112
+ # add inventory tokens for each bounding box
113
+ new_message["content"] = new_message["content"].replace(
114
+ "<|inventory|>",
115
+ "<|inventory|>" * (bboxes[i_pred]["features"].shape[0]),
116
+ )
117
+ i_pred += 1
118
+ new_messages.append(new_message)
119
+
120
+ assert i_pred == len(
121
+ bboxes
122
+ ), "Number of inventory tokens does not match number of images"
123
+ # add special tokens
124
+
125
+ text = self.tokenizer.apply_chat_template(
126
+ new_messages, add_generation_prompt=not self.training, tokenize=False
127
+ )
128
+ text = text.replace("<|begin_of_text|>", "")
129
+ return text
130
+
131
+ def inputs_merger(
132
+ self,
133
+ input_ids: torch.LongTensor,
134
+ inputs_embeds: Optional[torch.Tensor],
135
+ image_hidden_states: Optional[torch.Tensor],
136
+ ):
137
+ # along batch dimension
138
+ for i in range(len(image_hidden_states)):
139
+ if len(image_hidden_states[i]) == 0:
140
+ assert (
141
+ input_ids[i] == self.inventory_idx
142
+ ).sum() == 0, "No images but inventory token is still present"
143
+ continue
144
+
145
+ # count the number of inventory tokens
146
+ n_inventory_tokens = (input_ids[i] == self.inventory_idx).sum()
147
+ if n_inventory_tokens != image_hidden_states[i].shape[0]:
148
+ logger.warning(
149
+ f"Number of inventory tokens ({n_inventory_tokens}) does not match number of bounding boxes ({image_hidden_states[i].shape[0]}). Possible truncation."
150
+ )
151
+ # truncated from the start
152
+ image_hidden_states[i] = image_hidden_states[i][-n_inventory_tokens:]
153
+
154
+ # replace inventory tokens with bbox features
155
+ inputs_embeds[i, input_ids[i] == self.inventory_idx] = image_hidden_states[
156
+ i
157
+ ]
158
+ return inputs_embeds
159
+
160
+ def process_inputs(
161
+ self,
162
+ batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
163
+ batch_images: list[list] = [], # list of list of images (unprocessed)
164
+ ) -> tuple[dict[str, torch.FloatTensor], list[torch.FloatTensor], int]:
165
+ """
166
+ Converts raw images and messages into model inputs
167
+ """
168
+ assert len(batch_images) == len(
169
+ batch_messages
170
+ ), "Number of images and messages should match in the batch dim"
171
+ # initial forward pass
172
+ texts_batch = []
173
+ image_hidden_states = []
174
+ total_boxes = 0
175
+ for images, messages in zip(batch_images, batch_messages):
176
+ # process images
177
+ bboxes = self.extract_bboxes(images)
178
+ if len(bboxes) > 0:
179
+ # get bbox features
180
+ features = torch.concat([p["features"] for p in bboxes], dim=0)
181
+ # upscale to text embedding size
182
+ features_embeds = self.vision_to_text_embedding(features)
183
+ image_hidden_states.append(features_embeds)
184
+ # count bboxes total
185
+ total_boxes += features.shape[0]
186
+ else:
187
+ image_hidden_states.append([])
188
+
189
+ # process messages
190
+ text = self.prepare_messages(messages, bboxes)
191
+ texts_batch.append(text)
192
+
193
+ # tokenize text
194
+ # @NOTE: truncation could cause issues with inventory tokens not matching number of boxes
195
+ # in that case, we will truncate the boxes from the end, and issue a warning
196
+ batch = self.tokenizer(
197
+ texts_batch,
198
+ truncation=True,
199
+ padding=True,
200
+ max_length=16384,
201
+ return_tensors="pt",
202
+ )
203
+ return batch, image_hidden_states, total_boxes
204
+
205
+ def forward(
206
+ self,
207
+ batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
208
+ batch_images: list[list] = [], # list of list of images (unprocessed)
209
+ **kwargs,
210
+ ):
211
+ labels = None
212
+ batch, image_hidden_states, total_boxes = self.process_inputs(
213
+ batch_messages, batch_images
214
+ )
215
+ # move to cuda
216
+ batch = {k: v.cuda() for k, v in batch.items()}
217
+ attention_mask = batch["attention_mask"]
218
+ input_ids = batch["input_ids"]
219
+
220
+ labels = input_ids.clone()
221
+ # remove inventory tokens from labels
222
+ labels[labels == self.inventory_idx] = -100
223
+ # sanity check: should have same number of boxes as inventory tokens
224
+ assert (labels == -100).sum() == total_boxes
225
+
226
+ # get text embeddings
227
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
228
+ inputs_embeds = self.inputs_merger(
229
+ input_ids, inputs_embeds, image_hidden_states
230
+ )
231
+ # forward pass
232
+ return self.text_model(
233
+ inputs_embeds=inputs_embeds,
234
+ attention_mask=attention_mask,
235
+ labels=labels,
236
+ return_dict=True,
237
+ )
238
+
239
+ @torch.no_grad()
240
+ def generate(
241
+ self,
242
+ batch_messages: list[list[dict]],
243
+ batch_images: list[list],
244
+ do_sample=True,
245
+ temperature=0.6,
246
+ max_new_tokens=32,
247
+ ):
248
+ self.tokenizer.padding_side = "left"
249
+
250
+ batch, image_hidden_states, _ = self.process_inputs(
251
+ batch_messages, batch_images
252
+ )
253
+ batch = {k: v.cuda() for k, v in batch.items()}
254
+ attention_mask = batch["attention_mask"]
255
+ input_ids = batch["input_ids"]
256
+
257
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
258
+ inputs_embeds = self.inputs_merger(
259
+ input_ids, inputs_embeds, image_hidden_states
260
+ )
261
+
262
+ generated_sequences = self.text_model.generate(
263
+ inputs_embeds=inputs_embeds,
264
+ attention_mask=attention_mask,
265
+ do_sample=do_sample,
266
+ temperature=temperature,
267
+ max_new_tokens=max_new_tokens,
268
+ pad_token_id=self.tokenizer.eos_token_id,
269
+ )
270
+
271
+ # Decode the output
272
+ text_responses = self.tokenizer.batch_decode(
273
+ generated_sequences,
274
+ # generated_sequences[:, prompt_tokens:],
275
+ skip_special_tokens=False,
276
+ )
277
+
278
+ # remove <|eot_id|> tokens
279
+ text_responses = [
280
+ text_response.replace("<|eot_id|>", "") for text_response in text_responses
281
+ ]
282
+
283
+ _, total_tokens_used = generated_sequences.shape
284
+ return text_responses, total_tokens_used
models/oracle.py ADDED
@@ -0,0 +1,268 @@
1
+ import logging
2
+ import copy
3
+ from collections import Counter
4
+
5
+ from plancraft.config import EvalConfig
6
+ from plancraft.environments.actions import (
7
+ RealActionInteraction,
8
+ SymbolicMoveAction,
9
+ SymbolicSmeltAction,
10
+ StopAction,
11
+ )
12
+ from plancraft.environments.planner import optimal_planner
13
+ from plancraft.environments.recipes import (
14
+ ShapedRecipe,
15
+ ShapelessRecipe,
16
+ SmeltingRecipe,
17
+ id_to_item,
18
+ )
19
+ from plancraft.models.base import ABCModel, History
20
+ from plancraft.environments.sampler import MAX_STACK_SIZE
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def item_set_id_to_type(item_set_ids: set[int]):
26
+ return set(id_to_item(i) for i in item_set_ids)
27
+
28
+
29
+ def find_free_inventory_slot(inventory: list[dict], from_slot: int) -> int:
30
+ # find a free slot in the inventory for the item in from_slot
31
+ from_item_type, from_item_quantity = None, None
32
+
33
+ type_to_slot = {}
34
+ slot_to_quantity = {}
35
+ for item in inventory:
36
+ if ("slot" in item and item["slot"] == from_slot) or (
37
+ "index" in item and item["index"] == from_slot
38
+ ):
39
+ from_item_quantity = item["quantity"]
40
+ from_item_type = item["type"]
41
+ # break
42
+ item_type = item["type"]
43
+ quantity = item["quantity"]
44
+ if quantity == 0:
45
+ item_type = "air"
46
+
47
+ if "slot" in item:
48
+ item_slot = item["slot"]
49
+ else:
50
+ item_slot = item["index"]
51
+
52
+ if item_type not in type_to_slot:
53
+ type_to_slot[item_type] = [item_slot]
54
+ else:
55
+ type_to_slot[item_type].append(item_slot)
56
+
57
+ if item_slot not in slot_to_quantity:
58
+ slot_to_quantity[item_slot] = quantity
59
+ else:
60
+ slot_to_quantity[item_slot] += quantity
61
+
62
+ assert from_item_type is not None, f"Item not found in slot {from_slot}"
63
+
64
+ # if there is a free slot with the same item type
65
+ if from_item_type in type_to_slot:
66
+ for slot in type_to_slot[from_item_type]:
67
+ if (
68
+ slot != from_slot
69
+ and slot_to_quantity[slot] + from_item_quantity
70
+ <= MAX_STACK_SIZE[from_item_type]
71
+ ):
72
+ return slot
73
+
74
+ # if there is a free slot with air
75
+ for slot in type_to_slot["air"]:
76
+ if slot != from_slot and slot > 10:
77
+ return slot
78
+
79
+ raise ValueError("No free slot found")
80
+
81
+
82
+ def find_item_in_inventory(target: str, inventory: list[dict]) -> int:
83
+ for item in inventory:
84
+ if item["type"] == target and item["quantity"] > 0:
85
+ if "slot" in item:
86
+ return item["slot"]
87
+ elif "index" in item:
88
+ return item["index"]
89
+ raise ValueError("Neither slot or index is set")
90
+
91
+
92
+ def get_inventory_counter(inventory: list[dict]) -> Counter:
93
+ counter = Counter()
94
+ for item in inventory:
95
+ if "slot" in item and item["slot"] == 0:
96
+ continue
97
+ if "index" in item and item["index"] == 0:
98
+ continue
99
+ if item["type"] == "air":
100
+ continue
101
+ counter[item["type"]] += item["quantity"]
102
+ return counter
103
+
104
+
105
+ def get_crafting_slot_item(inventory: list[dict]) -> dict:
106
+ for item in inventory:
107
+ if "slot" in item and item["slot"] == 0 and item["quantity"] > 0:
108
+ return item
109
+ if "index" in item and item["index"] == 0 and item["quantity"] > 0:
110
+ return item
111
+ return None
112
+
113
+
114
+ def update_inventory(
115
+ inventory: list[dict], slot_from: int, slot_to: int, quantity: int
116
+ ) -> list[dict]:
117
+ """
118
+ decrements quantity of item in slot_from
119
+ NOTE: we don't care about incrementing the items in slot_to
120
+
121
+ """
122
+ new_inventory = []
123
+ for item in inventory:
124
+ if "slot" in item and item["slot"] == slot_from:
125
+ item["quantity"] -= quantity
126
+ elif "index" in item and item["index"] == slot_from:
127
+ item["quantity"] -= quantity
128
+ new_inventory.append(item)
129
+ return new_inventory
130
+
131
+
132
+ class OracleModel(ABCModel):
133
+ """
134
+ Oracle model returns actions that solve the task optimally
135
+ """
136
+
137
+ def __init__(self, cfg: EvalConfig):
138
+ assert (
139
+ cfg.plancraft.environment.symbolic_action_space
140
+ ), "Only symbolic actions are supported for oracle"
141
+ self.history = History(objective="")
142
+ self.plans = []
143
+ self.subplans = []
144
+
145
+ def reset_history(self, objective: str = ""):
146
+ self.history.reset(objective=objective)
147
+ self.plans = []
148
+ self.subplans = []
149
+
150
+ def get_plan(self, observation: dict):
151
+ # objective="Craft an item of type: ...."
152
+ # this simply recovering the target item to craft
153
+ target = self.history.objective.split(": ")[-1]
154
+ inventory_counter = get_inventory_counter(observation["inventory"])
155
+ self.plans = optimal_planner(target=target, inventory=inventory_counter)
156
+
157
+ def get_next_action(
158
+ self, observation: dict
159
+ ) -> SymbolicMoveAction | SymbolicSmeltAction:
160
+ if len(self.subplans) > 0:
161
+ return self.subplans.pop(0)
162
+ if len(self.plans) == 0:
163
+ raise ValueError("No more steps in plan")
164
+
165
+ observed_inventory = copy.deepcopy(observation["inventory"])
166
+
167
+ # take item from crafting slot
168
+ if slot_item := get_crafting_slot_item(observed_inventory):
169
+ # move item from crafting slot to inventory
170
+ free_slot = find_free_inventory_slot(observed_inventory, from_slot=0)
171
+ return SymbolicMoveAction(
172
+ slot_from=0, slot_to=free_slot, quantity=slot_item["quantity"]
173
+ )
174
+
175
+ plan_recipe, new_inventory = self.plans.pop(0)
176
+ self.subplans = []
177
+ new_inventory_counter = Counter(new_inventory)
178
+ current_inventory = observed_inventory
179
+ current_inventory_counter = get_inventory_counter(current_inventory)
180
+ items_to_use_counter = current_inventory_counter - new_inventory_counter
181
+ new_items = new_inventory_counter - current_inventory_counter
182
+ assert len(new_items) == 1
183
+
184
+ if isinstance(plan_recipe, ShapelessRecipe):
185
+ crafting_slot = 1
186
+
187
+ # add each item to crafting slots
188
+ for item, quantity in items_to_use_counter.items():
189
+ n = 0
190
+ while n < quantity:
191
+ from_slot = find_item_in_inventory(item, current_inventory)
192
+
193
+ # skip if from_slot is the crafting slot
194
+ if from_slot == crafting_slot:
195
+ crafting_slot += 1
196
+ n += 1
197
+ continue
198
+
199
+ # low_level_plan.append(("move", item, from_slot, crafting_slot, 1))
200
+ action = SymbolicMoveAction(
201
+ slot_from=from_slot, slot_to=crafting_slot, quantity=1
202
+ )
203
+ # update state of inventory
204
+ current_inventory = update_inventory(
205
+ current_inventory, from_slot, crafting_slot, 1
206
+ )
207
+ self.subplans.append(action)
208
+
209
+ crafting_slot += 1
210
+ n += 1
211
+
212
+ # if plan_recipe is a smelting recipe
213
+ elif isinstance(plan_recipe, SmeltingRecipe):
214
+ assert len(items_to_use_counter) == 1, "smelting only supports one item"
215
+ for item, quantity in items_to_use_counter.items():
216
+ from_slot = find_item_in_inventory(item, current_inventory)
217
+ free_slot = find_free_inventory_slot(
218
+ current_inventory, from_slot=from_slot
219
+ )
220
+ action = SymbolicSmeltAction(
221
+ slot_from=from_slot, slot_to=free_slot, quantity=quantity
222
+ )
223
+ self.subplans.append(action)
224
+
225
+ # if plan_recipe is a shaped recipe
226
+ elif isinstance(plan_recipe, ShapedRecipe):
227
+ for i, row in enumerate(plan_recipe.kernel):
228
+ for j, item_set in enumerate(row):
229
+ inventory_position = (i * 3) + j + 1
230
+ valid_items = item_set_id_to_type(item_set)
231
+ for item in valid_items:
232
+ if items_to_use_counter[item] > 0:
233
+ from_slot = find_item_in_inventory(item, current_inventory)
234
+ action = SymbolicMoveAction(
235
+ slot_from=from_slot,
236
+ slot_to=inventory_position,
237
+ quantity=1,
238
+ )
239
+ items_to_use_counter[item] -= 1
240
+ # update state of inventory
241
+ current_inventory = update_inventory(
242
+ current_inventory, from_slot, inventory_position, 1
243
+ )
244
+ self.subplans.append(action)
245
+ break
246
+ else:
247
+ raise NotImplementedError(f"Recipe type {type(plan_recipe)} not supported")
248
+
249
+ return self.subplans.pop(0)
250
+
251
+ def step(
252
+ self, observation: dict
253
+ ) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
254
+ # add observation to history
255
+ self.history.add_observation_to_history(observation)
256
+
257
+ # get action
258
+ if len(self.plans) == 0:
259
+ self.get_plan(observation)
260
+ if self.plans is None:
261
+ self.plans = []
262
+ return StopAction()
263
+
264
+ action = self.get_next_action(observation)
265
+
266
+ # add action to history
267
+ self.history.add_action_to_history(action)
268
+ return action