learning3d 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- learning3d/__init__.py +2 -0
- learning3d/data_utils/__init__.py +4 -0
- learning3d/data_utils/dataloaders.py +454 -0
- learning3d/data_utils/user_data.py +119 -0
- learning3d/examples/test_dcp.py +139 -0
- learning3d/examples/test_deepgmr.py +144 -0
- learning3d/examples/test_flownet.py +113 -0
- learning3d/examples/test_masknet.py +159 -0
- learning3d/examples/test_masknet2.py +162 -0
- learning3d/examples/test_pcn.py +118 -0
- learning3d/examples/test_pcrnet.py +120 -0
- learning3d/examples/test_pnlk.py +121 -0
- learning3d/examples/test_pointconv.py +126 -0
- learning3d/examples/test_pointnet.py +121 -0
- learning3d/examples/test_prnet.py +126 -0
- learning3d/examples/test_rpmnet.py +120 -0
- learning3d/examples/train_PointNetLK.py +240 -0
- learning3d/examples/train_dcp.py +249 -0
- learning3d/examples/train_deepgmr.py +244 -0
- learning3d/examples/train_flownet.py +259 -0
- learning3d/examples/train_masknet.py +239 -0
- learning3d/examples/train_pcn.py +216 -0
- learning3d/examples/train_pcrnet.py +228 -0
- learning3d/examples/train_pointconv.py +245 -0
- learning3d/examples/train_pointnet.py +244 -0
- learning3d/examples/train_prnet.py +229 -0
- learning3d/examples/train_rpmnet.py +228 -0
- learning3d/losses/__init__.py +12 -0
- learning3d/losses/chamfer_distance.py +51 -0
- learning3d/losses/classification.py +14 -0
- learning3d/losses/correspondence_loss.py +10 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
- learning3d/losses/cuda/emd_torch/setup.py +29 -0
- learning3d/losses/emd.py +16 -0
- learning3d/losses/frobenius_norm.py +21 -0
- learning3d/losses/rmse_features.py +16 -0
- learning3d/models/__init__.py +23 -0
- learning3d/models/classifier.py +41 -0
- learning3d/models/dcp.py +92 -0
- learning3d/models/deepgmr.py +165 -0
- learning3d/models/dgcnn.py +92 -0
- learning3d/models/flownet3d.py +446 -0
- learning3d/models/masknet.py +84 -0
- learning3d/models/masknet2.py +264 -0
- learning3d/models/pcn.py +164 -0
- learning3d/models/pcrnet.py +74 -0
- learning3d/models/pointconv.py +108 -0
- learning3d/models/pointnet.py +108 -0
- learning3d/models/pointnetlk.py +173 -0
- learning3d/models/pooling.py +15 -0
- learning3d/models/ppfnet.py +102 -0
- learning3d/models/prnet.py +431 -0
- learning3d/models/rpmnet.py +359 -0
- learning3d/models/segmentation.py +38 -0
- learning3d/ops/__init__.py +0 -0
- learning3d/ops/data_utils.py +45 -0
- learning3d/ops/invmat.py +134 -0
- learning3d/ops/quaternion.py +218 -0
- learning3d/ops/se3.py +157 -0
- learning3d/ops/sinc.py +229 -0
- learning3d/ops/so3.py +213 -0
- learning3d/ops/transform_functions.py +342 -0
- learning3d/utils/__init__.py +9 -0
- learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
- learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
- learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
- learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
- learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
- learning3d/utils/lib/pointnet2_modules.py +160 -0
- learning3d/utils/lib/pointnet2_utils.py +318 -0
- learning3d/utils/lib/pytorch_utils.py +236 -0
- learning3d/utils/lib/setup.py +23 -0
- learning3d/utils/lib/src/ball_query.cpp +25 -0
- learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
- learning3d/utils/lib/src/ball_query_gpu.h +15 -0
- learning3d/utils/lib/src/cuda_utils.h +15 -0
- learning3d/utils/lib/src/group_points.cpp +36 -0
- learning3d/utils/lib/src/group_points_gpu.cu +86 -0
- learning3d/utils/lib/src/group_points_gpu.h +22 -0
- learning3d/utils/lib/src/interpolate.cpp +65 -0
- learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
- learning3d/utils/lib/src/interpolate_gpu.h +36 -0
- learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
- learning3d/utils/lib/src/sampling.cpp +46 -0
- learning3d/utils/lib/src/sampling_gpu.cu +253 -0
- learning3d/utils/lib/src/sampling_gpu.h +29 -0
- learning3d/utils/pointconv_util.py +382 -0
- learning3d/utils/ppfnet_util.py +244 -0
- learning3d/utils/svd.py +59 -0
- learning3d/utils/transformer.py +243 -0
- learning3d-0.0.1.dist-info/LICENSE +21 -0
- learning3d-0.0.1.dist-info/METADATA +271 -0
- learning3d-0.0.1.dist-info/RECORD +115 -0
- learning3d-0.0.1.dist-info/WHEEL +5 -0
- learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,264 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .pooling import Pooling
|
5
|
+
|
6
|
+
|
7
|
+
# Mish Activation Function
|
8
|
+
class Mish(nn.Module):
|
9
|
+
def __init__(self):
|
10
|
+
super(Mish, self).__init__()
|
11
|
+
|
12
|
+
def forward(self, x):
|
13
|
+
return x * torch.tanh(F.softplus(x))
|
14
|
+
|
15
|
+
|
16
|
+
# Basic Convolution Block
|
17
|
+
class BasicConv1D(nn.Module):
|
18
|
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, active = True):
|
19
|
+
super(BasicConv1D, self).__init__()
|
20
|
+
self.active = active
|
21
|
+
self.bn = nn.BatchNorm1d( out_channels)
|
22
|
+
if self.active == True:
|
23
|
+
self.activation = Mish()
|
24
|
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, bias=False)
|
25
|
+
#self.dropout = nn.Dropout(0.5)
|
26
|
+
|
27
|
+
def forward(self, x):
|
28
|
+
x = self.conv(x)
|
29
|
+
x = self.bn(x)
|
30
|
+
if self.active == True:
|
31
|
+
x = self.activation(x)
|
32
|
+
return x
|
33
|
+
|
34
|
+
|
35
|
+
class Self_Attn(nn.Module):
|
36
|
+
""" Self attention Layer"""
|
37
|
+
def __init__(self, in_dim, out_dim):
|
38
|
+
super(Self_Attn,self).__init__()
|
39
|
+
|
40
|
+
self.in_dim = in_dim
|
41
|
+
self.out_dim = out_dim
|
42
|
+
|
43
|
+
# Query Convolution
|
44
|
+
self.query_conv =BasicConv1D(in_dim, out_dim)
|
45
|
+
|
46
|
+
self.beta = nn.Parameter(torch.zeros(1))
|
47
|
+
|
48
|
+
self.softmax = nn.Softmax(dim=-1) #
|
49
|
+
|
50
|
+
def forward(self,x):
|
51
|
+
"""
|
52
|
+
inputs :
|
53
|
+
x : input feature maps( B X C X N) 32, 1024, 64
|
54
|
+
returns :
|
55
|
+
out : self attention value + input feature
|
56
|
+
attention: B X N X N (N is Width*Height)
|
57
|
+
"""
|
58
|
+
|
59
|
+
proj_query = self.query_conv(x).permute(0,2,1) # B, in_dim, N ---> B, in_dim // 8, N ----> B, N, in_dim // 8
|
60
|
+
proj_key = proj_query.permute(0,2,1) #B, in_dim, N ---> B, in_dim // 8, N
|
61
|
+
|
62
|
+
energy = torch.bmm(proj_query,proj_key) # transpose check B, N, N
|
63
|
+
|
64
|
+
attention = self.softmax(energy) # B , N, N
|
65
|
+
|
66
|
+
out_x = torch.bmm(proj_key, attention.permute(0,2,1) ) #B, out_dim, N
|
67
|
+
|
68
|
+
out = self.beta * out_x + proj_key
|
69
|
+
|
70
|
+
return out
|
71
|
+
|
72
|
+
class PointNet(torch.nn.Module):
|
73
|
+
def __init__(self, emb_dims=224, input_shape="bnc", use_bn=False, global_feat=True):
|
74
|
+
# emb_dims: Embedding Dimensions for PointNet.
|
75
|
+
# input_shape: Shape of Input Point Cloud (b: batch, n: no of points, c: channels)
|
76
|
+
super(PointNet, self).__init__()
|
77
|
+
if input_shape not in ["bcn", "bnc"]:
|
78
|
+
raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
|
79
|
+
self.input_shape = input_shape
|
80
|
+
self.emb_dims = emb_dims
|
81
|
+
self.use_bn = use_bn
|
82
|
+
self.global_feat = global_feat
|
83
|
+
if not self.global_feat: self.pooling = Pooling('max')
|
84
|
+
|
85
|
+
self.conv1 = Self_Attn(3, 32)
|
86
|
+
self.conv2 = Self_Attn(32, 64)
|
87
|
+
self.conv3 = Self_Attn(64, 64)
|
88
|
+
self.conv4 = Self_Attn(64, 128)
|
89
|
+
self.conv5 = Self_Attn(128, self.emb_dims)
|
90
|
+
|
91
|
+
|
92
|
+
def forward(self, input_data):
|
93
|
+
# input_data: Point Cloud having shape input_shape.
|
94
|
+
# output: PointNet features (Batch x emb_dims)
|
95
|
+
|
96
|
+
if self.input_shape == "bnc":
|
97
|
+
num_points = input_data.shape[1]
|
98
|
+
input_data = input_data.permute(0, 2, 1)
|
99
|
+
else:
|
100
|
+
num_points = input_data.shape[2]
|
101
|
+
if input_data.shape[1] != 3:
|
102
|
+
raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
|
103
|
+
|
104
|
+
output = input_data
|
105
|
+
|
106
|
+
x1 = self.conv1(output) #32
|
107
|
+
x2 = self.conv2(x1) #64
|
108
|
+
x3 = self.conv3(x2) #64
|
109
|
+
x4 = self.conv4(x3+x2) #128
|
110
|
+
x5 = self.conv5(x4)
|
111
|
+
|
112
|
+
output = torch.cat([ x1, x2, x3, x4, x5], dim=1) #256, x4 x0,
|
113
|
+
point_feature = output
|
114
|
+
|
115
|
+
if self.global_feat:
|
116
|
+
return output
|
117
|
+
else:
|
118
|
+
output = self.pooling(output)
|
119
|
+
output = output.view(-1, self.emb_dims, 1).repeat(1, 1, num_points)
|
120
|
+
return torch.cat([output, point_feature], 1)
|
121
|
+
|
122
|
+
|
123
|
+
# self attention mechanism
|
124
|
+
class self_attention_fc(nn.Module):
|
125
|
+
""" Self attention Layer"""
|
126
|
+
def __init__(self,in_dim, out_dim): #1024
|
127
|
+
super(self_attention_fc,self).__init__()
|
128
|
+
|
129
|
+
self.in_dim = in_dim
|
130
|
+
self.out_dim = out_dim
|
131
|
+
|
132
|
+
self.query_conv = BasicConv1D(in_dim, out_dim)
|
133
|
+
|
134
|
+
self.beta = nn.Parameter(torch.zeros(1))
|
135
|
+
self.softmax = nn.Softmax(dim=-1) #
|
136
|
+
|
137
|
+
def forward(self,x, y): #B, 1024 , 1
|
138
|
+
"""
|
139
|
+
inputs :
|
140
|
+
x : input feature maps( B X C,1 )
|
141
|
+
returns :
|
142
|
+
out : self attention value + input feature
|
143
|
+
attention: B X N X N (N is Width*Height)
|
144
|
+
"""
|
145
|
+
proj_query_x = self.query_conv(x) #[B, in_dim, 1]----->[B, out_dim1, 1]
|
146
|
+
|
147
|
+
proj_key_y = self.query_conv(y).permute(0,2,1) #[B, 1, out_dim1]
|
148
|
+
|
149
|
+
energy_xy = torch.bmm(proj_query_x, proj_key_y) # xi Attention scores for all points in y [B, 64, 64]
|
150
|
+
|
151
|
+
attention_xy = self.softmax(energy_xy)
|
152
|
+
attention_yx = self.softmax(energy_xy.permute(0,2,1))
|
153
|
+
|
154
|
+
proj_value_x = proj_query_x # self.value_conv_x(x) # [B, out_dim, 64]
|
155
|
+
proj_value_y = proj_key_y.permute(0,2,1) # self.value_conv_x(y) # [B, out_dim, 64]
|
156
|
+
|
157
|
+
out_x = torch.bmm(attention_xy, proj_value_x) # [B, out_dim]
|
158
|
+
out_x = self.beta* out_x + proj_value_x # self.kama*
|
159
|
+
|
160
|
+
out_y = torch.bmm(attention_yx, proj_value_y ) # [B, out_dim]
|
161
|
+
out_y = self.beta*out_y + proj_value_y # self.kama *
|
162
|
+
|
163
|
+
return out_x, out_y
|
164
|
+
|
165
|
+
|
166
|
+
|
167
|
+
class PointNetMask(nn.Module):
|
168
|
+
def __init__(self, template_feature_size=1024, source_feature_size=1024, feature_model=PointNet()):
|
169
|
+
super().__init__()
|
170
|
+
self.feature_model = feature_model
|
171
|
+
self.pooling_max = Pooling(pool_type='max')
|
172
|
+
self.pooling_avg = Pooling(pool_type='avg')
|
173
|
+
|
174
|
+
input_size = template_feature_size + source_feature_size
|
175
|
+
|
176
|
+
self.global_feat_1 = self_attention_fc(1024, 512)
|
177
|
+
self.global_feat_2 = self_attention_fc(512, 256)
|
178
|
+
self.global_feat_3 = self_attention_fc(256, 512)
|
179
|
+
|
180
|
+
self.h3 = nn.Sequential(BasicConv1D(1024, 512),
|
181
|
+
BasicConv1D(512, 256),
|
182
|
+
BasicConv1D(256, 128),
|
183
|
+
nn.Conv1d(128, 1, 1), nn.Sigmoid())
|
184
|
+
|
185
|
+
|
186
|
+
def find_mask(self, source_features, template_features):
|
187
|
+
global_source_features_max = self.pooling_max(source_features)
|
188
|
+
global_template_features_max = self.pooling_max(template_features)
|
189
|
+
global_source_features_avg = self.pooling_avg(source_features)
|
190
|
+
global_template_features_avg = self.pooling_avg(template_features)
|
191
|
+
global_source_features = torch.cat([global_source_features_max, global_source_features_avg], dim=1)
|
192
|
+
global_template_features = torch.cat([global_template_features_max, global_template_features_avg], dim=1)
|
193
|
+
|
194
|
+
shared_feat_1,shared_feat_2 = self.global_feat_1(global_source_features.unsqueeze(2), global_template_features.unsqueeze(2))
|
195
|
+
shared_feat_1,shared_feat_2 = self.global_feat_2(shared_feat_1, shared_feat_2)
|
196
|
+
shared_feat_1,shared_feat_2 = self.global_feat_3(shared_feat_1, shared_feat_2)
|
197
|
+
|
198
|
+
batch_size, _ , num_points = source_features.size()
|
199
|
+
global_source_features = shared_feat_1
|
200
|
+
global_source_features = global_source_features.repeat(1,1,num_points)
|
201
|
+
x = torch.cat([template_features, global_source_features], dim=1)
|
202
|
+
x = self.h3(x)
|
203
|
+
|
204
|
+
batch_size, _ , num_points = template_features.size()
|
205
|
+
global_template_features = shared_feat_2
|
206
|
+
global_template_features = global_template_features.repeat(1,1,num_points)
|
207
|
+
y = torch.cat([source_features, global_template_features], dim=1)
|
208
|
+
y = self.h3(y)
|
209
|
+
|
210
|
+
return x.view(batch_size, -1), y.view(batch_size, -1)
|
211
|
+
|
212
|
+
def forward(self, template, source):
|
213
|
+
source_features = self.feature_model(source) # [B x C x N]
|
214
|
+
template_features = self.feature_model(template) # [B x C x N]
|
215
|
+
|
216
|
+
template_mask, source_mask = self.find_mask(source_features, template_features)
|
217
|
+
return template_mask, source_mask
|
218
|
+
|
219
|
+
class MaskNet2(nn.Module):
|
220
|
+
def __init__(self, feature_model=PointNet(use_bn=True), is_training=True):
|
221
|
+
super().__init__()
|
222
|
+
self.maskNet = PointNetMask(feature_model=feature_model)
|
223
|
+
self.is_training = is_training
|
224
|
+
|
225
|
+
@staticmethod
|
226
|
+
def index_points(points, idx):
|
227
|
+
"""
|
228
|
+
Input:
|
229
|
+
points: input points data, [B, N, C]
|
230
|
+
idx: sample index data, [B, S]
|
231
|
+
Return:
|
232
|
+
new_points:, indexed points data, [B, S, C]
|
233
|
+
"""
|
234
|
+
device = points.device
|
235
|
+
B = points.shape[0]
|
236
|
+
view_shape = list(idx.shape)
|
237
|
+
view_shape[1:] = [1] * (len(view_shape) - 1)
|
238
|
+
repeat_shape = list(idx.shape)
|
239
|
+
repeat_shape[0] = 1
|
240
|
+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
241
|
+
new_points = points[batch_indices, idx, :]
|
242
|
+
|
243
|
+
return new_points
|
244
|
+
|
245
|
+
def forward(self, template, source, point_selection='threshold', mask_threshold = 0.5):
|
246
|
+
template_mask, source_mask = self.maskNet(template, source) #B, N
|
247
|
+
if not torch.cuda.is_available():
|
248
|
+
device = 'cpu'
|
249
|
+
device = torch.device(device)
|
250
|
+
|
251
|
+
source_binary_mask = torch.where(source_mask > mask_threshold, torch.ones(source_mask.size()).to(device), torch.zeros(source_mask.size()).to(device))
|
252
|
+
template_binary_mask = torch.where(template_mask > mask_threshold, torch.ones(template_mask.size()).to(device), torch.zeros(template_mask.size()).to(device))
|
253
|
+
|
254
|
+
masked_template = template[:, torch.tensor(template_binary_mask, dtype = torch.bool).squeeze(0), 0:3]
|
255
|
+
masked_source = source[:, torch.tensor(source_binary_mask, dtype = torch.bool).squeeze(0), 0:3]
|
256
|
+
|
257
|
+
return masked_template, masked_source, template_mask, source_mask
|
258
|
+
|
259
|
+
|
260
|
+
if __name__ == '__main__':
|
261
|
+
template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
|
262
|
+
net = MaskNet2()
|
263
|
+
result = net(template, source)
|
264
|
+
import ipdb; ipdb.set_trace()
|
learning3d/models/pcn.py
ADDED
@@ -0,0 +1,164 @@
|
|
1
|
+
# author: Vinit Sarode (vinitsarode5@gmail.com) 03/23/2020
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
import torch.nn.functional as F
|
6
|
+
from .pooling import Pooling
|
7
|
+
|
8
|
+
class PCN(torch.nn.Module):
|
9
|
+
def __init__(self, emb_dims=1024, input_shape="bnc", num_coarse=1024, grid_size=4, detailed_output=False):
|
10
|
+
# emb_dims: Embedding Dimensions for PCN.
|
11
|
+
# input_shape: Shape of Input Point Cloud (b: batch, n: no of points, c: channels)
|
12
|
+
super(PCN, self).__init__()
|
13
|
+
if input_shape not in ["bcn", "bnc"]:
|
14
|
+
raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
|
15
|
+
self.input_shape = input_shape
|
16
|
+
self.emb_dims = emb_dims
|
17
|
+
self.num_coarse = num_coarse
|
18
|
+
self.detailed_output = detailed_output
|
19
|
+
self.grid_size = grid_size
|
20
|
+
self.num_fine = self.grid_size ** 2 * self.num_coarse
|
21
|
+
self.pooling = Pooling('max')
|
22
|
+
|
23
|
+
self.encoder()
|
24
|
+
self.decoder_layers = self.decoder()
|
25
|
+
if detailed_output: self.folding_layers = self.folding()
|
26
|
+
|
27
|
+
def encoder_1(self):
|
28
|
+
self.conv1 = torch.nn.Conv1d(3, 128, 1)
|
29
|
+
self.conv2 = torch.nn.Conv1d(128, 256, 1)
|
30
|
+
self.relu = torch.nn.ReLU()
|
31
|
+
|
32
|
+
# self.bn1 = torch.nn.BatchNorm1d(128)
|
33
|
+
# self.bn2 = torch.nn.BatchNorm1d(256)
|
34
|
+
|
35
|
+
layers = [self.conv1, self.relu,
|
36
|
+
self.conv2]
|
37
|
+
return layers
|
38
|
+
|
39
|
+
def encoder_2(self):
|
40
|
+
self.conv3 = torch.nn.Conv1d(2*256, 512, 1)
|
41
|
+
self.conv4 = torch.nn.Conv1d(512, self.emb_dims, 1)
|
42
|
+
|
43
|
+
# self.bn3 = torch.nn.BatchNorm1d(512)
|
44
|
+
# self.bn4 = torch.nn.BatchNorm1d(self.emb_dims)
|
45
|
+
self.relu = torch.nn.ReLU()
|
46
|
+
|
47
|
+
layers = [self.conv3, self.relu,
|
48
|
+
self.conv4]
|
49
|
+
return layers
|
50
|
+
|
51
|
+
def encoder(self):
|
52
|
+
self.encoder_layers1 = self.encoder_1()
|
53
|
+
self.encoder_layers2 = self.encoder_2()
|
54
|
+
|
55
|
+
def decoder(self):
|
56
|
+
self.linear1 = torch.nn.Linear(self.emb_dims, 1024)
|
57
|
+
self.linear2 = torch.nn.Linear(1024, 1024)
|
58
|
+
self.linear3 = torch.nn.Linear(1024, self.num_coarse*3)
|
59
|
+
|
60
|
+
# self.bn1 = torch.nn.BatchNorm1d(1024)
|
61
|
+
# self.bn2 = torch.nn.BatchNorm1d(1024)
|
62
|
+
# self.bn3 = torch.nn.BatchNorm1d(self.num_coarse*3)
|
63
|
+
self.relu = torch.nn.ReLU()
|
64
|
+
|
65
|
+
layers = [self.linear1, self.relu,
|
66
|
+
self.linear2, self.relu,
|
67
|
+
self.linear3]
|
68
|
+
return layers
|
69
|
+
|
70
|
+
def folding(self):
|
71
|
+
self.conv5 = torch.nn.Conv1d(1029, 512, 1)
|
72
|
+
self.conv6 = torch.nn.Conv1d(512, 512, 1)
|
73
|
+
self.conv7 = torch.nn.Conv1d(512, 3, 1)
|
74
|
+
|
75
|
+
# self.bn5 = torch.nn.BatchNorm1d(512)
|
76
|
+
# self.bn6 = torch.nn.BatchNorm1d(512)
|
77
|
+
self.relu = torch.nn.ReLU()
|
78
|
+
|
79
|
+
layers = [self.conv5, self.relu,
|
80
|
+
self.conv6, self.relu,
|
81
|
+
self.conv7]
|
82
|
+
return layers
|
83
|
+
|
84
|
+
def fine_decoder(self):
|
85
|
+
# Fine Output
|
86
|
+
linspace = torch.linspace(-0.05, 0.05, steps=self.grid_size)
|
87
|
+
grid = torch.meshgrid(linspace, linspace)
|
88
|
+
grid = torch.reshape(torch.stack(grid, dim=2), (-1,2)) # 16x2
|
89
|
+
grid = torch.unsqueeze(grid, dim=0) # 1x16x2
|
90
|
+
grid_feature = grid.repeat([self.coarse_output.shape[0], self.num_coarse, 1]) # Bx16384x2
|
91
|
+
|
92
|
+
point_feature = torch.unsqueeze(self.coarse_output, dim=2) # Bx1024x1x3
|
93
|
+
point_feature = point_feature.repeat([1, 1, self.grid_size ** 2, 1]) # Bx1024x16x3
|
94
|
+
point_feature = torch.reshape(point_feature, (-1, self.num_fine, 3)) # Bx16384x3
|
95
|
+
|
96
|
+
global_feature = torch.unsqueeze(self.global_feature_v, dim=1) # Bx1x1024
|
97
|
+
global_feature = global_feature.repeat([1, self.num_fine, 1]) # Bx16384x1024
|
98
|
+
|
99
|
+
feature = torch.cat([grid_feature, point_feature, global_feature], dim=2) # Bx16384x1029
|
100
|
+
|
101
|
+
center = torch.unsqueeze(self.coarse_output, dim=2) # Bx1024x1x3
|
102
|
+
center = center.repeat([1, 1, self.grid_size ** 2, 1]) # Bx1024x16x3
|
103
|
+
center = torch.reshape(center, [-1, self.num_fine, 3]) # Bx16384x3
|
104
|
+
|
105
|
+
output = feature.permute(0, 2, 1)
|
106
|
+
for idx, layer in enumerate(self.folding_layers):
|
107
|
+
output = layer(output)
|
108
|
+
fine_output = output.permute(0, 2, 1) + center
|
109
|
+
return fine_output
|
110
|
+
|
111
|
+
def encode(self, input_data):
|
112
|
+
output = input_data
|
113
|
+
for idx, layer in enumerate(self.encoder_layers1):
|
114
|
+
output = layer(output)
|
115
|
+
|
116
|
+
global_feature_g = self.pooling(output)
|
117
|
+
|
118
|
+
global_feature_g = global_feature_g.unsqueeze(2)
|
119
|
+
global_feature_g = global_feature_g.repeat(1,1,self.num_points)
|
120
|
+
output = torch.cat([output, global_feature_g], dim=1)
|
121
|
+
|
122
|
+
for idx, layer in enumerate(self.encoder_layers2):
|
123
|
+
output = layer(output)
|
124
|
+
|
125
|
+
self.global_feature_v = self.pooling(output)
|
126
|
+
|
127
|
+
def decode(self):
|
128
|
+
output = self.global_feature_v
|
129
|
+
for idx, layer in enumerate(self.decoder_layers):
|
130
|
+
output = layer(output)
|
131
|
+
self.coarse_output = output.view(self.global_feature_v.shape[0], self.num_coarse, 3)
|
132
|
+
|
133
|
+
def forward(self, input_data):
|
134
|
+
# input_data: Point Cloud having shape input_shape.
|
135
|
+
# output: PointNet features (Batch x emb_dims)
|
136
|
+
if self.input_shape == "bnc":
|
137
|
+
self.num_points = input_data.shape[1]
|
138
|
+
input_data = input_data.permute(0, 2, 1)
|
139
|
+
else:
|
140
|
+
self.num_points = input_data.shape[2]
|
141
|
+
if input_data.shape[1] != 3:
|
142
|
+
raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
|
143
|
+
|
144
|
+
self.encode(input_data)
|
145
|
+
self.decode()
|
146
|
+
|
147
|
+
result = {'coarse_output': self.coarse_output}
|
148
|
+
|
149
|
+
if self.detailed_output:
|
150
|
+
fine_output = self.fine_decoder()
|
151
|
+
result['fine_output'] = fine_output
|
152
|
+
|
153
|
+
return result
|
154
|
+
|
155
|
+
|
156
|
+
if __name__ == '__main__':
|
157
|
+
# Test the code.
|
158
|
+
x = torch.rand((10,1024,3))
|
159
|
+
|
160
|
+
pcn = PCN()
|
161
|
+
y = pcn(x)
|
162
|
+
print("Network Architecture: ")
|
163
|
+
print(pn)
|
164
|
+
print("Input Shape of PCN: ", x.shape, "\nOutput Shape of PCN: ", y['coarse_output'].shape)
|
@@ -0,0 +1,74 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .pointnet import PointNet
|
5
|
+
from .pooling import Pooling
|
6
|
+
from .. ops.transform_functions import PCRNetTransform as transform
|
7
|
+
|
8
|
+
|
9
|
+
class iPCRNet(nn.Module):
|
10
|
+
def __init__(self, feature_model=PointNet(), droput=0.0, pooling='max'):
|
11
|
+
super().__init__()
|
12
|
+
self.feature_model = feature_model
|
13
|
+
self.pooling = Pooling(pooling)
|
14
|
+
|
15
|
+
self.linear = [nn.Linear(self.feature_model.emb_dims * 2, 1024), nn.ReLU(),
|
16
|
+
nn.Linear(1024, 1024), nn.ReLU(),
|
17
|
+
nn.Linear(1024, 512), nn.ReLU(),
|
18
|
+
nn.Linear(512, 512), nn.ReLU(),
|
19
|
+
nn.Linear(512, 256), nn.ReLU()]
|
20
|
+
|
21
|
+
if droput>0.0:
|
22
|
+
self.linear.append(nn.Dropout(droput))
|
23
|
+
self.linear.append(nn.Linear(256,7))
|
24
|
+
|
25
|
+
self.linear = nn.Sequential(*self.linear)
|
26
|
+
|
27
|
+
# Single Pass Alignment Module (SPAM)
|
28
|
+
def spam(self, template_features, source, est_R, est_t):
|
29
|
+
batch_size = source.size(0)
|
30
|
+
|
31
|
+
self.source_features = self.pooling(self.feature_model(source))
|
32
|
+
y = torch.cat([template_features, self.source_features], dim=1)
|
33
|
+
pose_7d = self.linear(y)
|
34
|
+
pose_7d = transform.create_pose_7d(pose_7d)
|
35
|
+
|
36
|
+
# Find current rotation and translation.
|
37
|
+
identity = torch.eye(3).to(source).view(1,3,3).expand(batch_size, 3, 3).contiguous()
|
38
|
+
est_R_temp = transform.quaternion_rotate(identity, pose_7d).permute(0, 2, 1)
|
39
|
+
est_t_temp = transform.get_translation(pose_7d).view(-1, 1, 3)
|
40
|
+
|
41
|
+
# update translation matrix.
|
42
|
+
est_t = torch.bmm(est_R_temp, est_t.permute(0, 2, 1)).permute(0, 2, 1) + est_t_temp
|
43
|
+
# update rotation matrix.
|
44
|
+
est_R = torch.bmm(est_R_temp, est_R)
|
45
|
+
|
46
|
+
source = transform.quaternion_transform(source, pose_7d) # Ps' = est_R*Ps + est_t
|
47
|
+
return est_R, est_t, source
|
48
|
+
|
49
|
+
def forward(self, template, source, max_iteration=8):
|
50
|
+
est_R = torch.eye(3).to(template).view(1, 3, 3).expand(template.size(0), 3, 3).contiguous() # (Bx3x3)
|
51
|
+
est_t = torch.zeros(1,3).to(template).view(1, 1, 3).expand(template.size(0), 1, 3).contiguous() # (Bx1x3)
|
52
|
+
template_features = self.pooling(self.feature_model(template))
|
53
|
+
|
54
|
+
if max_iteration == 1:
|
55
|
+
est_R, est_t, source = self.spam(template_features, source, est_R, est_t)
|
56
|
+
else:
|
57
|
+
for i in range(max_iteration):
|
58
|
+
est_R, est_t, source = self.spam(template_features, source, est_R, est_t)
|
59
|
+
|
60
|
+
result = {'est_R': est_R, # source -> template
|
61
|
+
'est_t': est_t, # source -> template
|
62
|
+
'est_T': transform.convert2transformation(est_R, est_t), # source -> template
|
63
|
+
'r': template_features - self.source_features,
|
64
|
+
'transformed_source': source}
|
65
|
+
return result
|
66
|
+
|
67
|
+
|
68
|
+
if __name__ == '__main__':
|
69
|
+
template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
|
70
|
+
pn = PointNet()
|
71
|
+
|
72
|
+
net = iPCRNet(pn)
|
73
|
+
result = net(template, source)
|
74
|
+
import ipdb; ipdb.set_trace()
|
@@ -0,0 +1,108 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .. utils import PointConvDensitySetAbstraction
|
5
|
+
|
6
|
+
class PointConvDensityClsSsg(torch.nn.Module):
|
7
|
+
def __init__(self, emb_dims=1024, input_shape="bnc", input_channel_dim=3, classifier=False, num_classes=40, pretrained=None):
|
8
|
+
super(PointConvDensityClsSsg, self).__init__()
|
9
|
+
if input_shape not in ["bnc", "bcn"]:
|
10
|
+
raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
|
11
|
+
self.input_shape = input_shape
|
12
|
+
self.emb_dims = emb_dims
|
13
|
+
self.classifier = classifier
|
14
|
+
self.input_channel_dim = input_channel_dim
|
15
|
+
self.create_structure()
|
16
|
+
if self.classifier: self.create_classifier(num_classes)
|
17
|
+
|
18
|
+
def create_structure(self):
|
19
|
+
# Arguments to define PointConv network using PointConvDensitySetAbstraction class.
|
20
|
+
# npoint: number of points sampled from input.
|
21
|
+
# nsample: number of neighbours chosen for each point in sampled point cloud.
|
22
|
+
# in_channel: number of channels in input.
|
23
|
+
# mlp: sizes of multi-layer perceptrons.
|
24
|
+
# bandwidth: used to compute gaussian density.
|
25
|
+
# group_all: group all points from input to a single point if set to True.
|
26
|
+
self.sa1 = PointConvDensitySetAbstraction(npoint=512, nsample=32, in_channel=self.input_channel_dim,
|
27
|
+
mlp=[64, 64, 128], bandwidth = 0.1, group_all=False)
|
28
|
+
self.sa2 = PointConvDensitySetAbstraction(npoint=128, nsample=64, in_channel=128 + 3,
|
29
|
+
mlp=[128, 128, 256], bandwidth = 0.2, group_all=False)
|
30
|
+
self.sa3 = PointConvDensitySetAbstraction(npoint=1, nsample=None, in_channel=256 + 3,
|
31
|
+
mlp=[256, 512, self.emb_dims], bandwidth = 0.4, group_all=True)
|
32
|
+
|
33
|
+
def create_classifier(self, num_classes):
|
34
|
+
# These are simple fully-connected layers with batch-norm and dropouts.
|
35
|
+
# This architecture is given by PointConv paper. Hence, I used it here as a default version.
|
36
|
+
# This can be easily modified by overwriting this function or by using classifier.py class.
|
37
|
+
self.fc1 = nn.Linear(self.emb_dims, 512)
|
38
|
+
self.bn1 = nn.BatchNorm1d(512)
|
39
|
+
self.drop1 = nn.Dropout(0.7)
|
40
|
+
self.fc2 = nn.Linear(512, 256)
|
41
|
+
self.bn2 = nn.BatchNorm1d(256)
|
42
|
+
self.drop2 = nn.Dropout(0.7)
|
43
|
+
self.fc3 = nn.Linear(256, num_classes)
|
44
|
+
|
45
|
+
def forward(self, input_data):
|
46
|
+
if self.input_shape == "bnc":
|
47
|
+
input_data = input_data.permute(0, 2, 1)
|
48
|
+
batch_size = input_data.shape[0]
|
49
|
+
|
50
|
+
# Convert point clouds to latent features using PointConv network.
|
51
|
+
l1_points, l1_features = self.sa1(input_data[:, :3, :], input_data[:, 3:, :])
|
52
|
+
l2_points, l2_features = self.sa2(l1_points, l1_features)
|
53
|
+
l3_points, l3_features = self.sa3(l2_points, l2_features)
|
54
|
+
features = l3_features.view(batch_size, self.emb_dims)
|
55
|
+
|
56
|
+
if self.classifier:
|
57
|
+
# Use these features to classify the input point cloud.
|
58
|
+
features = self.drop1(F.relu(self.bn1(self.fc1(features))))
|
59
|
+
features = self.drop2(F.relu(self.bn2(self.fc2(features))))
|
60
|
+
features = self.fc3(features)
|
61
|
+
output = F.log_softmax(features, -1)
|
62
|
+
else:
|
63
|
+
# Return the PointConv features for the use of other higher level tasks.
|
64
|
+
output = features
|
65
|
+
|
66
|
+
return output
|
67
|
+
|
68
|
+
def create_pointconv(classifier=False, pretrained=None):
|
69
|
+
if classifier and pretrained is not None:
|
70
|
+
class Network(torch.nn.Module):
|
71
|
+
def __init__(self, emb_dims=1024, input_shape="bnc", input_channel_dim=3, classifier=False, num_classes=40, pretrained=None):
|
72
|
+
# Arguments:
|
73
|
+
# emb_dims: Size of embeddings.
|
74
|
+
# input_shape: Shape of input point cloud.
|
75
|
+
# input_channel_dim: Number of channels in point cloud. [eg. Nx3 (only points) or Nx6 (points + normals)]
|
76
|
+
# classifier: Do you want to use default classifier layers or just the embedding layers.
|
77
|
+
# num_classes: If you use classifier then decide the number of classes in your dataset.
|
78
|
+
# use_pretrained: Use pretrained classification network.
|
79
|
+
super(PointConv, self).__init__()
|
80
|
+
self.pointconv = PointConvDensityClsSsg(emb_dims, input_shape, input_channel_dim, classifier, num_classes)
|
81
|
+
# super().__init__(emb_dims, input_shape, input_channel_dim, classifier, num_classes)
|
82
|
+
if classifier and pretrained is not None:
|
83
|
+
self.use_pretrained(pretrained)
|
84
|
+
|
85
|
+
def use_pretrained(self, pretrained):
|
86
|
+
checkpoint = torch.load(pretrained, map_location='cpu')
|
87
|
+
self.pointconv.load_state_dict(checkpoint['model_state_dict'])
|
88
|
+
|
89
|
+
def forward(self, input_data):
|
90
|
+
return self.pointconv(input_data)
|
91
|
+
return Network
|
92
|
+
else:
|
93
|
+
class Network(PointConvDensityClsSsg):
|
94
|
+
def __init__(self, emb_dims=1024, input_shape="bnc", input_channel_dim=3, classifier=False, num_classes=40, pretrained=None):
|
95
|
+
super().__init__(emb_dims=emb_dims, input_shape=input_shape, input_channel_dim=input_channel_dim, classifier=classifier, num_classes=num_classes, pretrained=pretrained)
|
96
|
+
return Network
|
97
|
+
|
98
|
+
|
99
|
+
if __name__ == '__main__':
|
100
|
+
# Test the code.
|
101
|
+
x = torch.rand((2,1024,3))
|
102
|
+
|
103
|
+
PointConv = create_pointconv(classifier=False, pretrained='checkpoint.pth')
|
104
|
+
pc = PointConv(input_channel_dim=3, classifier=False, pretrained='checkpoint.pth')
|
105
|
+
y = pc(x)
|
106
|
+
print("Network Architecture: ")
|
107
|
+
print(pc)
|
108
|
+
print("Input Shape of PointNet: ", x.shape, "\nOutput Shape of PointNet: ", y.shape)
|