learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,249 @@
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import logging
5
+ import numpy
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ from torch.utils.data import DataLoader
11
+ from tensorboardX import SummaryWriter
12
+ from tqdm import tqdm
13
+
14
+ # Only if the files are in example folder.
15
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ if BASE_DIR[-8:] == 'examples':
17
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
18
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
19
+
20
+ from learning3d.models import DGCNN, DCP
21
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
22
+
23
+ def _init_(args):
24
+ if not os.path.exists('checkpoints'):
25
+ os.makedirs('checkpoints')
26
+ if not os.path.exists('checkpoints/' + args.exp_name):
27
+ os.makedirs('checkpoints/' + args.exp_name)
28
+ if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
29
+ os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
30
+ os.system('cp train_dcp.py checkpoints' + '/' + args.exp_name + '/' + 'train.py.backup')
31
+
32
+ class IOStream:
33
+ def __init__(self, path):
34
+ self.f = open(path, 'a')
35
+
36
+ def cprint(self, text):
37
+ print(text)
38
+ self.f.write(text + '\n')
39
+ self.f.flush()
40
+
41
+ def close(self):
42
+ self.f.close()
43
+
44
+ def get_transformations(igt):
45
+ R_ba = igt[:, 0:3, 0:3] # Ps = R_ba * Pt
46
+ translation_ba = igt[:, 0:3, 3].unsqueeze(2) # Ps = Pt + t_ba
47
+ R_ab = R_ba.permute(0, 2, 1) # Pt = R_ab * Ps
48
+ translation_ab = -torch.bmm(R_ab, translation_ba) # Pt = Ps + t_ab
49
+ return R_ab, translation_ab, R_ba, translation_ba
50
+
51
+ def test_one_epoch(device, model, test_loader):
52
+ model.eval()
53
+ test_loss = 0.0
54
+ pred = 0.0
55
+ count = 0
56
+ for i, data in enumerate(tqdm(test_loader)):
57
+ template, source, igt = data
58
+ transformations = get_transformations(igt)
59
+ transformations = [t.to(device) for t in transformations]
60
+ R_ab, translation_ab, R_ba, translation_ba = transformations
61
+
62
+ template = template.to(device)
63
+ source = source.to(device)
64
+ igt = igt.to(device)
65
+
66
+ output = model(template, source)
67
+ identity = torch.eye(3).cuda().unsqueeze(0).repeat(template.shape[0], 1, 1)
68
+ loss_val = torch.nn.functional.mse_loss(torch.matmul(output['est_R'].transpose(2, 1), R_ab), identity) \
69
+ + torch.nn.functional.mse_loss(output['est_t'], translation_ab[:,:,0])
70
+
71
+ cycle_loss = torch.nn.functional.mse_loss(torch.matmul(output['est_R_'].transpose(2, 1), R_ba), identity) \
72
+ + torch.nn.functional.mse_loss(output['est_t_'], translation_ba[:,:,0])
73
+ loss_val = loss_val + cycle_loss * 0.1
74
+
75
+ test_loss += loss_val.item()
76
+ count += 1
77
+
78
+ test_loss = float(test_loss)/count
79
+ return test_loss
80
+
81
+ def test(args, model, test_loader, textio):
82
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
83
+ textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy))
84
+
85
+ def train_one_epoch(device, model, train_loader, optimizer):
86
+ model.train()
87
+ train_loss = 0.0
88
+ pred = 0.0
89
+ count = 0
90
+ for i, data in enumerate(tqdm(train_loader)):
91
+ template, source, igt = data
92
+ transformations = get_transformations(igt)
93
+ transformations = [t.to(device) for t in transformations]
94
+ R_ab, translation_ab, R_ba, translation_ba = transformations
95
+
96
+ template = template.to(device)
97
+ source = source.to(device)
98
+ igt = igt.to(device)
99
+
100
+ output = model(template, source)
101
+ identity = torch.eye(3).cuda().unsqueeze(0).repeat(template.shape[0], 1, 1)
102
+ loss_val = torch.nn.functional.mse_loss(torch.matmul(output['est_R'].transpose(2, 1), R_ab), identity) \
103
+ + torch.nn.functional.mse_loss(output['est_t'], translation_ab[:,:,0])
104
+
105
+ cycle_loss = torch.nn.functional.mse_loss(torch.matmul(output['est_R_'].transpose(2, 1), R_ba), identity) \
106
+ + torch.nn.functional.mse_loss(output['est_t_'], translation_ba[:,:,0])
107
+ loss_val = loss_val + cycle_loss * 0.1
108
+ # print(loss_val.item())
109
+
110
+ # forward + backward + optimize
111
+ optimizer.zero_grad()
112
+ loss_val.backward()
113
+ optimizer.step()
114
+
115
+ train_loss += loss_val.item()
116
+ count += 1
117
+
118
+ train_loss = float(train_loss)/count
119
+ return train_loss
120
+
121
+ def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
122
+ learnable_params = filter(lambda p: p.requires_grad, model.parameters())
123
+ if args.optimizer == 'Adam':
124
+ optimizer = torch.optim.Adam(learnable_params)
125
+ else:
126
+ optimizer = torch.optim.SGD(learnable_params, lr=0.1)
127
+
128
+ if checkpoint is not None:
129
+ min_loss = checkpoint['min_loss']
130
+ optimizer.load_state_dict(checkpoint['optimizer'])
131
+
132
+ best_test_loss = np.inf
133
+
134
+ for epoch in range(args.start_epoch, args.epochs):
135
+ train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
136
+ test_loss = test_one_epoch(args.device, model, test_loader)
137
+
138
+ if test_loss<best_test_loss:
139
+ best_test_loss = test_loss
140
+ snap = {'epoch': epoch + 1,
141
+ 'model': model.state_dict(),
142
+ 'min_loss': best_test_loss,
143
+ 'optimizer' : optimizer.state_dict(),}
144
+ torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
145
+ torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
146
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
147
+
148
+ torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
149
+ torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
150
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
151
+
152
+ boardio.add_scalar('Train Loss', train_loss, epoch+1)
153
+ boardio.add_scalar('Test Loss', test_loss, epoch+1)
154
+ boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
155
+
156
+ textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
157
+
158
+ def options():
159
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
160
+ parser.add_argument('--exp_name', type=str, default='exp_dcp', metavar='N',
161
+ help='Name of the experiment')
162
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
163
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
164
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
165
+
166
+ # settings for input data
167
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
168
+ metavar='DATASET', help='dataset type (default: modelnet)')
169
+ parser.add_argument('--num_points', default=1024, type=int,
170
+ metavar='N', help='points in point-cloud (default: 1024)')
171
+
172
+ # settings for PointNet
173
+ parser.add_argument('--pointnet', default='tune', type=str, choices=['fixed', 'tune'],
174
+ help='train pointnet (default: tune)')
175
+ parser.add_argument('--emb_dims', default=1024, type=int,
176
+ metavar='K', help='dim. of the feature vector (default: 1024)')
177
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
178
+ help='symmetric function (default: max)')
179
+
180
+ # settings for on training
181
+ parser.add_argument('--seed', type=int, default=1234)
182
+ parser.add_argument('-j', '--workers', default=4, type=int,
183
+ metavar='N', help='number of data loading workers (default: 4)')
184
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
185
+ metavar='N', help='mini-batch size (default: 32)')
186
+ parser.add_argument('--epochs', default=200, type=int,
187
+ metavar='N', help='number of total epochs to run')
188
+ parser.add_argument('--start_epoch', default=0, type=int,
189
+ metavar='N', help='manual epoch number (useful on restarts)')
190
+ parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
191
+ metavar='METHOD', help='name of an optimizer (default: Adam)')
192
+ parser.add_argument('--resume', default='', type=str,
193
+ metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
194
+ parser.add_argument('--pretrained', default='', type=str,
195
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
196
+ parser.add_argument('--device', default='cuda:0', type=str,
197
+ metavar='DEVICE', help='use CUDA if available')
198
+
199
+ args = parser.parse_args()
200
+ return args
201
+
202
+ def main():
203
+ args = options()
204
+
205
+ torch.backends.cudnn.deterministic = True
206
+ torch.manual_seed(args.seed)
207
+ torch.cuda.manual_seed_all(args.seed)
208
+ np.random.seed(args.seed)
209
+
210
+ boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
211
+ _init_(args)
212
+
213
+ textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
214
+ textio.cprint(str(args))
215
+
216
+
217
+ trainset = RegistrationData('DCP', ModelNet40Data(train=True))
218
+ testset = RegistrationData('DCP', ModelNet40Data(train=False))
219
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
220
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
221
+
222
+ if not torch.cuda.is_available():
223
+ args.device = 'cpu'
224
+ args.device = torch.device(args.device)
225
+
226
+ # Create PointNet Model.
227
+ dgcnn = DGCNN(emb_dims=args.emb_dims)
228
+ model = DCP(feature_model=dgcnn, cycle=True)
229
+ model = model.to(args.device)
230
+
231
+ checkpoint = None
232
+ if args.resume:
233
+ assert os.path.isfile(args.resume)
234
+ checkpoint = torch.load(args.resume)
235
+ args.start_epoch = checkpoint['epoch']
236
+ model.load_state_dict(checkpoint['model'])
237
+
238
+ if args.pretrained:
239
+ assert os.path.isfile(args.pretrained)
240
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
241
+ model.to(args.device)
242
+
243
+ if args.eval:
244
+ test(args, model, test_loader, textio)
245
+ else:
246
+ train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
247
+
248
+ if __name__ == '__main__':
249
+ main()
@@ -0,0 +1,244 @@
1
+ import open3d as o3d
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import logging
6
+ import numpy
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from torch.utils.data import DataLoader
12
+ from tensorboardX import SummaryWriter
13
+ from tqdm import tqdm
14
+
15
+ # Only if the files are in example folder.
16
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ if BASE_DIR[-8:] == 'examples':
18
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
19
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
20
+
21
+ from learning3d.models import DeepGMR
22
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
23
+
24
+ def display_open3d(template, source, transformed_source):
25
+ template_ = o3d.geometry.PointCloud()
26
+ source_ = o3d.geometry.PointCloud()
27
+ transformed_source_ = o3d.geometry.PointCloud()
28
+ template_.points = o3d.utility.Vector3dVector(template)
29
+ source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0]))
30
+ transformed_source_.points = o3d.utility.Vector3dVector(transformed_source)
31
+ template_.paint_uniform_color([1, 0, 0])
32
+ source_.paint_uniform_color([0, 1, 0])
33
+ transformed_source_.paint_uniform_color([0, 0, 1])
34
+ o3d.visualization.draw_geometries([template_, source_, transformed_source_])
35
+
36
+ def rotation_error(R, R_gt):
37
+ cos_theta = (torch.einsum('bij,bij->b', R, R_gt) - 1) / 2
38
+ cos_theta = torch.clamp(cos_theta, -1, 1)
39
+ return torch.acos(cos_theta) * 180 / math.pi
40
+
41
+ def translation_error(t, t_gt):
42
+ return torch.norm(t - t_gt, dim=1)
43
+
44
+ def rmse(pts, T, T_gt):
45
+ pts_pred = pts @ T[:, :3, :3].transpose(1, 2) + T[:, :3, 3].unsqueeze(1)
46
+ pts_gt = pts @ T_gt[:, :3, :3].transpose(1, 2) + T_gt[:, :3, 3].unsqueeze(1)
47
+ return torch.norm(pts_pred - pts_gt, dim=2).mean(dim=1)
48
+
49
+ def test_one_epoch(device, model, test_loader):
50
+ model.eval()
51
+ test_loss = 0.0
52
+ pred = 0.0
53
+ count = 0
54
+ rotation_errors, translation_errors, rmses = [], [], []
55
+
56
+ for i, data in enumerate(tqdm(test_loader)):
57
+ template, source, igt = data
58
+
59
+ template = template.to(device)
60
+ source = source.to(device)
61
+ igt = igt.to(device)
62
+
63
+ output = model(template, source)
64
+
65
+ eye = torch.eye(4).expand_as(igt).to(igt.device)
66
+ mse1 = F.mse_loss(output['est_T_inverse'] @ torch.inverse(igt), eye)
67
+ mse2 = F.mse_loss(output['est_T'] @ igt, eye)
68
+ loss = mse1 + mse2
69
+
70
+ r_err = rotation_error(est_T_inverse[:, :3, :3], igt[:, :3, :3])
71
+ t_err = translation_error(est_T_inverse[:, :3, 3], igt[:, :3, 3])
72
+ rmse_val = rmse(template[:, :100], est_T_inverse, igt)
73
+ rotation_errors.append(r_err)
74
+ translation_errors.append(t_err)
75
+ rmses.append(rmse_val)
76
+
77
+ test_loss += loss_val.item()
78
+ count += 1
79
+
80
+ test_loss = float(test_loss)/count
81
+ print("Mean rotation error: {}, Mean translation error: {} and Mean RMSE: {}".format(np.mean(rotation_errors), np.mean(translation_errors), np.mean(rmses)))
82
+ return test_loss
83
+
84
+ def test(args, model, test_loader, textio):
85
+ test_loss = test_one_epoch(args.device, model, test_loader)
86
+ textio.cprint('Validation Loss: %f'%(test_loss))
87
+
88
+ def train_one_epoch(device, model, train_loader, optimizer):
89
+ model.train()
90
+ train_loss = 0.0
91
+ pred = 0.0
92
+ count = 0
93
+ for i, data in enumerate(tqdm(train_loader)):
94
+ template, source, igt = data
95
+
96
+ template = template.to(device)
97
+ source = source.to(device)
98
+ igt = igt.to(device)
99
+
100
+ output = model(template, source)
101
+
102
+ eye = torch.eye(4).expand_as(igt).to(igt.device)
103
+ mse1 = F.mse_loss(output['est_T_inverse'] @ torch.inverse(igt), eye)
104
+ mse2 = F.mse_loss(output['est_T'] @ igt, eye)
105
+ loss = mse1 + mse2
106
+
107
+ # forward + backward + optimize
108
+ optimizer.zero_grad()
109
+ loss_val.backward()
110
+ optimizer.step()
111
+
112
+ train_loss += loss_val.item()
113
+ count += 1
114
+
115
+ train_loss = float(train_loss)/count
116
+ return train_loss
117
+
118
+ def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
119
+ learnable_params = filter(lambda p: p.requires_grad, model.parameters())
120
+ if args.optimizer == 'Adam':
121
+ optimizer = torch.optim.Adam(learnable_params)
122
+ else:
123
+ optimizer = torch.optim.SGD(learnable_params, lr=0.1)
124
+
125
+ if checkpoint is not None:
126
+ min_loss = checkpoint['min_loss']
127
+ optimizer.load_state_dict(checkpoint['optimizer'])
128
+
129
+ best_test_loss = np.inf
130
+
131
+ for epoch in range(args.start_epoch, args.epochs):
132
+ train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
133
+ test_loss = test_one_epoch(args.device, model, test_loader)
134
+
135
+ if test_loss<best_test_loss:
136
+ best_test_loss = test_loss
137
+ snap = {'epoch': epoch + 1,
138
+ 'model': model.state_dict(),
139
+ 'min_loss': best_test_loss,
140
+ 'optimizer' : optimizer.state_dict(),}
141
+ torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
142
+ torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
143
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
144
+
145
+ torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
146
+ torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
147
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
148
+
149
+ boardio.add_scalar('Train Loss', train_loss, epoch+1)
150
+ boardio.add_scalar('Test Loss', test_loss, epoch+1)
151
+ boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
152
+
153
+ textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
154
+
155
+ def options():
156
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
157
+ parser.add_argument('--exp_name', type=str, default='exp_deepgmr', metavar='N',
158
+ help='Name of the experiment')
159
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
160
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
161
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
162
+
163
+ # settings for input data
164
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
165
+ metavar='DATASET', help='dataset type (default: modelnet)')
166
+ parser.add_argument('--num_points', default=1024, type=int,
167
+ metavar='N', help='points in point-cloud (default: 1024)')
168
+
169
+ parser.add_argument('--nearest_neighbors', default=20, type=int,
170
+ metavar='K', help='No of nearest neighbors to be estimated.')
171
+ parser.add_argument('--use_rri', default=True, type=bool,
172
+ help='Find nearest neighbors to estimate features from PointNet.')
173
+
174
+ # settings for on training
175
+ parser.add_argument('-j', '--workers', default=4, type=int,
176
+ metavar='N', help='number of data loading workers (default: 4)')
177
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
178
+ metavar='N', help='mini-batch size (default: 32)')
179
+ parser.add_argument('--pretrained', default='', type=str,
180
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
181
+ parser.add_argument('--device', default='cuda:0', type=str,
182
+ metavar='DEVICE', help='use CUDA if available')
183
+ parser.add_argument('--epochs', default=200, type=int,
184
+ metavar='N', help='number of total epochs to run')
185
+ parser.add_argument('--start_epoch', default=0, type=int,
186
+ metavar='N', help='manual epoch number (useful on restarts)')
187
+ parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
188
+ metavar='METHOD', help='name of an optimizer (default: Adam)')
189
+ parser.add_argument('--resume', default='', type=str,
190
+ metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
191
+ parser.add_argument('--pretrained', default='', type=str,
192
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
193
+ parser.add_argument('--device', default='cuda:0', type=str,
194
+ metavar='DEVICE', help='use CUDA if available')
195
+
196
+ args = parser.parse_args()
197
+ if args.nearest_neighbors > 0:
198
+ args.use_rri = True
199
+ return args
200
+
201
+ def main():
202
+ args = options()
203
+ torch.backends.cudnn.deterministic = True
204
+ torch.manual_seed(args.seed)
205
+ torch.cuda.manual_seed_all(args.seed)
206
+ np.random.seed(args.seed)
207
+
208
+ boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
209
+ _init_(args)
210
+
211
+ textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
212
+ textio.cprint(str(args))
213
+
214
+ trainset = RegistrationData('DeepGMR', ModelNet40Data(train=True), additional_params={'nearest_neighbors': args.nearest_neighbors})
215
+ testset = RegistrationData('DeepGMR', ModelNet40Data(train=False), additional_params={'nearest_neighbors': args.nearest_neighbors})
216
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
217
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
218
+
219
+ if not torch.cuda.is_available():
220
+ args.device = 'cpu'
221
+ args.device = torch.device(args.device)
222
+
223
+ model = DeepGMR(use_rri=args.use_rri, nearest_neighbors=args.nearest_neighbors)
224
+ model = model.to(args.device)
225
+
226
+ checkpoint = None
227
+ if args.resume:
228
+ assert os.path.isfile(args.resume)
229
+ checkpoint = torch.load(args.resume)
230
+ args.start_epoch = checkpoint['epoch']
231
+ model.load_state_dict(checkpoint['model'])
232
+
233
+ if args.pretrained:
234
+ assert os.path.isfile(args.pretrained)
235
+ model.load_state_dict(torch.load(args.pretrained), strict=False)
236
+ model.to(args.device)
237
+
238
+ if args.eval:
239
+ test(args, model, test_loader, textio)
240
+ else:
241
+ train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
242
+
243
+ if __name__ == '__main__':
244
+ main()