py3dcal 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. py3DCal/__init__.py +12 -0
  2. py3DCal/data_collection/Calibrator.py +298 -0
  3. py3DCal/data_collection/Printers/Ender3/Ender3.py +82 -0
  4. py3DCal/data_collection/Printers/Ender3/__init__.py +0 -0
  5. py3DCal/data_collection/Printers/Printer.py +63 -0
  6. py3DCal/data_collection/Printers/__init__.py +0 -0
  7. py3DCal/data_collection/Sensors/DIGIT/DIGIT.py +47 -0
  8. py3DCal/data_collection/Sensors/DIGIT/__init__.py +0 -0
  9. py3DCal/data_collection/Sensors/DIGIT/default.csv +1222 -0
  10. py3DCal/data_collection/Sensors/GelsightMini/GelsightMini.py +45 -0
  11. py3DCal/data_collection/Sensors/GelsightMini/__init__.py +0 -0
  12. py3DCal/data_collection/Sensors/GelsightMini/default.csv +1210 -0
  13. py3DCal/data_collection/Sensors/Sensor.py +35 -0
  14. py3DCal/data_collection/Sensors/__init__.py +0 -0
  15. py3DCal/data_collection/__init__.py +0 -0
  16. py3DCal/model_training/__init__.py +0 -0
  17. py3DCal/model_training/datasets/DIGIT_dataset.py +75 -0
  18. py3DCal/model_training/datasets/GelSightMini_dataset.py +73 -0
  19. py3DCal/model_training/datasets/__init__.py +3 -0
  20. py3DCal/model_training/datasets/split_dataset.py +38 -0
  21. py3DCal/model_training/datasets/tactile_sensor_dataset.py +82 -0
  22. py3DCal/model_training/lib/__init__.py +0 -0
  23. py3DCal/model_training/lib/add_coordinate_embeddings.py +29 -0
  24. py3DCal/model_training/lib/depthmaps.py +74 -0
  25. py3DCal/model_training/lib/fast_poisson.py +51 -0
  26. py3DCal/model_training/lib/get_gradient_map.py +39 -0
  27. py3DCal/model_training/lib/precompute_gradients.py +61 -0
  28. py3DCal/model_training/lib/train_model.py +96 -0
  29. py3DCal/model_training/lib/validate_device.py +22 -0
  30. py3DCal/model_training/lib/validate_parameters.py +45 -0
  31. py3DCal/model_training/models/__init__.py +1 -0
  32. py3DCal/model_training/models/touchnet.py +211 -0
  33. py3DCal/model_training/touchnet/__init__.py +0 -0
  34. py3DCal/model_training/touchnet/dataset.py +78 -0
  35. py3DCal/model_training/touchnet/touchnet.py +736 -0
  36. py3DCal/model_training/touchnet/touchnet_architecture.py +72 -0
  37. py3DCal/utils/__init__.py +0 -0
  38. py3DCal/utils/utils.py +32 -0
  39. py3dcal-1.0.0.dist-info/LICENSE +21 -0
  40. py3dcal-1.0.0.dist-info/METADATA +29 -0
  41. py3dcal-1.0.0.dist-info/RECORD +44 -0
  42. py3dcal-1.0.0.dist-info/WHEEL +5 -0
  43. py3dcal-1.0.0.dist-info/entry_points.txt +3 -0
  44. py3dcal-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,736 @@
