learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,244 @@
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import logging
5
+ import numpy
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ from torch.utils.data import DataLoader
11
+ from tensorboardX import SummaryWriter
12
+ from tqdm import tqdm
13
+
14
+ # Only if the files are in example folder.
15
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ if BASE_DIR[-8:] == 'examples':
17
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
18
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
19
+
20
+ from learning3d.models import PointNet
21
+ from learning3d.models import Classifier
22
+ from learning3d.data_utils import ClassificationData, ModelNet40Data
23
+
24
+ def _init_(args):
25
+ if not os.path.exists('checkpoints'):
26
+ os.makedirs('checkpoints')
27
+ if not os.path.exists('checkpoints/' + args.exp_name):
28
+ os.makedirs('checkpoints/' + args.exp_name)
29
+ if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
30
+ os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
31
+ os.system('cp main.py checkpoints' + '/' + args.exp_name + '/' + 'main.py.backup')
32
+ os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup')
33
+
34
+
35
+ class IOStream:
36
+ def __init__(self, path):
37
+ self.f = open(path, 'a')
38
+
39
+ def cprint(self, text):
40
+ print(text)
41
+ self.f.write(text + '\n')
42
+ self.f.flush()
43
+
44
+ def close(self):
45
+ self.f.close()
46
+
47
+ def test_one_epoch(device, model, test_loader):
48
+ model.eval()
49
+ test_loss = 0.0
50
+ pred = 0.0
51
+ count = 0
52
+ for i, data in enumerate(tqdm(test_loader)):
53
+ points, target = data
54
+ target = target[:,0]
55
+
56
+ points = points.to(device)
57
+ target = target.to(device)
58
+
59
+ output = model(points)
60
+ loss_val = torch.nn.functional.nll_loss(
61
+ torch.nn.functional.log_softmax(output, dim=1), target, size_average=False)
62
+
63
+ test_loss += loss_val.item()
64
+ count += output.size(0)
65
+
66
+ _, pred1 = output.max(dim=1)
67
+ ag = (pred1 == target)
68
+ am = ag.sum()
69
+ pred += am.item()
70
+
71
+ test_loss = float(test_loss)/count
72
+ accuracy = float(pred)/count
73
+ return test_loss, accuracy
74
+
75
+ def test(args, model, test_loader, textio):
76
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
77
+ textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy))
78
+
79
+ def train_one_epoch(device, model, train_loader, optimizer):
80
+ model.train()
81
+ train_loss = 0.0
82
+ pred = 0.0
83
+ count = 0
84
+ for i, data in enumerate(tqdm(train_loader)):
85
+ points, target = data
86
+ target = target[:,0]
87
+
88
+ points = points.to(device)
89
+ target = target.to(device)
90
+
91
+ output = model(points)
92
+ loss_val = torch.nn.functional.nll_loss(
93
+ torch.nn.functional.log_softmax(output, dim=1), target, size_average=False)
94
+ # print(loss_val.item())
95
+
96
+ # forward + backward + optimize
97
+ optimizer.zero_grad()
98
+ loss_val.backward()
99
+ optimizer.step()
100
+
101
+ train_loss += loss_val.item()
102
+ count += output.size(0)
103
+
104
+ _, pred1 = output.max(dim=1)
105
+ ag = (pred1 == target)
106
+ am = ag.sum()
107
+ pred += am.item()
108
+
109
+ train_loss = float(train_loss)/count
110
+ accuracy = float(pred)/count
111
+ return train_loss, accuracy
112
+
113
+ def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
114
+ learnable_params = filter(lambda p: p.requires_grad, model.parameters())
115
+ if args.optimizer == 'Adam':
116
+ optimizer = torch.optim.Adam(learnable_params)
117
+ else:
118
+ optimizer = torch.optim.SGD(learnable_params, lr=0.1)
119
+
120
+ if checkpoint is not None:
121
+ min_loss = checkpoint['min_loss']
122
+ optimizer.load_state_dict(checkpoint['optimizer'])
123
+
124
+ best_test_loss = np.inf
125
+
126
+ for epoch in range(args.start_epoch, args.epochs):
127
+ train_loss, train_accuracy = train_one_epoch(args.device, model, train_loader, optimizer)
128
+ test_loss, test_accuracy = test_one_epoch(args.device, model, test_loader)
129
+
130
+ if test_loss<best_test_loss:
131
+ best_test_loss = test_loss
132
+ snap = {'epoch': epoch + 1,
133
+ 'model': model.state_dict(),
134
+ 'min_loss': best_test_loss,
135
+ 'optimizer' : optimizer.state_dict(),}
136
+ torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
137
+ torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
138
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
139
+
140
+ torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
141
+ torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
142
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
143
+
144
+ boardio.add_scalar('Train Loss', train_loss, epoch+1)
145
+ boardio.add_scalar('Test Loss', test_loss, epoch+1)
146
+ boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
147
+ boardio.add_scalar('Train Accuracy', train_accuracy, epoch+1)
148
+ boardio.add_scalar('Test Accuracy', test_accuracy, epoch+1)
149
+
150
+ textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
151
+ textio.cprint('EPOCH:: %d, Traininig Accuracy: %f, Testing Accuracy: %f'%(epoch+1, train_accuracy, test_accuracy))
152
+
153
+ def options():
154
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
155
+ parser.add_argument('--exp_name', type=str, default='exp_classifier', metavar='N',
156
+ help='Name of the experiment')
157
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
158
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
159
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
160
+
161
+ # settings for input data
162
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
163
+ metavar='DATASET', help='dataset type (default: modelnet)')
164
+ parser.add_argument('--num_points', default=1024, type=int,
165
+ metavar='N', help='points in point-cloud (default: 1024)')
166
+
167
+ # settings for PointNet
168
+ parser.add_argument('--pointnet', default='tune', type=str, choices=['fixed', 'tune'],
169
+ help='train pointnet (default: tune)')
170
+ parser.add_argument('--emb_dims', default=1024, type=int,
171
+ metavar='K', help='dim. of the feature vector (default: 1024)')
172
+ parser.add_argument('--symfn', default='max', choices=['max', 'avg'],
173
+ help='symmetric function (default: max)')
174
+
175
+ # settings for on training
176
+ parser.add_argument('--seed', type=int, default=1234)
177
+ parser.add_argument('-j', '--workers', default=4, type=int,
178
+ metavar='N', help='number of data loading workers (default: 4)')
179
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
180
+ metavar='N', help='mini-batch size (default: 32)')
181
+ parser.add_argument('--epochs', default=200, type=int,
182
+ metavar='N', help='number of total epochs to run')
183
+ parser.add_argument('--start_epoch', default=0, type=int,
184
+ metavar='N', help='manual epoch number (useful on restarts)')
185
+ parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
186
+ metavar='METHOD', help='name of an optimizer (default: Adam)')
187
+ parser.add_argument('--resume', default='', type=str,
188
+ metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
189
+ parser.add_argument('--pretrained', default='', type=str,
190
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
191
+ parser.add_argument('--device', default='cuda:0', type=str,
192
+ metavar='DEVICE', help='use CUDA if available')
193
+
194
+ args = parser.parse_args()
195
+ return args
196
+
197
+ def main():
198
+ args = options()
199
+ args.dataset_path = os.path.join(os.getcwd(), os.pardir, os.pardir, 'ModelNet40', 'ModelNet40')
200
+
201
+ torch.backends.cudnn.deterministic = True
202
+ torch.manual_seed(args.seed)
203
+ torch.cuda.manual_seed_all(args.seed)
204
+ np.random.seed(args.seed)
205
+
206
+ boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
207
+ _init_(args)
208
+
209
+ textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
210
+ textio.cprint(str(args))
211
+
212
+
213
+ trainset = ClassificationData(ModelNet40Data(train=True))
214
+ testset = ClassificationData(ModelNet40Data(train=False))
215
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
216
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
217
+
218
+ if not torch.cuda.is_available():
219
+ args.device = 'cpu'
220
+ args.device = torch.device(args.device)
221
+
222
+ # Create PointNet Model.
223
+ ptnet = PointNet(emb_dims=args.emb_dims, use_bn=True)
224
+ model = Classifier(feature_model=ptnet)
225
+
226
+ checkpoint = None
227
+ if args.resume:
228
+ assert os.path.isfile(args.resume)
229
+ checkpoint = torch.load(args.resume)
230
+ args.start_epoch = checkpoint['epoch']
231
+ model.load_state_dict(checkpoint['model'])
232
+
233
+ if args.pretrained:
234
+ assert os.path.isfile(args.pretrained)
235
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
236
+ model.to(args.device)
237
+
238
+ if args.eval:
239
+ test(args, model, test_loader, textio)
240
+ else:
241
+ train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
242
+
243
+ if __name__ == '__main__':
244
+ main()
@@ -0,0 +1,229 @@
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import logging
5
+ import numpy
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ from torch.utils.data import DataLoader
11
+ from tensorboardX import SummaryWriter
12
+ from tqdm import tqdm
13
+
14
+ # Only if the files are in example folder.
15
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ if BASE_DIR[-8:] == 'examples':
17
+ sys.path.append(os.path.join(BASE_DIR, os.pardir))
18
+ os.chdir(os.path.join(BASE_DIR, os.pardir))
19
+
20
+ from learning3d.models import PRNet
21
+ from learning3d.data_utils import RegistrationData, ModelNet40Data
22
+
23
+ def _init_(args):
24
+ if not os.path.exists('checkpoints'):
25
+ os.makedirs('checkpoints')
26
+ if not os.path.exists('checkpoints/' + args.exp_name):
27
+ os.makedirs('checkpoints/' + args.exp_name)
28
+ if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
29
+ os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
30
+ os.system('cp train_dcp.py checkpoints' + '/' + args.exp_name + '/' + 'train.py.backup')
31
+
32
+ class IOStream:
33
+ def __init__(self, path):
34
+ self.f = open(path, 'a')
35
+
36
+ def cprint(self, text):
37
+ print(text)
38
+ self.f.write(text + '\n')
39
+ self.f.flush()
40
+
41
+ def close(self):
42
+ self.f.close()
43
+
44
+ def get_transformations(igt):
45
+ R_ba = igt[:, 0:3, 0:3] # Ps = R_ba * Pt
46
+ translation_ba = igt[:, 0:3, 3].unsqueeze(2) # Ps = Pt + t_ba
47
+ R_ab = R_ba.permute(0, 2, 1) # Pt = R_ab * Ps
48
+ translation_ab = -torch.bmm(R_ab, translation_ba) # Pt = Ps + t_ab
49
+ return R_ab, translation_ab, R_ba, translation_ba
50
+
51
+ def test_one_epoch(device, model, test_loader):
52
+ model.eval()
53
+ test_loss = 0.0
54
+ pred = 0.0
55
+ count = 0
56
+ for i, data in enumerate(tqdm(test_loader)):
57
+ template, source, igt = data
58
+ transformations = get_transformations(igt)
59
+ transformations = [t.to(device) for t in transformations]
60
+ R_ab, translation_ab, R_ba, translation_ba = transformations
61
+
62
+ template = template.to(device)
63
+ source = source.to(device)
64
+ igt = igt.to(device)
65
+
66
+ output = model(template, source, R_ab, translation_ab.squeeze(2))
67
+ loss_val = output['loss']
68
+
69
+ test_loss += loss_val.item()
70
+ count += 1
71
+
72
+ test_loss = float(test_loss)/count
73
+ return test_loss
74
+
75
+ def test(args, model, test_loader, textio):
76
+ test_loss = test_one_epoch(args.device, model, test_loader)
77
+ textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy))
78
+
79
+ def train_one_epoch(device, model, train_loader, optimizer):
80
+ model.train()
81
+ train_loss = 0.0
82
+ pred = 0.0
83
+ count = 0
84
+ for i, data in enumerate(tqdm(train_loader)):
85
+ template, source, igt = data
86
+ transformations = get_transformations(igt)
87
+ transformations = [t.to(device) for t in transformations]
88
+ R_ab, translation_ab, R_ba, translation_ba = transformations
89
+
90
+ template = template.to(device)
91
+ source = source.to(device)
92
+ igt = igt.to(device)
93
+
94
+ output = model(template, source, R_ab, translation_ab.squeeze(2))
95
+ loss_val = output['loss']
96
+
97
+ # forward + backward + optimize
98
+ optimizer.zero_grad()
99
+ loss_val.backward()
100
+ optimizer.step()
101
+
102
+ train_loss += loss_val.item()
103
+ count += 1
104
+
105
+ train_loss = float(train_loss)/count
106
+ return train_loss
107
+
108
+ def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
109
+ learnable_params = filter(lambda p: p.requires_grad, model.parameters())
110
+ if args.optimizer == 'Adam':
111
+ optimizer = torch.optim.Adam(learnable_params)
112
+ else:
113
+ optimizer = torch.optim.SGD(learnable_params, lr=0.1)
114
+
115
+ if checkpoint is not None:
116
+ min_loss = checkpoint['min_loss']
117
+ optimizer.load_state_dict(checkpoint['optimizer'])
118
+
119
+ best_test_loss = np.inf
120
+
121
+ for epoch in range(args.start_epoch, args.epochs):
122
+ train_loss = train_one_epoch(args.device, model, train_loader, optimizer)
123
+ test_loss = test_one_epoch(args.device, model, test_loader)
124
+
125
+ if test_loss<best_test_loss:
126
+ best_test_loss = test_loss
127
+ snap = {'epoch': epoch + 1,
128
+ 'model': model.state_dict(),
129
+ 'min_loss': best_test_loss,
130
+ 'optimizer' : optimizer.state_dict(),}
131
+ torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
132
+ torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
133
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/best_ptnet_model.t7' % (args.exp_name))
134
+
135
+ torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
136
+ torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
137
+ torch.save(model.feature_model.state_dict(), 'checkpoints/%s/models/ptnet_model.t7' % (args.exp_name))
138
+
139
+ boardio.add_scalar('Train Loss', train_loss, epoch+1)
140
+ boardio.add_scalar('Test Loss', test_loss, epoch+1)
141
+ boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
142
+
143
+ textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
144
+
145
+ def options():
146
+ parser = argparse.ArgumentParser(description='Point Cloud Registration')
147
+ parser.add_argument('--exp_name', type=str, default='exp_prnet', metavar='N',
148
+ help='Name of the experiment')
149
+ parser.add_argument('--dataset_path', type=str, default='ModelNet40',
150
+ metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40'
151
+ parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
152
+
153
+ # settings for input data
154
+ parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'],
155
+ metavar='DATASET', help='dataset type (default: modelnet)')
156
+ parser.add_argument('--emb_dims', default=512, type=int,
157
+ metavar='K', help='dim. of the feature vector (default: 1024)')
158
+ parser.add_argument('--num_iterations', default=3, type=int,
159
+ help='Number of Iterations')
160
+
161
+ # settings for on training
162
+ parser.add_argument('--seed', type=int, default=1234)
163
+ parser.add_argument('-j', '--workers', default=4, type=int,
164
+ metavar='N', help='number of data loading workers (default: 4)')
165
+ parser.add_argument('-b', '--batch_size', default=32, type=int,
166
+ metavar='N', help='mini-batch size (default: 32)')
167
+ parser.add_argument('--epochs', default=200, type=int,
168
+ metavar='N', help='number of total epochs to run')
169
+ parser.add_argument('--start_epoch', default=0, type=int,
170
+ metavar='N', help='manual epoch number (useful on restarts)')
171
+ parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
172
+ metavar='METHOD', help='name of an optimizer (default: Adam)')
173
+ parser.add_argument('--resume', default='', type=str,
174
+ metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
175
+ parser.add_argument('--pretrained', default='', type=str,
176
+ metavar='PATH', help='path to pretrained model file (default: null (no-use))')
177
+ parser.add_argument('--device', default='cuda:0', type=str,
178
+ metavar='DEVICE', help='use CUDA if available')
179
+
180
+ args = parser.parse_args()
181
+ return args
182
+
183
+ def main():
184
+ args = options()
185
+
186
+ torch.backends.cudnn.deterministic = True
187
+ torch.manual_seed(args.seed)
188
+ torch.cuda.manual_seed_all(args.seed)
189
+ np.random.seed(args.seed)
190
+
191
+ boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
192
+ _init_(args)
193
+
194
+ textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
195
+ textio.cprint(str(args))
196
+
197
+
198
+ trainset = RegistrationData('PRNet', ModelNet40Data(train=True), partial_source=True, partial_template=True)
199
+ testset = RegistrationData('PRNet', ModelNet40Data(train=False), partial_source=True, partial_template=True)
200
+ train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
201
+ test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
202
+
203
+ if not torch.cuda.is_available():
204
+ args.device = 'cpu'
205
+ args.device = torch.device(args.device)
206
+
207
+ # Create PointNet Model.
208
+ model = PRNet(emb_dims=args.emb_dims, num_iters=args.num_iterations)
209
+ model = model.to(args.device)
210
+
211
+ checkpoint = None
212
+ if args.resume:
213
+ assert os.path.isfile(args.resume)
214
+ checkpoint = torch.load(args.resume)
215
+ args.start_epoch = checkpoint['epoch']
216
+ model.load_state_dict(checkpoint['model'])
217
+
218
+ if args.pretrained:
219
+ assert os.path.isfile(args.pretrained)
220
+ model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
221
+ model.to(args.device)
222
+
223
+ if args.eval:
224
+ test(args, model, test_loader, textio)
225
+ else:
226
+ train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
227
+
228
+ if __name__ == '__main__':
229
+ main()