learning3d 0.0.7__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {learning3d/data_utils → data_utils}/dataloaders.py +16 -14
- examples/test_curvenet.py +118 -0
- {learning3d/examples → examples}/test_dcp.py +1 -1
- {learning3d/examples → examples}/test_deepgmr.py +1 -1
- {learning3d/examples → examples}/test_prnet.py +1 -1
- {learning3d-0.0.7.dist-info → learning3d-0.2.0.dist-info}/METADATA +56 -11
- learning3d-0.2.0.dist-info/RECORD +70 -0
- {learning3d-0.0.7.dist-info → learning3d-0.2.0.dist-info}/WHEEL +1 -1
- learning3d-0.2.0.dist-info/top_level.txt +6 -0
- {learning3d/models → models}/__init__.py +7 -1
- models/curvenet.py +130 -0
- {learning3d/models → models}/dgcnn.py +1 -35
- {learning3d/models → models}/prnet.py +5 -39
- utils/__init__.py +23 -0
- utils/curvenet_util.py +540 -0
- utils/model_common_utils.py +156 -0
- learning3d/losses/cuda/chamfer_distance/__init__.py +0 -1
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cpp +0 -185
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.cu +0 -209
- learning3d/losses/cuda/chamfer_distance/chamfer_distance.py +0 -66
- learning3d/losses/cuda/emd_torch/pkg/emd_loss_layer.py +0 -41
- learning3d/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh +0 -347
- learning3d/losses/cuda/emd_torch/pkg/include/cuda_helper.h +0 -18
- learning3d/losses/cuda/emd_torch/pkg/include/emd.h +0 -54
- learning3d/losses/cuda/emd_torch/pkg/layer/__init__.py +0 -1
- learning3d/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py +0 -40
- learning3d/losses/cuda/emd_torch/pkg/src/cuda/emd.cu +0 -70
- learning3d/losses/cuda/emd_torch/pkg/src/emd.cpp +0 -1
- learning3d/losses/cuda/emd_torch/setup.py +0 -29
- learning3d/ops/__init__.py +0 -0
- learning3d/utils/__init__.py +0 -4
- learning3d-0.0.7.dist-info/RECORD +0 -80
- learning3d-0.0.7.dist-info/top_level.txt +0 -1
- {learning3d/data_utils → data_utils}/__init__.py +0 -0
- {learning3d/data_utils → data_utils}/user_data.py +0 -0
- {learning3d/examples → examples}/test_masknet.py +0 -0
- {learning3d/examples → examples}/test_masknet2.py +0 -0
- {learning3d/examples → examples}/test_pcn.py +0 -0
- {learning3d/examples → examples}/test_pcrnet.py +0 -0
- {learning3d/examples → examples}/test_pnlk.py +0 -0
- {learning3d/examples → examples}/test_pointconv.py +0 -0
- {learning3d/examples → examples}/test_pointnet.py +0 -0
- {learning3d/examples → examples}/test_rpmnet.py +0 -0
- {learning3d/examples → examples}/train_PointNetLK.py +0 -0
- {learning3d/examples → examples}/train_dcp.py +0 -0
- {learning3d/examples → examples}/train_deepgmr.py +0 -0
- {learning3d/examples → examples}/train_masknet.py +0 -0
- {learning3d/examples → examples}/train_pcn.py +0 -0
- {learning3d/examples → examples}/train_pcrnet.py +0 -0
- {learning3d/examples → examples}/train_pointconv.py +0 -0
- {learning3d/examples → examples}/train_pointnet.py +0 -0
- {learning3d/examples → examples}/train_prnet.py +0 -0
- {learning3d/examples → examples}/train_rpmnet.py +0 -0
- {learning3d-0.0.7.dist-info → learning3d-0.2.0.dist-info}/LICENSE +0 -0
- {learning3d/losses → losses}/__init__.py +0 -0
- {learning3d/losses → losses}/chamfer_distance.py +0 -0
- {learning3d/losses → losses}/classification.py +0 -0
- {learning3d/losses → losses}/correspondence_loss.py +0 -0
- {learning3d/losses → losses}/emd.py +0 -0
- {learning3d/losses → losses}/frobenius_norm.py +0 -0
- {learning3d/losses → losses}/rmse_features.py +0 -0
- {learning3d/models → models}/classifier.py +0 -0
- {learning3d/models → models}/dcp.py +0 -0
- {learning3d/models → models}/deepgmr.py +0 -0
- {learning3d/models → models}/masknet.py +0 -0
- {learning3d/models → models}/masknet2.py +0 -0
- {learning3d/models → models}/pcn.py +0 -0
- {learning3d/models → models}/pcrnet.py +0 -0
- {learning3d/models → models}/pointconv.py +0 -0
- {learning3d/models → models}/pointnet.py +0 -0
- {learning3d/models → models}/pointnetlk.py +0 -0
- {learning3d/models → models}/pooling.py +0 -0
- {learning3d/models → models}/ppfnet.py +0 -0
- {learning3d/models → models}/rpmnet.py +0 -0
- {learning3d/models → models}/segmentation.py +0 -0
- {learning3d → ops}/__init__.py +0 -0
- {learning3d/ops → ops}/data_utils.py +0 -0
- {learning3d/ops → ops}/invmat.py +0 -0
- {learning3d/ops → ops}/quaternion.py +0 -0
- {learning3d/ops → ops}/se3.py +0 -0
- {learning3d/ops → ops}/sinc.py +0 -0
- {learning3d/ops → ops}/so3.py +0 -0
- {learning3d/ops → ops}/transform_functions.py +0 -0
- {learning3d/utils → utils}/pointconv_util.py +0 -0
- {learning3d/utils → utils}/ppfnet_util.py +0 -0
- {learning3d/utils → utils}/svd.py +0 -0
- {learning3d/utils → utils}/transformer.py +0 -0
@@ -16,7 +16,7 @@ import torch.nn as nn
|
|
16
16
|
import torch.nn.functional as F
|
17
17
|
|
18
18
|
from .. ops import transform_functions as transform
|
19
|
-
from .. utils import Transformer, Identity
|
19
|
+
from .. utils import Transformer, Identity, knn, get_graph_feature
|
20
20
|
|
21
21
|
from sklearn.metrics import r2_score
|
22
22
|
|
@@ -30,40 +30,6 @@ def pairwise_distance(src, tgt):
|
|
30
30
|
distances = xx.transpose(2, 1).contiguous() + inner + yy
|
31
31
|
return torch.sqrt(distances)
|
32
32
|
|
33
|
-
|
34
|
-
def knn(x, k):
|
35
|
-
inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
|
36
|
-
xx = torch.sum(x ** 2, dim=1, keepdim=True)
|
37
|
-
distance = -xx - inner - xx.transpose(2, 1).contiguous()
|
38
|
-
|
39
|
-
idx = distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
|
40
|
-
return idx
|
41
|
-
|
42
|
-
|
43
|
-
def get_graph_feature(x, k=20):
|
44
|
-
# x = x.squeeze()
|
45
|
-
x = x.view(*x.size()[:3])
|
46
|
-
idx = knn(x, k=k) # (batch_size, num_points, k)
|
47
|
-
batch_size, num_points, _ = idx.size()
|
48
|
-
|
49
|
-
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
|
50
|
-
|
51
|
-
idx = idx + idx_base
|
52
|
-
|
53
|
-
idx = idx.view(-1)
|
54
|
-
|
55
|
-
_, num_dims, _ = x.size()
|
56
|
-
|
57
|
-
x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
|
58
|
-
feature = x.view(batch_size * num_points, -1)[idx, :]
|
59
|
-
feature = feature.view(batch_size, num_points, k, num_dims)
|
60
|
-
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
|
61
|
-
|
62
|
-
feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
|
63
|
-
|
64
|
-
return feature
|
65
|
-
|
66
|
-
|
67
33
|
def cycle_consistency(rotation_ab, translation_ab, rotation_ba, translation_ba):
|
68
34
|
batch_size = rotation_ab.size(0)
|
69
35
|
identity = torch.eye(3, device=rotation_ab.device).unsqueeze(0).repeat(batch_size, 1, 1)
|
@@ -109,19 +75,19 @@ class DGCNN(nn.Module):
|
|
109
75
|
|
110
76
|
def forward(self, x):
|
111
77
|
batch_size, num_dims, num_points = x.size()
|
112
|
-
x = get_graph_feature(x)
|
78
|
+
x = get_graph_feature(x, device=device)
|
113
79
|
x = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.2)
|
114
80
|
x1 = x.max(dim=-1, keepdim=True)[0]
|
115
81
|
|
116
|
-
x = get_graph_feature(x1)
|
82
|
+
x = get_graph_feature(x1, device=device)
|
117
83
|
x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.2)
|
118
84
|
x2 = x.max(dim=-1, keepdim=True)[0]
|
119
85
|
|
120
|
-
x = get_graph_feature(x2)
|
86
|
+
x = get_graph_feature(x2, device=device)
|
121
87
|
x = F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.2)
|
122
88
|
x3 = x.max(dim=-1, keepdim=True)[0]
|
123
89
|
|
124
|
-
x = get_graph_feature(x3)
|
90
|
+
x = get_graph_feature(x3, device=device)
|
125
91
|
x = F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.2)
|
126
92
|
x4 = x.max(dim=-1, keepdim=True)[0]
|
127
93
|
|
utils/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
from .svd import SVDHead
|
2
|
+
from .transformer import Transformer, Identity
|
3
|
+
from .ppfnet_util import angle_difference, square_distance, index_points, farthest_point_sample, query_ball_point, sample_and_group, sample_and_group_multi
|
4
|
+
from .pointconv_util import PointConvDensitySetAbstraction
|
5
|
+
from .model_common_utils import (
|
6
|
+
knn,
|
7
|
+
pc_normalize,
|
8
|
+
square_distance,
|
9
|
+
index_points,
|
10
|
+
farthest_point_sample,
|
11
|
+
knn_point,
|
12
|
+
query_ball_point,
|
13
|
+
get_graph_feature
|
14
|
+
)
|
15
|
+
from .curvenet_util import (
|
16
|
+
LPFA,
|
17
|
+
CIC,
|
18
|
+
)
|
19
|
+
|
20
|
+
try:
|
21
|
+
from .lib import pointnet2_utils
|
22
|
+
except:
|
23
|
+
print("Error raised in pointnet2 module in utils!\nEither don't use pointnet2_utils or retry it's setup.")
|
utils/curvenet_util.py
ADDED
@@ -0,0 +1,540 @@
|
|
1
|
+
"""
|
2
|
+
@Author: Yue Wang
|
3
|
+
@Contact: yuewangx@mit.edu
|
4
|
+
@File: pointnet_util.py
|
5
|
+
@Time: 2018/10/13 10:39 PM
|
6
|
+
|
7
|
+
Modified by
|
8
|
+
@Author: Tiange Xiang
|
9
|
+
@Contact: txia7609@uni.sydney.edu.au
|
10
|
+
@Time: 2021/01/21 3:10 PM
|
11
|
+
"""
|
12
|
+
|
13
|
+
import torch
|
14
|
+
import torch.nn as nn
|
15
|
+
import torch.nn.functional as F
|
16
|
+
from time import time
|
17
|
+
import numpy as np
|
18
|
+
from .model_common_utils import (
|
19
|
+
knn,
|
20
|
+
square_distance,
|
21
|
+
index_points,
|
22
|
+
farthest_point_sample,
|
23
|
+
query_ball_point,
|
24
|
+
)
|
25
|
+
|
26
|
+
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
|
27
|
+
"""
|
28
|
+
Input:
|
29
|
+
npoint:
|
30
|
+
radius:
|
31
|
+
nsample:
|
32
|
+
xyz: input points position data, [B, N, 3]
|
33
|
+
points: input points data, [B, N, D]
|
34
|
+
Return:
|
35
|
+
new_xyz: sampled points position data, [B, npoint, nsample, 3]
|
36
|
+
new_points: sampled points data, [B, npoint, nsample, 3+D]
|
37
|
+
"""
|
38
|
+
new_xyz = index_points(xyz, farthest_point_sample(xyz, npoint, start_with_first_point=True))
|
39
|
+
torch.cuda.empty_cache()
|
40
|
+
|
41
|
+
idx = query_ball_point(radius, nsample, xyz, new_xyz, get_cnt=False)
|
42
|
+
torch.cuda.empty_cache()
|
43
|
+
|
44
|
+
new_points = index_points(points, idx)
|
45
|
+
torch.cuda.empty_cache()
|
46
|
+
|
47
|
+
if returnfps:
|
48
|
+
return new_xyz, new_points, idx
|
49
|
+
else:
|
50
|
+
return new_xyz, new_points
|
51
|
+
|
52
|
+
def batched_index_select(input, dim, index):
|
53
|
+
views = [input.shape[0]] + \
|
54
|
+
[1 if i != dim else -1 for i in range(1, len(input.shape))]
|
55
|
+
expanse = list(input.shape)
|
56
|
+
expanse[0] = -1
|
57
|
+
expanse[dim] = -1
|
58
|
+
index = index.view(views).expand(expanse)
|
59
|
+
return torch.gather(input, dim, index)
|
60
|
+
|
61
|
+
def gumbel_softmax(logits, dim, temperature=1):
|
62
|
+
"""
|
63
|
+
ST-gumple-softmax w/o random gumbel samplings
|
64
|
+
input: [*, n_class]
|
65
|
+
return: flatten --> [*, n_class] an one-hot vector
|
66
|
+
"""
|
67
|
+
y = F.softmax(logits / temperature, dim=dim)
|
68
|
+
|
69
|
+
shape = y.size()
|
70
|
+
_, ind = y.max(dim=-1)
|
71
|
+
y_hard = torch.zeros_like(y).view(-1, shape[-1])
|
72
|
+
y_hard.scatter_(1, ind.view(-1, 1), 1)
|
73
|
+
y_hard = y_hard.view(*shape)
|
74
|
+
|
75
|
+
y_hard = (y_hard - y).detach() + y
|
76
|
+
return y_hard
|
77
|
+
|
78
|
+
class Walk(nn.Module):
|
79
|
+
'''
|
80
|
+
Walk in the cloud
|
81
|
+
'''
|
82
|
+
def __init__(self, in_channel, k, curve_num, curve_length):
|
83
|
+
super(Walk, self).__init__()
|
84
|
+
self.curve_num = curve_num
|
85
|
+
self.curve_length = curve_length
|
86
|
+
self.k = k
|
87
|
+
|
88
|
+
self.agent_mlp = nn.Sequential(
|
89
|
+
nn.Conv2d(in_channel * 2,
|
90
|
+
1,
|
91
|
+
kernel_size=1,
|
92
|
+
bias=False), nn.BatchNorm2d(1))
|
93
|
+
self.momentum_mlp = nn.Sequential(
|
94
|
+
nn.Conv1d(in_channel * 2,
|
95
|
+
2,
|
96
|
+
kernel_size=1,
|
97
|
+
bias=False), nn.BatchNorm1d(2))
|
98
|
+
|
99
|
+
def crossover_suppression(self, cur, neighbor, bn, n, k):
|
100
|
+
# cur: bs*n, 3
|
101
|
+
# neighbor: bs*n, 3, k
|
102
|
+
neighbor = neighbor.detach()
|
103
|
+
cur = cur.unsqueeze(-1).detach()
|
104
|
+
dot = torch.bmm(cur.transpose(1,2), neighbor) # bs*n, 1, k
|
105
|
+
norm1 = torch.norm(cur, dim=1, keepdim=True)
|
106
|
+
norm2 = torch.norm(neighbor, dim=1, keepdim=True)
|
107
|
+
divider = torch.clamp(norm1 * norm2, min=1e-8)
|
108
|
+
ans = torch.div(dot, divider).squeeze() # bs*n, k
|
109
|
+
|
110
|
+
# normalize to [0, 1]
|
111
|
+
ans = 1. + ans
|
112
|
+
ans = torch.clamp(ans, 0., 1.0)
|
113
|
+
|
114
|
+
return ans.detach()
|
115
|
+
|
116
|
+
def forward(self, xyz, x, adj, cur):
|
117
|
+
bn, c, tot_points = x.size()
|
118
|
+
device = x.device
|
119
|
+
|
120
|
+
# raw point coordinates
|
121
|
+
xyz = xyz.transpose(1,2).contiguous # bs, n, 3
|
122
|
+
|
123
|
+
# point features
|
124
|
+
x = x.transpose(1,2).contiguous() # bs, n, c
|
125
|
+
|
126
|
+
flatten_x = x.view(bn * tot_points, -1)
|
127
|
+
batch_offset = torch.arange(0, bn, device=device).detach() * tot_points
|
128
|
+
|
129
|
+
# indices of neighbors for the starting points
|
130
|
+
tmp_adj = (adj + batch_offset.view(-1,1,1)).view(adj.size(0)*adj.size(1),-1) #bs, n, k
|
131
|
+
|
132
|
+
# batch flattened indices for teh starting points
|
133
|
+
flatten_cur = (cur + batch_offset.view(-1,1,1)).view(-1)
|
134
|
+
|
135
|
+
curves = []
|
136
|
+
flatten_curve_idxs = [flatten_cur.unsqueeze(1)]
|
137
|
+
|
138
|
+
# one step at a time
|
139
|
+
for step in range(self.curve_length):
|
140
|
+
|
141
|
+
if step == 0:
|
142
|
+
# get starting point features using flattend indices
|
143
|
+
starting_points = flatten_x[flatten_cur, :].contiguous()
|
144
|
+
pre_feature = starting_points.view(bn, self.curve_num, -1, 1).transpose(1,2) # bs * n, c
|
145
|
+
else:
|
146
|
+
# dynamic momentum
|
147
|
+
cat_feature = torch.cat((cur_feature.squeeze(-1), pre_feature.squeeze(-1)),dim=1)
|
148
|
+
att_feature = F.softmax(self.momentum_mlp(cat_feature),dim=1).view(bn, 1, self.curve_num, 2) # bs, 1, n, 2
|
149
|
+
cat_feature = torch.cat((cur_feature, pre_feature),dim=-1) # bs, c, n, 2
|
150
|
+
|
151
|
+
# update curve descriptor
|
152
|
+
pre_feature = torch.sum(cat_feature * att_feature, dim=-1, keepdim=True) # bs, c, n
|
153
|
+
pre_feature_cos = pre_feature.transpose(1,2).contiguous().view(bn * self.curve_num, -1)
|
154
|
+
|
155
|
+
pick_idx = tmp_adj[flatten_cur] # bs*n, k
|
156
|
+
|
157
|
+
# get the neighbors of current points
|
158
|
+
pick_values = flatten_x[pick_idx.view(-1),:]
|
159
|
+
|
160
|
+
# reshape to fit crossover suppresion below
|
161
|
+
pick_values_cos = pick_values.view(bn * self.curve_num, self.k, c)
|
162
|
+
pick_values = pick_values_cos.view(bn, self.curve_num, self.k, c)
|
163
|
+
pick_values_cos = pick_values_cos.transpose(1,2).contiguous()
|
164
|
+
|
165
|
+
pick_values = pick_values.permute(0,3,1,2) # bs, c, n, k
|
166
|
+
|
167
|
+
pre_feature_expand = pre_feature.expand_as(pick_values)
|
168
|
+
|
169
|
+
# concat current point features with curve descriptors
|
170
|
+
pre_feature_expand = torch.cat((pick_values, pre_feature_expand),dim=1)
|
171
|
+
|
172
|
+
# which node to pick next?
|
173
|
+
pre_feature_expand = self.agent_mlp(pre_feature_expand) # bs, 1, n, k
|
174
|
+
|
175
|
+
if step !=0:
|
176
|
+
# cross over supression
|
177
|
+
d = self.crossover_suppression(cur_feature_cos - pre_feature_cos,
|
178
|
+
pick_values_cos - cur_feature_cos.unsqueeze(-1),
|
179
|
+
bn, self.curve_num, self.k)
|
180
|
+
d = d.view(bn, self.curve_num, self.k).unsqueeze(1) # bs, 1, n, k
|
181
|
+
pre_feature_expand = torch.mul(pre_feature_expand, d)
|
182
|
+
|
183
|
+
pre_feature_expand = gumbel_softmax(pre_feature_expand, -1) #bs, 1, n, k
|
184
|
+
|
185
|
+
cur_feature = torch.sum(pick_values * pre_feature_expand, dim=-1, keepdim=True) # bs, c, n, 1
|
186
|
+
|
187
|
+
cur_feature_cos = cur_feature.transpose(1,2).contiguous().view(bn * self.curve_num, c)
|
188
|
+
|
189
|
+
cur = torch.argmax(pre_feature_expand, dim=-1).view(-1, 1) # bs * n, 1
|
190
|
+
|
191
|
+
flatten_cur = batched_index_select(pick_idx, 1, cur).squeeze() # bs * n
|
192
|
+
|
193
|
+
# collect curve progress
|
194
|
+
curves.append(cur_feature)
|
195
|
+
flatten_curve_idxs.append(flatten_cur.unsqueeze(1))
|
196
|
+
|
197
|
+
return torch.cat(curves,dim=-1), torch.cat(flatten_curve_idxs, dim=1)
|
198
|
+
|
199
|
+
|
200
|
+
class Attention_block(nn.Module):
|
201
|
+
'''
|
202
|
+
Used in attention U-Net.
|
203
|
+
'''
|
204
|
+
def __init__(self,F_g,F_l,F_int):
|
205
|
+
super(Attention_block,self).__init__()
|
206
|
+
self.W_g = nn.Sequential(
|
207
|
+
nn.Conv1d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
|
208
|
+
nn.BatchNorm1d(F_int)
|
209
|
+
)
|
210
|
+
|
211
|
+
self.W_x = nn.Sequential(
|
212
|
+
nn.Conv1d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
|
213
|
+
nn.BatchNorm1d(F_int)
|
214
|
+
)
|
215
|
+
|
216
|
+
self.psi = nn.Sequential(
|
217
|
+
nn.Conv1d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
|
218
|
+
nn.BatchNorm1d(1),
|
219
|
+
nn.Sigmoid()
|
220
|
+
)
|
221
|
+
|
222
|
+
def forward(self,g,x):
|
223
|
+
g1 = self.W_g(g)
|
224
|
+
x1 = self.W_x(x)
|
225
|
+
psi = F.leaky_relu(g1+x1, negative_slope=0.2)
|
226
|
+
psi = self.psi(psi)
|
227
|
+
|
228
|
+
return psi, 1. - psi
|
229
|
+
|
230
|
+
|
231
|
+
class LPFA(nn.Module):
|
232
|
+
def __init__(self, in_channel, out_channel, k, mlp_num=2, initial=False):
|
233
|
+
super(LPFA, self).__init__()
|
234
|
+
self.k = k
|
235
|
+
self.initial = initial
|
236
|
+
|
237
|
+
if not initial:
|
238
|
+
self.xyz2feature = nn.Sequential(
|
239
|
+
nn.Conv2d(9, in_channel, kernel_size=1, bias=False),
|
240
|
+
nn.BatchNorm2d(in_channel))
|
241
|
+
|
242
|
+
self.mlp = []
|
243
|
+
for _ in range(mlp_num):
|
244
|
+
self.mlp.append(nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, bias=False),
|
245
|
+
nn.BatchNorm2d(out_channel),
|
246
|
+
nn.LeakyReLU(0.2)))
|
247
|
+
in_channel = out_channel
|
248
|
+
self.mlp = nn.Sequential(*self.mlp)
|
249
|
+
|
250
|
+
def forward(self, x, xyz, idx=None):
|
251
|
+
x = self.group_feature(x, xyz, idx)
|
252
|
+
x = self.mlp(x)
|
253
|
+
|
254
|
+
if self.initial:
|
255
|
+
x = x.max(dim=-1, keepdim=False)[0]
|
256
|
+
else:
|
257
|
+
x = x.mean(dim=-1, keepdim=False)
|
258
|
+
|
259
|
+
return x
|
260
|
+
|
261
|
+
def group_feature(self, x, xyz, idx):
|
262
|
+
batch_size, num_dims, num_points = x.size()
|
263
|
+
device = x.device
|
264
|
+
|
265
|
+
if idx is None:
|
266
|
+
idx = knn(xyz, k=self.k, add_one_to_k=True)[:,:,:self.k] # (batch_size, num_points, k)
|
267
|
+
|
268
|
+
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
|
269
|
+
idx = idx + idx_base
|
270
|
+
idx = idx.view(-1)
|
271
|
+
|
272
|
+
xyz = xyz.transpose(2, 1).contiguous() # bs, n, 3
|
273
|
+
point_feature = xyz.view(batch_size * num_points, -1)[idx, :]
|
274
|
+
point_feature = point_feature.view(batch_size, num_points, self.k, -1) # bs, n, k, 3
|
275
|
+
points = xyz.view(batch_size, num_points, 1, 3).expand(-1, -1, self.k, -1) # bs, n, k, 3
|
276
|
+
|
277
|
+
point_feature = torch.cat((points, point_feature, point_feature - points),
|
278
|
+
dim=3).permute(0, 3, 1, 2).contiguous()
|
279
|
+
|
280
|
+
if self.initial:
|
281
|
+
return point_feature
|
282
|
+
|
283
|
+
x = x.transpose(2, 1).contiguous() # bs, n, c
|
284
|
+
feature = x.view(batch_size * num_points, -1)[idx, :]
|
285
|
+
feature = feature.view(batch_size, num_points, self.k, num_dims) #bs, n, k, c
|
286
|
+
x = x.view(batch_size, num_points, 1, num_dims)
|
287
|
+
feature = feature - x
|
288
|
+
|
289
|
+
feature = feature.permute(0, 3, 1, 2).contiguous()
|
290
|
+
point_feature = self.xyz2feature(point_feature) #bs, c, n, k
|
291
|
+
feature = F.leaky_relu(feature + point_feature, 0.2)
|
292
|
+
return feature #bs, c, n, k
|
293
|
+
|
294
|
+
|
295
|
+
class PointNetFeaturePropagation(nn.Module):
|
296
|
+
def __init__(self, in_channel, mlp, att=None):
|
297
|
+
super(PointNetFeaturePropagation, self).__init__()
|
298
|
+
self.mlp_convs = nn.ModuleList()
|
299
|
+
self.mlp_bns = nn.ModuleList()
|
300
|
+
last_channel = in_channel
|
301
|
+
self.att = None
|
302
|
+
if att is not None:
|
303
|
+
self.att = Attention_block(F_g=att[0],F_l=att[1],F_int=att[2])
|
304
|
+
|
305
|
+
for out_channel in mlp:
|
306
|
+
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
|
307
|
+
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
|
308
|
+
last_channel = out_channel
|
309
|
+
|
310
|
+
def forward(self, xyz1, xyz2, points1, points2):
|
311
|
+
"""
|
312
|
+
Input:
|
313
|
+
xyz1: input points position data, [B, C, N]
|
314
|
+
xyz2: sampled input points position data, [B, C, S], skipped xyz
|
315
|
+
points1: input points data, [B, D, N]
|
316
|
+
points2: input points data, [B, D, S], skipped features
|
317
|
+
Return:
|
318
|
+
new_points: upsampled points data, [B, D', N]
|
319
|
+
"""
|
320
|
+
xyz1 = xyz1.permute(0, 2, 1)
|
321
|
+
xyz2 = xyz2.permute(0, 2, 1)
|
322
|
+
|
323
|
+
points2 = points2.permute(0, 2, 1)
|
324
|
+
B, N, C = xyz1.shape
|
325
|
+
_, S, _ = xyz2.shape
|
326
|
+
|
327
|
+
if S == 1:
|
328
|
+
interpolated_points = points2.repeat(1, N, 1)
|
329
|
+
else:
|
330
|
+
dists = square_distance(xyz1, xyz2)
|
331
|
+
dists, idx = dists.sort(dim=-1)
|
332
|
+
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
|
333
|
+
|
334
|
+
dist_recip = 1.0 / (dists + 1e-8)
|
335
|
+
norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
336
|
+
weight = dist_recip / norm
|
337
|
+
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
|
338
|
+
|
339
|
+
# skip attention
|
340
|
+
if self.att is not None:
|
341
|
+
psix, psig = self.att(interpolated_points.permute(0, 2, 1), points1)
|
342
|
+
points1 = points1 * psix
|
343
|
+
|
344
|
+
if points1 is not None:
|
345
|
+
points1 = points1.permute(0, 2, 1)
|
346
|
+
new_points = torch.cat([points1, interpolated_points], dim=-1)
|
347
|
+
else:
|
348
|
+
new_points = interpolated_points
|
349
|
+
|
350
|
+
new_points = new_points.permute(0, 2, 1)
|
351
|
+
|
352
|
+
for i, conv in enumerate(self.mlp_convs):
|
353
|
+
bn = self.mlp_bns[i]
|
354
|
+
new_points = F.leaky_relu(bn(conv(new_points)), 0.2)
|
355
|
+
|
356
|
+
return new_points
|
357
|
+
|
358
|
+
|
359
|
+
class CIC(nn.Module):
|
360
|
+
def __init__(self, npoint, radius, k, in_channels, output_channels, bottleneck_ratio=2, mlp_num=2, curve_config=None):
|
361
|
+
super(CIC, self).__init__()
|
362
|
+
self.in_channels = in_channels
|
363
|
+
self.output_channels = output_channels
|
364
|
+
self.bottleneck_ratio = bottleneck_ratio
|
365
|
+
self.radius = radius
|
366
|
+
self.k = k
|
367
|
+
self.npoint = npoint
|
368
|
+
|
369
|
+
planes = in_channels // bottleneck_ratio
|
370
|
+
|
371
|
+
self.use_curve = curve_config is not None
|
372
|
+
if self.use_curve:
|
373
|
+
self.curveaggregation = CurveAggregation(planes)
|
374
|
+
self.curvegrouping = CurveGrouping(planes, k, curve_config[0], curve_config[1])
|
375
|
+
|
376
|
+
self.conv1 = nn.Sequential(
|
377
|
+
nn.Conv1d(in_channels,
|
378
|
+
planes,
|
379
|
+
kernel_size=1,
|
380
|
+
bias=False),
|
381
|
+
nn.BatchNorm1d(in_channels // bottleneck_ratio),
|
382
|
+
nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
383
|
+
|
384
|
+
self.conv2 = nn.Sequential(
|
385
|
+
nn.Conv1d(planes, output_channels, kernel_size=1, bias=False),
|
386
|
+
nn.BatchNorm1d(output_channels))
|
387
|
+
|
388
|
+
if in_channels != output_channels:
|
389
|
+
self.shortcut = nn.Sequential(
|
390
|
+
nn.Conv1d(in_channels,
|
391
|
+
output_channels,
|
392
|
+
kernel_size=1,
|
393
|
+
bias=False),
|
394
|
+
nn.BatchNorm1d(output_channels))
|
395
|
+
|
396
|
+
self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
397
|
+
|
398
|
+
self.maxpool = MaskedMaxPool(npoint, radius, k)
|
399
|
+
|
400
|
+
self.lpfa = LPFA(planes, planes, k, mlp_num=mlp_num, initial=False)
|
401
|
+
|
402
|
+
def forward(self, xyz, x):
|
403
|
+
|
404
|
+
# max pool
|
405
|
+
if xyz.size(-1) != self.npoint:
|
406
|
+
xyz, x = self.maxpool(
|
407
|
+
xyz.transpose(1, 2).contiguous(), x)
|
408
|
+
xyz = xyz.transpose(1, 2)
|
409
|
+
|
410
|
+
shortcut = x
|
411
|
+
x = self.conv1(x) # bs, c', n
|
412
|
+
|
413
|
+
idx = knn(xyz, self.k, add_one_to_k=True)
|
414
|
+
|
415
|
+
if self.use_curve:
|
416
|
+
# curve grouping
|
417
|
+
curves, flatten_curve_idxs = self.curvegrouping(x, xyz, idx[:,:,1:]) # avoid self-loop
|
418
|
+
|
419
|
+
# curve aggregation
|
420
|
+
x = self.curveaggregation(x, curves)
|
421
|
+
else:
|
422
|
+
flatten_curve_idxs = None
|
423
|
+
|
424
|
+
x = self.lpfa(x, xyz, idx=idx[:,:,:self.k]) #bs, c', n, k
|
425
|
+
|
426
|
+
x = self.conv2(x) # bs, c, n
|
427
|
+
|
428
|
+
if self.in_channels != self.output_channels:
|
429
|
+
shortcut = self.shortcut(shortcut)
|
430
|
+
|
431
|
+
x = self.relu(x + shortcut)
|
432
|
+
return xyz, x, flatten_curve_idxs
|
433
|
+
|
434
|
+
|
435
|
+
class CurveAggregation(nn.Module):
|
436
|
+
def __init__(self, in_channel):
|
437
|
+
super(CurveAggregation, self).__init__()
|
438
|
+
self.in_channel = in_channel
|
439
|
+
mid_feature = in_channel // 2
|
440
|
+
self.conva = nn.Conv1d(in_channel,
|
441
|
+
mid_feature,
|
442
|
+
kernel_size=1,
|
443
|
+
bias=False)
|
444
|
+
self.convb = nn.Conv1d(in_channel,
|
445
|
+
mid_feature,
|
446
|
+
kernel_size=1,
|
447
|
+
bias=False)
|
448
|
+
self.convc = nn.Conv1d(in_channel,
|
449
|
+
mid_feature,
|
450
|
+
kernel_size=1,
|
451
|
+
bias=False)
|
452
|
+
self.convn = nn.Conv1d(mid_feature,
|
453
|
+
mid_feature,
|
454
|
+
kernel_size=1,
|
455
|
+
bias=False)
|
456
|
+
self.convl = nn.Conv1d(mid_feature,
|
457
|
+
mid_feature,
|
458
|
+
kernel_size=1,
|
459
|
+
bias=False)
|
460
|
+
self.convd = nn.Sequential(
|
461
|
+
nn.Conv1d(mid_feature * 2,
|
462
|
+
in_channel,
|
463
|
+
kernel_size=1,
|
464
|
+
bias=False),
|
465
|
+
nn.BatchNorm1d(in_channel))
|
466
|
+
self.line_conv_att = nn.Conv2d(in_channel,
|
467
|
+
1,
|
468
|
+
kernel_size=1,
|
469
|
+
bias=False)
|
470
|
+
|
471
|
+
def forward(self, x, curves):
|
472
|
+
curves_att = self.line_conv_att(curves) # bs, 1, c_n, c_l
|
473
|
+
|
474
|
+
curver_inter = torch.sum(curves * F.softmax(curves_att, dim=-1), dim=-1) #bs, c, c_n
|
475
|
+
curves_intra = torch.sum(curves * F.softmax(curves_att, dim=-2), dim=-2) #bs, c, c_l
|
476
|
+
|
477
|
+
curver_inter = self.conva(curver_inter) # bs, mid, n
|
478
|
+
curves_intra = self.convb(curves_intra) # bs, mid ,n
|
479
|
+
|
480
|
+
x_logits = self.convc(x).transpose(1, 2).contiguous()
|
481
|
+
x_inter = F.softmax(torch.bmm(x_logits, curver_inter), dim=-1) # bs, n, c_n
|
482
|
+
x_intra = F.softmax(torch.bmm(x_logits, curves_intra), dim=-1) # bs, l, c_l
|
483
|
+
|
484
|
+
|
485
|
+
curver_inter = self.convn(curver_inter).transpose(1, 2).contiguous()
|
486
|
+
curves_intra = self.convl(curves_intra).transpose(1, 2).contiguous()
|
487
|
+
|
488
|
+
x_inter = torch.bmm(x_inter, curver_inter)
|
489
|
+
x_intra = torch.bmm(x_intra, curves_intra)
|
490
|
+
|
491
|
+
curve_features = torch.cat((x_inter, x_intra),dim=-1).transpose(1, 2).contiguous()
|
492
|
+
x = x + self.convd(curve_features)
|
493
|
+
|
494
|
+
return F.leaky_relu(x, negative_slope=0.2)
|
495
|
+
|
496
|
+
|
497
|
+
class CurveGrouping(nn.Module):
|
498
|
+
def __init__(self, in_channel, k, curve_num, curve_length):
|
499
|
+
super(CurveGrouping, self).__init__()
|
500
|
+
self.curve_num = curve_num
|
501
|
+
self.curve_length = curve_length
|
502
|
+
self.in_channel = in_channel
|
503
|
+
self.k = k
|
504
|
+
|
505
|
+
self.att = nn.Conv1d(in_channel, 1, kernel_size=1, bias=False)
|
506
|
+
|
507
|
+
self.walk = Walk(in_channel, k, curve_num, curve_length)
|
508
|
+
|
509
|
+
def forward(self, x, xyz, idx):
|
510
|
+
# starting point selection in self attention style
|
511
|
+
x_att = torch.sigmoid(self.att(x))
|
512
|
+
x = x * x_att
|
513
|
+
|
514
|
+
_, start_index = torch.topk(x_att,
|
515
|
+
self.curve_num,
|
516
|
+
dim=2,
|
517
|
+
sorted=False)
|
518
|
+
start_index = start_index.squeeze(1).unsqueeze(2)
|
519
|
+
|
520
|
+
curves, flatten_curve_idxs = self.walk(xyz, x, idx, start_index) #bs, c, c_n, c_l
|
521
|
+
|
522
|
+
return curves, flatten_curve_idxs
|
523
|
+
|
524
|
+
|
525
|
+
class MaskedMaxPool(nn.Module):
|
526
|
+
def __init__(self, npoint, radius, k):
|
527
|
+
super(MaskedMaxPool, self).__init__()
|
528
|
+
self.npoint = npoint
|
529
|
+
self.radius = radius
|
530
|
+
self.k = k
|
531
|
+
|
532
|
+
def forward(self, xyz, features):
|
533
|
+
sub_xyz, neighborhood_features = sample_and_group(self.npoint, self.radius, self.k, xyz, features.transpose(1,2))
|
534
|
+
|
535
|
+
neighborhood_features = neighborhood_features.permute(0, 3, 1, 2).contiguous()
|
536
|
+
sub_features = F.max_pool2d(
|
537
|
+
neighborhood_features, kernel_size=[1, neighborhood_features.shape[3]]
|
538
|
+
) # bs, c, n, 1
|
539
|
+
sub_features = torch.squeeze(sub_features, -1) # bs, c, n
|
540
|
+
return sub_xyz, sub_features
|