ocnn 2.2.1__py3-none-any.whl → 2.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/nn/octree_conv.py CHANGED
@@ -1,411 +1,429 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from torch.autograd import Function
11
- from typing import List
12
-
13
- from ocnn.octree import Octree
14
- from ocnn.utils import scatter_add, xavier_uniform_, resize_with_last_val, list2str
15
- from .octree2col import octree2col, col2octree
16
- from .octree_pad import octree_pad, octree_depad
17
-
18
-
19
- class OctreeConvBase:
20
-
21
- def __init__(self, in_channels: int, out_channels: int,
22
- kernel_size: List[int] = [3], stride: int = 1,
23
- nempty: bool = False, max_buffer: int = int(2e8)):
24
- super().__init__()
25
- self.in_channels = in_channels
26
- self.out_channels = out_channels
27
- self.kernel_size = resize_with_last_val(kernel_size)
28
- self.kernel = list2str(self.kernel_size)
29
- self.stride = stride
30
- self.nempty = nempty
31
- self.max_buffer = max_buffer # about 200M
32
-
33
- self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
34
- self.in_conv = in_channels if self.is_conv_layer() else out_channels
35
- self.out_conv = out_channels if self.is_conv_layer() else in_channels
36
- self.weights_shape = (self.kdim, self.in_conv, self.out_conv)
37
-
38
- def is_conv_layer(self):
39
- r''' Returns :obj:`True` to indicate this is a convolution layer.
40
- '''
41
-
42
- raise NotImplementedError
43
-
44
- def setup(self, octree: Octree, depth: int):
45
- r''' Setup the shapes of each tensor.
46
- This function MUST be called before :obj:`forward_gemm`, :obj:`backward_gemm`
47
- and :obj:`weight_gemm`.
48
- '''
49
-
50
- # The depth of tensors:
51
- # The in_depth and out_depth are the octree depth of the input and output
52
- # data; neigh_depth is the octree depth of the neighborhood information, as
53
- # well as `col` data, neigh_depth is always the same as the depth of larger
54
- # data when doing octree2col or col2octree.
55
- self.in_depth = depth
56
- self.out_depth = depth
57
- self.neigh_depth = depth
58
- if self.stride == 2:
59
- if self.is_conv_layer():
60
- self.out_depth = depth - 1
61
- else:
62
- self.out_depth = depth + 1
63
- self.neigh_depth = depth + 1
64
-
65
- # The height of tensors
66
- if self.nempty:
67
- self.in_h = octree.nnum_nempty[self.in_depth]
68
- self.out_h = octree.nnum_nempty[self.out_depth]
69
- else:
70
- self.in_h = octree.nnum[self.in_depth]
71
- self.out_h = octree.nnum[self.out_depth]
72
- if self.stride == 2:
73
- if self.is_conv_layer():
74
- self.out_h = octree.nnum_nempty[self.out_depth]
75
- else:
76
- self.in_h = octree.nnum_nempty[self.in_depth]
77
- self.in_shape = (self.in_h, self.in_channels)
78
- self.out_shape = (self.out_h, self.out_channels)
79
-
80
- # The neighborhood indices
81
- self.neigh = octree.get_neigh(
82
- self.neigh_depth, self.kernel, self.stride, self.nempty)
83
-
84
- # The heigh and number of the temporary buffer
85
- self.buffer_n = 1
86
- self.buffer_h = self.neigh.shape[0]
87
- ideal_size = self.buffer_h * self.kdim * self.in_conv
88
- if ideal_size > self.max_buffer:
89
- kc = self.kdim * self.in_conv # make `max_buffer` be divided
90
- max_buffer = self.max_buffer // kc * kc # by `kc` with no remainder
91
- self.buffer_n = (ideal_size + max_buffer - 1) // max_buffer
92
- self.buffer_h = (self.buffer_h + self.buffer_n - 1) // self.buffer_n
93
- self.buffer_shape = (self.buffer_h, self.kdim, self.in_conv)
94
-
95
- def check_and_init(self, data: torch.Tensor):
96
- r''' Checks the input data and initializes the shape of output data.
97
- '''
98
-
99
- # Check the shape of input data
100
- check = tuple(data.shape) == self.in_shape
101
- assert check, 'The shape of input data is wrong.'
102
-
103
- # Init the output data
104
- out = data.new_zeros(self.out_shape)
105
- return out
106
-
107
- def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
108
- weights: torch.Tensor):
109
- r''' Peforms the forward pass of octree-based convolution.
110
- '''
111
-
112
- # Initialize the buffer
113
- buffer = data.new_empty(self.buffer_shape)
114
-
115
- # Loop over each sub-matrix
116
- for i in range(self.buffer_n):
117
- start = i * self.buffer_h
118
- end = (i + 1) * self.buffer_h
119
-
120
- # The boundary case in the last iteration
121
- if end > self.neigh.shape[0]:
122
- dis = end - self.neigh.shape[0]
123
- end = self.neigh.shape[0]
124
- buffer, _ = buffer.split([self.buffer_h-dis, dis])
125
-
126
- # Perform octree2col
127
- neigh_i = self.neigh[start:end]
128
- valid = neigh_i >= 0
129
- buffer.fill_(0)
130
- buffer[valid] = data[neigh_i[valid]]
131
-
132
- # The sub-matrix gemm
133
- out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
134
-
135
- return out
136
-
137
- def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
138
- weights: torch.Tensor):
139
- r''' Performs the backward pass of octree-based convolution.
140
- '''
141
-
142
- # Loop over each sub-matrix
143
- for i in range(self.buffer_n):
144
- start = i * self.buffer_h
145
- end = (i + 1) * self.buffer_h
146
-
147
- # The boundary case in the last iteration
148
- if end > self.neigh.shape[0]:
149
- end = self.neigh.shape[0]
150
-
151
- # The sub-matrix gemm
152
- buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
153
- buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
154
- buffer = buffer.to(out.dtype) # for pytorch.amp
155
-
156
- # Performs col2octree
157
- neigh_i = self.neigh[start:end]
158
- valid = neigh_i >= 0
159
- out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
160
-
161
- return out
162
-
163
- def weight_gemm(
164
- self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
165
- r''' Computes the gradient of the weight matrix.
166
- '''
167
-
168
- # Record the shape of out
169
- out_shape = out.shape
170
- out = out.flatten(0, 1)
171
-
172
- # Initialize the buffer
173
- buffer = data.new_empty(self.buffer_shape)
174
-
175
- # Loop over each sub-matrix
176
- for i in range(self.buffer_n):
177
- start = i * self.buffer_h
178
- end = (i + 1) * self.buffer_h
179
-
180
- # The boundary case in the last iteration
181
- if end > self.neigh.shape[0]:
182
- d = end - self.neigh.shape[0]
183
- end = self.neigh.shape[0]
184
- buffer, _ = buffer.split([self.buffer_h-d, d])
185
-
186
- # Perform octree2col
187
- neigh_i = self.neigh[start:end]
188
- valid = neigh_i >= 0
189
- buffer.fill_(0)
190
- buffer[valid] = data[neigh_i[valid]]
191
-
192
- # Accumulate the gradient via gemm
193
- out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
194
-
195
- return out.view(out_shape)
196
-
197
-
198
- class _OctreeConv(OctreeConvBase):
199
- r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
200
- '''
201
-
202
- def is_conv_layer(self): return True
203
-
204
-
205
- class _OctreeDeconv(OctreeConvBase):
206
- r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
207
- '''
208
-
209
- def is_conv_layer(self): return False
210
-
211
-
212
- class OctreeConvFunction(Function):
213
- r''' Wrap the octree convolution for auto-diff.
214
- '''
215
-
216
- @staticmethod
217
- def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
218
- depth: int, in_channels: int, out_channels: int,
219
- kernel_size: List[int] = [3, 3, 3], stride: int = 1,
220
- nempty: bool = False, max_buffer: int = int(2e8)):
221
- octree_conv = _OctreeConv(
222
- in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
223
- octree_conv.setup(octree, depth)
224
- out = octree_conv.check_and_init(data)
225
- out = octree_conv.forward_gemm(out, data, weights)
226
-
227
- ctx.save_for_backward(data, weights)
228
- ctx.octree_conv = octree_conv
229
- return out
230
-
231
- @staticmethod
232
- def backward(ctx, grad):
233
- data, weights = ctx.saved_tensors
234
- octree_conv = ctx.octree_conv
235
-
236
- grad_out = None
237
- if ctx.needs_input_grad[0]:
238
- grad_out = torch.zeros_like(data)
239
- grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
240
-
241
- grad_w = None
242
- if ctx.needs_input_grad[1]:
243
- grad_w = torch.zeros_like(weights)
244
- grad_w = octree_conv.weight_gemm(grad_w, data, grad)
245
-
246
- return (grad_out, grad_w) + (None,) * 8
247
-
248
-
249
- class OctreeDeconvFunction(Function):
250
- r''' Wrap the octree deconvolution for auto-diff.
251
- '''
252
-
253
- @staticmethod
254
- def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
255
- depth: int, in_channels: int, out_channels: int,
256
- kernel_size: List[int] = [3, 3, 3], stride: int = 1,
257
- nempty: bool = False, max_buffer: int = int(2e8)):
258
- octree_deconv = _OctreeDeconv(
259
- in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
260
- octree_deconv.setup(octree, depth)
261
- out = octree_deconv.check_and_init(data)
262
- out = octree_deconv.backward_gemm(out, data, weights)
263
-
264
- ctx.save_for_backward(data, weights)
265
- ctx.octree_deconv = octree_deconv
266
- return out
267
-
268
- @staticmethod
269
- def backward(ctx, grad):
270
- data, weights = ctx.saved_tensors
271
- octree_deconv = ctx.octree_deconv
272
-
273
- grad_out = None
274
- if ctx.needs_input_grad[0]:
275
- grad_out = torch.zeros_like(data)
276
- grad_out = octree_deconv.forward_gemm(grad_out, grad, weights)
277
-
278
- grad_w = None
279
- if ctx.needs_input_grad[1]:
280
- grad_w = torch.zeros_like(weights)
281
- grad_w = octree_deconv.weight_gemm(grad_w, grad, data)
282
-
283
- return (grad_out, grad_w) + (None,) * 8
284
-
285
-
286
- # alias
287
- octree_conv = OctreeConvFunction.apply
288
- octree_deconv = OctreeDeconvFunction.apply
289
-
290
-
291
- class OctreeConv(OctreeConvBase, torch.nn.Module):
292
- r''' Performs octree convolution.
293
-
294
- Args:
295
- in_channels (int): Number of input channels.
296
- out_channels (int): Number of output channels.
297
- kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
298
- :obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
299
- :obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
300
- stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
301
- nempty (bool): If True, only performs the convolution on non-empty
302
- octree nodes.
303
- direct_method (bool): If True, directly performs the convolution via using
304
- gemm and octree2col/col2octree. The octree2col/col2octree needs to
305
- construct a large matrix, which may consume a lot of memory. If False,
306
- performs the convolution in a sub-matrix manner, which can save the
307
- requied runtime memory.
308
- use_bias (bool): If True, add a bias term to the convolution.
309
- max_buffer (int): The maximum number of elements in the buffer, used when
310
- :attr:`direct_method` is False.
311
-
312
- .. note::
313
- Each non-empty octree node has exactly 8 children nodes, among which some
314
- children nodes are non-empty and some are empty. If :attr:`nempty` is true,
315
- the convolution is performed on non-empty octree nodes only, which is exactly
316
- the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
317
- convolution is performed on all octree nodes, which is essential for shape
318
- reconstruction tasks and can also be used in classification and segmentation
319
- (with slightly better performance and larger memory cost).
320
- '''
321
-
322
- def __init__(self, in_channels: int, out_channels: int,
323
- kernel_size: List[int] = [3], stride: int = 1,
324
- nempty: bool = False, direct_method: bool = False,
325
- use_bias: bool = False, max_buffer: int = int(2e8)):
326
- super().__init__(
327
- in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
328
-
329
- self.direct_method = direct_method
330
- self.use_bias = use_bias
331
- self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
332
- if self.use_bias:
333
- self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
334
- self.reset_parameters()
335
-
336
- def reset_parameters(self):
337
- xavier_uniform_(self.weights)
338
- if self.use_bias:
339
- torch.nn.init.zeros_(self.bias)
340
-
341
- def is_conv_layer(self): return True
342
-
343
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
344
- r''' Defines the octree convolution.
345
-
346
- Args:
347
- data (torch.Tensor): The input data.
348
- octree (Octree): The corresponding octree.
349
- depth (int): The depth of current octree.
350
- '''
351
-
352
- if self.direct_method:
353
- col = octree2col(
354
- data, octree, depth, self.kernel, self.stride, self.nempty)
355
- out = torch.mm(col.flatten(1), self.weights.flatten(0, 1))
356
- else:
357
- out = octree_conv(
358
- data, self.weights, octree, depth, self.in_channels,
359
- self.out_channels, self.kernel_size, self.stride, self.nempty,
360
- self.max_buffer)
361
-
362
- if self.use_bias:
363
- out += self.bias
364
-
365
- if self.stride == 2 and not self.nempty:
366
- out = octree_pad(out, octree, depth-1)
367
- return out
368
-
369
- def extra_repr(self) -> str:
370
- r''' Sets the extra representation of the module.
371
- '''
372
-
373
- return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
374
- 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
375
- self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
376
-
377
-
378
- class OctreeDeconv(OctreeConv):
379
- r''' Performs octree deconvolution.
380
-
381
- Please refer to :class:`OctreeConv` for the meaning of the arguments.
382
- '''
383
-
384
- def is_conv_layer(self): return False
385
-
386
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
387
- r''' Defines the octree deconvolution.
388
-
389
- Please refer to :meth:`OctreeConv.forward` for the meaning of the arguments.
390
- '''
391
-
392
- depth_col = depth
393
- if self.stride == 2:
394
- depth_col = depth + 1
395
- if not self.nempty:
396
- data = octree_depad(data, octree, depth)
397
-
398
- if self.direct_method:
399
- col = torch.mm(data, self.weights.flatten(0, 1).t())
400
- col = col.view(col.shape[0], self.kdim, -1)
401
- out = col2octree(
402
- col, octree, depth_col, self.kernel, self.stride, self.nempty)
403
- else:
404
- out = octree_deconv(
405
- data, self.weights, octree, depth, self.in_channels,
406
- self.out_channels, self.kernel_size, self.stride, self.nempty,
407
- self.max_buffer)
408
-
409
- if self.use_bias:
410
- out += self.bias
411
- return out
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from torch.autograd import Function
11
+ from typing import List
12
+
13
+ from ocnn.octree import Octree
14
+ from ocnn.utils import scatter_add, xavier_uniform_, resize_with_last_val, list2str
15
+ from .octree2col import octree2col, col2octree
16
+ from .octree_pad import octree_pad, octree_depad
17
+
18
+
19
+ class OctreeConvBase:
20
+
21
+ def __init__(self, in_channels: int, out_channels: int,
22
+ kernel_size: List[int] = [3], stride: int = 1,
23
+ nempty: bool = False, max_buffer: int = int(2e8)):
24
+ super().__init__()
25
+ self.in_channels = in_channels
26
+ self.out_channels = out_channels
27
+ self.kernel_size = resize_with_last_val(kernel_size)
28
+ self.kernel = list2str(self.kernel_size)
29
+ self.stride = stride
30
+ self.nempty = nempty
31
+ self.max_buffer = max_buffer # about 200M
32
+
33
+ self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
34
+ self.in_conv = in_channels if self.is_conv_layer() else out_channels
35
+ self.out_conv = out_channels if self.is_conv_layer() else in_channels
36
+ self.weights_shape = (self.kdim, self.in_conv, self.out_conv)
37
+
38
+ def is_conv_layer(self):
39
+ r''' Returns :obj:`True` to indicate this is a convolution layer.
40
+ '''
41
+
42
+ raise NotImplementedError
43
+
44
+ def setup(self, octree: Octree, depth: int):
45
+ r''' Setup the shapes of each tensor.
46
+ This function MUST be called before :obj:`forward_gemm`, :obj:`backward_gemm`
47
+ and :obj:`weight_gemm`.
48
+ '''
49
+
50
+ # The depth of tensors:
51
+ # The in_depth and out_depth are the octree depth of the input and output
52
+ # data; neigh_depth is the octree depth of the neighborhood information, as
53
+ # well as `col` data, neigh_depth is always the same as the depth of larger
54
+ # data when doing octree2col or col2octree.
55
+ self.in_depth = depth
56
+ self.out_depth = depth
57
+ self.neigh_depth = depth
58
+ if self.stride == 2:
59
+ if self.is_conv_layer():
60
+ self.out_depth = depth - 1
61
+ else:
62
+ self.out_depth = depth + 1
63
+ self.neigh_depth = depth + 1
64
+
65
+ # The height of tensors
66
+ if self.nempty:
67
+ self.in_h = octree.nnum_nempty[self.in_depth]
68
+ self.out_h = octree.nnum_nempty[self.out_depth]
69
+ else:
70
+ self.in_h = octree.nnum[self.in_depth]
71
+ self.out_h = octree.nnum[self.out_depth]
72
+ if self.stride == 2:
73
+ if self.is_conv_layer():
74
+ self.out_h = octree.nnum_nempty[self.out_depth]
75
+ else:
76
+ self.in_h = octree.nnum_nempty[self.in_depth]
77
+ self.in_shape = (self.in_h, self.in_channels)
78
+ self.out_shape = (self.out_h, self.out_channels)
79
+
80
+ # The neighborhood indices
81
+ self.neigh = octree.get_neigh(
82
+ self.neigh_depth, self.kernel, self.stride, self.nempty)
83
+
84
+ # The heigh and number of the temporary buffer
85
+ self.buffer_n = 1
86
+ self.buffer_h = self.neigh.shape[0]
87
+ ideal_size = self.buffer_h * self.kdim * self.in_conv
88
+ if ideal_size > self.max_buffer:
89
+ kc = self.kdim * self.in_conv # make `max_buffer` be divided
90
+ max_buffer = self.max_buffer // kc * kc # by `kc` with no remainder
91
+ self.buffer_n = (ideal_size + max_buffer - 1) // max_buffer
92
+ self.buffer_h = (self.buffer_h + self.buffer_n - 1) // self.buffer_n
93
+ self.buffer_shape = (self.buffer_h, self.kdim, self.in_conv)
94
+
95
+ def check_and_init(self, data: torch.Tensor):
96
+ r''' Checks the input data and initializes the shape of output data.
97
+ '''
98
+
99
+ # Check the shape of input data
100
+ check = tuple(data.shape) == self.in_shape
101
+ assert check, 'The shape of input data is wrong.'
102
+
103
+ # Init the output data
104
+ out = data.new_zeros(self.out_shape)
105
+ return out
106
+
107
+ def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
108
+ weights: torch.Tensor):
109
+ r''' Peforms the forward pass of octree-based convolution.
110
+ '''
111
+
112
+ # Type check
113
+ if data.dtype != out.dtype:
114
+ data = data.to(out.dtype)
115
+ if weights.dtype != out.dtype:
116
+ weights = weights.to(out.dtype)
117
+
118
+ # Initialize the buffer
119
+ buffer = data.new_empty(self.buffer_shape)
120
+
121
+ # Loop over each sub-matrix
122
+ for i in range(self.buffer_n):
123
+ start = i * self.buffer_h
124
+ end = (i + 1) * self.buffer_h
125
+
126
+ # The boundary case in the last iteration
127
+ if end > self.neigh.shape[0]:
128
+ dis = end - self.neigh.shape[0]
129
+ end = self.neigh.shape[0]
130
+ buffer, _ = buffer.split([self.buffer_h-dis, dis])
131
+
132
+ # Perform octree2col
133
+ neigh_i = self.neigh[start:end]
134
+ valid = neigh_i >= 0
135
+ buffer.fill_(0)
136
+ buffer[valid] = data[neigh_i[valid]]
137
+
138
+ # The sub-matrix gemm
139
+ out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
140
+
141
+ return out
142
+
143
+ def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
144
+ weights: torch.Tensor):
145
+ r''' Performs the backward pass of octree-based convolution.
146
+ '''
147
+
148
+ # Type check
149
+ if grad.dtype != out.dtype:
150
+ grad = grad.to(out.dtype)
151
+ if weights.dtype != out.dtype:
152
+ weights = weights.to(out.dtype)
153
+
154
+ # Loop over each sub-matrix
155
+ for i in range(self.buffer_n):
156
+ start = i * self.buffer_h
157
+ end = (i + 1) * self.buffer_h
158
+
159
+ # The boundary case in the last iteration
160
+ if end > self.neigh.shape[0]:
161
+ end = self.neigh.shape[0]
162
+
163
+ # The sub-matrix gemm
164
+ buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
165
+ buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
166
+ buffer = buffer.to(out.dtype) # for pytorch.amp
167
+
168
+ # Performs col2octree
169
+ neigh_i = self.neigh[start:end]
170
+ valid = neigh_i >= 0
171
+ out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
172
+
173
+ return out
174
+
175
+ def weight_gemm(
176
+ self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
177
+ r''' Computes the gradient of the weight matrix.
178
+ '''
179
+
180
+ # Type check
181
+ if data.dtype != out.dtype:
182
+ data = data.to(out.dtype)
183
+ if grad.dtype != out.dtype:
184
+ grad = grad.to(out.dtype)
185
+
186
+ # Record the shape of out
187
+ out_shape = out.shape
188
+ out = out.flatten(0, 1)
189
+
190
+ # Initialize the buffer
191
+ buffer = data.new_empty(self.buffer_shape)
192
+
193
+ # Loop over each sub-matrix
194
+ for i in range(self.buffer_n):
195
+ start = i * self.buffer_h
196
+ end = (i + 1) * self.buffer_h
197
+
198
+ # The boundary case in the last iteration
199
+ if end > self.neigh.shape[0]:
200
+ d = end - self.neigh.shape[0]
201
+ end = self.neigh.shape[0]
202
+ buffer, _ = buffer.split([self.buffer_h-d, d])
203
+
204
+ # Perform octree2col
205
+ neigh_i = self.neigh[start:end]
206
+ valid = neigh_i >= 0
207
+ buffer.fill_(0)
208
+ buffer[valid] = data[neigh_i[valid]]
209
+
210
+ # Accumulate the gradient via gemm
211
+ out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
212
+
213
+ return out.view(out_shape)
214
+
215
+
216
+ class _OctreeConv(OctreeConvBase):
217
+ r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
218
+ '''
219
+
220
+ def is_conv_layer(self): return True
221
+
222
+
223
+ class _OctreeDeconv(OctreeConvBase):
224
+ r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
225
+ '''
226
+
227
+ def is_conv_layer(self): return False
228
+
229
+
230
+ class OctreeConvFunction(Function):
231
+ r''' Wrap the octree convolution for auto-diff.
232
+ '''
233
+
234
+ @staticmethod
235
+ def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
236
+ depth: int, in_channels: int, out_channels: int,
237
+ kernel_size: List[int] = [3, 3, 3], stride: int = 1,
238
+ nempty: bool = False, max_buffer: int = int(2e8)):
239
+ octree_conv = _OctreeConv(
240
+ in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
241
+ octree_conv.setup(octree, depth)
242
+ out = octree_conv.check_and_init(data)
243
+ out = octree_conv.forward_gemm(out, data, weights)
244
+
245
+ ctx.save_for_backward(data, weights)
246
+ ctx.octree_conv = octree_conv
247
+ return out
248
+
249
+ @staticmethod
250
+ def backward(ctx, grad):
251
+ data, weights = ctx.saved_tensors
252
+ octree_conv = ctx.octree_conv
253
+
254
+ grad_out = None
255
+ if ctx.needs_input_grad[0]:
256
+ grad_out = torch.zeros_like(data)
257
+ grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
258
+
259
+ grad_w = None
260
+ if ctx.needs_input_grad[1]:
261
+ grad_w = torch.zeros_like(weights)
262
+ grad_w = octree_conv.weight_gemm(grad_w, data, grad)
263
+
264
+ return (grad_out, grad_w) + (None,) * 8
265
+
266
+
267
+ class OctreeDeconvFunction(Function):
268
+ r''' Wrap the octree deconvolution for auto-diff.
269
+ '''
270
+
271
+ @staticmethod
272
+ def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
273
+ depth: int, in_channels: int, out_channels: int,
274
+ kernel_size: List[int] = [3, 3, 3], stride: int = 1,
275
+ nempty: bool = False, max_buffer: int = int(2e8)):
276
+ octree_deconv = _OctreeDeconv(
277
+ in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
278
+ octree_deconv.setup(octree, depth)
279
+ out = octree_deconv.check_and_init(data)
280
+ out = octree_deconv.backward_gemm(out, data, weights)
281
+
282
+ ctx.save_for_backward(data, weights)
283
+ ctx.octree_deconv = octree_deconv
284
+ return out
285
+
286
+ @staticmethod
287
+ def backward(ctx, grad):
288
+ data, weights = ctx.saved_tensors
289
+ octree_deconv = ctx.octree_deconv
290
+
291
+ grad_out = None
292
+ if ctx.needs_input_grad[0]:
293
+ grad_out = torch.zeros_like(data)
294
+ grad_out = octree_deconv.forward_gemm(grad_out, grad, weights)
295
+
296
+ grad_w = None
297
+ if ctx.needs_input_grad[1]:
298
+ grad_w = torch.zeros_like(weights)
299
+ grad_w = octree_deconv.weight_gemm(grad_w, grad, data)
300
+
301
+ return (grad_out, grad_w) + (None,) * 8
302
+
303
+
304
+ # alias
305
+ octree_conv = OctreeConvFunction.apply
306
+ octree_deconv = OctreeDeconvFunction.apply
307
+
308
+
309
+ class OctreeConv(OctreeConvBase, torch.nn.Module):
310
+ r''' Performs octree convolution.
311
+
312
+ Args:
313
+ in_channels (int): Number of input channels.
314
+ out_channels (int): Number of output channels.
315
+ kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
316
+ :obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
317
+ :obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
318
+ stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
319
+ nempty (bool): If True, only performs the convolution on non-empty
320
+ octree nodes.
321
+ direct_method (bool): If True, directly performs the convolution via using
322
+ gemm and octree2col/col2octree. The octree2col/col2octree needs to
323
+ construct a large matrix, which may consume a lot of memory. If False,
324
+ performs the convolution in a sub-matrix manner, which can save the
325
+ requied runtime memory.
326
+ use_bias (bool): If True, add a bias term to the convolution.
327
+ max_buffer (int): The maximum number of elements in the buffer, used when
328
+ :attr:`direct_method` is False.
329
+
330
+ .. note::
331
+ Each non-empty octree node has exactly 8 children nodes, among which some
332
+ children nodes are non-empty and some are empty. If :attr:`nempty` is true,
333
+ the convolution is performed on non-empty octree nodes only, which is exactly
334
+ the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
335
+ convolution is performed on all octree nodes, which is essential for shape
336
+ reconstruction tasks and can also be used in classification and segmentation
337
+ (with slightly better performance and larger memory cost).
338
+ '''
339
+
340
+ def __init__(self, in_channels: int, out_channels: int,
341
+ kernel_size: List[int] = [3], stride: int = 1,
342
+ nempty: bool = False, direct_method: bool = False,
343
+ use_bias: bool = False, max_buffer: int = int(2e8)):
344
+ super().__init__(
345
+ in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
346
+
347
+ self.direct_method = direct_method
348
+ self.use_bias = use_bias
349
+ self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
350
+ if self.use_bias:
351
+ self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
352
+ self.reset_parameters()
353
+
354
+ def reset_parameters(self):
355
+ xavier_uniform_(self.weights)
356
+ if self.use_bias:
357
+ torch.nn.init.zeros_(self.bias)
358
+
359
+ def is_conv_layer(self): return True
360
+
361
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
362
+ r''' Defines the octree convolution.
363
+
364
+ Args:
365
+ data (torch.Tensor): The input data.
366
+ octree (Octree): The corresponding octree.
367
+ depth (int): The depth of current octree.
368
+ '''
369
+
370
+ if self.direct_method:
371
+ col = octree2col(
372
+ data, octree, depth, self.kernel, self.stride, self.nempty)
373
+ out = torch.mm(col.flatten(1), self.weights.flatten(0, 1))
374
+ else:
375
+ out = octree_conv(
376
+ data, self.weights, octree, depth, self.in_channels,
377
+ self.out_channels, self.kernel_size, self.stride, self.nempty,
378
+ self.max_buffer)
379
+
380
+ if self.use_bias:
381
+ out += self.bias
382
+
383
+ if self.stride == 2 and not self.nempty:
384
+ out = octree_pad(out, octree, depth-1)
385
+ return out
386
+
387
+ def extra_repr(self) -> str:
388
+ r''' Sets the extra representation of the module.
389
+ '''
390
+
391
+ return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
392
+ 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
393
+ self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
394
+
395
+
396
+ class OctreeDeconv(OctreeConv):
397
+ r''' Performs octree deconvolution.
398
+
399
+ Please refer to :class:`OctreeConv` for the meaning of the arguments.
400
+ '''
401
+
402
+ def is_conv_layer(self): return False
403
+
404
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
405
+ r''' Defines the octree deconvolution.
406
+
407
+ Please refer to :meth:`OctreeConv.forward` for the meaning of the arguments.
408
+ '''
409
+
410
+ depth_col = depth
411
+ if self.stride == 2:
412
+ depth_col = depth + 1
413
+ if not self.nempty:
414
+ data = octree_depad(data, octree, depth)
415
+
416
+ if self.direct_method:
417
+ col = torch.mm(data, self.weights.flatten(0, 1).t())
418
+ col = col.view(col.shape[0], self.kdim, -1)
419
+ out = col2octree(
420
+ col, octree, depth_col, self.kernel, self.stride, self.nempty)
421
+ else:
422
+ out = octree_deconv(
423
+ data, self.weights, octree, depth, self.in_channels,
424
+ self.out_channels, self.kernel_size, self.stride, self.nempty,
425
+ self.max_buffer)
426
+
427
+ if self.use_bias:
428
+ out += self.bias
429
+ return out