learning3d 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- learning3d/__init__.py +2 -0
- learning3d/data_utils/__init__.py +4 -0
- learning3d/data_utils/dataloaders.py +454 -0
- learning3d/data_utils/user_data.py +119 -0
- learning3d/examples/test_dcp.py +139 -0
- learning3d/examples/test_deepgmr.py +144 -0
- learning3d/examples/test_flownet.py +113 -0
- learning3d/examples/test_masknet.py +159 -0
- learning3d/examples/test_masknet2.py +162 -0
- learning3d/examples/test_pcn.py +118 -0
- learning3d/examples/test_pcrnet.py +120 -0
- learning3d/examples/test_pnlk.py +121 -0
- learning3d/examples/test_pointconv.py +126 -0
- learning3d/examples/test_pointnet.py +121 -0
- learning3d/examples/test_prnet.py +126 -0
- learning3d/examples/test_rpmnet.py +120 -0
- learning3d/examples/train_PointNetLK.py +240 -0
- learning3d/examples/train_dcp.py +249 -0
- learning3d/examples/train_deepgmr.py +244 -0
- learning3d/examples/train_flownet.py +259 -0
- learning3d/examples/train_masknet.py +239 -0
- learning3d/examples/train_pcn.py +216 -0
- learning3d/examples/train_pcrnet.py +228 -0
- learning3d/examples/train_pointconv.py +245 -0
- learning3d/examples/train_pointnet.py +244 -0
- learning3d/examples/train_prnet.py +229 -0
- learning3d/examples/train_rpmnet.py +228 -0
- learning3d/losses/__init__.py +12 -0
- learning3d/losses/chamfer_distance.py +51 -0
- learning3d/losses/classification.py +14 -0
- learning3d/losses/correspondence_loss.py +10 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
- learning3d/losses/cuda/emd_torch/setup.py +29 -0
- learning3d/losses/emd.py +16 -0
- learning3d/losses/frobenius_norm.py +21 -0
- learning3d/losses/rmse_features.py +16 -0
- learning3d/models/__init__.py +23 -0
- learning3d/models/classifier.py +41 -0
- learning3d/models/dcp.py +92 -0
- learning3d/models/deepgmr.py +165 -0
- learning3d/models/dgcnn.py +92 -0
- learning3d/models/flownet3d.py +446 -0
- learning3d/models/masknet.py +84 -0
- learning3d/models/masknet2.py +264 -0
- learning3d/models/pcn.py +164 -0
- learning3d/models/pcrnet.py +74 -0
- learning3d/models/pointconv.py +108 -0
- learning3d/models/pointnet.py +108 -0
- learning3d/models/pointnetlk.py +173 -0
- learning3d/models/pooling.py +15 -0
- learning3d/models/ppfnet.py +102 -0
- learning3d/models/prnet.py +431 -0
- learning3d/models/rpmnet.py +359 -0
- learning3d/models/segmentation.py +38 -0
- learning3d/ops/__init__.py +0 -0
- learning3d/ops/data_utils.py +45 -0
- learning3d/ops/invmat.py +134 -0
- learning3d/ops/quaternion.py +218 -0
- learning3d/ops/se3.py +157 -0
- learning3d/ops/sinc.py +229 -0
- learning3d/ops/so3.py +213 -0
- learning3d/ops/transform_functions.py +342 -0
- learning3d/utils/__init__.py +9 -0
- learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
- learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
- learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
- learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
- learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
- learning3d/utils/lib/pointnet2_modules.py +160 -0
- learning3d/utils/lib/pointnet2_utils.py +318 -0
- learning3d/utils/lib/pytorch_utils.py +236 -0
- learning3d/utils/lib/setup.py +23 -0
- learning3d/utils/lib/src/ball_query.cpp +25 -0
- learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
- learning3d/utils/lib/src/ball_query_gpu.h +15 -0
- learning3d/utils/lib/src/cuda_utils.h +15 -0
- learning3d/utils/lib/src/group_points.cpp +36 -0
- learning3d/utils/lib/src/group_points_gpu.cu +86 -0
- learning3d/utils/lib/src/group_points_gpu.h +22 -0
- learning3d/utils/lib/src/interpolate.cpp +65 -0
- learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
- learning3d/utils/lib/src/interpolate_gpu.h +36 -0
- learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
- learning3d/utils/lib/src/sampling.cpp +46 -0
- learning3d/utils/lib/src/sampling_gpu.cu +253 -0
- learning3d/utils/lib/src/sampling_gpu.h +29 -0
- learning3d/utils/pointconv_util.py +382 -0
- learning3d/utils/ppfnet_util.py +244 -0
- learning3d/utils/svd.py +59 -0
- learning3d/utils/transformer.py +243 -0
- learning3d-0.0.1.dist-info/LICENSE +21 -0
- learning3d-0.0.1.dist-info/METADATA +271 -0
- learning3d-0.0.1.dist-info/RECORD +115 -0
- learning3d-0.0.1.dist-info/WHEEL +5 -0
- learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,259 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
|
5
|
+
from __future__ import print_function
|
6
|
+
import os
|
7
|
+
import gc
|
8
|
+
import argparse
|
9
|
+
import torch
|
10
|
+
import torch.nn as nn
|
11
|
+
import torch.nn.functional as F
|
12
|
+
import torch.optim as optim
|
13
|
+
from torch.optim.lr_scheduler import MultiStepLR
|
14
|
+
from learning3d.models import FlowNet3D
|
15
|
+
from learning3d.data_utils import SceneflowDataset
|
16
|
+
import numpy as np
|
17
|
+
from torch.utils.data import DataLoader
|
18
|
+
from tensorboardX import SummaryWriter
|
19
|
+
from tqdm import tqdm
|
20
|
+
|
21
|
+
class IOStream:
|
22
|
+
def __init__(self, path):
|
23
|
+
self.f = open(path, 'a')
|
24
|
+
|
25
|
+
def cprint(self, text):
|
26
|
+
print(text)
|
27
|
+
self.f.write(text + '\n')
|
28
|
+
self.f.flush()
|
29
|
+
|
30
|
+
def close(self):
|
31
|
+
self.f.close()
|
32
|
+
|
33
|
+
|
34
|
+
def _init_(args):
|
35
|
+
if not os.path.exists('checkpoints'):
|
36
|
+
os.makedirs('checkpoints')
|
37
|
+
if not os.path.exists('checkpoints/' + args.exp_name):
|
38
|
+
os.makedirs('checkpoints/' + args.exp_name)
|
39
|
+
if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
|
40
|
+
os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
|
41
|
+
|
42
|
+
def weights_init(m):
|
43
|
+
classname=m.__class__.__name__
|
44
|
+
if classname.find('Conv2d') != -1:
|
45
|
+
nn.init.kaiming_normal_(m.weight.data)
|
46
|
+
if classname.find('Conv1d') != -1:
|
47
|
+
nn.init.kaiming_normal_(m.weight.data)
|
48
|
+
|
49
|
+
def test_one_epoch(args, net, test_loader):
|
50
|
+
net.eval()
|
51
|
+
|
52
|
+
total_loss = 0
|
53
|
+
num_examples = 0
|
54
|
+
for i, data in tqdm(enumerate(test_loader), total=len(test_loader), smoothing=0.9):
|
55
|
+
pc1, pc2, color1, color2, flow, mask1 = data
|
56
|
+
pc1 = pc1.cuda().transpose(2,1).contiguous()
|
57
|
+
pc2 = pc2.cuda().transpose(2,1).contiguous()
|
58
|
+
color1 = color1.cuda().transpose(2,1).contiguous()
|
59
|
+
color2 = color2.cuda().transpose(2,1).contiguous()
|
60
|
+
flow = flow.cuda()
|
61
|
+
mask1 = mask1.cuda().float()
|
62
|
+
|
63
|
+
batch_size = pc1.size(0)
|
64
|
+
num_examples += batch_size
|
65
|
+
flow_pred = net(pc1, pc2, color1, color2).permute(0,2,1)
|
66
|
+
loss_1 = torch.mean(mask1 * torch.sum((flow_pred - flow) * (flow_pred - flow), -1) / 2.0)
|
67
|
+
|
68
|
+
pc1, pc2 = pc1.permute(0,2,1), pc2.permute(0,2,1)
|
69
|
+
pc1_ = pc1 + flow_pred
|
70
|
+
|
71
|
+
total_loss += loss_1.item() * batch_size
|
72
|
+
|
73
|
+
|
74
|
+
return total_loss * 1.0 / num_examples
|
75
|
+
|
76
|
+
|
77
|
+
def train_one_epoch(args, net, train_loader, opt):
|
78
|
+
net.train()
|
79
|
+
num_examples = 0
|
80
|
+
total_loss = 0
|
81
|
+
for i, data in tqdm(enumerate(train_loader), total=len(train_loader), smoothing=0.9):
|
82
|
+
pc1, pc2, color1, color2, flow, mask1 = data
|
83
|
+
pc1 = pc1.cuda().transpose(2,1).contiguous()
|
84
|
+
pc2 = pc2.cuda().transpose(2,1).contiguous()
|
85
|
+
color1 = color1.cuda().transpose(2,1).contiguous()
|
86
|
+
color2 = color2.cuda().transpose(2,1).contiguous()
|
87
|
+
flow = flow.cuda().transpose(2,1).contiguous()
|
88
|
+
mask1 = mask1.cuda().float()
|
89
|
+
|
90
|
+
batch_size = pc1.size(0)
|
91
|
+
opt.zero_grad()
|
92
|
+
num_examples += batch_size
|
93
|
+
flow_pred = net(pc1, pc2, color1, color2)
|
94
|
+
loss_1 = torch.mean(mask1 * torch.sum((flow_pred - flow) ** 2, 1) / 2.0)
|
95
|
+
|
96
|
+
pc1, pc2, flow_pred = pc1.permute(0,2,1), pc2.permute(0,2,1), flow_pred.permute(0,2,1)
|
97
|
+
pc1_ = pc1 + flow_pred
|
98
|
+
|
99
|
+
loss_1.backward()
|
100
|
+
|
101
|
+
opt.step()
|
102
|
+
total_loss += loss_1.item() * batch_size
|
103
|
+
|
104
|
+
# if (i+1) % 100 == 0:
|
105
|
+
# print("batch: %d, mean loss: %f" % (i, total_loss / 100 / batch_size))
|
106
|
+
# total_loss = 0
|
107
|
+
return total_loss * 1.0 / num_examples
|
108
|
+
|
109
|
+
|
110
|
+
def test(args, net, test_loader, boardio, textio):
|
111
|
+
|
112
|
+
test_loss = test_one_epoch(args, net, test_loader)
|
113
|
+
|
114
|
+
textio.cprint('==FINAL TEST==')
|
115
|
+
textio.cprint('mean test loss: %f'%test_loss)
|
116
|
+
|
117
|
+
|
118
|
+
def train(args, net, train_loader, test_loader, boardio, textio):
|
119
|
+
if args.use_sgd:
|
120
|
+
print("Use SGD")
|
121
|
+
opt = optim.SGD(net.parameters(), lr=args.lr * 100, momentum=args.momentum, weight_decay=1e-4)
|
122
|
+
else:
|
123
|
+
print("Use Adam")
|
124
|
+
opt = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-4)
|
125
|
+
scheduler = MultiStepLR(opt, milestones=[75, 150, 200], gamma=0.1)
|
126
|
+
|
127
|
+
best_test_loss = np.inf
|
128
|
+
for epoch in range(args.epochs):
|
129
|
+
scheduler.step()
|
130
|
+
textio.cprint('==epoch: %d=='%epoch)
|
131
|
+
train_loss = train_one_epoch(args, net, train_loader, opt)
|
132
|
+
textio.cprint('mean train EPE loss: %f'%train_loss)
|
133
|
+
|
134
|
+
test_loss = test_one_epoch(args, net, test_loader)
|
135
|
+
textio.cprint('mean test EPE loss: %f'%test_loss)
|
136
|
+
|
137
|
+
if best_test_loss >= test_loss:
|
138
|
+
best_test_loss = test_loss
|
139
|
+
textio.cprint('best test loss till now: %f'%test_loss)
|
140
|
+
if torch.cuda.device_count() > 1:
|
141
|
+
torch.save(net.module.state_dict(), 'checkpoints/%s/models/model.best.t7' % args.exp_name)
|
142
|
+
else:
|
143
|
+
torch.save(net.state_dict(), 'checkpoints/%s/models/model.best.t7' % args.exp_name)
|
144
|
+
|
145
|
+
boardio.add_scalar('Train Loss', train_loss, epoch+1)
|
146
|
+
boardio.add_scalar('Test Loss', test_loss, epoch+1)
|
147
|
+
boardio.add_scalar('Best Test Loss', best_test_loss, epoch+1)
|
148
|
+
|
149
|
+
if torch.cuda.device_count() > 1:
|
150
|
+
torch.save(net.module.state_dict(), 'checkpoints/%s/models/model.%d.t7' % (args.exp_name, epoch))
|
151
|
+
else:
|
152
|
+
torch.save(net.state_dict(), 'checkpoints/%s/models/model.%d.t7' % (args.exp_name, epoch))
|
153
|
+
gc.collect()
|
154
|
+
|
155
|
+
|
156
|
+
def main():
|
157
|
+
parser = argparse.ArgumentParser(description='Point Cloud Registration')
|
158
|
+
parser.add_argument('--exp_name', type=str, default='exp_flownet', metavar='N',
|
159
|
+
help='Name of the experiment')
|
160
|
+
parser.add_argument('--model', type=str, default='flownet', metavar='N',
|
161
|
+
choices=['flownet'],
|
162
|
+
help='Model to use, [flownet]')
|
163
|
+
parser.add_argument('--emb_dims', type=int, default=512, metavar='N',
|
164
|
+
help='Dimension of embeddings')
|
165
|
+
parser.add_argument('--num_points', type=int, default=2048,
|
166
|
+
help='Point Number [default: 2048]')
|
167
|
+
parser.add_argument('--dropout', type=float, default=0.5, metavar='N',
|
168
|
+
help='Dropout ratio in transformer')
|
169
|
+
parser.add_argument('--batch_size', type=int, default=16, metavar='batch_size',
|
170
|
+
help='Size of batch)')
|
171
|
+
parser.add_argument('--test_batch_size', type=int, default=10, metavar='batch_size',
|
172
|
+
help='Size of batch)')
|
173
|
+
parser.add_argument('--epochs', type=int, default=250, metavar='N',
|
174
|
+
help='number of episode to train ')
|
175
|
+
parser.add_argument('--use_sgd', action='store_true', default=True,
|
176
|
+
help='Use SGD')
|
177
|
+
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
|
178
|
+
help='learning rate (default: 0.001, 0.1 if using sgd)')
|
179
|
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
180
|
+
help='SGD momentum (default: 0.9)')
|
181
|
+
parser.add_argument('--no_cuda', action='store_true', default=False,
|
182
|
+
help='enables CUDA training')
|
183
|
+
parser.add_argument('--seed', type=int, default=1234, metavar='S',
|
184
|
+
help='random seed (default: 1)')
|
185
|
+
parser.add_argument('--eval', action='store_true', default=False,
|
186
|
+
help='evaluate the model')
|
187
|
+
parser.add_argument('--cycle', type=bool, default=False, metavar='N',
|
188
|
+
help='Whether to use cycle consistency')
|
189
|
+
parser.add_argument('--gaussian_noise', type=bool, default=False, metavar='N',
|
190
|
+
help='Wheter to add gaussian noise')
|
191
|
+
parser.add_argument('--unseen', type=bool, default=False, metavar='N',
|
192
|
+
help='Whether to test on unseen category')
|
193
|
+
parser.add_argument('--dataset', type=str, default='SceneflowDataset',
|
194
|
+
choices=['SceneflowDataset'], metavar='N',
|
195
|
+
help='dataset to use')
|
196
|
+
parser.add_argument('--dataset_path', type=str, default='data_processed_maxcut_35_20k_2k_8192', metavar='N',
|
197
|
+
help='dataset to use')
|
198
|
+
parser.add_argument('--model_path', type=str, default='', metavar='N',
|
199
|
+
help='Pretrained model path')
|
200
|
+
parser.add_argument('--pretrained', type=str, default='', metavar='N',
|
201
|
+
help='Pretrained model path')
|
202
|
+
|
203
|
+
args = parser.parse_args()
|
204
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
205
|
+
# CUDA settings
|
206
|
+
torch.backends.cudnn.deterministic = True
|
207
|
+
torch.manual_seed(args.seed)
|
208
|
+
torch.cuda.manual_seed_all(args.seed)
|
209
|
+
np.random.seed(args.seed)
|
210
|
+
|
211
|
+
boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
|
212
|
+
_init_(args)
|
213
|
+
|
214
|
+
textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
|
215
|
+
textio.cprint(str(args))
|
216
|
+
|
217
|
+
if args.dataset == 'SceneflowDataset':
|
218
|
+
train_loader = DataLoader(
|
219
|
+
SceneflowDataset(npoints=args.num_points, partition='train'),
|
220
|
+
batch_size=args.batch_size, shuffle=True, drop_last=True)
|
221
|
+
test_loader = DataLoader(
|
222
|
+
SceneflowDataset(npoints=args.num_points, partition='test'),
|
223
|
+
batch_size=args.test_batch_size, shuffle=False, drop_last=False)
|
224
|
+
else:
|
225
|
+
raise Exception("not implemented")
|
226
|
+
|
227
|
+
if args.model == 'flownet':
|
228
|
+
net = FlowNet3D().cuda()
|
229
|
+
net.apply(weights_init)
|
230
|
+
if args.pretrained:
|
231
|
+
net.load_state_dict(torch.load(args.pretrained), strict=False)
|
232
|
+
print("Pretrained Model Loaded Successfully!")
|
233
|
+
if args.eval:
|
234
|
+
if args.model_path is '':
|
235
|
+
model_path = 'checkpoints' + '/' + args.exp_name + '/models/model.best.t7'
|
236
|
+
else:
|
237
|
+
model_path = args.model_path
|
238
|
+
print(model_path)
|
239
|
+
if not os.path.exists(model_path):
|
240
|
+
print("can't find pretrained model")
|
241
|
+
return
|
242
|
+
net.load_state_dict(torch.load(model_path), strict=False)
|
243
|
+
if torch.cuda.device_count() > 1:
|
244
|
+
net = nn.DataParallel(net)
|
245
|
+
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
246
|
+
else:
|
247
|
+
raise Exception('Not implemented')
|
248
|
+
if args.eval:
|
249
|
+
test(args, net, test_loader, boardio, textio)
|
250
|
+
else:
|
251
|
+
train(args, net, train_loader, test_loader, boardio, textio)
|
252
|
+
|
253
|
+
|
254
|
+
print('FINISH')
|
255
|
+
# boardio.close()
|
256
|
+
|
257
|
+
|
258
|
+
if __name__ == '__main__':
|
259
|
+
main()
|
@@ -0,0 +1,239 @@
|
|
1
|
+
import argparse
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
import logging
|
5
|
+
import numpy
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.utils.data
|
9
|
+
import torchvision
|
10
|
+
from torch.utils.data import DataLoader
|
11
|
+
from tensorboardX import SummaryWriter
|
12
|
+
from tqdm import tqdm
|
13
|
+
|
14
|
+
# Only if the files are in example folder.
|
15
|
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
16
|
+
if BASE_DIR[-8:] == 'examples':
|
17
|
+
sys.path.append(os.path.join(BASE_DIR, os.pardir))
|
18
|
+
os.chdir(os.path.join(BASE_DIR, os.pardir))
|
19
|
+
|
20
|
+
from learning3d.models import MaskNet
|
21
|
+
from learning3d.data_utils import RegistrationData, ModelNet40Data
|
22
|
+
|
23
|
+
def _init_(args):
|
24
|
+
if not os.path.exists('checkpoints'):
|
25
|
+
os.makedirs('checkpoints')
|
26
|
+
if not os.path.exists('checkpoints/' + args.exp_name):
|
27
|
+
os.makedirs('checkpoints/' + args.exp_name)
|
28
|
+
if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'):
|
29
|
+
os.makedirs('checkpoints/' + args.exp_name + '/' + 'models')
|
30
|
+
os.system('cp train.py checkpoints' + '/' + args.exp_name + '/' + 'train.py.backup')
|
31
|
+
os.system('cp learning3d/models/masknet.py checkpoints' + '/' + args.exp_name + '/' + 'masknet.py.backup')
|
32
|
+
os.system('cp learning3d/data_utils/dataloaders.py checkpoints' + '/' + args.exp_name + '/' + 'dataloaders.py.backup')
|
33
|
+
|
34
|
+
|
35
|
+
class IOStream:
|
36
|
+
def __init__(self, path):
|
37
|
+
self.f = open(path, 'a')
|
38
|
+
|
39
|
+
def cprint(self, text):
|
40
|
+
print(text)
|
41
|
+
self.f.write(text + '\n')
|
42
|
+
self.f.flush()
|
43
|
+
|
44
|
+
def close(self):
|
45
|
+
self.f.close()
|
46
|
+
|
47
|
+
def test_one_epoch(args, model, test_loader):
|
48
|
+
model.eval()
|
49
|
+
test_loss = 0.0
|
50
|
+
pred = 0.0
|
51
|
+
count = 0
|
52
|
+
for i, data in enumerate(tqdm(test_loader)):
|
53
|
+
template, source, igt, gt_mask = data
|
54
|
+
|
55
|
+
template = template.to(args.device)
|
56
|
+
source = source.to(args.device)
|
57
|
+
igt = igt.to(args.device) # [source] = [igt]*[template]
|
58
|
+
gt_mask = gt_mask.to(args.device)
|
59
|
+
|
60
|
+
masked_template, predicted_mask = model(template, source)
|
61
|
+
|
62
|
+
if args.loss_fn == 'mse':
|
63
|
+
loss_mask = torch.nn.functional.mse_loss(predicted_mask, gt_mask)
|
64
|
+
elif args.loss_fn == 'bce':
|
65
|
+
loss_mask = torch.nn.BCELoss()(predicted_mask, gt_mask)
|
66
|
+
|
67
|
+
test_loss += loss_mask.item()
|
68
|
+
count += 1
|
69
|
+
|
70
|
+
test_loss = float(test_loss)/count
|
71
|
+
return test_loss
|
72
|
+
|
73
|
+
def test(args, model, test_loader, textio):
|
74
|
+
test_loss, test_accuracy = test_one_epoch(args.device, model, pnlk, test_loader)
|
75
|
+
textio.cprint('Validation Loss: %f & Validation Accuracy: %f'%(test_loss, test_accuracy))
|
76
|
+
|
77
|
+
def train_one_epoch(args, model, train_loader, optimizer):
|
78
|
+
model.train()
|
79
|
+
train_loss = 0.0
|
80
|
+
pred = 0.0
|
81
|
+
count = 0
|
82
|
+
for i, data in enumerate(tqdm(train_loader)):
|
83
|
+
template, source, igt, gt_mask = data
|
84
|
+
|
85
|
+
template = template.to(args.device)
|
86
|
+
source = source.to(args.device)
|
87
|
+
igt = igt.to(args.device) # [source] = [igt]*[template]
|
88
|
+
gt_mask = gt_mask.to(args.device)
|
89
|
+
|
90
|
+
masked_template, predicted_mask = model(template, source)
|
91
|
+
|
92
|
+
if args.loss_fn == 'mse':
|
93
|
+
loss_mask = torch.nn.functional.mse_loss(predicted_mask, gt_mask)
|
94
|
+
elif args.loss_fn == 'bce':
|
95
|
+
loss_mask = torch.nn.BCELoss()(predicted_mask, gt_mask)
|
96
|
+
|
97
|
+
# forward + backward + optimize
|
98
|
+
optimizer.zero_grad()
|
99
|
+
loss_mask.backward()
|
100
|
+
optimizer.step()
|
101
|
+
|
102
|
+
train_loss += loss_mask.item()
|
103
|
+
count += 1
|
104
|
+
|
105
|
+
train_loss = float(train_loss)/count
|
106
|
+
return train_loss
|
107
|
+
|
108
|
+
def train(args, model, train_loader, test_loader, boardio, textio, checkpoint):
|
109
|
+
learnable_params = filter(lambda p: p.requires_grad, model.parameters())
|
110
|
+
if args.optimizer == 'Adam':
|
111
|
+
optimizer = torch.optim.Adam(learnable_params, lr=0.0001)
|
112
|
+
else:
|
113
|
+
optimizer = torch.optim.SGD(learnable_params, lr=0.1)
|
114
|
+
|
115
|
+
if checkpoint is not None:
|
116
|
+
min_loss = checkpoint['min_loss']
|
117
|
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
118
|
+
|
119
|
+
best_test_loss = np.inf
|
120
|
+
|
121
|
+
for epoch in range(args.start_epoch, args.epochs):
|
122
|
+
train_loss = train_one_epoch(args, model, train_loader, optimizer)
|
123
|
+
test_loss = test_one_epoch(args, model, test_loader)
|
124
|
+
|
125
|
+
if test_loss<best_test_loss:
|
126
|
+
best_test_loss = test_loss
|
127
|
+
|
128
|
+
snap = {'epoch': epoch + 1,
|
129
|
+
'model': model.state_dict(),
|
130
|
+
'min_loss': best_test_loss,
|
131
|
+
'optimizer' : optimizer.state_dict(),}
|
132
|
+
torch.save(snap, 'checkpoints/%s/models/best_model_snap.t7' % (args.exp_name))
|
133
|
+
torch.save(model.state_dict(), 'checkpoints/%s/models/best_model.t7' % (args.exp_name))
|
134
|
+
|
135
|
+
snap = {'epoch': epoch + 1,
|
136
|
+
'model': model.state_dict(),
|
137
|
+
'min_loss': best_test_loss,
|
138
|
+
'optimizer' : optimizer.state_dict(),}
|
139
|
+
torch.save(snap, 'checkpoints/%s/models/model_snap.t7' % (args.exp_name))
|
140
|
+
torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % (args.exp_name))
|
141
|
+
|
142
|
+
boardio.add_scalar('Train_Loss', train_loss, epoch+1)
|
143
|
+
boardio.add_scalar('Test_Loss', test_loss, epoch+1)
|
144
|
+
boardio.add_scalar('Best_Test_Loss', best_test_loss, epoch+1)
|
145
|
+
|
146
|
+
textio.cprint('EPOCH:: %d, Traininig Loss: %f, Testing Loss: %f, Best Loss: %f'%(epoch+1, train_loss, test_loss, best_test_loss))
|
147
|
+
|
148
|
+
def options():
|
149
|
+
parser = argparse.ArgumentParser(description='MaskNet: A Fully-Convolutional Network For Inlier Estimation (Training)')
|
150
|
+
parser.add_argument('--exp_name', type=str, default='exp_masknet', metavar='N',
|
151
|
+
help='Name of the experiment')
|
152
|
+
parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.')
|
153
|
+
|
154
|
+
# settings for input data
|
155
|
+
parser.add_argument('--num_points', default=1024, type=int,
|
156
|
+
metavar='N', help='points in point-cloud (default: 1024)')
|
157
|
+
parser.add_argument('--partial_source', default=True, type=bool,
|
158
|
+
help='create partial source point cloud in dataset.')
|
159
|
+
parser.add_argument('--noise', default=False, type=bool,
|
160
|
+
help='Add noise in source point clouds.')
|
161
|
+
parser.add_argument('--outliers', default=False, type=bool,
|
162
|
+
help='Add outliers to template point cloud.')
|
163
|
+
|
164
|
+
# settings for on training
|
165
|
+
parser.add_argument('--seed', type=int, default=1234)
|
166
|
+
parser.add_argument('-j', '--workers', default=4, type=int,
|
167
|
+
metavar='N', help='number of data loading workers (default: 4)')
|
168
|
+
parser.add_argument('-b', '--batch_size', default=32, type=int,
|
169
|
+
metavar='N', help='mini-batch size (default: 32)')
|
170
|
+
parser.add_argument('--test_batch_size', default=8, type=int,
|
171
|
+
metavar='N', help='test-mini-batch size (default: 8)')
|
172
|
+
parser.add_argument('--unseen', default=False, type=bool,
|
173
|
+
help='Use first 20 categories for training and last 20 for testing')
|
174
|
+
parser.add_argument('--epochs', default=500, type=int,
|
175
|
+
metavar='N', help='number of total epochs to run')
|
176
|
+
parser.add_argument('--start_epoch', default=0, type=int,
|
177
|
+
metavar='N', help='manual epoch number (useful on restarts)')
|
178
|
+
parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'],
|
179
|
+
metavar='METHOD', help='name of an optimizer (default: Adam)')
|
180
|
+
parser.add_argument('--resume', default='', type=str,
|
181
|
+
metavar='PATH', help='path to latest checkpoint (default: null (no-use))')
|
182
|
+
parser.add_argument('--pretrained', default='', type=str,
|
183
|
+
metavar='PATH', help='path to pretrained model file (default: null (no-use))')
|
184
|
+
parser.add_argument('--device', default='cuda:0', type=str,
|
185
|
+
metavar='DEVICE', help='use CUDA if available')
|
186
|
+
parser.add_argument('--loss_fn', default='mse', type=str, choices=['mse', 'bce'])
|
187
|
+
|
188
|
+
args = parser.parse_args()
|
189
|
+
return args
|
190
|
+
|
191
|
+
def main():
|
192
|
+
args = options()
|
193
|
+
|
194
|
+
torch.backends.cudnn.deterministic = True
|
195
|
+
torch.manual_seed(args.seed)
|
196
|
+
torch.cuda.manual_seed_all(args.seed)
|
197
|
+
np.random.seed(args.seed)
|
198
|
+
|
199
|
+
boardio = SummaryWriter(log_dir='checkpoints/' + args.exp_name)
|
200
|
+
_init_(args)
|
201
|
+
|
202
|
+
textio = IOStream('checkpoints/' + args.exp_name + '/run.log')
|
203
|
+
textio.cprint(str(args))
|
204
|
+
|
205
|
+
trainset = RegistrationData(ModelNet40Data(train=True, num_points=args.num_points, unseen=args.unseen),
|
206
|
+
partial_source=args.partial_source, noise=args.noise, outliers=args.outliers,
|
207
|
+
additional_params={'use_masknet': True})
|
208
|
+
testset = RegistrationData(ModelNet40Data(train=False, num_points=args.num_points, unseen=args.unseen),
|
209
|
+
partial_source=args.partial_source, noise=args.noise, outliers=args.outliers,
|
210
|
+
additional_params={'use_masknet': True})
|
211
|
+
train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
|
212
|
+
test_loader = DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, drop_last=False, num_workers=args.workers)
|
213
|
+
|
214
|
+
if not torch.cuda.is_available():
|
215
|
+
args.device = 'cpu'
|
216
|
+
args.device = torch.device(args.device)
|
217
|
+
|
218
|
+
model = MaskNet()
|
219
|
+
model = model.to(args.device)
|
220
|
+
|
221
|
+
checkpoint = None
|
222
|
+
if args.resume:
|
223
|
+
assert os.path.isfile(args.resume)
|
224
|
+
checkpoint = torch.load(args.resume)
|
225
|
+
args.start_epoch = checkpoint['epoch']
|
226
|
+
model.load_state_dict(checkpoint['model'])
|
227
|
+
|
228
|
+
if args.pretrained:
|
229
|
+
assert os.path.isfile(args.pretrained)
|
230
|
+
model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
|
231
|
+
model.to(args.device)
|
232
|
+
|
233
|
+
if args.eval:
|
234
|
+
test(args, model, test_loader, textio)
|
235
|
+
else:
|
236
|
+
train(args, model, train_loader, test_loader, boardio, textio, checkpoint)
|
237
|
+
|
238
|
+
if __name__ == '__main__':
|
239
|
+
main()
|