learning3d 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. learning3d/__init__.py +2 -0
  2. learning3d/data_utils/__init__.py +4 -0
  3. learning3d/data_utils/dataloaders.py +454 -0
  4. learning3d/data_utils/user_data.py +119 -0
  5. learning3d/examples/test_dcp.py +139 -0
  6. learning3d/examples/test_deepgmr.py +144 -0
  7. learning3d/examples/test_flownet.py +113 -0
  8. learning3d/examples/test_masknet.py +159 -0
  9. learning3d/examples/test_masknet2.py +162 -0
  10. learning3d/examples/test_pcn.py +118 -0
  11. learning3d/examples/test_pcrnet.py +120 -0
  12. learning3d/examples/test_pnlk.py +121 -0
  13. learning3d/examples/test_pointconv.py +126 -0
  14. learning3d/examples/test_pointnet.py +121 -0
  15. learning3d/examples/test_prnet.py +126 -0
  16. learning3d/examples/test_rpmnet.py +120 -0
  17. learning3d/examples/train_PointNetLK.py +240 -0
  18. learning3d/examples/train_dcp.py +249 -0
  19. learning3d/examples/train_deepgmr.py +244 -0
  20. learning3d/examples/train_flownet.py +259 -0
  21. learning3d/examples/train_masknet.py +239 -0
  22. learning3d/examples/train_pcn.py +216 -0
  23. learning3d/examples/train_pcrnet.py +228 -0
  24. learning3d/examples/train_pointconv.py +245 -0
  25. learning3d/examples/train_pointnet.py +244 -0
  26. learning3d/examples/train_prnet.py +229 -0
  27. learning3d/examples/train_rpmnet.py +228 -0
  28. learning3d/losses/__init__.py +12 -0
  29. learning3d/losses/chamfer_distance.py +51 -0
  30. learning3d/losses/classification.py +14 -0
  31. learning3d/losses/correspondence_loss.py +10 -0
  32. learning3d/losses/cuda/chamfer_distance/__init__.py +1 -0
  33. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +185 -0
  34. learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +209 -0
  35. learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +66 -0
  36. learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +41 -0
  37. learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +347 -0
  38. learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +18 -0
  39. learning3d/losses/cuda/emd_torch/pkg/include/emd.h +54 -0
  40. learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +1 -0
  41. learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +40 -0
  42. learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +70 -0
  43. learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +1 -0
  44. learning3d/losses/cuda/emd_torch/setup.py +29 -0
  45. learning3d/losses/emd.py +16 -0
  46. learning3d/losses/frobenius_norm.py +21 -0
  47. learning3d/losses/rmse_features.py +16 -0
  48. learning3d/models/__init__.py +23 -0
  49. learning3d/models/classifier.py +41 -0
  50. learning3d/models/dcp.py +92 -0
  51. learning3d/models/deepgmr.py +165 -0
  52. learning3d/models/dgcnn.py +92 -0
  53. learning3d/models/flownet3d.py +446 -0
  54. learning3d/models/masknet.py +84 -0
  55. learning3d/models/masknet2.py +264 -0
  56. learning3d/models/pcn.py +164 -0
  57. learning3d/models/pcrnet.py +74 -0
  58. learning3d/models/pointconv.py +108 -0
  59. learning3d/models/pointnet.py +108 -0
  60. learning3d/models/pointnetlk.py +173 -0
  61. learning3d/models/pooling.py +15 -0
  62. learning3d/models/ppfnet.py +102 -0
  63. learning3d/models/prnet.py +431 -0
  64. learning3d/models/rpmnet.py +359 -0
  65. learning3d/models/segmentation.py +38 -0
  66. learning3d/ops/__init__.py +0 -0
  67. learning3d/ops/data_utils.py +45 -0
  68. learning3d/ops/invmat.py +134 -0
  69. learning3d/ops/quaternion.py +218 -0
  70. learning3d/ops/se3.py +157 -0
  71. learning3d/ops/sinc.py +229 -0
  72. learning3d/ops/so3.py +213 -0
  73. learning3d/ops/transform_functions.py +342 -0
  74. learning3d/utils/__init__.py +9 -0
  75. learning3d/utils/lib/build/lib.linux-x86_64-3.5/pointnet2_cuda.cpython-35m-x86_64-linux-gnu.so +0 -0
  76. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query.o +0 -0
  77. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/ball_query_gpu.o +0 -0
  78. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points.o +0 -0
  79. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/group_points_gpu.o +0 -0
  80. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate.o +0 -0
  81. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/interpolate_gpu.o +0 -0
  82. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/pointnet2_api.o +0 -0
  83. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling.o +0 -0
  84. learning3d/utils/lib/build/temp.linux-x86_64-3.5/src/sampling_gpu.o +0 -0
  85. learning3d/utils/lib/dist/pointnet2-0.0.0-py3.5-linux-x86_64.egg +0 -0
  86. learning3d/utils/lib/pointnet2.egg-info/SOURCES.txt +14 -0
  87. learning3d/utils/lib/pointnet2.egg-info/dependency_links.txt +1 -0
  88. learning3d/utils/lib/pointnet2.egg-info/top_level.txt +1 -0
  89. learning3d/utils/lib/pointnet2_modules.py +160 -0
  90. learning3d/utils/lib/pointnet2_utils.py +318 -0
  91. learning3d/utils/lib/pytorch_utils.py +236 -0
  92. learning3d/utils/lib/setup.py +23 -0
  93. learning3d/utils/lib/src/ball_query.cpp +25 -0
  94. learning3d/utils/lib/src/ball_query_gpu.cu +67 -0
  95. learning3d/utils/lib/src/ball_query_gpu.h +15 -0
  96. learning3d/utils/lib/src/cuda_utils.h +15 -0
  97. learning3d/utils/lib/src/group_points.cpp +36 -0
  98. learning3d/utils/lib/src/group_points_gpu.cu +86 -0
  99. learning3d/utils/lib/src/group_points_gpu.h +22 -0
  100. learning3d/utils/lib/src/interpolate.cpp +65 -0
  101. learning3d/utils/lib/src/interpolate_gpu.cu +233 -0
  102. learning3d/utils/lib/src/interpolate_gpu.h +36 -0
  103. learning3d/utils/lib/src/pointnet2_api.cpp +25 -0
  104. learning3d/utils/lib/src/sampling.cpp +46 -0
  105. learning3d/utils/lib/src/sampling_gpu.cu +253 -0
  106. learning3d/utils/lib/src/sampling_gpu.h +29 -0
  107. learning3d/utils/pointconv_util.py +382 -0
  108. learning3d/utils/ppfnet_util.py +244 -0
  109. learning3d/utils/svd.py +59 -0
  110. learning3d/utils/transformer.py +243 -0
  111. learning3d-0.0.1.dist-info/LICENSE +21 -0
  112. learning3d-0.0.1.dist-info/METADATA +271 -0
  113. learning3d-0.0.1.dist-info/RECORD +115 -0
  114. learning3d-0.0.1.dist-info/WHEEL +5 -0
  115. learning3d-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,264 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .pooling import Pooling
