learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,108 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .pooling import Pooling
5
+
6
+
7
+ class PointNet(torch.nn.Module):
8
+ def __init__(self, emb_dims=1024, input_shape="bnc", use_bn=False, global_feat=True):
9
+ # emb_dims: Embedding Dimensions for PointNet.
10
+ # input_shape: Shape of Input Point Cloud (b: batch, n: no of points, c: channels)
11
+ super(PointNet, self).__init__()
12
+ if input_shape not in ["bcn", "bnc"]:
13
+ raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
14
+ self.input_shape = input_shape
15
+ self.emb_dims = emb_dims
16
+ self.use_bn = use_bn
17
+ self.global_feat = global_feat
18
+ if not self.global_feat: self.pooling = Pooling('max')
19
+
20
+ self.layers = self.create_structure()
21
+
22
+ def create_structure(self):
23
+ self.conv1 = torch.nn.Conv1d(3, 64, 1)
24
+ self.conv2 = torch.nn.Conv1d(64, 64, 1)
25
+ self.conv3 = torch.nn.Conv1d(64, 64, 1)
26
+ self.conv4 = torch.nn.Conv1d(64, 128, 1)
27
+ self.conv5 = torch.nn.Conv1d(128, self.emb_dims, 1)
28
+ self.relu = torch.nn.ReLU()
29
+
30
+ if self.use_bn:
31
+ self.bn1 = torch.nn.BatchNorm1d(64)
32
+ self.bn2 = torch.nn.BatchNorm1d(64)
33
+ self.bn3 = torch.nn.BatchNorm1d(64)
34
+ self.bn4 = torch.nn.BatchNorm1d(128)
35
+ self.bn5 = torch.nn.BatchNorm1d(self.emb_dims)
36
+
37
+ if self.use_bn:
38
+ layers = [self.conv1, self.bn1, self.relu,
39
+ self.conv2, self.bn2, self.relu,
40
+ self.conv3, self.bn3, self.relu,
41
+ self.conv4, self.bn4, self.relu,
42
+ self.conv5, self.bn5, self.relu]
43
+ else:
44
+ layers = [self.conv1, self.relu,
45
+ self.conv2, self.relu,
46
+ self.conv3, self.relu,
47
+ self.conv4, self.relu,
48
+ self.conv5, self.relu]
49
+ return layers
50
+
51
+
52
+ def forward(self, input_data):
53
+ # input_data: Point Cloud having shape input_shape.
54
+ # output: PointNet features (Batch x emb_dims)
55
+ if self.input_shape == "bnc":
56
+ num_points = input_data.shape[1]
57
+ input_data = input_data.permute(0, 2, 1)
58
+ else:
59
+ num_points = input_data.shape[2]
60
+ if input_data.shape[1] != 3:
61
+ raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
62
+
63
+ output = input_data
64
+ for idx, layer in enumerate(self.layers):
65
+ output = layer(output)
66
+ if idx == 1 and not self.global_feat: point_feature = output
67
+
68
+ if self.global_feat:
69
+ return output
70
+ else:
71
+ output = self.pooling(output)
72
+ output = output.view(-1, self.emb_dims, 1).repeat(1, 1, num_points)
73
+ return torch.cat([output, point_feature], 1)
74
+
75
+
76
+ if __name__ == '__main__':
77
+ # Test the code.
78
+ x = torch.rand((10,1024,3))
79
+
80
+ pn = PointNet(use_bn=True)
81
+ y = pn(x)
82
+ print("Network Architecture: ")
83
+ print(pn)
84
+ print("Input Shape of PointNet: ", x.shape, "\nOutput Shape of PointNet: ", y.shape)
85
+
86
+ class PointNet_modified(PointNet):
87
+ def __init__(self):
88
+ super().__init__()
89
+
90
+ def create_structure(self):
91
+ self.conv1 = torch.nn.Conv1d(3, 64, 1)
92
+ self.conv2 = torch.nn.Conv1d(64, 128, 1)
93
+ self.conv3 = torch.nn.Conv1d(128, self.emb_dims, 1)
94
+ self.relu = torch.nn.ReLU()
95
+
96
+ layers = [self.conv1, self.relu,
97
+ self.conv2, self.relu,
98
+ self.conv3, self.relu]
99
+ return layers
100
+
101
+ pn = PointNet_modified()
102
+ y = pn(x)
103
+ print("\n\n\nModified Network Architecture: ")
104
+ print(pn)
105
+ print("Input Shape of PointNet: ", x.shape, "\nOutput Shape of PointNet: ", y.shape)
106
+
107
+
108
+
@@ -0,0 +1,173 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .pointnet import PointNet
5
+ from .pooling import Pooling
6
+ from .. ops import data_utils
7
+ from .. ops import se3, so3, invmat
8
+
9
+
10
+ class PointNetLK(nn.Module):
11
+ def __init__(self, feature_model=PointNet(), delta=1.0e-2, learn_delta=False, xtol=1.0e-7, p0_zero_mean=True, p1_zero_mean=True, pooling='max'):
12
+ super().__init__()
13
+ self.feature_model = feature_model
14
+ self.pooling = Pooling(pooling)
15
+ self.inverse = invmat.InvMatrix.apply
16
+ self.exp = se3.Exp # [B, 6] -> [B, 4, 4]
17
+ self.transform = se3.transform # [B, 1, 4, 4] x [B, N, 3] -> [B, N, 3]
18
+
19
+ w1, w2, w3, v1, v2, v3 = delta, delta, delta, delta, delta, delta
20
+ twist = torch.Tensor([w1, w2, w3, v1, v2, v3])
21
+ self.dt = torch.nn.Parameter(twist.view(1, 6), requires_grad=learn_delta)
22
+
23
+ # results
24
+ self.last_err = None
25
+ self.g_series = None # for debug purpose
26
+ self.prev_r = None
27
+ self.g = None # estimation result
28
+ self.itr = 0
29
+ self.xtol = xtol
30
+ self.p0_zero_mean = p0_zero_mean
31
+ self.p1_zero_mean = p1_zero_mean
32
+
33
+ def forward(self, template, source, maxiter=10):
34
+ template, source, template_mean, source_mean = data_utils.mean_shift(template, source,
35
+ self.p0_zero_mean, self.p1_zero_mean)
36
+
37
+ result = self.iclk(template, source, maxiter)
38
+ result = data_utils.postprocess_data(result, template, source, template_mean, source_mean,
39
+ self.p0_zero_mean, self.p1_zero_mean)
40
+ return result
41
+
42
+ def iclk(self, template, source, maxiter):
43
+ batch_size = template.size(0)
44
+
45
+ est_T0 = torch.eye(4).to(template).view(1, 4, 4).expand(template.size(0), 4, 4).contiguous()
46
+ est_T = est_T0
47
+ self.est_T_series = torch.zeros(maxiter+1, *est_T0.size(), dtype=est_T0.dtype)
48
+ self.est_T_series[0] = est_T0.clone()
49
+
50
+ training = self.handle_batchNorm(template, source)
51
+
52
+ # re-calc. with current modules
53
+ template_features = self.pooling(self.feature_model(template)) # [B, N, 3] -> [B, K]
54
+
55
+ # approx. J by finite difference
56
+ dt = self.dt.to(template).expand(batch_size, 6)
57
+ J = self.approx_Jic(template, template_features, dt)
58
+
59
+ self.last_err = None
60
+ pinv = self.compute_inverse_jacobian(J, template_features, source)
61
+ if pinv == {}:
62
+ result = {'est_R': est_T[:,0:3,0:3],
63
+ 'est_t': est_T[:,0:3,3],
64
+ 'est_T': est_T,
65
+ 'r': None,
66
+ 'transformed_source': self.transform(est_T.unsqueeze(1), source),
67
+ 'itr': 1,
68
+ 'est_T_series': self.est_T_series}
69
+ return result
70
+
71
+ itr = 0
72
+ r = None
73
+ for itr in range(maxiter):
74
+ self.prev_r = r
75
+ transformed_source = self.transform(est_T.unsqueeze(1), source) # [B, 1, 4, 4] x [B, N, 3] -> [B, N, 3]
76
+ source_features = self.pooling(self.feature_model(transformed_source)) # [B, N, 3] -> [B, K]
77
+ r = source_features - template_features
78
+
79
+ pose = -pinv.bmm(r.unsqueeze(-1)).view(batch_size, 6)
80
+
81
+ check = pose.norm(p=2, dim=1, keepdim=True).max()
82
+ if float(check) < self.xtol:
83
+ if itr == 0:
84
+ self.last_err = 0 # no update.
85
+ break
86
+
87
+ est_T = self.update(est_T, pose)
88
+ self.est_T_series[itr+1] = est_T.clone()
89
+
90
+ rep = len(range(itr, maxiter))
91
+ self.est_T_series[(itr+1):] = est_T.clone().unsqueeze(0).repeat(rep, 1, 1, 1)
92
+
93
+ self.feature_model.train(training)
94
+ self.est_T = est_T
95
+
96
+ result = {'est_R': est_T[:,0:3,0:3],
97
+ 'est_t': est_T[:,0:3,3],
98
+ 'est_T': est_T,
99
+ 'r': r,
100
+ 'transformed_source': self.transform(est_T.unsqueeze(1), source),
101
+ 'itr': itr+1,
102
+ 'est_T_series': self.est_T_series}
103
+
104
+ return result
105
+
106
+ def update(self, g, dx):
107
+ # [B, 4, 4] x [B, 6] -> [B, 4, 4]
108
+ dg = self.exp(dx)
109
+ return dg.matmul(g)
110
+
111
+ def approx_Jic(self, template, template_features, dt):
112
+ # p0: [B, N, 3], Variable
113
+ # f0: [B, K], corresponding feature vector
114
+ # dt: [B, 6], Variable
115
+ # Jk = (feature_model(p(-delta[k], p0)) - f0) / delta[k]
116
+
117
+ batch_size = template.size(0)
118
+ num_points = template.size(1)
119
+
120
+ # compute transforms
121
+ transf = torch.zeros(batch_size, 6, 4, 4).to(template)
122
+ for b in range(template.size(0)):
123
+ d = torch.diag(dt[b, :]) # [6, 6]
124
+ D = self.exp(-d) # [6, 4, 4]
125
+ transf[b, :, :, :] = D[:, :, :]
126
+ transf = transf.unsqueeze(2).contiguous() # [B, 6, 1, 4, 4]
127
+ p = self.transform(transf, template.unsqueeze(1)) # x [B, 1, N, 3] -> [B, 6, N, 3]
128
+
129
+ #f0 = self.feature_model(p0).unsqueeze(-1) # [B, K, 1]
130
+ template_features = template_features.unsqueeze(-1) # [B, K, 1]
131
+ f = self.pooling(self.feature_model(p.view(-1, num_points, 3))).view(batch_size, 6, -1).transpose(1, 2) # [B, K, 6]
132
+
133
+ df = template_features - f # [B, K, 6]
134
+ J = df / dt.unsqueeze(1)
135
+
136
+ return J
137
+
138
+ def compute_inverse_jacobian(self, J, template_features, source):
139
+ # compute pinv(J) to solve J*x = -r
140
+ try:
141
+ Jt = J.transpose(1, 2) # [B, 6, K]
142
+ H = Jt.bmm(J) # [B, 6, 6]
143
+ B = self.inverse(H)
144
+ pinv = B.bmm(Jt) # [B, 6, K]
145
+ return pinv
146
+ except RuntimeError as err:
147
+ # singular...?
148
+ self.last_err = err
149
+ g = torch.eye(4).to(source).view(1, 4, 4).expand(source.size(0), 4, 4).contiguous()
150
+ #print(err)
151
+ # Perhaps we can use MP-inverse, but,...
152
+ # probably, self.dt is way too small...
153
+ source_features = self.pooling(self.feature_model(source)) # [B, N, 3] -> [B, K]
154
+ r = source_features - template_features
155
+ self.feature_model.train(self.feature_model.training)
156
+ return {}
157
+
158
+ def handle_batchNorm(self, template, source):
159
+ training = self.feature_model.training
160
+ if training:
161
+ # first, update BatchNorm modules
162
+ template_features, source_features = self.pooling(self.feature_model(template)), self.pooling(self.feature_model(source))
163
+ self.feature_model.eval() # and fix them.
164
+ return training
165
+
166
+
167
+ if __name__ == '__main__':
168
+ template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
169
+ pn = PointNet()
170
+
171
+ net = PointNetLK(pn)
172
+ result = net(template, source)
173
+ import ipdb; ipdb.set_trace()
@@ -0,0 +1,15 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Pooling(torch.nn.Module):
7
+ def __init__(self, pool_type='max'):
8
+ self.pool_type = pool_type
9
+ super(Pooling, self).__init__()
10
+
11
+ def forward(self, input):
12
+ if self.pool_type == 'max':
13
+ return torch.max(input, 2)[0].contiguous()
14
+ elif self.pool_type == 'avg' or self.pool_type == 'average':
15
+ return torch.mean(input, 2).contiguous()
@@ -0,0 +1,102 @@
1
+ """Feature Extraction and Parameter Prediction networks
2
+ """
3
+ import logging
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .. utils import sample_and_group_multi
10
+
11
+ _raw_features_sizes = {'xyz': 3, 'dxyz': 3, 'ppf': 4}
12
+ _raw_features_order = {'xyz': 0, 'dxyz': 1, 'ppf': 2}
13
+
14
+
15
+ def get_prepool(in_dim, out_dim):
16
+ """Shared FC part in PointNet before max pooling"""
17
+ net = nn.Sequential(
18
+ nn.Conv2d(in_dim, out_dim // 2, 1),
19
+ nn.GroupNorm(8, out_dim // 2),
20
+ nn.ReLU(),
21
+ nn.Conv2d(out_dim // 2, out_dim // 2, 1),
22
+ nn.GroupNorm(8, out_dim // 2),
23
+ nn.ReLU(),
24
+ nn.Conv2d(out_dim // 2, out_dim, 1),
25
+ nn.GroupNorm(8, out_dim),
26
+ nn.ReLU(),
27
+ )
28
+ return net
29
+
30
+
31
+ def get_postpool(in_dim, out_dim):
32
+ """Linear layers in PointNet after max pooling
33
+
34
+ Args:
35
+ in_dim: Number of input channels
36
+ out_dim: Number of output channels. Typically smaller than in_dim
37
+
38
+ """
39
+ net = nn.Sequential(
40
+ nn.Conv1d(in_dim, in_dim, 1),
41
+ nn.GroupNorm(8, in_dim),
42
+ nn.ReLU(),
43
+ nn.Conv1d(in_dim, out_dim, 1),
44
+ nn.GroupNorm(8, out_dim),
45
+ nn.ReLU(),
46
+ nn.Conv1d(out_dim, out_dim, 1),
47
+ )
48
+
49
+ return net
50
+
51
+
52
+ class PPFNet(nn.Module):
53
+ """Feature extraction Module that extracts hybrid features"""
54
+ def __init__(self, features=['ppf', 'dxyz', 'xyz'], emb_dims=96, radius=0.3, num_neighbors=64):
55
+ super().__init__()
56
+
57
+ self._logger = logging.getLogger(self.__class__.__name__)
58
+ self._logger.info('Using early fusion, feature dim = {}'.format(emb_dims))
59
+ self.radius = radius
60
+ self.n_sample = num_neighbors
61
+
62
+ self.features = sorted(features, key=lambda f: _raw_features_order[f])
63
+ self._logger.info('Feature extraction using features {}'.format(', '.join(self.features)))
64
+
65
+ # Layers
66
+ raw_dim = np.sum([_raw_features_sizes[f] for f in self.features]) # number of channels after concat
67
+ self.prepool = get_prepool(raw_dim, emb_dims * 2)
68
+ self.postpool = get_postpool(emb_dims * 2, emb_dims)
69
+
70
+ def forward(self, xyz, normals):
71
+ """Forward pass of the feature extraction network
72
+
73
+ Args:
74
+ xyz: (B, N, 3)
75
+ normals: (B, N, 3)
76
+
77
+ Returns:
78
+ cluster features (B, N, C)
79
+
80
+ """
81
+ features = sample_and_group_multi(-1, self.radius, self.n_sample, xyz, normals)
82
+ features['xyz'] = features['xyz'][:, :, None, :]
83
+
84
+ # Gate and concat
85
+ concat = []
86
+ for i in range(len(self.features)):
87
+ f = self.features[i]
88
+ expanded = (features[f]).expand(-1, -1, self.n_sample, -1)
89
+ concat.append(expanded)
90
+ fused_input_feat = torch.cat(concat, -1)
91
+
92
+ # Prepool_FC, pool, postpool-FC
93
+ new_feat = fused_input_feat.permute(0, 3, 2, 1) # [B, 10, n_sample, N]
94
+ new_feat = self.prepool(new_feat)
95
+
96
+ pooled_feat = torch.max(new_feat, 2)[0] # Max pooling (B, C, N)
97
+
98
+ post_feat = self.postpool(pooled_feat) # Post pooling dense layers
99
+ cluster_feat = post_feat.permute(0, 2, 1)
100
+ cluster_feat = cluster_feat / torch.norm(cluster_feat, dim=-1, keepdim=True)
101
+
102
+ return cluster_feat # (B, N, C)