plancraft 0.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
models/bbox_model.py ADDED
@@ -0,0 +1,492 @@
1
+ import einops
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms.v2 as v2
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+ from plancraft.environments.items import ALL_ITEMS
8
+ from torchvision.models.detection.faster_rcnn import (
9
+ fasterrcnn_resnet50_fpn_v2,
10
+ ResNet50_Weights,
11
+ )
12
+ from torchvision.models.detection.roi_heads import (
13
+ fastrcnn_loss,
14
+ keypointrcnn_inference,
15
+ keypointrcnn_loss,
16
+ maskrcnn_inference,
17
+ maskrcnn_loss,
18
+ )
19
+ from torchvision.ops import boxes as box_ops
20
+
21
+
22
+ def slot_to_bbox(slot: int):
23
+ # crafting slot
24
+ if slot == 0:
25
+ # slot size: 25x25
26
+ # top left corner: (x= 118, y=30)
27
+ box_size = 25
28
+ left_x = 117
29
+ top_y = 29
30
+ # crafting grid
31
+ elif slot < 10:
32
+ # slot size: 18x18
33
+ # top left corner: (x = 28 + 18 * col, y = 16 + 18 * row)
34
+ box_size = 18
35
+ row = (slot - 1) // 3
36
+ col = (slot - 1) % 3
37
+ left_x = 27 + (box_size * col)
38
+ top_y = 15 + (box_size * row)
39
+ # inventory
40
+ elif slot < 37:
41
+ # slot size: 18x18
42
+ # top left corner: (x= 6 + 18 * col, y=83 + 18 * row)
43
+ box_size = 18
44
+ row = (slot - 10) // 9
45
+ col = (slot - 10) % 9
46
+ left_x = 5 + (box_size * col)
47
+ top_y = 82 + (box_size * row)
48
+ # hotbar
49
+ else:
50
+ # slot size: 18x18
51
+ # top left corner: (x= 6 + 18 * col, y=141)
52
+ box_size = 18
53
+ col = (slot - 37) % 9
54
+ left_x = 5 + (box_size * col)
55
+ top_y = 140
56
+ return [left_x, top_y, left_x + box_size, top_y + box_size]
57
+
58
+
59
+ def precompute_slot_bboxes():
60
+ """Precompute the bounding boxes for all slots."""
61
+ slot_bboxes = {}
62
+ for slot in range(46): # Assuming slot indices range from 0 to 45
63
+ slot_bboxes[slot] = slot_to_bbox(slot)
64
+ return slot_bboxes
65
+
66
+
67
+ # Precompute all slot bounding boxes
68
+ IDX_TO_BBOX = precompute_slot_bboxes()
69
+
70
+
71
+ def postprocess_detections_custom(
72
+ self,
73
+ class_logits,
74
+ quantity_logits,
75
+ box_features,
76
+ box_regression,
77
+ proposals,
78
+ image_shapes,
79
+ ):
80
+ device = class_logits.device
81
+ num_classes = class_logits.shape[-1]
82
+
83
+ boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
84
+ pred_boxes = self.box_coder.decode(box_regression, proposals)
85
+
86
+ pred_scores = F.softmax(class_logits, -1)
87
+
88
+ pred_quantity = F.softmax(quantity_logits, -1).argmax(dim=-1)
89
+ # repeat the quantities, once for each class
90
+ pred_quantity = einops.repeat(
91
+ pred_quantity, "n -> n c", c=num_classes, n=pred_quantity.shape[0]
92
+ )
93
+
94
+ pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
95
+ pred_scores_list = pred_scores.split(boxes_per_image, 0)
96
+ pred_quantity_list = pred_quantity.split(boxes_per_image, 0)
97
+ pred_features_list = box_features.split(boxes_per_image, 0)
98
+
99
+ all_boxes = []
100
+ all_scores = []
101
+ all_labels = []
102
+ all_quantity_labels = []
103
+ all_features = []
104
+
105
+ for boxes, scores, quantities, features, image_shape in zip(
106
+ pred_boxes_list,
107
+ pred_scores_list,
108
+ pred_quantity_list,
109
+ pred_features_list,
110
+ image_shapes,
111
+ ):
112
+ boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
113
+
114
+ # create labels for each prediction
115
+ labels = torch.arange(num_classes, device=device)
116
+ labels = labels.view(1, -1).expand_as(scores)
117
+
118
+ box_idxs = (
119
+ torch.arange(boxes.size(0), device=device).view(-1, 1).expand_as(labels)
120
+ )
121
+
122
+ # remove predictions with the background label
123
+ boxes = boxes[:, 1:]
124
+ scores = scores[:, 1:]
125
+ labels = labels[:, 1:]
126
+ quantities = quantities[:, 1:]
127
+ box_idxs = box_idxs[:, 1:]
128
+
129
+ # batch everything, by making every class prediction be a separate instance
130
+ boxes = boxes.reshape(-1, 4)
131
+ scores = scores.reshape(-1)
132
+ labels = labels.reshape(-1)
133
+ quantities = quantities.reshape(-1)
134
+ box_idxs = box_idxs.reshape(-1)
135
+
136
+ # remove low scoring boxes
137
+ inds = torch.where(scores > self.score_thresh)[0]
138
+ boxes, scores, labels, quantities, box_idxs = (
139
+ boxes[inds],
140
+ scores[inds],
141
+ labels[inds],
142
+ quantities[inds],
143
+ box_idxs[inds],
144
+ )
145
+
146
+ # remove empty boxes
147
+ keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
148
+ boxes, scores, labels, quantities, box_idxs = (
149
+ boxes[keep],
150
+ scores[keep],
151
+ labels[keep],
152
+ quantities[keep],
153
+ box_idxs[keep],
154
+ )
155
+
156
+ # non-maximum suppression, independently done per class
157
+ keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
158
+ # keep only topk scoring predictions
159
+ keep = keep[: self.detections_per_img]
160
+ boxes, scores, labels, quantities, box_idxs = (
161
+ boxes[keep],
162
+ scores[keep],
163
+ labels[keep],
164
+ quantities[keep],
165
+ box_idxs[keep],
166
+ )
167
+
168
+ all_boxes.append(boxes)
169
+ all_scores.append(scores)
170
+ all_labels.append(labels)
171
+ all_quantity_labels.append(quantities)
172
+ all_features.append(features[box_idxs])
173
+
174
+ return all_boxes, all_scores, all_labels, all_quantity_labels, all_features
175
+
176
+
177
+ def forward_custom(
178
+ self,
179
+ features,
180
+ proposals,
181
+ image_shapes,
182
+ targets=None,
183
+ ):
184
+ training = False
185
+ if targets is not None:
186
+ training = True
187
+ for t in targets:
188
+ floating_point_types = (torch.float, torch.double, torch.half)
189
+ if t["boxes"].dtype not in floating_point_types:
190
+ raise TypeError(
191
+ f"target boxes must of float type, instead got {t['boxes'].dtype}"
192
+ )
193
+ if not t["labels"].dtype == torch.int64:
194
+ raise TypeError(
195
+ f"target labels must of int64 type, instead got {t['labels'].dtype}"
196
+ )
197
+ if self.has_keypoint():
198
+ if not t["keypoints"].dtype == torch.float32:
199
+ raise TypeError(
200
+ f"target keypoints must of float type, instead got {t['keypoints'].dtype}"
201
+ )
202
+
203
+ if training:
204
+ proposals, matched_idxs, labels, regression_targets = (
205
+ self.select_training_samples(proposals, targets)
206
+ )
207
+ else:
208
+ labels = None
209
+ regression_targets = None
210
+ matched_idxs = None
211
+
212
+ box_features = self.box_roi_pool(features, proposals, image_shapes)
213
+ box_features = self.box_head(box_features)
214
+ class_logits, box_regression = self.box_predictor(box_features)
215
+
216
+ result = []
217
+ losses = {}
218
+ if training:
219
+ if labels is None:
220
+ raise ValueError("labels cannot be None")
221
+ if regression_targets is None:
222
+ raise ValueError("regression_targets cannot be None")
223
+ loss_classifier, loss_box_reg = fastrcnn_loss(
224
+ class_logits, box_regression, labels, regression_targets
225
+ )
226
+
227
+ # custom addition to calculate quantity loss
228
+ dtype = proposals[0].dtype
229
+ gt_boxes = [t["boxes"].to(dtype) for t in targets]
230
+ gt_labels = [t["quantity_labels"] for t in targets]
231
+ _, quantity_labels = self.assign_targets_to_proposals(
232
+ proposals, gt_boxes, gt_labels
233
+ )
234
+ quantity_labels = torch.cat(quantity_labels, dim=0)
235
+ # needs quantity_prediction layer to be added to class
236
+ quantity_preds = self.quantity_prediction(box_features)
237
+ loss_classsifier_quantity = F.cross_entropy(
238
+ quantity_preds,
239
+ quantity_labels,
240
+ )
241
+ losses = {
242
+ "loss_classifier": loss_classifier,
243
+ "loss_box_reg": loss_box_reg,
244
+ "loss_classifier_quantity": loss_classsifier_quantity,
245
+ }
246
+ else:
247
+ quantity_logits = self.quantity_prediction(box_features)
248
+
249
+ boxes, scores, labels, quantities, features = postprocess_detections_custom(
250
+ self,
251
+ class_logits,
252
+ quantity_logits,
253
+ box_features,
254
+ box_regression,
255
+ proposals,
256
+ image_shapes,
257
+ )
258
+ num_images = len(boxes)
259
+ for i in range(num_images):
260
+ result.append(
261
+ {
262
+ "boxes": boxes[i],
263
+ "labels": labels[i],
264
+ "scores": scores[i],
265
+ "quantities": quantities[i],
266
+ "features": features[i],
267
+ }
268
+ )
269
+
270
+ if self.has_mask():
271
+ mask_proposals = [p["boxes"] for p in result]
272
+ if training:
273
+ if matched_idxs is None:
274
+ raise ValueError("if in training, matched_idxs should not be None")
275
+
276
+ # during training, only focus on positive boxes
277
+ num_images = len(proposals)
278
+ mask_proposals = []
279
+ pos_matched_idxs = []
280
+ for img_id in range(num_images):
281
+ pos = torch.where(labels[img_id] > 0)[0]
282
+ mask_proposals.append(proposals[img_id][pos])
283
+ pos_matched_idxs.append(matched_idxs[img_id][pos])
284
+ else:
285
+ pos_matched_idxs = None
286
+
287
+ if self.mask_roi_pool is not None:
288
+ mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
289
+ mask_features = self.mask_head(mask_features)
290
+ mask_logits = self.mask_predictor(mask_features)
291
+ else:
292
+ raise Exception("Expected mask_roi_pool to be not None")
293
+
294
+ loss_mask = {}
295
+ if training:
296
+ if targets is None or pos_matched_idxs is None or mask_logits is None:
297
+ raise ValueError(
298
+ "targets, pos_matched_idxs, mask_logits cannot be None when training"
299
+ )
300
+
301
+ gt_masks = [t["masks"] for t in targets]
302
+ gt_labels = [t["labels"] for t in targets]
303
+ rcnn_loss_mask = maskrcnn_loss(
304
+ mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs
305
+ )
306
+ loss_mask = {"loss_mask": rcnn_loss_mask}
307
+ else:
308
+ labels = [r["labels"] for r in result]
309
+ masks_probs = maskrcnn_inference(mask_logits, labels)
310
+ for mask_prob, r in zip(masks_probs, result):
311
+ r["masks"] = mask_prob
312
+
313
+ losses.update(loss_mask)
314
+
315
+ # keep none checks in if conditional so torchscript will conditionally
316
+ # compile each branch
317
+ if (
318
+ self.keypoint_roi_pool is not None
319
+ and self.keypoint_head is not None
320
+ and self.keypoint_predictor is not None
321
+ ):
322
+ keypoint_proposals = [p["boxes"] for p in result]
323
+ if training:
324
+ # during training, only focus on positive boxes
325
+ num_images = len(proposals)
326
+ keypoint_proposals = []
327
+ pos_matched_idxs = []
328
+ if matched_idxs is None:
329
+ raise ValueError("if in trainning, matched_idxs should not be None")
330
+
331
+ for img_id in range(num_images):
332
+ pos = torch.where(labels[img_id] > 0)[0]
333
+ keypoint_proposals.append(proposals[img_id][pos])
334
+ pos_matched_idxs.append(matched_idxs[img_id][pos])
335
+ else:
336
+ pos_matched_idxs = None
337
+
338
+ keypoint_features = self.keypoint_roi_pool(
339
+ features, keypoint_proposals, image_shapes
340
+ )
341
+ keypoint_features = self.keypoint_head(keypoint_features)
342
+ keypoint_logits = self.keypoint_predictor(keypoint_features)
343
+
344
+ loss_keypoint = {}
345
+ if training:
346
+ if targets is None or pos_matched_idxs is None:
347
+ raise ValueError(
348
+ "both targets and pos_matched_idxs should not be None when in training mode"
349
+ )
350
+
351
+ gt_keypoints = [t["keypoints"] for t in targets]
352
+ rcnn_loss_keypoint = keypointrcnn_loss(
353
+ keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
354
+ )
355
+ loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
356
+ else:
357
+ if keypoint_logits is None or keypoint_proposals is None:
358
+ raise ValueError(
359
+ "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
360
+ )
361
+
362
+ keypoints_probs, kp_scores = keypointrcnn_inference(
363
+ keypoint_logits, keypoint_proposals
364
+ )
365
+ for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
366
+ r["keypoints"] = keypoint_prob
367
+ r["keypoints_scores"] = kps
368
+ losses.update(loss_keypoint)
369
+
370
+ return result, losses
371
+
372
+
373
+ def calculate_iou(boxA, boxB):
374
+ """Calculate Intersection over Union (IoU) between two bounding boxes."""
375
+ # Determine the coordinates of the intersection rectangle
376
+ xA = max(boxA[0], boxB[0])
377
+ yA = max(boxA[1], boxB[1])
378
+ xB = min(boxA[2], boxB[2])
379
+ yB = min(boxA[3], boxB[3])
380
+
381
+ # Compute the area of intersection
382
+ interArea = max(0, xB - xA) * max(0, yB - yA)
383
+
384
+ # Compute the area of both the bounding boxes
385
+ boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
386
+ boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
387
+
388
+ # Compute the IoU
389
+ iou = interArea / float(boxAArea + boxBArea - interArea)
390
+
391
+ return iou
392
+
393
+
394
+ def bbox_to_slot_index_iou(bbox: tuple[int, int, int, int]) -> int:
395
+ """Assign the given bounding box to the slot with the highest IoU."""
396
+ best_slot = None
397
+ best_iou = -1
398
+ # Iterate through all precomputed slot bounding boxes
399
+ for slot, slot_bbox in IDX_TO_BBOX.items():
400
+ iou = calculate_iou(bbox, slot_bbox)
401
+ if iou > best_iou:
402
+ best_iou = iou
403
+ best_slot = slot
404
+ return best_slot
405
+
406
+
407
+ class IntegratedBoundingBoxModel(nn.Module, PyTorchModelHubMixin):
408
+ """
409
+ Custom mask rcnn model with quantity prediction
410
+
411
+ Also returns the feature vectors of the detected boxes
412
+ """
413
+
414
+ def __init__(self, load_resnet_weights=False):
415
+ super(IntegratedBoundingBoxModel, self).__init__()
416
+ weights = None
417
+ if load_resnet_weights:
418
+ weights = ResNet50_Weights.DEFAULT
419
+
420
+ self.model = fasterrcnn_resnet50_fpn_v2(
421
+ weights_backbone=weights,
422
+ image_mean=[0.63, 0.63, 0.63],
423
+ image_std=[0.21, 0.21, 0.21],
424
+ min_size=128,
425
+ max_size=256,
426
+ num_classes=len(ALL_ITEMS),
427
+ box_score_thresh=0.05,
428
+ rpn_batch_size_per_image=64,
429
+ box_detections_per_img=64,
430
+ box_batch_size_per_image=128,
431
+ )
432
+ self.model.roi_heads.quantity_prediction = nn.Linear(1024, 65)
433
+
434
+ # replace the head with leaky activations
435
+ self.model.roi_heads.forward = forward_custom.__get__(
436
+ self.model.roi_heads, type(self.model.roi_heads)
437
+ )
438
+
439
+ self.transform = v2.Compose(
440
+ [v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
441
+ )
442
+
443
+ def forward(self, x, targets=None):
444
+ if self.training:
445
+ # normal forward pass
446
+ loss_dict = self.model(x, targets)
447
+ return loss_dict
448
+ else:
449
+ preds = self.model(x)
450
+ return preds
451
+
452
+ def get_inventory(self, pil_image):
453
+ """
454
+ Predict boxes and quantities
455
+ """
456
+ img_tensor = self.transform(pil_image)
457
+ if next(self.model.parameters()).is_cuda:
458
+ img_tensor = img_tensor.cuda()
459
+ with torch.no_grad():
460
+ predictions = self.model(img_tensor.unsqueeze(0))
461
+ return self.prediction_to_inventory(predictions[0])
462
+
463
+ @staticmethod
464
+ def prediction_to_inventory(prediction, threshold=0.9) -> list[dict]:
465
+ inventory = []
466
+ seen_slots = set()
467
+ for i in range(len(prediction["boxes"])):
468
+ slot = bbox_to_slot_index_iou(prediction["boxes"][i])
469
+ score = prediction["scores"][i]
470
+ label_idx = prediction["labels"][i].item()
471
+ label = ALL_ITEMS[label_idx]
472
+ quantity = prediction["quantities"][i].item()
473
+ if score > threshold:
474
+ if slot in seen_slots:
475
+ continue
476
+ inventory.append({"slot": slot, "type": label, "quantity": quantity})
477
+ return inventory
478
+
479
+ def freeze(self):
480
+ # NOTE: this might seem excessive
481
+ # but transformers trainer is really good at enabling gradients against my will
482
+ self.eval()
483
+ self.model.eval()
484
+ self.training = False
485
+ for param in self.model.parameters():
486
+ param.requires_grad = False
487
+ self.model.training = False
488
+ self.model.roi_heads.training = False
489
+ self.model.rpn.training = False
490
+
491
+ def save(self, path: str):
492
+ torch.save(self.state_dict(), path)
models/dummy.py ADDED
@@ -0,0 +1,54 @@
1
+ import random
2
+
3
+ from plancraft.config import EvalConfig
4
+ from plancraft.environments.actions import (
5
+ RealActionInteraction,
6
+ SymbolicMoveAction,
7
+ SymbolicSmeltAction,
8
+ )
9
+ from plancraft.models.base import ABCModel, History
10
+
11
+
12
+ class DummyModel(ABCModel):
13
+ """
14
+ Dummy model returns actions that do random action
15
+ """
16
+
17
+ def __init__(self, cfg: EvalConfig):
18
+ self.symbolic_move_action = cfg.plancraft.environment.symbolic_action_space
19
+ self.history = History(objective="")
20
+
21
+ def random_select(self, observation):
22
+ if observation is None or "inventory" not in observation:
23
+ return SymbolicMoveAction(slot_from=0, slot_to=0, quantity=1)
24
+ # randomly pick an item from the inventory
25
+ item_indices = set()
26
+ for item in observation["inventory"]:
27
+ if item["quantity"] > 0:
28
+ item_indices.add(item["index"])
29
+ all_slots_to = set(range(1, 46))
30
+ empty_slots = all_slots_to - item_indices
31
+
32
+ random_slot_from = random.choice(list(item_indices))
33
+ random_slot_to = random.choice(list(empty_slots))
34
+
35
+ return SymbolicMoveAction(
36
+ slot_from=random_slot_from, slot_to=random_slot_to, quantity=1
37
+ )
38
+
39
+ def step(
40
+ self, observation: dict
41
+ ) -> list[SymbolicMoveAction | RealActionInteraction | SymbolicSmeltAction]:
42
+ # add observation to history
43
+ self.history.add_observation_to_history(observation)
44
+
45
+ # get action
46
+ if self.symbolic_move_action:
47
+ action = self.random_select(observation)
48
+ else:
49
+ action = RealActionInteraction()
50
+
51
+ # add action to history
52
+ self.history.add_action_to_history(action)
53
+
54
+ return action
@@ -0,0 +1,16 @@
1
+ import os
2
+ import glob
3
+
4
+ import numpy as np
5
+ import imageio
6
+
7
+
8
+ def get_few_shot_images_path():
9
+ return os.path.dirname(__file__)
10
+
11
+
12
+ def load_prompt_images() -> list[np.ndarray]:
13
+ current_dir = get_few_shot_images_path()
14
+ files = glob.glob(os.path.join(current_dir, "*.png"))
15
+ images = [imageio.imread(file) for file in files]
16
+ return images