py3dcal 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. py3DCal/__init__.py +12 -0
  2. py3DCal/data_collection/Calibrator.py +298 -0
  3. py3DCal/data_collection/Printers/Ender3/Ender3.py +82 -0
  4. py3DCal/data_collection/Printers/Ender3/__init__.py +0 -0
  5. py3DCal/data_collection/Printers/Printer.py +63 -0
  6. py3DCal/data_collection/Printers/__init__.py +0 -0
  7. py3DCal/data_collection/Sensors/DIGIT/DIGIT.py +47 -0
  8. py3DCal/data_collection/Sensors/DIGIT/__init__.py +0 -0
  9. py3DCal/data_collection/Sensors/DIGIT/default.csv +1222 -0
  10. py3DCal/data_collection/Sensors/GelsightMini/GelsightMini.py +45 -0
  11. py3DCal/data_collection/Sensors/GelsightMini/__init__.py +0 -0
  12. py3DCal/data_collection/Sensors/GelsightMini/default.csv +1210 -0
  13. py3DCal/data_collection/Sensors/Sensor.py +35 -0
  14. py3DCal/data_collection/Sensors/__init__.py +0 -0
  15. py3DCal/data_collection/__init__.py +0 -0
  16. py3DCal/model_training/__init__.py +0 -0
  17. py3DCal/model_training/datasets/DIGIT_dataset.py +75 -0
  18. py3DCal/model_training/datasets/GelSightMini_dataset.py +73 -0
  19. py3DCal/model_training/datasets/__init__.py +3 -0
  20. py3DCal/model_training/datasets/split_dataset.py +38 -0
  21. py3DCal/model_training/datasets/tactile_sensor_dataset.py +82 -0
  22. py3DCal/model_training/lib/__init__.py +0 -0
  23. py3DCal/model_training/lib/add_coordinate_embeddings.py +29 -0
  24. py3DCal/model_training/lib/depthmaps.py +74 -0
  25. py3DCal/model_training/lib/fast_poisson.py +51 -0
  26. py3DCal/model_training/lib/get_gradient_map.py +39 -0
  27. py3DCal/model_training/lib/precompute_gradients.py +61 -0
  28. py3DCal/model_training/lib/train_model.py +96 -0
  29. py3DCal/model_training/lib/validate_device.py +22 -0
  30. py3DCal/model_training/lib/validate_parameters.py +45 -0
  31. py3DCal/model_training/models/__init__.py +1 -0
  32. py3DCal/model_training/models/touchnet.py +211 -0
  33. py3DCal/model_training/touchnet/__init__.py +0 -0
  34. py3DCal/model_training/touchnet/dataset.py +78 -0
  35. py3DCal/model_training/touchnet/touchnet.py +736 -0
  36. py3DCal/model_training/touchnet/touchnet_architecture.py +72 -0
  37. py3DCal/utils/__init__.py +0 -0
  38. py3DCal/utils/utils.py +32 -0
  39. py3dcal-1.0.0.dist-info/LICENSE +21 -0
  40. py3dcal-1.0.0.dist-info/METADATA +29 -0
  41. py3dcal-1.0.0.dist-info/RECORD +44 -0
  42. py3dcal-1.0.0.dist-info/WHEEL +5 -0
  43. py3dcal-1.0.0.dist-info/entry_points.txt +3 -0
  44. py3dcal-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,35 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ class Sensor(ABC):
