plancraft 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- plancraft-0.1.2.dist-info/METADATA +74 -0
- plancraft-0.1.2.dist-info/RECORD +5 -0
- {plancraft-0.1.0.dist-info → plancraft-0.1.2.dist-info}/WHEEL +1 -1
- plancraft-0.1.2.dist-info/top_level.txt +1 -0
- environments/__init__.py +0 -0
- environments/actions.py +0 -218
- environments/env_real.py +0 -315
- environments/env_symbolic.py +0 -215
- environments/items.py +0 -10
- environments/planner.py +0 -109
- environments/recipes.py +0 -542
- environments/sampler.py +0 -224
- models/__init__.py +0 -21
- models/act.py +0 -184
- models/base.py +0 -152
- models/bbox_model.py +0 -492
- models/dummy.py +0 -54
- models/few_shot_images/__init__.py +0 -16
- models/generators.py +0 -483
- models/oam.py +0 -284
- models/oracle.py +0 -268
- models/prompts.py +0 -158
- models/react.py +0 -98
- models/utils.py +0 -289
- plancraft-0.1.0.dist-info/METADATA +0 -53
- plancraft-0.1.0.dist-info/RECORD +0 -26
- plancraft-0.1.0.dist-info/top_level.txt +0 -3
- train/dataset.py +0 -187
- {plancraft-0.1.0.dist-info → plancraft-0.1.2.dist-info}/LICENSE +0 -0
models/bbox_model.py
DELETED
@@ -1,492 +0,0 @@
|
|
1
|
-
import einops
|
2
|
-
import torch
|
3
|
-
import torch.nn as nn
|
4
|
-
import torch.nn.functional as F
|
5
|
-
import torchvision.transforms.v2 as v2
|
6
|
-
from huggingface_hub import PyTorchModelHubMixin
|
7
|
-
from plancraft.environments.items import ALL_ITEMS
|
8
|
-
from torchvision.models.detection.faster_rcnn import (
|
9
|
-
fasterrcnn_resnet50_fpn_v2,
|
10
|
-
ResNet50_Weights,
|
11
|
-
)
|
12
|
-
from torchvision.models.detection.roi_heads import (
|
13
|
-
fastrcnn_loss,
|
14
|
-
keypointrcnn_inference,
|
15
|
-
keypointrcnn_loss,
|
16
|
-
maskrcnn_inference,
|
17
|
-
maskrcnn_loss,
|
18
|
-
)
|
19
|
-
from torchvision.ops import boxes as box_ops
|
20
|
-
|
21
|
-
|
22
|
-
def slot_to_bbox(slot: int):
|
23
|
-
# crafting slot
|
24
|
-
if slot == 0:
|
25
|
-
# slot size: 25x25
|
26
|
-
# top left corner: (x= 118, y=30)
|
27
|
-
box_size = 25
|
28
|
-
left_x = 117
|
29
|
-
top_y = 29
|
30
|
-
# crafting grid
|
31
|
-
elif slot < 10:
|
32
|
-
# slot size: 18x18
|
33
|
-
# top left corner: (x = 28 + 18 * col, y = 16 + 18 * row)
|
34
|
-
box_size = 18
|
35
|
-
row = (slot - 1) // 3
|
36
|
-
col = (slot - 1) % 3
|
37
|
-
left_x = 27 + (box_size * col)
|
38
|
-
top_y = 15 + (box_size * row)
|
39
|
-
# inventory
|
40
|
-
elif slot < 37:
|
41
|
-
# slot size: 18x18
|
42
|
-
# top left corner: (x= 6 + 18 * col, y=83 + 18 * row)
|
43
|
-
box_size = 18
|
44
|
-
row = (slot - 10) // 9
|
45
|
-
col = (slot - 10) % 9
|
46
|
-
left_x = 5 + (box_size * col)
|
47
|
-
top_y = 82 + (box_size * row)
|
48
|
-
# hotbar
|
49
|
-
else:
|
50
|
-
# slot size: 18x18
|
51
|
-
# top left corner: (x= 6 + 18 * col, y=141)
|
52
|
-
box_size = 18
|
53
|
-
col = (slot - 37) % 9
|
54
|
-
left_x = 5 + (box_size * col)
|
55
|
-
top_y = 140
|
56
|
-
return [left_x, top_y, left_x + box_size, top_y + box_size]
|
57
|
-
|
58
|
-
|
59
|
-
def precompute_slot_bboxes():
|
60
|
-
"""Precompute the bounding boxes for all slots."""
|
61
|
-
slot_bboxes = {}
|
62
|
-
for slot in range(46): # Assuming slot indices range from 0 to 45
|
63
|
-
slot_bboxes[slot] = slot_to_bbox(slot)
|
64
|
-
return slot_bboxes
|
65
|
-
|
66
|
-
|
67
|
-
# Precompute all slot bounding boxes
|
68
|
-
IDX_TO_BBOX = precompute_slot_bboxes()
|
69
|
-
|
70
|
-
|
71
|
-
def postprocess_detections_custom(
|
72
|
-
self,
|
73
|
-
class_logits,
|
74
|
-
quantity_logits,
|
75
|
-
box_features,
|
76
|
-
box_regression,
|
77
|
-
proposals,
|
78
|
-
image_shapes,
|
79
|
-
):
|
80
|
-
device = class_logits.device
|
81
|
-
num_classes = class_logits.shape[-1]
|
82
|
-
|
83
|
-
boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
|
84
|
-
pred_boxes = self.box_coder.decode(box_regression, proposals)
|
85
|
-
|
86
|
-
pred_scores = F.softmax(class_logits, -1)
|
87
|
-
|
88
|
-
pred_quantity = F.softmax(quantity_logits, -1).argmax(dim=-1)
|
89
|
-
# repeat the quantities, once for each class
|
90
|
-
pred_quantity = einops.repeat(
|
91
|
-
pred_quantity, "n -> n c", c=num_classes, n=pred_quantity.shape[0]
|
92
|
-
)
|
93
|
-
|
94
|
-
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
|
95
|
-
pred_scores_list = pred_scores.split(boxes_per_image, 0)
|
96
|
-
pred_quantity_list = pred_quantity.split(boxes_per_image, 0)
|
97
|
-
pred_features_list = box_features.split(boxes_per_image, 0)
|
98
|
-
|
99
|
-
all_boxes = []
|
100
|
-
all_scores = []
|
101
|
-
all_labels = []
|
102
|
-
all_quantity_labels = []
|
103
|
-
all_features = []
|
104
|
-
|
105
|
-
for boxes, scores, quantities, features, image_shape in zip(
|
106
|
-
pred_boxes_list,
|
107
|
-
pred_scores_list,
|
108
|
-
pred_quantity_list,
|
109
|
-
pred_features_list,
|
110
|
-
image_shapes,
|
111
|
-
):
|
112
|
-
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
|
113
|
-
|
114
|
-
# create labels for each prediction
|
115
|
-
labels = torch.arange(num_classes, device=device)
|
116
|
-
labels = labels.view(1, -1).expand_as(scores)
|
117
|
-
|
118
|
-
box_idxs = (
|
119
|
-
torch.arange(boxes.size(0), device=device).view(-1, 1).expand_as(labels)
|
120
|
-
)
|
121
|
-
|
122
|
-
# remove predictions with the background label
|
123
|
-
boxes = boxes[:, 1:]
|
124
|
-
scores = scores[:, 1:]
|
125
|
-
labels = labels[:, 1:]
|
126
|
-
quantities = quantities[:, 1:]
|
127
|
-
box_idxs = box_idxs[:, 1:]
|
128
|
-
|
129
|
-
# batch everything, by making every class prediction be a separate instance
|
130
|
-
boxes = boxes.reshape(-1, 4)
|
131
|
-
scores = scores.reshape(-1)
|
132
|
-
labels = labels.reshape(-1)
|
133
|
-
quantities = quantities.reshape(-1)
|
134
|
-
box_idxs = box_idxs.reshape(-1)
|
135
|
-
|
136
|
-
# remove low scoring boxes
|
137
|
-
inds = torch.where(scores > self.score_thresh)[0]
|
138
|
-
boxes, scores, labels, quantities, box_idxs = (
|
139
|
-
boxes[inds],
|
140
|
-
scores[inds],
|
141
|
-
labels[inds],
|
142
|
-
quantities[inds],
|
143
|
-
box_idxs[inds],
|
144
|
-
)
|
145
|
-
|
146
|
-
# remove empty boxes
|
147
|
-
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
|
148
|
-
boxes, scores, labels, quantities, box_idxs = (
|
149
|
-
boxes[keep],
|
150
|
-
scores[keep],
|
151
|
-
labels[keep],
|
152
|
-
quantities[keep],
|
153
|
-
box_idxs[keep],
|
154
|
-
)
|
155
|
-
|
156
|
-
# non-maximum suppression, independently done per class
|
157
|
-
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
|
158
|
-
# keep only topk scoring predictions
|
159
|
-
keep = keep[: self.detections_per_img]
|
160
|
-
boxes, scores, labels, quantities, box_idxs = (
|
161
|
-
boxes[keep],
|
162
|
-
scores[keep],
|
163
|
-
labels[keep],
|
164
|
-
quantities[keep],
|
165
|
-
box_idxs[keep],
|
166
|
-
)
|
167
|
-
|
168
|
-
all_boxes.append(boxes)
|
169
|
-
all_scores.append(scores)
|
170
|
-
all_labels.append(labels)
|
171
|
-
all_quantity_labels.append(quantities)
|
172
|
-
all_features.append(features[box_idxs])
|
173
|
-
|
174
|
-
return all_boxes, all_scores, all_labels, all_quantity_labels, all_features
|
175
|
-
|
176
|
-
|
177
|
-
def forward_custom(
|
178
|
-
self,
|
179
|
-
features,
|
180
|
-
proposals,
|
181
|
-
image_shapes,
|
182
|
-
targets=None,
|
183
|
-
):
|
184
|
-
training = False
|
185
|
-
if targets is not None:
|
186
|
-
training = True
|
187
|
-
for t in targets:
|
188
|
-
floating_point_types = (torch.float, torch.double, torch.half)
|
189
|
-
if t["boxes"].dtype not in floating_point_types:
|
190
|
-
raise TypeError(
|
191
|
-
f"target boxes must of float type, instead got {t['boxes'].dtype}"
|
192
|
-
)
|
193
|
-
if not t["labels"].dtype == torch.int64:
|
194
|
-
raise TypeError(
|
195
|
-
f"target labels must of int64 type, instead got {t['labels'].dtype}"
|
196
|
-
)
|
197
|
-
if self.has_keypoint():
|
198
|
-
if not t["keypoints"].dtype == torch.float32:
|
199
|
-
raise TypeError(
|
200
|
-
f"target keypoints must of float type, instead got {t['keypoints'].dtype}"
|
201
|
-
)
|
202
|
-
|
203
|
-
if training:
|
204
|
-
proposals, matched_idxs, labels, regression_targets = (
|
205
|
-
self.select_training_samples(proposals, targets)
|
206
|
-
)
|
207
|
-
else:
|
208
|
-
labels = None
|
209
|
-
regression_targets = None
|
210
|
-
matched_idxs = None
|
211
|
-
|
212
|
-
box_features = self.box_roi_pool(features, proposals, image_shapes)
|
213
|
-
box_features = self.box_head(box_features)
|
214
|
-
class_logits, box_regression = self.box_predictor(box_features)
|
215
|
-
|
216
|
-
result = []
|
217
|
-
losses = {}
|
218
|
-
if training:
|
219
|
-
if labels is None:
|
220
|
-
raise ValueError("labels cannot be None")
|
221
|
-
if regression_targets is None:
|
222
|
-
raise ValueError("regression_targets cannot be None")
|
223
|
-
loss_classifier, loss_box_reg = fastrcnn_loss(
|
224
|
-
class_logits, box_regression, labels, regression_targets
|
225
|
-
)
|
226
|
-
|
227
|
-
# custom addition to calculate quantity loss
|
228
|
-
dtype = proposals[0].dtype
|
229
|
-
gt_boxes = [t["boxes"].to(dtype) for t in targets]
|
230
|
-
gt_labels = [t["quantity_labels"] for t in targets]
|
231
|
-
_, quantity_labels = self.assign_targets_to_proposals(
|
232
|
-
proposals, gt_boxes, gt_labels
|
233
|
-
)
|
234
|
-
quantity_labels = torch.cat(quantity_labels, dim=0)
|
235
|
-
# needs quantity_prediction layer to be added to class
|
236
|
-
quantity_preds = self.quantity_prediction(box_features)
|
237
|
-
loss_classsifier_quantity = F.cross_entropy(
|
238
|
-
quantity_preds,
|
239
|
-
quantity_labels,
|
240
|
-
)
|
241
|
-
losses = {
|
242
|
-
"loss_classifier": loss_classifier,
|
243
|
-
"loss_box_reg": loss_box_reg,
|
244
|
-
"loss_classifier_quantity": loss_classsifier_quantity,
|
245
|
-
}
|
246
|
-
else:
|
247
|
-
quantity_logits = self.quantity_prediction(box_features)
|
248
|
-
|
249
|
-
boxes, scores, labels, quantities, features = postprocess_detections_custom(
|
250
|
-
self,
|
251
|
-
class_logits,
|
252
|
-
quantity_logits,
|
253
|
-
box_features,
|
254
|
-
box_regression,
|
255
|
-
proposals,
|
256
|
-
image_shapes,
|
257
|
-
)
|
258
|
-
num_images = len(boxes)
|
259
|
-
for i in range(num_images):
|
260
|
-
result.append(
|
261
|
-
{
|
262
|
-
"boxes": boxes[i],
|
263
|
-
"labels": labels[i],
|
264
|
-
"scores": scores[i],
|
265
|
-
"quantities": quantities[i],
|
266
|
-
"features": features[i],
|
267
|
-
}
|
268
|
-
)
|
269
|
-
|
270
|
-
if self.has_mask():
|
271
|
-
mask_proposals = [p["boxes"] for p in result]
|
272
|
-
if training:
|
273
|
-
if matched_idxs is None:
|
274
|
-
raise ValueError("if in training, matched_idxs should not be None")
|
275
|
-
|
276
|
-
# during training, only focus on positive boxes
|
277
|
-
num_images = len(proposals)
|
278
|
-
mask_proposals = []
|
279
|
-
pos_matched_idxs = []
|
280
|
-
for img_id in range(num_images):
|
281
|
-
pos = torch.where(labels[img_id] > 0)[0]
|
282
|
-
mask_proposals.append(proposals[img_id][pos])
|
283
|
-
pos_matched_idxs.append(matched_idxs[img_id][pos])
|
284
|
-
else:
|
285
|
-
pos_matched_idxs = None
|
286
|
-
|
287
|
-
if self.mask_roi_pool is not None:
|
288
|
-
mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
|
289
|
-
mask_features = self.mask_head(mask_features)
|
290
|
-
mask_logits = self.mask_predictor(mask_features)
|
291
|
-
else:
|
292
|
-
raise Exception("Expected mask_roi_pool to be not None")
|
293
|
-
|
294
|
-
loss_mask = {}
|
295
|
-
if training:
|
296
|
-
if targets is None or pos_matched_idxs is None or mask_logits is None:
|
297
|
-
raise ValueError(
|
298
|
-
"targets, pos_matched_idxs, mask_logits cannot be None when training"
|
299
|
-
)
|
300
|
-
|
301
|
-
gt_masks = [t["masks"] for t in targets]
|
302
|
-
gt_labels = [t["labels"] for t in targets]
|
303
|
-
rcnn_loss_mask = maskrcnn_loss(
|
304
|
-
mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs
|
305
|
-
)
|
306
|
-
loss_mask = {"loss_mask": rcnn_loss_mask}
|
307
|
-
else:
|
308
|
-
labels = [r["labels"] for r in result]
|
309
|
-
masks_probs = maskrcnn_inference(mask_logits, labels)
|
310
|
-
for mask_prob, r in zip(masks_probs, result):
|
311
|
-
r["masks"] = mask_prob
|
312
|
-
|
313
|
-
losses.update(loss_mask)
|
314
|
-
|
315
|
-
# keep none checks in if conditional so torchscript will conditionally
|
316
|
-
# compile each branch
|
317
|
-
if (
|
318
|
-
self.keypoint_roi_pool is not None
|
319
|
-
and self.keypoint_head is not None
|
320
|
-
and self.keypoint_predictor is not None
|
321
|
-
):
|
322
|
-
keypoint_proposals = [p["boxes"] for p in result]
|
323
|
-
if training:
|
324
|
-
# during training, only focus on positive boxes
|
325
|
-
num_images = len(proposals)
|
326
|
-
keypoint_proposals = []
|
327
|
-
pos_matched_idxs = []
|
328
|
-
if matched_idxs is None:
|
329
|
-
raise ValueError("if in trainning, matched_idxs should not be None")
|
330
|
-
|
331
|
-
for img_id in range(num_images):
|
332
|
-
pos = torch.where(labels[img_id] > 0)[0]
|
333
|
-
keypoint_proposals.append(proposals[img_id][pos])
|
334
|
-
pos_matched_idxs.append(matched_idxs[img_id][pos])
|
335
|
-
else:
|
336
|
-
pos_matched_idxs = None
|
337
|
-
|
338
|
-
keypoint_features = self.keypoint_roi_pool(
|
339
|
-
features, keypoint_proposals, image_shapes
|
340
|
-
)
|
341
|
-
keypoint_features = self.keypoint_head(keypoint_features)
|
342
|
-
keypoint_logits = self.keypoint_predictor(keypoint_features)
|
343
|
-
|
344
|
-
loss_keypoint = {}
|
345
|
-
if training:
|
346
|
-
if targets is None or pos_matched_idxs is None:
|
347
|
-
raise ValueError(
|
348
|
-
"both targets and pos_matched_idxs should not be None when in training mode"
|
349
|
-
)
|
350
|
-
|
351
|
-
gt_keypoints = [t["keypoints"] for t in targets]
|
352
|
-
rcnn_loss_keypoint = keypointrcnn_loss(
|
353
|
-
keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
|
354
|
-
)
|
355
|
-
loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
|
356
|
-
else:
|
357
|
-
if keypoint_logits is None or keypoint_proposals is None:
|
358
|
-
raise ValueError(
|
359
|
-
"both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
360
|
-
)
|
361
|
-
|
362
|
-
keypoints_probs, kp_scores = keypointrcnn_inference(
|
363
|
-
keypoint_logits, keypoint_proposals
|
364
|
-
)
|
365
|
-
for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
|
366
|
-
r["keypoints"] = keypoint_prob
|
367
|
-
r["keypoints_scores"] = kps
|
368
|
-
losses.update(loss_keypoint)
|
369
|
-
|
370
|
-
return result, losses
|
371
|
-
|
372
|
-
|
373
|
-
def calculate_iou(boxA, boxB):
|
374
|
-
"""Calculate Intersection over Union (IoU) between two bounding boxes."""
|
375
|
-
# Determine the coordinates of the intersection rectangle
|
376
|
-
xA = max(boxA[0], boxB[0])
|
377
|
-
yA = max(boxA[1], boxB[1])
|
378
|
-
xB = min(boxA[2], boxB[2])
|
379
|
-
yB = min(boxA[3], boxB[3])
|
380
|
-
|
381
|
-
# Compute the area of intersection
|
382
|
-
interArea = max(0, xB - xA) * max(0, yB - yA)
|
383
|
-
|
384
|
-
# Compute the area of both the bounding boxes
|
385
|
-
boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
386
|
-
boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
387
|
-
|
388
|
-
# Compute the IoU
|
389
|
-
iou = interArea / float(boxAArea + boxBArea - interArea)
|
390
|
-
|
391
|
-
return iou
|
392
|
-
|
393
|
-
|
394
|
-
def bbox_to_slot_index_iou(bbox: tuple[int, int, int, int]) -> int:
|
395
|
-
"""Assign the given bounding box to the slot with the highest IoU."""
|
396
|
-
best_slot = None
|
397
|
-
best_iou = -1
|
398
|
-
# Iterate through all precomputed slot bounding boxes
|
399
|
-
for slot, slot_bbox in IDX_TO_BBOX.items():
|
400
|
-
iou = calculate_iou(bbox, slot_bbox)
|
401
|
-
if iou > best_iou:
|
402
|
-
best_iou = iou
|
403
|
-
best_slot = slot
|
404
|
-
return best_slot
|
405
|
-
|
406
|
-
|
407
|
-
class IntegratedBoundingBoxModel(nn.Module, PyTorchModelHubMixin):
|
408
|
-
"""
|
409
|
-
Custom mask rcnn model with quantity prediction
|
410
|
-
|
411
|
-
Also returns the feature vectors of the detected boxes
|
412
|
-
"""
|
413
|
-
|
414
|
-
def __init__(self, load_resnet_weights=False):
|
415
|
-
super(IntegratedBoundingBoxModel, self).__init__()
|
416
|
-
weights = None
|
417
|
-
if load_resnet_weights:
|
418
|
-
weights = ResNet50_Weights.DEFAULT
|
419
|
-
|
420
|
-
self.model = fasterrcnn_resnet50_fpn_v2(
|
421
|
-
weights_backbone=weights,
|
422
|
-
image_mean=[0.63, 0.63, 0.63],
|
423
|
-
image_std=[0.21, 0.21, 0.21],
|
424
|
-
min_size=128,
|
425
|
-
max_size=256,
|
426
|
-
num_classes=len(ALL_ITEMS),
|
427
|
-
box_score_thresh=0.05,
|
428
|
-
rpn_batch_size_per_image=64,
|
429
|
-
box_detections_per_img=64,
|
430
|
-
box_batch_size_per_image=128,
|
431
|
-
)
|
432
|
-
self.model.roi_heads.quantity_prediction = nn.Linear(1024, 65)
|
433
|
-
|
434
|
-
# replace the head with leaky activations
|
435
|
-
self.model.roi_heads.forward = forward_custom.__get__(
|
436
|
-
self.model.roi_heads, type(self.model.roi_heads)
|
437
|
-
)
|
438
|
-
|
439
|
-
self.transform = v2.Compose(
|
440
|
-
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
|
441
|
-
)
|
442
|
-
|
443
|
-
def forward(self, x, targets=None):
|
444
|
-
if self.training:
|
445
|
-
# normal forward pass
|
446
|
-
loss_dict = self.model(x, targets)
|
447
|
-
return loss_dict
|
448
|
-
else:
|
449
|
-
preds = self.model(x)
|
450
|
-
return preds
|
451
|
-
|
452
|
-
def get_inventory(self, pil_image):
|
453
|
-
"""
|
454
|
-
Predict boxes and quantities
|
455
|
-
"""
|
456
|
-
img_tensor = self.transform(pil_image)
|
457
|
-
if next(self.model.parameters()).is_cuda:
|
458
|
-
img_tensor = img_tensor.cuda()
|
459
|
-
with torch.no_grad():
|
460
|
-
predictions = self.model(img_tensor.unsqueeze(0))
|
461
|
-
return self.prediction_to_inventory(predictions[0])
|
462
|
-
|
463
|
-
@staticmethod
|
464
|
-
def prediction_to_inventory(prediction, threshold=0.9) -> list[dict]:
|
465
|
-
inventory = []
|
466
|
-
seen_slots = set()
|
467
|
-
for i in range(len(prediction["boxes"])):
|
468
|
-
slot = bbox_to_slot_index_iou(prediction["boxes"][i])
|
469
|
-
score = prediction["scores"][i]
|
470
|
-
label_idx = prediction["labels"][i].item()
|
471
|
-
label = ALL_ITEMS[label_idx]
|
472
|
-
quantity = prediction["quantities"][i].item()
|
473
|
-
if score > threshold:
|
474
|
-
if slot in seen_slots:
|
475
|
-
continue
|
476
|
-
inventory.append({"slot": slot, "type": label, "quantity": quantity})
|
477
|
-
return inventory
|
478
|
-
|
479
|
-
def freeze(self):
|
480
|
-
# NOTE: this might seem excessive
|
481
|
-
# but transformers trainer is really good at enabling gradients against my will
|
482
|
-
self.eval()
|
483
|
-
self.model.eval()
|
484
|
-
self.training = False
|
485
|
-
for param in self.model.parameters():
|
486
|
-
param.requires_grad = False
|
487
|
-
self.model.training = False
|
488
|
-
self.model.roi_heads.training = False
|
489
|
-
self.model.rpn.training = False
|
490
|
-
|
491
|
-
def save(self, path: str):
|
492
|
-
torch.save(self.state_dict(), path)
|
models/dummy.py
DELETED
@@ -1,54 +0,0 @@
|
|
1
|
-
import random
|
2
|
-
|
3
|
-
from plancraft.config import EvalConfig
|
4
|
-
from plancraft.environments.actions import (
|
5
|
-
RealActionInteraction,
|
6
|
-
SymbolicMoveAction,
|
7
|
-
SymbolicSmeltAction,
|
8
|
-
)
|
9
|
-
from plancraft.models.base import ABCModel, History
|
10
|
-
|
11
|
-
|
12
|
-
class DummyModel(ABCModel):
|
13
|
-
"""
|
14
|
-
Dummy model returns actions that do random action
|
15
|
-
"""
|
16
|
-
|
17
|
-
def __init__(self, cfg: EvalConfig):
|
18
|
-
self.symbolic_move_action = cfg.plancraft.environment.symbolic_action_space
|
19
|
-
self.history = History(objective="")
|
20
|
-
|
21
|
-
def random_select(self, observation):
|
22
|
-
if observation is None or "inventory" not in observation:
|
23
|
-
return SymbolicMoveAction(slot_from=0, slot_to=0, quantity=1)
|
24
|
-
# randomly pick an item from the inventory
|
25
|
-
item_indices = set()
|
26
|
-
for item in observation["inventory"]:
|
27
|
-
if item["quantity"] > 0:
|
28
|
-
item_indices.add(item["index"])
|
29
|
-
all_slots_to = set(range(1, 46))
|
30
|
-
empty_slots = all_slots_to - item_indices
|
31
|
-
|
32
|
-
random_slot_from = random.choice(list(item_indices))
|
33
|
-
random_slot_to = random.choice(list(empty_slots))
|
34
|
-
|
35
|
-
return SymbolicMoveAction(
|
36
|
-
slot_from=random_slot_from, slot_to=random_slot_to, quantity=1
|
37
|
-
)
|
38
|
-
|
39
|
-
def step(
|
40
|
-
self, observation: dict
|
41
|
-
) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
|
42
|
-
# add observation to history
|
43
|
-
self.history.add_observation_to_history(observation)
|
44
|
-
|
45
|
-
# get action
|
46
|
-
if self.symbolic_move_action:
|
47
|
-
action = self.random_select(observation)
|
48
|
-
else:
|
49
|
-
action = RealActionInteraction()
|
50
|
-
|
51
|
-
# add action to history
|
52
|
-
self.history.add_action_to_history(action)
|
53
|
-
|
54
|
-
return action
|
@@ -1,16 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import glob
|
3
|
-
|
4
|
-
import numpy as np
|
5
|
-
import imageio
|
6
|
-
|
7
|
-
|
8
|
-
def get_few_shot_images_path():
|
9
|
-
return os.path.dirname(__file__)
|
10
|
-
|
11
|
-
|
12
|
-
def load_prompt_images() -> list[np.ndarray]:
|
13
|
-
current_dir = get_few_shot_images_path()
|
14
|
-
files = glob.glob(os.path.join(current_dir, "*.png"))
|
15
|
-
images = [imageio.imread(file) for file in files]
|
16
|
-
return images
|