learning3d 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- learning3d/__init__.py +2 -0
- learning3d/data_utils/__init__.py +4 -0
- learning3d/data_utils/dataloaders.py +454 -0
- learning3d/data_utils/user_data.py +119 -0
- learning3d/examples/test_dcp.py +139 -0
- learning3d/examples/test_deepgmr.py +144 -0
- learning3d/examples/test_flownet.py +113 -0
- learning3d/examples/test_masknet.py +159 -0
- learning3d/examples/test_masknet2.py +162 -0
- learning3d/examples/test_pcn.py +118 -0
- learning3d/examples/test_pcrnet.py +120 -0
- learning3d/examples/test_pnlk.py +121 -0
- learning3d/examples/test_pointconv.py +126 -0
- learning3d/examples/test_pointnet.py +121 -0
- learning3d/examples/test_prnet.py +126 -0
- learning3d/examples/test_rpmnet.py +120 -0
- learning3d/examples/train_PointNetLK.py +240 -0
- learning3d/examples/train_dcp.py +249 -0
- learning3d/examples/train_deepgmr.py +244 -0
- learning3d/examples/train_flownet.py +259 -0
- learning3d/examples/train_masknet.py +239 -0
- learning3d/examples/train_pcn.py +216 -0
- learning3d/examples/train_pcrnet.py +228 -0
- learning3d/examples/train_pointconv.py +245 -0
- learning3d/examples/train_pointnet.py +244 -0
- learning3d/examples/train_prnet.py +229 -0
- learning3d/examples/train_rpmnet.py +228 -0
- learning3d/losses/__init__.py +12 -0
- learning3d/losses/chamfer_distance.py +51 -0
- learning3d/losses/classification.py +14 -0
- learning3d/losses/correspondence_loss.py +10 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
- learning3d/losses/cuda/emd_torch/setup.py +29 -0
- learning3d/losses/emd.py +16 -0
- learning3d/losses/frobenius_norm.py +21 -0
- learning3d/losses/rmse_features.py +16 -0
- learning3d/models/__init__.py +23 -0
- learning3d/models/classifier.py +41 -0
- learning3d/models/dcp.py +92 -0
- learning3d/models/deepgmr.py +165 -0
- learning3d/models/dgcnn.py +92 -0
- learning3d/models/flownet3d.py +446 -0
- learning3d/models/masknet.py +84 -0
- learning3d/models/masknet2.py +264 -0
- learning3d/models/pcn.py +164 -0
- learning3d/models/pcrnet.py +74 -0
- learning3d/models/pointconv.py +108 -0
- learning3d/models/pointnet.py +108 -0
- learning3d/models/pointnetlk.py +173 -0
- learning3d/models/pooling.py +15 -0
- learning3d/models/ppfnet.py +102 -0
- learning3d/models/prnet.py +431 -0
- learning3d/models/rpmnet.py +359 -0
- learning3d/models/segmentation.py +38 -0
- learning3d/ops/__init__.py +0 -0
- learning3d/ops/data_utils.py +45 -0
- learning3d/ops/invmat.py +134 -0
- learning3d/ops/quaternion.py +218 -0
- learning3d/ops/se3.py +157 -0
- learning3d/ops/sinc.py +229 -0
- learning3d/ops/so3.py +213 -0
- learning3d/ops/transform_functions.py +342 -0
- learning3d/utils/__init__.py +9 -0
- learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
- learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
- learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
- learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
- learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
- learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
- learning3d/utils/lib/pointnet2_modules.py +160 -0
- learning3d/utils/lib/pointnet2_utils.py +318 -0
- learning3d/utils/lib/pytorch_utils.py +236 -0
- learning3d/utils/lib/setup.py +23 -0
- learning3d/utils/lib/src/ball_query.cpp +25 -0
- learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
- learning3d/utils/lib/src/ball_query_gpu.h +15 -0
- learning3d/utils/lib/src/cuda_utils.h +15 -0
- learning3d/utils/lib/src/group_points.cpp +36 -0
- learning3d/utils/lib/src/group_points_gpu.cu +86 -0
- learning3d/utils/lib/src/group_points_gpu.h +22 -0
- learning3d/utils/lib/src/interpolate.cpp +65 -0
- learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
- learning3d/utils/lib/src/interpolate_gpu.h +36 -0
- learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
- learning3d/utils/lib/src/sampling.cpp +46 -0
- learning3d/utils/lib/src/sampling_gpu.cu +253 -0
- learning3d/utils/lib/src/sampling_gpu.h +29 -0
- learning3d/utils/pointconv_util.py +382 -0
- learning3d/utils/ppfnet_util.py +244 -0
- learning3d/utils/svd.py +59 -0
- learning3d/utils/transformer.py +243 -0
- learning3d-0.0.1.dist-info/LICENSE +21 -0
- learning3d-0.0.1.dist-info/METADATA +271 -0
- learning3d-0.0.1.dist-info/RECORD +115 -0
- learning3d-0.0.1.dist-info/WHEEL +5 -0
- learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,108 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .pooling import Pooling
|
5
|
+
|
6
|
+
|
7
|
+
class PointNet(torch.nn.Module):
|
8
|
+
def __init__(self, emb_dims=1024, input_shape="bnc", use_bn=False, global_feat=True):
|
9
|
+
# emb_dims: Embedding Dimensions for PointNet.
|
10
|
+
# input_shape: Shape of Input Point Cloud (b: batch, n: no of points, c: channels)
|
11
|
+
super(PointNet, self).__init__()
|
12
|
+
if input_shape not in ["bcn", "bnc"]:
|
13
|
+
raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
|
14
|
+
self.input_shape = input_shape
|
15
|
+
self.emb_dims = emb_dims
|
16
|
+
self.use_bn = use_bn
|
17
|
+
self.global_feat = global_feat
|
18
|
+
if not self.global_feat: self.pooling = Pooling('max')
|
19
|
+
|
20
|
+
self.layers = self.create_structure()
|
21
|
+
|
22
|
+
def create_structure(self):
|
23
|
+
self.conv1 = torch.nn.Conv1d(3, 64, 1)
|
24
|
+
self.conv2 = torch.nn.Conv1d(64, 64, 1)
|
25
|
+
self.conv3 = torch.nn.Conv1d(64, 64, 1)
|
26
|
+
self.conv4 = torch.nn.Conv1d(64, 128, 1)
|
27
|
+
self.conv5 = torch.nn.Conv1d(128, self.emb_dims, 1)
|
28
|
+
self.relu = torch.nn.ReLU()
|
29
|
+
|
30
|
+
if self.use_bn:
|
31
|
+
self.bn1 = torch.nn.BatchNorm1d(64)
|
32
|
+
self.bn2 = torch.nn.BatchNorm1d(64)
|
33
|
+
self.bn3 = torch.nn.BatchNorm1d(64)
|
34
|
+
self.bn4 = torch.nn.BatchNorm1d(128)
|
35
|
+
self.bn5 = torch.nn.BatchNorm1d(self.emb_dims)
|
36
|
+
|
37
|
+
if self.use_bn:
|
38
|
+
layers = [self.conv1, self.bn1, self.relu,
|
39
|
+
self.conv2, self.bn2, self.relu,
|
40
|
+
self.conv3, self.bn3, self.relu,
|
41
|
+
self.conv4, self.bn4, self.relu,
|
42
|
+
self.conv5, self.bn5, self.relu]
|
43
|
+
else:
|
44
|
+
layers = [self.conv1, self.relu,
|
45
|
+
self.conv2, self.relu,
|
46
|
+
self.conv3, self.relu,
|
47
|
+
self.conv4, self.relu,
|
48
|
+
self.conv5, self.relu]
|
49
|
+
return layers
|
50
|
+
|
51
|
+
|
52
|
+
def forward(self, input_data):
|
53
|
+
# input_data: Point Cloud having shape input_shape.
|
54
|
+
# output: PointNet features (Batch x emb_dims)
|
55
|
+
if self.input_shape == "bnc":
|
56
|
+
num_points = input_data.shape[1]
|
57
|
+
input_data = input_data.permute(0, 2, 1)
|
58
|
+
else:
|
59
|
+
num_points = input_data.shape[2]
|
60
|
+
if input_data.shape[1] != 3:
|
61
|
+
raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
|
62
|
+
|
63
|
+
output = input_data
|
64
|
+
for idx, layer in enumerate(self.layers):
|
65
|
+
output = layer(output)
|
66
|
+
if idx == 1 and not self.global_feat: point_feature = output
|
67
|
+
|
68
|
+
if self.global_feat:
|
69
|
+
return output
|
70
|
+
else:
|
71
|
+
output = self.pooling(output)
|
72
|
+
output = output.view(-1, self.emb_dims, 1).repeat(1, 1, num_points)
|
73
|
+
return torch.cat([output, point_feature], 1)
|
74
|
+
|
75
|
+
|
76
|
+
if __name__ == '__main__':
|
77
|
+
# Test the code.
|
78
|
+
x = torch.rand((10,1024,3))
|
79
|
+
|
80
|
+
pn = PointNet(use_bn=True)
|
81
|
+
y = pn(x)
|
82
|
+
print("Network Architecture: ")
|
83
|
+
print(pn)
|
84
|
+
print("Input Shape of PointNet: ", x.shape, "\nOutput Shape of PointNet: ", y.shape)
|
85
|
+
|
86
|
+
class PointNet_modified(PointNet):
|
87
|
+
def __init__(self):
|
88
|
+
super().__init__()
|
89
|
+
|
90
|
+
def create_structure(self):
|
91
|
+
self.conv1 = torch.nn.Conv1d(3, 64, 1)
|
92
|
+
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
93
|
+
self.conv3 = torch.nn.Conv1d(128, self.emb_dims, 1)
|
94
|
+
self.relu = torch.nn.ReLU()
|
95
|
+
|
96
|
+
layers = [self.conv1, self.relu,
|
97
|
+
self.conv2, self.relu,
|
98
|
+
self.conv3, self.relu]
|
99
|
+
return layers
|
100
|
+
|
101
|
+
pn = PointNet_modified()
|
102
|
+
y = pn(x)
|
103
|
+
print("\n\n\nModified Network Architecture: ")
|
104
|
+
print(pn)
|
105
|
+
print("Input Shape of PointNet: ", x.shape, "\nOutput Shape of PointNet: ", y.shape)
|
106
|
+
|
107
|
+
|
108
|
+
|
@@ -0,0 +1,173 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from .pointnet import PointNet
|
5
|
+
from .pooling import Pooling
|
6
|
+
from .. ops import data_utils
|
7
|
+
from .. ops import se3, so3, invmat
|
8
|
+
|
9
|
+
|
10
|
+
class PointNetLK(nn.Module):
|
11
|
+
def __init__(self, feature_model=PointNet(), delta=1.0e-2, learn_delta=False, xtol=1.0e-7, p0_zero_mean=True, p1_zero_mean=True, pooling='max'):
|
12
|
+
super().__init__()
|
13
|
+
self.feature_model = feature_model
|
14
|
+
self.pooling = Pooling(pooling)
|
15
|
+
self.inverse = invmat.InvMatrix.apply
|
16
|
+
self.exp = se3.Exp # [B, 6] -> [B, 4, 4]
|
17
|
+
self.transform = se3.transform # [B, 1, 4, 4] x [B, N, 3] -> [B, N, 3]
|
18
|
+
|
19
|
+
w1, w2, w3, v1, v2, v3 = delta, delta, delta, delta, delta, delta
|
20
|
+
twist = torch.Tensor([w1, w2, w3, v1, v2, v3])
|
21
|
+
self.dt = torch.nn.Parameter(twist.view(1, 6), requires_grad=learn_delta)
|
22
|
+
|
23
|
+
# results
|
24
|
+
self.last_err = None
|
25
|
+
self.g_series = None # for debug purpose
|
26
|
+
self.prev_r = None
|
27
|
+
self.g = None # estimation result
|
28
|
+
self.itr = 0
|
29
|
+
self.xtol = xtol
|
30
|
+
self.p0_zero_mean = p0_zero_mean
|
31
|
+
self.p1_zero_mean = p1_zero_mean
|
32
|
+
|
33
|
+
def forward(self, template, source, maxiter=10):
|
34
|
+
template, source, template_mean, source_mean = data_utils.mean_shift(template, source,
|
35
|
+
self.p0_zero_mean, self.p1_zero_mean)
|
36
|
+
|
37
|
+
result = self.iclk(template, source, maxiter)
|
38
|
+
result = data_utils.postprocess_data(result, template, source, template_mean, source_mean,
|
39
|
+
self.p0_zero_mean, self.p1_zero_mean)
|
40
|
+
return result
|
41
|
+
|
42
|
+
def iclk(self, template, source, maxiter):
|
43
|
+
batch_size = template.size(0)
|
44
|
+
|
45
|
+
est_T0 = torch.eye(4).to(template).view(1, 4, 4).expand(template.size(0), 4, 4).contiguous()
|
46
|
+
est_T = est_T0
|
47
|
+
self.est_T_series = torch.zeros(maxiter+1, *est_T0.size(), dtype=est_T0.dtype)
|
48
|
+
self.est_T_series[0] = est_T0.clone()
|
49
|
+
|
50
|
+
training = self.handle_batchNorm(template, source)
|
51
|
+
|
52
|
+
# re-calc. with current modules
|
53
|
+
template_features = self.pooling(self.feature_model(template)) # [B, N, 3] -> [B, K]
|
54
|
+
|
55
|
+
# approx. J by finite difference
|
56
|
+
dt = self.dt.to(template).expand(batch_size, 6)
|
57
|
+
J = self.approx_Jic(template, template_features, dt)
|
58
|
+
|
59
|
+
self.last_err = None
|
60
|
+
pinv = self.compute_inverse_jacobian(J, template_features, source)
|
61
|
+
if pinv == {}:
|
62
|
+
result = {'est_R': est_T[:,0:3,0:3],
|
63
|
+
'est_t': est_T[:,0:3,3],
|
64
|
+
'est_T': est_T,
|
65
|
+
'r': None,
|
66
|
+
'transformed_source': self.transform(est_T.unsqueeze(1), source),
|
67
|
+
'itr': 1,
|
68
|
+
'est_T_series': self.est_T_series}
|
69
|
+
return result
|
70
|
+
|
71
|
+
itr = 0
|
72
|
+
r = None
|
73
|
+
for itr in range(maxiter):
|
74
|
+
self.prev_r = r
|
75
|
+
transformed_source = self.transform(est_T.unsqueeze(1), source) # [B, 1, 4, 4] x [B, N, 3] -> [B, N, 3]
|
76
|
+
source_features = self.pooling(self.feature_model(transformed_source)) # [B, N, 3] -> [B, K]
|
77
|
+
r = source_features - template_features
|
78
|
+
|
79
|
+
pose = -pinv.bmm(r.unsqueeze(-1)).view(batch_size, 6)
|
80
|
+
|
81
|
+
check = pose.norm(p=2, dim=1, keepdim=True).max()
|
82
|
+
if float(check) < self.xtol:
|
83
|
+
if itr == 0:
|
84
|
+
self.last_err = 0 # no update.
|
85
|
+
break
|
86
|
+
|
87
|
+
est_T = self.update(est_T, pose)
|
88
|
+
self.est_T_series[itr+1] = est_T.clone()
|
89
|
+
|
90
|
+
rep = len(range(itr, maxiter))
|
91
|
+
self.est_T_series[(itr+1):] = est_T.clone().unsqueeze(0).repeat(rep, 1, 1, 1)
|
92
|
+
|
93
|
+
self.feature_model.train(training)
|
94
|
+
self.est_T = est_T
|
95
|
+
|
96
|
+
result = {'est_R': est_T[:,0:3,0:3],
|
97
|
+
'est_t': est_T[:,0:3,3],
|
98
|
+
'est_T': est_T,
|
99
|
+
'r': r,
|
100
|
+
'transformed_source': self.transform(est_T.unsqueeze(1), source),
|
101
|
+
'itr': itr+1,
|
102
|
+
'est_T_series': self.est_T_series}
|
103
|
+
|
104
|
+
return result
|
105
|
+
|
106
|
+
def update(self, g, dx):
|
107
|
+
# [B, 4, 4] x [B, 6] -> [B, 4, 4]
|
108
|
+
dg = self.exp(dx)
|
109
|
+
return dg.matmul(g)
|
110
|
+
|
111
|
+
def approx_Jic(self, template, template_features, dt):
|
112
|
+
# p0: [B, N, 3], Variable
|
113
|
+
# f0: [B, K], corresponding feature vector
|
114
|
+
# dt: [B, 6], Variable
|
115
|
+
# Jk = (feature_model(p(-delta[k], p0)) - f0) / delta[k]
|
116
|
+
|
117
|
+
batch_size = template.size(0)
|
118
|
+
num_points = template.size(1)
|
119
|
+
|
120
|
+
# compute transforms
|
121
|
+
transf = torch.zeros(batch_size, 6, 4, 4).to(template)
|
122
|
+
for b in range(template.size(0)):
|
123
|
+
d = torch.diag(dt[b, :]) # [6, 6]
|
124
|
+
D = self.exp(-d) # [6, 4, 4]
|
125
|
+
transf[b, :, :, :] = D[:, :, :]
|
126
|
+
transf = transf.unsqueeze(2).contiguous() # [B, 6, 1, 4, 4]
|
127
|
+
p = self.transform(transf, template.unsqueeze(1)) # x [B, 1, N, 3] -> [B, 6, N, 3]
|
128
|
+
|
129
|
+
#f0 = self.feature_model(p0).unsqueeze(-1) # [B, K, 1]
|
130
|
+
template_features = template_features.unsqueeze(-1) # [B, K, 1]
|
131
|
+
f = self.pooling(self.feature_model(p.view(-1, num_points, 3))).view(batch_size, 6, -1).transpose(1, 2) # [B, K, 6]
|
132
|
+
|
133
|
+
df = template_features - f # [B, K, 6]
|
134
|
+
J = df / dt.unsqueeze(1)
|
135
|
+
|
136
|
+
return J
|
137
|
+
|
138
|
+
def compute_inverse_jacobian(self, J, template_features, source):
|
139
|
+
# compute pinv(J) to solve J*x = -r
|
140
|
+
try:
|
141
|
+
Jt = J.transpose(1, 2) # [B, 6, K]
|
142
|
+
H = Jt.bmm(J) # [B, 6, 6]
|
143
|
+
B = self.inverse(H)
|
144
|
+
pinv = B.bmm(Jt) # [B, 6, K]
|
145
|
+
return pinv
|
146
|
+
except RuntimeError as err:
|
147
|
+
# singular...?
|
148
|
+
self.last_err = err
|
149
|
+
g = torch.eye(4).to(source).view(1, 4, 4).expand(source.size(0), 4, 4).contiguous()
|
150
|
+
#print(err)
|
151
|
+
# Perhaps we can use MP-inverse, but,...
|
152
|
+
# probably, self.dt is way too small...
|
153
|
+
source_features = self.pooling(self.feature_model(source)) # [B, N, 3] -> [B, K]
|
154
|
+
r = source_features - template_features
|
155
|
+
self.feature_model.train(self.feature_model.training)
|
156
|
+
return {}
|
157
|
+
|
158
|
+
def handle_batchNorm(self, template, source):
|
159
|
+
training = self.feature_model.training
|
160
|
+
if training:
|
161
|
+
# first, update BatchNorm modules
|
162
|
+
template_features, source_features = self.pooling(self.feature_model(template)), self.pooling(self.feature_model(source))
|
163
|
+
self.feature_model.eval() # and fix them.
|
164
|
+
return training
|
165
|
+
|
166
|
+
|
167
|
+
if __name__ == '__main__':
|
168
|
+
template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
|
169
|
+
pn = PointNet()
|
170
|
+
|
171
|
+
net = PointNetLK(pn)
|
172
|
+
result = net(template, source)
|
173
|
+
import ipdb; ipdb.set_trace()
|
@@ -0,0 +1,15 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
|
6
|
+
class Pooling(torch.nn.Module):
|
7
|
+
def __init__(self, pool_type='max'):
|
8
|
+
self.pool_type = pool_type
|
9
|
+
super(Pooling, self).__init__()
|
10
|
+
|
11
|
+
def forward(self, input):
|
12
|
+
if self.pool_type == 'max':
|
13
|
+
return torch.max(input, 2)[0].contiguous()
|
14
|
+
elif self.pool_type == 'avg' or self.pool_type == 'average':
|
15
|
+
return torch.mean(input, 2).contiguous()
|
@@ -0,0 +1,102 @@
|
|
1
|
+
"""Feature Extraction and Parameter Prediction networks
|
2
|
+
"""
|
3
|
+
import logging
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
import torch.nn.functional as F
|
8
|
+
|
9
|
+
from .. utils import sample_and_group_multi
|
10
|
+
|
11
|
+
_raw_features_sizes = {'xyz': 3, 'dxyz': 3, 'ppf': 4}
|
12
|
+
_raw_features_order = {'xyz': 0, 'dxyz': 1, 'ppf': 2}
|
13
|
+
|
14
|
+
|
15
|
+
def get_prepool(in_dim, out_dim):
|
16
|
+
"""Shared FC part in PointNet before max pooling"""
|
17
|
+
net = nn.Sequential(
|
18
|
+
nn.Conv2d(in_dim, out_dim // 2, 1),
|
19
|
+
nn.GroupNorm(8, out_dim // 2),
|
20
|
+
nn.ReLU(),
|
21
|
+
nn.Conv2d(out_dim // 2, out_dim // 2, 1),
|
22
|
+
nn.GroupNorm(8, out_dim // 2),
|
23
|
+
nn.ReLU(),
|
24
|
+
nn.Conv2d(out_dim // 2, out_dim, 1),
|
25
|
+
nn.GroupNorm(8, out_dim),
|
26
|
+
nn.ReLU(),
|
27
|
+
)
|
28
|
+
return net
|
29
|
+
|
30
|
+
|
31
|
+
def get_postpool(in_dim, out_dim):
|
32
|
+
"""Linear layers in PointNet after max pooling
|
33
|
+
|
34
|
+
Args:
|
35
|
+
in_dim: Number of input channels
|
36
|
+
out_dim: Number of output channels. Typically smaller than in_dim
|
37
|
+
|
38
|
+
"""
|
39
|
+
net = nn.Sequential(
|
40
|
+
nn.Conv1d(in_dim, in_dim, 1),
|
41
|
+
nn.GroupNorm(8, in_dim),
|
42
|
+
nn.ReLU(),
|
43
|
+
nn.Conv1d(in_dim, out_dim, 1),
|
44
|
+
nn.GroupNorm(8, out_dim),
|
45
|
+
nn.ReLU(),
|
46
|
+
nn.Conv1d(out_dim, out_dim, 1),
|
47
|
+
)
|
48
|
+
|
49
|
+
return net
|
50
|
+
|
51
|
+
|
52
|
+
class PPFNet(nn.Module):
|
53
|
+
"""Feature extraction Module that extracts hybrid features"""
|
54
|
+
def __init__(self, features=['ppf', 'dxyz', 'xyz'], emb_dims=96, radius=0.3, num_neighbors=64):
|
55
|
+
super().__init__()
|
56
|
+
|
57
|
+
self._logger = logging.getLogger(self.__class__.__name__)
|
58
|
+
self._logger.info('Using early fusion, feature dim = {}'.format(emb_dims))
|
59
|
+
self.radius = radius
|
60
|
+
self.n_sample = num_neighbors
|
61
|
+
|
62
|
+
self.features = sorted(features, key=lambda f: _raw_features_order[f])
|
63
|
+
self._logger.info('Feature extraction using features {}'.format(', '.join(self.features)))
|
64
|
+
|
65
|
+
# Layers
|
66
|
+
raw_dim = np.sum([_raw_features_sizes[f] for f in self.features]) # number of channels after concat
|
67
|
+
self.prepool = get_prepool(raw_dim, emb_dims * 2)
|
68
|
+
self.postpool = get_postpool(emb_dims * 2, emb_dims)
|
69
|
+
|
70
|
+
def forward(self, xyz, normals):
|
71
|
+
"""Forward pass of the feature extraction network
|
72
|
+
|
73
|
+
Args:
|
74
|
+
xyz: (B, N, 3)
|
75
|
+
normals: (B, N, 3)
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
cluster features (B, N, C)
|
79
|
+
|
80
|
+
"""
|
81
|
+
features = sample_and_group_multi(-1, self.radius, self.n_sample, xyz, normals)
|
82
|
+
features['xyz'] = features['xyz'][:, :, None, :]
|
83
|
+
|
84
|
+
# Gate and concat
|
85
|
+
concat = []
|
86
|
+
for i in range(len(self.features)):
|
87
|
+
f = self.features[i]
|
88
|
+
expanded = (features[f]).expand(-1, -1, self.n_sample, -1)
|
89
|
+
concat.append(expanded)
|
90
|
+
fused_input_feat = torch.cat(concat, -1)
|
91
|
+
|
92
|
+
# Prepool_FC, pool, postpool-FC
|
93
|
+
new_feat = fused_input_feat.permute(0, 3, 2, 1) # [B, 10, n_sample, N]
|
94
|
+
new_feat = self.prepool(new_feat)
|
95
|
+
|
96
|
+
pooled_feat = torch.max(new_feat, 2)[0] # Max pooling (B, C, N)
|
97
|
+
|
98
|
+
post_feat = self.postpool(pooled_feat) # Post pooling dense layers
|
99
|
+
cluster_feat = post_feat.permute(0, 2, 1)
|
100
|
+
cluster_feat = cluster_feat / torch.norm(cluster_feat, dim=-1, keepdim=True)
|
101
|
+
|
102
|
+
return cluster_feat # (B, N, C)
|