plancraft 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- plancraft-0.1.2.dist-info/METADATA +74 -0
- plancraft-0.1.2.dist-info/RECORD +5 -0
- {plancraft-0.1.1.dist-info → plancraft-0.1.2.dist-info}/WHEEL +1 -1
- plancraft-0.1.2.dist-info/top_level.txt +1 -0
- environments/__init__.py +0 -0
- environments/actions.py +0 -218
- environments/env_real.py +0 -316
- environments/env_symbolic.py +0 -212
- environments/items.py +0 -10
- environments/planner.py +0 -109
- environments/recipes.py +0 -542
- environments/sampler.py +0 -224
- models/__init__.py +0 -21
- models/act.py +0 -184
- models/base.py +0 -152
- models/bbox_model.py +0 -492
- models/dummy.py +0 -54
- models/few_shot_images/__init__.py +0 -16
- models/generators.py +0 -480
- models/oam.py +0 -283
- models/oracle.py +0 -265
- models/prompts.py +0 -158
- models/react.py +0 -93
- models/utils.py +0 -289
- plancraft-0.1.1.dist-info/METADATA +0 -74
- plancraft-0.1.1.dist-info/RECORD +0 -26
- plancraft-0.1.1.dist-info/top_level.txt +0 -3
- train/dataset.py +0 -187
- {plancraft-0.1.1.dist-info → plancraft-0.1.2.dist-info}/LICENSE +0 -0
models/oam.py
DELETED
@@ -1,283 +0,0 @@
|
|
1
|
-
from typing import Optional
|
2
|
-
|
3
|
-
from loguru import logger
|
4
|
-
|
5
|
-
import torch
|
6
|
-
import torch.nn as nn
|
7
|
-
import torchvision.transforms.v2 as v2
|
8
|
-
from transformers import (
|
9
|
-
AutoConfig,
|
10
|
-
AutoModelForCausalLM,
|
11
|
-
AutoTokenizer,
|
12
|
-
PretrainedConfig,
|
13
|
-
PreTrainedModel,
|
14
|
-
)
|
15
|
-
|
16
|
-
from plancraft.models.bbox_model import IntegratedBoundingBoxModel
|
17
|
-
|
18
|
-
|
19
|
-
class PlancraftOAMConfig(PretrainedConfig):
|
20
|
-
model_type = "plancraft-aom"
|
21
|
-
is_composition = True
|
22
|
-
|
23
|
-
def __init__(
|
24
|
-
self,
|
25
|
-
from_llama=False,
|
26
|
-
**kwargs,
|
27
|
-
):
|
28
|
-
self.from_llama = from_llama
|
29
|
-
super().__init__(**kwargs)
|
30
|
-
|
31
|
-
|
32
|
-
class PlancraftOAM(PreTrainedModel):
|
33
|
-
config_class = PlancraftOAMConfig
|
34
|
-
|
35
|
-
def __init__(self, config: PlancraftOAMConfig):
|
36
|
-
super().__init__(config)
|
37
|
-
|
38
|
-
self.config = config
|
39
|
-
# load text model
|
40
|
-
if self.config.from_llama:
|
41
|
-
self.text_model = AutoModelForCausalLM.from_pretrained(
|
42
|
-
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
43
|
-
)
|
44
|
-
else:
|
45
|
-
text_model_config = AutoConfig.from_pretrained(
|
46
|
-
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
47
|
-
)
|
48
|
-
self.text_model = AutoModelForCausalLM.from_config(text_model_config)
|
49
|
-
|
50
|
-
# load vision model
|
51
|
-
self.vision_model = IntegratedBoundingBoxModel.from_pretrained(
|
52
|
-
"gautierdag/plancraft-maskrcnn"
|
53
|
-
)
|
54
|
-
self.vision_model.eval()
|
55
|
-
|
56
|
-
# convert vision features to text embedding
|
57
|
-
self.vision_to_text_embedding = nn.Linear(
|
58
|
-
1024, self.text_model.config.hidden_size
|
59
|
-
)
|
60
|
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
61
|
-
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
62
|
-
trust_remote=True,
|
63
|
-
)
|
64
|
-
# add special tokens
|
65
|
-
self.tokenizer.add_special_tokens(
|
66
|
-
{
|
67
|
-
"additional_special_tokens": [
|
68
|
-
"<|inventory|>",
|
69
|
-
]
|
70
|
-
}
|
71
|
-
)
|
72
|
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
73
|
-
self.inventory_idx = self.tokenizer.convert_tokens_to_ids("<|inventory|>")
|
74
|
-
|
75
|
-
# resize token embeddings
|
76
|
-
self.text_model.resize_token_embeddings(len(self.tokenizer))
|
77
|
-
# image transforms
|
78
|
-
self.transforms = v2.Compose(
|
79
|
-
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
|
80
|
-
)
|
81
|
-
|
82
|
-
@torch.no_grad()
|
83
|
-
def extract_bboxes(self, images: list) -> list[dict]:
|
84
|
-
if len(images) == 0:
|
85
|
-
return []
|
86
|
-
img_tensors = torch.stack([self.transforms(img) for img in images])
|
87
|
-
img_tensors = img_tensors.cuda()
|
88
|
-
# disable gradients
|
89
|
-
self.vision_model.freeze()
|
90
|
-
# get bounding box predictions
|
91
|
-
bbox_preds = self.vision_model(img_tensors)
|
92
|
-
return bbox_preds
|
93
|
-
|
94
|
-
def prepare_messages(self, messages: list, bboxes: list[dict]) -> str:
|
95
|
-
# no bounding boxes
|
96
|
-
if len(bboxes) == 0:
|
97
|
-
text = self.tokenizer.apply_chat_template(
|
98
|
-
messages, add_generation_prompt=not self.training, tokenize=False
|
99
|
-
)
|
100
|
-
text = text.replace("<|begin_of_text|>", "")
|
101
|
-
return text
|
102
|
-
|
103
|
-
# expand <|inventory|> tokens into N tokens (N = number of bounding boxes)
|
104
|
-
new_messages = []
|
105
|
-
i_pred = 0
|
106
|
-
for m in messages:
|
107
|
-
new_message = m.copy()
|
108
|
-
if new_message["role"] == "user" and new_message["content"].endswith(
|
109
|
-
"<|inventory|>"
|
110
|
-
):
|
111
|
-
# add inventory tokens for each bounding box
|
112
|
-
new_message["content"] = new_message["content"].replace(
|
113
|
-
"<|inventory|>",
|
114
|
-
"<|inventory|>" * (bboxes[i_pred]["features"].shape[0]),
|
115
|
-
)
|
116
|
-
i_pred += 1
|
117
|
-
new_messages.append(new_message)
|
118
|
-
|
119
|
-
assert i_pred == len(
|
120
|
-
bboxes
|
121
|
-
), "Number of inventory tokens does not match number of images"
|
122
|
-
# add special tokens
|
123
|
-
|
124
|
-
text = self.tokenizer.apply_chat_template(
|
125
|
-
new_messages, add_generation_prompt=not self.training, tokenize=False
|
126
|
-
)
|
127
|
-
text = text.replace("<|begin_of_text|>", "")
|
128
|
-
return text
|
129
|
-
|
130
|
-
def inputs_merger(
|
131
|
-
self,
|
132
|
-
input_ids: torch.LongTensor,
|
133
|
-
inputs_embeds: Optional[torch.Tensor],
|
134
|
-
image_hidden_states: Optional[torch.Tensor],
|
135
|
-
):
|
136
|
-
# along batch dimension
|
137
|
-
for i in range(len(image_hidden_states)):
|
138
|
-
if len(image_hidden_states[i]) == 0:
|
139
|
-
assert (
|
140
|
-
input_ids[i] == self.inventory_idx
|
141
|
-
).sum() == 0, "No images but inventory token is still present"
|
142
|
-
continue
|
143
|
-
|
144
|
-
# count the number of inventory tokens
|
145
|
-
n_inventory_tokens = (input_ids[i] == self.inventory_idx).sum()
|
146
|
-
if n_inventory_tokens != image_hidden_states[i].shape[0]:
|
147
|
-
logger.warning(
|
148
|
-
f"Number of inventory tokens ({n_inventory_tokens}) does not match number of bounding boxes ({image_hidden_states[i].shape[0]}). Possible truncation."
|
149
|
-
)
|
150
|
-
# truncated from the start
|
151
|
-
image_hidden_states[i] = image_hidden_states[i][-n_inventory_tokens:]
|
152
|
-
|
153
|
-
# replace inventory tokens with bbox features
|
154
|
-
inputs_embeds[i, input_ids[i] == self.inventory_idx] = image_hidden_states[
|
155
|
-
i
|
156
|
-
]
|
157
|
-
return inputs_embeds
|
158
|
-
|
159
|
-
def process_inputs(
|
160
|
-
self,
|
161
|
-
batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
|
162
|
-
batch_images: list[list] = [], # list of list of images (unprocessed)
|
163
|
-
) -> tuple[dict[str, torch.FloatTensor], list[torch.FloatTensor], int]:
|
164
|
-
"""
|
165
|
-
Converts raw images and messages into model inputs
|
166
|
-
"""
|
167
|
-
assert len(batch_images) == len(
|
168
|
-
batch_messages
|
169
|
-
), "Number of images and messages should match in the batch dim"
|
170
|
-
# initial forward pass
|
171
|
-
texts_batch = []
|
172
|
-
image_hidden_states = []
|
173
|
-
total_boxes = 0
|
174
|
-
for images, messages in zip(batch_images, batch_messages):
|
175
|
-
# process images
|
176
|
-
bboxes = self.extract_bboxes(images)
|
177
|
-
if len(bboxes) > 0:
|
178
|
-
# get bbox features
|
179
|
-
features = torch.concat([p["features"] for p in bboxes], dim=0)
|
180
|
-
# upscale to text embedding size
|
181
|
-
features_embeds = self.vision_to_text_embedding(features)
|
182
|
-
image_hidden_states.append(features_embeds)
|
183
|
-
# count bboxes total
|
184
|
-
total_boxes += features.shape[0]
|
185
|
-
else:
|
186
|
-
image_hidden_states.append([])
|
187
|
-
|
188
|
-
# process messages
|
189
|
-
text = self.prepare_messages(messages, bboxes)
|
190
|
-
texts_batch.append(text)
|
191
|
-
|
192
|
-
# tokenize text
|
193
|
-
# @NOTE: truncation could cause issues with inventory tokens not matching number of boxes
|
194
|
-
# in that case, we will truncate the boxes from the end, and issue a warning
|
195
|
-
batch = self.tokenizer(
|
196
|
-
texts_batch,
|
197
|
-
truncation=True,
|
198
|
-
padding=True,
|
199
|
-
max_length=16384,
|
200
|
-
return_tensors="pt",
|
201
|
-
)
|
202
|
-
return batch, image_hidden_states, total_boxes
|
203
|
-
|
204
|
-
def forward(
|
205
|
-
self,
|
206
|
-
batch_messages: list[list[dict]] = [], # list of list of messages (untokenized)
|
207
|
-
batch_images: list[list] = [], # list of list of images (unprocessed)
|
208
|
-
**kwargs,
|
209
|
-
):
|
210
|
-
labels = None
|
211
|
-
batch, image_hidden_states, total_boxes = self.process_inputs(
|
212
|
-
batch_messages, batch_images
|
213
|
-
)
|
214
|
-
# move to cuda
|
215
|
-
batch = {k: v.cuda() for k, v in batch.items()}
|
216
|
-
attention_mask = batch["attention_mask"]
|
217
|
-
input_ids = batch["input_ids"]
|
218
|
-
|
219
|
-
labels = input_ids.clone()
|
220
|
-
# remove inventory tokens from labels
|
221
|
-
labels[labels == self.inventory_idx] = -100
|
222
|
-
# sanity check: should have same number of boxes as inventory tokens
|
223
|
-
assert (labels == -100).sum() == total_boxes
|
224
|
-
|
225
|
-
# get text embeddings
|
226
|
-
inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
227
|
-
inputs_embeds = self.inputs_merger(
|
228
|
-
input_ids, inputs_embeds, image_hidden_states
|
229
|
-
)
|
230
|
-
# forward pass
|
231
|
-
return self.text_model(
|
232
|
-
inputs_embeds=inputs_embeds,
|
233
|
-
attention_mask=attention_mask,
|
234
|
-
labels=labels,
|
235
|
-
return_dict=True,
|
236
|
-
)
|
237
|
-
|
238
|
-
@torch.no_grad()
|
239
|
-
def generate(
|
240
|
-
self,
|
241
|
-
batch_messages: list[list[dict]],
|
242
|
-
batch_images: list[list],
|
243
|
-
do_sample=True,
|
244
|
-
temperature=0.6,
|
245
|
-
max_new_tokens=32,
|
246
|
-
):
|
247
|
-
self.tokenizer.padding_side = "left"
|
248
|
-
|
249
|
-
batch, image_hidden_states, _ = self.process_inputs(
|
250
|
-
batch_messages, batch_images
|
251
|
-
)
|
252
|
-
batch = {k: v.cuda() for k, v in batch.items()}
|
253
|
-
attention_mask = batch["attention_mask"]
|
254
|
-
input_ids = batch["input_ids"]
|
255
|
-
|
256
|
-
inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
257
|
-
inputs_embeds = self.inputs_merger(
|
258
|
-
input_ids, inputs_embeds, image_hidden_states
|
259
|
-
)
|
260
|
-
|
261
|
-
generated_sequences = self.text_model.generate(
|
262
|
-
inputs_embeds=inputs_embeds,
|
263
|
-
attention_mask=attention_mask,
|
264
|
-
do_sample=do_sample,
|
265
|
-
temperature=temperature,
|
266
|
-
max_new_tokens=max_new_tokens,
|
267
|
-
pad_token_id=self.tokenizer.eos_token_id,
|
268
|
-
)
|
269
|
-
|
270
|
-
# Decode the output
|
271
|
-
text_responses = self.tokenizer.batch_decode(
|
272
|
-
generated_sequences,
|
273
|
-
# generated_sequences[:, prompt_tokens:],
|
274
|
-
skip_special_tokens=False,
|
275
|
-
)
|
276
|
-
|
277
|
-
# remove <|eot_id|> tokens
|
278
|
-
text_responses = [
|
279
|
-
text_response.replace("<|eot_id|>", "") for text_response in text_responses
|
280
|
-
]
|
281
|
-
|
282
|
-
_, total_tokens_used = generated_sequences.shape
|
283
|
-
return text_responses, total_tokens_used
|
models/oracle.py
DELETED
@@ -1,265 +0,0 @@
|
|
1
|
-
import copy
|
2
|
-
from collections import Counter
|
3
|
-
|
4
|
-
from plancraft.config import EvalConfig
|
5
|
-
from plancraft.environments.actions import (
|
6
|
-
RealActionInteraction,
|
7
|
-
StopAction,
|
8
|
-
SymbolicMoveAction,
|
9
|
-
SymbolicSmeltAction,
|
10
|
-
)
|
11
|
-
from plancraft.environments.planner import optimal_planner
|
12
|
-
from plancraft.environments.recipes import (
|
13
|
-
ShapedRecipe,
|
14
|
-
ShapelessRecipe,
|
15
|
-
SmeltingRecipe,
|
16
|
-
id_to_item,
|
17
|
-
)
|
18
|
-
from plancraft.environments.sampler import MAX_STACK_SIZE
|
19
|
-
from plancraft.models.base import ABCModel, History
|
20
|
-
|
21
|
-
|
22
|
-
def item_set_id_to_type(item_set_ids: set[int]):
|
23
|
-
return set(id_to_item(i) for i in item_set_ids)
|
24
|
-
|
25
|
-
|
26
|
-
def find_free_inventory_slot(inventory: list[dict], from_slot: int) -> int:
|
27
|
-
# find a free slot in the inventory for the item in from_slot
|
28
|
-
from_item_type, from_item_quantity = None, None
|
29
|
-
|
30
|
-
type_to_slot = {}
|
31
|
-
slot_to_quantity = {}
|
32
|
-
for item in inventory:
|
33
|
-
if ("slot" in item and item["slot"] == from_slot) or (
|
34
|
-
"index" in item and item["index"] == from_slot
|
35
|
-
):
|
36
|
-
from_item_quantity = item["quantity"]
|
37
|
-
from_item_type = item["type"]
|
38
|
-
# break
|
39
|
-
item_type = item["type"]
|
40
|
-
quantity = item["quantity"]
|
41
|
-
if quantity == 0:
|
42
|
-
item_type = "air"
|
43
|
-
|
44
|
-
if "slot" in item:
|
45
|
-
item_slot = item["slot"]
|
46
|
-
else:
|
47
|
-
item_slot = item["index"]
|
48
|
-
|
49
|
-
if item_type not in type_to_slot:
|
50
|
-
type_to_slot[item_type] = [item_slot]
|
51
|
-
else:
|
52
|
-
type_to_slot[item_type].append(item_slot)
|
53
|
-
|
54
|
-
if item_slot not in slot_to_quantity:
|
55
|
-
slot_to_quantity[item_slot] = quantity
|
56
|
-
else:
|
57
|
-
slot_to_quantity[item_slot] += quantity
|
58
|
-
|
59
|
-
assert from_item_type is not None, f"Item not found in slot {from_slot}"
|
60
|
-
|
61
|
-
# if there is a free slot with the same item type
|
62
|
-
if from_item_type in type_to_slot:
|
63
|
-
for slot in type_to_slot[from_item_type]:
|
64
|
-
if (
|
65
|
-
slot != from_slot
|
66
|
-
and slot_to_quantity[slot] + from_item_quantity
|
67
|
-
<= MAX_STACK_SIZE[from_item_type]
|
68
|
-
):
|
69
|
-
return slot
|
70
|
-
|
71
|
-
# if there is a free slot with air
|
72
|
-
for slot in type_to_slot["air"]:
|
73
|
-
if slot != from_slot and slot > 10:
|
74
|
-
return slot
|
75
|
-
|
76
|
-
raise ValueError("No free slot found")
|
77
|
-
|
78
|
-
|
79
|
-
def find_item_in_inventory(target: str, inventory: list[dict]) -> int:
|
80
|
-
for item in inventory:
|
81
|
-
if item["type"] == target and item["quantity"] > 0:
|
82
|
-
if "slot" in item:
|
83
|
-
return item["slot"]
|
84
|
-
elif "index" in item:
|
85
|
-
return item["index"]
|
86
|
-
raise ValueError("Neither slot or index is set")
|
87
|
-
|
88
|
-
|
89
|
-
def get_inventory_counter(inventory: list[dict]) -> Counter:
|
90
|
-
counter = Counter()
|
91
|
-
for item in inventory:
|
92
|
-
if "slot" in item and item["slot"] == 0:
|
93
|
-
continue
|
94
|
-
if "index" in item and item["index"] == 0:
|
95
|
-
continue
|
96
|
-
if item["type"] == "air":
|
97
|
-
continue
|
98
|
-
counter[item["type"]] += item["quantity"]
|
99
|
-
return counter
|
100
|
-
|
101
|
-
|
102
|
-
def get_crafting_slot_item(inventory: list[dict]) -> dict:
|
103
|
-
for item in inventory:
|
104
|
-
if "slot" in item and item["slot"] == 0 and item["quantity"] > 0:
|
105
|
-
return item
|
106
|
-
if "index" in item and item["index"] == 0 and item["quantity"] > 0:
|
107
|
-
return item
|
108
|
-
return None
|
109
|
-
|
110
|
-
|
111
|
-
def update_inventory(
|
112
|
-
inventory: list[dict], slot_from: int, slot_to: int, quantity: int
|
113
|
-
) -> list[dict]:
|
114
|
-
"""
|
115
|
-
decrements quantity of item in slot_from
|
116
|
-
NOTE: we don't care about incrementing the items in slot_to
|
117
|
-
|
118
|
-
"""
|
119
|
-
new_inventory = []
|
120
|
-
for item in inventory:
|
121
|
-
if "slot" in item and item["slot"] == slot_from:
|
122
|
-
item["quantity"] -= quantity
|
123
|
-
elif "index" in item and item["index"] == slot_from:
|
124
|
-
item["quantity"] -= quantity
|
125
|
-
new_inventory.append(item)
|
126
|
-
return new_inventory
|
127
|
-
|
128
|
-
|
129
|
-
class OracleModel(ABCModel):
|
130
|
-
"""
|
131
|
-
Oracle model returns actions that solve the task optimally
|
132
|
-
"""
|
133
|
-
|
134
|
-
def __init__(self, cfg: EvalConfig):
|
135
|
-
assert (
|
136
|
-
cfg.plancraft.environment.symbolic_action_space
|
137
|
-
), "Only symbolic actions are supported for oracle"
|
138
|
-
self.history = History(objective="")
|
139
|
-
self.plans = []
|
140
|
-
self.subplans = []
|
141
|
-
|
142
|
-
def reset_history(self, objective: str = ""):
|
143
|
-
self.history.reset(objective=objective)
|
144
|
-
self.plans = []
|
145
|
-
self.subplans = []
|
146
|
-
|
147
|
-
def get_plan(self, observation: dict):
|
148
|
-
# objective="Craft an item of type: ...."
|
149
|
-
# this simply recovering the target item to craft
|
150
|
-
target = self.history.objective.split(": ")[-1]
|
151
|
-
inventory_counter = get_inventory_counter(observation["inventory"])
|
152
|
-
self.plans = optimal_planner(target=target, inventory=inventory_counter)
|
153
|
-
|
154
|
-
def get_next_action(
|
155
|
-
self, observation: dict
|
156
|
-
) -> SymbolicMoveAction | SymbolicSmeltAction:
|
157
|
-
if len(self.subplans) > 0:
|
158
|
-
return self.subplans.pop(0)
|
159
|
-
if len(self.plans) == 0:
|
160
|
-
raise ValueError("No more steps in plan")
|
161
|
-
|
162
|
-
observed_inventory = copy.deepcopy(observation["inventory"])
|
163
|
-
|
164
|
-
# take item from crafting slot
|
165
|
-
if slot_item := get_crafting_slot_item(observed_inventory):
|
166
|
-
# move item from crafting slot to inventory
|
167
|
-
free_slot = find_free_inventory_slot(observed_inventory, from_slot=0)
|
168
|
-
return SymbolicMoveAction(
|
169
|
-
slot_from=0, slot_to=free_slot, quantity=slot_item["quantity"]
|
170
|
-
)
|
171
|
-
|
172
|
-
plan_recipe, new_inventory = self.plans.pop(0)
|
173
|
-
self.subplans = []
|
174
|
-
new_inventory_counter = Counter(new_inventory)
|
175
|
-
current_inventory = observed_inventory
|
176
|
-
current_inventory_counter = get_inventory_counter(current_inventory)
|
177
|
-
items_to_use_counter = current_inventory_counter - new_inventory_counter
|
178
|
-
new_items = new_inventory_counter - current_inventory_counter
|
179
|
-
assert len(new_items) == 1
|
180
|
-
|
181
|
-
if isinstance(plan_recipe, ShapelessRecipe):
|
182
|
-
crafting_slot = 1
|
183
|
-
|
184
|
-
# add each item to crafting slots
|
185
|
-
for item, quantity in items_to_use_counter.items():
|
186
|
-
n = 0
|
187
|
-
while n < quantity:
|
188
|
-
from_slot = find_item_in_inventory(item, current_inventory)
|
189
|
-
|
190
|
-
# skip if from_slot is the crafting slot
|
191
|
-
if from_slot == crafting_slot:
|
192
|
-
crafting_slot += 1
|
193
|
-
n += 1
|
194
|
-
continue
|
195
|
-
|
196
|
-
# low_level_plan.append(("move", item, from_slot, crafting_slot, 1))
|
197
|
-
action = SymbolicMoveAction(
|
198
|
-
slot_from=from_slot, slot_to=crafting_slot, quantity=1
|
199
|
-
)
|
200
|
-
# update state of inventory
|
201
|
-
current_inventory = update_inventory(
|
202
|
-
current_inventory, from_slot, crafting_slot, 1
|
203
|
-
)
|
204
|
-
self.subplans.append(action)
|
205
|
-
|
206
|
-
crafting_slot += 1
|
207
|
-
n += 1
|
208
|
-
|
209
|
-
# if plan_recipe is a smelting recipe
|
210
|
-
elif isinstance(plan_recipe, SmeltingRecipe):
|
211
|
-
assert len(items_to_use_counter) == 1, "smelting only supports one item"
|
212
|
-
for item, quantity in items_to_use_counter.items():
|
213
|
-
from_slot = find_item_in_inventory(item, current_inventory)
|
214
|
-
free_slot = find_free_inventory_slot(
|
215
|
-
current_inventory, from_slot=from_slot
|
216
|
-
)
|
217
|
-
action = SymbolicSmeltAction(
|
218
|
-
slot_from=from_slot, slot_to=free_slot, quantity=quantity
|
219
|
-
)
|
220
|
-
self.subplans.append(action)
|
221
|
-
|
222
|
-
# if plan_recipe is a shaped recipe
|
223
|
-
elif isinstance(plan_recipe, ShapedRecipe):
|
224
|
-
for i, row in enumerate(plan_recipe.kernel):
|
225
|
-
for j, item_set in enumerate(row):
|
226
|
-
inventory_position = (i * 3) + j + 1
|
227
|
-
valid_items = item_set_id_to_type(item_set)
|
228
|
-
for item in valid_items:
|
229
|
-
if items_to_use_counter[item] > 0:
|
230
|
-
from_slot = find_item_in_inventory(item, current_inventory)
|
231
|
-
action = SymbolicMoveAction(
|
232
|
-
slot_from=from_slot,
|
233
|
-
slot_to=inventory_position,
|
234
|
-
quantity=1,
|
235
|
-
)
|
236
|
-
items_to_use_counter[item] -= 1
|
237
|
-
# update state of inventory
|
238
|
-
current_inventory = update_inventory(
|
239
|
-
current_inventory, from_slot, inventory_position, 1
|
240
|
-
)
|
241
|
-
self.subplans.append(action)
|
242
|
-
break
|
243
|
-
else:
|
244
|
-
raise NotImplementedError(f"Recipe type {type(plan_recipe)} not supported")
|
245
|
-
|
246
|
-
return self.subplans.pop(0)
|
247
|
-
|
248
|
-
def step(
|
249
|
-
self, observation: dict
|
250
|
-
) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
|
251
|
-
# add observation to history
|
252
|
-
self.history.add_observation_to_history(observation)
|
253
|
-
|
254
|
-
# get action
|
255
|
-
if len(self.plans) == 0:
|
256
|
-
self.get_plan(observation)
|
257
|
-
if self.plans is None:
|
258
|
-
self.plans = []
|
259
|
-
return StopAction()
|
260
|
-
|
261
|
-
action = self.get_next_action(observation)
|
262
|
-
|
263
|
-
# add action to history
|
264
|
-
self.history.add_action_to_history(action)
|
265
|
-
return action
|