plancraft 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- environments/__init__.py +0 -0
- environments/actions.py +218 -0
- environments/env_real.py +315 -0
- environments/env_symbolic.py +215 -0
- environments/items.py +10 -0
- environments/planner.py +109 -0
- environments/recipes.py +542 -0
- environments/sampler.py +224 -0
- models/__init__.py +21 -0
- models/act.py +184 -0
- models/base.py +152 -0
- models/bbox_model.py +492 -0
- models/dummy.py +54 -0
- models/few_shot_images/__init__.py +16 -0
- models/generators.py +483 -0
- models/oam.py +284 -0
- models/oracle.py +268 -0
- models/prompts.py +158 -0
- models/react.py +98 -0
- models/utils.py +289 -0
- plancraft-0.1.0.dist-info/LICENSE +21 -0
- plancraft-0.1.0.dist-info/METADATA +53 -0
- plancraft-0.1.0.dist-info/RECORD +26 -0
- plancraft-0.1.0.dist-info/WHEEL +5 -0
- plancraft-0.1.0.dist-info/top_level.txt +3 -0
- train/dataset.py +187 -0
models/oam.py
ADDED
@@ -0,0 +1,284 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import torchvision.transforms.v2 as v2
|
7
|
+
from transformers import (
|
8
|
+
AutoConfig,
|
9
|
+
AutoModelForCausalLM,
|
10
|
+
AutoTokenizer,
|
11
|
+
PretrainedConfig,
|
12
|
+
PreTrainedModel,
|
13
|
+
)
|
14
|
+
|
15
|
+
from plancraft.models.bbox_model import IntegratedBoundingBoxModel
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class PlancraftOAMConfig(PretrainedConfig):
|
21
|
+
model_type = "plancraft-aom"
|
22
|
+
is_composition = True
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
from_llama=False,
|
27
|
+
**kwargs,
|
28
|
+
):
|
29
|
+
self.from_llama = from_llama
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
|
32
|
+
|
33
|
+
class PlancraftOAM(PreTrainedModel):
|
34
|
+
config_class = PlancraftOAMConfig
|
35
|
+
|
36
|
+
def __init__(self, config: PlancraftOAMConfig):
|
37
|
+
super().__init__(config)
|
38
|
+
|
39
|
+
self.config = config
|
40
|
+
# load text model
|
41
|
+
if self.config.from_llama:
|
42
|
+
self.text_model = AutoModelForCausalLM.from_pretrained(
|
43
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
44
|
+
)
|
45
|
+
else:
|
46
|
+
text_model_config = AutoConfig.from_pretrained(
|
47
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
48
|
+
)
|
49
|
+
self.text_model = AutoModelForCausalLM.from_config(text_model_config)
|
50
|
+
|
51
|
+
# load vision model
|
52
|
+
self.vision_model = IntegratedBoundingBoxModel.from_pretrained(
|
53
|
+
"gautierdag/plancraft-maskrcnn"
|
54
|
+
)
|
55
|
+
self.vision_model.eval()
|
56
|
+
|
57
|
+
# convert vision features to text embedding
|
58
|
+
self.vision_to_text_embedding = nn.Linear(
|
59
|
+
1024, self.text_model.config.hidden_size
|
60
|
+
)
|
61
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
62
|
+
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
63
|
+
trust_remote=True,
|
64
|
+
)
|
65
|
+
# add special tokens
|
66
|
+
self.tokenizer.add_special_tokens(
|
67
|
+
{
|
68
|
+
"additional_special_tokens": [
|
69
|
+
"<|inventory|>",
|
70
|
+
]
|
71
|
+
}
|
72
|
+
)
|
73
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
74
|
+
self.inventory_idx = self.tokenizer.convert_tokens_to_ids("<|inventory|>")
|
75
|
+
|
76
|
+
# resize token embeddings
|
77
|
+
self.text_model.resize_token_embeddings(len(self.tokenizer))
|
78
|
+
# image transforms
|
79
|
+
self.transforms = v2.Compose(
|
80
|
+
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
|
81
|
+
)
|
82
|
+
|
83
|
+
@torch.no_grad()
|
84
|
+
def extract_bboxes(self, images: list) -> list[dict]:
|
85
|
+
if len(images) == 0:
|
86
|
+
return []
|
87
|
+
img_tensors = torch.stack([self.transforms(img) for img in images])
|
88
|
+
img_tensors = img_tensors.cuda()
|
89
|
+
# disable gradients
|
90
|
+
self.vision_model.freeze()
|
91
|
+
# get bounding box predictions
|
92
|
+
bbox_preds = self.vision_model(img_tensors)
|
93
|
+
return bbox_preds
|
94
|
+
|
95
|
+
def prepare_messages(self, messages: list, bboxes: list[dict]) -> str:
|
96
|
+
# no bounding boxes
|
97
|
+
if len(bboxes) == 0:
|
98
|
+
text = self.tokenizer.apply_chat_template(
|
99
|
+
messages, add_generation_prompt=not self.training, tokenize=False
|
100
|
+
)
|
101
|
+
text = text.replace("<|begin_of_text|>", "")
|
102
|
+
return text
|
103
|
+
|
104
|
+
# expand <|inventory|> tokens into N tokens (N = number of bounding boxes)
|
105
|
+
new_messages = []
|
106
|
+
i_pred = 0
|
107
|
+
for m in messages:
|
108
|
+
new_message = m.copy()
|
109
|
+
if new_message["role"] == "user" and new_message["content"].endswith(
|
110
|
+
"<|inventory|>"
|
111
|
+
):
|
112
|
+
# add inventory tokens for each bounding box
|
113
|
+
new_message["content"] = new_message["content"].replace(
|
114
|
+
"<|inventory|>",
|
115
|
+
"<|inventory|>" * (bboxes[i_pred]["features"].shape[0]),
|
116
|
+
)
|
117
|
+
i_pred += 1
|
118
|
+
new_messages.append(new_message)
|
119
|
+
|
120
|
+
assert i_pred == len(
|
121
|
+
bboxes
|
122
|
+
), "Number of inventory tokens does not match number of images"
|
123
|
+
# add special tokens
|
124
|
+
|
125
|
+
text = self.tokenizer.apply_chat_template(
|
126
|
+
new_messages, add_generation_prompt=not self.training, tokenize=False
|
127
|
+
)
|
128
|
+
text = text.replace("<|begin_of_text|>", "")
|
129
|
+
return text
|
130
|
+
|
131
|
+
def inputs_merger(
|
132
|
+
self,
|
133
|
+
input_ids: torch.LongTensor,
|
134
|
+
inputs_embeds: Optional[torch.Tensor],
|
135
|
+
image_hidden_states: Optional[torch.Tensor],
|
136
|
+
):
|
137
|
+
# along batch dimension
|
138
|
+
for i in range(len(image_hidden_states)):
|
139
|
+
if len(image_hidden_states[i]) == 0:
|
140
|
+
assert (
|
141
|
+
input_ids[i] == self.inventory_idx
|
142
|
+
).sum() == 0, "No images but inventory token is still present"
|
143
|
+
continue
|
144
|
+
|
145
|
+
# count the number of inventory tokens
|
146
|
+
n_inventory_tokens = (input_ids[i] == self.inventory_idx).sum()
|
147
|
+
if n_inventory_tokens != image_hidden_states[i].shape[0]:
|
148
|
+
logger.warning(
|
149
|
+
f"Number of inventory tokens ({n_inventory_tokens}) does not match number of bounding boxes ({image_hidden_states[i].shape[0]}). Possible truncation."
|
150
|
+
)
|
151
|
+
# truncated from the start
|
152
|
+
image_hidden_states[i] = image_hidden_states[i][-n_inventory_tokens:]
|
153
|
+
|
154
|
+
# replace inventory tokens with bbox features
|
155
|
+
inputs_embeds[i, input_ids[i] == self.inventory_idx] = image_hidden_states[
|
156
|
+
i
|
157
|
+
]
|
158
|
+
return inputs_embeds
|
159
|
+
|
160
|
+
def process_inputs(
|
161
|
+
self,
|
162
|
+
batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
|
163
|
+
batch_images: list[list] = [], # list of list of images (unprocessed)
|
164
|
+
) -> tuple[dict[str, torch.FloatTensor], list[torch.FloatTensor], int]:
|
165
|
+
"""
|
166
|
+
Converts raw images and messages into model inputs
|
167
|
+
"""
|
168
|
+
assert len(batch_images) == len(
|
169
|
+
batch_messages
|
170
|
+
), "Number of images and messages should match in the batch dim"
|
171
|
+
# initial forward pass
|
172
|
+
texts_batch = []
|
173
|
+
image_hidden_states = []
|
174
|
+
total_boxes = 0
|
175
|
+
for images, messages in zip(batch_images, batch_messages):
|
176
|
+
# process images
|
177
|
+
bboxes = self.extract_bboxes(images)
|
178
|
+
if len(bboxes) > 0:
|
179
|
+
# get bbox features
|
180
|
+
features = torch.concat([p["features"] for p in bboxes], dim=0)
|
181
|
+
# upscale to text embedding size
|
182
|
+
features_embeds = self.vision_to_text_embedding(features)
|
183
|
+
image_hidden_states.append(features_embeds)
|
184
|
+
# count bboxes total
|
185
|
+
total_boxes += features.shape[0]
|
186
|
+
else:
|
187
|
+
image_hidden_states.append([])
|
188
|
+
|
189
|
+
# process messages
|
190
|
+
text = self.prepare_messages(messages, bboxes)
|
191
|
+
texts_batch.append(text)
|
192
|
+
|
193
|
+
# tokenize text
|
194
|
+
# @NOTE: truncation could cause issues with inventory tokens not matching number of boxes
|
195
|
+
# in that case, we will truncate the boxes from the end, and issue a warning
|
196
|
+
batch = self.tokenizer(
|
197
|
+
texts_batch,
|
198
|
+
truncation=True,
|
199
|
+
padding=True,
|
200
|
+
max_length=16384,
|
201
|
+
return_tensors="pt",
|
202
|
+
)
|
203
|
+
return batch, image_hidden_states, total_boxes
|
204
|
+
|
205
|
+
def forward(
|
206
|
+
self,
|
207
|
+
batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
|
208
|
+
batch_images: list[list] = [], # list of list of images (unprocessed)
|
209
|
+
**kwargs,
|
210
|
+
):
|
211
|
+
labels = None
|
212
|
+
batch, image_hidden_states, total_boxes = self.process_inputs(
|
213
|
+
batch_messages, batch_images
|
214
|
+
)
|
215
|
+
# move to cuda
|
216
|
+
batch = {k: v.cuda() for k, v in batch.items()}
|
217
|
+
attention_mask = batch["attention_mask"]
|
218
|
+
input_ids = batch["input_ids"]
|
219
|
+
|
220
|
+
labels = input_ids.clone()
|
221
|
+
# remove inventory tokens from labels
|
222
|
+
labels[labels == self.inventory_idx] = -100
|
223
|
+
# sanity check: should have same number of boxes as inventory tokens
|
224
|
+
assert (labels == -100).sum() == total_boxes
|
225
|
+
|
226
|
+
# get text embeddings
|
227
|
+
inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
228
|
+
inputs_embeds = self.inputs_merger(
|
229
|
+
input_ids, inputs_embeds, image_hidden_states
|
230
|
+
)
|
231
|
+
# forward pass
|
232
|
+
return self.text_model(
|
233
|
+
inputs_embeds=inputs_embeds,
|
234
|
+
attention_mask=attention_mask,
|
235
|
+
labels=labels,
|
236
|
+
return_dict=True,
|
237
|
+
)
|
238
|
+
|
239
|
+
@torch.no_grad()
|
240
|
+
def generate(
|
241
|
+
self,
|
242
|
+
batch_messages: list[list[dict]],
|
243
|
+
batch_images: list[list],
|
244
|
+
do_sample=True,
|
245
|
+
temperature=0.6,
|
246
|
+
max_new_tokens=32,
|
247
|
+
):
|
248
|
+
self.tokenizer.padding_side = "left"
|
249
|
+
|
250
|
+
batch, image_hidden_states, _ = self.process_inputs(
|
251
|
+
batch_messages, batch_images
|
252
|
+
)
|
253
|
+
batch = {k: v.cuda() for k, v in batch.items()}
|
254
|
+
attention_mask = batch["attention_mask"]
|
255
|
+
input_ids = batch["input_ids"]
|
256
|
+
|
257
|
+
inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
258
|
+
inputs_embeds = self.inputs_merger(
|
259
|
+
input_ids, inputs_embeds, image_hidden_states
|
260
|
+
)
|
261
|
+
|
262
|
+
generated_sequences = self.text_model.generate(
|
263
|
+
inputs_embeds=inputs_embeds,
|
264
|
+
attention_mask=attention_mask,
|
265
|
+
do_sample=do_sample,
|
266
|
+
temperature=temperature,
|
267
|
+
max_new_tokens=max_new_tokens,
|
268
|
+
pad_token_id=self.tokenizer.eos_token_id,
|
269
|
+
)
|
270
|
+
|
271
|
+
# Decode the output
|
272
|
+
text_responses = self.tokenizer.batch_decode(
|
273
|
+
generated_sequences,
|
274
|
+
# generated_sequences[:, prompt_tokens:],
|
275
|
+
skip_special_tokens=False,
|
276
|
+
)
|
277
|
+
|
278
|
+
# remove <|eot_id|> tokens
|
279
|
+
text_responses = [
|
280
|
+
text_response.replace("<|eot_id|>", "") for text_response in text_responses
|
281
|
+
]
|
282
|
+
|
283
|
+
_, total_tokens_used = generated_sequences.shape
|
284
|
+
return text_responses, total_tokens_used
|
models/oracle.py
ADDED
@@ -0,0 +1,268 @@
|
|
1
|
+
import logging
|
2
|
+
import copy
|
3
|
+
from collections import Counter
|
4
|
+
|
5
|
+
from plancraft.config import EvalConfig
|
6
|
+
from plancraft.environments.actions import (
|
7
|
+
RealActionInteraction,
|
8
|
+
SymbolicMoveAction,
|
9
|
+
SymbolicSmeltAction,
|
10
|
+
StopAction,
|
11
|
+
)
|
12
|
+
from plancraft.environments.planner import optimal_planner
|
13
|
+
from plancraft.environments.recipes import (
|
14
|
+
ShapedRecipe,
|
15
|
+
ShapelessRecipe,
|
16
|
+
SmeltingRecipe,
|
17
|
+
id_to_item,
|
18
|
+
)
|
19
|
+
from plancraft.models.base import ABCModel, History
|
20
|
+
from plancraft.environments.sampler import MAX_STACK_SIZE
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
def item_set_id_to_type(item_set_ids: set[int]):
|
26
|
+
return set(id_to_item(i) for i in item_set_ids)
|
27
|
+
|
28
|
+
|
29
|
+
def find_free_inventory_slot(inventory: list[dict], from_slot: int) -> int:
|
30
|
+
# find a free slot in the inventory for the item in from_slot
|
31
|
+
from_item_type, from_item_quantity = None, None
|
32
|
+
|
33
|
+
type_to_slot = {}
|
34
|
+
slot_to_quantity = {}
|
35
|
+
for item in inventory:
|
36
|
+
if ("slot" in item and item["slot"] == from_slot) or (
|
37
|
+
"index" in item and item["index"] == from_slot
|
38
|
+
):
|
39
|
+
from_item_quantity = item["quantity"]
|
40
|
+
from_item_type = item["type"]
|
41
|
+
# break
|
42
|
+
item_type = item["type"]
|
43
|
+
quantity = item["quantity"]
|
44
|
+
if quantity == 0:
|
45
|
+
item_type = "air"
|
46
|
+
|
47
|
+
if "slot" in item:
|
48
|
+
item_slot = item["slot"]
|
49
|
+
else:
|
50
|
+
item_slot = item["index"]
|
51
|
+
|
52
|
+
if item_type not in type_to_slot:
|
53
|
+
type_to_slot[item_type] = [item_slot]
|
54
|
+
else:
|
55
|
+
type_to_slot[item_type].append(item_slot)
|
56
|
+
|
57
|
+
if item_slot not in slot_to_quantity:
|
58
|
+
slot_to_quantity[item_slot] = quantity
|
59
|
+
else:
|
60
|
+
slot_to_quantity[item_slot] += quantity
|
61
|
+
|
62
|
+
assert from_item_type is not None, f"Item not found in slot {from_slot}"
|
63
|
+
|
64
|
+
# if there is a free slot with the same item type
|
65
|
+
if from_item_type in type_to_slot:
|
66
|
+
for slot in type_to_slot[from_item_type]:
|
67
|
+
if (
|
68
|
+
slot != from_slot
|
69
|
+
and slot_to_quantity[slot] + from_item_quantity
|
70
|
+
<= MAX_STACK_SIZE[from_item_type]
|
71
|
+
):
|
72
|
+
return slot
|
73
|
+
|
74
|
+
# if there is a free slot with air
|
75
|
+
for slot in type_to_slot["air"]:
|
76
|
+
if slot != from_slot and slot > 10:
|
77
|
+
return slot
|
78
|
+
|
79
|
+
raise ValueError("No free slot found")
|
80
|
+
|
81
|
+
|
82
|
+
def find_item_in_inventory(target: str, inventory: list[dict]) -> int:
|
83
|
+
for item in inventory:
|
84
|
+
if item["type"] == target and item["quantity"] > 0:
|
85
|
+
if "slot" in item:
|
86
|
+
return item["slot"]
|
87
|
+
elif "index" in item:
|
88
|
+
return item["index"]
|
89
|
+
raise ValueError("Neither slot or index is set")
|
90
|
+
|
91
|
+
|
92
|
+
def get_inventory_counter(inventory: list[dict]) -> Counter:
|
93
|
+
counter = Counter()
|
94
|
+
for item in inventory:
|
95
|
+
if "slot" in item and item["slot"] == 0:
|
96
|
+
continue
|
97
|
+
if "index" in item and item["index"] == 0:
|
98
|
+
continue
|
99
|
+
if item["type"] == "air":
|
100
|
+
continue
|
101
|
+
counter[item["type"]] += item["quantity"]
|
102
|
+
return counter
|
103
|
+
|
104
|
+
|
105
|
+
def get_crafting_slot_item(inventory: list[dict]) -> dict:
|
106
|
+
for item in inventory:
|
107
|
+
if "slot" in item and item["slot"] == 0 and item["quantity"] > 0:
|
108
|
+
return item
|
109
|
+
if "index" in item and item["index"] == 0 and item["quantity"] > 0:
|
110
|
+
return item
|
111
|
+
return None
|
112
|
+
|
113
|
+
|
114
|
+
def update_inventory(
|
115
|
+
inventory: list[dict], slot_from: int, slot_to: int, quantity: int
|
116
|
+
) -> list[dict]:
|
117
|
+
"""
|
118
|
+
decrements quantity of item in slot_from
|
119
|
+
NOTE: we don't care about incrementing the items in slot_to
|
120
|
+
|
121
|
+
"""
|
122
|
+
new_inventory = []
|
123
|
+
for item in inventory:
|
124
|
+
if "slot" in item and item["slot"] == slot_from:
|
125
|
+
item["quantity"] -= quantity
|
126
|
+
elif "index" in item and item["index"] == slot_from:
|
127
|
+
item["quantity"] -= quantity
|
128
|
+
new_inventory.append(item)
|
129
|
+
return new_inventory
|
130
|
+
|
131
|
+
|
132
|
+
class OracleModel(ABCModel):
|
133
|
+
"""
|
134
|
+
Oracle model returns actions that solve the task optimally
|
135
|
+
"""
|
136
|
+
|
137
|
+
def __init__(self, cfg: EvalConfig):
|
138
|
+
assert (
|
139
|
+
cfg.plancraft.environment.symbolic_action_space
|
140
|
+
), "Only symbolic actions are supported for oracle"
|
141
|
+
self.history = History(objective="")
|
142
|
+
self.plans = []
|
143
|
+
self.subplans = []
|
144
|
+
|
145
|
+
def reset_history(self, objective: str = ""):
|
146
|
+
self.history.reset(objective=objective)
|
147
|
+
self.plans = []
|
148
|
+
self.subplans = []
|
149
|
+
|
150
|
+
def get_plan(self, observation: dict):
|
151
|
+
# objective="Craft an item of type: ...."
|
152
|
+
# this simply recovering the target item to craft
|
153
|
+
target = self.history.objective.split(": ")[-1]
|
154
|
+
inventory_counter = get_inventory_counter(observation["inventory"])
|
155
|
+
self.plans = optimal_planner(target=target, inventory=inventory_counter)
|
156
|
+
|
157
|
+
def get_next_action(
|
158
|
+
self, observation: dict
|
159
|
+
) -> SymbolicMoveAction | SymbolicSmeltAction:
|
160
|
+
if len(self.subplans) > 0:
|
161
|
+
return self.subplans.pop(0)
|
162
|
+
if len(self.plans) == 0:
|
163
|
+
raise ValueError("No more steps in plan")
|
164
|
+
|
165
|
+
observed_inventory = copy.deepcopy(observation["inventory"])
|
166
|
+
|
167
|
+
# take item from crafting slot
|
168
|
+
if slot_item := get_crafting_slot_item(observed_inventory):
|
169
|
+
# move item from crafting slot to inventory
|
170
|
+
free_slot = find_free_inventory_slot(observed_inventory, from_slot=0)
|
171
|
+
return SymbolicMoveAction(
|
172
|
+
slot_from=0, slot_to=free_slot, quantity=slot_item["quantity"]
|
173
|
+
)
|
174
|
+
|
175
|
+
plan_recipe, new_inventory = self.plans.pop(0)
|
176
|
+
self.subplans = []
|
177
|
+
new_inventory_counter = Counter(new_inventory)
|
178
|
+
current_inventory = observed_inventory
|
179
|
+
current_inventory_counter = get_inventory_counter(current_inventory)
|
180
|
+
items_to_use_counter = current_inventory_counter - new_inventory_counter
|
181
|
+
new_items = new_inventory_counter - current_inventory_counter
|
182
|
+
assert len(new_items) == 1
|
183
|
+
|
184
|
+
if isinstance(plan_recipe, ShapelessRecipe):
|
185
|
+
crafting_slot = 1
|
186
|
+
|
187
|
+
# add each item to crafting slots
|
188
|
+
for item, quantity in items_to_use_counter.items():
|
189
|
+
n = 0
|
190
|
+
while n < quantity:
|
191
|
+
from_slot = find_item_in_inventory(item, current_inventory)
|
192
|
+
|
193
|
+
# skip if from_slot is the crafting slot
|
194
|
+
if from_slot == crafting_slot:
|
195
|
+
crafting_slot += 1
|
196
|
+
n += 1
|
197
|
+
continue
|
198
|
+
|
199
|
+
# low_level_plan.append(("move", item, from_slot, crafting_slot, 1))
|
200
|
+
action = SymbolicMoveAction(
|
201
|
+
slot_from=from_slot, slot_to=crafting_slot, quantity=1
|
202
|
+
)
|
203
|
+
# update state of inventory
|
204
|
+
current_inventory = update_inventory(
|
205
|
+
current_inventory, from_slot, crafting_slot, 1
|
206
|
+
)
|
207
|
+
self.subplans.append(action)
|
208
|
+
|
209
|
+
crafting_slot += 1
|
210
|
+
n += 1
|
211
|
+
|
212
|
+
# if plan_recipe is a smelting recipe
|
213
|
+
elif isinstance(plan_recipe, SmeltingRecipe):
|
214
|
+
assert len(items_to_use_counter) == 1, "smelting only supports one item"
|
215
|
+
for item, quantity in items_to_use_counter.items():
|
216
|
+
from_slot = find_item_in_inventory(item, current_inventory)
|
217
|
+
free_slot = find_free_inventory_slot(
|
218
|
+
current_inventory, from_slot=from_slot
|
219
|
+
)
|
220
|
+
action = SymbolicSmeltAction(
|
221
|
+
slot_from=from_slot, slot_to=free_slot, quantity=quantity
|
222
|
+
)
|
223
|
+
self.subplans.append(action)
|
224
|
+
|
225
|
+
# if plan_recipe is a shaped recipe
|
226
|
+
elif isinstance(plan_recipe, ShapedRecipe):
|
227
|
+
for i, row in enumerate(plan_recipe.kernel):
|
228
|
+
for j, item_set in enumerate(row):
|
229
|
+
inventory_position = (i * 3) + j + 1
|
230
|
+
valid_items = item_set_id_to_type(item_set)
|
231
|
+
for item in valid_items:
|
232
|
+
if items_to_use_counter[item] > 0:
|
233
|
+
from_slot = find_item_in_inventory(item, current_inventory)
|
234
|
+
action = SymbolicMoveAction(
|
235
|
+
slot_from=from_slot,
|
236
|
+
slot_to=inventory_position,
|
237
|
+
quantity=1,
|
238
|
+
)
|
239
|
+
items_to_use_counter[item] -= 1
|
240
|
+
# update state of inventory
|
241
|
+
current_inventory = update_inventory(
|
242
|
+
current_inventory, from_slot, inventory_position, 1
|
243
|
+
)
|
244
|
+
self.subplans.append(action)
|
245
|
+
break
|
246
|
+
else:
|
247
|
+
raise NotImplementedError(f"Recipe type {type(plan_recipe)} not supported")
|
248
|
+
|
249
|
+
return self.subplans.pop(0)
|
250
|
+
|
251
|
+
def step(
|
252
|
+
self, observation: dict
|
253
|
+
) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
|
254
|
+
# add observation to history
|
255
|
+
self.history.add_observation_to_history(observation)
|
256
|
+
|
257
|
+
# get action
|
258
|
+
if len(self.plans) == 0:
|
259
|
+
self.get_plan(observation)
|
260
|
+
if self.plans is None:
|
261
|
+
self.plans = []
|
262
|
+
return StopAction()
|
263
|
+
|
264
|
+
action = self.get_next_action(observation)
|
265
|
+
|
266
|
+
# add action to history
|
267
|
+
self.history.add_action_to_history(action)
|
268
|
+
return action
|