1
+ # touchnet.py
2
+ from enum import Enum
3
+ import os
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.hub import load_state_dict_from_url
9
+ import torch.optim as optim
10
+ from torchvision import transforms
11
+ import cv2
12
+ import pandas as pd
13
+ from sklearn.model_selection import train_test_split
14
+ from torch.utils.data import Subset, DataLoader
15
+ import math
16
+ import copy
17
+ import matplotlib.pyplot as plt
18
+ import requests
19
+ import tarfile
20
+ from tqdm import tqdm
21
+ from pathlib import Path
22
+ from typing import Union
23
+ from .dataset import TactileSensorDataset
24
+ from .touchnet_architecture import TouchNetArchitecture
25
+ from ..lib.fast_poisson import fast_poisson
26
+
27
+ class SensorType(Enum):
28
+ """
29
+ SensorType: Available sensor types with pretrained weights and compiled datasets
30
+ """
31
+ DIGIT = "DIGIT"
32
+ GELSIGHTMINI = "GelSightMini"
33
+ CUSTOM = "Custom"
34
+
35
+ class TouchNet:
36
+ """
37
+ TouchNet: A Deep Learning Model for Enhanced Calibration and Sensing of Vision-Based Tactile Sensors
38
+ Args:
39
+ root (str or pathlib.Path, optional): Root directory for datasets and models. Defaults to current directory.
40
+ sensor_type (py3DCal.SensorType, optional): Type of tactile sensor. Defaults to py3DCal.SensorType.CUSTOM.
41
+ load_pretrained_model (bool, optional): If True, loads the pretrained model for the specified sensor type. Defaults to False.
42
+ download_dataset (bool, optional): If True, downloads the dataset for the specified sensor type. Defaults to False.
43
+ device (str, optional): Device to run the model on. Defaults to "cpu".
44
+ """
45
+ def __init__(self, root: Union[str, Path] = Path("."), sensor_type: SensorType = SensorType.CUSTOM, load_pretrained_model: bool = False, download_dataset: bool = False, device: str = "cpu"):
46
+
47
+ self._validate_parameters(sensor_type, load_pretrained_model, download_dataset, device)
48
+
49
+ self.model = TouchNetArchitecture()
50
+ self.sensor_type = sensor_type
51
+ self.load_pretrained_model = load_pretrained_model
52
+ self.download_dataset = download_dataset
53
+ self.transform = transforms.Compose([transforms.ToTensor()])
54
+ self.root = root
55
+ if sensor_type == SensorType.DIGIT:
56
+ self.dataset_path = os.path.join(root, "digit_calibration_data")
57
+ elif sensor_type == SensorType.GELSIGHTMINI:
58
+ self.dataset_path = os.path.join(root, "gsmini_calibration_data")
59
+ else:
60
+ self.dataset_path = "."
61
+ self.blank_image_path = os.path.join(self.dataset_path, "blank_images", "blank.png")
62
+ self.annotation_path = os.path.join(self.dataset_path, "annotations", "annotations.csv")
63
+ self.device = device
64
+
65
+ if self.load_pretrained_model:
66
+ self._load_pretrained_model()
67
+
68
+ self.model.to(self.device)
69
+
70
+ if self.download_dataset:
71
+ self._download_dataset()
72
+
73
+
74
+ def _validate_parameters(self, sensor_type: SensorType, load_pretrained_model: bool, download_dataset: bool, device: str):
75
+ """
76
+ Validates the parameters.
77
+ Args:
78
+ sensor_type (py3DCal.SensorType): Type of tactile sensor.
79
+ load_pretrained_model (bool): If True, loads the pretrained model for the specified sensor type.
80
+ download_dataset (bool): If True, downloads the dataset for the specified sensor type.
81
+ device (str): Device to run the model on.
82
+ Returns:
83
+ None.
84
+ """
85
+ self._validate_sensor_type(sensor_type)
86
+ self._validate_device(device)
87
+ self._validate_load_pretrained_download_dataset(sensor_type, load_pretrained_model, download_dataset)
88
+
89
+
90
+ def _validate_device(self, device: str):
91
+ """
92
+ Validates the device by converting it to a torch.device object.
93
+ Args:
94
+ device (str): Device to run the model on.
95
+ Returns:
96
+ None.
97
+ Raises:
98
+ ValueError: If the device is not specified or invalid.
99
+ """
100
+ try:
101
+ device = torch.device(device)
102
+ except Exception as e:
103
+ raise ValueError(
104
+ f"Invalid device '{device}'. Valid options include:\n"
105
+ " - 'cpu': CPU processing\n"
106
+ " - 'cuda' or 'cuda:0': NVIDIA GPU\n"
107
+ " - 'mps': Apple Silicon GPU\n"
108
+ "See: https://pytorch.org/docs/stable/tensor_attributes.html#torch.device"
109
+ ) from e
110
+
111
+ def _validate_sensor_type(self, sensor_type):
112
+ """
113
+ Validates the sensor type.
114
+ Args:
115
+ sensor_type (py3DCal.SensorType): Type of tactile sensor.
116
+ Returns:
117
+ None.
118
+ Raises:
119
+ ValueError: If the sensor type is not specified or invalid.
120
+ """
121
+ if sensor_type is None or sensor_type not in [SensorType.DIGIT, SensorType.GELSIGHTMINI, SensorType.CUSTOM]:
122
+ raise ValueError(f"Invalid sensor type: {sensor_type}. Sensor type must be either {SensorType.DIGIT}, {SensorType.GELSIGHTMINI}, or {SensorType.CUSTOM}.")
123
+
124
+ def _validate_load_pretrained_download_dataset(self, sensor_type: SensorType, load_pretrained_model: bool, download_dataset: bool):
125
+ """
126
+ Validates that load pretrained and donwload dataset are being called with either SensorType.DIGIT or SensorType.GELSIGHTMINI.
127
+ Args:
128
+ sensor_type (py3DCal.SensorType): Type of tactile sensor.
129
+ load_pretrained_model (bool): If True, loads the pretrained model for the specified sensor type.
130
+ download_dataset (bool): If True, downloads the dataset for the specified sensor type.
131
+ Returns:
132
+ None.
133
+ Raises:
134
+ ValueError: If load pretrained and download dataset are being called with SensorType.CUSTOM.
135
+ """
136
+ if (load_pretrained_model is True or download_dataset is True) and (sensor_type is not SensorType.DIGIT and sensor_type is not SensorType.GELSIGHTMINI):
137
+ raise ValueError("Cannot load pretrained model or download dataset for custom sensor type. Sensor type must be either SensorType.DIGIT or SensorType.GELSIGHTMINI.")
138
+
139
+ def set_dataset_path(self, dataset_path: Union[str, Path]):
140
+ """
141
+ Set the dataset path for custom datasets.
142
+ Args:
143
+ dataset_path (str or pathlib.Path): Path to the dataset.
144
+ Returns:
145
+ None.
146
+ """
147
+ self.dataset_path = dataset_path
148
+ self.annotation_path = os.path.join(dataset_path, "annotated_data.csv")
149
+ # self.blank_image_path = os.path.join(dataset_path, "blank.png")
150
+ self.blank_image_path = "../data/sensors/DIGIT/Images/blank.png"
151
+
152
+ def set_blank_image_path(self, blank_image_path: Union[str, Path]):
153
+ """
154
+ Set the blank image path for custom datasets.
155
+ Args:
156
+ blank_image_path (str or pathlib.Path): Path to the blank image.
157
+ Returns:
158
+ None.
159
+ """
160
+ self.blank_image_path = blank_image_path
161
+
162
+ def train(self, num_epochs: int = 60, batch_size: int = 64, learning_rate: float = 1e-4, train_split: float = 0.8, loss_fn: nn.Module = nn.MSELoss()):
163
+ """
164
+ Train TouchNet model on a dataset for 60 epochs with a
165
+ 64 batch size, and AdamW optimizer with learning rate 1e-4.
166
+
167
+ Args:
168
+ num_epochs (int): Number of epochs to train for. Defaults to 60.
169
+ batch_size (int): Batch size. Defaults to 64.
170
+ learning_rate (float): Learning rate. Defaults to 1e-4.
171
+ train_split (float): Proportion of data to use for training. Defaults to 0.8.
172
+ loss_fn (nn.Module): Loss function. Defaults to nn.MSELoss().
173
+
174
+ Outputs:
175
+ weights.pth: Trained model weights.
176
+ loss.csv: Training and testing losses.
177
+ """
178
+ optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-4)
179
+ df = pd.read_csv(self.annotation_path, comment='#')
180
+ unique_coords = df[['x_mm', 'y_mm']].drop_duplicates()
181
+ coord_tuples = [(row['x_mm'], row['y_mm']) for _, row in unique_coords.iterrows()]
182
+ train_coords, test_coords = train_test_split(coord_tuples, train_size=train_split, random_state=42)
183
+ train_coords_set = set(train_coords)
184
+ test_coords_set = set(test_coords)
185
+ train_idx = []
186
+ test_idx = []
187
+ for i in range(len(df)):
188
+ coord = (df.loc[i, 'x_mm'], df.loc[i, 'y_mm'])
189
+ if coord in train_coords_set:
190
+ train_idx.append(i)
191
+ elif coord in test_coords_set:
192
+ test_idx.append(i)
193
+ transform = transforms.Compose([transforms.ToTensor()])
194
+ dataset = TactileSensorDataset(dataset_path=os.path.join(self.dataset_path, "probe_images"), annotation_path=self.annotation_path, blank_image_path=self.blank_image_path, transform=transform)
195
+ train_dataset = Subset(dataset, train_idx)
196
+ test_dataset = Subset(dataset, test_idx)
197
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, persistent_workers=True)
198
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True, persistent_workers=True)
199
+ epoch_train_losses = []
200
+ epoch_test_losses = []
201
+ for epoch in range(num_epochs):
202
+ self.model.train()
203
+ train_loss = 0.0
204
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
205
+ inputs = inputs.to(torch.float32).to(self.device)
206
+ targets = targets.to(torch.float32).to(self.device)
207
+ optimizer.zero_grad()
208
+ outputs = self.model(inputs)
209
+
210
+ loss = loss_fn(outputs, targets)
211
+ loss.backward()
212
+ optimizer.step()
213
+
214
+ train_loss += loss.item()
215
+
216
+ print(f" [Batch {batch_idx}/{len(train_loader)}] - Loss: {loss.item():.4f}")
217
+ avg_train_loss = train_loss / len(train_loader)
218
+ epoch_train_losses.append(avg_train_loss)
219
+ print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f}")
220
+ self.model.eval()
221
+ test_loss = 0.0
222
+ with torch.no_grad():
223
+ for inputs, targets in test_loader:
224
+ inputs = inputs.to(torch.float32).to(self.device)
225
+ targets = targets.to(torch.float32).to(self.device)
226
+ outputs = self.model(inputs)
227
+ loss = loss_fn(outputs, targets)
228
+ test_loss += loss.item()
229
+
230
+ avg_test_loss = test_loss / len(test_loader)
231
+ epoch_test_losses.append(avg_test_loss)
232
+ print(f"TEST LOSS: {avg_test_loss:.4f}")
233
+
234
+ with open("losses.csv", "w") as f:
235
+ f.write("epoch,train_loss,test_loss\n")
236
+ for i in range(len(epoch_train_losses)):
237
+ f.write(f"{i+1},{epoch_train_losses[i]},{epoch_test_losses[i]}\n")
238
+ torch.save(self.model.state_dict(), "weights.pth")
239
+
240
+ def _add_coordinate_channels(self, image: torch.Tensor) -> torch.Tensor:
241
+ """
242
+ Adds positional embedding to the input image by appending x and y coordinate channels.
243
+
244
+ Args:
245
+ image (torch.Tensor): Input image tensor of shape (C, H, W).
246
+
247
+ Returns:
248
+ torch.Tensor: Image tensor with added coordinate channels of shape (C+2, H, W).
249
+ - X channel: column indices (1s in first column, 2s in second column, etc.)
250
+ - Y channel: row indices (1s in first row, 2s in second row, etc.)
251
+ """
252
+ # Get image dimensions
253
+ _, height, width = image.shape
254
+
255
+ # Create x coordinate channel (column indices)
256
+ x_coords = torch.arange(1, width + 1, dtype=torch.float32).unsqueeze(0).repeat(height, 1)
257
+ x_channel = x_coords.unsqueeze(0) # Add channel dimension
258
+
259
+ # Create y coordinate channel (row indices)
260
+ y_coords = torch.arange(1, height + 1, dtype=torch.float32).unsqueeze(1).repeat(1, width)
261
+ y_channel = y_coords.unsqueeze(0) # Add channel dimension
262
+
263
+ # Concatenate original image with coordinate channels
264
+ image_with_coords = torch.cat([image, x_channel, y_channel], dim=0)
265
+
266
+ return image_with_coords
267
+
268
+ def get_depthmap(self, image_path: str) -> np.ndarray:
269
+ """
270
+ Returns the depthmap for a given input image.
271
+ Args:
272
+ image_path (str): Path to the input image.
273
+ Returns:
274
+ depthmap (numpy.ndarray): The computed depthmap.
275
+ """
276
+ model = self.model.to(self.device)
277
+ model.eval()
278
+ image = Image.open(image_path)
279
+ image_pil = image.convert("RGB")
280
+ transformed_image = self.transform(image_pil)
281
+ blank_image = Image.open(self.blank_image_path)
282
+ blank_image_pil = blank_image.convert("RGB")
283
+ transformed_blank_image = self.transform(blank_image_pil)
284
+ augmented_image = transformed_image - transformed_blank_image
285
+ augmented_image = self._add_coordinate_channels(augmented_image)
286
+ augmented_image = augmented_image.unsqueeze(0).to(self.device)
287
+
288
+ with torch.no_grad():
289
+ output = model(augmented_image)
290
+
291
+ output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
292
+
293
+ depthmap = fast_poisson(output[:,:,0], output[:,:,1])
294
+
295
+ depthmap = np.clip(-depthmap, a_min=0, a_max=None)
296
+
297
+ return depthmap
298
+
299
+ def save_depthmap_image(self, image_path: str, save_path: Union[str, Path] = Path("depthmap.png")):
300
+ """
301
+ Save an image of the depthmap for a given input image.
302
+ Args:
303
+ image_path (str): Path to the input image.
304
+ save_path (str or pathlib.Path): Path to save the depthmap image.
305
+ Returns:
306
+ None.
307
+ """
308
+ depthmap = self.get_depthmap(image_path)
309
+
310
+ plt.imsave(save_path, depthmap, cmap='viridis')
311
+
312
+ def show_depthmap(self, image_path: str):
313
+ """
314
+ Show the depthmap for a given input image.
315
+ Args:
316
+ image_path (str): Path to the input image.
317
+ Returns:
318
+ None.
319
+ """
320
+ depthmap = self.get_depthmap(image_path)
321
+
322
+ plt.imshow(depthmap)
323
+ plt.show()
324
+
325
+ def _load_pretrained_model(self):
326
+ """
327
+ Loads a pretrained model for either the DIGIT or GelSightMini sensor.
328
+ Args:
329
+ None.
330
+ Returns:
331
+ None.
332
+ """
333
+ if self.sensor_type == SensorType.DIGIT:
334
+ file_path = os.path.join(self.root, "digit_pretrained_weights.pth")
335
+
336
+ # Check if DIGIT pretrained weights exist locally, if not download them
337
+ if not os.path.exists(file_path):
338
+
339
+ print(f"Downloading DIGIT pretrained weights ...")
340
+ response = requests.get('https://zenodo.org/records/17487330/files/digit_pretrained_weights.pth?download=1', stream=True)
341
+ response.raise_for_status()
342
+
343
+ total_size = int(response.headers.get('content-length', 0))
344
+ block_size = 1024
345
+
346
+ # Save file in chunks to handle large datasets
347
+ with open(file_path, 'wb') as f, tqdm(
348
+ total=total_size,
349
+ unit='B',
350
+ unit_scale=True,
351
+ desc="Downloading",
352
+ ncols=80
353
+ ) as progress_bar:
354
+ for chunk in response.iter_content(chunk_size=block_size):
355
+ if chunk:
356
+ f.write(chunk)
357
+ progress_bar.update(len(chunk))
358
+
359
+ print(f"Download complete!")
360
+ else:
361
+ print(f"DIGIT pretrained weights already exists at: {file_path}/")
362
+
363
+ elif self.sensor_type == SensorType.GELSIGHTMINI:
364
+ file_path = os.path.join(self.root, "gsmini_pretrained_weights.pth")
365
+
366
+ # Check if GelSight Mini pretrained weights exist locally, if not download them
367
+ if not os.path.exists(file_path):
368
+
369
+ print(f"Downloading GelSight Mini pretrained weights ...")
370
+ response = requests.get('https://zenodo.org/records/17487330/files/gsmini_pretrained_weights.pth?download=1', stream=True)
371
+ response.raise_for_status()
372
+
373
+ total_size = int(response.headers.get('content-length', 0))
374
+ block_size = 1024
375
+
376
+ # Save file in chunks to handle large datasets
377
+ with open(file_path, 'wb') as f, tqdm(
378
+ total=total_size,
379
+ unit='B',
380
+ unit_scale=True,
381
+ desc="Downloading",
382
+ ncols=80
383
+ ) as progress_bar:
384
+ for chunk in response.iter_content(chunk_size=block_size):
385
+ if chunk:
386
+ f.write(chunk)
387
+ progress_bar.update(len(chunk))
388
+
389
+ print(f"Download complete!")
390
+ else:
391
+ print(f"GelSight Mini pretrained weights already exists at: {file_path}/")
392
+
393
+ state_dict = torch.load(file_path, map_location="cpu")
394
+ self.model.load_state_dict(state_dict)
395
+
396
+ def load_model_weights(self, model_path: Union[str, Path]):
397
+ """
398
+ Loads custom model weights from a .pth file.
399
+ Args:
400
+ model_path (str or pathlib.Path): Path to the model weights file.
401
+ Returns:
402
+ None.
403
+ """
404
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
405
+
406
+ def _download_dataset(self):
407
+ """
408
+ Downloads the dataset for either the SensorType.DIGIT or SensorType.GELSIGHTMINI sensor. Used for model training.
409
+ Args:
410
+ None.
411
+ Returns:
412
+ None.
413
+ """
414
+ if self.sensor_type == SensorType.DIGIT:
415
+ self.dataset_path = os.path.join(self.root, "digit_calibration_data")
416
+
417
+ # Check if self.dataset_path exists
418
+ if not os.path.exists(self.dataset_path):
419
+ os.makedirs(self.root, exist_ok=True)
420
+
421
+ tar_path = os.path.join(self.root, "digit_calibration_data.tar.gz")
422
+
423
+ print(f"Downloading DIGIT dataset ...")
424
+ response = requests.get('https://zenodo.org/records/17487330/files/digit_calibration_data.tar.gz?download=1', stream=True)
425
+ response.raise_for_status()
426
+
427
+ total_size = int(response.headers.get('content-length', 0))
428
+ block_size = 1024
429
+
430
+ # Save file in chunks to handle large datasets
431
+ with open(tar_path, 'wb') as f, tqdm(
432
+ total=total_size,
433
+ unit='B',
434
+ unit_scale=True,
435
+ desc="Downloading",
436
+ ncols=80
437
+ ) as progress_bar:
438
+ for chunk in response.iter_content(chunk_size=block_size):
439
+ if chunk:
440
+ f.write(chunk)
441
+ progress_bar.update(len(chunk))
442
+
443
+ print(f"Download complete!")
444
+
445
+ # Extract .tar.gz file
446
+ print("Extracting files ...")
447
+ with tarfile.open(tar_path, "r:gz") as tar:
448
+ tar.extractall(path=self.root)
449
+
450
+ os.remove(tar_path)
451
+
452
+ print(f"Extraction complete! Files are in: {self.root}/")
453
+
454
+ else:
455
+ print(f"DIGIT dataset already exists at: {self.dataset_path}/")
456
+
457
+
458
+ elif self.sensor_type == SensorType.GELSIGHTMINI:
459
+ self.dataset_path = os.path.join(self.root, "gsmini_calibration_data")
460
+
461
+ # Check if self.dataset_path exists
462
+ if not os.path.exists(self.dataset_path):
463
+ os.makedirs(self.root, exist_ok=True)
464
+
465
+ tar_path = os.path.join(self.root, "gsmini_calibration_data.tar.gz")
466
+
467
+ print(f"Downloading GelSight Mini dataset ...")
468
+ response = requests.get('https://zenodo.org/records/17487330/files/gsmini_calibration_data.tar.gz?download=1', stream=True)
469
+ response.raise_for_status()
470
+
471
+ total_size = int(response.headers.get('content-length', 0))
472
+ block_size = 1024
473
+
474
+ # Save file in chunks to handle large datasets
475
+ with open(tar_path, 'wb') as f, tqdm(
476
+ total=total_size,
477
+ unit='B',
478
+ unit_scale=True,
479
+ desc="Downloading",
480
+ ncols=80
481
+ ) as progress_bar:
482
+ for chunk in response.iter_content(chunk_size=block_size):
483
+ if chunk:
484
+ f.write(chunk)
485
+ progress_bar.update(len(chunk))
486
+
487
+ print(f"Download complete!")
488
+
489
+ # Extract .tar.gz file
490
+ print("Extracting files ...")
491
+ with tarfile.open(tar_path, "r:gz") as tar:
492
+ tar.extractall(path=self.root)
493
+
494
+ os.remove(tar_path)
495
+
496
+ print(f"Extraction complete! Files are in: {self.root}/")
497
+
498
+ else:
499
+ print(f"GelSight Mini dataset already exists at: {self.dataset_path}/")
500
+
501
+ self.annotation_path = os.path.join(self.dataset_path, "annotations", "annotations.csv")
502
+ self.blank_image_path = os.path.join(self.dataset_path, "blank_images", "blank.png")
503
+
504
+ def _fit_circle(self, img_path, instructions, initial_pos=(100,100,30)):
505
+ """
506
+ Fits a circle to an image.
507
+ Args:
508
+ img_path: Path to the image.
509
+ instructions: Instructions to display on the image.
510
+ initial_pos: Initial position of the circle.
511
+ Returns:
512
+ x: x-coordinate of the circle.
513
+ y: y-coordinate of the circle.
514
+ r: radius of the circle.
515
+ """
516
+ x, y, r = initial_pos
517
+
518
+ # Read the image
519
+ image = cv2.imread(img_path)
520
+
521
+ # Define rectangle parameters
522
+ rectangle_height = 250 # Adjust as needed
523
+ rectangle_color = (255, 255, 255) # White color (BGR format)
524
+
525
+ # Create a new image with the rectangle
526
+ new_image = np.zeros((image.shape[0] + rectangle_height, image.shape[1], 3), dtype=np.uint8)
527
+ new_image[:image.shape[0], :, :] = image
528
+ cv2.rectangle(new_image, (0, image.shape[0]), (image.shape[1], image.shape[0] + rectangle_height), rectangle_color, -1)
529
+
530
+ # Add text inside the rectangle
531
+ for i, line in enumerate(instructions.split('\n')):
532
+ str_y = image.shape[0] + 40 + i*30
533
+ cv2.putText(new_image, line, (10, str_y ), cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
534
+
535
+ while True:
536
+ # Create a copy of the original image
537
+ annotated_image = new_image.copy()
538
+
539
+ # Draw the circle
540
+ cv2.circle(annotated_image, (x, y), r, (0, 0, 255), 1)
541
+
542
+ # Display the image
543
+ cv2.imshow('Circle Fitting', annotated_image)
544
+
545
+ # Update circle position based on key presses
546
+ key = cv2.waitKey(1) & 0xFF
547
+ if key == ord('q'):
548
+ break
549
+ elif key == ord('w'):
550
+ y -= 1
551
+ elif key == ord('s'):
552
+ y += 1
553
+ elif key == ord('a'):
554
+ x -= 1
555
+ elif key == ord('d'):
556
+ x += 1
557
+ elif key == ord('r'):
558
+ r += 1
559
+ elif key == ord('f'):
560
+ r -= 1
561
+
562
+ return x, y, r
563
+
564
+ def _scale_px_to_mm(self, root_dir, csv_path, instructions, initial_val, anchor_idx, circle_vals):
565
+ """
566
+ Scales the pixel-to-millimeter calibration.
567
+ Args:
568
+ root_dir: Path to the root directory.
569
+ csv_path: Path to the CSV file.
570
+ instructions: Instructions to display on the image.
571
+ initial_val: Initial value of the pixel/mm ratio.
572
+ anchor_idx: Index of the anchor image.
573
+ circle_vals: Values of the circles.
574
+ Returns:
575
+ px_per_mm: Pixel/millimeter ratio.
576
+ calibration_data: Calibration data.
577
+ """
578
+ # Convert data labels into dataframe
579
+ calibration_data = pd.read_csv(csv_path)
580
+
581
+ achor_image_path = os.path.join(root_dir, calibration_data["img_name"][anchor_idx])
582
+ anchor_image = cv2.imread(achor_image_path)
583
+ image_list = [anchor_idx]
584
+
585
+ height, width, _ = anchor_image.shape
586
+
587
+ # Define rectangle parameters
588
+ rectangle_height = 250 # Adjust as needed
589
+ rectangle_color = (255, 255, 255) # White color (BGR format)
590
+
591
+ # Create a new image with the rectangle
592
+ new_image = np.zeros((height * 3 + rectangle_height, width * 3, 3), dtype=np.uint8)
593
+ new_image[:height, :width, :] = anchor_image
594
+ cv2.rectangle(new_image, (0, height * 3), (width * 3, height * 3 + rectangle_height), rectangle_color, -1)
595
+
596
+ px_per_mm = initial_val
597
+ circle1_x, circle1_y, circle1_r = circle_vals
598
+ img1_x_mm = calibration_data["x_mm"][anchor_idx]
599
+ img1_y_mm = calibration_data["y_mm"][anchor_idx]
600
+
601
+ calibration_data['x_px'] = circle1_x + (calibration_data['x_mm'] - img1_x_mm) * px_per_mm
602
+ calibration_data['y_px'] = circle1_y + (img1_y_mm - calibration_data['y_mm']) * px_per_mm
603
+
604
+ idx = 1
605
+
606
+ while len(image_list) < 9:
607
+ random_row = calibration_data.sample(n=1)
608
+
609
+ if random_row.iloc[0,4] < width and random_row.iloc[0,5] < height:
610
+ image_path = os.path.join(root_dir, random_row.iloc[0,0])
611
+ image = cv2.imread(image_path)
612
+
613
+ image_list += [random_row.index[0]]
614
+
615
+ row = math.floor(idx / 3)
616
+ column = idx % 3
617
+
618
+ new_image[(height * row):(height * (row+1)), (width * column):(width * (column+1)), :] = image
619
+
620
+ idx += 1
621
+
622
+ # Add text inside the rectangle
623
+ for i, line in enumerate(instructions.split('\n')):
624
+ str_y = height * 3 + 40 + i*30
625
+ cv2.putText(new_image, line, (10, str_y ), cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
626
+
627
+ # Display the image
628
+ cv2.imshow('Circle Fitting', new_image)
629
+
630
+ while True:
631
+ # Create a copy of the original image
632
+ annotated_image = new_image.copy()
633
+
634
+ calibration_data['x_px'] = circle1_x + (calibration_data['x_mm'] - img1_x_mm) * px_per_mm
635
+ calibration_data['y_px'] = circle1_y + (img1_y_mm - calibration_data['y_mm']) * px_per_mm
636
+
637
+ # Draw the circles
638
+ for i in range(9):
639
+ row = math.floor(i / 3)
640
+ column = i % 3
641
+
642
+ idx = image_list[i]
643
+ x = int(calibration_data['x_px'][idx]) + column * width
644
+ y = int(calibration_data['y_px'][idx]) + row * height
645
+
646
+ cv2.circle(annotated_image, (x, y), circle1_r, (0, 0, 255), 1)
647
+
648
+ # Display the image
649
+ cv2.imshow('Circle Fitting', annotated_image)
650
+
651
+ # Update circle position based on key presses
652
+ key = cv2.waitKey(1) & 0xFF
653
+ if key == ord('q'):
654
+ break
655
+ elif key == ord('w'):
656
+ circle1_y -= 1
657
+ elif key == ord('s'):
658
+ circle1_y += 1
659
+ elif key == ord('a'):
660
+ circle1_x -= 1
661
+ elif key == ord('d'):
662
+ circle1_x += 1
663
+ elif key == ord('r'):
664
+ px_per_mm += 1
665
+ elif key == ord('f'):
666
+ px_per_mm -= 1
667
+
668
+ return px_per_mm, calibration_data
669
+
670
+ def annotate_data(self, dataset_path, csv_file="sensor_data.csv", img_idxs=None):
671
+ """
672
+ Tool to annotate custom dataset with pixel-to-millimeter calibration.
673
+ Creates an annotated_data.csv file required for training.
674
+
675
+ Controls:
676
+ - w/s: Move circle up/down
677
+ - a/d: Move circle left/right
678
+ - r/f: Increase/decrease circle size or pixel/mm ratio
679
+ - q: Proceed to next step
680
+
681
+ Args:
682
+ dataset_path: Path to the dataset directory containing images and CSV file.
683
+ csv_file: Name of the CSV file containing sensor data. Default: "sensor_data.csv".
684
+ img_idxs: Tuple of two image indices to use for calibration. Default: None (auto-selected at 25th and 75th percentiles).
685
+
686
+ Returns:
687
+ Saves annotated_data.csv in the dataset_path directory with pixel coordinates.
688
+ """
689
+ # Convert data labels into dataframe
690
+ csv_path = os.path.join(dataset_path, csv_file)
691
+ calibration_data = pd.read_csv(csv_path)
692
+
693
+ # Extract data from middle of the sensor (media Y value)
694
+ middle_row = calibration_data.loc[calibration_data['y_mm'] == calibration_data["y_mm"].median()]
695
+
696
+ if img_idxs is None:
697
+ # Get the indices of the 25th percentile and 75th percentile X values
698
+ img1 = middle_row.loc[middle_row['x_mm'] == middle_row['x_mm'].quantile(0.25)]
699
+ img2 = middle_row.loc[middle_row['x_mm'] == middle_row['x_mm'].quantile(0.75)]
700
+ idx_1 = img1.index[0]
701
+ idx_2 = img2.index[0]
702
+ else:
703
+ idx_1 = img_idxs[0]
704
+ idx_2 = img_idxs[1]
705
+
706
+ # Get the image names and XY probe coordinates
707
+ image1_name = os.path.join(dataset_path, calibration_data.iloc[idx_1, 0])
708
+ img1_x_mm = calibration_data.iloc[idx_1, 1]
709
+ img1_y_mm = calibration_data.iloc[idx_1, 2]
710
+
711
+ image2_name = os.path.join(dataset_path, calibration_data.iloc[idx_2, 0])
712
+ img2_x_mm = calibration_data.iloc[idx_2, 1]
713
+ img2_y_mm = calibration_data.iloc[idx_2, 2]
714
+
715
+ # Present the images to the user and have them fit the circles
716
+ circle1_x, circle1_y, circle1_r = self._fit_circle(image1_name, instructions="w: Up\ns: Down\na: Left\nd: Right\nr: Bigger\nf: Smaller\nq: Next")
717
+ circle2_x, circle2_y, circle2_r = self._fit_circle(image2_name, instructions="w: Up\ns: Down\na: Left\nd: Right\nr: Bigger\nf: Smaller\nq: Next", initial_pos=(circle1_x, circle1_y, circle1_r))
718
+
719
+ print(circle1_x, circle1_y, circle1_r)
720
+ print(circle2_x, circle2_y, circle2_r)
721
+
722
+ # Close all opencv windows
723
+ cv2.destroyAllWindows()
724
+
725
+ # Calculate pixels/mm
726
+ known_x_distance = abs(img2_x_mm - img1_x_mm) # mm
727
+ px_per_mm = abs(circle2_x - circle1_x) / known_x_distance
728
+
729
+ px_per_mm, calibration_data = self._scale_px_to_mm(root_dir=dataset_path, csv_path=csv_path, instructions="w: Up\ns: Down\na: Left\nd: Right\nr: Increase pixel/mm value\nf: Decrease pixel/mm value\nq: Quit", initial_val=px_per_mm, anchor_idx=idx_1, circle_vals=(circle1_x, circle1_y, circle1_r))
730
+
731
+ # Print out the pixels/mm value
732
+ print("pixels per mm:", px_per_mm)
733
+
734
+ # Create CSV file with annotated data
735
+ annotated_file_path = os.path.join(dataset_path, "annotated_data.csv")
736
+ calibration_data.to_csv(annotated_file_path, index=False)