graspzero 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
graspzero/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .predictor import GraspPredictor
2
+ __version__ = "0.1.0"
3
+ __all__ = ["GraspPredictor"]
graspzero/model.py ADDED
@@ -0,0 +1,68 @@
1
+ # model.py — exact architecture matching your checkpoint
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import Dinov2Model
6
+
7
+ class LoRALinear(nn.Module):
8
+ def __init__(self, original, rank=8):
9
+ super().__init__()
10
+ self.original = original
11
+ self.lora_A = nn.Linear(original.in_features, rank, bias=False)
12
+ self.lora_B = nn.Linear(rank, original.out_features, bias=False)
13
+ def forward(self, x):
14
+ return self.original(x) + self.lora_B(self.lora_A(x))
15
+
16
+ class AffordanceDecoder(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.block1 = nn.Sequential(
20
+ nn.Conv2d(384, 256, 3, padding=1),
21
+ nn.BatchNorm2d(256), nn.ReLU()
22
+ )
23
+ self.block2 = nn.Sequential(
24
+ nn.Conv2d(256, 128, 3, padding=1),
25
+ nn.BatchNorm2d(128), nn.ReLU()
26
+ )
27
+ self.block3 = nn.Sequential(
28
+ nn.Conv2d(128, 64, 3, padding=1),
29
+ nn.BatchNorm2d(64), nn.ReLU()
30
+ )
31
+ self.block4 = nn.Sequential(
32
+ nn.Conv2d(64, 32, 3, padding=1),
33
+ nn.BatchNorm2d(32), nn.ReLU()
34
+ )
35
+ self.output_head = nn.Conv2d(32, 1, 1)
36
+
37
+ def forward(self, x):
38
+ x = F.interpolate(self.block1(x), scale_factor=2,
39
+ mode='bilinear', align_corners=False)
40
+ x = F.interpolate(self.block2(x), scale_factor=2,
41
+ mode='bilinear', align_corners=False)
42
+ x = F.interpolate(self.block3(x), scale_factor=2,
43
+ mode='bilinear', align_corners=False)
44
+ x = F.interpolate(self.block4(x), scale_factor=2,
45
+ mode='bilinear', align_corners=False)
46
+ return self.output_head(x)
47
+
48
+ class GraspZeroModel(nn.Module):
49
+ def __init__(self):
50
+ super().__init__()
51
+ self.backbone = Dinov2Model.from_pretrained(
52
+ "facebook/dinov2-small"
53
+ )
54
+ for layer_idx in range(8, 12):
55
+ layer = self.backbone.encoder.layer[layer_idx]
56
+ attn = layer.attention.attention
57
+ attn.query = LoRALinear(attn.query)
58
+ attn.key = LoRALinear(attn.key)
59
+ attn.value = LoRALinear(attn.value)
60
+ self.decoder = AffordanceDecoder()
61
+
62
+ def forward(self, x):
63
+ outputs = self.backbone(x)
64
+ patch_tokens = outputs.last_hidden_state[:, 1:, :]
65
+ B, N, C = patch_tokens.shape
66
+ H = W = int(N ** 0.5)
67
+ patch_tokens = patch_tokens.permute(0, 2, 1).reshape(B, C, H, W)
68
+ return self.decoder(patch_tokens)
graspzero/predictor.py ADDED
@@ -0,0 +1,87 @@
1
+ # predictor.py — main user-facing class
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from huggingface_hub import hf_hub_download
7
+ from .model import GraspZeroModel
8
+
9
+ WEIGHTS_REPO = "Jignesh2619/graspzero"
10
+ WEIGHTS_FILE = "graspzero_weights.pt"
11
+
12
+ class GraspPredictor:
13
+ def __init__(self, device=None):
14
+ if device is None:
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ self.device = device
17
+
18
+ print("Loading GraspZero model...")
19
+ weights_path = hf_hub_download(
20
+ repo_id=WEIGHTS_REPO,
21
+ filename=WEIGHTS_FILE
22
+ )
23
+ self.model = GraspZeroModel()
24
+ ckpt = torch.load(weights_path, map_location="cpu")
25
+ state = ckpt["model"] if "model" in ckpt else ckpt
26
+ self.model.load_state_dict(state)
27
+ self.model.eval().to(self.device)
28
+ print(f"Ready on {self.device}")
29
+
30
+ self.transform = transforms.Compose([
31
+ transforms.Resize((518, 518)),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(
34
+ mean=[0.485, 0.456, 0.406],
35
+ std=[0.229, 0.224, 0.225]
36
+ )
37
+ ])
38
+
39
+ def predict(self, image):
40
+ """
41
+ Args:
42
+ image: file path (str) or PIL Image or numpy array
43
+ Returns:
44
+ dict with keys:
45
+ grasp_x, grasp_y — pixel coords in original image space
46
+ confidence — float 0-1
47
+ mask — numpy array (H, W) probability map
48
+ """
49
+ # Load image
50
+ if isinstance(image, str):
51
+ img = Image.open(image).convert("RGB")
52
+ elif isinstance(image, np.ndarray):
53
+ img = Image.fromarray(image).convert("RGB")
54
+ else:
55
+ img = image.convert("RGB")
56
+
57
+ orig_w, orig_h = img.size
58
+
59
+ # Inference
60
+ x = self.transform(img).unsqueeze(0).to(self.device)
61
+ with torch.no_grad():
62
+ pred = self.model(x)
63
+ mask = torch.sigmoid(pred).squeeze().cpu().numpy()
64
+
65
+ # Resize mask back to original image size
66
+ mask = np.array(
67
+ Image.fromarray((mask * 255).astype(np.uint8))
68
+ .resize((orig_w, orig_h))
69
+ ) / 255.0
70
+
71
+ # Grasp point = center of mass of high-confidence region
72
+ confidence = float(mask.max())
73
+ if confidence > 0.3:
74
+ high = mask > (confidence * 0.7)
75
+ ys, xs = np.where(high)
76
+ grasp_x = int(xs.mean())
77
+ grasp_y = int(ys.mean())
78
+ else:
79
+ grasp_x = orig_w // 2
80
+ grasp_y = orig_h // 2
81
+
82
+ return {
83
+ "grasp_x": grasp_x,
84
+ "grasp_y": grasp_y,
85
+ "confidence": confidence,
86
+ "mask": mask
87
+ }
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.4
2
+ Name: graspzero
3
+ Version: 0.1.0
4
+ Summary: Zero-shot robotic grasping — no demos, no training
5
+ Author: Jignesh
6
+ Requires-Python: >=3.9
7
+ Requires-Dist: torch>=2.0.0
8
+ Requires-Dist: torchvision>=0.15.0
9
+ Requires-Dist: transformers>=4.35.0
10
+ Requires-Dist: huggingface_hub>=0.19.0
11
+ Requires-Dist: Pillow>=9.0.0
12
+ Requires-Dist: numpy>=1.24.0
13
+ Dynamic: author
14
+ Dynamic: requires-dist
15
+ Dynamic: requires-python
16
+ Dynamic: summary
@@ -0,0 +1,7 @@
1
+ graspzero/__init__.py,sha256=njuImq59Uz6tO2oxlSuZ9E7FbRoLtpI6NZqaPIhLxc8,90
2
+ graspzero/model.py,sha256=RGlDBVH8FVM3i-XnhNlatN5_rlfEU0PaEeoSFqQGyVk,2638
3
+ graspzero/predictor.py,sha256=5H3CmSvcvD7ZC41R7foFuC0i63qYizH13NVl0ryn908,2872
4
+ graspzero-0.1.0.dist-info/METADATA,sha256=yamyw-Y3xIJWkenXroLzpumcIZUDzkZHblnV5LUScog,447
5
+ graspzero-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
6
+ graspzero-0.1.0.dist-info/top_level.txt,sha256=szyqy2z25Yylo3zO66bWx_jlUUIrpzjmDFkcxGw07BU,10
7
+ graspzero-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ graspzero