plancraft 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- environments/__init__.py +0 -0
- environments/actions.py +218 -0
- environments/env_real.py +315 -0
- environments/env_symbolic.py +215 -0
- environments/items.py +10 -0
- environments/planner.py +109 -0
- environments/recipes.py +542 -0
- environments/sampler.py +224 -0
- models/__init__.py +21 -0
- models/act.py +184 -0
- models/base.py +152 -0
- models/bbox_model.py +492 -0
- models/dummy.py +54 -0
- models/few_shot_images/__init__.py +16 -0
- models/generators.py +483 -0
- models/oam.py +284 -0
- models/oracle.py +268 -0
- models/prompts.py +158 -0
- models/react.py +98 -0
- models/utils.py +289 -0
- plancraft-0.1.0.dist-info/LICENSE +21 -0
- plancraft-0.1.0.dist-info/METADATA +53 -0
- plancraft-0.1.0.dist-info/RECORD +26 -0
- plancraft-0.1.0.dist-info/WHEEL +5 -0
- plancraft-0.1.0.dist-info/top_level.txt +3 -0
- train/dataset.py +187 -0
models/bbox_model.py
ADDED
@@ -0,0 +1,492 @@
|
|
1
|
+
import einops
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
import torchvision.transforms.v2 as v2
|
6
|
+
from huggingface_hub import PyTorchModelHubMixin
|
7
|
+
from plancraft.environments.items import ALL_ITEMS
|
8
|
+
from torchvision.models.detection.faster_rcnn import (
|
9
|
+
fasterrcnn_resnet50_fpn_v2,
|
10
|
+
ResNet50_Weights,
|
11
|
+
)
|
12
|
+
from torchvision.models.detection.roi_heads import (
|
13
|
+
fastrcnn_loss,
|
14
|
+
keypointrcnn_inference,
|
15
|
+
keypointrcnn_loss,
|
16
|
+
maskrcnn_inference,
|
17
|
+
maskrcnn_loss,
|
18
|
+
)
|
19
|
+
from torchvision.ops import boxes as box_ops
|
20
|
+
|
21
|
+
|
22
|
+
def slot_to_bbox(slot: int):
|
23
|
+
# crafting slot
|
24
|
+
if slot == 0:
|
25
|
+
# slot size: 25x25
|
26
|
+
# top left corner: (x= 118, y=30)
|
27
|
+
box_size = 25
|
28
|
+
left_x = 117
|
29
|
+
top_y = 29
|
30
|
+
# crafting grid
|
31
|
+
elif slot < 10:
|
32
|
+
# slot size: 18x18
|
33
|
+
# top left corner: (x = 28 + 18 * col, y = 16 + 18 * row)
|
34
|
+
box_size = 18
|
35
|
+
row = (slot - 1) // 3
|
36
|
+
col = (slot - 1) % 3
|
37
|
+
left_x = 27 + (box_size * col)
|
38
|
+
top_y = 15 + (box_size * row)
|
39
|
+
# inventory
|
40
|
+
elif slot < 37:
|
41
|
+
# slot size: 18x18
|
42
|
+
# top left corner: (x= 6 + 18 * col, y=83 + 18 * row)
|
43
|
+
box_size = 18
|
44
|
+
row = (slot - 10) // 9
|
45
|
+
col = (slot - 10) % 9
|
46
|
+
left_x = 5 + (box_size * col)
|
47
|
+
top_y = 82 + (box_size * row)
|
48
|
+
# hotbar
|
49
|
+
else:
|
50
|
+
# slot size: 18x18
|
51
|
+
# top left corner: (x= 6 + 18 * col, y=141)
|
52
|
+
box_size = 18
|
53
|
+
col = (slot - 37) % 9
|
54
|
+
left_x = 5 + (box_size * col)
|
55
|
+
top_y = 140
|
56
|
+
return [left_x, top_y, left_x + box_size, top_y + box_size]
|
57
|
+
|
58
|
+
|
59
|
+
def precompute_slot_bboxes():
|
60
|
+
"""Precompute the bounding boxes for all slots."""
|
61
|
+
slot_bboxes = {}
|
62
|
+
for slot in range(46): # Assuming slot indices range from 0 to 45
|
63
|
+
slot_bboxes[slot] = slot_to_bbox(slot)
|
64
|
+
return slot_bboxes
|
65
|
+
|
66
|
+
|
67
|
+
# Precompute all slot bounding boxes
|
68
|
+
IDX_TO_BBOX = precompute_slot_bboxes()
|
69
|
+
|
70
|
+
|
71
|
+
def postprocess_detections_custom(
|
72
|
+
self,
|
73
|
+
class_logits,
|
74
|
+
quantity_logits,
|
75
|
+
box_features,
|
76
|
+
box_regression,
|
77
|
+
proposals,
|
78
|
+
image_shapes,
|
79
|
+
):
|
80
|
+
device = class_logits.device
|
81
|
+
num_classes = class_logits.shape[-1]
|
82
|
+
|
83
|
+
boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
|
84
|
+
pred_boxes = self.box_coder.decode(box_regression, proposals)
|
85
|
+
|
86
|
+
pred_scores = F.softmax(class_logits, -1)
|
87
|
+
|
88
|
+
pred_quantity = F.softmax(quantity_logits, -1).argmax(dim=-1)
|
89
|
+
# repeat the quantities, once for each class
|
90
|
+
pred_quantity = einops.repeat(
|
91
|
+
pred_quantity, "n -> n c", c=num_classes, n=pred_quantity.shape[0]
|
92
|
+
)
|
93
|
+
|
94
|
+
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
|
95
|
+
pred_scores_list = pred_scores.split(boxes_per_image, 0)
|
96
|
+
pred_quantity_list = pred_quantity.split(boxes_per_image, 0)
|
97
|
+
pred_features_list = box_features.split(boxes_per_image, 0)
|
98
|
+
|
99
|
+
all_boxes = []
|
100
|
+
all_scores = []
|
101
|
+
all_labels = []
|
102
|
+
all_quantity_labels = []
|
103
|
+
all_features = []
|
104
|
+
|
105
|
+
for boxes, scores, quantities, features, image_shape in zip(
|
106
|
+
pred_boxes_list,
|
107
|
+
pred_scores_list,
|
108
|
+
pred_quantity_list,
|
109
|
+
pred_features_list,
|
110
|
+
image_shapes,
|
111
|
+
):
|
112
|
+
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
|
113
|
+
|
114
|
+
# create labels for each prediction
|
115
|
+
labels = torch.arange(num_classes, device=device)
|
116
|
+
labels = labels.view(1, -1).expand_as(scores)
|
117
|
+
|
118
|
+
box_idxs = (
|
119
|
+
torch.arange(boxes.size(0), device=device).view(-1, 1).expand_as(labels)
|
120
|
+
)
|
121
|
+
|
122
|
+
# remove predictions with the background label
|
123
|
+
boxes = boxes[:, 1:]
|
124
|
+
scores = scores[:, 1:]
|
125
|
+
labels = labels[:, 1:]
|
126
|
+
quantities = quantities[:, 1:]
|
127
|
+
box_idxs = box_idxs[:, 1:]
|
128
|
+
|
129
|
+
# batch everything, by making every class prediction be a separate instance
|
130
|
+
boxes = boxes.reshape(-1, 4)
|
131
|
+
scores = scores.reshape(-1)
|
132
|
+
labels = labels.reshape(-1)
|
133
|
+
quantities = quantities.reshape(-1)
|
134
|
+
box_idxs = box_idxs.reshape(-1)
|
135
|
+
|
136
|
+
# remove low scoring boxes
|
137
|
+
inds = torch.where(scores > self.score_thresh)[0]
|
138
|
+
boxes, scores, labels, quantities, box_idxs = (
|
139
|
+
boxes[inds],
|
140
|
+
scores[inds],
|
141
|
+
labels[inds],
|
142
|
+
quantities[inds],
|
143
|
+
box_idxs[inds],
|
144
|
+
)
|
145
|
+
|
146
|
+
# remove empty boxes
|
147
|
+
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
|
148
|
+
boxes, scores, labels, quantities, box_idxs = (
|
149
|
+
boxes[keep],
|
150
|
+
scores[keep],
|
151
|
+
labels[keep],
|
152
|
+
quantities[keep],
|
153
|
+
box_idxs[keep],
|
154
|
+
)
|
155
|
+
|
156
|
+
# non-maximum suppression, independently done per class
|
157
|
+
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
|
158
|
+
# keep only topk scoring predictions
|
159
|
+
keep = keep[: self.detections_per_img]
|
160
|
+
boxes, scores, labels, quantities, box_idxs = (
|
161
|
+
boxes[keep],
|
162
|
+
scores[keep],
|
163
|
+
labels[keep],
|
164
|
+
quantities[keep],
|
165
|
+
box_idxs[keep],
|
166
|
+
)
|
167
|
+
|
168
|
+
all_boxes.append(boxes)
|
169
|
+
all_scores.append(scores)
|
170
|
+
all_labels.append(labels)
|
171
|
+
all_quantity_labels.append(quantities)
|
172
|
+
all_features.append(features[box_idxs])
|
173
|
+
|
174
|
+
return all_boxes, all_scores, all_labels, all_quantity_labels, all_features
|
175
|
+
|
176
|
+
|
177
|
+
def forward_custom(
|
178
|
+
self,
|
179
|
+
features,
|
180
|
+
proposals,
|
181
|
+
image_shapes,
|
182
|
+
targets=None,
|
183
|
+
):
|
184
|
+
training = False
|
185
|
+
if targets is not None:
|
186
|
+
training = True
|
187
|
+
for t in targets:
|
188
|
+
floating_point_types = (torch.float, torch.double, torch.half)
|
189
|
+
if t["boxes"].dtype not in floating_point_types:
|
190
|
+
raise TypeError(
|
191
|
+
f"target boxes must of float type, instead got {t['boxes'].dtype}"
|
192
|
+
)
|
193
|
+
if not t["labels"].dtype == torch.int64:
|
194
|
+
raise TypeError(
|
195
|
+
f"target labels must of int64 type, instead got {t['labels'].dtype}"
|
196
|
+
)
|
197
|
+
if self.has_keypoint():
|
198
|
+
if not t["keypoints"].dtype == torch.float32:
|
199
|
+
raise TypeError(
|
200
|
+
f"target keypoints must of float type, instead got {t['keypoints'].dtype}"
|
201
|
+
)
|
202
|
+
|
203
|
+
if training:
|
204
|
+
proposals, matched_idxs, labels, regression_targets = (
|
205
|
+
self.select_training_samples(proposals, targets)
|
206
|
+
)
|
207
|
+
else:
|
208
|
+
labels = None
|
209
|
+
regression_targets = None
|
210
|
+
matched_idxs = None
|
211
|
+
|
212
|
+
box_features = self.box_roi_pool(features, proposals, image_shapes)
|
213
|
+
box_features = self.box_head(box_features)
|
214
|
+
class_logits, box_regression = self.box_predictor(box_features)
|
215
|
+
|
216
|
+
result = []
|
217
|
+
losses = {}
|
218
|
+
if training:
|
219
|
+
if labels is None:
|
220
|
+
raise ValueError("labels cannot be None")
|
221
|
+
if regression_targets is None:
|
222
|
+
raise ValueError("regression_targets cannot be None")
|
223
|
+
loss_classifier, loss_box_reg = fastrcnn_loss(
|
224
|
+
class_logits, box_regression, labels, regression_targets
|
225
|
+
)
|
226
|
+
|
227
|
+
# custom addition to calculate quantity loss
|
228
|
+
dtype = proposals[0].dtype
|
229
|
+
gt_boxes = [t["boxes"].to(dtype) for t in targets]
|
230
|
+
gt_labels = [t["quantity_labels"] for t in targets]
|
231
|
+
_, quantity_labels = self.assign_targets_to_proposals(
|
232
|
+
proposals, gt_boxes, gt_labels
|
233
|
+
)
|
234
|
+
quantity_labels = torch.cat(quantity_labels, dim=0)
|
235
|
+
# needs quantity_prediction layer to be added to class
|
236
|
+
quantity_preds = self.quantity_prediction(box_features)
|
237
|
+
loss_classsifier_quantity = F.cross_entropy(
|
238
|
+
quantity_preds,
|
239
|
+
quantity_labels,
|
240
|
+
)
|
241
|
+
losses = {
|
242
|
+
"loss_classifier": loss_classifier,
|
243
|
+
"loss_box_reg": loss_box_reg,
|
244
|
+
"loss_classifier_quantity": loss_classsifier_quantity,
|
245
|
+
}
|
246
|
+
else:
|
247
|
+
quantity_logits = self.quantity_prediction(box_features)
|
248
|
+
|
249
|
+
boxes, scores, labels, quantities, features = postprocess_detections_custom(
|
250
|
+
self,
|
251
|
+
class_logits,
|
252
|
+
quantity_logits,
|
253
|
+
box_features,
|
254
|
+
box_regression,
|
255
|
+
proposals,
|
256
|
+
image_shapes,
|
257
|
+
)
|
258
|
+
num_images = len(boxes)
|
259
|
+
for i in range(num_images):
|
260
|
+
result.append(
|
261
|
+
{
|
262
|
+
"boxes": boxes[i],
|
263
|
+
"labels": labels[i],
|
264
|
+
"scores": scores[i],
|
265
|
+
"quantities": quantities[i],
|
266
|
+
"features": features[i],
|
267
|
+
}
|
268
|
+
)
|
269
|
+
|
270
|
+
if self.has_mask():
|
271
|
+
mask_proposals = [p["boxes"] for p in result]
|
272
|
+
if training:
|
273
|
+
if matched_idxs is None:
|
274
|
+
raise ValueError("if in training, matched_idxs should not be None")
|
275
|
+
|
276
|
+
# during training, only focus on positive boxes
|
277
|
+
num_images = len(proposals)
|
278
|
+
mask_proposals = []
|
279
|
+
pos_matched_idxs = []
|
280
|
+
for img_id in range(num_images):
|
281
|
+
pos = torch.where(labels[img_id] > 0)[0]
|
282
|
+
mask_proposals.append(proposals[img_id][pos])
|
283
|
+
pos_matched_idxs.append(matched_idxs[img_id][pos])
|
284
|
+
else:
|
285
|
+
pos_matched_idxs = None
|
286
|
+
|
287
|
+
if self.mask_roi_pool is not None:
|
288
|
+
mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
|
289
|
+
mask_features = self.mask_head(mask_features)
|
290
|
+
mask_logits = self.mask_predictor(mask_features)
|
291
|
+
else:
|
292
|
+
raise Exception("Expected mask_roi_pool to be not None")
|
293
|
+
|
294
|
+
loss_mask = {}
|
295
|
+
if training:
|
296
|
+
if targets is None or pos_matched_idxs is None or mask_logits is None:
|
297
|
+
raise ValueError(
|
298
|
+
"targets, pos_matched_idxs, mask_logits cannot be None when training"
|
299
|
+
)
|
300
|
+
|
301
|
+
gt_masks = [t["masks"] for t in targets]
|
302
|
+
gt_labels = [t["labels"] for t in targets]
|
303
|
+
rcnn_loss_mask = maskrcnn_loss(
|
304
|
+
mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs
|
305
|
+
)
|
306
|
+
loss_mask = {"loss_mask": rcnn_loss_mask}
|
307
|
+
else:
|
308
|
+
labels = [r["labels"] for r in result]
|
309
|
+
masks_probs = maskrcnn_inference(mask_logits, labels)
|
310
|
+
for mask_prob, r in zip(masks_probs, result):
|
311
|
+
r["masks"] = mask_prob
|
312
|
+
|
313
|
+
losses.update(loss_mask)
|
314
|
+
|
315
|
+
# keep none checks in if conditional so torchscript will conditionally
|
316
|
+
# compile each branch
|
317
|
+
if (
|
318
|
+
self.keypoint_roi_pool is not None
|
319
|
+
and self.keypoint_head is not None
|
320
|
+
and self.keypoint_predictor is not None
|
321
|
+
):
|
322
|
+
keypoint_proposals = [p["boxes"] for p in result]
|
323
|
+
if training:
|
324
|
+
# during training, only focus on positive boxes
|
325
|
+
num_images = len(proposals)
|
326
|
+
keypoint_proposals = []
|
327
|
+
pos_matched_idxs = []
|
328
|
+
if matched_idxs is None:
|
329
|
+
raise ValueError("if in trainning, matched_idxs should not be None")
|
330
|
+
|
331
|
+
for img_id in range(num_images):
|
332
|
+
pos = torch.where(labels[img_id] > 0)[0]
|
333
|
+
keypoint_proposals.append(proposals[img_id][pos])
|
334
|
+
pos_matched_idxs.append(matched_idxs[img_id][pos])
|
335
|
+
else:
|
336
|
+
pos_matched_idxs = None
|
337
|
+
|
338
|
+
keypoint_features = self.keypoint_roi_pool(
|
339
|
+
features, keypoint_proposals, image_shapes
|
340
|
+
)
|
341
|
+
keypoint_features = self.keypoint_head(keypoint_features)
|
342
|
+
keypoint_logits = self.keypoint_predictor(keypoint_features)
|
343
|
+
|
344
|
+
loss_keypoint = {}
|
345
|
+
if training:
|
346
|
+
if targets is None or pos_matched_idxs is None:
|
347
|
+
raise ValueError(
|
348
|
+
"both targets and pos_matched_idxs should not be None when in training mode"
|
349
|
+
)
|
350
|
+
|
351
|
+
gt_keypoints = [t["keypoints"] for t in targets]
|
352
|
+
rcnn_loss_keypoint = keypointrcnn_loss(
|
353
|
+
keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
|
354
|
+
)
|
355
|
+
loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
|
356
|
+
else:
|
357
|
+
if keypoint_logits is None or keypoint_proposals is None:
|
358
|
+
raise ValueError(
|
359
|
+
"both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
360
|
+
)
|
361
|
+
|
362
|
+
keypoints_probs, kp_scores = keypointrcnn_inference(
|
363
|
+
keypoint_logits, keypoint_proposals
|
364
|
+
)
|
365
|
+
for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
|
366
|
+
r["keypoints"] = keypoint_prob
|
367
|
+
r["keypoints_scores"] = kps
|
368
|
+
losses.update(loss_keypoint)
|
369
|
+
|
370
|
+
return result, losses
|
371
|
+
|
372
|
+
|
373
|
+
def calculate_iou(boxA, boxB):
|
374
|
+
"""Calculate Intersection over Union (IoU) between two bounding boxes."""
|
375
|
+
# Determine the coordinates of the intersection rectangle
|
376
|
+
xA = max(boxA[0], boxB[0])
|
377
|
+
yA = max(boxA[1], boxB[1])
|
378
|
+
xB = min(boxA[2], boxB[2])
|
379
|
+
yB = min(boxA[3], boxB[3])
|
380
|
+
|
381
|
+
# Compute the area of intersection
|
382
|
+
interArea = max(0, xB - xA) * max(0, yB - yA)
|
383
|
+
|
384
|
+
# Compute the area of both the bounding boxes
|
385
|
+
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
386
|
+
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
387
|
+
|
388
|
+
# Compute the IoU
|
389
|
+
iou = interArea / float(boxAArea + boxBArea - interArea)
|
390
|
+
|
391
|
+
return iou
|
392
|
+
|
393
|
+
|
394
|
+
def bbox_to_slot_index_iou(bbox: tuple[int, int, int, int]) -> int:
|
395
|
+
"""Assign the given bounding box to the slot with the highest IoU."""
|
396
|
+
best_slot = None
|
397
|
+
best_iou = -1
|
398
|
+
# Iterate through all precomputed slot bounding boxes
|
399
|
+
for slot, slot_bbox in IDX_TO_BBOX.items():
|
400
|
+
iou = calculate_iou(bbox, slot_bbox)
|
401
|
+
if iou > best_iou:
|
402
|
+
best_iou = iou
|
403
|
+
best_slot = slot
|
404
|
+
return best_slot
|
405
|
+
|
406
|
+
|
407
|
+
class IntegratedBoundingBoxModel(nn.Module, PyTorchModelHubMixin):
|
408
|
+
"""
|
409
|
+
Custom mask rcnn model with quantity prediction
|
410
|
+
|
411
|
+
Also returns the feature vectors of the detected boxes
|
412
|
+
"""
|
413
|
+
|
414
|
+
def __init__(self, load_resnet_weights=False):
|
415
|
+
super(IntegratedBoundingBoxModel, self).__init__()
|
416
|
+
weights = None
|
417
|
+
if load_resnet_weights:
|
418
|
+
weights = ResNet50_Weights.DEFAULT
|
419
|
+
|
420
|
+
self.model = fasterrcnn_resnet50_fpn_v2(
|
421
|
+
weights_backbone=weights,
|
422
|
+
image_mean=[0.63, 0.63, 0.63],
|
423
|
+
image_std=[0.21, 0.21, 0.21],
|
424
|
+
min_size=128,
|
425
|
+
max_size=256,
|
426
|
+
num_classes=len(ALL_ITEMS),
|
427
|
+
box_score_thresh=0.05,
|
428
|
+
rpn_batch_size_per_image=64,
|
429
|
+
box_detections_per_img=64,
|
430
|
+
box_batch_size_per_image=128,
|
431
|
+
)
|
432
|
+
self.model.roi_heads.quantity_prediction = nn.Linear(1024, 65)
|
433
|
+
|
434
|
+
# replace the head with leaky activations
|
435
|
+
self.model.roi_heads.forward = forward_custom.__get__(
|
436
|
+
self.model.roi_heads, type(self.model.roi_heads)
|
437
|
+
)
|
438
|
+
|
439
|
+
self.transform = v2.Compose(
|
440
|
+
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
|
441
|
+
)
|
442
|
+
|
443
|
+
def forward(self, x, targets=None):
|
444
|
+
if self.training:
|
445
|
+
# normal forward pass
|
446
|
+
loss_dict = self.model(x, targets)
|
447
|
+
return loss_dict
|
448
|
+
else:
|
449
|
+
preds = self.model(x)
|
450
|
+
return preds
|
451
|
+
|
452
|
+
def get_inventory(self, pil_image):
|
453
|
+
"""
|
454
|
+
Predict boxes and quantities
|
455
|
+
"""
|
456
|
+
img_tensor = self.transform(pil_image)
|
457
|
+
if next(self.model.parameters()).is_cuda:
|
458
|
+
img_tensor = img_tensor.cuda()
|
459
|
+
with torch.no_grad():
|
460
|
+
predictions = self.model(img_tensor.unsqueeze(0))
|
461
|
+
return self.prediction_to_inventory(predictions[0])
|
462
|
+
|
463
|
+
@staticmethod
|
464
|
+
def prediction_to_inventory(prediction, threshold=0.9) -> list[dict]:
|
465
|
+
inventory = []
|
466
|
+
seen_slots = set()
|
467
|
+
for i in range(len(prediction["boxes"])):
|
468
|
+
slot = bbox_to_slot_index_iou(prediction["boxes"][i])
|
469
|
+
score = prediction["scores"][i]
|
470
|
+
label_idx = prediction["labels"][i].item()
|
471
|
+
label = ALL_ITEMS[label_idx]
|
472
|
+
quantity = prediction["quantities"][i].item()
|
473
|
+
if score > threshold:
|
474
|
+
if slot in seen_slots:
|
475
|
+
continue
|
476
|
+
inventory.append({"slot": slot, "type": label, "quantity": quantity})
|
477
|
+
return inventory
|
478
|
+
|
479
|
+
def freeze(self):
|
480
|
+
# NOTE: this might seem excessive
|
481
|
+
# but transformers trainer is really good at enabling gradients against my will
|
482
|
+
self.eval()
|
483
|
+
self.model.eval()
|
484
|
+
self.training = False
|
485
|
+
for param in self.model.parameters():
|
486
|
+
param.requires_grad = False
|
487
|
+
self.model.training = False
|
488
|
+
self.model.roi_heads.training = False
|
489
|
+
self.model.rpn.training = False
|
490
|
+
|
491
|
+
def save(self, path: str):
|
492
|
+
torch.save(self.state_dict(), path)
|
models/dummy.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
import random
|
2
|
+
|
3
|
+
from plancraft.config import EvalConfig
|
4
|
+
from plancraft.environments.actions import (
|
5
|
+
RealActionInteraction,
|
6
|
+
SymbolicMoveAction,
|
7
|
+
SymbolicSmeltAction,
|
8
|
+
)
|
9
|
+
from plancraft.models.base import ABCModel, History
|
10
|
+
|
11
|
+
|
12
|
+
class DummyModel(ABCModel):
|
13
|
+
"""
|
14
|
+
Dummy model returns actions that do random action
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, cfg: EvalConfig):
|
18
|
+
self.symbolic_move_action = cfg.plancraft.environment.symbolic_action_space
|
19
|
+
self.history = History(objective="")
|
20
|
+
|
21
|
+
def random_select(self, observation):
|
22
|
+
if observation is None or "inventory" not in observation:
|
23
|
+
return SymbolicMoveAction(slot_from=0, slot_to=0, quantity=1)
|
24
|
+
# randomly pick an item from the inventory
|
25
|
+
item_indices = set()
|
26
|
+
for item in observation["inventory"]:
|
27
|
+
if item["quantity"] > 0:
|
28
|
+
item_indices.add(item["index"])
|
29
|
+
all_slots_to = set(range(1, 46))
|
30
|
+
empty_slots = all_slots_to - item_indices
|
31
|
+
|
32
|
+
random_slot_from = random.choice(list(item_indices))
|
33
|
+
random_slot_to = random.choice(list(empty_slots))
|
34
|
+
|
35
|
+
return SymbolicMoveAction(
|
36
|
+
slot_from=random_slot_from, slot_to=random_slot_to, quantity=1
|
37
|
+
)
|
38
|
+
|
39
|
+
def step(
|
40
|
+
self, observation: dict
|
41
|
+
) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
|
42
|
+
# add observation to history
|
43
|
+
self.history.add_observation_to_history(observation)
|
44
|
+
|
45
|
+
# get action
|
46
|
+
if self.symbolic_move_action:
|
47
|
+
action = self.random_select(observation)
|
48
|
+
else:
|
49
|
+
action = RealActionInteraction()
|
50
|
+
|
51
|
+
# add action to history
|
52
|
+
self.history.add_action_to_history(action)
|
53
|
+
|
54
|
+
return action
|
@@ -0,0 +1,16 @@
|
|
1
|
+
import os
|
2
|
+
import glob
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import imageio
|
6
|
+
|
7
|
+
|
8
|
+
def get_few_shot_images_path():
|
9
|
+
return os.path.dirname(__file__)
|
10
|
+
|
11
|
+
|
12
|
+
def load_prompt_images() -> list[np.ndarray]:
|
13
|
+
current_dir = get_few_shot_images_path()
|
14
|
+
files = glob.glob(os.path.join(current_dir, "*.png"))
|
15
|
+
images = [imageio.imread(file) for file in files]
|
16
|
+
return images
|