plancraft 0.1.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- environments/__init__.py +0 -0
- environments/actions.py +218 -0
- environments/env_real.py +315 -0
- environments/env_symbolic.py +215 -0
- environments/items.py +10 -0
- environments/planner.py +109 -0
- environments/recipes.py +542 -0
- environments/sampler.py +224 -0
- models/__init__.py +21 -0
- models/act.py +184 -0
- models/base.py +152 -0
- models/bbox_model.py +492 -0
- models/dummy.py +54 -0
- models/few_shot_images/__init__.py +16 -0
- models/generators.py +483 -0
- models/oam.py +284 -0
- models/oracle.py +268 -0
- models/prompts.py +158 -0
- models/react.py +98 -0
- models/utils.py +289 -0
- plancraft-0.1.0.dist-info/LICENSE +21 -0
- plancraft-0.1.0.dist-info/METADATA +53 -0
- plancraft-0.1.0.dist-info/RECORD +26 -0
- plancraft-0.1.0.dist-info/WHEEL +5 -0
- plancraft-0.1.0.dist-info/top_level.txt +3 -0
- train/dataset.py +187 -0
models/oam.py
ADDED
@@ -0,0 +1,284 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import torchvision.transforms.v2 as v2
|
7
|
+
from transformers import (
|
8
|
+
AutoConfig,
|
9
|
+
AutoModelForCausalLM,
|
10
|
+
AutoTokenizer,
|
11
|
+
PretrainedConfig,
|
12
|
+
PreTrainedModel,
|
13
|
+
)
|
14
|
+
|
15
|
+
from plancraft.models.bbox_model import IntegratedBoundingBoxModel
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class PlancraftOAMConfig(PretrainedConfig):
|
21
|
+
model_type = "plancraft-aom"
|
22
|
+
is_composition = True
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
from_llama=False,
|
27
|
+
**kwargs,
|
28
|
+
):
|
29
|
+
self.from_llama = from_llama
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
|
32
|
+
|
33
|
+
class PlancraftOAM(PreTrainedModel):
|
34
|
+
config_class = PlancraftOAMConfig
|
35
|
+
|
36
|
+
def __init__(self, config: PlancraftOAMConfig):
|
37
|
+
super().__init__(config)
|
38
|
+
|
39
|
+
self.config = config
|
40
|
+
# load text model
|
41
|
+
if self.config.from_llama:
|
42
|
+
self.text_model = AutoModelForCausalLM.from_pretrained(
|
43
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
44
|
+
)
|
45
|
+
else:
|
46
|
+
text_model_config = AutoConfig.from_pretrained(
|
47
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
48
|
+
)
|
49
|
+
self.text_model = AutoModelForCausalLM.from_config(text_model_config)
|
50
|
+
|
51
|
+
# load vision model
|
52
|
+
self.vision_model = IntegratedBoundingBoxModel.from_pretrained(
|
53
|
+
"gautierdag/plancraft-maskrcnn"
|
54
|
+
)
|
55
|
+
self.vision_model.eval()
|
56
|
+
|
57
|
+
# convert vision features to text embedding
|
58
|
+
self.vision_to_text_embedding = nn.Linear(
|
59
|
+
1024, self.text_model.config.hidden_size
|
60
|
+
)
|
61
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
62
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
63
|
+
trust_remote=True,
|
64
|
+
)
|
65
|
+
# add special tokens
|
66
|
+
self.tokenizer.add_special_tokens(
|
67
|
+
{
|
68
|
+
"additional_special_tokens": [
|
69
|
+
"<|inventory|>",
|
70
|
+
]
|
71
|
+
}
|
72
|
+
)
|
73
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
74
|
+
self.inventory_idx = self.tokenizer.convert_tokens_to_ids("<|inventory|>")
|
75
|
+
|
76
|
+
# resize token embeddings
|
77
|
+
self.text_model.resize_token_embeddings(len(self.tokenizer))
|
78
|
+
# image transforms
|
79
|
+
self.transforms = v2.Compose(
|
80
|
+
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
|
81
|
+
)
|
82
|
+
|
83
|
+
@torch.no_grad()
|
84
|
+
def extract_bboxes(self, images: list) -> list[dict]:
|
85
|
+
if len(images) == 0:
|
86
|
+
return []
|
87
|
+
img_tensors = torch.stack([self.transforms(img) for img in images])
|
88
|
+
img_tensors = img_tensors.cuda()
|
89
|
+
# disable gradients
|
90
|
+
self.vision_model.freeze()
|
91
|
+
# get bounding box predictions
|
92
|
+
bbox_preds = self.vision_model(img_tensors)
|
93
|
+
return bbox_preds
|
94
|
+
|
95
|
+
def prepare_messages(self, messages: list, bboxes: list[dict]) -> str:
|
96
|
+
# no bounding boxes
|
97
|
+
if len(bboxes) == 0:
|
98
|
+
text = self.tokenizer.apply_chat_template(
|
99
|
+
messages, add_generation_prompt=not self.training, tokenize=False
|
100
|
+
)
|
101
|
+
text = text.replace("<|begin_of_text|>", "")
|
102
|
+
return text
|
103
|
+
|
104
|
+
# expand <|inventory|> tokens into N tokens (N = number of bounding boxes)
|
105
|
+
new_messages = []
|
106
|
+
i_pred = 0
|
107
|
+
for m in messages:
|
108
|
+
new_message = m.copy()
|
109
|
+
if new_message["role"] == "user" and new_message["content"].endswith(
|
110
|
+
"<|inventory|>"
|
111
|
+
):
|
112
|
+
# add inventory tokens for each bounding box
|
113
|
+
new_message["content"] = new_message["content"].replace(
|
114
|
+
"<|inventory|>",
|
115
|
+
"<|inventory|>" * (bboxes[i_pred]["features"].shape[0]),
|
116
|
+
)
|
117
|
+
i_pred += 1
|
118
|
+
new_messages.append(new_message)
|
119
|
+
|
120
|
+
assert i_pred == len(
|
121
|
+
bboxes
|
122
|
+
), "Number of inventory tokens does not match number of images"
|
123
|
+
# add special tokens
|
124
|
+
|
125
|
+
text = self.tokenizer.apply_chat_template(
|
126
|
+
new_messages, add_generation_prompt=not self.training, tokenize=False
|
127
|
+
)
|
128
|
+
text = text.replace("<|begin_of_text|>", "")
|
129
|
+
return text
|
130
|
+
|
131
|
+
def inputs_merger(
|
132
|
+
self,
|
133
|
+
input_ids: torch.LongTensor,
|
134
|
+
inputs_embeds: Optional[torch.Tensor],
|
135
|
+
image_hidden_states: Optional[torch.Tensor],
|
136
|
+
):
|
137
|
+
# along batch dimension
|
138
|
+
for i in range(len(image_hidden_states)):
|
139
|
+
if len(image_hidden_states[i]) == 0:
|
140
|
+
assert (
|
141
|
+
input_ids[i] == self.inventory_idx
|
142
|
+
).sum() == 0, "No images but inventory token is still present"
|
143
|
+
continue
|
144
|
+
|
145
|
+
# count the number of inventory tokens
|
146
|
+
n_inventory_tokens = (input_ids[i] == self.inventory_idx).sum()
|
147
|
+
if n_inventory_tokens != image_hidden_states[i].shape[0]:
|
148
|
+
logger.warning(
|
149
|
+
f"Number of inventory tokens ({n_inventory_tokens}) does not match number of bounding boxes ({image_hidden_states[i].shape[0]}). Possible truncation."
|
150
|
+
)
|
151
|
+
# truncated from the start
|
152
|
+
image_hidden_states[i] = image_hidden_states[i][-n_inventory_tokens:]
|
153
|
+
|
154
|
+
# replace inventory tokens with bbox features
|
155
|
+
inputs_embeds[i, input_ids[i] == self.inventory_idx] = image_hidden_states[
|
156
|
+
i
|
157
|
+
]
|
158
|
+
return inputs_embeds
|
159
|
+
|
160
|
+
def process_inputs(
|
161
|
+
self,
|
162
|
+
batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
|
163
|
+
batch_images: list[list] = [], # list of list of images (unprocessed)
|
164
|
+
) -> tuple[dict[str, torch.FloatTensor], list[torch.FloatTensor], int]:
|
165
|
+
"""
|
166
|
+
Converts raw images and messages into model inputs
|
167
|
+
"""
|
168
|
+
assert len(batch_images) == len(
|
169
|
+
batch_messages
|
170
|
+
), "Number of images and messages should match in the batch dim"
|
171
|
+
# initial forward pass
|
172
|
+
texts_batch = []
|
173
|
+
image_hidden_states = []
|
174
|
+
total_boxes = 0
|
175
|
+
for images, messages in zip(batch_images, batch_messages):
|
176
|
+
# process images
|
177
|
+
bboxes = self.extract_bboxes(images)
|
178
|
+
if len(bboxes) > 0:
|
179
|
+
# get bbox features
|
180
|
+
features = torch.concat([p["features"] for p in bboxes], dim=0)
|
181
|
+
# upscale to text embedding size
|
182
|
+
features_embeds = self.vision_to_text_embedding(features)
|
183
|
+
image_hidden_states.append(features_embeds)
|
184
|
+
# count bboxes total
|
185
|
+
total_boxes += features.shape[0]
|
186
|
+
else:
|
187
|
+
image_hidden_states.append([])
|
188
|
+
|
189
|
+
# process messages
|
190
|
+
text = self.prepare_messages(messages, bboxes)
|
191
|
+
texts_batch.append(text)
|
192
|
+
|
193
|
+
# tokenize text
|
194
|
+
# @NOTE: truncation could cause issues with inventory tokens not matching number of boxes
|
195
|
+
# in that case, we will truncate the boxes from the end, and issue a warning
|
196
|
+
batch = self.tokenizer(
|
197
|
+
texts_batch,
|
198
|
+
truncation=True,
|
199
|
+
padding=True,
|
200
|
+
max_length=16384,
|
201
|
+
return_tensors="pt",
|
202
|
+
)
|
203
|
+
return batch, image_hidden_states, total_boxes
|
204
|
+
|
205
|
+
def forward(
|
206
|
+
self,
|
207
|
+
batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
|
208
|
+
batch_images: list[list] = [], # list of list of images (unprocessed)
|
209
|
+
**kwargs,
|
210
|
+
):
|
211
|
+
labels = None
|
212
|
+
batch, image_hidden_states, total_boxes = self.process_inputs(
|
213
|
+
batch_messages, batch_images
|
214
|
+
)
|
215
|
+
# move to cuda
|
216
|
+
batch = {k: v.cuda() for k, v in batch.items()}
|
217
|
+
attention_mask = batch["attention_mask"]
|
218
|
+
input_ids = batch["input_ids"]
|
219
|
+
|
220
|
+
labels = input_ids.clone()
|
221
|
+
# remove inventory tokens from labels
|
222
|
+
labels[labels == self.inventory_idx] = -100
|
223
|
+
# sanity check: should have same number of boxes as inventory tokens
|
224
|
+
assert (labels == -100).sum() == total_boxes
|
225
|
+
|
226
|
+
# get text embeddings
|
227
|
+
inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
228
|
+
inputs_embeds = self.inputs_merger(
|
229
|
+
input_ids, inputs_embeds, image_hidden_states
|
230
|
+
)
|
231
|
+
# forward pass
|
232
|
+
return self.text_model(
|
233
|
+
inputs_embeds=inputs_embeds,
|
234
|
+
attention_mask=attention_mask,
|
235
|
+
labels=labels,
|
236
|
+
return_dict=True,
|
237
|
+
)
|
238
|
+
|
239
|
+
@torch.no_grad()
|
240
|
+
def generate(
|
241
|
+
self,
|
242
|
+
batch_messages: list[list[dict]],
|
243
|
+
batch_images: list[list],
|
244
|
+
do_sample=True,
|
245
|
+
temperature=0.6,
|
246
|
+
max_new_tokens=32,
|
247
|
+
):
|
248
|
+
self.tokenizer.padding_side = "left"
|
249
|
+
|
250
|
+
batch, image_hidden_states, _ = self.process_inputs(
|
251
|
+
batch_messages, batch_images
|
252
|
+
)
|
253
|
+
batch = {k: v.cuda() for k, v in batch.items()}
|
254
|
+
attention_mask = batch["attention_mask"]
|
255
|
+
input_ids = batch["input_ids"]
|
256
|
+
|
257
|
+
inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
258
|
+
inputs_embeds = self.inputs_merger(
|
259
|
+
input_ids, inputs_embeds, image_hidden_states
|
260
|
+
)
|
261
|
+
|
262
|
+
generated_sequences = self.text_model.generate(
|
263
|
+
inputs_embeds=inputs_embeds,
|
264
|
+
attention_mask=attention_mask,
|
265
|
+
do_sample=do_sample,
|
266
|
+
temperature=temperature,
|
267
|
+
max_new_tokens=max_new_tokens,
|
268
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
269
|
+
)
|
270
|
+
|
271
|
+
# Decode the output
|
272
|
+
text_responses = self.tokenizer.batch_decode(
|
273
|
+
generated_sequences,
|
274
|
+
# generated_sequences[:, prompt_tokens:],
|
275
|
+
skip_special_tokens=False,
|
276
|
+
)
|
277
|
+
|
278
|
+
# remove <|eot_id|> tokens
|
279
|
+
text_responses = [
|
280
|
+
text_response.replace("<|eot_id|>", "") for text_response in text_responses
|
281
|
+
]
|
282
|
+
|
283
|
+
_, total_tokens_used = generated_sequences.shape
|
284
|
+
return text_responses, total_tokens_used
|
models/oracle.py
ADDED
@@ -0,0 +1,268 @@
|
|
1
|
+
import logging
|
2
|
+
import copy
|
3
|
+
from collections import Counter
|
4
|
+
|
5
|
+
from plancraft.config import EvalConfig
|
6
|
+
from plancraft.environments.actions import (
|
7
|
+
RealActionInteraction,
|
8
|
+
SymbolicMoveAction,
|
9
|
+
SymbolicSmeltAction,
|
10
|
+
StopAction,
|
11
|
+
)
|
12
|
+
from plancraft.environments.planner import optimal_planner
|
13
|
+
from plancraft.environments.recipes import (
|
14
|
+
ShapedRecipe,
|
15
|
+
ShapelessRecipe,
|
16
|
+
SmeltingRecipe,
|
17
|
+
id_to_item,
|
18
|
+
)
|
19
|
+
from plancraft.models.base import ABCModel, History
|
20
|
+
from plancraft.environments.sampler import MAX_STACK_SIZE
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
def item_set_id_to_type(item_set_ids: set[int]):
|
26
|
+
return set(id_to_item(i) for i in item_set_ids)
|
27
|
+
|
28
|
+
|
29
|
+
def find_free_inventory_slot(inventory: list[dict], from_slot: int) -> int:
|
30
|
+
# find a free slot in the inventory for the item in from_slot
|
31
|
+
from_item_type, from_item_quantity = None, None
|
32
|
+
|
33
|
+
type_to_slot = {}
|
34
|
+
slot_to_quantity = {}
|
35
|
+
for item in inventory:
|
36
|
+
if ("slot" in item and item["slot"] == from_slot) or (
|
37
|
+
"index" in item and item["index"] == from_slot
|
38
|
+
):
|
39
|
+
from_item_quantity = item["quantity"]
|
40
|
+
from_item_type = item["type"]
|
41
|
+
# break
|
42
|
+
item_type = item["type"]
|
43
|
+
quantity = item["quantity"]
|
44
|
+
if quantity == 0:
|
45
|
+
item_type = "air"
|
46
|
+
|
47
|
+
if "slot" in item:
|
48
|
+
item_slot = item["slot"]
|
49
|
+
else:
|
50
|
+
item_slot = item["index"]
|
51
|
+
|
52
|
+
if item_type not in type_to_slot:
|
53
|
+
type_to_slot[item_type] = [item_slot]
|
54
|
+
else:
|
55
|
+
type_to_slot[item_type].append(item_slot)
|
56
|
+
|
57
|
+
if item_slot not in slot_to_quantity:
|
58
|
+
slot_to_quantity[item_slot] = quantity
|
59
|
+
else:
|
60
|
+
slot_to_quantity[item_slot] += quantity
|
61
|
+
|
62
|
+
assert from_item_type is not None, f"Item not found in slot {from_slot}"
|
63
|
+
|
64
|
+
# if there is a free slot with the same item type
|
65
|
+
if from_item_type in type_to_slot:
|
66
|
+
for slot in type_to_slot[from_item_type]:
|
67
|
+
if (
|
68
|
+
slot != from_slot
|
69
|
+
and slot_to_quantity[slot] + from_item_quantity
|
70
|
+
<= MAX_STACK_SIZE[from_item_type]
|
71
|
+
):
|
72
|
+
return slot
|
73
|
+
|
74
|
+
# if there is a free slot with air
|
75
|
+
for slot in type_to_slot["air"]:
|
76
|
+
if slot != from_slot and slot > 10:
|
77
|
+
return slot
|
78
|
+
|
79
|
+
raise ValueError("No free slot found")
|
80
|
+
|
81
|
+
|
82
|
+
def find_item_in_inventory(target: str, inventory: list[dict]) -> int:
|
83
|
+
for item in inventory:
|
84
|
+
if item["type"] == target and item["quantity"] > 0:
|
85
|
+
if "slot" in item:
|
86
|
+
return item["slot"]
|
87
|
+
elif "index" in item:
|
88
|
+
return item["index"]
|
89
|
+
raise ValueError("Neither slot or index is set")
|
90
|
+
|
91
|
+
|
92
|
+
def get_inventory_counter(inventory: list[dict]) -> Counter:
|
93
|
+
counter = Counter()
|
94
|
+
for item in inventory:
|
95
|
+
if "slot" in item and item["slot"] == 0:
|
96
|
+
continue
|
97
|
+
if "index" in item and item["index"] == 0:
|
98
|
+
continue
|
99
|
+
if item["type"] == "air":
|
100
|
+
continue
|
101
|
+
counter[item["type"]] += item["quantity"]
|
102
|
+
return counter
|
103
|
+
|
104
|
+
|
105
|
+
def get_crafting_slot_item(inventory: list[dict]) -> dict:
|
106
|
+
for item in inventory:
|
107
|
+
if "slot" in item and item["slot"] == 0 and item["quantity"] > 0:
|
108
|
+
return item
|
109
|
+
if "index" in item and item["index"] == 0 and item["quantity"] > 0:
|
110
|
+
return item
|
111
|
+
return None
|
112
|
+
|
113
|
+
|
114
|
+
def update_inventory(
|
115
|
+
inventory: list[dict], slot_from: int, slot_to: int, quantity: int
|
116
|
+
) -> list[dict]:
|
117
|
+
"""
|
118
|
+
decrements quantity of item in slot_from
|
119
|
+
NOTE: we don't care about incrementing the items in slot_to
|
120
|
+
|
121
|
+
"""
|
122
|
+
new_inventory = []
|
123
|
+
for item in inventory:
|
124
|
+
if "slot" in item and item["slot"] == slot_from:
|
125
|
+
item["quantity"] -= quantity
|
126
|
+
elif "index" in item and item["index"] == slot_from:
|
127
|
+
item["quantity"] -= quantity
|
128
|
+
new_inventory.append(item)
|
129
|
+
return new_inventory
|
130
|
+
|
131
|
+
|
132
|
+
class OracleModel(ABCModel):
|
133
|
+
"""
|
134
|
+
Oracle model returns actions that solve the task optimally
|
135
|
+
"""
|
136
|
+
|
137
|
+
def __init__(self, cfg: EvalConfig):
|
138
|
+
assert (
|
139
|
+
cfg.plancraft.environment.symbolic_action_space
|
140
|
+
), "Only symbolic actions are supported for oracle"
|
141
|
+
self.history = History(objective="")
|
142
|
+
self.plans = []
|
143
|
+
self.subplans = []
|
144
|
+
|
145
|
+
def reset_history(self, objective: str = ""):
|
146
|
+
self.history.reset(objective=objective)
|
147
|
+
self.plans = []
|
148
|
+
self.subplans = []
|
149
|
+
|
150
|
+
def get_plan(self, observation: dict):
|
151
|
+
# objective="Craft an item of type: ...."
|
152
|
+
# this simply recovering the target item to craft
|
153
|
+
target = self.history.objective.split(": ")[-1]
|
154
|
+
inventory_counter = get_inventory_counter(observation["inventory"])
|
155
|
+
self.plans = optimal_planner(target=target, inventory=inventory_counter)
|
156
|
+
|
157
|
+
def get_next_action(
|
158
|
+
self, observation: dict
|
159
|
+
) -> SymbolicMoveAction | SymbolicSmeltAction:
|
160
|
+
if len(self.subplans) > 0:
|
161
|
+
return self.subplans.pop(0)
|
162
|
+
if len(self.plans) == 0:
|
163
|
+
raise ValueError("No more steps in plan")
|
164
|
+
|
165
|
+
observed_inventory = copy.deepcopy(observation["inventory"])
|
166
|
+
|
167
|
+
# take item from crafting slot
|
168
|
+
if slot_item := get_crafting_slot_item(observed_inventory):
|
169
|
+
# move item from crafting slot to inventory
|
170
|
+
free_slot = find_free_inventory_slot(observed_inventory, from_slot=0)
|
171
|
+
return SymbolicMoveAction(
|
172
|
+
slot_from=0, slot_to=free_slot, quantity=slot_item["quantity"]
|
173
|
+
)
|
174
|
+
|
175
|
+
plan_recipe, new_inventory = self.plans.pop(0)
|
176
|
+
self.subplans = []
|
177
|
+
new_inventory_counter = Counter(new_inventory)
|
178
|
+
current_inventory = observed_inventory
|
179
|
+
current_inventory_counter = get_inventory_counter(current_inventory)
|
180
|
+
items_to_use_counter = current_inventory_counter - new_inventory_counter
|
181
|
+
new_items = new_inventory_counter - current_inventory_counter
|
182
|
+
assert len(new_items) == 1
|
183
|
+
|
184
|
+
if isinstance(plan_recipe, ShapelessRecipe):
|
185
|
+
crafting_slot = 1
|
186
|
+
|
187
|
+
# add each item to crafting slots
|
188
|
+
for item, quantity in items_to_use_counter.items():
|
189
|
+
n = 0
|
190
|
+
while n < quantity:
|
191
|
+
from_slot = find_item_in_inventory(item, current_inventory)
|
192
|
+
|
193
|
+
# skip if from_slot is the crafting slot
|
194
|
+
if from_slot == crafting_slot:
|
195
|
+
crafting_slot += 1
|
196
|
+
n += 1
|
197
|
+
continue
|
198
|
+
|
199
|
+
# low_level_plan.append(("move", item, from_slot, crafting_slot, 1))
|
200
|
+
action = SymbolicMoveAction(
|
201
|
+
slot_from=from_slot, slot_to=crafting_slot, quantity=1
|
202
|
+
)
|
203
|
+
# update state of inventory
|
204
|
+
current_inventory = update_inventory(
|
205
|
+
current_inventory, from_slot, crafting_slot, 1
|
206
|
+
)
|
207
|
+
self.subplans.append(action)
|
208
|
+
|
209
|
+
crafting_slot += 1
|
210
|
+
n += 1
|
211
|
+
|
212
|
+
# if plan_recipe is a smelting recipe
|
213
|
+
elif isinstance(plan_recipe, SmeltingRecipe):
|
214
|
+
assert len(items_to_use_counter) == 1, "smelting only supports one item"
|
215
|
+
for item, quantity in items_to_use_counter.items():
|
216
|
+
from_slot = find_item_in_inventory(item, current_inventory)
|
217
|
+
free_slot = find_free_inventory_slot(
|
218
|
+
current_inventory, from_slot=from_slot
|
219
|
+
)
|
220
|
+
action = SymbolicSmeltAction(
|
221
|
+
slot_from=from_slot, slot_to=free_slot, quantity=quantity
|
222
|
+
)
|
223
|
+
self.subplans.append(action)
|
224
|
+
|
225
|
+
# if plan_recipe is a shaped recipe
|
226
|
+
elif isinstance(plan_recipe, ShapedRecipe):
|
227
|
+
for i, row in enumerate(plan_recipe.kernel):
|
228
|
+
for j, item_set in enumerate(row):
|
229
|
+
inventory_position = (i * 3) + j + 1
|
230
|
+
valid_items = item_set_id_to_type(item_set)
|
231
|
+
for item in valid_items:
|
232
|
+
if items_to_use_counter[item] > 0:
|
233
|
+
from_slot = find_item_in_inventory(item, current_inventory)
|
234
|
+
action = SymbolicMoveAction(
|
235
|
+
slot_from=from_slot,
|
236
|
+
slot_to=inventory_position,
|
237
|
+
quantity=1,
|
238
|
+
)
|
239
|
+
items_to_use_counter[item] -= 1
|
240
|
+
# update state of inventory
|
241
|
+
current_inventory = update_inventory(
|
242
|
+
current_inventory, from_slot, inventory_position, 1
|
243
|
+
)
|
244
|
+
self.subplans.append(action)
|
245
|
+
break
|
246
|
+
else:
|
247
|
+
raise NotImplementedError(f"Recipe type {type(plan_recipe)} not supported")
|
248
|
+
|
249
|
+
return self.subplans.pop(0)
|
250
|
+
|
251
|
+
def step(
|
252
|
+
self, observation: dict
|
253
|
+
) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
|
254
|
+
# add observation to history
|
255
|
+
self.history.add_observation_to_history(observation)
|
256
|
+
|
257
|
+
# get action
|
258
|
+
if len(self.plans) == 0:
|
259
|
+
self.get_plan(observation)
|
260
|
+
if self.plans is None:
|
261
|
+
self.plans = []
|
262
|
+
return StopAction()
|
263
|
+
|
264
|
+
action = self.get_next_action(observation)
|
265
|
+
|
266
|
+
# add action to history
|
267
|
+
self.history.add_action_to_history(action)
|
268
|
+
return action
|