plancraft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
models/oam.py DELETED
@@ -1,284 +0,0 @@
1
- import logging
2
- from typing import Optional
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torchvision.transforms.v2 as v2
7
- from transformers import (
8
- AutoConfig,
9
- AutoModelForCausalLM,
10
- AutoTokenizer,
11
- PretrainedConfig,
12
- PreTrainedModel,
13
- )
14
-
15
- from plancraft.models.bbox_model import IntegratedBoundingBoxModel
16
-
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- class PlancraftOAMConfig(PretrainedConfig):
21
- model_type = "plancraft-aom"
22
- is_composition = True
23
-
24
- def __init__(
25
- self,
26
- from_llama=False,
27
- **kwargs,
28
- ):
29
- self.from_llama = from_llama
30
- super().__init__(**kwargs)
31
-
32
-
33
- class PlancraftOAM(PreTrainedModel):
34
- config_class = PlancraftOAMConfig
35
-
36
- def __init__(self, config: PlancraftOAMConfig):
37
- super().__init__(config)
38
-
39
- self.config = config
40
- # load text model
41
- if self.config.from_llama:
42
- self.text_model = AutoModelForCausalLM.from_pretrained(
43
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
44
- )
45
- else:
46
- text_model_config = AutoConfig.from_pretrained(
47
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
48
- )
49
- self.text_model = AutoModelForCausalLM.from_config(text_model_config)
50
-
51
- # load vision model
52
- self.vision_model = IntegratedBoundingBoxModel.from_pretrained(
53
- "gautierdag/plancraft-maskrcnn"
54
- )
55
- self.vision_model.eval()
56
-
57
- # convert vision features to text embedding
58
- self.vision_to_text_embedding = nn.Linear(
59
- 1024, self.text_model.config.hidden_size
60
- )
61
- self.tokenizer = AutoTokenizer.from_pretrained(
62
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
63
- trust_remote=True,
64
- )
65
- # add special tokens
66
- self.tokenizer.add_special_tokens(
67
- {
68
- "additional_special_tokens": [
69
- "<|inventory|>",
70
- ]
71
- }
72
- )
73
- self.tokenizer.pad_token = self.tokenizer.eos_token
74
- self.inventory_idx = self.tokenizer.convert_tokens_to_ids("<|inventory|>")
75
-
76
- # resize token embeddings
77
- self.text_model.resize_token_embeddings(len(self.tokenizer))
78
- # image transforms
79
- self.transforms = v2.Compose(
80
- [v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
81
- )
82
-
83
- @torch.no_grad()
84
- def extract_bboxes(self, images: list) -> list[dict]:
85
- if len(images) == 0:
86
- return []
87
- img_tensors = torch.stack([self.transforms(img) for img in images])
88
- img_tensors = img_tensors.cuda()
89
- # disable gradients
90
- self.vision_model.freeze()
91
- # get bounding box predictions
92
- bbox_preds = self.vision_model(img_tensors)
93
- return bbox_preds
94
-
95
- def prepare_messages(self, messages: list, bboxes: list[dict]) -> str:
96
- # no bounding boxes
97
- if len(bboxes) == 0:
98
- text = self.tokenizer.apply_chat_template(
99
- messages, add_generation_prompt=not self.training, tokenize=False
100
- )
101
- text = text.replace("<|begin_of_text|>", "")
102
- return text
103
-
104
- # expand <|inventory|> tokens into N tokens (N = number of bounding boxes)
105
- new_messages = []
106
- i_pred = 0
107
- for m in messages:
108
- new_message = m.copy()
109
- if new_message["role"] == "user" and new_message["content"].endswith(
110
- "<|inventory|>"
111
- ):
112
- # add inventory tokens for each bounding box
113
- new_message["content"] = new_message["content"].replace(
114
- "<|inventory|>",
115
- "<|inventory|>" * (bboxes[i_pred]["features"].shape[0]),
116
- )
117
- i_pred += 1
118
- new_messages.append(new_message)
119
-
120
- assert i_pred == len(
121
- bboxes
122
- ), "Number of inventory tokens does not match number of images"
123
- # add special tokens
124
-
125
- text = self.tokenizer.apply_chat_template(
126
- new_messages, add_generation_prompt=not self.training, tokenize=False
127
- )
128
- text = text.replace("<|begin_of_text|>", "")
129
- return text
130
-
131
- def inputs_merger(
132
- self,
133
- input_ids: torch.LongTensor,
134
- inputs_embeds: Optional[torch.Tensor],
135
- image_hidden_states: Optional[torch.Tensor],
136
- ):
137
- # along batch dimension
138
- for i in range(len(image_hidden_states)):
139
- if len(image_hidden_states[i]) == 0:
140
- assert (
141
- input_ids[i] == self.inventory_idx
142
- ).sum() == 0, "No images but inventory token is still present"
143
- continue
144
-
145
- # count the number of inventory tokens
146
- n_inventory_tokens = (input_ids[i] == self.inventory_idx).sum()
147
- if n_inventory_tokens != image_hidden_states[i].shape[0]:
148
- logger.warning(
149
- f"Number of inventory tokens ({n_inventory_tokens}) does not match number of bounding boxes ({image_hidden_states[i].shape[0]}). Possible truncation."
150
- )
151
- # truncated from the start
152
- image_hidden_states[i] = image_hidden_states[i][-n_inventory_tokens:]
153
-
154
- # replace inventory tokens with bbox features
155
- inputs_embeds[i, input_ids[i] == self.inventory_idx] = image_hidden_states[
156
- i
157
- ]
158
- return inputs_embeds
159
-
160
- def process_inputs(
161
- self,
162
- batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
163
- batch_images: list[list] = [], # list of list of images (unprocessed)
164
- ) -> tuple[dict[str, torch.FloatTensor], list[torch.FloatTensor], int]:
165
- """
166
- Converts raw images and messages into model inputs
167
- """
168
- assert len(batch_images) == len(
169
- batch_messages
170
- ), "Number of images and messages should match in the batch dim"
171
- # initial forward pass
172
- texts_batch = []
173
- image_hidden_states = []
174
- total_boxes = 0
175
- for images, messages in zip(batch_images, batch_messages):
176
- # process images
177
- bboxes = self.extract_bboxes(images)
178
- if len(bboxes) > 0:
179
- # get bbox features
180
- features = torch.concat([p["features"] for p in bboxes], dim=0)
181
- # upscale to text embedding size
182
- features_embeds = self.vision_to_text_embedding(features)
183
- image_hidden_states.append(features_embeds)
184
- # count bboxes total
185
- total_boxes += features.shape[0]
186
- else:
187
- image_hidden_states.append([])
188
-
189
- # process messages
190
- text = self.prepare_messages(messages, bboxes)
191
- texts_batch.append(text)
192
-
193
- # tokenize text
194
- # @NOTE: truncation could cause issues with inventory tokens not matching number of boxes
195
- # in that case, we will truncate the boxes from the end, and issue a warning
196
- batch = self.tokenizer(
197
- texts_batch,
198
- truncation=True,
199
- padding=True,
200
- max_length=16384,
201
- return_tensors="pt",
202
- )
203
- return batch, image_hidden_states, total_boxes
204
-
205
- def forward(
206
- self,
207
- batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
208
- batch_images: list[list] = [], # list of list of images (unprocessed)
209
- **kwargs,
210
- ):
211
- labels = None
212
- batch, image_hidden_states, total_boxes = self.process_inputs(
213
- batch_messages, batch_images
214
- )
215
- # move to cuda
216
- batch = {k: v.cuda() for k, v in batch.items()}
217
- attention_mask = batch["attention_mask"]
218
- input_ids = batch["input_ids"]
219
-
220
- labels = input_ids.clone()
221
- # remove inventory tokens from labels
222
- labels[labels == self.inventory_idx] = -100
223
- # sanity check: should have same number of boxes as inventory tokens
224
- assert (labels == -100).sum() == total_boxes
225
-
226
- # get text embeddings
227
- inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
228
- inputs_embeds = self.inputs_merger(
229
- input_ids, inputs_embeds, image_hidden_states
230
- )
231
- # forward pass
232
- return self.text_model(
233
- inputs_embeds=inputs_embeds,
234
- attention_mask=attention_mask,
235
- labels=labels,
236
- return_dict=True,
237
- )
238
-
239
- @torch.no_grad()
240
- def generate(
241
- self,
242
- batch_messages: list[list[dict]],
243
- batch_images: list[list],
244
- do_sample=True,
245
- temperature=0.6,
246
- max_new_tokens=32,
247
- ):
248
- self.tokenizer.padding_side = "left"
249
-
250
- batch, image_hidden_states, _ = self.process_inputs(
251
- batch_messages, batch_images
252
- )
253
- batch = {k: v.cuda() for k, v in batch.items()}
254
- attention_mask = batch["attention_mask"]
255
- input_ids = batch["input_ids"]
256
-
257
- inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
258
- inputs_embeds = self.inputs_merger(
259
- input_ids, inputs_embeds, image_hidden_states
260
- )
261
-
262
- generated_sequences = self.text_model.generate(
263
- inputs_embeds=inputs_embeds,
264
- attention_mask=attention_mask,
265
- do_sample=do_sample,
266
- temperature=temperature,
267
- max_new_tokens=max_new_tokens,
268
- pad_token_id=self.tokenizer.eos_token_id,
269
- )
270
-
271
- # Decode the output
272
- text_responses = self.tokenizer.batch_decode(
273
- generated_sequences,
274
- # generated_sequences[:, prompt_tokens:],
275
- skip_special_tokens=False,
276
- )
277
-
278
- # remove <|eot_id|> tokens
279
- text_responses = [
280
- text_response.replace("<|eot_id|>", "") for text_response in text_responses
281
- ]
282
-
283
- _, total_tokens_used = generated_sequences.shape
284
- return text_responses, total_tokens_used
models/oracle.py DELETED
@@ -1,268 +0,0 @@
1
- import logging
2
- import copy
3
- from collections import Counter
4
-
5
- from plancraft.config import EvalConfig
6
- from plancraft.environments.actions import (
7
- RealActionInteraction,
8
- SymbolicMoveAction,
9
- SymbolicSmeltAction,
10
- StopAction,
11
- )
12
- from plancraft.environments.planner import optimal_planner
13
- from plancraft.environments.recipes import (
14
- ShapedRecipe,
15
- ShapelessRecipe,
16
- SmeltingRecipe,
17
- id_to_item,
18
- )
19
- from plancraft.models.base import ABCModel, History
20
- from plancraft.environments.sampler import MAX_STACK_SIZE
21
-
22
- logger = logging.getLogger(__name__)
23
-
24
-
25
- def item_set_id_to_type(item_set_ids: set[int]):
26
- return set(id_to_item(i) for i in item_set_ids)
27
-
28
-
29
- def find_free_inventory_slot(inventory: list[dict], from_slot: int) -> int:
30
- # find a free slot in the inventory for the item in from_slot
31
- from_item_type, from_item_quantity = None, None
32
-
33
- type_to_slot = {}
34
- slot_to_quantity = {}
35
- for item in inventory:
36
- if ("slot" in item and item["slot"] == from_slot) or (
37
- "index" in item and item["index"] == from_slot
38
- ):
39
- from_item_quantity = item["quantity"]
40
- from_item_type = item["type"]
41
- # break
42
- item_type = item["type"]
43
- quantity = item["quantity"]
44
- if quantity == 0:
45
- item_type = "air"
46
-
47
- if "slot" in item:
48
- item_slot = item["slot"]
49
- else:
50
- item_slot = item["index"]
51
-
52
- if item_type not in type_to_slot:
53
- type_to_slot[item_type] = [item_slot]
54
- else:
55
- type_to_slot[item_type].append(item_slot)
56
-
57
- if item_slot not in slot_to_quantity:
58
- slot_to_quantity[item_slot] = quantity
59
- else:
60
- slot_to_quantity[item_slot] += quantity
61
-
62
- assert from_item_type is not None, f"Item not found in slot {from_slot}"
63
-
64
- # if there is a free slot with the same item type
65
- if from_item_type in type_to_slot:
66
- for slot in type_to_slot[from_item_type]:
67
- if (
68
- slot != from_slot
69
- and slot_to_quantity[slot] + from_item_quantity
70
- <= MAX_STACK_SIZE[from_item_type]
71
- ):
72
- return slot
73
-
74
- # if there is a free slot with air
75
- for slot in type_to_slot["air"]:
76
- if slot != from_slot and slot > 10:
77
- return slot
78
-
79
- raise ValueError("No free slot found")
80
-
81
-
82
- def find_item_in_inventory(target: str, inventory: list[dict]) -> int:
83
- for item in inventory:
84
- if item["type"] == target and item["quantity"] > 0:
85
- if "slot" in item:
86
- return item["slot"]
87
- elif "index" in item:
88
- return item["index"]
89
- raise ValueError("Neither slot or index is set")
90
-
91
-
92
- def get_inventory_counter(inventory: list[dict]) -> Counter:
93
- counter = Counter()
94
- for item in inventory:
95
- if "slot" in item and item["slot"] == 0:
96
- continue
97
- if "index" in item and item["index"] == 0:
98
- continue
99
- if item["type"] == "air":
100
- continue
101
- counter[item["type"]] += item["quantity"]
102
- return counter
103
-
104
-
105
- def get_crafting_slot_item(inventory: list[dict]) -> dict:
106
- for item in inventory:
107
- if "slot" in item and item["slot"] == 0 and item["quantity"] > 0:
108
- return item
109
- if "index" in item and item["index"] == 0 and item["quantity"] > 0:
110
- return item
111
- return None
112
-
113
-
114
- def update_inventory(
115
- inventory: list[dict], slot_from: int, slot_to: int, quantity: int
116
- ) -> list[dict]:
117
- """
118
- decrements quantity of item in slot_from
119
- NOTE: we don't care about incrementing the items in slot_to
120
-
121
- """
122
- new_inventory = []
123
- for item in inventory:
124
- if "slot" in item and item["slot"] == slot_from:
125
- item["quantity"] -= quantity
126
- elif "index" in item and item["index"] == slot_from:
127
- item["quantity"] -= quantity
128
- new_inventory.append(item)
129
- return new_inventory
130
-
131
-
132
- class OracleModel(ABCModel):
133
- """
134
- Oracle model returns actions that solve the task optimally
135
- """
136
-
137
- def __init__(self, cfg: EvalConfig):
138
- assert (
139
- cfg.plancraft.environment.symbolic_action_space
140
- ), "Only symbolic actions are supported for oracle"
141
- self.history = History(objective="")
142
- self.plans = []
143
- self.subplans = []
144
-
145
- def reset_history(self, objective: str = ""):
146
- self.history.reset(objective=objective)
147
- self.plans = []
148
- self.subplans = []
149
-
150
- def get_plan(self, observation: dict):
151
- # objective="Craft an item of type: ...."
152
- # this simply recovering the target item to craft
153
- target = self.history.objective.split(": ")[-1]
154
- inventory_counter = get_inventory_counter(observation["inventory"])
155
- self.plans = optimal_planner(target=target, inventory=inventory_counter)
156
-
157
- def get_next_action(
158
- self, observation: dict
159
- ) -> SymbolicMoveAction | SymbolicSmeltAction:
160
- if len(self.subplans) > 0:
161
- return self.subplans.pop(0)
162
- if len(self.plans) == 0:
163
- raise ValueError("No more steps in plan")
164
-
165
- observed_inventory = copy.deepcopy(observation["inventory"])
166
-
167
- # take item from crafting slot
168
- if slot_item := get_crafting_slot_item(observed_inventory):
169
- # move item from crafting slot to inventory
170
- free_slot = find_free_inventory_slot(observed_inventory, from_slot=0)
171
- return SymbolicMoveAction(
172
- slot_from=0, slot_to=free_slot, quantity=slot_item["quantity"]
173
- )
174
-
175
- plan_recipe, new_inventory = self.plans.pop(0)
176
- self.subplans = []
177
- new_inventory_counter = Counter(new_inventory)
178
- current_inventory = observed_inventory
179
- current_inventory_counter = get_inventory_counter(current_inventory)
180
- items_to_use_counter = current_inventory_counter - new_inventory_counter
181
- new_items = new_inventory_counter - current_inventory_counter
182
- assert len(new_items) == 1
183
-
184
- if isinstance(plan_recipe, ShapelessRecipe):
185
- crafting_slot = 1
186
-
187
- # add each item to crafting slots
188
- for item, quantity in items_to_use_counter.items():
189
- n = 0
190
- while n < quantity:
191
- from_slot = find_item_in_inventory(item, current_inventory)
192
-
193
- # skip if from_slot is the crafting slot
194
- if from_slot == crafting_slot:
195
- crafting_slot += 1
196
- n += 1
197
- continue
198
-
199
- # low_level_plan.append(("move", item, from_slot, crafting_slot, 1))
200
- action = SymbolicMoveAction(
201
- slot_from=from_slot, slot_to=crafting_slot, quantity=1
202
- )
203
- # update state of inventory
204
- current_inventory = update_inventory(
205
- current_inventory, from_slot, crafting_slot, 1
206
- )
207
- self.subplans.append(action)
208
-
209
- crafting_slot += 1
210
- n += 1
211
-
212
- # if plan_recipe is a smelting recipe
213
- elif isinstance(plan_recipe, SmeltingRecipe):
214
- assert len(items_to_use_counter) == 1, "smelting only supports one item"
215
- for item, quantity in items_to_use_counter.items():
216
- from_slot = find_item_in_inventory(item, current_inventory)
217
- free_slot = find_free_inventory_slot(
218
- current_inventory, from_slot=from_slot
219
- )
220
- action = SymbolicSmeltAction(
221
- slot_from=from_slot, slot_to=free_slot, quantity=quantity
222
- )
223
- self.subplans.append(action)
224
-
225
- # if plan_recipe is a shaped recipe
226
- elif isinstance(plan_recipe, ShapedRecipe):
227
- for i, row in enumerate(plan_recipe.kernel):
228
- for j, item_set in enumerate(row):
229
- inventory_position = (i * 3) + j + 1
230
- valid_items = item_set_id_to_type(item_set)
231
- for item in valid_items:
232
- if items_to_use_counter[item] > 0:
233
- from_slot = find_item_in_inventory(item, current_inventory)
234
- action = SymbolicMoveAction(
235
- slot_from=from_slot,
236
- slot_to=inventory_position,
237
- quantity=1,
238
- )
239
- items_to_use_counter[item] -= 1
240
- # update state of inventory
241
- current_inventory = update_inventory(
242
- current_inventory, from_slot, inventory_position, 1
243
- )
244
- self.subplans.append(action)
245
- break
246
- else:
247
- raise NotImplementedError(f"Recipe type {type(plan_recipe)} not supported")
248
-
249
- return self.subplans.pop(0)
250
-
251
- def step(
252
- self, observation: dict
253
- ) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
254
- # add observation to history
255
- self.history.add_observation_to_history(observation)
256
-
257
- # get action
258
- if len(self.plans) == 0:
259
- self.get_plan(observation)
260
- if self.plans is None:
261
- self.plans = []
262
- return StopAction()
263
-
264
- action = self.get_next_action(observation)
265
-
266
- # add action to history
267
- self.history.add_action_to_history(action)
268
- return action