ocnn 2.2.1__py3-none-any.whl → 2.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -158
- ocnn/models/__init__.py +29 -27
- ocnn/models/autoencoder.py +155 -165
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -0
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +20 -20
- ocnn/modules/modules.py +193 -231
- ocnn/modules/resblocks.py +124 -124
- ocnn/nn/__init__.py +42 -42
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +411 -411
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +204 -204
- ocnn/nn/octree_gconv.py +79 -0
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +86 -86
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -21
- ocnn/octree/octree.py +639 -601
- ocnn/octree/points.py +317 -298
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +202 -153
- {ocnn-2.2.1.dist-info → ocnn-2.2.2.dist-info}/LICENSE +21 -21
- {ocnn-2.2.1.dist-info → ocnn-2.2.2.dist-info}/METADATA +67 -65
- ocnn-2.2.2.dist-info/RECORD +36 -0
- {ocnn-2.2.1.dist-info → ocnn-2.2.2.dist-info}/WHEEL +1 -1
- ocnn-2.2.1.dist-info/RECORD +0 -34
- {ocnn-2.2.1.dist-info → ocnn-2.2.2.dist-info}/top_level.txt +0 -0
ocnn/nn/octree_drop.py
CHANGED
|
@@ -1,55 +1,55 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from typing import Optional
|
|
10
|
-
|
|
11
|
-
from ocnn.octree import Octree
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class OctreeDropPath(torch.nn.Module):
|
|
15
|
-
r'''Drop paths (Stochastic Depth) per sample when applied in main path of
|
|
16
|
-
residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
drop_prob (int): The probability of drop paths.
|
|
20
|
-
nempty (bool): Indicate whether the input data only contains features of the
|
|
21
|
-
non-empty octree nodes or not.
|
|
22
|
-
scale_by_keep (bool): Whether to scale the kept features proportionally.
|
|
23
|
-
'''
|
|
24
|
-
|
|
25
|
-
def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
|
|
26
|
-
scale_by_keep: bool = True):
|
|
27
|
-
super().__init__()
|
|
28
|
-
|
|
29
|
-
self.drop_prob = drop_prob
|
|
30
|
-
self.nempty = nempty
|
|
31
|
-
self.scale_by_keep = scale_by_keep
|
|
32
|
-
|
|
33
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
34
|
-
batch_id: Optional[torch.Tensor] = None):
|
|
35
|
-
r''''''
|
|
36
|
-
|
|
37
|
-
if self.drop_prob <= 0.0 or not self.training:
|
|
38
|
-
return data
|
|
39
|
-
|
|
40
|
-
batch_size = octree.batch_size
|
|
41
|
-
keep_prob = 1 - self.drop_prob
|
|
42
|
-
rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
|
|
43
|
-
rnd_tensor = torch.floor(rnd_tensor + keep_prob)
|
|
44
|
-
if keep_prob > 0.0 and self.scale_by_keep:
|
|
45
|
-
rnd_tensor.div_(keep_prob)
|
|
46
|
-
|
|
47
|
-
if batch_id is None:
|
|
48
|
-
batch_id = octree.batch_id(depth, self.nempty)
|
|
49
|
-
drop_mask = rnd_tensor[batch_id]
|
|
50
|
-
output = data * drop_mask
|
|
51
|
-
return output
|
|
52
|
-
|
|
53
|
-
def extra_repr(self) -> str:
|
|
54
|
-
return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
|
|
55
|
-
self.drop_prob, self.nempty, self.scale_by_keep) # noqa
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from ocnn.octree import Octree
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OctreeDropPath(torch.nn.Module):
|
|
15
|
+
r'''Drop paths (Stochastic Depth) per sample when applied in main path of
|
|
16
|
+
residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
drop_prob (int): The probability of drop paths.
|
|
20
|
+
nempty (bool): Indicate whether the input data only contains features of the
|
|
21
|
+
non-empty octree nodes or not.
|
|
22
|
+
scale_by_keep (bool): Whether to scale the kept features proportionally.
|
|
23
|
+
'''
|
|
24
|
+
|
|
25
|
+
def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
|
|
26
|
+
scale_by_keep: bool = True):
|
|
27
|
+
super().__init__()
|
|
28
|
+
|
|
29
|
+
self.drop_prob = drop_prob
|
|
30
|
+
self.nempty = nempty
|
|
31
|
+
self.scale_by_keep = scale_by_keep
|
|
32
|
+
|
|
33
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
34
|
+
batch_id: Optional[torch.Tensor] = None):
|
|
35
|
+
r''''''
|
|
36
|
+
|
|
37
|
+
if self.drop_prob <= 0.0 or not self.training:
|
|
38
|
+
return data
|
|
39
|
+
|
|
40
|
+
batch_size = octree.batch_size
|
|
41
|
+
keep_prob = 1 - self.drop_prob
|
|
42
|
+
rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
|
|
43
|
+
rnd_tensor = torch.floor(rnd_tensor + keep_prob)
|
|
44
|
+
if keep_prob > 0.0 and self.scale_by_keep:
|
|
45
|
+
rnd_tensor.div_(keep_prob)
|
|
46
|
+
|
|
47
|
+
if batch_id is None:
|
|
48
|
+
batch_id = octree.batch_id(depth, self.nempty)
|
|
49
|
+
drop_mask = rnd_tensor[batch_id]
|
|
50
|
+
output = data * drop_mask
|
|
51
|
+
return output
|
|
52
|
+
|
|
53
|
+
def extra_repr(self) -> str:
|
|
54
|
+
return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
|
|
55
|
+
self.drop_prob, self.nempty, self.scale_by_keep) # noqa
|
ocnn/nn/octree_dwconv.py
CHANGED
|
@@ -1,204 +1,204 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn
|
|
10
|
-
from torch.autograd import Function
|
|
11
|
-
from typing import List
|
|
12
|
-
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
from ocnn.utils import scatter_add, xavier_uniform_
|
|
15
|
-
from .octree_pad import octree_pad
|
|
16
|
-
from .octree_conv import OctreeConvBase
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class OctreeDWConvBase(OctreeConvBase):
|
|
20
|
-
|
|
21
|
-
def __init__(self, in_channels: int, kernel_size: List[int] = [3],
|
|
22
|
-
stride: int = 1, nempty: bool = False,
|
|
23
|
-
max_buffer: int = int(2e8)):
|
|
24
|
-
super().__init__(
|
|
25
|
-
in_channels, in_channels, kernel_size, stride, nempty, max_buffer)
|
|
26
|
-
self.weights_shape = (self.kdim, 1, self.out_channels)
|
|
27
|
-
|
|
28
|
-
def is_conv_layer(self): return True
|
|
29
|
-
|
|
30
|
-
def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
|
|
31
|
-
weights: torch.Tensor):
|
|
32
|
-
r''' Peforms the forward pass of octree-based convolution.
|
|
33
|
-
'''
|
|
34
|
-
|
|
35
|
-
# Initialize the buffer
|
|
36
|
-
buffer = data.new_empty(self.buffer_shape)
|
|
37
|
-
|
|
38
|
-
# Loop over each sub-matrix
|
|
39
|
-
for i in range(self.buffer_n):
|
|
40
|
-
start = i * self.buffer_h
|
|
41
|
-
end = (i + 1) * self.buffer_h
|
|
42
|
-
|
|
43
|
-
# The boundary case in the last iteration
|
|
44
|
-
if end > self.neigh.shape[0]:
|
|
45
|
-
dis = end - self.neigh.shape[0]
|
|
46
|
-
end = self.neigh.shape[0]
|
|
47
|
-
buffer, _ = buffer.split([self.buffer_h-dis, dis])
|
|
48
|
-
|
|
49
|
-
# Perform octree2col
|
|
50
|
-
neigh_i = self.neigh[start:end]
|
|
51
|
-
valid = neigh_i >= 0
|
|
52
|
-
buffer.fill_(0)
|
|
53
|
-
buffer[valid] = data[neigh_i[valid]]
|
|
54
|
-
|
|
55
|
-
# The sub-matrix gemm
|
|
56
|
-
# out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
|
|
57
|
-
out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1))
|
|
58
|
-
return out
|
|
59
|
-
|
|
60
|
-
def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
|
|
61
|
-
weights: torch.Tensor):
|
|
62
|
-
r''' Performs the backward pass of octree-based convolution.
|
|
63
|
-
'''
|
|
64
|
-
|
|
65
|
-
# Loop over each sub-matrix
|
|
66
|
-
for i in range(self.buffer_n):
|
|
67
|
-
start = i * self.buffer_h
|
|
68
|
-
end = (i + 1) * self.buffer_h
|
|
69
|
-
|
|
70
|
-
# The boundary case in the last iteration
|
|
71
|
-
if end > self.neigh.shape[0]:
|
|
72
|
-
end = self.neigh.shape[0]
|
|
73
|
-
|
|
74
|
-
# The sub-matrix gemm
|
|
75
|
-
# buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
|
|
76
|
-
# buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
|
|
77
|
-
buffer = torch.einsum(
|
|
78
|
-
'ic,kc->ikc', grad[start:end], weights.flatten(0, 1))
|
|
79
|
-
|
|
80
|
-
# Performs col2octree
|
|
81
|
-
neigh_i = self.neigh[start:end]
|
|
82
|
-
valid = neigh_i >= 0
|
|
83
|
-
out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
|
|
84
|
-
|
|
85
|
-
return out
|
|
86
|
-
|
|
87
|
-
def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
|
|
88
|
-
r''' Computes the gradient of the weight matrix.
|
|
89
|
-
'''
|
|
90
|
-
|
|
91
|
-
# Record the shape of out
|
|
92
|
-
out_shape = out.shape
|
|
93
|
-
out = out.flatten(0, 1)
|
|
94
|
-
|
|
95
|
-
# Initialize the buffer
|
|
96
|
-
buffer = data.new_empty(self.buffer_shape)
|
|
97
|
-
|
|
98
|
-
# Loop over each sub-matrix
|
|
99
|
-
for i in range(self.buffer_n):
|
|
100
|
-
start = i * self.buffer_h
|
|
101
|
-
end = (i + 1) * self.buffer_h
|
|
102
|
-
|
|
103
|
-
# The boundary case in the last iteration
|
|
104
|
-
if end > self.neigh.shape[0]:
|
|
105
|
-
d = end - self.neigh.shape[0]
|
|
106
|
-
end = self.neigh.shape[0]
|
|
107
|
-
buffer, _ = buffer.split([self.buffer_h-d, d])
|
|
108
|
-
|
|
109
|
-
# Perform octree2col
|
|
110
|
-
neigh_i = self.neigh[start:end]
|
|
111
|
-
valid = neigh_i >= 0
|
|
112
|
-
buffer.fill_(0)
|
|
113
|
-
buffer[valid] = data[neigh_i[valid]]
|
|
114
|
-
|
|
115
|
-
# Accumulate the gradient via gemm
|
|
116
|
-
# out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
|
|
117
|
-
out += torch.einsum('ikc,ic->kc', buffer, grad[start:end])
|
|
118
|
-
return out.view(out_shape)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
class OctreeDWConvFunction(Function):
|
|
122
|
-
r''' Wrap the octree convolution for auto-diff.
|
|
123
|
-
'''
|
|
124
|
-
|
|
125
|
-
@staticmethod
|
|
126
|
-
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
|
|
127
|
-
depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3],
|
|
128
|
-
stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)):
|
|
129
|
-
octree_conv = OctreeDWConvBase(
|
|
130
|
-
in_channels, kernel_size, stride, nempty, max_buffer)
|
|
131
|
-
octree_conv.setup(octree, depth)
|
|
132
|
-
out = octree_conv.check_and_init(data)
|
|
133
|
-
out = octree_conv.forward_gemm(out, data, weights)
|
|
134
|
-
|
|
135
|
-
ctx.save_for_backward(data, weights)
|
|
136
|
-
ctx.octree_conv = octree_conv
|
|
137
|
-
return out
|
|
138
|
-
|
|
139
|
-
@staticmethod
|
|
140
|
-
def backward(ctx, grad):
|
|
141
|
-
data, weights = ctx.saved_tensors
|
|
142
|
-
octree_conv = ctx.octree_conv
|
|
143
|
-
|
|
144
|
-
grad_out = None
|
|
145
|
-
if ctx.needs_input_grad[0]:
|
|
146
|
-
grad_out = torch.zeros_like(data)
|
|
147
|
-
grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
|
|
148
|
-
|
|
149
|
-
grad_w = None
|
|
150
|
-
if ctx.needs_input_grad[1]:
|
|
151
|
-
grad_w = torch.zeros_like(weights)
|
|
152
|
-
grad_w = octree_conv.weight_gemm(grad_w, data, grad)
|
|
153
|
-
|
|
154
|
-
return (grad_out, grad_w) + (None,) * 7
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
# alias
|
|
158
|
-
octree_dwconv = OctreeDWConvFunction.apply
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
class OctreeDWConv(OctreeDWConvBase, torch.nn.Module):
|
|
162
|
-
r''' Performs octree-based depth-wise convolution.
|
|
163
|
-
|
|
164
|
-
Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments.
|
|
165
|
-
|
|
166
|
-
.. note::
|
|
167
|
-
This implementation uses the :func:`torch.einsum` and I find that the speed
|
|
168
|
-
is relatively slow. Further optimization is needed to speed it up.
|
|
169
|
-
'''
|
|
170
|
-
|
|
171
|
-
def __init__(self, in_channels: int, kernel_size: List[int] = [3],
|
|
172
|
-
stride: int = 1, nempty: bool = False, use_bias: bool = False,
|
|
173
|
-
max_buffer: int = int(2e8)):
|
|
174
|
-
super().__init__(in_channels, kernel_size, stride, nempty, max_buffer)
|
|
175
|
-
|
|
176
|
-
self.use_bias = use_bias
|
|
177
|
-
self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
|
|
178
|
-
if self.use_bias:
|
|
179
|
-
self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
|
|
180
|
-
self.reset_parameters()
|
|
181
|
-
|
|
182
|
-
def reset_parameters(self):
|
|
183
|
-
xavier_uniform_(self.weights)
|
|
184
|
-
if self.use_bias:
|
|
185
|
-
torch.nn.init.zeros_(self.bias)
|
|
186
|
-
|
|
187
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
188
|
-
r''''''
|
|
189
|
-
|
|
190
|
-
out = octree_dwconv(
|
|
191
|
-
data, self.weights, octree, depth, self.in_channels,
|
|
192
|
-
self.kernel_size, self.stride, self.nempty, self.max_buffer)
|
|
193
|
-
|
|
194
|
-
if self.use_bias:
|
|
195
|
-
out += self.bias
|
|
196
|
-
|
|
197
|
-
if self.stride == 2 and not self.nempty:
|
|
198
|
-
out = octree_pad(out, octree, depth-1)
|
|
199
|
-
return out
|
|
200
|
-
|
|
201
|
-
def extra_repr(self) -> str:
|
|
202
|
-
return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
|
|
203
|
-
'nempty={}, bias={}').format(self.in_channels, self.out_channels,
|
|
204
|
-
self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from torch.autograd import Function
|
|
11
|
+
from typing import List
|
|
12
|
+
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
from ocnn.utils import scatter_add, xavier_uniform_
|
|
15
|
+
from .octree_pad import octree_pad
|
|
16
|
+
from .octree_conv import OctreeConvBase
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OctreeDWConvBase(OctreeConvBase):
|
|
20
|
+
|
|
21
|
+
def __init__(self, in_channels: int, kernel_size: List[int] = [3],
|
|
22
|
+
stride: int = 1, nempty: bool = False,
|
|
23
|
+
max_buffer: int = int(2e8)):
|
|
24
|
+
super().__init__(
|
|
25
|
+
in_channels, in_channels, kernel_size, stride, nempty, max_buffer)
|
|
26
|
+
self.weights_shape = (self.kdim, 1, self.out_channels)
|
|
27
|
+
|
|
28
|
+
def is_conv_layer(self): return True
|
|
29
|
+
|
|
30
|
+
def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
|
|
31
|
+
weights: torch.Tensor):
|
|
32
|
+
r''' Peforms the forward pass of octree-based convolution.
|
|
33
|
+
'''
|
|
34
|
+
|
|
35
|
+
# Initialize the buffer
|
|
36
|
+
buffer = data.new_empty(self.buffer_shape)
|
|
37
|
+
|
|
38
|
+
# Loop over each sub-matrix
|
|
39
|
+
for i in range(self.buffer_n):
|
|
40
|
+
start = i * self.buffer_h
|
|
41
|
+
end = (i + 1) * self.buffer_h
|
|
42
|
+
|
|
43
|
+
# The boundary case in the last iteration
|
|
44
|
+
if end > self.neigh.shape[0]:
|
|
45
|
+
dis = end - self.neigh.shape[0]
|
|
46
|
+
end = self.neigh.shape[0]
|
|
47
|
+
buffer, _ = buffer.split([self.buffer_h-dis, dis])
|
|
48
|
+
|
|
49
|
+
# Perform octree2col
|
|
50
|
+
neigh_i = self.neigh[start:end]
|
|
51
|
+
valid = neigh_i >= 0
|
|
52
|
+
buffer.fill_(0)
|
|
53
|
+
buffer[valid] = data[neigh_i[valid]]
|
|
54
|
+
|
|
55
|
+
# The sub-matrix gemm
|
|
56
|
+
# out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
|
|
57
|
+
out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1))
|
|
58
|
+
return out
|
|
59
|
+
|
|
60
|
+
def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
|
|
61
|
+
weights: torch.Tensor):
|
|
62
|
+
r''' Performs the backward pass of octree-based convolution.
|
|
63
|
+
'''
|
|
64
|
+
|
|
65
|
+
# Loop over each sub-matrix
|
|
66
|
+
for i in range(self.buffer_n):
|
|
67
|
+
start = i * self.buffer_h
|
|
68
|
+
end = (i + 1) * self.buffer_h
|
|
69
|
+
|
|
70
|
+
# The boundary case in the last iteration
|
|
71
|
+
if end > self.neigh.shape[0]:
|
|
72
|
+
end = self.neigh.shape[0]
|
|
73
|
+
|
|
74
|
+
# The sub-matrix gemm
|
|
75
|
+
# buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
|
|
76
|
+
# buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
|
|
77
|
+
buffer = torch.einsum(
|
|
78
|
+
'ic,kc->ikc', grad[start:end], weights.flatten(0, 1))
|
|
79
|
+
|
|
80
|
+
# Performs col2octree
|
|
81
|
+
neigh_i = self.neigh[start:end]
|
|
82
|
+
valid = neigh_i >= 0
|
|
83
|
+
out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
|
|
84
|
+
|
|
85
|
+
return out
|
|
86
|
+
|
|
87
|
+
def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
|
|
88
|
+
r''' Computes the gradient of the weight matrix.
|
|
89
|
+
'''
|
|
90
|
+
|
|
91
|
+
# Record the shape of out
|
|
92
|
+
out_shape = out.shape
|
|
93
|
+
out = out.flatten(0, 1)
|
|
94
|
+
|
|
95
|
+
# Initialize the buffer
|
|
96
|
+
buffer = data.new_empty(self.buffer_shape)
|
|
97
|
+
|
|
98
|
+
# Loop over each sub-matrix
|
|
99
|
+
for i in range(self.buffer_n):
|
|
100
|
+
start = i * self.buffer_h
|
|
101
|
+
end = (i + 1) * self.buffer_h
|
|
102
|
+
|
|
103
|
+
# The boundary case in the last iteration
|
|
104
|
+
if end > self.neigh.shape[0]:
|
|
105
|
+
d = end - self.neigh.shape[0]
|
|
106
|
+
end = self.neigh.shape[0]
|
|
107
|
+
buffer, _ = buffer.split([self.buffer_h-d, d])
|
|
108
|
+
|
|
109
|
+
# Perform octree2col
|
|
110
|
+
neigh_i = self.neigh[start:end]
|
|
111
|
+
valid = neigh_i >= 0
|
|
112
|
+
buffer.fill_(0)
|
|
113
|
+
buffer[valid] = data[neigh_i[valid]]
|
|
114
|
+
|
|
115
|
+
# Accumulate the gradient via gemm
|
|
116
|
+
# out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
|
|
117
|
+
out += torch.einsum('ikc,ic->kc', buffer, grad[start:end])
|
|
118
|
+
return out.view(out_shape)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class OctreeDWConvFunction(Function):
|
|
122
|
+
r''' Wrap the octree convolution for auto-diff.
|
|
123
|
+
'''
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
|
|
127
|
+
depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3],
|
|
128
|
+
stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)):
|
|
129
|
+
octree_conv = OctreeDWConvBase(
|
|
130
|
+
in_channels, kernel_size, stride, nempty, max_buffer)
|
|
131
|
+
octree_conv.setup(octree, depth)
|
|
132
|
+
out = octree_conv.check_and_init(data)
|
|
133
|
+
out = octree_conv.forward_gemm(out, data, weights)
|
|
134
|
+
|
|
135
|
+
ctx.save_for_backward(data, weights)
|
|
136
|
+
ctx.octree_conv = octree_conv
|
|
137
|
+
return out
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def backward(ctx, grad):
|
|
141
|
+
data, weights = ctx.saved_tensors
|
|
142
|
+
octree_conv = ctx.octree_conv
|
|
143
|
+
|
|
144
|
+
grad_out = None
|
|
145
|
+
if ctx.needs_input_grad[0]:
|
|
146
|
+
grad_out = torch.zeros_like(data)
|
|
147
|
+
grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
|
|
148
|
+
|
|
149
|
+
grad_w = None
|
|
150
|
+
if ctx.needs_input_grad[1]:
|
|
151
|
+
grad_w = torch.zeros_like(weights)
|
|
152
|
+
grad_w = octree_conv.weight_gemm(grad_w, data, grad)
|
|
153
|
+
|
|
154
|
+
return (grad_out, grad_w) + (None,) * 7
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# alias
|
|
158
|
+
octree_dwconv = OctreeDWConvFunction.apply
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class OctreeDWConv(OctreeDWConvBase, torch.nn.Module):
|
|
162
|
+
r''' Performs octree-based depth-wise convolution.
|
|
163
|
+
|
|
164
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments.
|
|
165
|
+
|
|
166
|
+
.. note::
|
|
167
|
+
This implementation uses the :func:`torch.einsum` and I find that the speed
|
|
168
|
+
is relatively slow. Further optimization is needed to speed it up.
|
|
169
|
+
'''
|
|
170
|
+
|
|
171
|
+
def __init__(self, in_channels: int, kernel_size: List[int] = [3],
|
|
172
|
+
stride: int = 1, nempty: bool = False, use_bias: bool = False,
|
|
173
|
+
max_buffer: int = int(2e8)):
|
|
174
|
+
super().__init__(in_channels, kernel_size, stride, nempty, max_buffer)
|
|
175
|
+
|
|
176
|
+
self.use_bias = use_bias
|
|
177
|
+
self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
|
|
178
|
+
if self.use_bias:
|
|
179
|
+
self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
|
|
180
|
+
self.reset_parameters()
|
|
181
|
+
|
|
182
|
+
def reset_parameters(self):
|
|
183
|
+
xavier_uniform_(self.weights)
|
|
184
|
+
if self.use_bias:
|
|
185
|
+
torch.nn.init.zeros_(self.bias)
|
|
186
|
+
|
|
187
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
188
|
+
r''''''
|
|
189
|
+
|
|
190
|
+
out = octree_dwconv(
|
|
191
|
+
data, self.weights, octree, depth, self.in_channels,
|
|
192
|
+
self.kernel_size, self.stride, self.nempty, self.max_buffer)
|
|
193
|
+
|
|
194
|
+
if self.use_bias:
|
|
195
|
+
out += self.bias
|
|
196
|
+
|
|
197
|
+
if self.stride == 2 and not self.nempty:
|
|
198
|
+
out = octree_pad(out, octree, depth-1)
|
|
199
|
+
return out
|
|
200
|
+
|
|
201
|
+
def extra_repr(self) -> str:
|
|
202
|
+
return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
|
|
203
|
+
'nempty={}, bias={}').format(self.in_channels, self.out_channels,
|
|
204
|
+
self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
|
ocnn/nn/octree_gconv.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from typing import List
|
|
11
|
+
|
|
12
|
+
import ocnn
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OctreeGroupConv(torch.nn.Module):
|
|
17
|
+
r''' Performs octree-based group convolution.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
in_channels (int): Number of input channels.
|
|
21
|
+
out_channels (int): Number of output channels.
|
|
22
|
+
kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
|
|
23
|
+
:obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
|
|
24
|
+
:obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
|
|
25
|
+
stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
|
|
26
|
+
nempty (bool): If True, only performs the convolution on non-empty
|
|
27
|
+
octree nodes.
|
|
28
|
+
use_bias (bool): If True, add a bias term to the convolution.
|
|
29
|
+
group (int): The number of groups.
|
|
30
|
+
|
|
31
|
+
.. note::
|
|
32
|
+
Perform octree-based group convolution with a for-loop. The performance is
|
|
33
|
+
not optimal. Use this module only when the group number is small, otherwise
|
|
34
|
+
it may be slow.
|
|
35
|
+
'''
|
|
36
|
+
|
|
37
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
38
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
39
|
+
nempty: bool = False, use_bias: bool = False,
|
|
40
|
+
group: int = 1):
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
self.group = group
|
|
44
|
+
self.in_channels = in_channels
|
|
45
|
+
self.out_channels = out_channels
|
|
46
|
+
self.in_channels_per_group = in_channels // group
|
|
47
|
+
self.out_channels_per_group = out_channels // group
|
|
48
|
+
assert in_channels % group == 0 and out_channels % group == 0
|
|
49
|
+
|
|
50
|
+
self.convs = torch.nn.ModuleList([ocnn.nn.OctreeConv(
|
|
51
|
+
self.in_channels_per_group, self.out_channels_per_group,
|
|
52
|
+
kernel_size, stride, nempty, use_bias=use_bias)
|
|
53
|
+
for _ in range(group)])
|
|
54
|
+
|
|
55
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
56
|
+
r''' Defines the octree-based group convolution.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
data (torch.Tensor): The input data.
|
|
60
|
+
octree (Octree): The corresponding octree.
|
|
61
|
+
depth (int): The depth of current octree.
|
|
62
|
+
'''
|
|
63
|
+
|
|
64
|
+
channels = data.shape[1]
|
|
65
|
+
assert channels == self.in_channels
|
|
66
|
+
|
|
67
|
+
outs = [None] * self.group
|
|
68
|
+
slices = torch.split(data, self.in_channels_per_group, dim=1)
|
|
69
|
+
for i in range(self.group):
|
|
70
|
+
outs[i] = self.convs[i](slices[i], octree, depth)
|
|
71
|
+
out = torch.cat(outs, dim=1)
|
|
72
|
+
return out
|
|
73
|
+
|
|
74
|
+
def extra_repr(self) -> str:
|
|
75
|
+
r''' Sets the extra representation of the module.
|
|
76
|
+
'''
|
|
77
|
+
|
|
78
|
+
return ('in_channels={}, out_channels={}, group={}').format(
|
|
79
|
+
self.in_channels, self.out_channels, self.group) # noqa
|