4
+ """
5
+ Sensor: An abstract base class for tactile sensors.
6
+ """
7
+ def __init__(self):
8
+ self.name = ""
9
+ self.x_offset = 5
10
+ self.y_offset = 5
11
+ self.z_offset = 5
12
+ self.z_clearance = 2
13
+ self.max_penetration = 0
14
+ self.default_calibration_file = "calibration_procs/digit/default.csv"
15
+
16
+ @abstractmethod
17
+ def connect(self):
18
+ """ Connects to the sensor
19
+ """
20
+ pass
21
+
22
+ @abstractmethod
23
+ def disconnect(self):
24
+ """ Disconnects from the sensor
25
+ """
26
+ pass
27
+
28
+ @abstractmethod
29
+ def capture_image(self):
30
+ """ Captures an image from the sensor
31
+
32
+ Returns:
33
+ numpy.ndarray: The image from the sensor.
34
+ """
35
+ pass
File without changes
File without changes
File without changes
@@ -0,0 +1,75 @@
1
+ import os
2
+ import requests
3
+ import tarfile
4
+ from tqdm import tqdm
5
+ from typing import Union
6
+ from pathlib import Path
7
+ from .tactile_sensor_dataset import TactileSensorDataset
8
+ from ..lib.validate_parameters import validate_root
9
+
10
+
11
+ class DIGIT(TactileSensorDataset):
12
+ """
13
+ DIGIT: A Dataset Class for the DIGIT sensor
14
+ Args:
15
+ root (str or pathlib.Path): The root directory containing digit_calibration_data.
16
+ download (bool, optional): If True, downloads the dataset for the specified sensor type. Defaults to False.
17
+ add_coordinate_embeddings (bool, optional): If True, adds xy coordinate embeddings to each image. Defaults to True.
18
+ subtract_blank (bool, optional): If True, subtracts a blank image from each input image. Defaults to False.
19
+ transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default: ``transforms.ToTensor()``
20
+ """
21
+ def __init__(self, root: Union[str, Path] = Path("."), download=False, add_coordinate_embeddings=True, subtract_blank=True, transform=None):
22
+ validate_root(root)
23
+
24
+ self.dataset_path = os.path.join(root, "digit_calibration_data")
25
+
26
+ super().__init__(root=self.dataset_path, add_coordinate_embeddings=add_coordinate_embeddings, subtract_blank=subtract_blank, transform=transform)
27
+
28
+ if download:
29
+ self._download_dataset()
30
+
31
+ def _download_dataset(self):
32
+ """
33
+ Downloads the dataset for either the DIGIT sensor.
34
+
35
+ """
36
+ # Check if self.dataset_path exists
37
+ if not os.path.exists(self.dataset_path):
38
+ os.makedirs(self.root, exist_ok=True)
39
+
40
+ tar_path = os.path.join(self.root, "digit_calibration_data.tar.gz")
41
+
42
+ print(f"Downloading DIGIT dataset ...")
43
+ response = requests.get('https://zenodo.org/records/17517028/files/digit_calibration_data.tar.gz?download=1', stream=True)
44
+ response.raise_for_status()
45
+
46
+ total_size = int(response.headers.get('content-length', 0))
47
+ block_size = 1024
48
+
49
+ # Save file in chunks to handle large datasets
50
+ with open(tar_path, 'wb') as f, tqdm(
51
+ total=total_size,
52
+ unit='B',
53
+ unit_scale=True,
54
+ desc="Downloading",
55
+ ncols=80
56
+ ) as progress_bar:
57
+ for chunk in response.iter_content(chunk_size=block_size):
58
+ if chunk:
59
+ f.write(chunk)
60
+ progress_bar.update(len(chunk))
61
+
62
+ print(f"Download complete!")
63
+
64
+ # Extract .tar.gz file
65
+ print("Extracting files ...")
66
+ with tarfile.open(tar_path, "r:gz") as tar:
67
+ tar.extractall(path=self.root)
68
+
69
+ os.remove(tar_path)
70
+
71
+ print(f"Extraction complete! Files are in: {self.root}/")
72
+
73
+ else:
74
+ print(f"DIGIT dataset already exists at: {self.dataset_path}/")
75
+
@@ -0,0 +1,73 @@
1
+ import os
2
+ import requests
3
+ import tarfile
4
+ from tqdm import tqdm
5
+ from typing import Union
6
+ from pathlib import Path
7
+ from .tactile_sensor_dataset import TactileSensorDataset
8
+ from ..lib.validate_parameters import validate_root
9
+
10
+
11
+ class GelSightMini(TactileSensorDataset):
12
+ """
13
+ GelSight Mini: A Dataset Class for the GelSight Mini sensor
14
+ Args:
15
+ root (str or pathlib.Path): The root directory containing gsmini_calibration_data.
16
+ download (bool, optional): If True, downloads the dataset for the specified sensor type. Defaults to False.
17
+ add_coordinate_embeddings (bool, optional): If True, adds xy coordinate embeddings to each image. Defaults to True.
18
+ subtract_blank (bool, optional): If True, subtracts a blank image from each input image. Defaults to False.
19
+ transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default: ``transforms.ToTensor()``
20
+ """
21
+ def __init__(self, root: Union[str, Path] = Path("."), download=False, add_coordinate_embeddings=True, subtract_blank=True, transform=None):
22
+ validate_root(root)
23
+
24
+ self.dataset_path = os.path.join(root, "gsmini_calibration_data")
25
+
26
+ super().__init__(root=self.dataset_path, add_coordinate_embeddings=add_coordinate_embeddings, subtract_blank=subtract_blank, transform=transform)
27
+
28
+ if download:
29
+ self._download_dataset()
30
+
31
+ def _download_dataset(self):
32
+ """
33
+ Downloads the dataset for the GelSight Mini sensor.
34
+ """
35
+ # Check if self.dataset_path exists
36
+ if not os.path.exists(self.dataset_path):
37
+ os.makedirs(self.root, exist_ok=True)
38
+
39
+ tar_path = os.path.join(self.root, "gsmini_calibration_data.tar.gz")
40
+
41
+ print(f"Downloading GelSight Mini dataset ...")
42
+ response = requests.get('https://zenodo.org/records/17517028/files/gsmini_calibration_data.tar.gz?download=1', stream=True)
43
+ response.raise_for_status()
44
+
45
+ total_size = int(response.headers.get('content-length', 0))
46
+ block_size = 1024
47
+
48
+ # Save file in chunks to handle large datasets
49
+ with open(tar_path, 'wb') as f, tqdm(
50
+ total=total_size,
51
+ unit='B',
52
+ unit_scale=True,
53
+ desc="Downloading",
54
+ ncols=80
55
+ ) as progress_bar:
56
+ for chunk in response.iter_content(chunk_size=block_size):
57
+ if chunk:
58
+ f.write(chunk)
59
+ progress_bar.update(len(chunk))
60
+
61
+ print(f"Download complete!")
62
+
63
+ # Extract .tar.gz file
64
+ print("Extracting files ...")
65
+ with tarfile.open(tar_path, "r:gz") as tar:
66
+ tar.extractall(path=self.root)
67
+
68
+ os.remove(tar_path)
69
+
70
+ print(f"Extraction complete! Files are in: {self.root}/")
71
+
72
+ else:
73
+ print(f"GelSight Mini dataset already exists at: {self.dataset_path}/")
@@ -0,0 +1,3 @@
1
+ from .tactile_sensor_dataset import TactileSensorDataset
2
+ from .DIGIT_dataset import DIGIT
3
+ from .GelSightMini_dataset import GelSightMini
@@ -0,0 +1,38 @@
1
+ import copy
2
+ import pandas as pd
3
+ from sklearn.model_selection import train_test_split
4
+ from .tactile_sensor_dataset import TactileSensorDataset
5
+
6
+ def split_dataset(dataset, train_ratio=0.8):
7
+ """
8
+ Splits a dataset into training and validation sets.
9
+
10
+ Args:
11
+ dataset (py3DCal.datasets.TactileSensorDataset): The dataset to split.
12
+ train_ratio (float): The proportion of the dataset to include in the training set. Default is 0.8.
13
+
14
+ Returns:
15
+ tuple: A tuple containing the training and validation datasets.
16
+ """
17
+ if not isinstance(dataset, TactileSensorDataset):
18
+ raise TypeError("Expected dataset to be an instance of py3DCal.datasets.TactileSensorDataset")
19
+
20
+ df = dataset.data.copy()
21
+
22
+ unique_coords = df[['x_mm', 'y_mm']].drop_duplicates().reset_index(drop=True)
23
+
24
+ train_df, val_df = train_test_split(unique_coords, train_size=train_ratio, random_state=42)
25
+
26
+ # Merge back to get full rows
27
+ train_data = pd.merge(df, train_df, on=['x_mm', 'y_mm'])
28
+ val_data = pd.merge(df, val_df, on=['x_mm', 'y_mm'])
29
+
30
+ # Create two copies of the original dataset
31
+ train_dataset = copy.deepcopy(dataset)
32
+ val_dataset = copy.deepcopy(dataset)
33
+
34
+ # Update the data attribute of each dataset
35
+ train_dataset.data = train_data
36
+ val_dataset.data = val_data
37
+
38
+ return train_dataset, val_dataset
@@ -0,0 +1,82 @@
1
+ import os
2
+ import torch
3
+ import pandas as pd
4
+ from PIL import Image
5
+ from typing import Union
6
+ from pathlib import Path
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+ from ..lib.precompute_gradients import precompute_gradients
10
+ from ..lib.get_gradient_map import get_gradient_map
11
+ from ..lib.add_coordinate_embeddings import add_coordinate_embeddings
12
+ from ..lib.validate_parameters import validate_root
13
+
14
+ class TactileSensorDataset(Dataset):
15
+ """
16
+ Tactile Sensor Dataset.
17
+
18
+ Args:
19
+ root (str or pathlib.Path): The root directory that contains the dataset folder.
20
+ add_coordinate_embeddings (bool, optional): If True, adds xy coordinate embeddings to each image. Defaults to True.
21
+ subtract_blank (bool, optional): If True, subtracts a blank image from each input image. Defaults to False.
22
+ transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default: ``transforms.ToTensor()``
23
+ """
24
+ def __init__(self, root: Union[str, Path], add_coordinate_embeddings=True, subtract_blank=True, transform=None):
25
+ validate_root(root)
26
+
27
+ self.root = root
28
+ self.annotation_path = os.path.join(root, "annotations", "annotations.csv")
29
+ self.metadata_path = os.path.join(root, "annotations", "metadata.json")
30
+ self.blank_image_path = os.path.join(root, "blank_images", "blank.png")
31
+ self.add_coordinate_embeddings = add_coordinate_embeddings
32
+ self.subtract_blank = subtract_blank
33
+
34
+ if transform is None:
35
+ self.transform = transforms.ToTensor()
36
+ else:
37
+ self.transform = transform
38
+
39
+ # Load the CSV data
40
+ self.data = pd.read_csv(self.annotation_path)
41
+
42
+ # Get probe radius (in px) from metadata
43
+ metadata = pd.read_json(self.metadata_path, typ="series")
44
+ radius = metadata["probe_radius_mm"] * metadata["px_per_mm"]
45
+
46
+ # Precompute gradients
47
+ self.precomputed_gradients = precompute_gradients(dataset_path=self.root, annotation_path=self.annotation_path, r=radius)
48
+
49
+ # Load and transform blank image
50
+ if subtract_blank:
51
+ self.blank_image = self.transform(Image.open(self.blank_image_path).convert("RGB"))
52
+
53
+ def __len__(self):
54
+ return len(self.data) # Use the DataFrame length
55
+
56
+ def __getitem__(self, idx):
57
+ # Check if index is valid
58
+ if idx < 0 or idx >= len(self.data):
59
+ raise IndexError("Index out of range")
60
+
61
+ if torch.is_tensor(idx):
62
+ idx = idx.tolist()
63
+
64
+ image_name = os.path.join(self.root, "probe_images", self.data.iloc[idx, 0])
65
+ image = Image.open(image_name).convert("RGB")
66
+
67
+ target = get_gradient_map(idx, annotation_path=self.annotation_path, precomputed_gradients=self.precomputed_gradients)
68
+
69
+ image = self.transform(image)
70
+ target = self.transform(target)
71
+
72
+ if self.subtract_blank:
73
+ # Subtract pre-transformed blank tensor
74
+ image = image - self.blank_image
75
+
76
+ if self.add_coordinate_embeddings:
77
+ # Add coordinate embeddings
78
+ image = add_coordinate_embeddings(image)
79
+
80
+ sample = (image, target)
81
+
82
+ return sample
File without changes
@@ -0,0 +1,29 @@
1
+ import torch
2
+
3
+ def add_coordinate_embeddings(image):
4
+ """
5
+ Add coordinate embeddings to the input image.
6
+ - X channel: column indices (1s in first column, 2s in second column, etc.)
7
+ - Y channel: row indices (1s in first row, 2s in second row, etc.)
8
+
9
+ Args:
10
+ image (torch.Tensor): Input image tensor of shape (C, H, W).
11
+
12
+ Returns:
13
+ torch.Tensor: Image tensor with added coordinate embeddings of shape (C+2, H, W).
14
+ """
15
+ # Get image dimensions
16
+ _, height, width = image.shape
17
+
18
+ # Create x coordinate channel (column indices)
19
+ x_embedding = torch.arange(1, width + 1, dtype=torch.float32).unsqueeze(0).repeat(height, 1)
20
+ x_channel = x_embedding.unsqueeze(0) # Add channel dimension
21
+
22
+ # Create y coordinate channel (row indices)
23
+ y_embedding = torch.arange(1, height + 1, dtype=torch.float32).unsqueeze(1).repeat(1, width)
24
+ y_channel = y_embedding.unsqueeze(0) # Add channel dimension
25
+
26
+ # Concatenate original image with position embeddings
27
+ image_with_embeddings = torch.cat([image, x_channel, y_channel], dim=0)
28
+
29
+ return image_with_embeddings
@@ -0,0 +1,74 @@
1
+ from pyexpat import model
2
+ import numpy as np
3
+ import torch
4
+ from pathlib import Path
5
+ from typing import Union
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ from torchvision import transforms
9
+ from .validate_parameters import validate_device
10
+ from .add_coordinate_embeddings import add_coordinate_embeddings
11
+ from .fast_poisson import fast_poisson
12
+
13
+ def get_depthmap(model, image_path: Union[str, Path], blank_image_path: Union[str, Path], device='cpu') -> np.ndarray:
14
+ """
15
+ Returns the depthmap for a given input image.
16
+ Args:
17
+ model: A model which takes in an image and outputs gradient maps.
18
+ image_path (str or pathlib.Path): Path to the input image.
19
+ blank_image_path (str or pathlib.Path): Path to the blank image.
20
+ device (str, optional): Device to run the model on. Defaults to 'cpu'.
21
+ Returns:
22
+ depthmap (numpy.ndarray): The computed depthmap.
23
+ """
24
+ validate_device(device)
25
+
26
+ transform = transforms.ToTensor()
27
+
28
+ model.to(device)
29
+ model.eval()
30
+ image = transform(Image.open(image_path).convert("RGB"))
31
+ blank_image = transform(Image.open(blank_image_path).convert("RGB"))
32
+ augmented_image = image - blank_image
33
+ augmented_image = add_coordinate_embeddings(augmented_image)
34
+ augmented_image = augmented_image.unsqueeze(0).to(device)
35
+
36
+ with torch.no_grad():
37
+ output = model(augmented_image)
38
+
39
+ output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
40
+
41
+ depthmap = fast_poisson(output[:,:,0], output[:,:,1])
42
+
43
+ depthmap = np.clip(-depthmap, a_min=0, a_max=None)
44
+
45
+ return depthmap
46
+
47
+ def save_2d_depthmap(model, image_path: Union[str, Path], blank_image_path: Union[str, Path], device='cpu', save_path: Union[str, Path] = Path("depthmap.png")):
48
+ """
49
+ Save an image of the depthmap for a given input image.
50
+ Args:
51
+ image_path (str): Path to the input image.
52
+ save_path (str or pathlib.Path): Path to save the depthmap image.
53
+
54
+ Returns:
55
+ None.
56
+ """
57
+ depthmap = get_depthmap(model=model, image_path=image_path, blank_image_path=blank_image_path, device=device)
58
+
59
+ plt.imsave(save_path, depthmap, cmap='viridis')
60
+
61
+ def show_2d_depthmap(model, image_path: Union[str, Path], blank_image_path: Union[str, Path], device='cpu'):
62
+ """
63
+ Show the depthmap for a given input image.
64
+
65
+ Args:
66
+ image_path (str): Path to the input image.
67
+
68
+ Returns:
69
+ None.
70
+ """
71
+ depthmap = get_depthmap(model=model, image_path=image_path, blank_image_path=blank_image_path, device=device)
72
+
73
+ plt.imshow(depthmap)
74
+ plt.show()
@@ -0,0 +1,51 @@
1
+ import numpy as np
2
+ from scipy.fftpack import dst
3
+ from scipy.fftpack import idst
4
+
5
+ def fast_poisson(Fx, Fy):
6
+ """
7
+ Fast Poisson solver for 2D images.
8
+ Args:
9
+ Fx: 2D array of x-derivatives
10
+ Fy: 2D array of y-derivatives
11
+ Returns:
12
+ img: 2D array of the solution to the Poisson equation
13
+ """
14
+
15
+ height, width = Fx.shape
16
+
17
+ # Compute the difference of the Fx array in the x-direction to approximate the second derivative in the x-direction (only for interior)
18
+ Fxx = Fx[1:-1,1:-1] - Fx[1:-1,:-2]
19
+ # Compute the difference of the Fy array in the y-direction to approximate the second derivative in the y-direction (only for interior)
20
+ Fyy = Fy[1:-1,1:-1] - Fy[:-2,1:-1]
21
+
22
+ # Combine the two second derivatives to form the source term for the Poisson equation, g
23
+ g = Fxx + Fyy
24
+
25
+ # Apply the Discrete Sine Transform (DST) to the 2D array g (row-wise transform)
26
+ g_sinx = dst(g, norm='ortho')
27
+
28
+ # Apply the DST again (column-wise on the transposed array) to complete the 2D DST
29
+ g_sinxy = dst(g_sinx.T, norm='ortho').T
30
+
31
+ # Create a mesh grid of indices corresponding to the interior points (excluding the boundaries)
32
+ x_mesh, y_mesh = np.meshgrid(range(1, width-1), range(1, height-1))
33
+
34
+ # Construct the denominator for the Poisson solution based on the 2D frequency space
35
+ denom = (2*np.cos(np.pi*x_mesh/(width-1))-2) + (2*np.cos(np.pi*y_mesh/(height-1))-2)
36
+
37
+ # Divide the 2D DST coefficients by the frequency-dependent denominator to solve the Poisson equation in the frequency domain
38
+ out = g_sinxy / denom
39
+
40
+ # Apply the inverse DST (IDST) to the result in the x-direction
41
+ g_x = idst(out,norm='ortho')
42
+
43
+ # Apply the inverse DST again in the y-direction to obtain the solution in the spatial domain
44
+ g_xy = idst(g_x.T,norm='ortho').T
45
+
46
+ # Note: The norm='ortho' option in the DST and IDST ensures that the transforms are orthonormal, maintaining energy conservation in the transforms
47
+
48
+ # Pad the result (which is only for the interior) with 0's at the border because we are assuming fixed boundary conditions
49
+ img = np.pad(g_xy, pad_width=1, mode='constant')
50
+
51
+ return img
@@ -0,0 +1,39 @@
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ def get_gradient_map(idx, annotation_path, precomputed_gradients):
6
+ """
7
+ Returns a gradient map for an image using precomputed gradients.
8
+ Inputs:
9
+ - idx: index of the image to use for gradient map
10
+ - precomputed_gradients: precomputed gradients
11
+ - root_dir: root directory of the dataset
12
+ - csv_file: name of the csv file containing the sensor data
13
+ """
14
+ # Read data file
15
+ sensor_data = pd.read_csv(annotation_path, comment='#')
16
+
17
+ height, width, _ = precomputed_gradients.shape
18
+
19
+ x = int(float(sensor_data['x_px'][idx]))
20
+ y = int(float(sensor_data['y_px'][idx]))
21
+
22
+ right_shift = x - width // 2
23
+ down_shift = y - height // 2
24
+
25
+ offset = max(abs(right_shift), abs(down_shift))
26
+
27
+ gradient_map = np.zeros((height + offset * 2, width + offset * 2, 2))
28
+ gradient_map[:,:,0] = np.pad(precomputed_gradients[:,:,0], pad_width=offset, mode='constant')
29
+ gradient_map[:,:,1] = np.pad(precomputed_gradients[:,:,1], pad_width=offset, mode='constant')
30
+
31
+ # Shift the array 1 position to the right along the horizontal axis (axis=1)
32
+ gradient_map = np.roll(gradient_map, right_shift, axis=1)
33
+
34
+ # Shift the array 1 position down along the vertical axis (axis=0)
35
+ gradient_map = np.roll(gradient_map, down_shift, axis=0)
36
+
37
+ gradient_map = gradient_map[offset:offset+height, offset:offset+width]
38
+
39
+ return gradient_map
@@ -0,0 +1,61 @@
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ from PIL import Image
5
+
6
+ def precompute_gradients(dataset_path, annotation_path, r=36):
7
+ """
8
+ Computes the gradient map for a probe image. This is used to precompute the gradients for all images in the dataset for faster computation.
9
+
10
+ Args:
11
+ root_dir (str): The path of the data folder.
12
+ csv_file (str): The name of the csv data file (must be located in 'root_dir').
13
+
14
+ Returns:
15
+ numpy.ndarray: A h x w x 2 numpy array with x and y gradient values for a circle located at the center.
16
+ """
17
+ # Read data file
18
+ calibration_data = pd.read_csv(annotation_path)
19
+
20
+ # Read the image
21
+ image_path = os.path.join(dataset_path, "probe_images", calibration_data['img_name'][0])
22
+ image = Image.open(image_path)
23
+ image = np.asarray(image)
24
+
25
+ # Get image height and width
26
+ height, width, _ = image.shape
27
+
28
+ # Get circle center and radius
29
+ x = width // 2
30
+ y = height // 2
31
+ r = r
32
+
33
+ # Create graident map
34
+ gradient_map = np.zeros((height, width, 2))
35
+
36
+ for i in range(height):
37
+ for j in range(width):
38
+ # Distance from pixel to center of circle
39
+ d_center = np.sqrt((y - i) ** 2 + (x - j) ** 2)
40
+
41
+ # If pixel is outside circle, set gradients to 0
42
+ if d_center > r:
43
+ Gx = 0
44
+ Gy = 0
45
+
46
+ # Otherwise, calculate the gradients
47
+ else:
48
+ normX = (j - x) / r
49
+ normY = (i - y) / r
50
+ normZ = np.sqrt(1 - normX ** 2 - normY ** 2)
51
+
52
+ if normZ == 0:
53
+ normZ = 0.1
54
+
55
+ Gx = normX / normZ
56
+ Gy = normY / normZ
57
+
58
+ # Update values in gradient map
59
+ gradient_map[i,j] = np.array([Gx,Gy])
60
+
61
+ return gradient_map
@@ -0,0 +1,96 @@
1
+ import torch
2
+ import torch.optim as optim
3
+ import torch.nn as nn
4
+ from torch.utils.data import DataLoader
5
+ from ..datasets.tactile_sensor_dataset import TactileSensorDataset
6
+ from ..datasets.split_dataset import split_dataset
7
+ from .validate_parameters import validate_device
8
+
9
+
10
+ def train_model(model: nn.Module, dataset: TactileSensorDataset, num_epochs: int = 60, batch_size: int = 64, learning_rate: float = 1e-4, train_ratio: float = 0.8, loss_fn: nn.Module = nn.MSELoss(), device='cpu'):
11
+ """
12
+ Train TouchNet model on a dataset for 60 epochs with a
13
+ 64 batch size, and AdamW optimizer with learning rate 1e-4.
14
+
15
+ Args:
16
+ model (nn.Module): The PyTorch model to be trained.
17
+ dataset (py3DCal.datasets.TactileSensorDataset): The dataset to train the model on.
18
+ num_epochs (int): Number of epochs to train for. Defaults to 60.
19
+ batch_size (int): Batch size. Defaults to 64.
20
+ learning_rate (float): Learning rate. Defaults to 1e-4.
21
+ train_ratio (float): Proportion of data to use for training. Defaults to 0.8.
22
+ loss_fn (nn.Module): Loss function. Defaults to nn.MSELoss().
23
+ device (str): Device to run the training on. Defaults to 'cpu'.
24
+
25
+ Outputs:
26
+ weights.pth: Trained model weights.
27
+ loss.csv: Training and testing losses.
28
+ """
29
+ validate_device(device)
30
+ _validate_model_and_dataset(model, dataset)
31
+
32
+ optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
33
+ train_dataset, val_dataset = split_dataset(dataset, train_ratio=train_ratio)
34
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True if device == "cuda" else False, persistent_workers=True)
35
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True if device == "cuda" else False, persistent_workers=True)
36
+
37
+ model.to(device)
38
+
39
+ epoch_train_losses = []
40
+ epoch_val_losses = []
41
+
42
+ print("Starting training...\n")
43
+
44
+ # Training loop
45
+ for epoch in range(num_epochs):
46
+ print(f"Epoch {epoch+1}/{num_epochs}")
47
+
48
+ model.train()
49
+ train_loss = 0.0
50
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
51
+ inputs = inputs.to(torch.float32).to(device)
52
+ targets = targets.to(torch.float32).to(device)
53
+ optimizer.zero_grad()
54
+ outputs = model(inputs)
55
+
56
+ loss = loss_fn(outputs, targets)
57
+ loss.backward()
58
+ optimizer.step()
59
+
60
+ train_loss += loss.item()
61
+
62
+ print(f" [Batch {batch_idx}/{len(train_loader)}] - Loss: {loss.item():.4f}")
63
+
64
+ avg_train_loss = train_loss / len(train_loader)
65
+ epoch_train_losses.append(avg_train_loss)
66
+ print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f}")
67
+
68
+ # Validation loop
69
+ model.eval()
70
+ val_loss = 0.0
71
+
72
+ with torch.no_grad():
73
+ for inputs, targets in val_loader:
74
+ inputs = inputs.to(torch.float32).to(device)
75
+ targets = targets.to(torch.float32).to(device)
76
+ outputs = model(inputs)
77
+ loss = loss_fn(outputs, targets)
78
+ val_loss += loss.item()
79
+
80
+ avg_val_loss = val_loss / len(val_loader)
81
+ epoch_val_losses.append(avg_val_loss)
82
+ print(f"VAL LOSS: {avg_val_loss:.4f}")
83
+
84
+ with open("losses.csv", "w") as f:
85
+ f.write("epoch,train_loss,val_loss\n")
86
+ for i in range(len(epoch_train_losses)):
87
+ f.write(f"{i+1},{epoch_train_losses[i]},{epoch_val_losses[i]}\n")
88
+
89
+ torch.save(model.state_dict(), "weights.pth")
90
+
91
+ def _validate_model_and_dataset(model: nn.Module, dataset: TactileSensorDataset):
92
+ if not isinstance(model, nn.Module):
93
+ raise ValueError("Model must be an instance of torch.nn.Module.")
94
+
95
+ if not isinstance(dataset, TactileSensorDataset):
96
+ raise ValueError("Dataset must be an instance of py3DCal.datasets.TactileSensorDataset.")