5
+
6
+
7
+ # Mish Activation Function
8
+ class Mish(nn.Module):
9
+ def __init__(self):
10
+ super(Mish, self).__init__()
11
+
12
+ def forward(self, x):
13
+ return x * torch.tanh(F.softplus(x))
14
+
15
+
16
+ # Basic Convolution Block
17
+ class BasicConv1D(nn.Module):
18
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, active = True):
19
+ super(BasicConv1D, self).__init__()
20
+ self.active = active
21
+ self.bn = nn.BatchNorm1d( out_channels)
22
+ if self.active == True:
23
+ self.activation = Mish()
24
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, bias=False)
25
+ #self.dropout = nn.Dropout(0.5)
26
+
27
+ def forward(self, x):
28
+ x = self.conv(x)
29
+ x = self.bn(x)
30
+ if self.active == True:
31
+ x = self.activation(x)
32
+ return x
33
+
34
+
35
+ class Self_Attn(nn.Module):
36
+ """ Self attention Layer"""
37
+ def __init__(self, in_dim, out_dim):
38
+ super(Self_Attn,self).__init__()
39
+
40
+ self.in_dim = in_dim
41
+ self.out_dim = out_dim
42
+
43
+ # Query Convolution
44
+ self.query_conv =BasicConv1D(in_dim, out_dim)
45
+
46
+ self.beta = nn.Parameter(torch.zeros(1))
47
+
48
+ self.softmax = nn.Softmax(dim=-1) #
49
+
50
+ def forward(self,x):
51
+ """
52
+ inputs :
53
+ x : input feature maps( B X C X N) 32, 1024, 64
54
+ returns :
55
+ out : self attention value + input feature
56
+ attention: B X N X N (N is Width*Height)
57
+ """
58
+
59
+ proj_query = self.query_conv(x).permute(0,2,1) # B, in_dim, N ---> B, in_dim // 8, N ----> B, N, in_dim // 8
60
+ proj_key = proj_query.permute(0,2,1) #B, in_dim, N ---> B, in_dim // 8, N
61
+
62
+ energy = torch.bmm(proj_query,proj_key) # transpose check B, N, N
63
+
64
+ attention = self.softmax(energy) # B , N, N
65
+
66
+ out_x = torch.bmm(proj_key, attention.permute(0,2,1) ) #B, out_dim, N
67
+
68
+ out = self.beta * out_x + proj_key
69
+
70
+ return out
71
+
72
+ class PointNet(torch.nn.Module):
73
+ def __init__(self, emb_dims=224, input_shape="bnc", use_bn=False, global_feat=True):
74
+ # emb_dims: Embedding Dimensions for PointNet.
75
+ # input_shape: Shape of Input Point Cloud (b: batch, n: no of points, c: channels)
76
+ super(PointNet, self).__init__()
77
+ if input_shape not in ["bcn", "bnc"]:
78
+ raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
79
+ self.input_shape = input_shape
80
+ self.emb_dims = emb_dims
81
+ self.use_bn = use_bn
82
+ self.global_feat = global_feat
83
+ if not self.global_feat: self.pooling = Pooling('max')
84
+
85
+ self.conv1 = Self_Attn(3, 32)
86
+ self.conv2 = Self_Attn(32, 64)
87
+ self.conv3 = Self_Attn(64, 64)
88
+ self.conv4 = Self_Attn(64, 128)
89
+ self.conv5 = Self_Attn(128, self.emb_dims)
90
+
91
+
92
+ def forward(self, input_data):
93
+ # input_data: Point Cloud having shape input_shape.
94
+ # output: PointNet features (Batch x emb_dims)
95
+
96
+ if self.input_shape == "bnc":
97
+ num_points = input_data.shape[1]
98
+ input_data = input_data.permute(0, 2, 1)
99
+ else:
100
+ num_points = input_data.shape[2]
101
+ if input_data.shape[1] != 3:
102
+ raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
103
+
104
+ output = input_data
105
+
106
+ x1 = self.conv1(output) #32
107
+ x2 = self.conv2(x1) #64
108
+ x3 = self.conv3(x2) #64
109
+ x4 = self.conv4(x3+x2) #128
110
+ x5 = self.conv5(x4)
111
+
112
+ output = torch.cat([ x1, x2, x3, x4, x5], dim=1) #256, x4 x0,
113
+ point_feature = output
114
+
115
+ if self.global_feat:
116
+ return output
117
+ else:
118
+ output = self.pooling(output)
119
+ output = output.view(-1, self.emb_dims, 1).repeat(1, 1, num_points)
120
+ return torch.cat([output, point_feature], 1)
121
+
122
+
123
+ # self attention mechanism
124
+ class self_attention_fc(nn.Module):
125
+ """ Self attention Layer"""
126
+ def __init__(self,in_dim, out_dim): #1024
127
+ super(self_attention_fc,self).__init__()
128
+
129
+ self.in_dim = in_dim
130
+ self.out_dim = out_dim
131
+
132
+ self.query_conv = BasicConv1D(in_dim, out_dim)
133
+
134
+ self.beta = nn.Parameter(torch.zeros(1))
135
+ self.softmax = nn.Softmax(dim=-1) #
136
+
137
+ def forward(self,x, y): #B, 1024 , 1
138
+ """
139
+ inputs :
140
+ x : input feature maps( B X C,1 )
141
+ returns :
142
+ out : self attention value + input feature
143
+ attention: B X N X N (N is Width*Height)
144
+ """
145
+ proj_query_x = self.query_conv(x) #[B, in_dim, 1]----->[B, out_dim1, 1]
146
+
147
+ proj_key_y = self.query_conv(y).permute(0,2,1) #[B, 1, out_dim1]
148
+
149
+ energy_xy = torch.bmm(proj_query_x, proj_key_y) # xi Attention scores for all points in y [B, 64, 64]
150
+
151
+ attention_xy = self.softmax(energy_xy)
152
+ attention_yx = self.softmax(energy_xy.permute(0,2,1))
153
+
154
+ proj_value_x = proj_query_x # self.value_conv_x(x) # [B, out_dim, 64]
155
+ proj_value_y = proj_key_y.permute(0,2,1) # self.value_conv_x(y) # [B, out_dim, 64]
156
+
157
+ out_x = torch.bmm(attention_xy, proj_value_x) # [B, out_dim]
158
+ out_x = self.beta* out_x + proj_value_x # self.kama*
159
+
160
+ out_y = torch.bmm(attention_yx, proj_value_y ) # [B, out_dim]
161
+ out_y = self.beta*out_y + proj_value_y # self.kama *
162
+
163
+ return out_x, out_y
164
+
165
+
166
+
167
+ class PointNetMask(nn.Module):
168
+ def __init__(self, template_feature_size=1024, source_feature_size=1024, feature_model=PointNet()):
169
+ super().__init__()
170
+ self.feature_model = feature_model
171
+ self.pooling_max = Pooling(pool_type='max')
172
+ self.pooling_avg = Pooling(pool_type='avg')
173
+
174
+ input_size = template_feature_size + source_feature_size
175
+
176
+ self.global_feat_1 = self_attention_fc(1024, 512)
177
+ self.global_feat_2 = self_attention_fc(512, 256)
178
+ self.global_feat_3 = self_attention_fc(256, 512)
179
+
180
+ self.h3 = nn.Sequential(BasicConv1D(1024, 512),
181
+ BasicConv1D(512, 256),
182
+ BasicConv1D(256, 128),
183
+ nn.Conv1d(128, 1, 1), nn.Sigmoid())
184
+
185
+
186
+ def find_mask(self, source_features, template_features):
187
+ global_source_features_max = self.pooling_max(source_features)
188
+ global_template_features_max = self.pooling_max(template_features)
189
+ global_source_features_avg = self.pooling_avg(source_features)
190
+ global_template_features_avg = self.pooling_avg(template_features)
191
+ global_source_features = torch.cat([global_source_features_max, global_source_features_avg], dim=1)
192
+ global_template_features = torch.cat([global_template_features_max, global_template_features_avg], dim=1)
193
+
194
+ shared_feat_1,shared_feat_2 = self.global_feat_1(global_source_features.unsqueeze(2), global_template_features.unsqueeze(2))
195
+ shared_feat_1,shared_feat_2 = self.global_feat_2(shared_feat_1, shared_feat_2)
196
+ shared_feat_1,shared_feat_2 = self.global_feat_3(shared_feat_1, shared_feat_2)
197
+
198
+ batch_size, _ , num_points = source_features.size()
199
+ global_source_features = shared_feat_1
200
+ global_source_features = global_source_features.repeat(1,1,num_points)
201
+ x = torch.cat([template_features, global_source_features], dim=1)
202
+ x = self.h3(x)
203
+
204
+ batch_size, _ , num_points = template_features.size()
205
+ global_template_features = shared_feat_2
206
+ global_template_features = global_template_features.repeat(1,1,num_points)
207
+ y = torch.cat([source_features, global_template_features], dim=1)
208
+ y = self.h3(y)
209
+
210
+ return x.view(batch_size, -1), y.view(batch_size, -1)
211
+
212
+ def forward(self, template, source):
213
+ source_features = self.feature_model(source) # [B x C x N]
214
+ template_features = self.feature_model(template) # [B x C x N]
215
+
216
+ template_mask, source_mask = self.find_mask(source_features, template_features)
217
+ return template_mask, source_mask
218
+
219
+ class MaskNet2(nn.Module):
220
+ def __init__(self, feature_model=PointNet(use_bn=True), is_training=True):
221
+ super().__init__()
222
+ self.maskNet = PointNetMask(feature_model=feature_model)
223
+ self.is_training = is_training
224
+
225
+ @staticmethod
226
+ def index_points(points, idx):
227
+ """
228
+ Input:
229
+ points: input points data, [B, N, C]
230
+ idx: sample index data, [B, S]
231
+ Return:
232
+ new_points:, indexed points data, [B, S, C]
233
+ """
234
+ device = points.device
235
+ B = points.shape[0]
236
+ view_shape = list(idx.shape)
237
+ view_shape[1:] = [1] * (len(view_shape) - 1)
238
+ repeat_shape = list(idx.shape)
239
+ repeat_shape[0] = 1
240
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
241
+ new_points = points[batch_indices, idx, :]
242
+
243
+ return new_points
244
+
245
+ def forward(self, template, source, point_selection='threshold', mask_threshold = 0.5):
246
+ template_mask, source_mask = self.maskNet(template, source) #B, N
247
+ if not torch.cuda.is_available():
248
+ device = 'cpu'
249
+ device = torch.device(device)
250
+
251
+ source_binary_mask = torch.where(source_mask > mask_threshold, torch.ones(source_mask.size()).to(device), torch.zeros(source_mask.size()).to(device))
252
+ template_binary_mask = torch.where(template_mask > mask_threshold, torch.ones(template_mask.size()).to(device), torch.zeros(template_mask.size()).to(device))
253
+
254
+ masked_template = template[:, torch.tensor(template_binary_mask, dtype = torch.bool).squeeze(0), 0:3]
255
+ masked_source = source[:, torch.tensor(source_binary_mask, dtype = torch.bool).squeeze(0), 0:3]
256
+
257
+ return masked_template, masked_source, template_mask, source_mask
258
+
259
+
260
+ if __name__ == '__main__':
261
+ template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
262
+ net = MaskNet2()
263
+ result = net(template, source)
264
+ import ipdb; ipdb.set_trace()
@@ -0,0 +1,164 @@
1
+ # author: Vinit Sarode (vinitsarode5@gmail.com) 03/23/2020
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from .pooling import Pooling
7
+
8
+ class PCN(torch.nn.Module):
9
+ def __init__(self, emb_dims=1024, input_shape="bnc", num_coarse=1024, grid_size=4, detailed_output=False):
10
+ # emb_dims: Embedding Dimensions for PCN.
11
+ # input_shape: Shape of Input Point Cloud (b: batch, n: no of points, c: channels)
12
+ super(PCN, self).__init__()
13
+ if input_shape not in ["bcn", "bnc"]:
14
+ raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
15
+ self.input_shape = input_shape
16
+ self.emb_dims = emb_dims
17
+ self.num_coarse = num_coarse
18
+ self.detailed_output = detailed_output
19
+ self.grid_size = grid_size
20
+ self.num_fine = self.grid_size ** 2 * self.num_coarse
21
+ self.pooling = Pooling('max')
22
+
23
+ self.encoder()
24
+ self.decoder_layers = self.decoder()
25
+ if detailed_output: self.folding_layers = self.folding()
26
+
27
+ def encoder_1(self):
28
+ self.conv1 = torch.nn.Conv1d(3, 128, 1)
29
+ self.conv2 = torch.nn.Conv1d(128, 256, 1)
30
+ self.relu = torch.nn.ReLU()
31
+
32
+ # self.bn1 = torch.nn.BatchNorm1d(128)
33
+ # self.bn2 = torch.nn.BatchNorm1d(256)
34
+
35
+ layers = [self.conv1, self.relu,
36
+ self.conv2]
37
+ return layers
38
+
39
+ def encoder_2(self):
40
+ self.conv3 = torch.nn.Conv1d(2*256, 512, 1)
41
+ self.conv4 = torch.nn.Conv1d(512, self.emb_dims, 1)
42
+
43
+ # self.bn3 = torch.nn.BatchNorm1d(512)
44
+ # self.bn4 = torch.nn.BatchNorm1d(self.emb_dims)
45
+ self.relu = torch.nn.ReLU()
46
+
47
+ layers = [self.conv3, self.relu,
48
+ self.conv4]
49
+ return layers
50
+
51
+ def encoder(self):
52
+ self.encoder_layers1 = self.encoder_1()
53
+ self.encoder_layers2 = self.encoder_2()
54
+
55
+ def decoder(self):
56
+ self.linear1 = torch.nn.Linear(self.emb_dims, 1024)
57
+ self.linear2 = torch.nn.Linear(1024, 1024)
58
+ self.linear3 = torch.nn.Linear(1024, self.num_coarse*3)
59
+
60
+ # self.bn1 = torch.nn.BatchNorm1d(1024)
61
+ # self.bn2 = torch.nn.BatchNorm1d(1024)
62
+ # self.bn3 = torch.nn.BatchNorm1d(self.num_coarse*3)
63
+ self.relu = torch.nn.ReLU()
64
+
65
+ layers = [self.linear1, self.relu,
66
+ self.linear2, self.relu,
67
+ self.linear3]
68
+ return layers
69
+
70
+ def folding(self):
71
+ self.conv5 = torch.nn.Conv1d(1029, 512, 1)
72
+ self.conv6 = torch.nn.Conv1d(512, 512, 1)
73
+ self.conv7 = torch.nn.Conv1d(512, 3, 1)
74
+
75
+ # self.bn5 = torch.nn.BatchNorm1d(512)
76
+ # self.bn6 = torch.nn.BatchNorm1d(512)
77
+ self.relu = torch.nn.ReLU()
78
+
79
+ layers = [self.conv5, self.relu,
80
+ self.conv6, self.relu,
81
+ self.conv7]
82
+ return layers
83
+
84
+ def fine_decoder(self):
85
+ # Fine Output
86
+ linspace = torch.linspace(-0.05, 0.05, steps=self.grid_size)
87
+ grid = torch.meshgrid(linspace, linspace)
88
+ grid = torch.reshape(torch.stack(grid, dim=2), (-1,2)) # 16x2
89
+ grid = torch.unsqueeze(grid, dim=0) # 1x16x2
90
+ grid_feature = grid.repeat([self.coarse_output.shape[0], self.num_coarse, 1]) # Bx16384x2
91
+
92
+ point_feature = torch.unsqueeze(self.coarse_output, dim=2) # Bx1024x1x3
93
+ point_feature = point_feature.repeat([1, 1, self.grid_size ** 2, 1]) # Bx1024x16x3
94
+ point_feature = torch.reshape(point_feature, (-1, self.num_fine, 3)) # Bx16384x3
95
+
96
+ global_feature = torch.unsqueeze(self.global_feature_v, dim=1) # Bx1x1024
97
+ global_feature = global_feature.repeat([1, self.num_fine, 1]) # Bx16384x1024
98
+
99
+ feature = torch.cat([grid_feature, point_feature, global_feature], dim=2) # Bx16384x1029
100
+
101
+ center = torch.unsqueeze(self.coarse_output, dim=2) # Bx1024x1x3
102
+ center = center.repeat([1, 1, self.grid_size ** 2, 1]) # Bx1024x16x3
103
+ center = torch.reshape(center, [-1, self.num_fine, 3]) # Bx16384x3
104
+
105
+ output = feature.permute(0, 2, 1)
106
+ for idx, layer in enumerate(self.folding_layers):
107
+ output = layer(output)
108
+ fine_output = output.permute(0, 2, 1) + center
109
+ return fine_output
110
+
111
+ def encode(self, input_data):
112
+ output = input_data
113
+ for idx, layer in enumerate(self.encoder_layers1):
114
+ output = layer(output)
115
+
116
+ global_feature_g = self.pooling(output)
117
+
118
+ global_feature_g = global_feature_g.unsqueeze(2)
119
+ global_feature_g = global_feature_g.repeat(1,1,self.num_points)
120
+ output = torch.cat([output, global_feature_g], dim=1)
121
+
122
+ for idx, layer in enumerate(self.encoder_layers2):
123
+ output = layer(output)
124
+
125
+ self.global_feature_v = self.pooling(output)
126
+
127
+ def decode(self):
128
+ output = self.global_feature_v
129
+ for idx, layer in enumerate(self.decoder_layers):
130
+ output = layer(output)
131
+ self.coarse_output = output.view(self.global_feature_v.shape[0], self.num_coarse, 3)
132
+
133
+ def forward(self, input_data):
134
+ # input_data: Point Cloud having shape input_shape.
135
+ # output: PointNet features (Batch x emb_dims)
136
+ if self.input_shape == "bnc":
137
+ self.num_points = input_data.shape[1]
138
+ input_data = input_data.permute(0, 2, 1)
139
+ else:
140
+ self.num_points = input_data.shape[2]
141
+ if input_data.shape[1] != 3:
142
+ raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
143
+
144
+ self.encode(input_data)
145
+ self.decode()
146
+
147
+ result = {'coarse_output': self.coarse_output}
148
+
149
+ if self.detailed_output:
150
+ fine_output = self.fine_decoder()
151
+ result['fine_output'] = fine_output
152
+
153
+ return result
154
+
155
+
156
+ if __name__ == '__main__':
157
+ # Test the code.
158
+ x = torch.rand((10,1024,3))
159
+
160
+ pcn = PCN()
161
+ y = pcn(x)
162
+ print("Network Architecture: ")
163
+ print(pn)
164
+ print("Input Shape of PCN: ", x.shape, "\nOutput Shape of PCN: ", y['coarse_output'].shape)
@@ -0,0 +1,74 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .pointnet import PointNet
5
+ from .pooling import Pooling
6
+ from .. ops.transform_functions import PCRNetTransform as transform
7
+
8
+
9
+ class iPCRNet(nn.Module):
10
+ def __init__(self, feature_model=PointNet(), droput=0.0, pooling='max'):
11
+ super().__init__()
12
+ self.feature_model = feature_model
13
+ self.pooling = Pooling(pooling)
14
+
15
+ self.linear = [nn.Linear(self.feature_model.emb_dims * 2, 1024), nn.ReLU(),
16
+ nn.Linear(1024, 1024), nn.ReLU(),
17
+ nn.Linear(1024, 512), nn.ReLU(),
18
+ nn.Linear(512, 512), nn.ReLU(),
19
+ nn.Linear(512, 256), nn.ReLU()]
20
+
21
+ if droput>0.0:
22
+ self.linear.append(nn.Dropout(droput))
23
+ self.linear.append(nn.Linear(256,7))
24
+
25
+ self.linear = nn.Sequential(*self.linear)
26
+
27
+ # Single Pass Alignment Module (SPAM)
28
+ def spam(self, template_features, source, est_R, est_t):
29
+ batch_size = source.size(0)
30
+
31
+ self.source_features = self.pooling(self.feature_model(source))
32
+ y = torch.cat([template_features, self.source_features], dim=1)
33
+ pose_7d = self.linear(y)
34
+ pose_7d = transform.create_pose_7d(pose_7d)
35
+
36
+ # Find current rotation and translation.
37
+ identity = torch.eye(3).to(source).view(1,3,3).expand(batch_size, 3, 3).contiguous()
38
+ est_R_temp = transform.quaternion_rotate(identity, pose_7d).permute(0, 2, 1)
39
+ est_t_temp = transform.get_translation(pose_7d).view(-1, 1, 3)
40
+
41
+ # update translation matrix.
42
+ est_t = torch.bmm(est_R_temp, est_t.permute(0, 2, 1)).permute(0, 2, 1) + est_t_temp
43
+ # update rotation matrix.
44
+ est_R = torch.bmm(est_R_temp, est_R)
45
+
46
+ source = transform.quaternion_transform(source, pose_7d) # Ps' = est_R*Ps + est_t
47
+ return est_R, est_t, source
48
+
49
+ def forward(self, template, source, max_iteration=8):
50
+ est_R = torch.eye(3).to(template).view(1, 3, 3).expand(template.size(0), 3, 3).contiguous() # (Bx3x3)
51
+ est_t = torch.zeros(1,3).to(template).view(1, 1, 3).expand(template.size(0), 1, 3).contiguous() # (Bx1x3)
52
+ template_features = self.pooling(self.feature_model(template))
53
+
54
+ if max_iteration == 1:
55
+ est_R, est_t, source = self.spam(template_features, source, est_R, est_t)
56
+ else:
57
+ for i in range(max_iteration):
58
+ est_R, est_t, source = self.spam(template_features, source, est_R, est_t)
59
+
60
+ result = {'est_R': est_R, # source -> template
61
+ 'est_t': est_t, # source -> template
62
+ 'est_T': transform.convert2transformation(est_R, est_t), # source -> template
63
+ 'r': template_features - self.source_features,
64
+ 'transformed_source': source}
65
+ return result
66
+
67
+
68
+ if __name__ == '__main__':
69
+ template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
70
+ pn = PointNet()
71
+
72
+ net = iPCRNet(pn)
73
+ result = net(template, source)
74
+ import ipdb; ipdb.set_trace()
@@ -0,0 +1,108 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .. utils import PointConvDensitySetAbstraction
5
+
6
+ class PointConvDensityClsSsg(torch.nn.Module):
7
+ def __init__(self, emb_dims=1024, input_shape="bnc", input_channel_dim=3, classifier=False, num_classes=40, pretrained=None):
8
+ super(PointConvDensityClsSsg, self).__init__()
9
+ if input_shape not in ["bnc", "bcn"]:
10
+ raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
11
+ self.input_shape = input_shape
12
+ self.emb_dims = emb_dims
13
+ self.classifier = classifier
14
+ self.input_channel_dim = input_channel_dim
15
+ self.create_structure()
16
+ if self.classifier: self.create_classifier(num_classes)
17
+
18
+ def create_structure(self):
19
+ # Arguments to define PointConv network using PointConvDensitySetAbstraction class.
20
+ # npoint: number of points sampled from input.
21
+ # nsample: number of neighbours chosen for each point in sampled point cloud.
22
+ # in_channel: number of channels in input.
23
+ # mlp: sizes of multi-layer perceptrons.
24
+ # bandwidth: used to compute gaussian density.
25
+ # group_all: group all points from input to a single point if set to True.
26
+ self.sa1 = PointConvDensitySetAbstraction(npoint=512, nsample=32, in_channel=self.input_channel_dim,
27
+ mlp=[64, 64, 128], bandwidth = 0.1, group_all=False)
28
+ self.sa2 = PointConvDensitySetAbstraction(npoint=128, nsample=64, in_channel=128 + 3,
29
+ mlp=[128, 128, 256], bandwidth = 0.2, group_all=False)
30
+ self.sa3 = PointConvDensitySetAbstraction(npoint=1, nsample=None, in_channel=256 + 3,
31
+ mlp=[256, 512, self.emb_dims], bandwidth = 0.4, group_all=True)
32
+
33
+ def create_classifier(self, num_classes):
34
+ # These are simple fully-connected layers with batch-norm and dropouts.
35
+ # This architecture is given by PointConv paper. Hence, I used it here as a default version.
36
+ # This can be easily modified by overwriting this function or by using classifier.py class.
37
+ self.fc1 = nn.Linear(self.emb_dims, 512)
38
+ self.bn1 = nn.BatchNorm1d(512)
39
+ self.drop1 = nn.Dropout(0.7)
40
+ self.fc2 = nn.Linear(512, 256)
41
+ self.bn2 = nn.BatchNorm1d(256)
42
+ self.drop2 = nn.Dropout(0.7)
43
+ self.fc3 = nn.Linear(256, num_classes)
44
+
45
+ def forward(self, input_data):
46
+ if self.input_shape == "bnc":
47
+ input_data = input_data.permute(0, 2, 1)
48
+ batch_size = input_data.shape[0]
49
+
50
+ # Convert point clouds to latent features using PointConv network.
51
+ l1_points, l1_features = self.sa1(input_data[:, :3, :], input_data[:, 3:, :])
52
+ l2_points, l2_features = self.sa2(l1_points, l1_features)
53
+ l3_points, l3_features = self.sa3(l2_points, l2_features)
54
+ features = l3_features.view(batch_size, self.emb_dims)
55
+
56
+ if self.classifier:
57
+ # Use these features to classify the input point cloud.
58
+ features = self.drop1(F.relu(self.bn1(self.fc1(features))))
59
+ features = self.drop2(F.relu(self.bn2(self.fc2(features))))
60
+ features = self.fc3(features)
61
+ output = F.log_softmax(features, -1)
62
+ else:
63
+ # Return the PointConv features for the use of other higher level tasks.
64
+ output = features
65
+
66
+ return output
67
+
68
+ def create_pointconv(classifier=False, pretrained=None):
69
+ if classifier and pretrained is not None:
70
+ class Network(torch.nn.Module):
71
+ def __init__(self, emb_dims=1024, input_shape="bnc", input_channel_dim=3, classifier=False, num_classes=40, pretrained=None):
72
+ # Arguments:
73
+ # emb_dims: Size of embeddings.
74
+ # input_shape: Shape of input point cloud.
75
+ # input_channel_dim: Number of channels in point cloud. [eg. Nx3 (only points) or Nx6 (points + normals)]
76
+ # classifier: Do you want to use default classifier layers or just the embedding layers.
77
+ # num_classes: If you use classifier then decide the number of classes in your dataset.
78
+ # use_pretrained: Use pretrained classification network.
79
+ super(PointConv, self).__init__()
80
+ self.pointconv = PointConvDensityClsSsg(emb_dims, input_shape, input_channel_dim, classifier, num_classes)
81
+ # super().__init__(emb_dims, input_shape, input_channel_dim, classifier, num_classes)
82
+ if classifier and pretrained is not None:
83
+ self.use_pretrained(pretrained)
84
+
85
+ def use_pretrained(self, pretrained):
86
+ checkpoint = torch.load(pretrained, map_location='cpu')
87
+ self.pointconv.load_state_dict(checkpoint['model_state_dict'])
88
+
89
+ def forward(self, input_data):
90
+ return self.pointconv(input_data)
91
+ return Network
92
+ else:
93
+ class Network(PointConvDensityClsSsg):
94
+ def __init__(self, emb_dims=1024, input_shape="bnc", input_channel_dim=3, classifier=False, num_classes=40, pretrained=None):
95
+ super().__init__(emb_dims=emb_dims, input_shape=input_shape, input_channel_dim=input_channel_dim, classifier=classifier, num_classes=num_classes, pretrained=pretrained)
96
+ return Network
97
+
98
+
99
+ if __name__ == '__main__':
100
+ # Test the code.
101
+ x = torch.rand((2,1024,3))
102
+
103
+ PointConv = create_pointconv(classifier=False, pretrained='checkpoint.pth')
104
+ pc = PointConv(input_channel_dim=3, classifier=False, pretrained='checkpoint.pth')
105
+ y = pc(x)
106
+ print("Network Architecture: ")
107
+ print(pc)
108
+ print("Input Shape of PointNet: ", x.shape, "\nOutput Shape of PointNet: ", y.shape)