learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,259 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ from __future__ import print_function
6
+ import os
7
+ import gc
8
+ import argparse
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ from torch.optim.lr_scheduler import MultiStepLR
14
+ from learning3d.models import FlowNet3D
15
+ from learning3d.data_utils import SceneflowDataset
16
+ import numpy as np
17
+ from torch.utils.data import DataLoader
18
+ from tensorboardX import SummaryWriter
19
+ from tqdm import tqdm
20
+
21
+ class IOStream:
22
+ def __init__(self, path):
23
+ self.f = open(path, 'a')
24
+
25
+ def cprint(self, text):
26
+ print(text)
27
+ self.f.write(text + '\n')
28
+ self.f.flush()
29
+
30
+ def close(self):
31
+ self.f.close()
32
+
33
+
34
+ def _init_(args):
35
+ if not os.path.exists('checkpoints'):
36
+ os.makedirs('checkpoints')
37
+ if not os.path.exists('checkpoints/' + args.exp_name):
38
+ os.makedirs('checkpoints/' + args.exp_name)
39
+ if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
40
+ os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
41
+
42
+ def weights_init(m):
43
+ classname=m.__class__.__name__
44
+ if classname.find('Conv2d') != -1:
45
+ nn.init.kaiming_normal_(m.weight.data)
46
+ if classname.find('Conv1d') != -1:
47
+ nn.init.kaiming_normal_(m.weight.data)
48
+
49
+ def test_one_epoch(args, net, test_loader):
50
+ net.eval()
51
+
52
+ total_loss = 0
53
+ num_examples = 0
54
+ for i, data in tqdm(enumerate(test_loader), total=len(test_loader), smoothing=0.9):
55
+ pc1, pc2, color1, color2, flow, mask1 = data
56
+ pc1 = pc1.cuda().transpose(2,1).contiguous()
57
+ pc2 = pc2.cuda().transpose(2,1).contiguous()
58
+ color1 = color1.cuda().transpose(2,1).contiguous()
59
+ color2 = color2.cuda().transpose(2,1).contiguous()
60
+ flow = flow.cuda()
61
+ mask1 = mask1.cuda().float()
62
+
63
+ batch_size = pc1.size(0)
64
+ num_examples += batch_size
65
+ flow_pred = net(pc1, pc2, color1, color2).permute(0,2,1)
66
+ loss_1 = torch.mean(mask1 * torch.sum((flow_pred - flow) * (flow_pred - flow), -1) / 2.0)
67
+
68
+ pc1, pc2 = pc1.permute(0,2,1), pc2.permute(0,2,1)
69
+ pc1_ = pc1 + flow_pred
70
+
71
+ total_loss += loss_1.item() * batch_size
72
+
73
+
74
+ return total_loss * 1.0 / num_examples
75
+
76
+
77
+ def train_one_epoch(args, net, train_loader, opt):
78
+ net.train()
79
+ num_examples = 0
80
+ total_loss = 0
81
+ for i, data in tqdm(enumerate(train_loader), total=len(train_loader), smoothing=0.9):
82
+ pc1, pc2, color1, color2, flow, mask1 = data
83
+ pc1 = pc1.cuda().transpose(2,1).contiguous()
84
+ pc2 = pc2.cuda().transpose(2,1).contiguous()
85
+ color1 = color1.cuda().transpose(2,1).contiguous()
86
+ color2 = color2.cuda().transpose(2,1).contiguous()
87
+ flow = flow.cuda().transpose(2,1).contiguous()
88
+ mask1 = mask1.cuda().float()
89
+
90
+ batch_size = pc1.size(0)
91
+ opt.zero_grad()
92
+ num_examples += batch_size
93
+ flow_pred = net(pc1, pc2, color1, color2)
94
+ loss_1 = torch.mean(mask1 * torch.sum((flow_pred - flow) ** 2, 1) / 2.0)
95
+
96
+ pc1, pc2, flow_pred = pc1.permute(0,2,1), pc2.permute(0,2,1), flow_pred.permute(0,2,1)
97
+ pc1_ = pc1 + flow_pred
98
+
99
+ loss_1.backward()
100
+
101
+ opt.step()
102
+ total_loss += loss_1.item() * batch_size
103
+
104
+ # if (i+1) % 100 == 0:
105
+ # print("batch: %d, mean loss: %f" % (i, total_loss / 100 / batch_size))
106
+ # total_loss = 0
107
+ return total_loss * 1.0 / num_examples
108
+
109
+
110
+ def test(args, net, test_loader, boardio, textio):
111
+
112
+ test_loss = test_one_epoch(args, net, test_loader)
113
+
114
+ textio.cprint('==FINAL TEST==')
115
+ textio.cprint('mean test loss: %f'%test_loss)
116
+
117
+
118
+ def train(args, net, train_loader, test_loader, boardio, textio):
119
+ if args.use_sgd:
120
+ print("Use SGD")
121
+ opt = optim.SGD(net.parameters(), lr=args.lr * 100, momentum=args.momentum, weight_decay=1e-4)
122
+ else:
123
+ print("Use Adam")
124
+ opt = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-4)
125
+ scheduler = MultiStepLR(opt, milestones=[75, 150, 200], gamma=0.1)
126
+
127
+ best_test_loss = np.inf
128
+ for epoch in range(args.epochs):
129
+ scheduler.step()
130
+ textio.cprint('==epoch: %d=='%epoch)
131
+ train_loss = train_one_epoch(args, net, train_loader, opt)
132
+ textio.cprint('mean train EPE loss: %f'%train_loss)
133
+
134
+ test_loss = test_one_epoch(args, net, test_loader)
135
+ textio.cprint('mean test EPE loss: %f'%test_loss)
136
+
137
+ if best_test_loss >= test_loss:
138
+ best_test_loss = test_loss
139
+ textio.cprint('best test loss till now: %f'%test_loss)
140
+ if torch.cuda.device_count() > 1:
141
+ torch.save(net.module.state_dict(), 'checkpoints/%s/models/model.best.t7' % args.exp_name)
142
+ else:
143
+ torch.save(net.state_dict(), 'checkpoints/%s/models/model.best.t7' % args.exp_name)
144
+
145
+ boardio.add_scalar('Train Loss', train_loss, epoch+1)
146
+ boardio.add_scalar('Test Loss', test_loss, epoch+1)
147
+ boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
148
+
149
+ if torch.cuda.device_count() > 1:
150
+ torch.save(net.module.state_dict(), 'checkpoints/%s/models/model.%d.t7' % (args.exp_name, epoch))
151
+ else:
152
+ torch.save(net.state_dict(), 'checkpoints/%s/models/model.%d.t7' % (args.exp_name, epoch))
153
+ gc.collect()
154
+
155
+
156
+ def main():
157
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
158
+ parser.add_argument('--exp_name', type=str, default='exp_flownet', metavar='N',
159
+ help='Name of the experiment')
160
+ parser.add_argument('--model', type=str, default='flownet', metavar='N',
161
+ choices=['flownet'],
162
+ help='Model to use, [flownet]')
163
+ parser.add_argument('--emb_dims', type=int, default=512, metavar='N',
164
+ help='Dimension of embeddings')
165
+ parser.add_argument('--num_points', type=int, default=2048,
166
+ help='Point Number [default: 2048]')
167
+ parser.add_argument('--dropout', type=float, default=0.5, metavar='N',
168
+ help='Dropout ratio in transformer')
169
+ parser.add_argument('--batch_size', type=int, default=16, metavar='batch_size',
170
+ help='Size of batch)')
171
+ parser.add_argument('--test_batch_size', type=int, default=10, metavar='batch_size',
172
+ help='Size of batch)')
173
+ parser.add_argument('--epochs', type=int, default=250, metavar='N',
174
+ help='number of episode to train ')
175
+ parser.add_argument('--use_sgd', action='store_true', default=True,
176
+ help='Use SGD')
177
+ parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
178
+ help='learning rate (default: 0.001, 0.1 if using sgd)')
179
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
180
+ help='SGD momentum (default: 0.9)')
181
+ parser.add_argument('--no_cuda', action='store_true', default=False,
182
+ help='enables CUDA training')
183
+ parser.add_argument('--seed', type=int, default=1234, metavar='S',
184
+ help='random seed (default: 1)')
185
+ parser.add_argument('--eval', action='store_true', default=False,
186
+ help='evaluate the model')
187
+ parser.add_argument('--cycle', type=bool, default=False, metavar='N',
188
+ help='Whether to use cycle consistency')
189
+ parser.add_argument('--gaussian_noise', type=bool, default=False, metavar='N',
190
+ help='Wheter to add gaussian noise')
191
+ parser.add_argument('--unseen', type=bool, default=False, metavar='N',
192
+ help='Whether to test on unseen category')
193
+ parser.add_argument('--dataset', type=str, default='SceneflowDataset',
194
+ choices=['SceneflowDataset'], metavar='N',
195
+ help='dataset to use')
196
+ parser.add_argument('--dataset_path', type=str, default='data_processed_maxcut_35_20k_2k_8192', metavar='N',
197
+ help='dataset to use')
198
+ parser.add_argument('--model_path', type=str, default='', metavar='N',
199
+ help='Pretrained model path')
200
+ parser.add_argument('--pretrained', type=str, default='', metavar='N',
201
+ help='Pretrained model path')
202
+
203
+ args = parser.parse_args()
204
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
205
+ # CUDA settings
206
+ torch.backends.cudnn.deterministic = True
207
+ torch.manual_seed(args.seed)
208
+ torch.cuda.manual_seed_all(args.seed)
209
+ np.random.seed(args.seed)
210
+
211
+ boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
212
+ _init_(args)
213
+
214
+ textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
215
+ textio.cprint(str(args))
216
+
217
+ if args.dataset == 'SceneflowDataset':
218
+ train_loader = DataLoader(
219
+ SceneflowDataset(npoints=args.num_points, partition='train'),
220
+ batch_size=args.batch_size, shuffle=True, drop_last=True)
221
+ test_loader = DataLoader(
222
+ SceneflowDataset(npoints=args.num_points, partition='test'),
223
+ batch_size=args.test_batch_size, shuffle=False, drop_last=False)
224
+ else:
225
+ raise Exception("not implemented")
226
+
227
+ if args.model == 'flownet':
228
+ net = FlowNet3D().cuda()
229
+ net.apply(weights_init)
230
+ if args.pretrained:
231
+ net.load_state_dict(torch.load(args.pretrained), strict=False)
232
+ print("Pretrained Model Loaded Successfully!")
233
+ if args.eval:
234
+ if args.model_path is '':
235
+ model_path = 'checkpoints' + '/' + args.exp_name + '/models/model.best.t7'
236
+ else:
237
+ model_path = args.model_path
238
+ print(model_path)
239
+ if not os.path.exists(model_path):
240
+ print("can't find pretrained model")
241
+ return
242
+ net.load_state_dict(torch.load(model_path), strict=False)
243
+ if torch.cuda.device_count() > 1:
244
+ net = nn.DataParallel(net)
245
+ print("Let's use", torch.cuda.device_count(), "GPUs!")
246
+ else:
247
+ raise Exception('Not implemented')
248
+ if args.eval:
249
+ test(args, net, test_loader, boardio, textio)
250
+ else:
251
+ train(args, net, train_loader, test_loader, boardio, textio)
252
+
253
+
254
+ print('FINISH')
255
+ # boardio.close()
256
+
257
+
258
+ if __name__ == '__main__':
259
+ main()
@@ -0,0 +1,239 @@
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import logging
5
+ import numpy
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ from torch.utils.data import DataLoader
11
+ from tensorboardX import SummaryWriter
12
+ from tqdm import tqdm
13
+
14
+ # Only if the files are in example folder.
15
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ if BASE_DIR[-8:] == 'examples':
17
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
18
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
19
+
20
+ from learning3d.models import MaskNet
21
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
22
+
23
+ def _init_(args):
24
+ if not os.path.exists('checkpoints'):
25
+ os.makedirs('checkpoints')
26
+ if not os.path.exists('checkpoints/' + args.exp_name):
27
+ os.makedirs('checkpoints/' + args.exp_name)
28
+ if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
29
+ os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
30
+ os.system('cp train.py checkpoints' + '/' + args.exp_name + '/' + 'train.py.backup')
31
+ os.system('cp learning3d/models/masknet.py checkpoints' + '/' + args.exp_name + '/' + 'masknet.py.backup')
32
+ os.system('cp learning3d/data_utils/dataloaders.py checkpoints' + '/' + args.exp_name + '/' + 'dataloaders.py.backup')
33
+
34
+
35
+ class IOStream:
36
+ def __init__(self, path):
37
+ self.f = open(path, 'a')
38
+
39
+ def cprint(self, text):
40
+ print(text)
41
+ self.f.write(text + '\n')
42
+ self.f.flush()
43
+
44
+ def close(self):
45
+ self.f.close()
46
+
47
+ def test_one_epoch(args, model, test_loader):
48
+ model.eval()
49
+ test_loss = 0.0
50
+ pred = 0.0
51
+ count = 0
52
+ for i, data in enumerate(tqdm(test_loader)):
53
+ template, source, igt, gt_mask = data
54
+
55
+ template = template.to(args.device)
56
+ source = source.to(args.device)
57
+ igt = igt.to(args.device) # [source] = [igt]*[template]
58
+ gt_mask = gt_mask.to(args.device)
59
+
60
+ masked_template, predicted_mask = model(template, source)
61
+
62
+ if args.loss_fn == 'mse':
63
+ loss_mask = torch.nn.functional.mse_loss(predicted_mask, gt_mask)
64
+ elif args.loss_fn == 'bce':
65
+ loss_mask = torch.nn.BCELoss()(predicted_mask, gt_mask)
66
+
67
+ test_loss += loss_mask.item()
68
+ count += 1
69
+
70
+ test_loss = float(test_loss)/count
71
+ return test_loss
72
+
73
+ def test(args, model, test_loader, textio):
74
+ test_loss, test_accuracy = test_one_epoch(args.device, model, pnlk, test_loader)
75
+ textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy))
76
+
77
+ def train_one_epoch(args, model, train_loader, optimizer):
78
+ model.train()
79
+ train_loss = 0.0
80
+ pred = 0.0
81
+ count = 0
82
+ for i, data in enumerate(tqdm(train_loader)):
83
+ template, source, igt, gt_mask = data
84
+
85
+ template = template.to(args.device)
86
+ source = source.to(args.device)
87
+ igt = igt.to(args.device) # [source] = [igt]*[template]
88
+ gt_mask = gt_mask.to(args.device)
89
+
90
+ masked_template, predicted_mask = model(template, source)
91
+
92
+ if args.loss_fn == 'mse':
93
+ loss_mask = torch.nn.functional.mse_loss(predicted_mask, gt_mask)
94
+ elif args.loss_fn == 'bce':
95
+ loss_mask = torch.nn.BCELoss()(predicted_mask, gt_mask)
96
+
97
+ # forward + backward + optimize
98
+ optimizer.zero_grad()
99
+ loss_mask.backward()
100
+ optimizer.step()
101
+
102
+ train_loss += loss_mask.item()
103
+ count += 1
104
+
105
+ train_loss = float(train_loss)/count
106
+ return train_loss
107
+
108
+ def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
109
+ learnable_params = filter(lambda p: p.requires_grad, model.parameters())
110
+ if args.optimizer == 'Adam':
111
+ optimizer = torch.optim.Adam(learnable_params, lr=0.0001)
112
+ else:
113
+ optimizer = torch.optim.SGD(learnable_params, lr=0.1)
114
+
115
+ if checkpoint is not None:
116
+ min_loss = checkpoint['min_loss']
117
+ optimizer.load_state_dict(checkpoint['optimizer'])
118
+
119
+ best_test_loss = np.inf
120
+
121
+ for epoch in range(args.start_epoch, args.epochs):
122
+ train_loss = train_one_epoch(args, model, train_loader, optimizer)
123
+ test_loss = test_one_epoch(args, model, test_loader)
124
+
125
+ if test_loss<best_test_loss:
126
+ best_test_loss = test_loss
127
+
128
+ snap = {'epoch': epoch + 1,
129
+ 'model': model.state_dict(),
130
+ 'min_loss': best_test_loss,
131
+ 'optimizer' : optimizer.state_dict(),}
132
+ torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
133
+ torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
134
+
135
+ snap = {'epoch': epoch + 1,
136
+ 'model': model.state_dict(),
137
+ 'min_loss': best_test_loss,
138
+ 'optimizer' : optimizer.state_dict(),}
139
+ torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
140
+ torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
141
+
142
+ boardio.add_scalar('Train_Loss', train_loss, epoch+1)
143
+ boardio.add_scalar('Test_Loss', test_loss, epoch+1)
144
+ boardio.add_scalar('Best_Test_Loss', best_test_loss, epoch+1)
145
+
146
+ textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
147
+
148
+ def options():
149
+ parser = argparse.ArgumentParser(description='MaskNet: A Fully-Convolutional Network For Inlier Estimation (Training)')
150
+ parser.add_argument('--exp_name', type=str, default='exp_masknet', metavar='N',
151
+ help='Name of the experiment')
152
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
153
+
154
+ # settings for input data
155
+ parser.add_argument('--num_points', default=1024, type=int,
156
+ metavar='N', help='points in point-cloud (default: 1024)')
157
+ parser.add_argument('--partial_source', default=True, type=bool,
158
+ help='create partial source point cloud in dataset.')
159
+ parser.add_argument('--noise', default=False, type=bool,
160
+ help='Add noise in source point clouds.')
161
+ parser.add_argument('--outliers', default=False, type=bool,
162
+ help='Add outliers to template point cloud.')
163
+
164
+ # settings for on training
165
+ parser.add_argument('--seed', type=int, default=1234)
166
+ parser.add_argument('-j', '--workers', default=4, type=int,
167
+ metavar='N', help='number of data loading workers (default: 4)')
168
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
169
+ metavar='N', help='mini-batch size (default: 32)')
170
+ parser.add_argument('--test_batch_size', default=8, type=int,
171
+ metavar='N', help='test-mini-batch size (default: 8)')
172
+ parser.add_argument('--unseen', default=False, type=bool,
173
+ help='Use first 20 categories for training and last 20 for testing')
174
+ parser.add_argument('--epochs', default=500, type=int,
175
+ metavar='N', help='number of total epochs to run')
176
+ parser.add_argument('--start_epoch', default=0, type=int,
177
+ metavar='N', help='manual epoch number (useful on restarts)')
178
+ parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
179
+ metavar='METHOD', help='name of an optimizer (default: Adam)')
180
+ parser.add_argument('--resume', default='', type=str,
181
+ metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
182
+ parser.add_argument('--pretrained', default='', type=str,
183
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
184
+ parser.add_argument('--device', default='cuda:0', type=str,
185
+ metavar='DEVICE', help='use CUDA if available')
186
+ parser.add_argument('--loss_fn', default='mse', type=str, choices=['mse', 'bce'])
187
+
188
+ args = parser.parse_args()
189
+ return args
190
+
191
+ def main():
192
+ args = options()
193
+
194
+ torch.backends.cudnn.deterministic = True
195
+ torch.manual_seed(args.seed)
196
+ torch.cuda.manual_seed_all(args.seed)
197
+ np.random.seed(args.seed)
198
+
199
+ boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
200
+ _init_(args)
201
+
202
+ textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
203
+ textio.cprint(str(args))
204
+
205
+ trainset = RegistrationData(ModelNet40Data(train=True, num_points=args.num_points, unseen=args.unseen),
206
+ partial_source=args.partial_source, noise=args.noise, outliers=args.outliers,
207
+ additional_params={'use_masknet': True})
208
+ testset = RegistrationData(ModelNet40Data(train=False, num_points=args.num_points, unseen=args.unseen),
209
+ partial_source=args.partial_source, noise=args.noise, outliers=args.outliers,
210
+ additional_params={'use_masknet': True})
211
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
212
+ test_loader = DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
213
+
214
+ if not torch.cuda.is_available():
215
+ args.device = 'cpu'
216
+ args.device = torch.device(args.device)
217
+
218
+ model = MaskNet()
219
+ model = model.to(args.device)
220
+
221
+ checkpoint = None
222
+ if args.resume:
223
+ assert os.path.isfile(args.resume)
224
+ checkpoint = torch.load(args.resume)
225
+ args.start_epoch = checkpoint['epoch']
226
+ model.load_state_dict(checkpoint['model'])
227
+
228
+ if args.pretrained:
229
+ assert os.path.isfile(args.pretrained)
230
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
231
+ model.to(args.device)
232
+
233
+ if args.eval:
234
+ test(args, model, test_loader, textio)
235
+ else:
236
+ train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
237
+
238
+ if __name__ == '__main__':
239
+ main()