py3dcal 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of py3dcal might be problematic. Click here for more details.

Files changed (40) hide show
  1. py3DCal/__init__.py +14 -0
  2. py3DCal/data_collection/Calibrator.py +300 -0
  3. py3DCal/data_collection/__init__.py +0 -0
  4. py3DCal/data_collection/printers/Ender3/Ender3.py +82 -0
  5. py3DCal/data_collection/printers/Ender3/__init__.py +0 -0
  6. py3DCal/data_collection/printers/Printer.py +63 -0
  7. py3DCal/data_collection/printers/__init__.py +0 -0
  8. py3DCal/data_collection/sensors/DIGIT/DIGIT.py +47 -0
  9. py3DCal/data_collection/sensors/DIGIT/__init__.py +0 -0
  10. py3DCal/data_collection/sensors/DIGIT/default.csv +1222 -0
  11. py3DCal/data_collection/sensors/GelsightMini/GelsightMini.py +45 -0
  12. py3DCal/data_collection/sensors/GelsightMini/__init__.py +0 -0
  13. py3DCal/data_collection/sensors/GelsightMini/default.csv +1210 -0
  14. py3DCal/data_collection/sensors/Sensor.py +44 -0
  15. py3DCal/data_collection/sensors/__init__.py +0 -0
  16. py3DCal/model_training/__init__.py +0 -0
  17. py3DCal/model_training/datasets/DIGIT_dataset.py +77 -0
  18. py3DCal/model_training/datasets/GelSightMini_dataset.py +75 -0
  19. py3DCal/model_training/datasets/__init__.py +3 -0
  20. py3DCal/model_training/datasets/split_dataset.py +38 -0
  21. py3DCal/model_training/datasets/tactile_sensor_dataset.py +83 -0
  22. py3DCal/model_training/lib/__init__.py +0 -0
  23. py3DCal/model_training/lib/add_coordinate_embeddings.py +29 -0
  24. py3DCal/model_training/lib/annotate_dataset.py +422 -0
  25. py3DCal/model_training/lib/depthmaps.py +82 -0
  26. py3DCal/model_training/lib/fast_poisson.py +51 -0
  27. py3DCal/model_training/lib/get_gradient_map.py +39 -0
  28. py3DCal/model_training/lib/precompute_gradients.py +61 -0
  29. py3DCal/model_training/lib/train_model.py +96 -0
  30. py3DCal/model_training/lib/validate_parameters.py +87 -0
  31. py3DCal/model_training/models/__init__.py +1 -0
  32. py3DCal/model_training/models/touchnet.py +211 -0
  33. py3DCal/utils/__init__.py +0 -0
  34. py3DCal/utils/utils.py +32 -0
  35. py3dcal-1.0.5.dist-info/LICENSE +21 -0
  36. py3dcal-1.0.5.dist-info/METADATA +29 -0
  37. py3dcal-1.0.5.dist-info/RECORD +40 -0
  38. py3dcal-1.0.5.dist-info/WHEEL +5 -0
  39. py3dcal-1.0.5.dist-info/entry_points.txt +3 -0
  40. py3dcal-1.0.5.dist-info/top_level.txt +1 -0
