plancraft 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
models/oam.py DELETED
@@ -1,283 +0,0 @@
1
- from typing import Optional
2
-
3
- from loguru import logger
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torchvision.transforms.v2 as v2
8
- from transformers import (
9
- AutoConfig,
10
- AutoModelForCausalLM,
11
- AutoTokenizer,
12
- PretrainedConfig,
13
- PreTrainedModel,
14
- )
15
-
16
- from plancraft.models.bbox_model import IntegratedBoundingBoxModel
17
-
18
-
19
- class PlancraftOAMConfig(PretrainedConfig):
20
- model_type = "plancraft-aom"
21
- is_composition = True
22
-
23
- def __init__(
24
- self,
25
- from_llama=False,
26
- **kwargs,
27
- ):
28
- self.from_llama = from_llama
29
- super().__init__(**kwargs)
30
-
31
-
32
- class PlancraftOAM(PreTrainedModel):
33
- config_class = PlancraftOAMConfig
34
-
35
- def __init__(self, config: PlancraftOAMConfig):
36
- super().__init__(config)
37
-
38
- self.config = config
39
- # load text model
40
- if self.config.from_llama:
41
- self.text_model = AutoModelForCausalLM.from_pretrained(
42
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
43
- )
44
- else:
45
- text_model_config = AutoConfig.from_pretrained(
46
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
47
- )
48
- self.text_model = AutoModelForCausalLM.from_config(text_model_config)
49
-
50
- # load vision model
51
- self.vision_model = IntegratedBoundingBoxModel.from_pretrained(
52
- "gautierdag/plancraft-maskrcnn"
53
- )
54
- self.vision_model.eval()
55
-
56
- # convert vision features to text embedding
57
- self.vision_to_text_embedding = nn.Linear(
58
- 1024, self.text_model.config.hidden_size
59
- )
60
- self.tokenizer = AutoTokenizer.from_pretrained(
61
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
62
- trust_remote=True,
63
- )
64
- # add special tokens
65
- self.tokenizer.add_special_tokens(
66
- {
67
- "additional_special_tokens": [
68
- "<|inventory|>",
69
- ]
70
- }
71
- )
72
- self.tokenizer.pad_token = self.tokenizer.eos_token
73
- self.inventory_idx = self.tokenizer.convert_tokens_to_ids("<|inventory|>")
74
-
75
- # resize token embeddings
76
- self.text_model.resize_token_embeddings(len(self.tokenizer))
77
- # image transforms
78
- self.transforms = v2.Compose(
79
- [v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
80
- )
81
-
82
- @torch.no_grad()
83
- def extract_bboxes(self, images: list) -> list[dict]:
84
- if len(images) == 0:
85
- return []
86
- img_tensors = torch.stack([self.transforms(img) for img in images])
87
- img_tensors = img_tensors.cuda()
88
- # disable gradients
89
- self.vision_model.freeze()
90
- # get bounding box predictions
91
- bbox_preds = self.vision_model(img_tensors)
92
- return bbox_preds
93
-
94
- def prepare_messages(self, messages: list, bboxes: list[dict]) -> str:
95
- # no bounding boxes
96
- if len(bboxes) == 0:
97
- text = self.tokenizer.apply_chat_template(
98
- messages, add_generation_prompt=not self.training, tokenize=False
99
- )
100
- text = text.replace("<|begin_of_text|>", "")
101
- return text
102
-
103
- # expand <|inventory|> tokens into N tokens (N = number of bounding boxes)
104
- new_messages = []
105
- i_pred = 0
106
- for m in messages:
107
- new_message = m.copy()
108
- if new_message["role"] == "user" and new_message["content"].endswith(
109
- "<|inventory|>"
110
- ):
111
- # add inventory tokens for each bounding box
112
- new_message["content"] = new_message["content"].replace(
113
- "<|inventory|>",
114
- "<|inventory|>" * (bboxes[i_pred]["features"].shape[0]),
115
- )
116
- i_pred += 1
117
- new_messages.append(new_message)
118
-
119
- assert i_pred == len(
120
- bboxes
121
- ), "Number of inventory tokens does not match number of images"
122
- # add special tokens
123
-
124
- text = self.tokenizer.apply_chat_template(
125
- new_messages, add_generation_prompt=not self.training, tokenize=False
126
- )
127
- text = text.replace("<|begin_of_text|>", "")
128
- return text
129
-
130
- def inputs_merger(
131
- self,
132
- input_ids: torch.LongTensor,
133
- inputs_embeds: Optional[torch.Tensor],
134
- image_hidden_states: Optional[torch.Tensor],
135
- ):
136
- # along batch dimension
137
- for i in range(len(image_hidden_states)):
138
- if len(image_hidden_states[i]) == 0:
139
- assert (
140
- input_ids[i] == self.inventory_idx
141
- ).sum() == 0, "No images but inventory token is still present"
142
- continue
143
-
144
- # count the number of inventory tokens
145
- n_inventory_tokens = (input_ids[i] == self.inventory_idx).sum()
146
- if n_inventory_tokens != image_hidden_states[i].shape[0]:
147
- logger.warning(
148
- f"Number of inventory tokens ({n_inventory_tokens}) does not match number of bounding boxes ({image_hidden_states[i].shape[0]}). Possible truncation."
149
- )
150
- # truncated from the start
151
- image_hidden_states[i] = image_hidden_states[i][-n_inventory_tokens:]
152
-
153
- # replace inventory tokens with bbox features
154
- inputs_embeds[i, input_ids[i] == self.inventory_idx] = image_hidden_states[
155
- i
156
- ]
157
- return inputs_embeds
158
-
159
- def process_inputs(
160
- self,
161
- batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
162
- batch_images: list[list] = [], # list of list of images (unprocessed)
163
- ) -> tuple[dict[str, torch.FloatTensor], list[torch.FloatTensor], int]:
164
- """
165
- Converts raw images and messages into model inputs
166
- """
167
- assert len(batch_images) == len(
168
- batch_messages
169
- ), "Number of images and messages should match in the batch dim"
170
- # initial forward pass
171
- texts_batch = []
172
- image_hidden_states = []
173
- total_boxes = 0
174
- for images, messages in zip(batch_images, batch_messages):
175
- # process images
176
- bboxes = self.extract_bboxes(images)
177
- if len(bboxes) > 0:
178
- # get bbox features
179
- features = torch.concat([p["features"] for p in bboxes], dim=0)
180
- # upscale to text embedding size
181
- features_embeds = self.vision_to_text_embedding(features)
182
- image_hidden_states.append(features_embeds)
183
- # count bboxes total
184
- total_boxes += features.shape[0]
185
- else:
186
- image_hidden_states.append([])
187
-
188
- # process messages
189
- text = self.prepare_messages(messages, bboxes)
190
- texts_batch.append(text)
191
-
192
- # tokenize text
193
- # @NOTE: truncation could cause issues with inventory tokens not matching number of boxes
194
- # in that case, we will truncate the boxes from the end, and issue a warning
195
- batch = self.tokenizer(
196
- texts_batch,
197
- truncation=True,
198
- padding=True,
199
- max_length=16384,
200
- return_tensors="pt",
201
- )
202
- return batch, image_hidden_states, total_boxes
203
-
204
- def forward(
205
- self,
206
- batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
207
- batch_images: list[list] = [], # list of list of images (unprocessed)
208
- **kwargs,
209
- ):
210
- labels = None
211
- batch, image_hidden_states, total_boxes = self.process_inputs(
212
- batch_messages, batch_images
213
- )
214
- # move to cuda
215
- batch = {k: v.cuda() for k, v in batch.items()}
216
- attention_mask = batch["attention_mask"]
217
- input_ids = batch["input_ids"]
218
-
219
- labels = input_ids.clone()
220
- # remove inventory tokens from labels
221
- labels[labels == self.inventory_idx] = -100
222
- # sanity check: should have same number of boxes as inventory tokens
223
- assert (labels == -100).sum() == total_boxes
224
-
225
- # get text embeddings
226
- inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
227
- inputs_embeds = self.inputs_merger(
228
- input_ids, inputs_embeds, image_hidden_states
229
- )
230
- # forward pass
231
- return self.text_model(
232
- inputs_embeds=inputs_embeds,
233
- attention_mask=attention_mask,
234
- labels=labels,
235
- return_dict=True,
236
- )
237
-
238
- @torch.no_grad()
239
- def generate(
240
- self,
241
- batch_messages: list[list[dict]],
242
- batch_images: list[list],
243
- do_sample=True,
244
- temperature=0.6,
245
- max_new_tokens=32,
246
- ):
247
- self.tokenizer.padding_side = "left"
248
-
249
- batch, image_hidden_states, _ = self.process_inputs(
250
- batch_messages, batch_images
251
- )
252
- batch = {k: v.cuda() for k, v in batch.items()}
253
- attention_mask = batch["attention_mask"]
254
- input_ids = batch["input_ids"]
255
-
256
- inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
257
- inputs_embeds = self.inputs_merger(
258
- input_ids, inputs_embeds, image_hidden_states
259
- )
260
-
261
- generated_sequences = self.text_model.generate(
262
- inputs_embeds=inputs_embeds,
263
- attention_mask=attention_mask,
264
- do_sample=do_sample,
265
- temperature=temperature,
266
- max_new_tokens=max_new_tokens,
267
- pad_token_id=self.tokenizer.eos_token_id,
268
- )
269
-
270
- # Decode the output
271
- text_responses = self.tokenizer.batch_decode(
272
- generated_sequences,
273
- # generated_sequences[:, prompt_tokens:],
274
- skip_special_tokens=False,
275
- )
276
-
277
- # remove <|eot_id|> tokens
278
- text_responses = [
279
- text_response.replace("<|eot_id|>", "") for text_response in text_responses
280
- ]
281
-
282
- _, total_tokens_used = generated_sequences.shape
283
- return text_responses, total_tokens_used
models/oracle.py DELETED
@@ -1,265 +0,0 @@
1
- import copy
2
- from collections import Counter
3
-
4
- from plancraft.config import EvalConfig
5
- from plancraft.environments.actions import (
6
- RealActionInteraction,
7
- StopAction,
8
- SymbolicMoveAction,
9
- SymbolicSmeltAction,
10
- )
11
- from plancraft.environments.planner import optimal_planner
12
- from plancraft.environments.recipes import (
13
- ShapedRecipe,
14
- ShapelessRecipe,
15
- SmeltingRecipe,
16
- id_to_item,
17
- )
18
- from plancraft.environments.sampler import MAX_STACK_SIZE
19
- from plancraft.models.base import ABCModel, History
20
-
21
-
22
- def item_set_id_to_type(item_set_ids: set[int]):
23
- return set(id_to_item(i) for i in item_set_ids)
24
-
25
-
26
- def find_free_inventory_slot(inventory: list[dict], from_slot: int) -> int:
27
- # find a free slot in the inventory for the item in from_slot
28
- from_item_type, from_item_quantity = None, None
29
-
30
- type_to_slot = {}
31
- slot_to_quantity = {}
32
- for item in inventory:
33
- if ("slot" in item and item["slot"] == from_slot) or (
34
- "index" in item and item["index"] == from_slot
35
- ):
36
- from_item_quantity = item["quantity"]
37
- from_item_type = item["type"]
38
- # break
39
- item_type = item["type"]
40
- quantity = item["quantity"]
41
- if quantity == 0:
42
- item_type = "air"
43
-
44
- if "slot" in item:
45
- item_slot = item["slot"]
46
- else:
47
- item_slot = item["index"]
48
-
49
- if item_type not in type_to_slot:
50
- type_to_slot[item_type] = [item_slot]
51
- else:
52
- type_to_slot[item_type].append(item_slot)
53
-
54
- if item_slot not in slot_to_quantity:
55
- slot_to_quantity[item_slot] = quantity
56
- else:
57
- slot_to_quantity[item_slot] += quantity
58
-
59
- assert from_item_type is not None, f"Item not found in slot {from_slot}"
60
-
61
- # if there is a free slot with the same item type
62
- if from_item_type in type_to_slot:
63
- for slot in type_to_slot[from_item_type]:
64
- if (
65
- slot != from_slot
66
- and slot_to_quantity[slot] + from_item_quantity
67
- <= MAX_STACK_SIZE[from_item_type]
68
- ):
69
- return slot
70
-
71
- # if there is a free slot with air
72
- for slot in type_to_slot["air"]:
73
- if slot != from_slot and slot > 10:
74
- return slot
75
-
76
- raise ValueError("No free slot found")
77
-
78
-
79
- def find_item_in_inventory(target: str, inventory: list[dict]) -> int:
80
- for item in inventory:
81
- if item["type"] == target and item["quantity"] > 0:
82
- if "slot" in item:
83
- return item["slot"]
84
- elif "index" in item:
85
- return item["index"]
86
- raise ValueError("Neither slot or index is set")
87
-
88
-
89
- def get_inventory_counter(inventory: list[dict]) -> Counter:
90
- counter = Counter()
91
- for item in inventory:
92
- if "slot" in item and item["slot"] == 0:
93
- continue
94
- if "index" in item and item["index"] == 0:
95
- continue
96
- if item["type"] == "air":
97
- continue
98
- counter[item["type"]] += item["quantity"]
99
- return counter
100
-
101
-
102
- def get_crafting_slot_item(inventory: list[dict]) -> dict:
103
- for item in inventory:
104
- if "slot" in item and item["slot"] == 0 and item["quantity"] > 0:
105
- return item
106
- if "index" in item and item["index"] == 0 and item["quantity"] > 0:
107
- return item
108
- return None
109
-
110
-
111
- def update_inventory(
112
- inventory: list[dict], slot_from: int, slot_to: int, quantity: int
113
- ) -> list[dict]:
114
- """
115
- decrements quantity of item in slot_from
116
- NOTE: we don't care about incrementing the items in slot_to
117
-
118
- """
119
- new_inventory = []
120
- for item in inventory:
121
- if "slot" in item and item["slot"] == slot_from:
122
- item["quantity"] -= quantity
123
- elif "index" in item and item["index"] == slot_from:
124
- item["quantity"] -= quantity
125
- new_inventory.append(item)
126
- return new_inventory
127
-
128
-
129
- class OracleModel(ABCModel):
130
- """
131
- Oracle model returns actions that solve the task optimally
132
- """
133
-
134
- def __init__(self, cfg: EvalConfig):
135
- assert (
136
- cfg.plancraft.environment.symbolic_action_space
137
- ), "Only symbolic actions are supported for oracle"
138
- self.history = History(objective="")
139
- self.plans = []
140
- self.subplans = []
141
-
142
- def reset_history(self, objective: str = ""):
143
- self.history.reset(objective=objective)
144
- self.plans = []
145
- self.subplans = []
146
-
147
- def get_plan(self, observation: dict):
148
- # objective="Craft an item of type: ...."
149
- # this simply recovering the target item to craft
150
- target = self.history.objective.split(": ")[-1]
151
- inventory_counter = get_inventory_counter(observation["inventory"])
152
- self.plans = optimal_planner(target=target, inventory=inventory_counter)
153
-
154
- def get_next_action(
155
- self, observation: dict
156
- ) -> SymbolicMoveAction | SymbolicSmeltAction:
157
- if len(self.subplans) > 0:
158
- return self.subplans.pop(0)
159
- if len(self.plans) == 0:
160
- raise ValueError("No more steps in plan")
161
-
162
- observed_inventory = copy.deepcopy(observation["inventory"])
163
-
164
- # take item from crafting slot
165
- if slot_item := get_crafting_slot_item(observed_inventory):
166
- # move item from crafting slot to inventory
167
- free_slot = find_free_inventory_slot(observed_inventory, from_slot=0)
168
- return SymbolicMoveAction(
169
- slot_from=0, slot_to=free_slot, quantity=slot_item["quantity"]
170
- )
171
-
172
- plan_recipe, new_inventory = self.plans.pop(0)
173
- self.subplans = []
174
- new_inventory_counter = Counter(new_inventory)
175
- current_inventory = observed_inventory
176
- current_inventory_counter = get_inventory_counter(current_inventory)
177
- items_to_use_counter = current_inventory_counter - new_inventory_counter
178
- new_items = new_inventory_counter - current_inventory_counter
179
- assert len(new_items) == 1
180
-
181
- if isinstance(plan_recipe, ShapelessRecipe):
182
- crafting_slot = 1
183
-
184
- # add each item to crafting slots
185
- for item, quantity in items_to_use_counter.items():
186
- n = 0
187
- while n < quantity:
188
- from_slot = find_item_in_inventory(item, current_inventory)
189
-
190
- # skip if from_slot is the crafting slot
191
- if from_slot == crafting_slot:
192
- crafting_slot += 1
193
- n += 1
194
- continue
195
-
196
- # low_level_plan.append(("move", item, from_slot, crafting_slot, 1))
197
- action = SymbolicMoveAction(
198
- slot_from=from_slot, slot_to=crafting_slot, quantity=1
199
- )
200
- # update state of inventory
201
- current_inventory = update_inventory(
202
- current_inventory, from_slot, crafting_slot, 1
203
- )
204
- self.subplans.append(action)
205
-
206
- crafting_slot += 1
207
- n += 1
208
-
209
- # if plan_recipe is a smelting recipe
210
- elif isinstance(plan_recipe, SmeltingRecipe):
211
- assert len(items_to_use_counter) == 1, "smelting only supports one item"
212
- for item, quantity in items_to_use_counter.items():
213
- from_slot = find_item_in_inventory(item, current_inventory)
214
- free_slot = find_free_inventory_slot(
215
- current_inventory, from_slot=from_slot
216
- )
217
- action = SymbolicSmeltAction(
218
- slot_from=from_slot, slot_to=free_slot, quantity=quantity
219
- )
220
- self.subplans.append(action)
221
-
222
- # if plan_recipe is a shaped recipe
223
- elif isinstance(plan_recipe, ShapedRecipe):
224
- for i, row in enumerate(plan_recipe.kernel):
225
- for j, item_set in enumerate(row):
226
- inventory_position = (i * 3) + j + 1
227
- valid_items = item_set_id_to_type(item_set)
228
- for item in valid_items:
229
- if items_to_use_counter[item] > 0:
230
- from_slot = find_item_in_inventory(item, current_inventory)
231
- action = SymbolicMoveAction(
232
- slot_from=from_slot,
233
- slot_to=inventory_position,
234
- quantity=1,
235
- )
236
- items_to_use_counter[item] -= 1
237
- # update state of inventory
238
- current_inventory = update_inventory(
239
- current_inventory, from_slot, inventory_position, 1
240
- )
241
- self.subplans.append(action)
242
- break
243
- else:
244
- raise NotImplementedError(f"Recipe type {type(plan_recipe)} not supported")
245
-
246
- return self.subplans.pop(0)
247
-
248
- def step(
249
- self, observation: dict
250
- ) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
251
- # add observation to history
252
- self.history.add_observation_to_history(observation)
253
-
254
- # get action
255
- if len(self.plans) == 0:
256
- self.get_plan(observation)
257
- if self.plans is None:
258
- self.plans = []
259
- return StopAction()
260
-
261
- action = self.get_next_action(observation)
262
-
263
- # add action to history
264
- self.history.add_action_to_history(action)
265
- return action