py3dcal 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py3DCal/__init__.py +12 -0
- py3DCal/data_collection/Calibrator.py +298 -0
- py3DCal/data_collection/Printers/Ender3/Ender3.py +82 -0
- py3DCal/data_collection/Printers/Ender3/__init__.py +0 -0
- py3DCal/data_collection/Printers/Printer.py +63 -0
- py3DCal/data_collection/Printers/__init__.py +0 -0
- py3DCal/data_collection/Sensors/DIGIT/DIGIT.py +47 -0
- py3DCal/data_collection/Sensors/DIGIT/__init__.py +0 -0
- py3DCal/data_collection/Sensors/DIGIT/default.csv +1222 -0
- py3DCal/data_collection/Sensors/GelsightMini/GelsightMini.py +45 -0
- py3DCal/data_collection/Sensors/GelsightMini/__init__.py +0 -0
- py3DCal/data_collection/Sensors/GelsightMini/default.csv +1210 -0
- py3DCal/data_collection/Sensors/Sensor.py +35 -0
- py3DCal/data_collection/Sensors/__init__.py +0 -0
- py3DCal/data_collection/__init__.py +0 -0
- py3DCal/model_training/__init__.py +0 -0
- py3DCal/model_training/datasets/DIGIT_dataset.py +75 -0
- py3DCal/model_training/datasets/GelSightMini_dataset.py +73 -0
- py3DCal/model_training/datasets/__init__.py +3 -0
- py3DCal/model_training/datasets/split_dataset.py +38 -0
- py3DCal/model_training/datasets/tactile_sensor_dataset.py +82 -0
- py3DCal/model_training/lib/__init__.py +0 -0
- py3DCal/model_training/lib/add_coordinate_embeddings.py +29 -0
- py3DCal/model_training/lib/depthmaps.py +74 -0
- py3DCal/model_training/lib/fast_poisson.py +51 -0
- py3DCal/model_training/lib/get_gradient_map.py +39 -0
- py3DCal/model_training/lib/precompute_gradients.py +61 -0
- py3DCal/model_training/lib/train_model.py +96 -0
- py3DCal/model_training/lib/validate_device.py +22 -0
- py3DCal/model_training/lib/validate_parameters.py +45 -0
- py3DCal/model_training/models/__init__.py +1 -0
- py3DCal/model_training/models/touchnet.py +211 -0
- py3DCal/model_training/touchnet/__init__.py +0 -0
- py3DCal/model_training/touchnet/dataset.py +78 -0
- py3DCal/model_training/touchnet/touchnet.py +736 -0
- py3DCal/model_training/touchnet/touchnet_architecture.py +72 -0
- py3DCal/utils/__init__.py +0 -0
- py3DCal/utils/utils.py +32 -0
- py3dcal-1.0.0.dist-info/LICENSE +21 -0
- py3dcal-1.0.0.dist-info/METADATA +29 -0
- py3dcal-1.0.0.dist-info/RECORD +44 -0
- py3dcal-1.0.0.dist-info/WHEEL +5 -0
- py3dcal-1.0.0.dist-info/entry_points.txt +3 -0
- py3dcal-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,736 @@
|
|
|
1
|
+
# touchnet.py
|
|
2
|
+
from enum import Enum
|
|
3
|
+
import os
|
|
4
|
+
from PIL import Image
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from torch.hub import load_state_dict_from_url
|
|
9
|
+
import torch.optim as optim
|
|
10
|
+
from torchvision import transforms
|
|
11
|
+
import cv2
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from sklearn.model_selection import train_test_split
|
|
14
|
+
from torch.utils.data import Subset, DataLoader
|
|
15
|
+
import math
|
|
16
|
+
import copy
|
|
17
|
+
import matplotlib.pyplot as plt
|
|
18
|
+
import requests
|
|
19
|
+
import tarfile
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Union
|
|
23
|
+
from .dataset import TactileSensorDataset
|
|
24
|
+
from .touchnet_architecture import TouchNetArchitecture
|
|
25
|
+
from ..lib.fast_poisson import fast_poisson
|
|
26
|
+
|
|
27
|
+
class SensorType(Enum):
|
|
28
|
+
"""
|
|
29
|
+
SensorType: Available sensor types with pretrained weights and compiled datasets
|
|
30
|
+
"""
|
|
31
|
+
DIGIT = "DIGIT"
|
|
32
|
+
GELSIGHTMINI = "GelSightMini"
|
|
33
|
+
CUSTOM = "Custom"
|
|
34
|
+
|
|
35
|
+
class TouchNet:
|
|
36
|
+
"""
|
|
37
|
+
TouchNet: A Deep Learning Model for Enhanced Calibration and Sensing of Vision-Based Tactile Sensors
|
|
38
|
+
Args:
|
|
39
|
+
root (str or pathlib.Path, optional): Root directory for datasets and models. Defaults to current directory.
|
|
40
|
+
sensor_type (py3DCal.SensorType, optional): Type of tactile sensor. Defaults to py3DCal.SensorType.CUSTOM.
|
|
41
|
+
load_pretrained_model (bool, optional): If True, loads the pretrained model for the specified sensor type. Defaults to False.
|
|
42
|
+
download_dataset (bool, optional): If True, downloads the dataset for the specified sensor type. Defaults to False.
|
|
43
|
+
device (str, optional): Device to run the model on. Defaults to "cpu".
|
|
44
|
+
"""
|
|
45
|
+
def __init__(self, root: Union[str, Path] = Path("."), sensor_type: SensorType = SensorType.CUSTOM, load_pretrained_model: bool = False, download_dataset: bool = False, device: str = "cpu"):
|
|
46
|
+
|
|
47
|
+
self._validate_parameters(sensor_type, load_pretrained_model, download_dataset, device)
|
|
48
|
+
|
|
49
|
+
self.model = TouchNetArchitecture()
|
|
50
|
+
self.sensor_type = sensor_type
|
|
51
|
+
self.load_pretrained_model = load_pretrained_model
|
|
52
|
+
self.download_dataset = download_dataset
|
|
53
|
+
self.transform = transforms.Compose([transforms.ToTensor()])
|
|
54
|
+
self.root = root
|
|
55
|
+
if sensor_type == SensorType.DIGIT:
|
|
56
|
+
self.dataset_path = os.path.join(root, "digit_calibration_data")
|
|
57
|
+
elif sensor_type == SensorType.GELSIGHTMINI:
|
|
58
|
+
self.dataset_path = os.path.join(root, "gsmini_calibration_data")
|
|
59
|
+
else:
|
|
60
|
+
self.dataset_path = "."
|
|
61
|
+
self.blank_image_path = os.path.join(self.dataset_path, "blank_images", "blank.png")
|
|
62
|
+
self.annotation_path = os.path.join(self.dataset_path, "annotations", "annotations.csv")
|
|
63
|
+
self.device = device
|
|
64
|
+
|
|
65
|
+
if self.load_pretrained_model:
|
|
66
|
+
self._load_pretrained_model()
|
|
67
|
+
|
|
68
|
+
self.model.to(self.device)
|
|
69
|
+
|
|
70
|
+
if self.download_dataset:
|
|
71
|
+
self._download_dataset()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _validate_parameters(self, sensor_type: SensorType, load_pretrained_model: bool, download_dataset: bool, device: str):
|
|
75
|
+
"""
|
|
76
|
+
Validates the parameters.
|
|
77
|
+
Args:
|
|
78
|
+
sensor_type (py3DCal.SensorType): Type of tactile sensor.
|
|
79
|
+
load_pretrained_model (bool): If True, loads the pretrained model for the specified sensor type.
|
|
80
|
+
download_dataset (bool): If True, downloads the dataset for the specified sensor type.
|
|
81
|
+
device (str): Device to run the model on.
|
|
82
|
+
Returns:
|
|
83
|
+
None.
|
|
84
|
+
"""
|
|
85
|
+
self._validate_sensor_type(sensor_type)
|
|
86
|
+
self._validate_device(device)
|
|
87
|
+
self._validate_load_pretrained_download_dataset(sensor_type, load_pretrained_model, download_dataset)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _validate_device(self, device: str):
|
|
91
|
+
"""
|
|
92
|
+
Validates the device by converting it to a torch.device object.
|
|
93
|
+
Args:
|
|
94
|
+
device (str): Device to run the model on.
|
|
95
|
+
Returns:
|
|
96
|
+
None.
|
|
97
|
+
Raises:
|
|
98
|
+
ValueError: If the device is not specified or invalid.
|
|
99
|
+
"""
|
|
100
|
+
try:
|
|
101
|
+
device = torch.device(device)
|
|
102
|
+
except Exception as e:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"Invalid device '{device}'. Valid options include:\n"
|
|
105
|
+
" - 'cpu': CPU processing\n"
|
|
106
|
+
" - 'cuda' or 'cuda:0': NVIDIA GPU\n"
|
|
107
|
+
" - 'mps': Apple Silicon GPU\n"
|
|
108
|
+
"See: https://pytorch.org/docs/stable/tensor_attributes.html#torch.device"
|
|
109
|
+
) from e
|
|
110
|
+
|
|
111
|
+
def _validate_sensor_type(self, sensor_type):
|
|
112
|
+
"""
|
|
113
|
+
Validates the sensor type.
|
|
114
|
+
Args:
|
|
115
|
+
sensor_type (py3DCal.SensorType): Type of tactile sensor.
|
|
116
|
+
Returns:
|
|
117
|
+
None.
|
|
118
|
+
Raises:
|
|
119
|
+
ValueError: If the sensor type is not specified or invalid.
|
|
120
|
+
"""
|
|
121
|
+
if sensor_type is None or sensor_type not in [SensorType.DIGIT, SensorType.GELSIGHTMINI, SensorType.CUSTOM]:
|
|
122
|
+
raise ValueError(f"Invalid sensor type: {sensor_type}. Sensor type must be either {SensorType.DIGIT}, {SensorType.GELSIGHTMINI}, or {SensorType.CUSTOM}.")
|
|
123
|
+
|
|
124
|
+
def _validate_load_pretrained_download_dataset(self, sensor_type: SensorType, load_pretrained_model: bool, download_dataset: bool):
|
|
125
|
+
"""
|
|
126
|
+
Validates that load pretrained and donwload dataset are being called with either SensorType.DIGIT or SensorType.GELSIGHTMINI.
|
|
127
|
+
Args:
|
|
128
|
+
sensor_type (py3DCal.SensorType): Type of tactile sensor.
|
|
129
|
+
load_pretrained_model (bool): If True, loads the pretrained model for the specified sensor type.
|
|
130
|
+
download_dataset (bool): If True, downloads the dataset for the specified sensor type.
|
|
131
|
+
Returns:
|
|
132
|
+
None.
|
|
133
|
+
Raises:
|
|
134
|
+
ValueError: If load pretrained and download dataset are being called with SensorType.CUSTOM.
|
|
135
|
+
"""
|
|
136
|
+
if (load_pretrained_model is True or download_dataset is True) and (sensor_type is not SensorType.DIGIT and sensor_type is not SensorType.GELSIGHTMINI):
|
|
137
|
+
raise ValueError("Cannot load pretrained model or download dataset for custom sensor type. Sensor type must be either SensorType.DIGIT or SensorType.GELSIGHTMINI.")
|
|
138
|
+
|
|
139
|
+
def set_dataset_path(self, dataset_path: Union[str, Path]):
|
|
140
|
+
"""
|
|
141
|
+
Set the dataset path for custom datasets.
|
|
142
|
+
Args:
|
|
143
|
+
dataset_path (str or pathlib.Path): Path to the dataset.
|
|
144
|
+
Returns:
|
|
145
|
+
None.
|
|
146
|
+
"""
|
|
147
|
+
self.dataset_path = dataset_path
|
|
148
|
+
self.annotation_path = os.path.join(dataset_path, "annotated_data.csv")
|
|
149
|
+
# self.blank_image_path = os.path.join(dataset_path, "blank.png")
|
|
150
|
+
self.blank_image_path = "../data/sensors/DIGIT/Images/blank.png"
|
|
151
|
+
|
|
152
|
+
def set_blank_image_path(self, blank_image_path: Union[str, Path]):
|
|
153
|
+
"""
|
|
154
|
+
Set the blank image path for custom datasets.
|
|
155
|
+
Args:
|
|
156
|
+
blank_image_path (str or pathlib.Path): Path to the blank image.
|
|
157
|
+
Returns:
|
|
158
|
+
None.
|
|
159
|
+
"""
|
|
160
|
+
self.blank_image_path = blank_image_path
|
|
161
|
+
|
|
162
|
+
def train(self, num_epochs: int = 60, batch_size: int = 64, learning_rate: float = 1e-4, train_split: float = 0.8, loss_fn: nn.Module = nn.MSELoss()):
|
|
163
|
+
"""
|
|
164
|
+
Train TouchNet model on a dataset for 60 epochs with a
|
|
165
|
+
64 batch size, and AdamW optimizer with learning rate 1e-4.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
num_epochs (int): Number of epochs to train for. Defaults to 60.
|
|
169
|
+
batch_size (int): Batch size. Defaults to 64.
|
|
170
|
+
learning_rate (float): Learning rate. Defaults to 1e-4.
|
|
171
|
+
train_split (float): Proportion of data to use for training. Defaults to 0.8.
|
|
172
|
+
loss_fn (nn.Module): Loss function. Defaults to nn.MSELoss().
|
|
173
|
+
|
|
174
|
+
Outputs:
|
|
175
|
+
weights.pth: Trained model weights.
|
|
176
|
+
loss.csv: Training and testing losses.
|
|
177
|
+
"""
|
|
178
|
+
optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-4)
|
|
179
|
+
df = pd.read_csv(self.annotation_path, comment='#')
|
|
180
|
+
unique_coords = df[['x_mm', 'y_mm']].drop_duplicates()
|
|
181
|
+
coord_tuples = [(row['x_mm'], row['y_mm']) for _, row in unique_coords.iterrows()]
|
|
182
|
+
train_coords, test_coords = train_test_split(coord_tuples, train_size=train_split, random_state=42)
|
|
183
|
+
train_coords_set = set(train_coords)
|
|
184
|
+
test_coords_set = set(test_coords)
|
|
185
|
+
train_idx = []
|
|
186
|
+
test_idx = []
|
|
187
|
+
for i in range(len(df)):
|
|
188
|
+
coord = (df.loc[i, 'x_mm'], df.loc[i, 'y_mm'])
|
|
189
|
+
if coord in train_coords_set:
|
|
190
|
+
train_idx.append(i)
|
|
191
|
+
elif coord in test_coords_set:
|
|
192
|
+
test_idx.append(i)
|
|
193
|
+
transform = transforms.Compose([transforms.ToTensor()])
|
|
194
|
+
dataset = TactileSensorDataset(dataset_path=os.path.join(self.dataset_path, "probe_images"), annotation_path=self.annotation_path, blank_image_path=self.blank_image_path, transform=transform)
|
|
195
|
+
train_dataset = Subset(dataset, train_idx)
|
|
196
|
+
test_dataset = Subset(dataset, test_idx)
|
|
197
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, persistent_workers=True)
|
|
198
|
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True, persistent_workers=True)
|
|
199
|
+
epoch_train_losses = []
|
|
200
|
+
epoch_test_losses = []
|
|
201
|
+
for epoch in range(num_epochs):
|
|
202
|
+
self.model.train()
|
|
203
|
+
train_loss = 0.0
|
|
204
|
+
for batch_idx, (inputs, targets) in enumerate(train_loader):
|
|
205
|
+
inputs = inputs.to(torch.float32).to(self.device)
|
|
206
|
+
targets = targets.to(torch.float32).to(self.device)
|
|
207
|
+
optimizer.zero_grad()
|
|
208
|
+
outputs = self.model(inputs)
|
|
209
|
+
|
|
210
|
+
loss = loss_fn(outputs, targets)
|
|
211
|
+
loss.backward()
|
|
212
|
+
optimizer.step()
|
|
213
|
+
|
|
214
|
+
train_loss += loss.item()
|
|
215
|
+
|
|
216
|
+
print(f" [Batch {batch_idx}/{len(train_loader)}] - Loss: {loss.item():.4f}")
|
|
217
|
+
avg_train_loss = train_loss / len(train_loader)
|
|
218
|
+
epoch_train_losses.append(avg_train_loss)
|
|
219
|
+
print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f}")
|
|
220
|
+
self.model.eval()
|
|
221
|
+
test_loss = 0.0
|
|
222
|
+
with torch.no_grad():
|
|
223
|
+
for inputs, targets in test_loader:
|
|
224
|
+
inputs = inputs.to(torch.float32).to(self.device)
|
|
225
|
+
targets = targets.to(torch.float32).to(self.device)
|
|
226
|
+
outputs = self.model(inputs)
|
|
227
|
+
loss = loss_fn(outputs, targets)
|
|
228
|
+
test_loss += loss.item()
|
|
229
|
+
|
|
230
|
+
avg_test_loss = test_loss / len(test_loader)
|
|
231
|
+
epoch_test_losses.append(avg_test_loss)
|
|
232
|
+
print(f"TEST LOSS: {avg_test_loss:.4f}")
|
|
233
|
+
|
|
234
|
+
with open("losses.csv", "w") as f:
|
|
235
|
+
f.write("epoch,train_loss,test_loss\n")
|
|
236
|
+
for i in range(len(epoch_train_losses)):
|
|
237
|
+
f.write(f"{i+1},{epoch_train_losses[i]},{epoch_test_losses[i]}\n")
|
|
238
|
+
torch.save(self.model.state_dict(), "weights.pth")
|
|
239
|
+
|
|
240
|
+
def _add_coordinate_channels(self, image: torch.Tensor) -> torch.Tensor:
|
|
241
|
+
"""
|
|
242
|
+
Adds positional embedding to the input image by appending x and y coordinate channels.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
image (torch.Tensor): Input image tensor of shape (C, H, W).
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
torch.Tensor: Image tensor with added coordinate channels of shape (C+2, H, W).
|
|
249
|
+
- X channel: column indices (1s in first column, 2s in second column, etc.)
|
|
250
|
+
- Y channel: row indices (1s in first row, 2s in second row, etc.)
|
|
251
|
+
"""
|
|
252
|
+
# Get image dimensions
|
|
253
|
+
_, height, width = image.shape
|
|
254
|
+
|
|
255
|
+
# Create x coordinate channel (column indices)
|
|
256
|
+
x_coords = torch.arange(1, width + 1, dtype=torch.float32).unsqueeze(0).repeat(height, 1)
|
|
257
|
+
x_channel = x_coords.unsqueeze(0) # Add channel dimension
|
|
258
|
+
|
|
259
|
+
# Create y coordinate channel (row indices)
|
|
260
|
+
y_coords = torch.arange(1, height + 1, dtype=torch.float32).unsqueeze(1).repeat(1, width)
|
|
261
|
+
y_channel = y_coords.unsqueeze(0) # Add channel dimension
|
|
262
|
+
|
|
263
|
+
# Concatenate original image with coordinate channels
|
|
264
|
+
image_with_coords = torch.cat([image, x_channel, y_channel], dim=0)
|
|
265
|
+
|
|
266
|
+
return image_with_coords
|
|
267
|
+
|
|
268
|
+
def get_depthmap(self, image_path: str) -> np.ndarray:
|
|
269
|
+
"""
|
|
270
|
+
Returns the depthmap for a given input image.
|
|
271
|
+
Args:
|
|
272
|
+
image_path (str): Path to the input image.
|
|
273
|
+
Returns:
|
|
274
|
+
depthmap (numpy.ndarray): The computed depthmap.
|
|
275
|
+
"""
|
|
276
|
+
model = self.model.to(self.device)
|
|
277
|
+
model.eval()
|
|
278
|
+
image = Image.open(image_path)
|
|
279
|
+
image_pil = image.convert("RGB")
|
|
280
|
+
transformed_image = self.transform(image_pil)
|
|
281
|
+
blank_image = Image.open(self.blank_image_path)
|
|
282
|
+
blank_image_pil = blank_image.convert("RGB")
|
|
283
|
+
transformed_blank_image = self.transform(blank_image_pil)
|
|
284
|
+
augmented_image = transformed_image - transformed_blank_image
|
|
285
|
+
augmented_image = self._add_coordinate_channels(augmented_image)
|
|
286
|
+
augmented_image = augmented_image.unsqueeze(0).to(self.device)
|
|
287
|
+
|
|
288
|
+
with torch.no_grad():
|
|
289
|
+
output = model(augmented_image)
|
|
290
|
+
|
|
291
|
+
output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
|
292
|
+
|
|
293
|
+
depthmap = fast_poisson(output[:,:,0], output[:,:,1])
|
|
294
|
+
|
|
295
|
+
depthmap = np.clip(-depthmap, a_min=0, a_max=None)
|
|
296
|
+
|
|
297
|
+
return depthmap
|
|
298
|
+
|
|
299
|
+
def save_depthmap_image(self, image_path: str, save_path: Union[str, Path] = Path("depthmap.png")):
|
|
300
|
+
"""
|
|
301
|
+
Save an image of the depthmap for a given input image.
|
|
302
|
+
Args:
|
|
303
|
+
image_path (str): Path to the input image.
|
|
304
|
+
save_path (str or pathlib.Path): Path to save the depthmap image.
|
|
305
|
+
Returns:
|
|
306
|
+
None.
|
|
307
|
+
"""
|
|
308
|
+
depthmap = self.get_depthmap(image_path)
|
|
309
|
+
|
|
310
|
+
plt.imsave(save_path, depthmap, cmap='viridis')
|
|
311
|
+
|
|
312
|
+
def show_depthmap(self, image_path: str):
|
|
313
|
+
"""
|
|
314
|
+
Show the depthmap for a given input image.
|
|
315
|
+
Args:
|
|
316
|
+
image_path (str): Path to the input image.
|
|
317
|
+
Returns:
|
|
318
|
+
None.
|
|
319
|
+
"""
|
|
320
|
+
depthmap = self.get_depthmap(image_path)
|
|
321
|
+
|
|
322
|
+
plt.imshow(depthmap)
|
|
323
|
+
plt.show()
|
|
324
|
+
|
|
325
|
+
def _load_pretrained_model(self):
|
|
326
|
+
"""
|
|
327
|
+
Loads a pretrained model for either the DIGIT or GelSightMini sensor.
|
|
328
|
+
Args:
|
|
329
|
+
None.
|
|
330
|
+
Returns:
|
|
331
|
+
None.
|
|
332
|
+
"""
|
|
333
|
+
if self.sensor_type == SensorType.DIGIT:
|
|
334
|
+
file_path = os.path.join(self.root, "digit_pretrained_weights.pth")
|
|
335
|
+
|
|
336
|
+
# Check if DIGIT pretrained weights exist locally, if not download them
|
|
337
|
+
if not os.path.exists(file_path):
|
|
338
|
+
|
|
339
|
+
print(f"Downloading DIGIT pretrained weights ...")
|
|
340
|
+
response = requests.get('https://zenodo.org/records/17487330/files/digit_pretrained_weights.pth?download=1', stream=True)
|
|
341
|
+
response.raise_for_status()
|
|
342
|
+
|
|
343
|
+
total_size = int(response.headers.get('content-length', 0))
|
|
344
|
+
block_size = 1024
|
|
345
|
+
|
|
346
|
+
# Save file in chunks to handle large datasets
|
|
347
|
+
with open(file_path, 'wb') as f, tqdm(
|
|
348
|
+
total=total_size,
|
|
349
|
+
unit='B',
|
|
350
|
+
unit_scale=True,
|
|
351
|
+
desc="Downloading",
|
|
352
|
+
ncols=80
|
|
353
|
+
) as progress_bar:
|
|
354
|
+
for chunk in response.iter_content(chunk_size=block_size):
|
|
355
|
+
if chunk:
|
|
356
|
+
f.write(chunk)
|
|
357
|
+
progress_bar.update(len(chunk))
|
|
358
|
+
|
|
359
|
+
print(f"Download complete!")
|
|
360
|
+
else:
|
|
361
|
+
print(f"DIGIT pretrained weights already exists at: {file_path}/")
|
|
362
|
+
|
|
363
|
+
elif self.sensor_type == SensorType.GELSIGHTMINI:
|
|
364
|
+
file_path = os.path.join(self.root, "gsmini_pretrained_weights.pth")
|
|
365
|
+
|
|
366
|
+
# Check if GelSight Mini pretrained weights exist locally, if not download them
|
|
367
|
+
if not os.path.exists(file_path):
|
|
368
|
+
|
|
369
|
+
print(f"Downloading GelSight Mini pretrained weights ...")
|
|
370
|
+
response = requests.get('https://zenodo.org/records/17487330/files/gsmini_pretrained_weights.pth?download=1', stream=True)
|
|
371
|
+
response.raise_for_status()
|
|
372
|
+
|
|
373
|
+
total_size = int(response.headers.get('content-length', 0))
|
|
374
|
+
block_size = 1024
|
|
375
|
+
|
|
376
|
+
# Save file in chunks to handle large datasets
|
|
377
|
+
with open(file_path, 'wb') as f, tqdm(
|
|
378
|
+
total=total_size,
|
|
379
|
+
unit='B',
|
|
380
|
+
unit_scale=True,
|
|
381
|
+
desc="Downloading",
|
|
382
|
+
ncols=80
|
|
383
|
+
) as progress_bar:
|
|
384
|
+
for chunk in response.iter_content(chunk_size=block_size):
|
|
385
|
+
if chunk:
|
|
386
|
+
f.write(chunk)
|
|
387
|
+
progress_bar.update(len(chunk))
|
|
388
|
+
|
|
389
|
+
print(f"Download complete!")
|
|
390
|
+
else:
|
|
391
|
+
print(f"GelSight Mini pretrained weights already exists at: {file_path}/")
|
|
392
|
+
|
|
393
|
+
state_dict = torch.load(file_path, map_location="cpu")
|
|
394
|
+
self.model.load_state_dict(state_dict)
|
|
395
|
+
|
|
396
|
+
def load_model_weights(self, model_path: Union[str, Path]):
|
|
397
|
+
"""
|
|
398
|
+
Loads custom model weights from a .pth file.
|
|
399
|
+
Args:
|
|
400
|
+
model_path (str or pathlib.Path): Path to the model weights file.
|
|
401
|
+
Returns:
|
|
402
|
+
None.
|
|
403
|
+
"""
|
|
404
|
+
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|
405
|
+
|
|
406
|
+
def _download_dataset(self):
|
|
407
|
+
"""
|
|
408
|
+
Downloads the dataset for either the SensorType.DIGIT or SensorType.GELSIGHTMINI sensor. Used for model training.
|
|
409
|
+
Args:
|
|
410
|
+
None.
|
|
411
|
+
Returns:
|
|
412
|
+
None.
|
|
413
|
+
"""
|
|
414
|
+
if self.sensor_type == SensorType.DIGIT:
|
|
415
|
+
self.dataset_path = os.path.join(self.root, "digit_calibration_data")
|
|
416
|
+
|
|
417
|
+
# Check if self.dataset_path exists
|
|
418
|
+
if not os.path.exists(self.dataset_path):
|
|
419
|
+
os.makedirs(self.root, exist_ok=True)
|
|
420
|
+
|
|
421
|
+
tar_path = os.path.join(self.root, "digit_calibration_data.tar.gz")
|
|
422
|
+
|
|
423
|
+
print(f"Downloading DIGIT dataset ...")
|
|
424
|
+
response = requests.get('https://zenodo.org/records/17487330/files/digit_calibration_data.tar.gz?download=1', stream=True)
|
|
425
|
+
response.raise_for_status()
|
|
426
|
+
|
|
427
|
+
total_size = int(response.headers.get('content-length', 0))
|
|
428
|
+
block_size = 1024
|
|
429
|
+
|
|
430
|
+
# Save file in chunks to handle large datasets
|
|
431
|
+
with open(tar_path, 'wb') as f, tqdm(
|
|
432
|
+
total=total_size,
|
|
433
|
+
unit='B',
|
|
434
|
+
unit_scale=True,
|
|
435
|
+
desc="Downloading",
|
|
436
|
+
ncols=80
|
|
437
|
+
) as progress_bar:
|
|
438
|
+
for chunk in response.iter_content(chunk_size=block_size):
|
|
439
|
+
if chunk:
|
|
440
|
+
f.write(chunk)
|
|
441
|
+
progress_bar.update(len(chunk))
|
|
442
|
+
|
|
443
|
+
print(f"Download complete!")
|
|
444
|
+
|
|
445
|
+
# Extract .tar.gz file
|
|
446
|
+
print("Extracting files ...")
|
|
447
|
+
with tarfile.open(tar_path, "r:gz") as tar:
|
|
448
|
+
tar.extractall(path=self.root)
|
|
449
|
+
|
|
450
|
+
os.remove(tar_path)
|
|
451
|
+
|
|
452
|
+
print(f"Extraction complete! Files are in: {self.root}/")
|
|
453
|
+
|
|
454
|
+
else:
|
|
455
|
+
print(f"DIGIT dataset already exists at: {self.dataset_path}/")
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
elif self.sensor_type == SensorType.GELSIGHTMINI:
|
|
459
|
+
self.dataset_path = os.path.join(self.root, "gsmini_calibration_data")
|
|
460
|
+
|
|
461
|
+
# Check if self.dataset_path exists
|
|
462
|
+
if not os.path.exists(self.dataset_path):
|
|
463
|
+
os.makedirs(self.root, exist_ok=True)
|
|
464
|
+
|
|
465
|
+
tar_path = os.path.join(self.root, "gsmini_calibration_data.tar.gz")
|
|
466
|
+
|
|
467
|
+
print(f"Downloading GelSight Mini dataset ...")
|
|
468
|
+
response = requests.get('https://zenodo.org/records/17487330/files/gsmini_calibration_data.tar.gz?download=1', stream=True)
|
|
469
|
+
response.raise_for_status()
|
|
470
|
+
|
|
471
|
+
total_size = int(response.headers.get('content-length', 0))
|
|
472
|
+
block_size = 1024
|
|
473
|
+
|
|
474
|
+
# Save file in chunks to handle large datasets
|
|
475
|
+
with open(tar_path, 'wb') as f, tqdm(
|
|
476
|
+
total=total_size,
|
|
477
|
+
unit='B',
|
|
478
|
+
unit_scale=True,
|
|
479
|
+
desc="Downloading",
|
|
480
|
+
ncols=80
|
|
481
|
+
) as progress_bar:
|
|
482
|
+
for chunk in response.iter_content(chunk_size=block_size):
|
|
483
|
+
if chunk:
|
|
484
|
+
f.write(chunk)
|
|
485
|
+
progress_bar.update(len(chunk))
|
|
486
|
+
|
|
487
|
+
print(f"Download complete!")
|
|
488
|
+
|
|
489
|
+
# Extract .tar.gz file
|
|
490
|
+
print("Extracting files ...")
|
|
491
|
+
with tarfile.open(tar_path, "r:gz") as tar:
|
|
492
|
+
tar.extractall(path=self.root)
|
|
493
|
+
|
|
494
|
+
os.remove(tar_path)
|
|
495
|
+
|
|
496
|
+
print(f"Extraction complete! Files are in: {self.root}/")
|
|
497
|
+
|
|
498
|
+
else:
|
|
499
|
+
print(f"GelSight Mini dataset already exists at: {self.dataset_path}/")
|
|
500
|
+
|
|
501
|
+
self.annotation_path = os.path.join(self.dataset_path, "annotations", "annotations.csv")
|
|
502
|
+
self.blank_image_path = os.path.join(self.dataset_path, "blank_images", "blank.png")
|
|
503
|
+
|
|
504
|
+
def _fit_circle(self, img_path, instructions, initial_pos=(100,100,30)):
|
|
505
|
+
"""
|
|
506
|
+
Fits a circle to an image.
|
|
507
|
+
Args:
|
|
508
|
+
img_path: Path to the image.
|
|
509
|
+
instructions: Instructions to display on the image.
|
|
510
|
+
initial_pos: Initial position of the circle.
|
|
511
|
+
Returns:
|
|
512
|
+
x: x-coordinate of the circle.
|
|
513
|
+
y: y-coordinate of the circle.
|
|
514
|
+
r: radius of the circle.
|
|
515
|
+
"""
|
|
516
|
+
x, y, r = initial_pos
|
|
517
|
+
|
|
518
|
+
# Read the image
|
|
519
|
+
image = cv2.imread(img_path)
|
|
520
|
+
|
|
521
|
+
# Define rectangle parameters
|
|
522
|
+
rectangle_height = 250 # Adjust as needed
|
|
523
|
+
rectangle_color = (255, 255, 255) # White color (BGR format)
|
|
524
|
+
|
|
525
|
+
# Create a new image with the rectangle
|
|
526
|
+
new_image = np.zeros((image.shape[0] + rectangle_height, image.shape[1], 3), dtype=np.uint8)
|
|
527
|
+
new_image[:image.shape[0], :, :] = image
|
|
528
|
+
cv2.rectangle(new_image, (0, image.shape[0]), (image.shape[1], image.shape[0] + rectangle_height), rectangle_color, -1)
|
|
529
|
+
|
|
530
|
+
# Add text inside the rectangle
|
|
531
|
+
for i, line in enumerate(instructions.split('\n')):
|
|
532
|
+
str_y = image.shape[0] + 40 + i*30
|
|
533
|
+
cv2.putText(new_image, line, (10, str_y ), cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
|
|
534
|
+
|
|
535
|
+
while True:
|
|
536
|
+
# Create a copy of the original image
|
|
537
|
+
annotated_image = new_image.copy()
|
|
538
|
+
|
|
539
|
+
# Draw the circle
|
|
540
|
+
cv2.circle(annotated_image, (x, y), r, (0, 0, 255), 1)
|
|
541
|
+
|
|
542
|
+
# Display the image
|
|
543
|
+
cv2.imshow('Circle Fitting', annotated_image)
|
|
544
|
+
|
|
545
|
+
# Update circle position based on key presses
|
|
546
|
+
key = cv2.waitKey(1) & 0xFF
|
|
547
|
+
if key == ord('q'):
|
|
548
|
+
break
|
|
549
|
+
elif key == ord('w'):
|
|
550
|
+
y -= 1
|
|
551
|
+
elif key == ord('s'):
|
|
552
|
+
y += 1
|
|
553
|
+
elif key == ord('a'):
|
|
554
|
+
x -= 1
|
|
555
|
+
elif key == ord('d'):
|
|
556
|
+
x += 1
|
|
557
|
+
elif key == ord('r'):
|
|
558
|
+
r += 1
|
|
559
|
+
elif key == ord('f'):
|
|
560
|
+
r -= 1
|
|
561
|
+
|
|
562
|
+
return x, y, r
|
|
563
|
+
|
|
564
|
+
def _scale_px_to_mm(self, root_dir, csv_path, instructions, initial_val, anchor_idx, circle_vals):
|
|
565
|
+
"""
|
|
566
|
+
Scales the pixel-to-millimeter calibration.
|
|
567
|
+
Args:
|
|
568
|
+
root_dir: Path to the root directory.
|
|
569
|
+
csv_path: Path to the CSV file.
|
|
570
|
+
instructions: Instructions to display on the image.
|
|
571
|
+
initial_val: Initial value of the pixel/mm ratio.
|
|
572
|
+
anchor_idx: Index of the anchor image.
|
|
573
|
+
circle_vals: Values of the circles.
|
|
574
|
+
Returns:
|
|
575
|
+
px_per_mm: Pixel/millimeter ratio.
|
|
576
|
+
calibration_data: Calibration data.
|
|
577
|
+
"""
|
|
578
|
+
# Convert data labels into dataframe
|
|
579
|
+
calibration_data = pd.read_csv(csv_path)
|
|
580
|
+
|
|
581
|
+
achor_image_path = os.path.join(root_dir, calibration_data["img_name"][anchor_idx])
|
|
582
|
+
anchor_image = cv2.imread(achor_image_path)
|
|
583
|
+
image_list = [anchor_idx]
|
|
584
|
+
|
|
585
|
+
height, width, _ = anchor_image.shape
|
|
586
|
+
|
|
587
|
+
# Define rectangle parameters
|
|
588
|
+
rectangle_height = 250 # Adjust as needed
|
|
589
|
+
rectangle_color = (255, 255, 255) # White color (BGR format)
|
|
590
|
+
|
|
591
|
+
# Create a new image with the rectangle
|
|
592
|
+
new_image = np.zeros((height * 3 + rectangle_height, width * 3, 3), dtype=np.uint8)
|
|
593
|
+
new_image[:height, :width, :] = anchor_image
|
|
594
|
+
cv2.rectangle(new_image, (0, height * 3), (width * 3, height * 3 + rectangle_height), rectangle_color, -1)
|
|
595
|
+
|
|
596
|
+
px_per_mm = initial_val
|
|
597
|
+
circle1_x, circle1_y, circle1_r = circle_vals
|
|
598
|
+
img1_x_mm = calibration_data["x_mm"][anchor_idx]
|
|
599
|
+
img1_y_mm = calibration_data["y_mm"][anchor_idx]
|
|
600
|
+
|
|
601
|
+
calibration_data['x_px'] = circle1_x + (calibration_data['x_mm'] - img1_x_mm) * px_per_mm
|
|
602
|
+
calibration_data['y_px'] = circle1_y + (img1_y_mm - calibration_data['y_mm']) * px_per_mm
|
|
603
|
+
|
|
604
|
+
idx = 1
|
|
605
|
+
|
|
606
|
+
while len(image_list) < 9:
|
|
607
|
+
random_row = calibration_data.sample(n=1)
|
|
608
|
+
|
|
609
|
+
if random_row.iloc[0,4] < width and random_row.iloc[0,5] < height:
|
|
610
|
+
image_path = os.path.join(root_dir, random_row.iloc[0,0])
|
|
611
|
+
image = cv2.imread(image_path)
|
|
612
|
+
|
|
613
|
+
image_list += [random_row.index[0]]
|
|
614
|
+
|
|
615
|
+
row = math.floor(idx / 3)
|
|
616
|
+
column = idx % 3
|
|
617
|
+
|
|
618
|
+
new_image[(height * row):(height * (row+1)), (width * column):(width * (column+1)), :] = image
|
|
619
|
+
|
|
620
|
+
idx += 1
|
|
621
|
+
|
|
622
|
+
# Add text inside the rectangle
|
|
623
|
+
for i, line in enumerate(instructions.split('\n')):
|
|
624
|
+
str_y = height * 3 + 40 + i*30
|
|
625
|
+
cv2.putText(new_image, line, (10, str_y ), cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
|
|
626
|
+
|
|
627
|
+
# Display the image
|
|
628
|
+
cv2.imshow('Circle Fitting', new_image)
|
|
629
|
+
|
|
630
|
+
while True:
|
|
631
|
+
# Create a copy of the original image
|
|
632
|
+
annotated_image = new_image.copy()
|
|
633
|
+
|
|
634
|
+
calibration_data['x_px'] = circle1_x + (calibration_data['x_mm'] - img1_x_mm) * px_per_mm
|
|
635
|
+
calibration_data['y_px'] = circle1_y + (img1_y_mm - calibration_data['y_mm']) * px_per_mm
|
|
636
|
+
|
|
637
|
+
# Draw the circles
|
|
638
|
+
for i in range(9):
|
|
639
|
+
row = math.floor(i / 3)
|
|
640
|
+
column = i % 3
|
|
641
|
+
|
|
642
|
+
idx = image_list[i]
|
|
643
|
+
x = int(calibration_data['x_px'][idx]) + column * width
|
|
644
|
+
y = int(calibration_data['y_px'][idx]) + row * height
|
|
645
|
+
|
|
646
|
+
cv2.circle(annotated_image, (x, y), circle1_r, (0, 0, 255), 1)
|
|
647
|
+
|
|
648
|
+
# Display the image
|
|
649
|
+
cv2.imshow('Circle Fitting', annotated_image)
|
|
650
|
+
|
|
651
|
+
# Update circle position based on key presses
|
|
652
|
+
key = cv2.waitKey(1) & 0xFF
|
|
653
|
+
if key == ord('q'):
|
|
654
|
+
break
|
|
655
|
+
elif key == ord('w'):
|
|
656
|
+
circle1_y -= 1
|
|
657
|
+
elif key == ord('s'):
|
|
658
|
+
circle1_y += 1
|
|
659
|
+
elif key == ord('a'):
|
|
660
|
+
circle1_x -= 1
|
|
661
|
+
elif key == ord('d'):
|
|
662
|
+
circle1_x += 1
|
|
663
|
+
elif key == ord('r'):
|
|
664
|
+
px_per_mm += 1
|
|
665
|
+
elif key == ord('f'):
|
|
666
|
+
px_per_mm -= 1
|
|
667
|
+
|
|
668
|
+
return px_per_mm, calibration_data
|
|
669
|
+
|
|
670
|
+
def annotate_data(self, dataset_path, csv_file="sensor_data.csv", img_idxs=None):
|
|
671
|
+
"""
|
|
672
|
+
Tool to annotate custom dataset with pixel-to-millimeter calibration.
|
|
673
|
+
Creates an annotated_data.csv file required for training.
|
|
674
|
+
|
|
675
|
+
Controls:
|
|
676
|
+
- w/s: Move circle up/down
|
|
677
|
+
- a/d: Move circle left/right
|
|
678
|
+
- r/f: Increase/decrease circle size or pixel/mm ratio
|
|
679
|
+
- q: Proceed to next step
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
dataset_path: Path to the dataset directory containing images and CSV file.
|
|
683
|
+
csv_file: Name of the CSV file containing sensor data. Default: "sensor_data.csv".
|
|
684
|
+
img_idxs: Tuple of two image indices to use for calibration. Default: None (auto-selected at 25th and 75th percentiles).
|
|
685
|
+
|
|
686
|
+
Returns:
|
|
687
|
+
Saves annotated_data.csv in the dataset_path directory with pixel coordinates.
|
|
688
|
+
"""
|
|
689
|
+
# Convert data labels into dataframe
|
|
690
|
+
csv_path = os.path.join(dataset_path, csv_file)
|
|
691
|
+
calibration_data = pd.read_csv(csv_path)
|
|
692
|
+
|
|
693
|
+
# Extract data from middle of the sensor (media Y value)
|
|
694
|
+
middle_row = calibration_data.loc[calibration_data['y_mm'] == calibration_data["y_mm"].median()]
|
|
695
|
+
|
|
696
|
+
if img_idxs is None:
|
|
697
|
+
# Get the indices of the 25th percentile and 75th percentile X values
|
|
698
|
+
img1 = middle_row.loc[middle_row['x_mm'] == middle_row['x_mm'].quantile(0.25)]
|
|
699
|
+
img2 = middle_row.loc[middle_row['x_mm'] == middle_row['x_mm'].quantile(0.75)]
|
|
700
|
+
idx_1 = img1.index[0]
|
|
701
|
+
idx_2 = img2.index[0]
|
|
702
|
+
else:
|
|
703
|
+
idx_1 = img_idxs[0]
|
|
704
|
+
idx_2 = img_idxs[1]
|
|
705
|
+
|
|
706
|
+
# Get the image names and XY probe coordinates
|
|
707
|
+
image1_name = os.path.join(dataset_path, calibration_data.iloc[idx_1, 0])
|
|
708
|
+
img1_x_mm = calibration_data.iloc[idx_1, 1]
|
|
709
|
+
img1_y_mm = calibration_data.iloc[idx_1, 2]
|
|
710
|
+
|
|
711
|
+
image2_name = os.path.join(dataset_path, calibration_data.iloc[idx_2, 0])
|
|
712
|
+
img2_x_mm = calibration_data.iloc[idx_2, 1]
|
|
713
|
+
img2_y_mm = calibration_data.iloc[idx_2, 2]
|
|
714
|
+
|
|
715
|
+
# Present the images to the user and have them fit the circles
|
|
716
|
+
circle1_x, circle1_y, circle1_r = self._fit_circle(image1_name, instructions="w: Up\ns: Down\na: Left\nd: Right\nr: Bigger\nf: Smaller\nq: Next")
|
|
717
|
+
circle2_x, circle2_y, circle2_r = self._fit_circle(image2_name, instructions="w: Up\ns: Down\na: Left\nd: Right\nr: Bigger\nf: Smaller\nq: Next", initial_pos=(circle1_x, circle1_y, circle1_r))
|
|
718
|
+
|
|
719
|
+
print(circle1_x, circle1_y, circle1_r)
|
|
720
|
+
print(circle2_x, circle2_y, circle2_r)
|
|
721
|
+
|
|
722
|
+
# Close all opencv windows
|
|
723
|
+
cv2.destroyAllWindows()
|
|
724
|
+
|
|
725
|
+
# Calculate pixels/mm
|
|
726
|
+
known_x_distance = abs(img2_x_mm - img1_x_mm) # mm
|
|
727
|
+
px_per_mm = abs(circle2_x - circle1_x) / known_x_distance
|
|
728
|
+
|
|
729
|
+
px_per_mm, calibration_data = self._scale_px_to_mm(root_dir=dataset_path, csv_path=csv_path, instructions="w: Up\ns: Down\na: Left\nd: Right\nr: Increase pixel/mm value\nf: Decrease pixel/mm value\nq: Quit", initial_val=px_per_mm, anchor_idx=idx_1, circle_vals=(circle1_x, circle1_y, circle1_r))
|
|
730
|
+
|
|
731
|
+
# Print out the pixels/mm value
|
|
732
|
+
print("pixels per mm:", px_per_mm)
|
|
733
|
+
|
|
734
|
+
# Create CSV file with annotated data
|
|
735
|
+
annotated_file_path = os.path.join(dataset_path, "annotated_data.csv")
|
|
736
|
+
calibration_data.to_csv(annotated_file_path, index=False)
|