graspzero 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graspzero/__init__.py +3 -0
- graspzero/model.py +68 -0
- graspzero/predictor.py +87 -0
- graspzero-0.1.0.dist-info/METADATA +16 -0
- graspzero-0.1.0.dist-info/RECORD +7 -0
- graspzero-0.1.0.dist-info/WHEEL +5 -0
- graspzero-0.1.0.dist-info/top_level.txt +1 -0
graspzero/__init__.py
ADDED
graspzero/model.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# model.py — exact architecture matching your checkpoint
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from transformers import Dinov2Model
|
|
6
|
+
|
|
7
|
+
class LoRALinear(nn.Module):
|
|
8
|
+
def __init__(self, original, rank=8):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.original = original
|
|
11
|
+
self.lora_A = nn.Linear(original.in_features, rank, bias=False)
|
|
12
|
+
self.lora_B = nn.Linear(rank, original.out_features, bias=False)
|
|
13
|
+
def forward(self, x):
|
|
14
|
+
return self.original(x) + self.lora_B(self.lora_A(x))
|
|
15
|
+
|
|
16
|
+
class AffordanceDecoder(nn.Module):
|
|
17
|
+
def __init__(self):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.block1 = nn.Sequential(
|
|
20
|
+
nn.Conv2d(384, 256, 3, padding=1),
|
|
21
|
+
nn.BatchNorm2d(256), nn.ReLU()
|
|
22
|
+
)
|
|
23
|
+
self.block2 = nn.Sequential(
|
|
24
|
+
nn.Conv2d(256, 128, 3, padding=1),
|
|
25
|
+
nn.BatchNorm2d(128), nn.ReLU()
|
|
26
|
+
)
|
|
27
|
+
self.block3 = nn.Sequential(
|
|
28
|
+
nn.Conv2d(128, 64, 3, padding=1),
|
|
29
|
+
nn.BatchNorm2d(64), nn.ReLU()
|
|
30
|
+
)
|
|
31
|
+
self.block4 = nn.Sequential(
|
|
32
|
+
nn.Conv2d(64, 32, 3, padding=1),
|
|
33
|
+
nn.BatchNorm2d(32), nn.ReLU()
|
|
34
|
+
)
|
|
35
|
+
self.output_head = nn.Conv2d(32, 1, 1)
|
|
36
|
+
|
|
37
|
+
def forward(self, x):
|
|
38
|
+
x = F.interpolate(self.block1(x), scale_factor=2,
|
|
39
|
+
mode='bilinear', align_corners=False)
|
|
40
|
+
x = F.interpolate(self.block2(x), scale_factor=2,
|
|
41
|
+
mode='bilinear', align_corners=False)
|
|
42
|
+
x = F.interpolate(self.block3(x), scale_factor=2,
|
|
43
|
+
mode='bilinear', align_corners=False)
|
|
44
|
+
x = F.interpolate(self.block4(x), scale_factor=2,
|
|
45
|
+
mode='bilinear', align_corners=False)
|
|
46
|
+
return self.output_head(x)
|
|
47
|
+
|
|
48
|
+
class GraspZeroModel(nn.Module):
|
|
49
|
+
def __init__(self):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.backbone = Dinov2Model.from_pretrained(
|
|
52
|
+
"facebook/dinov2-small"
|
|
53
|
+
)
|
|
54
|
+
for layer_idx in range(8, 12):
|
|
55
|
+
layer = self.backbone.encoder.layer[layer_idx]
|
|
56
|
+
attn = layer.attention.attention
|
|
57
|
+
attn.query = LoRALinear(attn.query)
|
|
58
|
+
attn.key = LoRALinear(attn.key)
|
|
59
|
+
attn.value = LoRALinear(attn.value)
|
|
60
|
+
self.decoder = AffordanceDecoder()
|
|
61
|
+
|
|
62
|
+
def forward(self, x):
|
|
63
|
+
outputs = self.backbone(x)
|
|
64
|
+
patch_tokens = outputs.last_hidden_state[:, 1:, :]
|
|
65
|
+
B, N, C = patch_tokens.shape
|
|
66
|
+
H = W = int(N ** 0.5)
|
|
67
|
+
patch_tokens = patch_tokens.permute(0, 2, 1).reshape(B, C, H, W)
|
|
68
|
+
return self.decoder(patch_tokens)
|
graspzero/predictor.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# predictor.py — main user-facing class
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from torchvision import transforms
|
|
6
|
+
from huggingface_hub import hf_hub_download
|
|
7
|
+
from .model import GraspZeroModel
|
|
8
|
+
|
|
9
|
+
WEIGHTS_REPO = "Jignesh2619/graspzero"
|
|
10
|
+
WEIGHTS_FILE = "graspzero_weights.pt"
|
|
11
|
+
|
|
12
|
+
class GraspPredictor:
|
|
13
|
+
def __init__(self, device=None):
|
|
14
|
+
if device is None:
|
|
15
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
16
|
+
self.device = device
|
|
17
|
+
|
|
18
|
+
print("Loading GraspZero model...")
|
|
19
|
+
weights_path = hf_hub_download(
|
|
20
|
+
repo_id=WEIGHTS_REPO,
|
|
21
|
+
filename=WEIGHTS_FILE
|
|
22
|
+
)
|
|
23
|
+
self.model = GraspZeroModel()
|
|
24
|
+
ckpt = torch.load(weights_path, map_location="cpu")
|
|
25
|
+
state = ckpt["model"] if "model" in ckpt else ckpt
|
|
26
|
+
self.model.load_state_dict(state)
|
|
27
|
+
self.model.eval().to(self.device)
|
|
28
|
+
print(f"Ready on {self.device}")
|
|
29
|
+
|
|
30
|
+
self.transform = transforms.Compose([
|
|
31
|
+
transforms.Resize((518, 518)),
|
|
32
|
+
transforms.ToTensor(),
|
|
33
|
+
transforms.Normalize(
|
|
34
|
+
mean=[0.485, 0.456, 0.406],
|
|
35
|
+
std=[0.229, 0.224, 0.225]
|
|
36
|
+
)
|
|
37
|
+
])
|
|
38
|
+
|
|
39
|
+
def predict(self, image):
|
|
40
|
+
"""
|
|
41
|
+
Args:
|
|
42
|
+
image: file path (str) or PIL Image or numpy array
|
|
43
|
+
Returns:
|
|
44
|
+
dict with keys:
|
|
45
|
+
grasp_x, grasp_y — pixel coords in original image space
|
|
46
|
+
confidence — float 0-1
|
|
47
|
+
mask — numpy array (H, W) probability map
|
|
48
|
+
"""
|
|
49
|
+
# Load image
|
|
50
|
+
if isinstance(image, str):
|
|
51
|
+
img = Image.open(image).convert("RGB")
|
|
52
|
+
elif isinstance(image, np.ndarray):
|
|
53
|
+
img = Image.fromarray(image).convert("RGB")
|
|
54
|
+
else:
|
|
55
|
+
img = image.convert("RGB")
|
|
56
|
+
|
|
57
|
+
orig_w, orig_h = img.size
|
|
58
|
+
|
|
59
|
+
# Inference
|
|
60
|
+
x = self.transform(img).unsqueeze(0).to(self.device)
|
|
61
|
+
with torch.no_grad():
|
|
62
|
+
pred = self.model(x)
|
|
63
|
+
mask = torch.sigmoid(pred).squeeze().cpu().numpy()
|
|
64
|
+
|
|
65
|
+
# Resize mask back to original image size
|
|
66
|
+
mask = np.array(
|
|
67
|
+
Image.fromarray((mask * 255).astype(np.uint8))
|
|
68
|
+
.resize((orig_w, orig_h))
|
|
69
|
+
) / 255.0
|
|
70
|
+
|
|
71
|
+
# Grasp point = center of mass of high-confidence region
|
|
72
|
+
confidence = float(mask.max())
|
|
73
|
+
if confidence > 0.3:
|
|
74
|
+
high = mask > (confidence * 0.7)
|
|
75
|
+
ys, xs = np.where(high)
|
|
76
|
+
grasp_x = int(xs.mean())
|
|
77
|
+
grasp_y = int(ys.mean())
|
|
78
|
+
else:
|
|
79
|
+
grasp_x = orig_w // 2
|
|
80
|
+
grasp_y = orig_h // 2
|
|
81
|
+
|
|
82
|
+
return {
|
|
83
|
+
"grasp_x": grasp_x,
|
|
84
|
+
"grasp_y": grasp_y,
|
|
85
|
+
"confidence": confidence,
|
|
86
|
+
"mask": mask
|
|
87
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: graspzero
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Zero-shot robotic grasping — no demos, no training
|
|
5
|
+
Author: Jignesh
|
|
6
|
+
Requires-Python: >=3.9
|
|
7
|
+
Requires-Dist: torch>=2.0.0
|
|
8
|
+
Requires-Dist: torchvision>=0.15.0
|
|
9
|
+
Requires-Dist: transformers>=4.35.0
|
|
10
|
+
Requires-Dist: huggingface_hub>=0.19.0
|
|
11
|
+
Requires-Dist: Pillow>=9.0.0
|
|
12
|
+
Requires-Dist: numpy>=1.24.0
|
|
13
|
+
Dynamic: author
|
|
14
|
+
Dynamic: requires-dist
|
|
15
|
+
Dynamic: requires-python
|
|
16
|
+
Dynamic: summary
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
graspzero/__init__.py,sha256=njuImq59Uz6tO2oxlSuZ9E7FbRoLtpI6NZqaPIhLxc8,90
|
|
2
|
+
graspzero/model.py,sha256=RGlDBVH8FVM3i-XnhNlatN5_rlfEU0PaEeoSFqQGyVk,2638
|
|
3
|
+
graspzero/predictor.py,sha256=5H3CmSvcvD7ZC41R7foFuC0i63qYizH13NVl0ryn908,2872
|
|
4
|
+
graspzero-0.1.0.dist-info/METADATA,sha256=yamyw-Y3xIJWkenXroLzpumcIZUDzkZHblnV5LUScog,447
|
|
5
|
+
graspzero-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
6
|
+
graspzero-0.1.0.dist-info/top_level.txt,sha256=szyqy2z25Yylo3zO66bWx_jlUUIrpzjmDFkcxGw07BU,10
|
|
7
|
+
graspzero-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
graspzero
|