plancraft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
models/oam.py DELETED
@@ -1,284 +0,0 @@
1
- import logging
2
- from typing import Optional
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torchvision.transforms.v2 as v2
7
- from transformers import (
8
- AutoConfig,
9
- AutoModelForCausalLM,
10
- AutoTokenizer,
11
- PretrainedConfig,
12
- PreTrainedModel,
13
- )
14
-
15
- from plancraft.models.bbox_model import IntegratedBoundingBoxModel
16
-
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- class PlancraftOAMConfig(PretrainedConfig):
21
- model_type = "plancraft-aom"
22
- is_composition = True
23
-
24
- def __init__(
25
- self,
26
- from_llama=False,
27
- **kwargs,
28
- ):
29
- self.from_llama = from_llama
30
- super().__init__(**kwargs)
31
-
32
-
33
- class PlancraftOAM(PreTrainedModel):
34
- config_class = PlancraftOAMConfig
35
-
36
- def __init__(self, config: PlancraftOAMConfig):
37
- super().__init__(config)
38
-
39
- self.config = config
40
- # load text model
41
- if self.config.from_llama:
42
- self.text_model = AutoModelForCausalLM.from_pretrained(
43
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
44
- )
45
- else:
46
- text_model_config = AutoConfig.from_pretrained(
47
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
48
- )
49
- self.text_model = AutoModelForCausalLM.from_config(text_model_config)
50
-
51
- # load vision model
52
- self.vision_model = IntegratedBoundingBoxModel.from_pretrained(
53
- "gautierdag/plancraft-maskrcnn"
54
- )
55
- self.vision_model.eval()
56
-
57
- # convert vision features to text embedding
58
- self.vision_to_text_embedding = nn.Linear(
59
- 1024, self.text_model.config.hidden_size
60
- )
61
- self.tokenizer = AutoTokenizer.from_pretrained(
62
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
63
- trust_remote=True,
64
- )
65
- # add special tokens
66
- self.tokenizer.add_special_tokens(
67
- {
68
- "additional_special_tokens": [
69
- "<|inventory|>",
70
- ]
71
- }
72
- )
73
- self.tokenizer.pad_token = self.tokenizer.eos_token
74
- self.inventory_idx = self.tokenizer.convert_tokens_to_ids("<|inventory|>")
75
-
76
- # resize token embeddings
77
- self.text_model.resize_token_embeddings(len(self.tokenizer))
78
- # image transforms
79
- self.transforms = v2.Compose(
80
- [v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
81
- )
82
-
83
- @torch.no_grad()
84
- def extract_bboxes(self, images: list) -> list[dict]:
85
- if len(images) == 0:
86
- return []
87
- img_tensors = torch.stack([self.transforms(img) for img in images])
88
- img_tensors = img_tensors.cuda()
89
- # disable gradients
90
- self.vision_model.freeze()
91
- # get bounding box predictions
92
- bbox_preds = self.vision_model(img_tensors)
93
- return bbox_preds
94
-
95
- def prepare_messages(self, messages: list, bboxes: list[dict]) -> str:
96
- # no bounding boxes
97
- if len(bboxes) == 0:
98
- text = self.tokenizer.apply_chat_template(
99
- messages, add_generation_prompt=not self.training, tokenize=False
100
- )
101
- text = text.replace("<|begin_of_text|>", "")
102
- return text
103
-
104
- # expand <|inventory|> tokens into N tokens (N = number of bounding boxes)
105
- new_messages = []
106
- i_pred = 0
107
- for m in messages:
108
- new_message = m.copy()
109
- if new_message["role"] == "user" and new_message["content"].endswith(
110
- "<|inventory|>"
111
- ):
112
- # add inventory tokens for each bounding box
113
- new_message["content"] = new_message["content"].replace(
114
- "<|inventory|>",
115
- "<|inventory|>" * (bboxes[i_pred]["features"].shape[0]),
116
- )
117
- i_pred += 1
118
- new_messages.append(new_message)
119
-
120
- assert i_pred == len(
121
- bboxes
122
- ), "Number of inventory tokens does not match number of images"
123
- # add special tokens
124
-
125
- text = self.tokenizer.apply_chat_template(
126
- new_messages, add_generation_prompt=not self.training, tokenize=False
127
- )
128
- text = text.replace("<|begin_of_text|>", "")
129
- return text
130
-
131
- def inputs_merger(
132
- self,
133
- input_ids: torch.LongTensor,
134
- inputs_embeds: Optional[torch.Tensor],
135
- image_hidden_states: Optional[torch.Tensor],
136
- ):
137
- # along batch dimension
138
- for i in range(len(image_hidden_states)):
139
- if len(image_hidden_states[i]) == 0:
140
- assert (
141
- input_ids[i] == self.inventory_idx
142
- ).sum() == 0, "No images but inventory token is still present"
143
- continue
144
-
145
- # count the number of inventory tokens
146
- n_inventory_tokens = (input_ids[i] == self.inventory_idx).sum()
147
- if n_inventory_tokens != image_hidden_states[i].shape[0]:
148
- logger.warning(
149
- f"Number of inventory tokens ({n_inventory_tokens}) does not match number of bounding boxes ({image_hidden_states[i].shape[0]}). Possible truncation."
150
- )
151
- # truncated from the start
152
- image_hidden_states[i] = image_hidden_states[i][-n_inventory_tokens:]
153
-
154
- # replace inventory tokens with bbox features
155
- inputs_embeds[i, input_ids[i] == self.inventory_idx] = image_hidden_states[
156
- i
157
- ]
158
- return inputs_embeds
159
-
160
- def process_inputs(
161
- self,
162
- batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
163
- batch_images: list[list] = [], # list of list of images (unprocessed)
164
- ) -> tuple[dict[str, torch.FloatTensor], list[torch.FloatTensor], int]:
165
- """
166
- Converts raw images and messages into model inputs
167
- """
168
- assert len(batch_images) == len(
169
- batch_messages
170
- ), "Number of images and messages should match in the batch dim"
171
- # initial forward pass
172
- texts_batch = []
173
- image_hidden_states = []
174
- total_boxes = 0
175
- for images, messages in zip(batch_images, batch_messages):
176
- # process images
177
- bboxes = self.extract_bboxes(images)
178
- if len(bboxes) > 0:
179
- # get bbox features
180
- features = torch.concat([p["features"] for p in bboxes], dim=0)
181
- # upscale to text embedding size
182
- features_embeds = self.vision_to_text_embedding(features)
183
- image_hidden_states.append(features_embeds)
184
- # count bboxes total
185
- total_boxes += features.shape[0]
186
- else:
187
- image_hidden_states.append([])
188
-
189
- # process messages
190
- text = self.prepare_messages(messages, bboxes)
191
- texts_batch.append(text)
192
-
193
- # tokenize text
194
- # @NOTE: truncation could cause issues with inventory tokens not matching number of boxes
195
- # in that case, we will truncate the boxes from the end, and issue a warning
196
- batch = self.tokenizer(
197
- texts_batch,
198
- truncation=True,
199
- padding=True,
200
- max_length=16384,
201
- return_tensors="pt",
202
- )
203
- return batch, image_hidden_states, total_boxes
204
-
205
- def forward(
206
- self,
207
- batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
208
- batch_images: list[list] = [], # list of list of images (unprocessed)
209
- **kwargs,
210
- ):
211
- labels = None
212
- batch, image_hidden_states, total_boxes = self.process_inputs(
213
- batch_messages, batch_images
214
- )
215
- # move to cuda
216
- batch = {k: v.cuda() for k, v in batch.items()}
217
- attention_mask = batch["attention_mask"]
218
- input_ids = batch["input_ids"]
219
-
220
- labels = input_ids.clone()
221
- # remove inventory tokens from labels
222
- labels[labels == self.inventory_idx] = -100
223
- # sanity check: should have same number of boxes as inventory tokens
224
- assert (labels == -100).sum() == total_boxes
225
-
226
- # get text embeddings
227
- inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
228
- inputs_embeds = self.inputs_merger(
229
- input_ids, inputs_embeds, image_hidden_states
230
- )
231
- # forward pass
232
- return self.text_model(
233
- inputs_embeds=inputs_embeds,
234
- attention_mask=attention_mask,
235
- labels=labels,
236
- return_dict=True,
237
- )
238
-
239
- @torch.no_grad()
240
- def generate(
241
- self,
242
- batch_messages: list[list[dict]],
243
- batch_images: list[list],
244
- do_sample=True,
245
- temperature=0.6,
246
- max_new_tokens=32,
247
- ):
248
- self.tokenizer.padding_side = "left"
249
-
250
- batch, image_hidden_states, _ = self.process_inputs(
251
- batch_messages, batch_images
252
- )
253
- batch = {k: v.cuda() for k, v in batch.items()}
254
- attention_mask = batch["attention_mask"]
255
- input_ids = batch["input_ids"]
256
-
257
- inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
258
- inputs_embeds = self.inputs_merger(
259
- input_ids, inputs_embeds, image_hidden_states
260
- )
261
-
262
- generated_sequences = self.text_model.generate(
263
- inputs_embeds=inputs_embeds,
264
- attention_mask=attention_mask,
265
- do_sample=do_sample,
266
- temperature=temperature,
267
- max_new_tokens=max_new_tokens,
268
- pad_token_id=self.tokenizer.eos_token_id,
269
- )
270
-
271
- # Decode the output
272
- text_responses = self.tokenizer.batch_decode(
273
- generated_sequences,
274
- # generated_sequences[:, prompt_tokens:],
275
- skip_special_tokens=False,
276
- )
277
-
278
- # remove <|eot_id|> tokens
279
- text_responses = [
280
- text_response.replace("<|eot_id|>", "") for text_response in text_responses
281
- ]
282
-
283
- _, total_tokens_used = generated_sequences.shape
284
- return text_responses, total_tokens_used
models/oracle.py DELETED
@@ -1,268 +0,0 @@
1
- import logging
2
- import copy
3
- from collections import Counter
4
-
5
- from plancraft.config import EvalConfig
6
- from plancraft.environments.actions import (
7
- RealActionInteraction,
8
- SymbolicMoveAction,
9
- SymbolicSmeltAction,
10
- StopAction,
11
- )
12
- from plancraft.environments.planner import optimal_planner
13
- from plancraft.environments.recipes import (
14
- ShapedRecipe,
15
- ShapelessRecipe,
16
- SmeltingRecipe,
17
- id_to_item,
18
- )
19
- from plancraft.models.base import ABCModel, History
20
- from plancraft.environments.sampler import MAX_STACK_SIZE
21
-
22
- logger = logging.getLogger(__name__)
23
-
24
-
25
- def item_set_id_to_type(item_set_ids: set[int]):
26
- return set(id_to_item(i) for i in item_set_ids)
27
-
28
-
29
- def find_free_inventory_slot(inventory: list[dict], from_slot: int) -> int:
30
- # find a free slot in the inventory for the item in from_slot
31
- from_item_type, from_item_quantity = None, None
32
-
33
- type_to_slot = {}
34
- slot_to_quantity = {}
35
- for item in inventory:
36
- if ("slot" in item and item["slot"] == from_slot) or (
37
- "index" in item and item["index"] == from_slot
38
- ):
39
- from_item_quantity = item["quantity"]
40
- from_item_type = item["type"]
41
- # break
42
- item_type = item["type"]
43
- quantity = item["quantity"]
44
- if quantity == 0:
45
- item_type = "air"
46
-
47
- if "slot" in item:
48
- item_slot = item["slot"]
49
- else:
50
- item_slot = item["index"]
51
-
52
- if item_type not in type_to_slot:
53
- type_to_slot[item_type] = [item_slot]
54
- else:
55
- type_to_slot[item_type].append(item_slot)
56
-
57
- if item_slot not in slot_to_quantity:
58
- slot_to_quantity[item_slot] = quantity
59
- else:
60
- slot_to_quantity[item_slot] += quantity
61
-
62
- assert from_item_type is not None, f"Item not found in slot {from_slot}"
63
-
64
- # if there is a free slot with the same item type
65
- if from_item_type in type_to_slot:
66
- for slot in type_to_slot[from_item_type]:
67
- if (
68
- slot != from_slot
69
- and slot_to_quantity[slot] + from_item_quantity
70
- <= MAX_STACK_SIZE[from_item_type]
71
- ):
72
- return slot
73
-
74
- # if there is a free slot with air
75
- for slot in type_to_slot["air"]:
76
- if slot != from_slot and slot > 10:
77
- return slot
78
-
79
- raise ValueError("No free slot found")
80
-
81
-
82
- def find_item_in_inventory(target: str, inventory: list[dict]) -> int:
83
- for item in inventory:
84
- if item["type"] == target and item["quantity"] > 0:
85
- if "slot" in item:
86
- return item["slot"]
87
- elif "index" in item:
88
- return item["index"]
89
- raise ValueError("Neither slot or index is set")
90
-
91
-
92
- def get_inventory_counter(inventory: list[dict]) -> Counter:
93
- counter = Counter()
94
- for item in inventory:
95
- if "slot" in item and item["slot"] == 0:
96
- continue
97
- if "index" in item and item["index"] == 0:
98
- continue
99
- if item["type"] == "air":
100
- continue
101
- counter[item["type"]] += item["quantity"]
102
- return counter
103
-
104
-
105
- def get_crafting_slot_item(inventory: list[dict]) -> dict:
106
- for item in inventory:
107
- if "slot" in item and item["slot"] == 0 and item["quantity"] > 0:
108
- return item
109
- if "index" in item and item["index"] == 0 and item["quantity"] > 0:
110
- return item
111
- return None
112
-
113
-
114
- def update_inventory(
115
- inventory: list[dict], slot_from: int, slot_to: int, quantity: int
116
- ) -> list[dict]:
117
- """
118
- decrements quantity of item in slot_from
119
- NOTE: we don't care about incrementing the items in slot_to
120
-
121
- """
122
- new_inventory = []
123
- for item in inventory:
124
- if "slot" in item and item["slot"] == slot_from:
125
- item["quantity"] -= quantity
126
- elif "index" in item and item["index"] == slot_from:
127
- item["quantity"] -= quantity
128
- new_inventory.append(item)
129
- return new_inventory
130
-
131
-
132
- class OracleModel(ABCModel):
133
- """
134
- Oracle model returns actions that solve the task optimally
135
- """
136
-
137
- def __init__(self, cfg: EvalConfig):
138
- assert (
139
- cfg.plancraft.environment.symbolic_action_space
140
- ), "Only symbolic actions are supported for oracle"
141
- self.history = History(objective="")
142
- self.plans = []
143
- self.subplans = []
144
-
145
- def reset_history(self, objective: str = ""):
146
- self.history.reset(objective=objective)
147
- self.plans = []
148
- self.subplans = []
149
-
150
- def get_plan(self, observation: dict):
151
- # objective="Craft an item of type: ...."
152
- # this simply recovering the target item to craft
153
- target = self.history.objective.split(": ")[-1]
154
- inventory_counter = get_inventory_counter(observation["inventory"])
155
- self.plans = optimal_planner(target=target, inventory=inventory_counter)
156
-
157
- def get_next_action(
158
- self, observation: dict
159
- ) -> SymbolicMoveAction | SymbolicSmeltAction:
160
- if len(self.subplans) > 0:
161
- return self.subplans.pop(0)
162
- if len(self.plans) == 0:
163
- raise ValueError("No more steps in plan")
164
-
165
- observed_inventory = copy.deepcopy(observation["inventory"])
166
-
167
- # take item from crafting slot
168
- if slot_item := get_crafting_slot_item(observed_inventory):
169
- # move item from crafting slot to inventory
170
- free_slot = find_free_inventory_slot(observed_inventory, from_slot=0)
171
- return SymbolicMoveAction(
172
- slot_from=0, slot_to=free_slot, quantity=slot_item["quantity"]
173
- )
174
-
175
- plan_recipe, new_inventory = self.plans.pop(0)
176
- self.subplans = []
177
- new_inventory_counter = Counter(new_inventory)
178
- current_inventory = observed_inventory
179
- current_inventory_counter = get_inventory_counter(current_inventory)
180
- items_to_use_counter = current_inventory_counter - new_inventory_counter
181
- new_items = new_inventory_counter - current_inventory_counter
182
- assert len(new_items) == 1
183
-
184
- if isinstance(plan_recipe, ShapelessRecipe):
185
- crafting_slot = 1
186
-
187
- # add each item to crafting slots
188
- for item, quantity in items_to_use_counter.items():
189
- n = 0
190
- while n < quantity:
191
- from_slot = find_item_in_inventory(item, current_inventory)
192
-
193
- # skip if from_slot is the crafting slot
194
- if from_slot == crafting_slot:
195
- crafting_slot += 1
196
- n += 1
197
- continue
198
-
199
- # low_level_plan.append(("move", item, from_slot, crafting_slot, 1))
200
- action = SymbolicMoveAction(
201
- slot_from=from_slot, slot_to=crafting_slot, quantity=1
202
- )
203
- # update state of inventory
204
- current_inventory = update_inventory(
205
- current_inventory, from_slot, crafting_slot, 1
206
- )
207
- self.subplans.append(action)
208
-
209
- crafting_slot += 1
210
- n += 1
211
-
212
- # if plan_recipe is a smelting recipe
213
- elif isinstance(plan_recipe, SmeltingRecipe):
214
- assert len(items_to_use_counter) == 1, "smelting only supports one item"
215
- for item, quantity in items_to_use_counter.items():
216
- from_slot = find_item_in_inventory(item, current_inventory)
217
- free_slot = find_free_inventory_slot(
218
- current_inventory, from_slot=from_slot
219
- )
220
- action = SymbolicSmeltAction(
221
- slot_from=from_slot, slot_to=free_slot, quantity=quantity
222
- )
223
- self.subplans.append(action)
224
-
225
- # if plan_recipe is a shaped recipe
226
- elif isinstance(plan_recipe, ShapedRecipe):
227
- for i, row in enumerate(plan_recipe.kernel):
228
- for j, item_set in enumerate(row):
229
- inventory_position = (i * 3) + j + 1
230
- valid_items = item_set_id_to_type(item_set)
231
- for item in valid_items:
232
- if items_to_use_counter[item] > 0:
233
- from_slot = find_item_in_inventory(item, current_inventory)
234
- action = SymbolicMoveAction(
235
- slot_from=from_slot,
236
- slot_to=inventory_position,
237
- quantity=1,
238
- )
239
- items_to_use_counter[item] -= 1
240
- # update state of inventory
241
- current_inventory = update_inventory(
242
- current_inventory, from_slot, inventory_position, 1
243
- )
244
- self.subplans.append(action)
245
- break
246
- else:
247
- raise NotImplementedError(f"Recipe type {type(plan_recipe)} not supported")
248
-
249
- return self.subplans.pop(0)
250
-
251
- def step(
252
- self, observation: dict
253
- ) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
254
- # add observation to history
255
- self.history.add_observation_to_history(observation)
256
-
257
- # get action
258
- if len(self.plans) == 0:
259
- self.get_plan(observation)
260
- if self.plans is None:
261
- self.plans = []
262
- return StopAction()
263
-
264
- action = self.get_next_action(observation)
265
-
266
- # add action to history
267
- self.history.add_action_to_history(action)
268
- return action