plancraft 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
models/oam.py DELETED
@@ -1,283 +0,0 @@
1
- from typing import Optional
2
-
3
- from loguru import logger
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torchvision.transforms.v2 as v2
8
- from transformers import (
9
- AutoConfig,
10
- AutoModelForCausalLM,
11
- AutoTokenizer,
12
- PretrainedConfig,
13
- PreTrainedModel,
14
- )
15
-
16
- from plancraft.models.bbox_model import IntegratedBoundingBoxModel
17
-
18
-
19
- class PlancraftOAMConfig(PretrainedConfig):
20
- model_type = "plancraft-aom"
21
- is_composition = True
22
-
23
- def __init__(
24
- self,
25
- from_llama=False,
26
- **kwargs,
27
- ):
28
- self.from_llama = from_llama
29
- super().__init__(**kwargs)
30
-
31
-
32
- class PlancraftOAM(PreTrainedModel):
33
- config_class = PlancraftOAMConfig
34
-
35
- def __init__(self, config: PlancraftOAMConfig):
36
- super().__init__(config)
37
-
38
- self.config = config
39
- # load text model
40
- if self.config.from_llama:
41
- self.text_model = AutoModelForCausalLM.from_pretrained(
42
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
43
- )
44
- else:
45
- text_model_config = AutoConfig.from_pretrained(
46
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
47
- )
48
- self.text_model = AutoModelForCausalLM.from_config(text_model_config)
49
-
50
- # load vision model
51
- self.vision_model = IntegratedBoundingBoxModel.from_pretrained(
52
- "gautierdag/plancraft-maskrcnn"
53
- )
54
- self.vision_model.eval()
55
-
56
- # convert vision features to text embedding
57
- self.vision_to_text_embedding = nn.Linear(
58
- 1024, self.text_model.config.hidden_size
59
- )
60
- self.tokenizer = AutoTokenizer.from_pretrained(
61
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
62
- trust_remote=True,
63
- )
64
- # add special tokens
65
- self.tokenizer.add_special_tokens(
66
- {
67
- "additional_special_tokens": [
68
- "<|inventory|>",
69
- ]
70
- }
71
- )
72
- self.tokenizer.pad_token = self.tokenizer.eos_token
73
- self.inventory_idx = self.tokenizer.convert_tokens_to_ids("<|inventory|>")
74
-
75
- # resize token embeddings
76
- self.text_model.resize_token_embeddings(len(self.tokenizer))
77
- # image transforms
78
- self.transforms = v2.Compose(
79
- [v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
80
- )
81
-
82
- @torch.no_grad()
83
- def extract_bboxes(self, images: list) -> list[dict]:
84
- if len(images) == 0:
85
- return []
86
- img_tensors = torch.stack([self.transforms(img) for img in images])
87
- img_tensors = img_tensors.cuda()
88
- # disable gradients
89
- self.vision_model.freeze()
90
- # get bounding box predictions
91
- bbox_preds = self.vision_model(img_tensors)
92
- return bbox_preds
93
-
94
- def prepare_messages(self, messages: list, bboxes: list[dict]) -> str:
95
- # no bounding boxes
96
- if len(bboxes) == 0:
97
- text = self.tokenizer.apply_chat_template(
98
- messages, add_generation_prompt=not self.training, tokenize=False
99
- )
100
- text = text.replace("<|begin_of_text|>", "")
101
- return text
102
-
103
- # expand <|inventory|> tokens into N tokens (N = number of bounding boxes)
104
- new_messages = []
105
- i_pred = 0
106
- for m in messages:
107
- new_message = m.copy()
108
- if new_message["role"] == "user" and new_message["content"].endswith(
109
- "<|inventory|>"
110
- ):
111
- # add inventory tokens for each bounding box
112
- new_message["content"] = new_message["content"].replace(
113
- "<|inventory|>",
114
- "<|inventory|>" * (bboxes[i_pred]["features"].shape[0]),
115
- )
116
- i_pred += 1
117
- new_messages.append(new_message)
118
-
119
- assert i_pred == len(
120
- bboxes
121
- ), "Number of inventory tokens does not match number of images"
122
- # add special tokens
123
-
124
- text = self.tokenizer.apply_chat_template(
125
- new_messages, add_generation_prompt=not self.training, tokenize=False
126
- )
127
- text = text.replace("<|begin_of_text|>", "")
128
- return text
129
-
130
- def inputs_merger(
131
- self,
132
- input_ids: torch.LongTensor,
133
- inputs_embeds: Optional[torch.Tensor],
134
- image_hidden_states: Optional[torch.Tensor],
135
- ):
136
- # along batch dimension
137
- for i in range(len(image_hidden_states)):
138
- if len(image_hidden_states[i]) == 0:
139
- assert (
140
- input_ids[i] == self.inventory_idx
141
- ).sum() == 0, "No images but inventory token is still present"
142
- continue
143
-
144
- # count the number of inventory tokens
145
- n_inventory_tokens = (input_ids[i] == self.inventory_idx).sum()
146
- if n_inventory_tokens != image_hidden_states[i].shape[0]:
147
- logger.warning(
148
- f"Number of inventory tokens ({n_inventory_tokens}) does not match number of bounding boxes ({image_hidden_states[i].shape[0]}). Possible truncation."
149
- )
150
- # truncated from the start
151
- image_hidden_states[i] = image_hidden_states[i][-n_inventory_tokens:]
152
-
153
- # replace inventory tokens with bbox features
154
- inputs_embeds[i, input_ids[i] == self.inventory_idx] = image_hidden_states[
155
- i
156
- ]
157
- return inputs_embeds
158
-
159
- def process_inputs(
160
- self,
161
- batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
162
- batch_images: list[list] = [], # list of list of images (unprocessed)
163
- ) -> tuple[dict[str, torch.FloatTensor], list[torch.FloatTensor], int]:
164
- """
165
- Converts raw images and messages into model inputs
166
- """
167
- assert len(batch_images) == len(
168
- batch_messages
169
- ), "Number of images and messages should match in the batch dim"
170
- # initial forward pass
171
- texts_batch = []
172
- image_hidden_states = []
173
- total_boxes = 0
174
- for images, messages in zip(batch_images, batch_messages):
175
- # process images
176
- bboxes = self.extract_bboxes(images)
177
- if len(bboxes) > 0:
178
- # get bbox features
179
- features = torch.concat([p["features"] for p in bboxes], dim=0)
180
- # upscale to text embedding size
181
- features_embeds = self.vision_to_text_embedding(features)
182
- image_hidden_states.append(features_embeds)
183
- # count bboxes total
184
- total_boxes += features.shape[0]
185
- else:
186
- image_hidden_states.append([])
187
-
188
- # process messages
189
- text = self.prepare_messages(messages, bboxes)
190
- texts_batch.append(text)
191
-
192
- # tokenize text
193
- # @NOTE: truncation could cause issues with inventory tokens not matching number of boxes
194
- # in that case, we will truncate the boxes from the end, and issue a warning
195
- batch = self.tokenizer(
196
- texts_batch,
197
- truncation=True,
198
- padding=True,
199
- max_length=16384,
200
- return_tensors="pt",
201
- )
202
- return batch, image_hidden_states, total_boxes
203
-
204
- def forward(
205
- self,
206
- batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
207
- batch_images: list[list] = [], # list of list of images (unprocessed)
208
- **kwargs,
209
- ):
210
- labels = None
211
- batch, image_hidden_states, total_boxes = self.process_inputs(
212
- batch_messages, batch_images
213
- )
214
- # move to cuda
215
- batch = {k: v.cuda() for k, v in batch.items()}
216
- attention_mask = batch["attention_mask"]
217
- input_ids = batch["input_ids"]
218
-
219
- labels = input_ids.clone()
220
- # remove inventory tokens from labels
221
- labels[labels == self.inventory_idx] = -100
222
- # sanity check: should have same number of boxes as inventory tokens
223
- assert (labels == -100).sum() == total_boxes
224
-
225
- # get text embeddings
226
- inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
227
- inputs_embeds = self.inputs_merger(
228
- input_ids, inputs_embeds, image_hidden_states
229
- )
230
- # forward pass
231
- return self.text_model(
232
- inputs_embeds=inputs_embeds,
233
- attention_mask=attention_mask,
234
- labels=labels,
235
- return_dict=True,
236
- )
237
-
238
- @torch.no_grad()
239
- def generate(
240
- self,
241
- batch_messages: list[list[dict]],
242
- batch_images: list[list],
243
- do_sample=True,
244
- temperature=0.6,
245
- max_new_tokens=32,
246
- ):
247
- self.tokenizer.padding_side = "left"
248
-
249
- batch, image_hidden_states, _ = self.process_inputs(
250
- batch_messages, batch_images
251
- )
252
- batch = {k: v.cuda() for k, v in batch.items()}
253
- attention_mask = batch["attention_mask"]
254
- input_ids = batch["input_ids"]
255
-
256
- inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
257
- inputs_embeds = self.inputs_merger(
258
- input_ids, inputs_embeds, image_hidden_states
259
- )
260
-
261
- generated_sequences = self.text_model.generate(
262
- inputs_embeds=inputs_embeds,
263
- attention_mask=attention_mask,
264
- do_sample=do_sample,
265
- temperature=temperature,
266
- max_new_tokens=max_new_tokens,
267
- pad_token_id=self.tokenizer.eos_token_id,
268
- )
269
-
270
- # Decode the output
271
- text_responses = self.tokenizer.batch_decode(
272
- generated_sequences,
273
- # generated_sequences[:, prompt_tokens:],
274
- skip_special_tokens=False,
275
- )
276
-
277
- # remove <|eot_id|> tokens
278
- text_responses = [
279
- text_response.replace("<|eot_id|>", "") for text_response in text_responses
280
- ]
281
-
282
- _, total_tokens_used = generated_sequences.shape
283
- return text_responses, total_tokens_used
models/oracle.py DELETED
@@ -1,265 +0,0 @@
1
- import copy
2
- from collections import Counter
3
-
4
- from plancraft.config import EvalConfig
5
- from plancraft.environments.actions import (
6
- RealActionInteraction,
7
- StopAction,
8
- SymbolicMoveAction,
9
- SymbolicSmeltAction,
10
- )
11
- from plancraft.environments.planner import optimal_planner
12
- from plancraft.environments.recipes import (
13
- ShapedRecipe,
14
- ShapelessRecipe,
15
- SmeltingRecipe,
16
- id_to_item,
17
- )
18
- from plancraft.environments.sampler import MAX_STACK_SIZE
19
- from plancraft.models.base import ABCModel, History
20
-
21
-
22
- def item_set_id_to_type(item_set_ids: set[int]):
23
- return set(id_to_item(i) for i in item_set_ids)
24
-
25
-
26
- def find_free_inventory_slot(inventory: list[dict], from_slot: int) -> int:
27
- # find a free slot in the inventory for the item in from_slot
28
- from_item_type, from_item_quantity = None, None
29
-
30
- type_to_slot = {}
31
- slot_to_quantity = {}
32
- for item in inventory:
33
- if ("slot" in item and item["slot"] == from_slot) or (
34
- "index" in item and item["index"] == from_slot
35
- ):
36
- from_item_quantity = item["quantity"]
37
- from_item_type = item["type"]
38
- # break
39
- item_type = item["type"]
40
- quantity = item["quantity"]
41
- if quantity == 0:
42
- item_type = "air"
43
-
44
- if "slot" in item:
45
- item_slot = item["slot"]
46
- else:
47
- item_slot = item["index"]
48
-
49
- if item_type not in type_to_slot:
50
- type_to_slot[item_type] = [item_slot]
51
- else:
52
- type_to_slot[item_type].append(item_slot)
53
-
54
- if item_slot not in slot_to_quantity:
55
- slot_to_quantity[item_slot] = quantity
56
- else:
57
- slot_to_quantity[item_slot] += quantity
58
-
59
- assert from_item_type is not None, f"Item not found in slot {from_slot}"
60
-
61
- # if there is a free slot with the same item type
62
- if from_item_type in type_to_slot:
63
- for slot in type_to_slot[from_item_type]:
64
- if (
65
- slot != from_slot
66
- and slot_to_quantity[slot] + from_item_quantity
67
- <= MAX_STACK_SIZE[from_item_type]
68
- ):
69
- return slot
70
-
71
- # if there is a free slot with air
72
- for slot in type_to_slot["air"]:
73
- if slot != from_slot and slot > 10:
74
- return slot
75
-
76
- raise ValueError("No free slot found")
77
-
78
-
79
- def find_item_in_inventory(target: str, inventory: list[dict]) -> int:
80
- for item in inventory:
81
- if item["type"] == target and item["quantity"] > 0:
82
- if "slot" in item:
83
- return item["slot"]
84
- elif "index" in item:
85
- return item["index"]
86
- raise ValueError("Neither slot or index is set")
87
-
88
-
89
- def get_inventory_counter(inventory: list[dict]) -> Counter:
90
- counter = Counter()
91
- for item in inventory:
92
- if "slot" in item and item["slot"] == 0:
93
- continue
94
- if "index" in item and item["index"] == 0:
95
- continue
96
- if item["type"] == "air":
97
- continue
98
- counter[item["type"]] += item["quantity"]
99
- return counter
100
-
101
-
102
- def get_crafting_slot_item(inventory: list[dict]) -> dict:
103
- for item in inventory:
104
- if "slot" in item and item["slot"] == 0 and item["quantity"] > 0:
105
- return item
106
- if "index" in item and item["index"] == 0 and item["quantity"] > 0:
107
- return item
108
- return None
109
-
110
-
111
- def update_inventory(
112
- inventory: list[dict], slot_from: int, slot_to: int, quantity: int
113
- ) -> list[dict]:
114
- """
115
- decrements quantity of item in slot_from
116
- NOTE: we don't care about incrementing the items in slot_to
117
-
118
- """
119
- new_inventory = []
120
- for item in inventory:
121
- if "slot" in item and item["slot"] == slot_from:
122
- item["quantity"] -= quantity
123
- elif "index" in item and item["index"] == slot_from:
124
- item["quantity"] -= quantity
125
- new_inventory.append(item)
126
- return new_inventory
127
-
128
-
129
- class OracleModel(ABCModel):
130
- """
131
- Oracle model returns actions that solve the task optimally
132
- """
133
-
134
- def __init__(self, cfg: EvalConfig):
135
- assert (
136
- cfg.plancraft.environment.symbolic_action_space
137
- ), "Only symbolic actions are supported for oracle"
138
- self.history = History(objective="")
139
- self.plans = []
140
- self.subplans = []
141
-
142
- def reset_history(self, objective: str = ""):
143
- self.history.reset(objective=objective)
144
- self.plans = []
145
- self.subplans = []
146
-
147
- def get_plan(self, observation: dict):
148
- # objective="Craft an item of type: ...."
149
- # this simply recovering the target item to craft
150
- target = self.history.objective.split(": ")[-1]
151
- inventory_counter = get_inventory_counter(observation["inventory"])
152
- self.plans = optimal_planner(target=target, inventory=inventory_counter)
153
-
154
- def get_next_action(
155
- self, observation: dict
156
- ) -> SymbolicMoveAction | SymbolicSmeltAction:
157
- if len(self.subplans) > 0:
158
- return self.subplans.pop(0)
159
- if len(self.plans) == 0:
160
- raise ValueError("No more steps in plan")
161
-
162
- observed_inventory = copy.deepcopy(observation["inventory"])
163
-
164
- # take item from crafting slot
165
- if slot_item := get_crafting_slot_item(observed_inventory):
166
- # move item from crafting slot to inventory
167
- free_slot = find_free_inventory_slot(observed_inventory, from_slot=0)
168
- return SymbolicMoveAction(
169
- slot_from=0, slot_to=free_slot, quantity=slot_item["quantity"]
170
- )
171
-
172
- plan_recipe, new_inventory = self.plans.pop(0)
173
- self.subplans = []
174
- new_inventory_counter = Counter(new_inventory)
175
- current_inventory = observed_inventory
176
- current_inventory_counter = get_inventory_counter(current_inventory)
177
- items_to_use_counter = current_inventory_counter - new_inventory_counter
178
- new_items = new_inventory_counter - current_inventory_counter
179
- assert len(new_items) == 1
180
-
181
- if isinstance(plan_recipe, ShapelessRecipe):
182
- crafting_slot = 1
183
-
184
- # add each item to crafting slots
185
- for item, quantity in items_to_use_counter.items():
186
- n = 0
187
- while n < quantity:
188
- from_slot = find_item_in_inventory(item, current_inventory)
189
-
190
- # skip if from_slot is the crafting slot
191
- if from_slot == crafting_slot:
192
- crafting_slot += 1
193
- n += 1
194
- continue
195
-
196
- # low_level_plan.append(("move", item, from_slot, crafting_slot, 1))
197
- action = SymbolicMoveAction(
198
- slot_from=from_slot, slot_to=crafting_slot, quantity=1
199
- )
200
- # update state of inventory
201
- current_inventory = update_inventory(
202
- current_inventory, from_slot, crafting_slot, 1
203
- )
204
- self.subplans.append(action)
205
-
206
- crafting_slot += 1
207
- n += 1
208
-
209
- # if plan_recipe is a smelting recipe
210
- elif isinstance(plan_recipe, SmeltingRecipe):
211
- assert len(items_to_use_counter) == 1, "smelting only supports one item"
212
- for item, quantity in items_to_use_counter.items():
213
- from_slot = find_item_in_inventory(item, current_inventory)
214
- free_slot = find_free_inventory_slot(
215
- current_inventory, from_slot=from_slot
216
- )
217
- action = SymbolicSmeltAction(
218
- slot_from=from_slot, slot_to=free_slot, quantity=quantity
219
- )
220
- self.subplans.append(action)
221
-
222
- # if plan_recipe is a shaped recipe
223
- elif isinstance(plan_recipe, ShapedRecipe):
224
- for i, row in enumerate(plan_recipe.kernel):
225
- for j, item_set in enumerate(row):
226
- inventory_position = (i * 3) + j + 1
227
- valid_items = item_set_id_to_type(item_set)
228
- for item in valid_items:
229
- if items_to_use_counter[item] > 0:
230
- from_slot = find_item_in_inventory(item, current_inventory)
231
- action = SymbolicMoveAction(
232
- slot_from=from_slot,
233
- slot_to=inventory_position,
234
- quantity=1,
235
- )
236
- items_to_use_counter[item] -= 1
237
- # update state of inventory
238
- current_inventory = update_inventory(
239
- current_inventory, from_slot, inventory_position, 1
240
- )
241
- self.subplans.append(action)
242
- break
243
- else:
244
- raise NotImplementedError(f"Recipe type {type(plan_recipe)} not supported")
245
-
246
- return self.subplans.pop(0)
247
-
248
- def step(
249
- self, observation: dict
250
- ) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
251
- # add observation to history
252
- self.history.add_observation_to_history(observation)
253
-
254
- # get action
255
- if len(self.plans) == 0:
256
- self.get_plan(observation)
257
- if self.plans is None:
258
- self.plans = []
259
- return StopAction()
260
-
261
- action = self.get_next_action(observation)
262
-
263
- # add action to history
264
- self.history.add_action_to_history(action)
265
- return action