ocnn 2.2.5__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +44 -44
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +429 -429
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +661 -659
- ocnn/octree/points.py +323 -322
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- ocnn-2.2.7.dist-info/METADATA +112 -0
- ocnn-2.2.7.dist-info/RECORD +36 -0
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info}/WHEEL +1 -1
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info/licenses}/LICENSE +21 -21
- ocnn-2.2.5.dist-info/METADATA +0 -80
- ocnn-2.2.5.dist-info/RECORD +0 -36
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info}/top_level.txt +0 -0
ocnn/nn/octree_conv.py
CHANGED
|
@@ -1,429 +1,429 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn
|
|
10
|
-
from torch.autograd import Function
|
|
11
|
-
from typing import List
|
|
12
|
-
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
from ocnn.utils import scatter_add, xavier_uniform_, resize_with_last_val, list2str
|
|
15
|
-
from .octree2col import octree2col, col2octree
|
|
16
|
-
from .octree_pad import octree_pad, octree_depad
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class OctreeConvBase:
|
|
20
|
-
|
|
21
|
-
def __init__(self, in_channels: int, out_channels: int,
|
|
22
|
-
kernel_size: List[int] = [3], stride: int = 1,
|
|
23
|
-
nempty: bool = False, max_buffer: int = int(2e8)):
|
|
24
|
-
super().__init__()
|
|
25
|
-
self.in_channels = in_channels
|
|
26
|
-
self.out_channels = out_channels
|
|
27
|
-
self.kernel_size = resize_with_last_val(kernel_size)
|
|
28
|
-
self.kernel = list2str(self.kernel_size)
|
|
29
|
-
self.stride = stride
|
|
30
|
-
self.nempty = nempty
|
|
31
|
-
self.max_buffer = max_buffer # about 200M
|
|
32
|
-
|
|
33
|
-
self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
|
|
34
|
-
self.in_conv = in_channels if self.is_conv_layer() else out_channels
|
|
35
|
-
self.out_conv = out_channels if self.is_conv_layer() else in_channels
|
|
36
|
-
self.weights_shape = (self.kdim, self.in_conv, self.out_conv)
|
|
37
|
-
|
|
38
|
-
def is_conv_layer(self):
|
|
39
|
-
r''' Returns :obj:`True` to indicate this is a convolution layer.
|
|
40
|
-
'''
|
|
41
|
-
|
|
42
|
-
raise NotImplementedError
|
|
43
|
-
|
|
44
|
-
def setup(self, octree: Octree, depth: int):
|
|
45
|
-
r''' Setup the shapes of each tensor.
|
|
46
|
-
This function MUST be called before :obj:`forward_gemm`, :obj:`backward_gemm`
|
|
47
|
-
and :obj:`weight_gemm`.
|
|
48
|
-
'''
|
|
49
|
-
|
|
50
|
-
# The depth of tensors:
|
|
51
|
-
# The in_depth and out_depth are the octree depth of the input and output
|
|
52
|
-
# data; neigh_depth is the octree depth of the neighborhood information, as
|
|
53
|
-
# well as `col` data, neigh_depth is always the same as the depth of larger
|
|
54
|
-
# data when doing octree2col or col2octree.
|
|
55
|
-
self.in_depth = depth
|
|
56
|
-
self.out_depth = depth
|
|
57
|
-
self.neigh_depth = depth
|
|
58
|
-
if self.stride == 2:
|
|
59
|
-
if self.is_conv_layer():
|
|
60
|
-
self.out_depth = depth - 1
|
|
61
|
-
else:
|
|
62
|
-
self.out_depth = depth + 1
|
|
63
|
-
self.neigh_depth = depth + 1
|
|
64
|
-
|
|
65
|
-
# The height of tensors
|
|
66
|
-
if self.nempty:
|
|
67
|
-
self.in_h = octree.nnum_nempty[self.in_depth]
|
|
68
|
-
self.out_h = octree.nnum_nempty[self.out_depth]
|
|
69
|
-
else:
|
|
70
|
-
self.in_h = octree.nnum[self.in_depth]
|
|
71
|
-
self.out_h = octree.nnum[self.out_depth]
|
|
72
|
-
if self.stride == 2:
|
|
73
|
-
if self.is_conv_layer():
|
|
74
|
-
self.out_h = octree.nnum_nempty[self.out_depth]
|
|
75
|
-
else:
|
|
76
|
-
self.in_h = octree.nnum_nempty[self.in_depth]
|
|
77
|
-
self.in_shape = (self.in_h, self.in_channels)
|
|
78
|
-
self.out_shape = (self.out_h, self.out_channels)
|
|
79
|
-
|
|
80
|
-
# The neighborhood indices
|
|
81
|
-
self.neigh = octree.get_neigh(
|
|
82
|
-
self.neigh_depth, self.kernel, self.stride, self.nempty)
|
|
83
|
-
|
|
84
|
-
# The heigh and number of the temporary buffer
|
|
85
|
-
self.buffer_n = 1
|
|
86
|
-
self.buffer_h = self.neigh.shape[0]
|
|
87
|
-
ideal_size = self.buffer_h * self.kdim * self.in_conv
|
|
88
|
-
if ideal_size > self.max_buffer:
|
|
89
|
-
kc = self.kdim * self.in_conv # make `max_buffer` be divided
|
|
90
|
-
max_buffer = self.max_buffer // kc * kc # by `kc` with no remainder
|
|
91
|
-
self.buffer_n = (ideal_size + max_buffer - 1) // max_buffer
|
|
92
|
-
self.buffer_h = (self.buffer_h + self.buffer_n - 1) // self.buffer_n
|
|
93
|
-
self.buffer_shape = (self.buffer_h, self.kdim, self.in_conv)
|
|
94
|
-
|
|
95
|
-
def check_and_init(self, data: torch.Tensor):
|
|
96
|
-
r''' Checks the input data and initializes the shape of output data.
|
|
97
|
-
'''
|
|
98
|
-
|
|
99
|
-
# Check the shape of input data
|
|
100
|
-
check = tuple(data.shape) == self.in_shape
|
|
101
|
-
assert check, 'The shape of input data is wrong.'
|
|
102
|
-
|
|
103
|
-
# Init the output data
|
|
104
|
-
out = data.new_zeros(self.out_shape)
|
|
105
|
-
return out
|
|
106
|
-
|
|
107
|
-
def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
|
|
108
|
-
weights: torch.Tensor):
|
|
109
|
-
r''' Peforms the forward pass of octree-based convolution.
|
|
110
|
-
'''
|
|
111
|
-
|
|
112
|
-
# Type check
|
|
113
|
-
if data.dtype != out.dtype:
|
|
114
|
-
data = data.to(out.dtype)
|
|
115
|
-
if weights.dtype != out.dtype:
|
|
116
|
-
weights = weights.to(out.dtype)
|
|
117
|
-
|
|
118
|
-
# Initialize the buffer
|
|
119
|
-
buffer = data.new_empty(self.buffer_shape)
|
|
120
|
-
|
|
121
|
-
# Loop over each sub-matrix
|
|
122
|
-
for i in range(self.buffer_n):
|
|
123
|
-
start = i * self.buffer_h
|
|
124
|
-
end = (i + 1) * self.buffer_h
|
|
125
|
-
|
|
126
|
-
# The boundary case in the last iteration
|
|
127
|
-
if end > self.neigh.shape[0]:
|
|
128
|
-
dis = end - self.neigh.shape[0]
|
|
129
|
-
end = self.neigh.shape[0]
|
|
130
|
-
buffer, _ = buffer.split([self.buffer_h-dis, dis])
|
|
131
|
-
|
|
132
|
-
# Perform octree2col
|
|
133
|
-
neigh_i = self.neigh[start:end]
|
|
134
|
-
valid = neigh_i >= 0
|
|
135
|
-
buffer.fill_(0)
|
|
136
|
-
buffer[valid] = data[neigh_i[valid]]
|
|
137
|
-
|
|
138
|
-
# The sub-matrix gemm
|
|
139
|
-
out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
|
|
140
|
-
|
|
141
|
-
return out
|
|
142
|
-
|
|
143
|
-
def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
|
|
144
|
-
weights: torch.Tensor):
|
|
145
|
-
r''' Performs the backward pass of octree-based convolution.
|
|
146
|
-
'''
|
|
147
|
-
|
|
148
|
-
# Type check
|
|
149
|
-
if grad.dtype != out.dtype:
|
|
150
|
-
grad = grad.to(out.dtype)
|
|
151
|
-
if weights.dtype != out.dtype:
|
|
152
|
-
weights = weights.to(out.dtype)
|
|
153
|
-
|
|
154
|
-
# Loop over each sub-matrix
|
|
155
|
-
for i in range(self.buffer_n):
|
|
156
|
-
start = i * self.buffer_h
|
|
157
|
-
end = (i + 1) * self.buffer_h
|
|
158
|
-
|
|
159
|
-
# The boundary case in the last iteration
|
|
160
|
-
if end > self.neigh.shape[0]:
|
|
161
|
-
end = self.neigh.shape[0]
|
|
162
|
-
|
|
163
|
-
# The sub-matrix gemm
|
|
164
|
-
buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
|
|
165
|
-
buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
|
|
166
|
-
buffer = buffer.to(out.dtype) # for pytorch.amp
|
|
167
|
-
|
|
168
|
-
# Performs col2octree
|
|
169
|
-
neigh_i = self.neigh[start:end]
|
|
170
|
-
valid = neigh_i >= 0
|
|
171
|
-
out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
|
|
172
|
-
|
|
173
|
-
return out
|
|
174
|
-
|
|
175
|
-
def weight_gemm(
|
|
176
|
-
self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
|
|
177
|
-
r''' Computes the gradient of the weight matrix.
|
|
178
|
-
'''
|
|
179
|
-
|
|
180
|
-
# Type check
|
|
181
|
-
if data.dtype != out.dtype:
|
|
182
|
-
data = data.to(out.dtype)
|
|
183
|
-
if grad.dtype != out.dtype:
|
|
184
|
-
grad = grad.to(out.dtype)
|
|
185
|
-
|
|
186
|
-
# Record the shape of out
|
|
187
|
-
out_shape = out.shape
|
|
188
|
-
out = out.flatten(0, 1)
|
|
189
|
-
|
|
190
|
-
# Initialize the buffer
|
|
191
|
-
buffer = data.new_empty(self.buffer_shape)
|
|
192
|
-
|
|
193
|
-
# Loop over each sub-matrix
|
|
194
|
-
for i in range(self.buffer_n):
|
|
195
|
-
start = i * self.buffer_h
|
|
196
|
-
end = (i + 1) * self.buffer_h
|
|
197
|
-
|
|
198
|
-
# The boundary case in the last iteration
|
|
199
|
-
if end > self.neigh.shape[0]:
|
|
200
|
-
d = end - self.neigh.shape[0]
|
|
201
|
-
end = self.neigh.shape[0]
|
|
202
|
-
buffer, _ = buffer.split([self.buffer_h-d, d])
|
|
203
|
-
|
|
204
|
-
# Perform octree2col
|
|
205
|
-
neigh_i = self.neigh[start:end]
|
|
206
|
-
valid = neigh_i >= 0
|
|
207
|
-
buffer.fill_(0)
|
|
208
|
-
buffer[valid] = data[neigh_i[valid]]
|
|
209
|
-
|
|
210
|
-
# Accumulate the gradient via gemm
|
|
211
|
-
out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
|
|
212
|
-
|
|
213
|
-
return out.view(out_shape)
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
class _OctreeConv(OctreeConvBase):
|
|
217
|
-
r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
|
|
218
|
-
'''
|
|
219
|
-
|
|
220
|
-
def is_conv_layer(self): return True
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
class _OctreeDeconv(OctreeConvBase):
|
|
224
|
-
r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
|
|
225
|
-
'''
|
|
226
|
-
|
|
227
|
-
def is_conv_layer(self): return False
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
class OctreeConvFunction(Function):
|
|
231
|
-
r''' Wrap the octree convolution for auto-diff.
|
|
232
|
-
'''
|
|
233
|
-
|
|
234
|
-
@staticmethod
|
|
235
|
-
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
|
|
236
|
-
depth: int, in_channels: int, out_channels: int,
|
|
237
|
-
kernel_size: List[int] = [3, 3, 3], stride: int = 1,
|
|
238
|
-
nempty: bool = False, max_buffer: int = int(2e8)):
|
|
239
|
-
octree_conv = _OctreeConv(
|
|
240
|
-
in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
|
|
241
|
-
octree_conv.setup(octree, depth)
|
|
242
|
-
out = octree_conv.check_and_init(data)
|
|
243
|
-
out = octree_conv.forward_gemm(out, data, weights)
|
|
244
|
-
|
|
245
|
-
ctx.save_for_backward(data, weights)
|
|
246
|
-
ctx.octree_conv = octree_conv
|
|
247
|
-
return out
|
|
248
|
-
|
|
249
|
-
@staticmethod
|
|
250
|
-
def backward(ctx, grad):
|
|
251
|
-
data, weights = ctx.saved_tensors
|
|
252
|
-
octree_conv = ctx.octree_conv
|
|
253
|
-
|
|
254
|
-
grad_out = None
|
|
255
|
-
if ctx.needs_input_grad[0]:
|
|
256
|
-
grad_out = torch.zeros_like(data)
|
|
257
|
-
grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
|
|
258
|
-
|
|
259
|
-
grad_w = None
|
|
260
|
-
if ctx.needs_input_grad[1]:
|
|
261
|
-
grad_w = torch.zeros_like(weights)
|
|
262
|
-
grad_w = octree_conv.weight_gemm(grad_w, data, grad)
|
|
263
|
-
|
|
264
|
-
return (grad_out, grad_w) + (None,) * 8
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
class OctreeDeconvFunction(Function):
|
|
268
|
-
r''' Wrap the octree deconvolution for auto-diff.
|
|
269
|
-
'''
|
|
270
|
-
|
|
271
|
-
@staticmethod
|
|
272
|
-
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
|
|
273
|
-
depth: int, in_channels: int, out_channels: int,
|
|
274
|
-
kernel_size: List[int] = [3, 3, 3], stride: int = 1,
|
|
275
|
-
nempty: bool = False, max_buffer: int = int(2e8)):
|
|
276
|
-
octree_deconv = _OctreeDeconv(
|
|
277
|
-
in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
|
|
278
|
-
octree_deconv.setup(octree, depth)
|
|
279
|
-
out = octree_deconv.check_and_init(data)
|
|
280
|
-
out = octree_deconv.backward_gemm(out, data, weights)
|
|
281
|
-
|
|
282
|
-
ctx.save_for_backward(data, weights)
|
|
283
|
-
ctx.octree_deconv = octree_deconv
|
|
284
|
-
return out
|
|
285
|
-
|
|
286
|
-
@staticmethod
|
|
287
|
-
def backward(ctx, grad):
|
|
288
|
-
data, weights = ctx.saved_tensors
|
|
289
|
-
octree_deconv = ctx.octree_deconv
|
|
290
|
-
|
|
291
|
-
grad_out = None
|
|
292
|
-
if ctx.needs_input_grad[0]:
|
|
293
|
-
grad_out = torch.zeros_like(data)
|
|
294
|
-
grad_out = octree_deconv.forward_gemm(grad_out, grad, weights)
|
|
295
|
-
|
|
296
|
-
grad_w = None
|
|
297
|
-
if ctx.needs_input_grad[1]:
|
|
298
|
-
grad_w = torch.zeros_like(weights)
|
|
299
|
-
grad_w = octree_deconv.weight_gemm(grad_w, grad, data)
|
|
300
|
-
|
|
301
|
-
return (grad_out, grad_w) + (None,) * 8
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
# alias
|
|
305
|
-
octree_conv = OctreeConvFunction.apply
|
|
306
|
-
octree_deconv = OctreeDeconvFunction.apply
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
class OctreeConv(OctreeConvBase, torch.nn.Module):
|
|
310
|
-
r''' Performs octree convolution.
|
|
311
|
-
|
|
312
|
-
Args:
|
|
313
|
-
in_channels (int): Number of input channels.
|
|
314
|
-
out_channels (int): Number of output channels.
|
|
315
|
-
kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
|
|
316
|
-
:obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
|
|
317
|
-
:obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
|
|
318
|
-
stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
|
|
319
|
-
nempty (bool): If True, only performs the convolution on non-empty
|
|
320
|
-
octree nodes.
|
|
321
|
-
direct_method (bool): If True, directly performs the convolution via using
|
|
322
|
-
gemm and octree2col/col2octree. The octree2col/col2octree needs to
|
|
323
|
-
construct a large matrix, which may consume a lot of memory. If False,
|
|
324
|
-
performs the convolution in a sub-matrix manner, which can save the
|
|
325
|
-
requied runtime memory.
|
|
326
|
-
use_bias (bool): If True, add a bias term to the convolution.
|
|
327
|
-
max_buffer (int): The maximum number of elements in the buffer, used when
|
|
328
|
-
:attr:`direct_method` is False.
|
|
329
|
-
|
|
330
|
-
.. note::
|
|
331
|
-
Each non-empty octree node has exactly 8 children nodes, among which some
|
|
332
|
-
children nodes are non-empty and some are empty. If :attr:`nempty` is true,
|
|
333
|
-
the convolution is performed on non-empty octree nodes only, which is exactly
|
|
334
|
-
the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
|
|
335
|
-
convolution is performed on all octree nodes, which is essential for shape
|
|
336
|
-
reconstruction tasks and can also be used in classification and segmentation
|
|
337
|
-
(with slightly better performance and larger memory cost).
|
|
338
|
-
'''
|
|
339
|
-
|
|
340
|
-
def __init__(self, in_channels: int, out_channels: int,
|
|
341
|
-
kernel_size: List[int] = [3], stride: int = 1,
|
|
342
|
-
nempty: bool = False, direct_method: bool = False,
|
|
343
|
-
use_bias: bool = False, max_buffer: int = int(2e8)):
|
|
344
|
-
super().__init__(
|
|
345
|
-
in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
|
|
346
|
-
|
|
347
|
-
self.direct_method = direct_method
|
|
348
|
-
self.use_bias = use_bias
|
|
349
|
-
self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
|
|
350
|
-
if self.use_bias:
|
|
351
|
-
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
|
352
|
-
self.reset_parameters()
|
|
353
|
-
|
|
354
|
-
def reset_parameters(self):
|
|
355
|
-
xavier_uniform_(self.weights)
|
|
356
|
-
if self.use_bias:
|
|
357
|
-
torch.nn.init.zeros_(self.bias)
|
|
358
|
-
|
|
359
|
-
def is_conv_layer(self): return True
|
|
360
|
-
|
|
361
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
362
|
-
r''' Defines the octree convolution.
|
|
363
|
-
|
|
364
|
-
Args:
|
|
365
|
-
data (torch.Tensor): The input data.
|
|
366
|
-
octree (Octree): The corresponding octree.
|
|
367
|
-
depth (int): The depth of current octree.
|
|
368
|
-
'''
|
|
369
|
-
|
|
370
|
-
if self.direct_method:
|
|
371
|
-
col = octree2col(
|
|
372
|
-
data, octree, depth, self.kernel, self.stride, self.nempty)
|
|
373
|
-
out = torch.mm(col.flatten(1), self.weights.flatten(0, 1))
|
|
374
|
-
else:
|
|
375
|
-
out = octree_conv(
|
|
376
|
-
data, self.weights, octree, depth, self.in_channels,
|
|
377
|
-
self.out_channels, self.kernel_size, self.stride, self.nempty,
|
|
378
|
-
self.max_buffer)
|
|
379
|
-
|
|
380
|
-
if self.use_bias:
|
|
381
|
-
out += self.bias
|
|
382
|
-
|
|
383
|
-
if self.stride == 2 and not self.nempty:
|
|
384
|
-
out = octree_pad(out, octree, depth-1)
|
|
385
|
-
return out
|
|
386
|
-
|
|
387
|
-
def extra_repr(self) -> str:
|
|
388
|
-
r''' Sets the extra representation of the module.
|
|
389
|
-
'''
|
|
390
|
-
|
|
391
|
-
return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
|
|
392
|
-
'nempty={}, bias={}').format(self.in_channels, self.out_channels,
|
|
393
|
-
self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
class OctreeDeconv(OctreeConv):
|
|
397
|
-
r''' Performs octree deconvolution.
|
|
398
|
-
|
|
399
|
-
Please refer to :class:`OctreeConv` for the meaning of the arguments.
|
|
400
|
-
'''
|
|
401
|
-
|
|
402
|
-
def is_conv_layer(self): return False
|
|
403
|
-
|
|
404
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
405
|
-
r''' Defines the octree deconvolution.
|
|
406
|
-
|
|
407
|
-
Please refer to :meth:`OctreeConv.forward` for the meaning of the arguments.
|
|
408
|
-
'''
|
|
409
|
-
|
|
410
|
-
depth_col = depth
|
|
411
|
-
if self.stride == 2:
|
|
412
|
-
depth_col = depth + 1
|
|
413
|
-
if not self.nempty:
|
|
414
|
-
data = octree_depad(data, octree, depth)
|
|
415
|
-
|
|
416
|
-
if self.direct_method:
|
|
417
|
-
col = torch.mm(data, self.weights.flatten(0, 1).t())
|
|
418
|
-
col = col.view(col.shape[0], self.kdim, -1)
|
|
419
|
-
out = col2octree(
|
|
420
|
-
col, octree, depth_col, self.kernel, self.stride, self.nempty)
|
|
421
|
-
else:
|
|
422
|
-
out = octree_deconv(
|
|
423
|
-
data, self.weights, octree, depth, self.in_channels,
|
|
424
|
-
self.out_channels, self.kernel_size, self.stride, self.nempty,
|
|
425
|
-
self.max_buffer)
|
|
426
|
-
|
|
427
|
-
if self.use_bias:
|
|
428
|
-
out += self.bias
|
|
429
|
-
return out
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from torch.autograd import Function
|
|
11
|
+
from typing import List
|
|
12
|
+
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
from ocnn.utils import scatter_add, xavier_uniform_, resize_with_last_val, list2str
|
|
15
|
+
from .octree2col import octree2col, col2octree
|
|
16
|
+
from .octree_pad import octree_pad, octree_depad
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OctreeConvBase:
|
|
20
|
+
|
|
21
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
22
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
23
|
+
nempty: bool = False, max_buffer: int = int(2e8)):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.in_channels = in_channels
|
|
26
|
+
self.out_channels = out_channels
|
|
27
|
+
self.kernel_size = resize_with_last_val(kernel_size)
|
|
28
|
+
self.kernel = list2str(self.kernel_size)
|
|
29
|
+
self.stride = stride
|
|
30
|
+
self.nempty = nempty
|
|
31
|
+
self.max_buffer = max_buffer # about 200M
|
|
32
|
+
|
|
33
|
+
self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
|
|
34
|
+
self.in_conv = in_channels if self.is_conv_layer() else out_channels
|
|
35
|
+
self.out_conv = out_channels if self.is_conv_layer() else in_channels
|
|
36
|
+
self.weights_shape = (self.kdim, self.in_conv, self.out_conv)
|
|
37
|
+
|
|
38
|
+
def is_conv_layer(self):
|
|
39
|
+
r''' Returns :obj:`True` to indicate this is a convolution layer.
|
|
40
|
+
'''
|
|
41
|
+
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
def setup(self, octree: Octree, depth: int):
|
|
45
|
+
r''' Setup the shapes of each tensor.
|
|
46
|
+
This function MUST be called before :obj:`forward_gemm`, :obj:`backward_gemm`
|
|
47
|
+
and :obj:`weight_gemm`.
|
|
48
|
+
'''
|
|
49
|
+
|
|
50
|
+
# The depth of tensors:
|
|
51
|
+
# The in_depth and out_depth are the octree depth of the input and output
|
|
52
|
+
# data; neigh_depth is the octree depth of the neighborhood information, as
|
|
53
|
+
# well as `col` data, neigh_depth is always the same as the depth of larger
|
|
54
|
+
# data when doing octree2col or col2octree.
|
|
55
|
+
self.in_depth = depth
|
|
56
|
+
self.out_depth = depth
|
|
57
|
+
self.neigh_depth = depth
|
|
58
|
+
if self.stride == 2:
|
|
59
|
+
if self.is_conv_layer():
|
|
60
|
+
self.out_depth = depth - 1
|
|
61
|
+
else:
|
|
62
|
+
self.out_depth = depth + 1
|
|
63
|
+
self.neigh_depth = depth + 1
|
|
64
|
+
|
|
65
|
+
# The height of tensors
|
|
66
|
+
if self.nempty:
|
|
67
|
+
self.in_h = octree.nnum_nempty[self.in_depth]
|
|
68
|
+
self.out_h = octree.nnum_nempty[self.out_depth]
|
|
69
|
+
else:
|
|
70
|
+
self.in_h = octree.nnum[self.in_depth]
|
|
71
|
+
self.out_h = octree.nnum[self.out_depth]
|
|
72
|
+
if self.stride == 2:
|
|
73
|
+
if self.is_conv_layer():
|
|
74
|
+
self.out_h = octree.nnum_nempty[self.out_depth]
|
|
75
|
+
else:
|
|
76
|
+
self.in_h = octree.nnum_nempty[self.in_depth]
|
|
77
|
+
self.in_shape = (self.in_h, self.in_channels)
|
|
78
|
+
self.out_shape = (self.out_h, self.out_channels)
|
|
79
|
+
|
|
80
|
+
# The neighborhood indices
|
|
81
|
+
self.neigh = octree.get_neigh(
|
|
82
|
+
self.neigh_depth, self.kernel, self.stride, self.nempty)
|
|
83
|
+
|
|
84
|
+
# The heigh and number of the temporary buffer
|
|
85
|
+
self.buffer_n = 1
|
|
86
|
+
self.buffer_h = self.neigh.shape[0]
|
|
87
|
+
ideal_size = self.buffer_h * self.kdim * self.in_conv
|
|
88
|
+
if ideal_size > self.max_buffer:
|
|
89
|
+
kc = self.kdim * self.in_conv # make `max_buffer` be divided
|
|
90
|
+
max_buffer = self.max_buffer // kc * kc # by `kc` with no remainder
|
|
91
|
+
self.buffer_n = (ideal_size + max_buffer - 1) // max_buffer
|
|
92
|
+
self.buffer_h = (self.buffer_h + self.buffer_n - 1) // self.buffer_n
|
|
93
|
+
self.buffer_shape = (self.buffer_h, self.kdim, self.in_conv)
|
|
94
|
+
|
|
95
|
+
def check_and_init(self, data: torch.Tensor):
|
|
96
|
+
r''' Checks the input data and initializes the shape of output data.
|
|
97
|
+
'''
|
|
98
|
+
|
|
99
|
+
# Check the shape of input data
|
|
100
|
+
check = tuple(data.shape) == self.in_shape
|
|
101
|
+
assert check, 'The shape of input data is wrong.'
|
|
102
|
+
|
|
103
|
+
# Init the output data
|
|
104
|
+
out = data.new_zeros(self.out_shape)
|
|
105
|
+
return out
|
|
106
|
+
|
|
107
|
+
def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
|
|
108
|
+
weights: torch.Tensor):
|
|
109
|
+
r''' Peforms the forward pass of octree-based convolution.
|
|
110
|
+
'''
|
|
111
|
+
|
|
112
|
+
# Type check
|
|
113
|
+
if data.dtype != out.dtype:
|
|
114
|
+
data = data.to(out.dtype)
|
|
115
|
+
if weights.dtype != out.dtype:
|
|
116
|
+
weights = weights.to(out.dtype)
|
|
117
|
+
|
|
118
|
+
# Initialize the buffer
|
|
119
|
+
buffer = data.new_empty(self.buffer_shape)
|
|
120
|
+
|
|
121
|
+
# Loop over each sub-matrix
|
|
122
|
+
for i in range(self.buffer_n):
|
|
123
|
+
start = i * self.buffer_h
|
|
124
|
+
end = (i + 1) * self.buffer_h
|
|
125
|
+
|
|
126
|
+
# The boundary case in the last iteration
|
|
127
|
+
if end > self.neigh.shape[0]:
|
|
128
|
+
dis = end - self.neigh.shape[0]
|
|
129
|
+
end = self.neigh.shape[0]
|
|
130
|
+
buffer, _ = buffer.split([self.buffer_h-dis, dis])
|
|
131
|
+
|
|
132
|
+
# Perform octree2col
|
|
133
|
+
neigh_i = self.neigh[start:end]
|
|
134
|
+
valid = neigh_i >= 0
|
|
135
|
+
buffer.fill_(0)
|
|
136
|
+
buffer[valid] = data[neigh_i[valid]]
|
|
137
|
+
|
|
138
|
+
# The sub-matrix gemm
|
|
139
|
+
out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
|
|
140
|
+
|
|
141
|
+
return out
|
|
142
|
+
|
|
143
|
+
def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
|
|
144
|
+
weights: torch.Tensor):
|
|
145
|
+
r''' Performs the backward pass of octree-based convolution.
|
|
146
|
+
'''
|
|
147
|
+
|
|
148
|
+
# Type check
|
|
149
|
+
if grad.dtype != out.dtype:
|
|
150
|
+
grad = grad.to(out.dtype)
|
|
151
|
+
if weights.dtype != out.dtype:
|
|
152
|
+
weights = weights.to(out.dtype)
|
|
153
|
+
|
|
154
|
+
# Loop over each sub-matrix
|
|
155
|
+
for i in range(self.buffer_n):
|
|
156
|
+
start = i * self.buffer_h
|
|
157
|
+
end = (i + 1) * self.buffer_h
|
|
158
|
+
|
|
159
|
+
# The boundary case in the last iteration
|
|
160
|
+
if end > self.neigh.shape[0]:
|
|
161
|
+
end = self.neigh.shape[0]
|
|
162
|
+
|
|
163
|
+
# The sub-matrix gemm
|
|
164
|
+
buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
|
|
165
|
+
buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
|
|
166
|
+
buffer = buffer.to(out.dtype) # for pytorch.amp
|
|
167
|
+
|
|
168
|
+
# Performs col2octree
|
|
169
|
+
neigh_i = self.neigh[start:end]
|
|
170
|
+
valid = neigh_i >= 0
|
|
171
|
+
out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
|
|
172
|
+
|
|
173
|
+
return out
|
|
174
|
+
|
|
175
|
+
def weight_gemm(
|
|
176
|
+
self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
|
|
177
|
+
r''' Computes the gradient of the weight matrix.
|
|
178
|
+
'''
|
|
179
|
+
|
|
180
|
+
# Type check
|
|
181
|
+
if data.dtype != out.dtype:
|
|
182
|
+
data = data.to(out.dtype)
|
|
183
|
+
if grad.dtype != out.dtype:
|
|
184
|
+
grad = grad.to(out.dtype)
|
|
185
|
+
|
|
186
|
+
# Record the shape of out
|
|
187
|
+
out_shape = out.shape
|
|
188
|
+
out = out.flatten(0, 1)
|
|
189
|
+
|
|
190
|
+
# Initialize the buffer
|
|
191
|
+
buffer = data.new_empty(self.buffer_shape)
|
|
192
|
+
|
|
193
|
+
# Loop over each sub-matrix
|
|
194
|
+
for i in range(self.buffer_n):
|
|
195
|
+
start = i * self.buffer_h
|
|
196
|
+
end = (i + 1) * self.buffer_h
|
|
197
|
+
|
|
198
|
+
# The boundary case in the last iteration
|
|
199
|
+
if end > self.neigh.shape[0]:
|
|
200
|
+
d = end - self.neigh.shape[0]
|
|
201
|
+
end = self.neigh.shape[0]
|
|
202
|
+
buffer, _ = buffer.split([self.buffer_h-d, d])
|
|
203
|
+
|
|
204
|
+
# Perform octree2col
|
|
205
|
+
neigh_i = self.neigh[start:end]
|
|
206
|
+
valid = neigh_i >= 0
|
|
207
|
+
buffer.fill_(0)
|
|
208
|
+
buffer[valid] = data[neigh_i[valid]]
|
|
209
|
+
|
|
210
|
+
# Accumulate the gradient via gemm
|
|
211
|
+
out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
|
|
212
|
+
|
|
213
|
+
return out.view(out_shape)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class _OctreeConv(OctreeConvBase):
|
|
217
|
+
r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
|
|
218
|
+
'''
|
|
219
|
+
|
|
220
|
+
def is_conv_layer(self): return True
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class _OctreeDeconv(OctreeConvBase):
|
|
224
|
+
r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
|
|
225
|
+
'''
|
|
226
|
+
|
|
227
|
+
def is_conv_layer(self): return False
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class OctreeConvFunction(Function):
|
|
231
|
+
r''' Wrap the octree convolution for auto-diff.
|
|
232
|
+
'''
|
|
233
|
+
|
|
234
|
+
@staticmethod
|
|
235
|
+
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
|
|
236
|
+
depth: int, in_channels: int, out_channels: int,
|
|
237
|
+
kernel_size: List[int] = [3, 3, 3], stride: int = 1,
|
|
238
|
+
nempty: bool = False, max_buffer: int = int(2e8)):
|
|
239
|
+
octree_conv = _OctreeConv(
|
|
240
|
+
in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
|
|
241
|
+
octree_conv.setup(octree, depth)
|
|
242
|
+
out = octree_conv.check_and_init(data)
|
|
243
|
+
out = octree_conv.forward_gemm(out, data, weights)
|
|
244
|
+
|
|
245
|
+
ctx.save_for_backward(data, weights)
|
|
246
|
+
ctx.octree_conv = octree_conv
|
|
247
|
+
return out
|
|
248
|
+
|
|
249
|
+
@staticmethod
|
|
250
|
+
def backward(ctx, grad):
|
|
251
|
+
data, weights = ctx.saved_tensors
|
|
252
|
+
octree_conv = ctx.octree_conv
|
|
253
|
+
|
|
254
|
+
grad_out = None
|
|
255
|
+
if ctx.needs_input_grad[0]:
|
|
256
|
+
grad_out = torch.zeros_like(data)
|
|
257
|
+
grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
|
|
258
|
+
|
|
259
|
+
grad_w = None
|
|
260
|
+
if ctx.needs_input_grad[1]:
|
|
261
|
+
grad_w = torch.zeros_like(weights)
|
|
262
|
+
grad_w = octree_conv.weight_gemm(grad_w, data, grad)
|
|
263
|
+
|
|
264
|
+
return (grad_out, grad_w) + (None,) * 8
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class OctreeDeconvFunction(Function):
|
|
268
|
+
r''' Wrap the octree deconvolution for auto-diff.
|
|
269
|
+
'''
|
|
270
|
+
|
|
271
|
+
@staticmethod
|
|
272
|
+
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
|
|
273
|
+
depth: int, in_channels: int, out_channels: int,
|
|
274
|
+
kernel_size: List[int] = [3, 3, 3], stride: int = 1,
|
|
275
|
+
nempty: bool = False, max_buffer: int = int(2e8)):
|
|
276
|
+
octree_deconv = _OctreeDeconv(
|
|
277
|
+
in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
|
|
278
|
+
octree_deconv.setup(octree, depth)
|
|
279
|
+
out = octree_deconv.check_and_init(data)
|
|
280
|
+
out = octree_deconv.backward_gemm(out, data, weights)
|
|
281
|
+
|
|
282
|
+
ctx.save_for_backward(data, weights)
|
|
283
|
+
ctx.octree_deconv = octree_deconv
|
|
284
|
+
return out
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def backward(ctx, grad):
|
|
288
|
+
data, weights = ctx.saved_tensors
|
|
289
|
+
octree_deconv = ctx.octree_deconv
|
|
290
|
+
|
|
291
|
+
grad_out = None
|
|
292
|
+
if ctx.needs_input_grad[0]:
|
|
293
|
+
grad_out = torch.zeros_like(data)
|
|
294
|
+
grad_out = octree_deconv.forward_gemm(grad_out, grad, weights)
|
|
295
|
+
|
|
296
|
+
grad_w = None
|
|
297
|
+
if ctx.needs_input_grad[1]:
|
|
298
|
+
grad_w = torch.zeros_like(weights)
|
|
299
|
+
grad_w = octree_deconv.weight_gemm(grad_w, grad, data)
|
|
300
|
+
|
|
301
|
+
return (grad_out, grad_w) + (None,) * 8
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
# alias
|
|
305
|
+
octree_conv = OctreeConvFunction.apply
|
|
306
|
+
octree_deconv = OctreeDeconvFunction.apply
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class OctreeConv(OctreeConvBase, torch.nn.Module):
|
|
310
|
+
r''' Performs octree convolution.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
in_channels (int): Number of input channels.
|
|
314
|
+
out_channels (int): Number of output channels.
|
|
315
|
+
kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
|
|
316
|
+
:obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
|
|
317
|
+
:obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
|
|
318
|
+
stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
|
|
319
|
+
nempty (bool): If True, only performs the convolution on non-empty
|
|
320
|
+
octree nodes.
|
|
321
|
+
direct_method (bool): If True, directly performs the convolution via using
|
|
322
|
+
gemm and octree2col/col2octree. The octree2col/col2octree needs to
|
|
323
|
+
construct a large matrix, which may consume a lot of memory. If False,
|
|
324
|
+
performs the convolution in a sub-matrix manner, which can save the
|
|
325
|
+
requied runtime memory.
|
|
326
|
+
use_bias (bool): If True, add a bias term to the convolution.
|
|
327
|
+
max_buffer (int): The maximum number of elements in the buffer, used when
|
|
328
|
+
:attr:`direct_method` is False.
|
|
329
|
+
|
|
330
|
+
.. note::
|
|
331
|
+
Each non-empty octree node has exactly 8 children nodes, among which some
|
|
332
|
+
children nodes are non-empty and some are empty. If :attr:`nempty` is true,
|
|
333
|
+
the convolution is performed on non-empty octree nodes only, which is exactly
|
|
334
|
+
the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
|
|
335
|
+
convolution is performed on all octree nodes, which is essential for shape
|
|
336
|
+
reconstruction tasks and can also be used in classification and segmentation
|
|
337
|
+
(with slightly better performance and larger memory cost).
|
|
338
|
+
'''
|
|
339
|
+
|
|
340
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
341
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
342
|
+
nempty: bool = False, direct_method: bool = False,
|
|
343
|
+
use_bias: bool = False, max_buffer: int = int(2e8)):
|
|
344
|
+
super().__init__(
|
|
345
|
+
in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
|
|
346
|
+
|
|
347
|
+
self.direct_method = direct_method
|
|
348
|
+
self.use_bias = use_bias
|
|
349
|
+
self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
|
|
350
|
+
if self.use_bias:
|
|
351
|
+
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
|
352
|
+
self.reset_parameters()
|
|
353
|
+
|
|
354
|
+
def reset_parameters(self):
|
|
355
|
+
xavier_uniform_(self.weights)
|
|
356
|
+
if self.use_bias:
|
|
357
|
+
torch.nn.init.zeros_(self.bias)
|
|
358
|
+
|
|
359
|
+
def is_conv_layer(self): return True
|
|
360
|
+
|
|
361
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
362
|
+
r''' Defines the octree convolution.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
data (torch.Tensor): The input data.
|
|
366
|
+
octree (Octree): The corresponding octree.
|
|
367
|
+
depth (int): The depth of current octree.
|
|
368
|
+
'''
|
|
369
|
+
|
|
370
|
+
if self.direct_method:
|
|
371
|
+
col = octree2col(
|
|
372
|
+
data, octree, depth, self.kernel, self.stride, self.nempty)
|
|
373
|
+
out = torch.mm(col.flatten(1), self.weights.flatten(0, 1))
|
|
374
|
+
else:
|
|
375
|
+
out = octree_conv(
|
|
376
|
+
data, self.weights, octree, depth, self.in_channels,
|
|
377
|
+
self.out_channels, self.kernel_size, self.stride, self.nempty,
|
|
378
|
+
self.max_buffer)
|
|
379
|
+
|
|
380
|
+
if self.use_bias:
|
|
381
|
+
out += self.bias
|
|
382
|
+
|
|
383
|
+
if self.stride == 2 and not self.nempty:
|
|
384
|
+
out = octree_pad(out, octree, depth-1)
|
|
385
|
+
return out
|
|
386
|
+
|
|
387
|
+
def extra_repr(self) -> str:
|
|
388
|
+
r''' Sets the extra representation of the module.
|
|
389
|
+
'''
|
|
390
|
+
|
|
391
|
+
return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
|
|
392
|
+
'nempty={}, bias={}').format(self.in_channels, self.out_channels,
|
|
393
|
+
self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class OctreeDeconv(OctreeConv):
|
|
397
|
+
r''' Performs octree deconvolution.
|
|
398
|
+
|
|
399
|
+
Please refer to :class:`OctreeConv` for the meaning of the arguments.
|
|
400
|
+
'''
|
|
401
|
+
|
|
402
|
+
def is_conv_layer(self): return False
|
|
403
|
+
|
|
404
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
405
|
+
r''' Defines the octree deconvolution.
|
|
406
|
+
|
|
407
|
+
Please refer to :meth:`OctreeConv.forward` for the meaning of the arguments.
|
|
408
|
+
'''
|
|
409
|
+
|
|
410
|
+
depth_col = depth
|
|
411
|
+
if self.stride == 2:
|
|
412
|
+
depth_col = depth + 1
|
|
413
|
+
if not self.nempty:
|
|
414
|
+
data = octree_depad(data, octree, depth)
|
|
415
|
+
|
|
416
|
+
if self.direct_method:
|
|
417
|
+
col = torch.mm(data, self.weights.flatten(0, 1).t())
|
|
418
|
+
col = col.view(col.shape[0], self.kdim, -1)
|
|
419
|
+
out = col2octree(
|
|
420
|
+
col, octree, depth_col, self.kernel, self.stride, self.nempty)
|
|
421
|
+
else:
|
|
422
|
+
out = octree_deconv(
|
|
423
|
+
data, self.weights, octree, depth, self.in_channels,
|
|
424
|
+
self.out_channels, self.kernel_size, self.stride, self.nempty,
|
|
425
|
+
self.max_buffer)
|
|
426
|
+
|
|
427
|
+
if self.use_bias:
|
|
428
|
+
out += self.bias
|
|
429
|
+
return out
|