plancraft 0.1.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- environments/__init__.py +0 -0
- environments/actions.py +218 -0
- environments/env_real.py +315 -0
- environments/env_symbolic.py +215 -0
- environments/items.py +10 -0
- environments/planner.py +109 -0
- environments/recipes.py +542 -0
- environments/sampler.py +224 -0
- models/__init__.py +21 -0
- models/act.py +184 -0
- models/base.py +152 -0
- models/bbox_model.py +492 -0
- models/dummy.py +54 -0
- models/few_shot_images/__init__.py +16 -0
- models/generators.py +483 -0
- models/oam.py +284 -0
- models/oracle.py +268 -0
- models/prompts.py +158 -0
- models/react.py +98 -0
- models/utils.py +289 -0
- plancraft-0.1.0.dist-info/LICENSE +21 -0
- plancraft-0.1.0.dist-info/METADATA +53 -0
- plancraft-0.1.0.dist-info/RECORD +26 -0
- plancraft-0.1.0.dist-info/WHEEL +5 -0
- plancraft-0.1.0.dist-info/top_level.txt +3 -0
- train/dataset.py +187 -0
models/bbox_model.py
ADDED
@@ -0,0 +1,492 @@
|
|
1
|
+
import einops
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
import torchvision.transforms.v2 as v2
|
6
|
+
from huggingface_hub import PyTorchModelHubMixin
|
7
|
+
from plancraft.environments.items import ALL_ITEMS
|
8
|
+
from torchvision.models.detection.faster_rcnn import (
|
9
|
+
fasterrcnn_resnet50_fpn_v2,
|
10
|
+
ResNet50_Weights,
|
11
|
+
)
|
12
|
+
from torchvision.models.detection.roi_heads import (
|
13
|
+
fastrcnn_loss,
|
14
|
+
keypointrcnn_inference,
|
15
|
+
keypointrcnn_loss,
|
16
|
+
maskrcnn_inference,
|
17
|
+
maskrcnn_loss,
|
18
|
+
)
|
19
|
+
from torchvision.ops import boxes as box_ops
|
20
|
+
|
21
|
+
|
22
|
+
def slot_to_bbox(slot: int):
|
23
|
+
# crafting slot
|
24
|
+
if slot == 0:
|
25
|
+
# slot size: 25x25
|
26
|
+
# top left corner: (x= 118, y=30)
|
27
|
+
box_size = 25
|
28
|
+
left_x = 117
|
29
|
+
top_y = 29
|
30
|
+
# crafting grid
|
31
|
+
elif slot < 10:
|
32
|
+
# slot size: 18x18
|
33
|
+
# top left corner: (x = 28 + 18 * col, y = 16 + 18 * row)
|
34
|
+
box_size = 18
|
35
|
+
row = (slot - 1) // 3
|
36
|
+
col = (slot - 1) % 3
|
37
|
+
left_x = 27 + (box_size * col)
|
38
|
+
top_y = 15 + (box_size * row)
|
39
|
+
# inventory
|
40
|
+
elif slot < 37:
|
41
|
+
# slot size: 18x18
|
42
|
+
# top left corner: (x= 6 + 18 * col, y=83 + 18 * row)
|
43
|
+
box_size = 18
|
44
|
+
row = (slot - 10) // 9
|
45
|
+
col = (slot - 10) % 9
|
46
|
+
left_x = 5 + (box_size * col)
|
47
|
+
top_y = 82 + (box_size * row)
|
48
|
+
# hotbar
|
49
|
+
else:
|
50
|
+
# slot size: 18x18
|
51
|
+
# top left corner: (x= 6 + 18 * col, y=141)
|
52
|
+
box_size = 18
|
53
|
+
col = (slot - 37) % 9
|
54
|
+
left_x = 5 + (box_size * col)
|
55
|
+
top_y = 140
|
56
|
+
return [left_x, top_y, left_x + box_size, top_y + box_size]
|
57
|
+
|
58
|
+
|
59
|
+
def precompute_slot_bboxes():
|
60
|
+
"""Precompute the bounding boxes for all slots."""
|
61
|
+
slot_bboxes = {}
|
62
|
+
for slot in range(46): # Assuming slot indices range from 0 to 45
|
63
|
+
slot_bboxes[slot] = slot_to_bbox(slot)
|
64
|
+
return slot_bboxes
|
65
|
+
|
66
|
+
|
67
|
+
# Precompute all slot bounding boxes
|
68
|
+
IDX_TO_BBOX = precompute_slot_bboxes()
|
69
|
+
|
70
|
+
|
71
|
+
def postprocess_detections_custom(
|
72
|
+
self,
|
73
|
+
class_logits,
|
74
|
+
quantity_logits,
|
75
|
+
box_features,
|
76
|
+
box_regression,
|
77
|
+
proposals,
|
78
|
+
image_shapes,
|
79
|
+
):
|
80
|
+
device = class_logits.device
|
81
|
+
num_classes = class_logits.shape[-1]
|
82
|
+
|
83
|
+
boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
|
84
|
+
pred_boxes = self.box_coder.decode(box_regression, proposals)
|
85
|
+
|
86
|
+
pred_scores = F.softmax(class_logits, -1)
|
87
|
+
|
88
|
+
pred_quantity = F.softmax(quantity_logits, -1).argmax(dim=-1)
|
89
|
+
# repeat the quantities, once for each class
|
90
|
+
pred_quantity = einops.repeat(
|
91
|
+
pred_quantity, "n -> n c", c=num_classes, n=pred_quantity.shape[0]
|
92
|
+
)
|
93
|
+
|
94
|
+
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
|
95
|
+
pred_scores_list = pred_scores.split(boxes_per_image, 0)
|
96
|
+
pred_quantity_list = pred_quantity.split(boxes_per_image, 0)
|
97
|
+
pred_features_list = box_features.split(boxes_per_image, 0)
|
98
|
+
|
99
|
+
all_boxes = []
|
100
|
+
all_scores = []
|
101
|
+
all_labels = []
|
102
|
+
all_quantity_labels = []
|
103
|
+
all_features = []
|
104
|
+
|
105
|
+
for boxes, scores, quantities, features, image_shape in zip(
|
106
|
+
pred_boxes_list,
|
107
|
+
pred_scores_list,
|
108
|
+
pred_quantity_list,
|
109
|
+
pred_features_list,
|
110
|
+
image_shapes,
|
111
|
+
):
|
112
|
+
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
|
113
|
+
|
114
|
+
# create labels for each prediction
|
115
|
+
labels = torch.arange(num_classes, device=device)
|
116
|
+
labels = labels.view(1, -1).expand_as(scores)
|
117
|
+
|
118
|
+
box_idxs = (
|
119
|
+
torch.arange(boxes.size(0), device=device).view(-1, 1).expand_as(labels)
|
120
|
+
)
|
121
|
+
|
122
|
+
# remove predictions with the background label
|
123
|
+
boxes = boxes[:, 1:]
|
124
|
+
scores = scores[:, 1:]
|
125
|
+
labels = labels[:, 1:]
|
126
|
+
quantities = quantities[:, 1:]
|
127
|
+
box_idxs = box_idxs[:, 1:]
|
128
|
+
|
129
|
+
# batch everything, by making every class prediction be a separate instance
|
130
|
+
boxes = boxes.reshape(-1, 4)
|
131
|
+
scores = scores.reshape(-1)
|
132
|
+
labels = labels.reshape(-1)
|
133
|
+
quantities = quantities.reshape(-1)
|
134
|
+
box_idxs = box_idxs.reshape(-1)
|
135
|
+
|
136
|
+
# remove low scoring boxes
|
137
|
+
inds = torch.where(scores > self.score_thresh)[0]
|
138
|
+
boxes, scores, labels, quantities, box_idxs = (
|
139
|
+
boxes[inds],
|
140
|
+
scores[inds],
|
141
|
+
labels[inds],
|
142
|
+
quantities[inds],
|
143
|
+
box_idxs[inds],
|
144
|
+
)
|
145
|
+
|
146
|
+
# remove empty boxes
|
147
|
+
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
|
148
|
+
boxes, scores, labels, quantities, box_idxs = (
|
149
|
+
boxes[keep],
|
150
|
+
scores[keep],
|
151
|
+
labels[keep],
|
152
|
+
quantities[keep],
|
153
|
+
box_idxs[keep],
|
154
|
+
)
|
155
|
+
|
156
|
+
# non-maximum suppression, independently done per class
|
157
|
+
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
|
158
|
+
# keep only topk scoring predictions
|
159
|
+
keep = keep[: self.detections_per_img]
|
160
|
+
boxes, scores, labels, quantities, box_idxs = (
|
161
|
+
boxes[keep],
|
162
|
+
scores[keep],
|
163
|
+
labels[keep],
|
164
|
+
quantities[keep],
|
165
|
+
box_idxs[keep],
|
166
|
+
)
|
167
|
+
|
168
|
+
all_boxes.append(boxes)
|
169
|
+
all_scores.append(scores)
|
170
|
+
all_labels.append(labels)
|
171
|
+
all_quantity_labels.append(quantities)
|
172
|
+
all_features.append(features[box_idxs])
|
173
|
+
|
174
|
+
return all_boxes, all_scores, all_labels, all_quantity_labels, all_features
|
175
|
+
|
176
|
+
|
177
|
+
def forward_custom(
|
178
|
+
self,
|
179
|
+
features,
|
180
|
+
proposals,
|
181
|
+
image_shapes,
|
182
|
+
targets=None,
|
183
|
+
):
|
184
|
+
training = False
|
185
|
+
if targets is not None:
|
186
|
+
training = True
|
187
|
+
for t in targets:
|
188
|
+
floating_point_types = (torch.float, torch.double, torch.half)
|
189
|
+
if t["boxes"].dtype not in floating_point_types:
|
190
|
+
raise TypeError(
|
191
|
+
f"target boxes must of float type, instead got {t['boxes'].dtype}"
|
192
|
+
)
|
193
|
+
if not t["labels"].dtype == torch.int64:
|
194
|
+
raise TypeError(
|
195
|
+
f"target labels must of int64 type, instead got {t['labels'].dtype}"
|
196
|
+
)
|
197
|
+
if self.has_keypoint():
|
198
|
+
if not t["keypoints"].dtype == torch.float32:
|
199
|
+
raise TypeError(
|
200
|
+
f"target keypoints must of float type, instead got {t['keypoints'].dtype}"
|
201
|
+
)
|
202
|
+
|
203
|
+
if training:
|
204
|
+
proposals, matched_idxs, labels, regression_targets = (
|
205
|
+
self.select_training_samples(proposals, targets)
|
206
|
+
)
|
207
|
+
else:
|
208
|
+
labels = None
|
209
|
+
regression_targets = None
|
210
|
+
matched_idxs = None
|
211
|
+
|
212
|
+
box_features = self.box_roi_pool(features, proposals, image_shapes)
|
213
|
+
box_features = self.box_head(box_features)
|
214
|
+
class_logits, box_regression = self.box_predictor(box_features)
|
215
|
+
|
216
|
+
result = []
|
217
|
+
losses = {}
|
218
|
+
if training:
|
219
|
+
if labels is None:
|
220
|
+
raise ValueError("labels cannot be None")
|
221
|
+
if regression_targets is None:
|
222
|
+
raise ValueError("regression_targets cannot be None")
|
223
|
+
loss_classifier, loss_box_reg = fastrcnn_loss(
|
224
|
+
class_logits, box_regression, labels, regression_targets
|
225
|
+
)
|
226
|
+
|
227
|
+
# custom addition to calculate quantity loss
|
228
|
+
dtype = proposals[0].dtype
|
229
|
+
gt_boxes = [t["boxes"].to(dtype) for t in targets]
|
230
|
+
gt_labels = [t["quantity_labels"] for t in targets]
|
231
|
+
_, quantity_labels = self.assign_targets_to_proposals(
|
232
|
+
proposals, gt_boxes, gt_labels
|
233
|
+
)
|
234
|
+
quantity_labels = torch.cat(quantity_labels, dim=0)
|
235
|
+
# needs quantity_prediction layer to be added to class
|
236
|
+
quantity_preds = self.quantity_prediction(box_features)
|
237
|
+
loss_classsifier_quantity = F.cross_entropy(
|
238
|
+
quantity_preds,
|
239
|
+
quantity_labels,
|
240
|
+
)
|
241
|
+
losses = {
|
242
|
+
"loss_classifier": loss_classifier,
|
243
|
+
"loss_box_reg": loss_box_reg,
|
244
|
+
"loss_classifier_quantity": loss_classsifier_quantity,
|
245
|
+
}
|
246
|
+
else:
|
247
|
+
quantity_logits = self.quantity_prediction(box_features)
|
248
|
+
|
249
|
+
boxes, scores, labels, quantities, features = postprocess_detections_custom(
|
250
|
+
self,
|
251
|
+
class_logits,
|
252
|
+
quantity_logits,
|
253
|
+
box_features,
|
254
|
+
box_regression,
|
255
|
+
proposals,
|
256
|
+
image_shapes,
|
257
|
+
)
|
258
|
+
num_images = len(boxes)
|
259
|
+
for i in range(num_images):
|
260
|
+
result.append(
|
261
|
+
{
|
262
|
+
"boxes": boxes[i],
|
263
|
+
"labels": labels[i],
|
264
|
+
"scores": scores[i],
|
265
|
+
"quantities": quantities[i],
|
266
|
+
"features": features[i],
|
267
|
+
}
|
268
|
+
)
|
269
|
+
|
270
|
+
if self.has_mask():
|
271
|
+
mask_proposals = [p["boxes"] for p in result]
|
272
|
+
if training:
|
273
|
+
if matched_idxs is None:
|
274
|
+
raise ValueError("if in training, matched_idxs should not be None")
|
275
|
+
|
276
|
+
# during training, only focus on positive boxes
|
277
|
+
num_images = len(proposals)
|
278
|
+
mask_proposals = []
|
279
|
+
pos_matched_idxs = []
|
280
|
+
for img_id in range(num_images):
|
281
|
+
pos = torch.where(labels[img_id] > 0)[0]
|
282
|
+
mask_proposals.append(proposals[img_id][pos])
|
283
|
+
pos_matched_idxs.append(matched_idxs[img_id][pos])
|
284
|
+
else:
|
285
|
+
pos_matched_idxs = None
|
286
|
+
|
287
|
+
if self.mask_roi_pool is not None:
|
288
|
+
mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
|
289
|
+
mask_features = self.mask_head(mask_features)
|
290
|
+
mask_logits = self.mask_predictor(mask_features)
|
291
|
+
else:
|
292
|
+
raise Exception("Expected mask_roi_pool to be not None")
|
293
|
+
|
294
|
+
loss_mask = {}
|
295
|
+
if training:
|
296
|
+
if targets is None or pos_matched_idxs is None or mask_logits is None:
|
297
|
+
raise ValueError(
|
298
|
+
"targets, pos_matched_idxs, mask_logits cannot be None when training"
|
299
|
+
)
|
300
|
+
|
301
|
+
gt_masks = [t["masks"] for t in targets]
|
302
|
+
gt_labels = [t["labels"] for t in targets]
|
303
|
+
rcnn_loss_mask = maskrcnn_loss(
|
304
|
+
mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs
|
305
|
+
)
|
306
|
+
loss_mask = {"loss_mask": rcnn_loss_mask}
|
307
|
+
else:
|
308
|
+
labels = [r["labels"] for r in result]
|
309
|
+
masks_probs = maskrcnn_inference(mask_logits, labels)
|
310
|
+
for mask_prob, r in zip(masks_probs, result):
|
311
|
+
r["masks"] = mask_prob
|
312
|
+
|
313
|
+
losses.update(loss_mask)
|
314
|
+
|
315
|
+
# keep none checks in if conditional so torchscript will conditionally
|
316
|
+
# compile each branch
|
317
|
+
if (
|
318
|
+
self.keypoint_roi_pool is not None
|
319
|
+
and self.keypoint_head is not None
|
320
|
+
and self.keypoint_predictor is not None
|
321
|
+
):
|
322
|
+
keypoint_proposals = [p["boxes"] for p in result]
|
323
|
+
if training:
|
324
|
+
# during training, only focus on positive boxes
|
325
|
+
num_images = len(proposals)
|
326
|
+
keypoint_proposals = []
|
327
|
+
pos_matched_idxs = []
|
328
|
+
if matched_idxs is None:
|
329
|
+
raise ValueError("if in trainning, matched_idxs should not be None")
|
330
|
+
|
331
|
+
for img_id in range(num_images):
|
332
|
+
pos = torch.where(labels[img_id] > 0)[0]
|
333
|
+
keypoint_proposals.append(proposals[img_id][pos])
|
334
|
+
pos_matched_idxs.append(matched_idxs[img_id][pos])
|
335
|
+
else:
|
336
|
+
pos_matched_idxs = None
|
337
|
+
|
338
|
+
keypoint_features = self.keypoint_roi_pool(
|
339
|
+
features, keypoint_proposals, image_shapes
|
340
|
+
)
|
341
|
+
keypoint_features = self.keypoint_head(keypoint_features)
|
342
|
+
keypoint_logits = self.keypoint_predictor(keypoint_features)
|
343
|
+
|
344
|
+
loss_keypoint = {}
|
345
|
+
if training:
|
346
|
+
if targets is None or pos_matched_idxs is None:
|
347
|
+
raise ValueError(
|
348
|
+
"both targets and pos_matched_idxs should not be None when in training mode"
|
349
|
+
)
|
350
|
+
|
351
|
+
gt_keypoints = [t["keypoints"] for t in targets]
|
352
|
+
rcnn_loss_keypoint = keypointrcnn_loss(
|
353
|
+
keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
|
354
|
+
)
|
355
|
+
loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
|
356
|
+
else:
|
357
|
+
if keypoint_logits is None or keypoint_proposals is None:
|
358
|
+
raise ValueError(
|
359
|
+
"both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
360
|
+
)
|
361
|
+
|
362
|
+
keypoints_probs, kp_scores = keypointrcnn_inference(
|
363
|
+
keypoint_logits, keypoint_proposals
|
364
|
+
)
|
365
|
+
for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
|
366
|
+
r["keypoints"] = keypoint_prob
|
367
|
+
r["keypoints_scores"] = kps
|
368
|
+
losses.update(loss_keypoint)
|
369
|
+
|
370
|
+
return result, losses
|
371
|
+
|
372
|
+
|
373
|
+
def calculate_iou(boxA, boxB):
|
374
|
+
"""Calculate Intersection over Union (IoU) between two bounding boxes."""
|
375
|
+
# Determine the coordinates of the intersection rectangle
|
376
|
+
xA = max(boxA[0], boxB[0])
|
377
|
+
yA = max(boxA[1], boxB[1])
|
378
|
+
xB = min(boxA[2], boxB[2])
|
379
|
+
yB = min(boxA[3], boxB[3])
|
380
|
+
|
381
|
+
# Compute the area of intersection
|
382
|
+
interArea = max(0, xB - xA) * max(0, yB - yA)
|
383
|
+
|
384
|
+
# Compute the area of both the bounding boxes
|
385
|
+
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
386
|
+
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
387
|
+
|
388
|
+
# Compute the IoU
|
389
|
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
390
|
+
|
391
|
+
return iou
|
392
|
+
|
393
|
+
|
394
|
+
def bbox_to_slot_index_iou(bbox: tuple[int, int, int, int]) -> int:
|
395
|
+
"""Assign the given bounding box to the slot with the highest IoU."""
|
396
|
+
best_slot = None
|
397
|
+
best_iou = -1
|
398
|
+
# Iterate through all precomputed slot bounding boxes
|
399
|
+
for slot, slot_bbox in IDX_TO_BBOX.items():
|
400
|
+
iou = calculate_iou(bbox, slot_bbox)
|
401
|
+
if iou > best_iou:
|
402
|
+
best_iou = iou
|
403
|
+
best_slot = slot
|
404
|
+
return best_slot
|
405
|
+
|
406
|
+
|
407
|
+
class IntegratedBoundingBoxModel(nn.Module, PyTorchModelHubMixin):
|
408
|
+
"""
|
409
|
+
Custom mask rcnn model with quantity prediction
|
410
|
+
|
411
|
+
Also returns the feature vectors of the detected boxes
|
412
|
+
"""
|
413
|
+
|
414
|
+
def __init__(self, load_resnet_weights=False):
|
415
|
+
super(IntegratedBoundingBoxModel, self).__init__()
|
416
|
+
weights = None
|
417
|
+
if load_resnet_weights:
|
418
|
+
weights = ResNet50_Weights.DEFAULT
|
419
|
+
|
420
|
+
self.model = fasterrcnn_resnet50_fpn_v2(
|
421
|
+
weights_backbone=weights,
|
422
|
+
image_mean=[0.63, 0.63, 0.63],
|
423
|
+
image_std=[0.21, 0.21, 0.21],
|
424
|
+
min_size=128,
|
425
|
+
max_size=256,
|
426
|
+
num_classes=len(ALL_ITEMS),
|
427
|
+
box_score_thresh=0.05,
|
428
|
+
rpn_batch_size_per_image=64,
|
429
|
+
box_detections_per_img=64,
|
430
|
+
box_batch_size_per_image=128,
|
431
|
+
)
|
432
|
+
self.model.roi_heads.quantity_prediction = nn.Linear(1024, 65)
|
433
|
+
|
434
|
+
# replace the head with leaky activations
|
435
|
+
self.model.roi_heads.forward = forward_custom.__get__(
|
436
|
+
self.model.roi_heads, type(self.model.roi_heads)
|
437
|
+
)
|
438
|
+
|
439
|
+
self.transform = v2.Compose(
|
440
|
+
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
|
441
|
+
)
|
442
|
+
|
443
|
+
def forward(self, x, targets=None):
|
444
|
+
if self.training:
|
445
|
+
# normal forward pass
|
446
|
+
loss_dict = self.model(x, targets)
|
447
|
+
return loss_dict
|
448
|
+
else:
|
449
|
+
preds = self.model(x)
|
450
|
+
return preds
|
451
|
+
|
452
|
+
def get_inventory(self, pil_image):
|
453
|
+
"""
|
454
|
+
Predict boxes and quantities
|
455
|
+
"""
|
456
|
+
img_tensor = self.transform(pil_image)
|
457
|
+
if next(self.model.parameters()).is_cuda:
|
458
|
+
img_tensor = img_tensor.cuda()
|
459
|
+
with torch.no_grad():
|
460
|
+
predictions = self.model(img_tensor.unsqueeze(0))
|
461
|
+
return self.prediction_to_inventory(predictions[0])
|
462
|
+
|
463
|
+
@staticmethod
|
464
|
+
def prediction_to_inventory(prediction, threshold=0.9) -> list[dict]:
|
465
|
+
inventory = []
|
466
|
+
seen_slots = set()
|
467
|
+
for i in range(len(prediction["boxes"])):
|
468
|
+
slot = bbox_to_slot_index_iou(prediction["boxes"][i])
|
469
|
+
score = prediction["scores"][i]
|
470
|
+
label_idx = prediction["labels"][i].item()
|
471
|
+
label = ALL_ITEMS[label_idx]
|
472
|
+
quantity = prediction["quantities"][i].item()
|
473
|
+
if score > threshold:
|
474
|
+
if slot in seen_slots:
|
475
|
+
continue
|
476
|
+
inventory.append({"slot": slot, "type": label, "quantity": quantity})
|
477
|
+
return inventory
|
478
|
+
|
479
|
+
def freeze(self):
|
480
|
+
# NOTE: this might seem excessive
|
481
|
+
# but transformers trainer is really good at enabling gradients against my will
|
482
|
+
self.eval()
|
483
|
+
self.model.eval()
|
484
|
+
self.training = False
|
485
|
+
for param in self.model.parameters():
|
486
|
+
param.requires_grad = False
|
487
|
+
self.model.training = False
|
488
|
+
self.model.roi_heads.training = False
|
489
|
+
self.model.rpn.training = False
|
490
|
+
|
491
|
+
def save(self, path: str):
|
492
|
+
torch.save(self.state_dict(), path)
|
models/dummy.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
import random
|
2
|
+
|
3
|
+
from plancraft.config import EvalConfig
|
4
|
+
from plancraft.environments.actions import (
|
5
|
+
RealActionInteraction,
|
6
|
+
SymbolicMoveAction,
|
7
|
+
SymbolicSmeltAction,
|
8
|
+
)
|
9
|
+
from plancraft.models.base import ABCModel, History
|
10
|
+
|
11
|
+
|
12
|
+
class DummyModel(ABCModel):
|
13
|
+
"""
|
14
|
+
Dummy model returns actions that do random action
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, cfg: EvalConfig):
|
18
|
+
self.symbolic_move_action = cfg.plancraft.environment.symbolic_action_space
|
19
|
+
self.history = History(objective="")
|
20
|
+
|
21
|
+
def random_select(self, observation):
|
22
|
+
if observation is None or "inventory" not in observation:
|
23
|
+
return SymbolicMoveAction(slot_from=0, slot_to=0, quantity=1)
|
24
|
+
# randomly pick an item from the inventory
|
25
|
+
item_indices = set()
|
26
|
+
for item in observation["inventory"]:
|
27
|
+
if item["quantity"] > 0:
|
28
|
+
item_indices.add(item["index"])
|
29
|
+
all_slots_to = set(range(1, 46))
|
30
|
+
empty_slots = all_slots_to - item_indices
|
31
|
+
|
32
|
+
random_slot_from = random.choice(list(item_indices))
|
33
|
+
random_slot_to = random.choice(list(empty_slots))
|
34
|
+
|
35
|
+
return SymbolicMoveAction(
|
36
|
+
slot_from=random_slot_from, slot_to=random_slot_to, quantity=1
|
37
|
+
)
|
38
|
+
|
39
|
+
def step(
|
40
|
+
self, observation: dict
|
41
|
+
) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
|
42
|
+
# add observation to history
|
43
|
+
self.history.add_observation_to_history(observation)
|
44
|
+
|
45
|
+
# get action
|
46
|
+
if self.symbolic_move_action:
|
47
|
+
action = self.random_select(observation)
|
48
|
+
else:
|
49
|
+
action = RealActionInteraction()
|
50
|
+
|
51
|
+
# add action to history
|
52
|
+
self.history.add_action_to_history(action)
|
53
|
+
|
54
|
+
return action
|
@@ -0,0 +1,16 @@
|
|
1
|
+
import os
|
2
|
+
import glob
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import imageio
|
6
|
+
|
7
|
+
|
8
|
+
def get_few_shot_images_path():
|
9
|
+
return os.path.dirname(__file__)
|
10
|
+
|
11
|
+
|
12
|
+
def load_prompt_images() -> list[np.ndarray]:
|
13
|
+
current_dir = get_few_shot_images_path()
|
14
|
+
files = glob.glob(os.path.join(current_dir, "*.png"))
|
15
|
+
images = [imageio.imread(file) for file in files]
|
16
|
+
return images
|