@@ -0,0 +1,422 @@
1
+ import os
2
+ import cv2
3
+ import math
4
+ import json
5
+ import numpy as np
6
+ import pandas as pd
7
+ from typing import Union
8
+ from pathlib import Path
9
+ from matplotlib import pyplot as plt
10
+ from matplotlib.patches import Circle
11
+ from .validate_parameters import validate_root
12
+
13
+ def annotate(dataset_path: Union[str, Path], probe_radius_mm: Union[int, float], img_idxs=None):
14
+ """
15
+ Tool to annotate custom dataset with pixel-to-millimeter calibration.
16
+ Creates an annotated_data.csv file required for training.
17
+
18
+ Controls:
19
+ - w/s: Move circle up/down
20
+ - a/d: Move circle left/right
21
+ - r/f: Increase/decrease circle size or pixel/mm ratio
22
+ - q: Proceed to next step
23
+
24
+ Args:
25
+ dataset_path (str or pathlib.Path): Path to the dataset directory.
26
+ probe_radius_mm (int or float): Radius of the probe used to collect data (in mm).
27
+ img_idxs (tuple or list, optional): The two image indices to use for circle fitting. Default: None (auto-selects images at 25th and 75th percentile columns of middle row).
28
+
29
+ Returns:
30
+ Saves annotated_data.csv in the dataset_path/annotations directory.
31
+ """
32
+ validate_root(dataset_path, must_exist=True)
33
+ validate_probe_radius(probe_radius_mm)
34
+ validate_indices(img_idxs)
35
+
36
+ # Open probe data
37
+ probe_data_path = os.path.join(dataset_path, "annotations", "probe_data.csv")
38
+ probe_data = pd.read_csv(probe_data_path)
39
+
40
+ # Get middle row
41
+ middle_row = probe_data.loc[probe_data["y_mm"] == probe_data["y_mm"].median()]
42
+
43
+ # Get 25th and 75th percentile indices if img_idxs not provided
44
+ if img_idxs is None:
45
+ # Get the indices of the 25th percentile and 75th percentile X values
46
+ idx1 = middle_row.loc[middle_row["x_mm"] == middle_row["x_mm"].quantile(0.25)].index[0]
47
+ idx2 = middle_row.loc[middle_row["x_mm"] == middle_row["x_mm"].quantile(0.75)].index[0]
48
+ else:
49
+ idx1 = img_idxs[0]
50
+ idx2 = img_idxs[1]
51
+
52
+ # Get the image names and probe coordinates
53
+ image1_name = os.path.join(dataset_path, "probe_images", probe_data["img_name"][idx1])
54
+ img1_x_mm = probe_data["x_mm"][idx1]
55
+ img1_y_mm = probe_data["y_mm"][idx1]
56
+
57
+ image2_name = os.path.join(dataset_path, "probe_images", probe_data["img_name"][idx2])
58
+ img2_x_mm = probe_data["x_mm"][idx2]
59
+ img2_y_mm = probe_data["y_mm"][idx2]
60
+
61
+ # Blank image path
62
+ blank_image_path = os.path.join(dataset_path, "blank_images", "blank.png")
63
+
64
+ # Fit 2 circles
65
+ circle1_x, circle1_y, circle1_r = _fit_circle(image1_name, blank_image_path)
66
+ circle2_x, _, _ = _fit_circle(image2_name, blank_image_path)
67
+
68
+ # Compute pixels/mm
69
+ dx_mm = abs(img2_x_mm - img1_x_mm)
70
+ px_per_mm = abs(circle2_x - circle1_x) / dx_mm
71
+
72
+ # Fine tune the fitting
73
+ px_per_mm, annotations = _adjust_fitting(dataset_path, anchor_idx=idx1, px_per_mm=px_per_mm, anchor_data=(circle1_x, circle1_y, circle1_r))
74
+
75
+ print("pixels per mm:", px_per_mm)
76
+
77
+ # Save metadata file
78
+ metadata_path = os.path.join(dataset_path, "annotations", 'metadata.json')
79
+ data = {"px_per_mm": px_per_mm, "probe_radius_mm": probe_radius_mm}
80
+ with open(metadata_path, "w") as json_file:
81
+ json.dump(data, json_file, indent=4)
82
+
83
+ # Create CSV file with annotated data
84
+ annotations_path = os.path.join(dataset_path, "annotations", "annotations.csv")
85
+ annotations.to_csv(annotations_path, index=False)
86
+
87
+ def _fit_circle(image_path: Union[str, Path], blank_image_path: Union[str, Path]):
88
+ """
89
+ Fits a circle to an image.
90
+
91
+ Args:
92
+ image_path: Path to the image.
93
+ blank_image_path: Path to the blank image.
94
+
95
+ Returns:
96
+ x: x-coordinate of the circle.
97
+ y: y-coordinate of the circle.
98
+ r: radius of the circle.
99
+ """
100
+ # Load original image (default view)
101
+ image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
102
+ blank_image = cv2.cvtColor(cv2.imread(blank_image_path), cv2.COLOR_BGR2RGB)
103
+ bitwise_not_blank = cv2.bitwise_not(blank_image)
104
+
105
+ # Initial circle position and radius
106
+ x = image.shape[1] // 2
107
+ y = image.shape[0] // 2
108
+ r = 30
109
+
110
+ # Flags for image display modes
111
+ subtract_blank = False
112
+ bitwise_not = False
113
+
114
+ # Disable Matplotlib’s conflicting keymaps
115
+ plt.rcParams['keymap.save'] = []
116
+ plt.rcParams['keymap.fullscreen'] = []
117
+
118
+ # Prepare figure with two subplots: text (left), image (right)
119
+ fig, (ax_text, ax_img) = plt.subplots(1, 2, figsize=(14, 8), gridspec_kw={'width_ratios': [1, 3]})
120
+ plt.subplots_adjust(wspace=0.4, bottom=0, top=1, left=0, right=1)
121
+
122
+ fig.canvas.manager.set_window_title('Fit Circle to Generated Annotations')
123
+
124
+ # Right: Image panel
125
+ img_artist = ax_img.imshow(image)
126
+ ax_img.set_axis_off()
127
+ circle_artist = plt.Circle((x, y), r, color='red', fill=False, linewidth=1)
128
+ ax_img.add_patch(circle_artist)
129
+ center_artist, = ax_img.plot(x, y, marker='*', color='lime', markersize=6)
130
+
131
+ # Left: Instruction panel
132
+ ax_text.set_axis_off()
133
+ ax_text.text(
134
+ 0.30, 0.75,
135
+ "Commands:\n\nw: Up\ns: Down\na: Left\nd: Right\nr: Bigger\nf: Smaller\nq: Next\n\n\n1: View 1 (RGB image)\n2: View 2 (Difference image)\n3: View 3 (Bitwise not image)",
136
+ fontsize=20, color='black', va='top', ha='left', wrap=True
137
+ )
138
+
139
+ plt.ion()
140
+ plt.show(block=False)
141
+
142
+ done = False
143
+
144
+ def on_key(event):
145
+ nonlocal x, y, r, done, subtract_blank, bitwise_not, image, blank_image, bitwise_not_blank
146
+
147
+ if event.key == 'q':
148
+ done = True
149
+ elif event.key in ('w', 'up'):
150
+ y -= 1
151
+ elif event.key in ('s', 'down'):
152
+ y += 1
153
+ elif event.key in ('a', 'left'):
154
+ x -= 1
155
+ elif event.key in ('d', 'right'):
156
+ x += 1
157
+ elif event.key == 'r':
158
+ r += 1
159
+ elif event.key == 'f':
160
+ r -= 1
161
+ elif event.key == '1': # Normal image
162
+ subtract_blank = False
163
+ bitwise_not = False
164
+ img_artist.set_data(image)
165
+
166
+ elif event.key == '2': # Difference image
167
+ subtract_blank = not subtract_blank
168
+ bitwise_not = False
169
+
170
+ if subtract_blank:
171
+ diff_image = cv2.absdiff(image, blank_image)
172
+ img_artist.set_data(diff_image)
173
+ else:
174
+ img_artist.set_data(image)
175
+
176
+ elif event.key == '3': # Bitwise not image
177
+ bitwise_not = not bitwise_not
178
+ subtract_blank = False
179
+
180
+ if bitwise_not:
181
+ bitwise_not_image = cv2.addWeighted(image, 0.5, bitwise_not_blank, 0.5, 0.0)
182
+ img_artist.set_data(bitwise_not_image)
183
+ else:
184
+ img_artist.set_data(image)
185
+
186
+ fig.canvas.mpl_connect('key_press_event', on_key)
187
+
188
+ # Interactive update loop
189
+ while not done:
190
+ circle_artist.center = (x, y)
191
+ circle_artist.set_radius(r)
192
+ center_artist.set_data([x], [y])
193
+ fig.canvas.draw_idle()
194
+ plt.pause(0.01)
195
+
196
+ plt.close(fig)
197
+ plt.ioff() # Turn off interactive mode
198
+ fig.canvas.flush_events() # Flush any pending events
199
+
200
+ return x, y, r
201
+
202
+ def _adjust_fitting(dataset_path: Union[str, Path], anchor_idx, px_per_mm, anchor_data):
203
+ """
204
+ Scales the pixel-to-millimeter calibration using an interactive Matplotlib GUI.
205
+ Args:
206
+ dataset_path: Path to the dataset.
207
+ csv_path: Path to the CSV file.
208
+ initial_val: Initial pixel/mm ratio.
209
+ anchor_idx: Index of the anchor image.
210
+ circle_vals: Values of the anchor circle (x, y, r).
211
+ Returns:
212
+ px_per_mm: Pixel/millimeter ratio.
213
+ calibration_data: Updated calibration dataframe.
214
+ """
215
+
216
+ # Load calibration data
217
+ calibration_data_path = os.path.join(dataset_path, "annotations", "probe_data.csv")
218
+ calibration_data = pd.read_csv(calibration_data_path)
219
+
220
+ # Load anchor image
221
+ anchor_image_path = os.path.join(dataset_path, "probe_images", calibration_data["img_name"][anchor_idx])
222
+ anchor_image = cv2.cvtColor(cv2.imread(anchor_image_path), cv2.COLOR_BGR2RGB)
223
+ anchor_x_mm = calibration_data["x_mm"][anchor_idx]
224
+ anchor_y_mm = calibration_data["y_mm"][anchor_idx]
225
+ anchor_x_px, anchor_y_px, anchor_r_px = anchor_data
226
+ height, width, _ = anchor_image.shape
227
+
228
+ # Add initial annotations (pixel coordinates)
229
+ calibration_data['x_px'] = anchor_x_px + (calibration_data['x_mm'] - anchor_x_mm) * px_per_mm
230
+ calibration_data['y_px'] = anchor_y_px + (anchor_y_mm - calibration_data['y_mm']) * px_per_mm
231
+
232
+ # Load blank image
233
+ blank_image_path = os.path.join(dataset_path, "blank_images", "blank.png")
234
+ blank_image = cv2.cvtColor(cv2.imread(blank_image_path), cv2.COLOR_BGR2RGB)
235
+
236
+ # Generate blank mosaic
237
+ blank_mosaic = np.zeros((height * 3, width * 3, 3), dtype=np.uint8)
238
+
239
+ for row in range(3):
240
+ for col in range(3):
241
+ blank_mosaic[(row * height):((row + 1) * height),
242
+ (col * width):((col + 1) * width), :] = blank_image
243
+
244
+ # Create bitwise not mosaic
245
+ bitwise_not_blank = cv2.bitwise_not(blank_mosaic)
246
+
247
+ # Generate 3×3 mosaic
248
+ image_list = [anchor_idx]
249
+ mosaic = np.zeros((height * 3, width * 3, 3), dtype=np.uint8)
250
+ mosaic[:height, :width, :] = anchor_image
251
+
252
+ idx = 1
253
+ while len(image_list) < 9:
254
+ random_row = calibration_data.sample(n=1)
255
+
256
+ # Make sure circles are within the camera's FOV
257
+ if random_row["x_px"].values[0] > width * 0.15 and random_row["x_px"].values[0] < width * 0.85 and random_row["y_px"].values[0] > height * 0.15 and random_row["y_px"].values[0] < height * 0.85:
258
+ image_path = os.path.join(dataset_path, "probe_images", random_row["img_name"].values[0])
259
+ image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
260
+
261
+ image_list.append(random_row.index[0])
262
+
263
+ row = math.floor(idx / 3)
264
+ col = idx % 3
265
+
266
+ mosaic[(height * row):(height * (row + 1)),
267
+ (width * col):(width * (col + 1)), :] = image
268
+ idx += 1
269
+
270
+ # Flags for image display modes
271
+ subtract_blank = False
272
+ bitwise_not = False
273
+
274
+ # Initialize Matplotlib figure
275
+ plt.rcParams['keymap.save'] = []
276
+ plt.rcParams['keymap.fullscreen'] = []
277
+
278
+ fig, (ax_text, ax_img) = plt.subplots(1, 2, figsize=(14, 8), gridspec_kw={'width_ratios': [1, 3]})
279
+ plt.subplots_adjust(wspace=0.4, bottom=0, top=1, left=0, right=1)
280
+ fig.canvas.manager.set_window_title('Validate and Refine Calibration Annotations')
281
+
282
+ # Right panel: image grid
283
+ img_artist = ax_img.imshow(mosaic)
284
+ ax_img.set_axis_off()
285
+
286
+ ax_img.text(
287
+ width * 0.19,
288
+ height * 0.1,
289
+ "Anchor Image",
290
+ color='yellow',
291
+ fontsize=13,
292
+ bbox=dict(facecolor='black', alpha=0.1, boxstyle='round,pad=0.3')
293
+ )
294
+
295
+
296
+ # Overlay circles
297
+ circle_artists = []
298
+ for i in range(9):
299
+ row = math.floor(i / 3)
300
+ col = i % 3
301
+ idx = image_list[i]
302
+ x = int(calibration_data.loc[idx, 'x_px']) + col * width
303
+ y = int(calibration_data.loc[idx, 'y_px']) + row * height
304
+ circ = Circle((x, y), anchor_r_px, color='red', fill=False, lw=1)
305
+ ax_img.add_patch(circ)
306
+ circle_artists.append(circ)
307
+
308
+ # Left panel: instructions
309
+ ax_text.set_axis_off()
310
+ ax_text.text(
311
+ 0.30, 0.75,
312
+ f"Commands:\n\nw: Up\ns: Down\na: Left\nd: Right\nr: Increase pixel/mm value\nf: Decrease pixel/mm value\nq: Quit\n\n\n1: View 1 (RGB image)\n2: View 2 (Difference image)\n3: View 3 (Bitwise not image)",
313
+ fontsize=20, color='black', va='top', ha='left', wrap=True
314
+ )
315
+
316
+ plt.ion()
317
+ plt.show(block=False)
318
+
319
+ done = False
320
+
321
+ # Keyboard event handler
322
+ def on_key(event):
323
+ nonlocal anchor_x_px, anchor_y_px, anchor_r_px, px_per_mm, done, subtract_blank, bitwise_not, mosaic, blank_mosaic, bitwise_not_blank
324
+
325
+ if event.key == 'q':
326
+ done = True
327
+ elif event.key in ('w', 'up'):
328
+ anchor_y_px -= 1
329
+ elif event.key in ('s', 'down'):
330
+ anchor_y_px += 1
331
+ elif event.key in ('a', 'left'):
332
+ anchor_x_px -= 1
333
+ elif event.key in ('d', 'right'):
334
+ anchor_x_px += 1
335
+ elif event.key == 'r':
336
+ px_per_mm += 1
337
+ elif event.key == 'f':
338
+ px_per_mm -= 1
339
+ elif event.key == '1':
340
+ subtract_blank = False
341
+ bitwise_not = False
342
+ img_artist.set_data(mosaic)
343
+ elif event.key == '2':
344
+ subtract_blank = not subtract_blank
345
+ bitwise_not = False
346
+
347
+ if subtract_blank:
348
+ diff_mosaic = cv2.absdiff(mosaic, blank_mosaic)
349
+ img_artist.set_data(diff_mosaic)
350
+ else:
351
+ img_artist.set_data(mosaic)
352
+
353
+ elif event.key == '3':
354
+ bitwise_not = not bitwise_not
355
+ subtract_blank = False
356
+
357
+ if bitwise_not:
358
+ bitwise_not_mosaic = cv2.addWeighted(mosaic, 0.5, bitwise_not_blank, 0.5, 0.0)
359
+ img_artist.set_data(bitwise_not_mosaic)
360
+ else:
361
+ img_artist.set_data(mosaic)
362
+
363
+ # Recompute coordinates
364
+ calibration_data['x_px'] = anchor_x_px + (calibration_data['x_mm'] - anchor_x_mm) * px_per_mm
365
+ calibration_data['y_px'] = anchor_y_px + (anchor_y_mm - calibration_data['y_mm']) * px_per_mm
366
+
367
+ for i in range(9):
368
+ row = math.floor(i / 3)
369
+ col = i % 3
370
+ idx = image_list[i]
371
+ x = int(calibration_data.loc[idx, 'x_px']) + col * width
372
+ y = int(calibration_data.loc[idx, 'y_px']) + row * height
373
+ circle_artists[i].center = (x, y)
374
+
375
+ fig.canvas.mpl_connect('key_press_event', on_key)
376
+
377
+ # Main interactive loop
378
+ while not done:
379
+ fig.canvas.draw_idle()
380
+ plt.pause(0.01)
381
+
382
+ plt.close(fig)
383
+
384
+ return px_per_mm, calibration_data
385
+
386
+ def validate_probe_radius(probe_radius_mm):
387
+ """
388
+ Validates the probe radius specified by the user.
389
+
390
+ Args:
391
+ probe_radius_mm: Probe radius specified by the user.
392
+ Returns:
393
+ None.
394
+ Raises:
395
+ ValueError: If the probe radius is not specified or invalid.
396
+ """
397
+ if probe_radius_mm is None:
398
+ raise ValueError(
399
+ "Probe radius cannot be None.\n"
400
+ )
401
+ if not isinstance(probe_radius_mm, (int, float)) or probe_radius_mm <= 0:
402
+ raise ValueError(
403
+ "Probe radius must be a positive number (int or float).\n"
404
+ )
405
+
406
+
407
+ def validate_indices(idxs):
408
+ """
409
+ Validates the image indices specified by the user.
410
+
411
+ Args:
412
+ idxs: Tuple of indices specified by the user.
413
+ Returns:
414
+ None.
415
+ Raises:
416
+ ValueError: If the indices are not specified or invalid.
417
+ """
418
+ if idxs is not None:
419
+ if not (isinstance(idxs, (tuple, list)) and len(idxs) == 2 and all(isinstance(i, int) for i in idxs)):
420
+ raise ValueError(
421
+ "Image indices must be a tuple or list of two integers.\n"
422
+ )
@@ -0,0 +1,82 @@
1
+ from pyexpat import model
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from pathlib import Path
6
+ from typing import Union
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ from torchvision import transforms
10
+ from .validate_parameters import validate_device
11
+ from .add_coordinate_embeddings import add_coordinate_embeddings
12
+ from .fast_poisson import fast_poisson
13
+
14
+ def get_depthmap(model: nn.Module, image_path: Union[str, Path], blank_image_path: Union[str, Path], device='cpu') -> np.ndarray:
15
+ """
16
+ Returns the depthmap for a given input image.
17
+ Args:
18
+ model: A model which takes in an image and outputs gradient maps.
19
+ image_path (str or pathlib.Path): Path to the input image.
20
+ blank_image_path (str or pathlib.Path): Path to the blank image.
21
+ device (str, optional): Device to run the model on. Defaults to 'cpu'.
22
+
23
+ Returns:
24
+ depthmap (numpy.ndarray): The computed depthmap.
25
+ """
26
+ validate_device(device)
27
+
28
+ transform = transforms.ToTensor()
29
+
30
+ model.to(device)
31
+ model.eval()
32
+ image = transform(Image.open(image_path).convert("RGB"))
33
+ blank_image = transform(Image.open(blank_image_path).convert("RGB"))
34
+ augmented_image = image - blank_image
35
+ augmented_image = add_coordinate_embeddings(augmented_image)
36
+ augmented_image = augmented_image.unsqueeze(0).to(device)
37
+
38
+ with torch.no_grad():
39
+ output = model(augmented_image)
40
+
41
+ output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
42
+
43
+ depthmap = fast_poisson(output[:,:,0], output[:,:,1])
44
+
45
+ depthmap = np.clip(-depthmap, a_min=0, a_max=None)
46
+
47
+ return depthmap
48
+
49
+ def save_2d_depthmap(model: nn.Module, image_path: Union[str, Path], blank_image_path: Union[str, Path], device='cpu', save_path: Union[str, Path] = Path("depthmap.png")):
50
+ """
51
+ Save an image of the depthmap for a given input image.
52
+ Args:
53
+ model (nn.Module): A model which takes in an image and outputs gradient maps.
54
+ image_path (str): Path to the input image.
55
+ save_path (str or pathlib.Path): Path to save the depthmap image.
56
+ blank_image_path (str): Path to the blank image.
57
+ device (str, optional): Device to run the model on. Defaults to 'cpu'.
58
+
59
+ Returns:
60
+ None.
61
+ """
62
+ depthmap = get_depthmap(model=model, image_path=image_path, blank_image_path=blank_image_path, device=device)
63
+
64
+ plt.imsave(save_path, depthmap, cmap='viridis')
65
+
66
+ def show_2d_depthmap(model: nn.Module, image_path: Union[str, Path], blank_image_path: Union[str, Path], device='cpu'):
67
+ """
68
+ Show the depthmap for a given input image.
69
+
70
+ Args:
71
+ model (nn.Module): A model which takes in an image and outputs gradient maps.
72
+ image_path (str): Path to the input image.
73
+ blank_image_path (str): Path to the blank image.
74
+ device (str, optional): Device to run the model on. Defaults to 'cpu'.
75
+
76
+ Returns:
77
+ None.
78
+ """
79
+ depthmap = get_depthmap(model=model, image_path=image_path, blank_image_path=blank_image_path, device=device)
80
+
81
+ plt.imshow(depthmap)
82
+ plt.show()
@@ -0,0 +1,51 @@
1
+ import numpy as np
2
+ from scipy.fftpack import dst
3
+ from scipy.fftpack import idst
4
+
5
+ def fast_poisson(Gx, Gy):
6
+ """
7
+ Fast Poisson solver for 2D images.
8
+ Args:
9
+ Gx (np.ndarray): 2D array of x-derivatives
10
+ Gy (np.ndarray): 2D array of y-derivatives
11
+ Returns:
12
+ depthmap (np.ndarray): 2D array of the solution to the Poisson equation
13
+ """
14
+
15
+ height, width = Gx.shape
16
+
17
+ # Compute the difference of the Gx array in the x-direction to approximate the second derivative in the x-direction (only for interior)
18
+ Gxx = Gx[1:-1,1:-1] - Gx[1:-1,:-2]
19
+ # Compute the difference of the Gy array in the y-direction to approximate the second derivative in the y-direction (only for interior)
20
+ Gyy = Gy[1:-1,1:-1] - Gy[:-2,1:-1]
21
+
22
+ # Combine the two second derivatives to form the source term for the Poisson equation, g
23
+ g = Gxx + Gyy
24
+
25
+ # Apply the Discrete Sine Transform (DST) to the 2D array g (row-wise transform)
26
+ g_sinx = dst(g, norm='ortho')
27
+
28
+ # Apply the DST again (column-wise on the transposed array) to complete the 2D DST
29
+ g_sinxy = dst(g_sinx.T, norm='ortho').T
30
+
31
+ # Create a mesh grid of indices corresponding to the interior points (excluding the boundaries)
32
+ x_mesh, y_mesh = np.meshgrid(range(1, width-1), range(1, height-1))
33
+
34
+ # Construct the denominator for the Poisson solution based on the 2D frequency space
35
+ denom = (2*np.cos(np.pi*x_mesh/(width-1))-2) + (2*np.cos(np.pi*y_mesh/(height-1))-2)
36
+
37
+ # Divide the 2D DST coefficients by the frequency-dependent denominator to solve the Poisson equation in the frequency domain
38
+ out = g_sinxy / denom
39
+
40
+ # Apply the inverse DST (IDST) to the result in the x-direction
41
+ g_x = idst(out,norm='ortho')
42
+
43
+ # Apply the inverse DST again in the y-direction to obtain the solution in the spatial domain
44
+ g_xy = idst(g_x.T,norm='ortho').T
45
+
46
+ # Note: The norm='ortho' option in the DST and IDST ensures that the transforms are orthonormal, maintaining energy conservation in the transforms
47
+
48
+ # Pad the result (which is only for the interior) with 0's at the border because we are assuming fixed boundary conditions
49
+ depthmap = np.pad(g_xy, pad_width=1, mode='constant')
50
+
51
+ return depthmap
@@ -0,0 +1,39 @@
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ def get_gradient_map(idx, annotation_path, precomputed_gradients):
6
+ """
7
+ Returns a gradient map for an image using precomputed gradients.
8
+ Inputs:
9
+ - idx: index of the image to use for gradient map
10
+ - precomputed_gradients: precomputed gradients
11
+ - root_dir: root directory of the dataset
12
+ - csv_file: name of the csv file containing the sensor data
13
+ """
14
+ # Read data file
15
+ sensor_data = pd.read_csv(annotation_path, comment='#')
16
+
17
+ height, width, _ = precomputed_gradients.shape
18
+
19
+ x = int(float(sensor_data['x_px'][idx]))
20
+ y = int(float(sensor_data['y_px'][idx]))
21
+
22
+ right_shift = x - width // 2
23
+ down_shift = y - height // 2
24
+
25
+ offset = max(abs(right_shift), abs(down_shift))
26
+
27
+ gradient_map = np.zeros((height + offset * 2, width + offset * 2, 2))
28
+ gradient_map[:,:,0] = np.pad(precomputed_gradients[:,:,0], pad_width=offset, mode='constant')
29
+ gradient_map[:,:,1] = np.pad(precomputed_gradients[:,:,1], pad_width=offset, mode='constant')
30
+
31
+ # Shift the array 1 position to the right along the horizontal axis (axis=1)
32
+ gradient_map = np.roll(gradient_map, right_shift, axis=1)
33
+
34
+ # Shift the array 1 position down along the vertical axis (axis=0)
35
+ gradient_map = np.roll(gradient_map, down_shift, axis=0)
36
+
37
+ gradient_map = gradient_map[offset:offset+height, offset:offset+width]
38
+
39
+ return gradient_map
@@ -0,0 +1,61 @@
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ from PIL import Image
5
+
6
+ def precompute_gradients(dataset_path, annotation_path, r=36):
7
+ """
8
+ Computes the gradient map for a probe image. This is used to precompute the gradients for all images in the dataset for faster computation.
9
+
10
+ Args:
11
+ root_dir (str): The path of the data folder.
12
+ csv_file (str): The name of the csv data file (must be located in 'root_dir').
13
+
14
+ Returns:
15
+ numpy.ndarray: A h x w x 2 numpy array with x and y gradient values for a circle located at the center.
16
+ """
17
+ # Read data file
18
+ calibration_data = pd.read_csv(annotation_path)
19
+
20
+ # Read the image
21
+ image_path = os.path.join(dataset_path, "probe_images", calibration_data['img_name'][0])
22
+ image = Image.open(image_path)
23
+ image = np.asarray(image)
24
+
25
+ # Get image height and width
26
+ height, width, _ = image.shape
27
+
28
+ # Get circle center and radius
29
+ x = width // 2
30
+ y = height // 2
31
+ r = r
32
+
33
+ # Create graident map
34
+ gradient_map = np.zeros((height, width, 2))
35
+
36
+ for i in range(height):
37
+ for j in range(width):
38
+ # Distance from pixel to center of circle
39
+ d_center = np.sqrt((y - i) ** 2 + (x - j) ** 2)
40
+
41
+ # If pixel is outside circle, set gradients to 0
42
+ if d_center > r:
43
+ Gx = 0
44
+ Gy = 0
45
+
46
+ # Otherwise, calculate the gradients
47
+ else:
48
+ normX = (j - x) / r
49
+ normY = (i - y) / r
50
+ normZ = np.sqrt(1 - normX ** 2 - normY ** 2)
51
+
52
+ if normZ == 0:
53
+ normZ = 0.1
54
+
55
+ Gx = normX / normZ
56
+ Gy = normY / normZ
57
+
58
+ # Update values in gradient map
59
+ gradient_map[i,j] = np.array([Gx,Gy])
60
+
61
+ return gradient_map