ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ocnn/__init__.py +24 -24
  2. ocnn/dataset.py +160 -160
  3. ocnn/models/__init__.py +29 -29
  4. ocnn/models/autoencoder.py +155 -155
  5. ocnn/models/hrnet.py +192 -192
  6. ocnn/models/image2shape.py +128 -128
  7. ocnn/models/lenet.py +46 -46
  8. ocnn/models/ounet.py +94 -94
  9. ocnn/models/resnet.py +53 -53
  10. ocnn/models/segnet.py +72 -72
  11. ocnn/models/unet.py +105 -105
  12. ocnn/modules/__init__.py +26 -26
  13. ocnn/modules/modules.py +303 -303
  14. ocnn/modules/resblocks.py +158 -158
  15. ocnn/nn/__init__.py +45 -44
  16. ocnn/nn/kernels/__init__.py +14 -0
  17. ocnn/nn/kernels/autotuner.py +416 -0
  18. ocnn/nn/kernels/config.py +67 -0
  19. ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
  20. ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
  21. ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
  22. ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
  23. ocnn/nn/kernels/utils.py +44 -0
  24. ocnn/nn/octree2col.py +53 -53
  25. ocnn/nn/octree2vox.py +50 -50
  26. ocnn/nn/octree_align.py +46 -46
  27. ocnn/nn/octree_conv.py +430 -429
  28. ocnn/nn/octree_conv_t.py +148 -0
  29. ocnn/nn/octree_drop.py +55 -55
  30. ocnn/nn/octree_dwconv.py +222 -222
  31. ocnn/nn/octree_gconv.py +79 -79
  32. ocnn/nn/octree_interp.py +196 -196
  33. ocnn/nn/octree_norm.py +126 -126
  34. ocnn/nn/octree_pad.py +39 -39
  35. ocnn/nn/octree_pool.py +200 -200
  36. ocnn/octree/__init__.py +22 -22
  37. ocnn/octree/octree.py +770 -770
  38. ocnn/octree/points.py +384 -323
  39. ocnn/octree/shuffled_key.py +115 -115
  40. ocnn/utils.py +205 -205
  41. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
  42. ocnn-2.3.0.dist-info/RECORD +45 -0
  43. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
  44. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
  45. ocnn-2.2.8.dist-info/RECORD +0 -36
  46. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/nn/octree_conv.py CHANGED
@@ -1,429 +1,430 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from torch.autograd import Function
11
- from typing import List
12
-
13
- from ocnn.octree import Octree
14
- from ocnn.utils import scatter_add, xavier_uniform_, resize_with_last_val, list2str
15
- from .octree2col import octree2col, col2octree
16
- from .octree_pad import octree_pad, octree_depad
17
-
18
-
19
- class OctreeConvBase:
20
-
21
- def __init__(self, in_channels: int, out_channels: int,
22
- kernel_size: List[int] = [3], stride: int = 1,
23
- nempty: bool = False, max_buffer: int = int(2e8)):
24
- super().__init__()
25
- self.in_channels = in_channels
26
- self.out_channels = out_channels
27
- self.kernel_size = resize_with_last_val(kernel_size)
28
- self.kernel = list2str(self.kernel_size)
29
- self.stride = stride
30
- self.nempty = nempty
31
- self.max_buffer = max_buffer # about 200M
32
-
33
- self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
34
- self.in_conv = in_channels if self.is_conv_layer() else out_channels
35
- self.out_conv = out_channels if self.is_conv_layer() else in_channels
36
- self.weights_shape = (self.kdim, self.in_conv, self.out_conv)
37
-
38
- def is_conv_layer(self):
39
- r''' Returns :obj:`True` to indicate this is a convolution layer.
40
- '''
41
-
42
- raise NotImplementedError
43
-
44
- def setup(self, octree: Octree, depth: int):
45
- r''' Setup the shapes of each tensor.
46
- This function MUST be called before :obj:`forward_gemm`, :obj:`backward_gemm`
47
- and :obj:`weight_gemm`.
48
- '''
49
-
50
- # The depth of tensors:
51
- # The in_depth and out_depth are the octree depth of the input and output
52
- # data; neigh_depth is the octree depth of the neighborhood information, as
53
- # well as `col` data, neigh_depth is always the same as the depth of larger
54
- # data when doing octree2col or col2octree.
55
- self.in_depth = depth
56
- self.out_depth = depth
57
- self.neigh_depth = depth
58
- if self.stride == 2:
59
- if self.is_conv_layer():
60
- self.out_depth = depth - 1
61
- else:
62
- self.out_depth = depth + 1
63
- self.neigh_depth = depth + 1
64
-
65
- # The height of tensors
66
- if self.nempty:
67
- self.in_h = octree.nnum_nempty[self.in_depth]
68
- self.out_h = octree.nnum_nempty[self.out_depth]
69
- else:
70
- self.in_h = octree.nnum[self.in_depth]
71
- self.out_h = octree.nnum[self.out_depth]
72
- if self.stride == 2:
73
- if self.is_conv_layer():
74
- self.out_h = octree.nnum_nempty[self.out_depth]
75
- else:
76
- self.in_h = octree.nnum_nempty[self.in_depth]
77
- self.in_shape = (self.in_h, self.in_channels)
78
- self.out_shape = (self.out_h, self.out_channels)
79
-
80
- # The neighborhood indices
81
- self.neigh = octree.get_neigh(
82
- self.neigh_depth, self.kernel, self.stride, self.nempty)
83
-
84
- # The heigh and number of the temporary buffer
85
- self.buffer_n = 1
86
- self.buffer_h = self.neigh.shape[0]
87
- ideal_size = self.buffer_h * self.kdim * self.in_conv
88
- if ideal_size > self.max_buffer:
89
- kc = self.kdim * self.in_conv # make `max_buffer` be divided
90
- max_buffer = self.max_buffer // kc * kc # by `kc` with no remainder
91
- self.buffer_n = (ideal_size + max_buffer - 1) // max_buffer
92
- self.buffer_h = (self.buffer_h + self.buffer_n - 1) // self.buffer_n
93
- self.buffer_shape = (self.buffer_h, self.kdim, self.in_conv)
94
-
95
- def check_and_init(self, data: torch.Tensor):
96
- r''' Checks the input data and initializes the shape of output data.
97
- '''
98
-
99
- # Check the shape of input data
100
- check = tuple(data.shape) == self.in_shape
101
- assert check, 'The shape of input data is wrong.'
102
-
103
- # Init the output data
104
- out = data.new_zeros(self.out_shape)
105
- return out
106
-
107
- def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
108
- weights: torch.Tensor):
109
- r''' Peforms the forward pass of octree-based convolution.
110
- '''
111
-
112
- # Type check
113
- if data.dtype != out.dtype:
114
- data = data.to(out.dtype)
115
- if weights.dtype != out.dtype:
116
- weights = weights.to(out.dtype)
117
-
118
- # Initialize the buffer
119
- buffer = data.new_empty(self.buffer_shape)
120
-
121
- # Loop over each sub-matrix
122
- for i in range(self.buffer_n):
123
- start = i * self.buffer_h
124
- end = (i + 1) * self.buffer_h
125
-
126
- # The boundary case in the last iteration
127
- if end > self.neigh.shape[0]:
128
- dis = end - self.neigh.shape[0]
129
- end = self.neigh.shape[0]
130
- buffer, _ = buffer.split([self.buffer_h-dis, dis])
131
-
132
- # Perform octree2col
133
- neigh_i = self.neigh[start:end]
134
- valid = neigh_i >= 0
135
- buffer.fill_(0)
136
- buffer[valid] = data[neigh_i[valid]]
137
-
138
- # The sub-matrix gemm
139
- out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
140
-
141
- return out
142
-
143
- def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
144
- weights: torch.Tensor):
145
- r''' Performs the backward pass of octree-based convolution.
146
- '''
147
-
148
- # Type check
149
- if grad.dtype != out.dtype:
150
- grad = grad.to(out.dtype)
151
- if weights.dtype != out.dtype:
152
- weights = weights.to(out.dtype)
153
-
154
- # Loop over each sub-matrix
155
- for i in range(self.buffer_n):
156
- start = i * self.buffer_h
157
- end = (i + 1) * self.buffer_h
158
-
159
- # The boundary case in the last iteration
160
- if end > self.neigh.shape[0]:
161
- end = self.neigh.shape[0]
162
-
163
- # The sub-matrix gemm
164
- buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
165
- buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
166
- buffer = buffer.to(out.dtype) # for pytorch.amp
167
-
168
- # Performs col2octree
169
- neigh_i = self.neigh[start:end]
170
- valid = neigh_i >= 0
171
- out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
172
-
173
- return out
174
-
175
- def weight_gemm(
176
- self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
177
- r''' Computes the gradient of the weight matrix.
178
- '''
179
-
180
- # Type check
181
- if data.dtype != out.dtype:
182
- data = data.to(out.dtype)
183
- if grad.dtype != out.dtype:
184
- grad = grad.to(out.dtype)
185
-
186
- # Record the shape of out
187
- out_shape = out.shape
188
- out = out.flatten(0, 1)
189
-
190
- # Initialize the buffer
191
- buffer = data.new_empty(self.buffer_shape)
192
-
193
- # Loop over each sub-matrix
194
- for i in range(self.buffer_n):
195
- start = i * self.buffer_h
196
- end = (i + 1) * self.buffer_h
197
-
198
- # The boundary case in the last iteration
199
- if end > self.neigh.shape[0]:
200
- d = end - self.neigh.shape[0]
201
- end = self.neigh.shape[0]
202
- buffer, _ = buffer.split([self.buffer_h-d, d])
203
-
204
- # Perform octree2col
205
- neigh_i = self.neigh[start:end]
206
- valid = neigh_i >= 0
207
- buffer.fill_(0)
208
- buffer[valid] = data[neigh_i[valid]]
209
-
210
- # Accumulate the gradient via gemm
211
- out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
212
-
213
- return out.view(out_shape)
214
-
215
-
216
- class _OctreeConv(OctreeConvBase):
217
- r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
218
- '''
219
-
220
- def is_conv_layer(self): return True
221
-
222
-
223
- class _OctreeDeconv(OctreeConvBase):
224
- r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
225
- '''
226
-
227
- def is_conv_layer(self): return False
228
-
229
-
230
- class OctreeConvFunction(Function):
231
- r''' Wrap the octree convolution for auto-diff.
232
- '''
233
-
234
- @staticmethod
235
- def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
236
- depth: int, in_channels: int, out_channels: int,
237
- kernel_size: List[int] = [3, 3, 3], stride: int = 1,
238
- nempty: bool = False, max_buffer: int = int(2e8)):
239
- octree_conv = _OctreeConv(
240
- in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
241
- octree_conv.setup(octree, depth)
242
- out = octree_conv.check_and_init(data)
243
- out = octree_conv.forward_gemm(out, data, weights)
244
-
245
- ctx.save_for_backward(data, weights)
246
- ctx.octree_conv = octree_conv
247
- return out
248
-
249
- @staticmethod
250
- def backward(ctx, grad):
251
- data, weights = ctx.saved_tensors
252
- octree_conv = ctx.octree_conv
253
-
254
- grad_out = None
255
- if ctx.needs_input_grad[0]:
256
- grad_out = torch.zeros_like(data)
257
- grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
258
-
259
- grad_w = None
260
- if ctx.needs_input_grad[1]:
261
- grad_w = torch.zeros_like(weights)
262
- grad_w = octree_conv.weight_gemm(grad_w, data, grad)
263
-
264
- return (grad_out, grad_w) + (None,) * 8
265
-
266
-
267
- class OctreeDeconvFunction(Function):
268
- r''' Wrap the octree deconvolution for auto-diff.
269
- '''
270
-
271
- @staticmethod
272
- def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
273
- depth: int, in_channels: int, out_channels: int,
274
- kernel_size: List[int] = [3, 3, 3], stride: int = 1,
275
- nempty: bool = False, max_buffer: int = int(2e8)):
276
- octree_deconv = _OctreeDeconv(
277
- in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
278
- octree_deconv.setup(octree, depth)
279
- out = octree_deconv.check_and_init(data)
280
- out = octree_deconv.backward_gemm(out, data, weights)
281
-
282
- ctx.save_for_backward(data, weights)
283
- ctx.octree_deconv = octree_deconv
284
- return out
285
-
286
- @staticmethod
287
- def backward(ctx, grad):
288
- data, weights = ctx.saved_tensors
289
- octree_deconv = ctx.octree_deconv
290
-
291
- grad_out = None
292
- if ctx.needs_input_grad[0]:
293
- grad_out = torch.zeros_like(data)
294
- grad_out = octree_deconv.forward_gemm(grad_out, grad, weights)
295
-
296
- grad_w = None
297
- if ctx.needs_input_grad[1]:
298
- grad_w = torch.zeros_like(weights)
299
- grad_w = octree_deconv.weight_gemm(grad_w, grad, data)
300
-
301
- return (grad_out, grad_w) + (None,) * 8
302
-
303
-
304
- # alias
305
- octree_conv = OctreeConvFunction.apply
306
- octree_deconv = OctreeDeconvFunction.apply
307
-
308
-
309
- class OctreeConv(OctreeConvBase, torch.nn.Module):
310
- r''' Performs octree convolution.
311
-
312
- Args:
313
- in_channels (int): Number of input channels.
314
- out_channels (int): Number of output channels.
315
- kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
316
- :obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
317
- :obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
318
- stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
319
- nempty (bool): If True, only performs the convolution on non-empty
320
- octree nodes.
321
- direct_method (bool): If True, directly performs the convolution via using
322
- gemm and octree2col/col2octree. The octree2col/col2octree needs to
323
- construct a large matrix, which may consume a lot of memory. If False,
324
- performs the convolution in a sub-matrix manner, which can save the
325
- requied runtime memory.
326
- use_bias (bool): If True, add a bias term to the convolution.
327
- max_buffer (int): The maximum number of elements in the buffer, used when
328
- :attr:`direct_method` is False.
329
-
330
- .. note::
331
- Each non-empty octree node has exactly 8 children nodes, among which some
332
- children nodes are non-empty and some are empty. If :attr:`nempty` is true,
333
- the convolution is performed on non-empty octree nodes only, which is exactly
334
- the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
335
- convolution is performed on all octree nodes, which is essential for shape
336
- reconstruction tasks and can also be used in classification and segmentation
337
- (with slightly better performance and larger memory cost).
338
- '''
339
-
340
- def __init__(self, in_channels: int, out_channels: int,
341
- kernel_size: List[int] = [3], stride: int = 1,
342
- nempty: bool = False, direct_method: bool = False,
343
- use_bias: bool = False, max_buffer: int = int(2e8)):
344
- super().__init__(
345
- in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
346
-
347
- self.direct_method = direct_method
348
- self.use_bias = use_bias
349
- self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
350
- if self.use_bias:
351
- self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
352
- self.reset_parameters()
353
-
354
- def reset_parameters(self):
355
- xavier_uniform_(self.weights)
356
- if self.use_bias:
357
- torch.nn.init.zeros_(self.bias)
358
-
359
- def is_conv_layer(self): return True
360
-
361
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
362
- r''' Defines the octree convolution.
363
-
364
- Args:
365
- data (torch.Tensor): The input data.
366
- octree (Octree): The corresponding octree.
367
- depth (int): The depth of current octree.
368
- '''
369
-
370
- if self.direct_method:
371
- col = octree2col(
372
- data, octree, depth, self.kernel, self.stride, self.nempty)
373
- out = torch.mm(col.flatten(1), self.weights.flatten(0, 1))
374
- else:
375
- out = octree_conv(
376
- data, self.weights, octree, depth, self.in_channels,
377
- self.out_channels, self.kernel_size, self.stride, self.nempty,
378
- self.max_buffer)
379
-
380
- if self.use_bias:
381
- out += self.bias
382
-
383
- if self.stride == 2 and not self.nempty:
384
- out = octree_pad(out, octree, depth-1)
385
- return out
386
-
387
- def extra_repr(self) -> str:
388
- r''' Sets the extra representation of the module.
389
- '''
390
-
391
- return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
392
- 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
393
- self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
394
-
395
-
396
- class OctreeDeconv(OctreeConv):
397
- r''' Performs octree deconvolution.
398
-
399
- Please refer to :class:`OctreeConv` for the meaning of the arguments.
400
- '''
401
-
402
- def is_conv_layer(self): return False
403
-
404
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
405
- r''' Defines the octree deconvolution.
406
-
407
- Please refer to :meth:`OctreeConv.forward` for the meaning of the arguments.
408
- '''
409
-
410
- depth_col = depth
411
- if self.stride == 2:
412
- depth_col = depth + 1
413
- if not self.nempty:
414
- data = octree_depad(data, octree, depth)
415
-
416
- if self.direct_method:
417
- col = torch.mm(data, self.weights.flatten(0, 1).t())
418
- col = col.view(col.shape[0], self.kdim, -1)
419
- out = col2octree(
420
- col, octree, depth_col, self.kernel, self.stride, self.nempty)
421
- else:
422
- out = octree_deconv(
423
- data, self.weights, octree, depth, self.in_channels,
424
- self.out_channels, self.kernel_size, self.stride, self.nempty,
425
- self.max_buffer)
426
-
427
- if self.use_bias:
428
- out += self.bias
429
- return out
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from torch.autograd import Function
11
+ from typing import List
12
+
13
+ from ocnn.octree import Octree
14
+ from ocnn.utils import scatter_add, xavier_uniform_, resize_with_last_val, list2str
15
+ from .octree2col import octree2col, col2octree
16
+ from .octree_pad import octree_pad, octree_depad
17
+
18
+
19
+ class OctreeConvBase:
20
+
21
+ def __init__(self, in_channels: int, out_channels: int,
22
+ kernel_size: List[int] = [3], stride: int = 1,
23
+ nempty: bool = False, max_buffer: int = int(2e8)):
24
+ super().__init__()
25
+ self.in_channels = in_channels
26
+ self.out_channels = out_channels
27
+ self.kernel_size = resize_with_last_val(kernel_size)
28
+ self.kernel = list2str(self.kernel_size)
29
+ self.stride = stride
30
+ self.nempty = nempty
31
+ self.max_buffer = max_buffer # about 200M
32
+
33
+ self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
34
+ self.in_conv = in_channels if self.is_conv_layer() else out_channels
35
+ self.out_conv = out_channels if self.is_conv_layer() else in_channels
36
+ self.weights_shape = (self.kdim, self.in_conv, self.out_conv)
37
+
38
+ def is_conv_layer(self):
39
+ r''' Returns :obj:`True` to indicate this is a convolution layer.
40
+ '''
41
+
42
+ raise NotImplementedError
43
+
44
+ def setup(self, octree: Octree, depth: int):
45
+ r''' Setup the shapes of each tensor.
46
+ This function MUST be called before :obj:`forward_gemm`, :obj:`backward_gemm`
47
+ and :obj:`weight_gemm`.
48
+ '''
49
+
50
+ # The depth of tensors:
51
+ # The in_depth and out_depth are the octree depth of the input and output
52
+ # data; neigh_depth is the octree depth of the neighborhood information, as
53
+ # well as `col` data, neigh_depth is always the same as the depth of larger
54
+ # data when doing octree2col or col2octree.
55
+ self.in_depth = depth
56
+ self.out_depth = depth
57
+ self.neigh_depth = depth
58
+ if self.stride == 2:
59
+ if self.is_conv_layer():
60
+ self.out_depth = depth - 1
61
+ else:
62
+ self.out_depth = depth + 1
63
+ self.neigh_depth = depth + 1
64
+
65
+ # The height of tensors
66
+ if self.nempty:
67
+ self.in_h = octree.nnum_nempty[self.in_depth]
68
+ self.out_h = octree.nnum_nempty[self.out_depth]
69
+ else:
70
+ self.in_h = octree.nnum[self.in_depth]
71
+ self.out_h = octree.nnum[self.out_depth]
72
+ if self.stride == 2:
73
+ if self.is_conv_layer():
74
+ self.out_h = octree.nnum_nempty[self.out_depth]
75
+ else:
76
+ self.in_h = octree.nnum_nempty[self.in_depth]
77
+ self.in_shape = (self.in_h, self.in_channels)
78
+ self.out_shape = (self.out_h, self.out_channels)
79
+
80
+ # The neighborhood indices
81
+ self.neigh = octree.get_neigh(
82
+ self.neigh_depth, self.kernel, self.stride, self.nempty)
83
+
84
+ # The heigh and number of the temporary buffer
85
+ self.buffer_n = 1
86
+ self.buffer_h = self.neigh.shape[0]
87
+ ideal_size = self.buffer_h * self.kdim * self.in_conv
88
+ if ideal_size > self.max_buffer:
89
+ kc = self.kdim * self.in_conv # make `max_buffer` be divided
90
+ max_buffer = self.max_buffer // kc * kc # by `kc` with no remainder
91
+ self.buffer_n = (ideal_size + max_buffer - 1) // max_buffer
92
+ self.buffer_h = (self.buffer_h + self.buffer_n - 1) // self.buffer_n
93
+ self.buffer_shape = (self.buffer_h, self.kdim, self.in_conv)
94
+
95
+ def check_and_init(self, data: torch.Tensor):
96
+ r''' Checks the input data and initializes the shape of output data.
97
+ '''
98
+
99
+ # Check the shape of input data
100
+ check = tuple(data.shape) == self.in_shape
101
+ assert check, ('The shape of input data is wrong: ' +
102
+ 'expected {}, got {}.'.format(self.in_shape, data.shape))
103
+
104
+ # Init the output data
105
+ out = data.new_zeros(self.out_shape)
106
+ return out
107
+
108
+ def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
109
+ weights: torch.Tensor):
110
+ r''' Peforms the forward pass of octree-based convolution.
111
+ '''
112
+
113
+ # Type check
114
+ if data.dtype != out.dtype:
115
+ data = data.to(out.dtype)
116
+ if weights.dtype != out.dtype:
117
+ weights = weights.to(out.dtype)
118
+
119
+ # Initialize the buffer
120
+ buffer = data.new_empty(self.buffer_shape)
121
+
122
+ # Loop over each sub-matrix
123
+ for i in range(self.buffer_n):
124
+ start = i * self.buffer_h
125
+ end = (i + 1) * self.buffer_h
126
+
127
+ # The boundary case in the last iteration
128
+ if end > self.neigh.shape[0]:
129
+ dis = end - self.neigh.shape[0]
130
+ end = self.neigh.shape[0]
131
+ buffer, _ = buffer.split([self.buffer_h-dis, dis])
132
+
133
+ # Perform octree2col
134
+ neigh_i = self.neigh[start:end]
135
+ valid = neigh_i >= 0
136
+ buffer.fill_(0)
137
+ buffer[valid] = data[neigh_i[valid]]
138
+
139
+ # The sub-matrix gemm
140
+ out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
141
+
142
+ return out
143
+
144
+ def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
145
+ weights: torch.Tensor):
146
+ r''' Performs the backward pass of octree-based convolution.
147
+ '''
148
+
149
+ # Type check
150
+ if grad.dtype != out.dtype:
151
+ grad = grad.to(out.dtype)
152
+ if weights.dtype != out.dtype:
153
+ weights = weights.to(out.dtype)
154
+
155
+ # Loop over each sub-matrix
156
+ for i in range(self.buffer_n):
157
+ start = i * self.buffer_h
158
+ end = (i + 1) * self.buffer_h
159
+
160
+ # The boundary case in the last iteration
161
+ if end > self.neigh.shape[0]:
162
+ end = self.neigh.shape[0]
163
+
164
+ # The sub-matrix gemm
165
+ buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
166
+ buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
167
+ buffer = buffer.to(out.dtype) # for pytorch.amp
168
+
169
+ # Performs col2octree
170
+ neigh_i = self.neigh[start:end]
171
+ valid = neigh_i >= 0
172
+ out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
173
+
174
+ return out
175
+
176
+ def weight_gemm(
177
+ self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
178
+ r''' Computes the gradient of the weight matrix.
179
+ '''
180
+
181
+ # Type check
182
+ if data.dtype != out.dtype:
183
+ data = data.to(out.dtype)
184
+ if grad.dtype != out.dtype:
185
+ grad = grad.to(out.dtype)
186
+
187
+ # Record the shape of out
188
+ out_shape = out.shape
189
+ out = out.flatten(0, 1)
190
+
191
+ # Initialize the buffer
192
+ buffer = data.new_empty(self.buffer_shape)
193
+
194
+ # Loop over each sub-matrix
195
+ for i in range(self.buffer_n):
196
+ start = i * self.buffer_h
197
+ end = (i + 1) * self.buffer_h
198
+
199
+ # The boundary case in the last iteration
200
+ if end > self.neigh.shape[0]:
201
+ d = end - self.neigh.shape[0]
202
+ end = self.neigh.shape[0]
203
+ buffer, _ = buffer.split([self.buffer_h-d, d])
204
+
205
+ # Perform octree2col
206
+ neigh_i = self.neigh[start:end]
207
+ valid = neigh_i >= 0
208
+ buffer.fill_(0)
209
+ buffer[valid] = data[neigh_i[valid]]
210
+
211
+ # Accumulate the gradient via gemm
212
+ out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
213
+
214
+ return out.view(out_shape)
215
+
216
+
217
+ class _OctreeConv(OctreeConvBase):
218
+ r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
219
+ '''
220
+
221
+ def is_conv_layer(self): return True
222
+
223
+
224
+ class _OctreeDeconv(OctreeConvBase):
225
+ r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
226
+ '''
227
+
228
+ def is_conv_layer(self): return False
229
+
230
+
231
+ class OctreeConvFunction(Function):
232
+ r''' Wrap the octree convolution for auto-diff.
233
+ '''
234
+
235
+ @staticmethod
236
+ def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
237
+ depth: int, in_channels: int, out_channels: int,
238
+ kernel_size: List[int] = [3, 3, 3], stride: int = 1,
239
+ nempty: bool = False, max_buffer: int = int(2e8)):
240
+ octree_conv = _OctreeConv(
241
+ in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
242
+ octree_conv.setup(octree, depth)
243
+ out = octree_conv.check_and_init(data)
244
+ out = octree_conv.forward_gemm(out, data, weights)
245
+
246
+ ctx.save_for_backward(data, weights)
247
+ ctx.octree_conv = octree_conv
248
+ return out
249
+
250
+ @staticmethod
251
+ def backward(ctx, grad):
252
+ data, weights = ctx.saved_tensors
253
+ octree_conv = ctx.octree_conv
254
+
255
+ grad_out = None
256
+ if ctx.needs_input_grad[0]:
257
+ grad_out = torch.zeros_like(data)
258
+ grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
259
+
260
+ grad_w = None
261
+ if ctx.needs_input_grad[1]:
262
+ grad_w = torch.zeros_like(weights)
263
+ grad_w = octree_conv.weight_gemm(grad_w, data, grad)
264
+
265
+ return (grad_out, grad_w) + (None,) * 8
266
+
267
+
268
+ class OctreeDeconvFunction(Function):
269
+ r''' Wrap the octree deconvolution for auto-diff.
270
+ '''
271
+
272
+ @staticmethod
273
+ def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
274
+ depth: int, in_channels: int, out_channels: int,
275
+ kernel_size: List[int] = [3, 3, 3], stride: int = 1,
276
+ nempty: bool = False, max_buffer: int = int(2e8)):
277
+ octree_deconv = _OctreeDeconv(
278
+ in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
279
+ octree_deconv.setup(octree, depth)
280
+ out = octree_deconv.check_and_init(data)
281
+ out = octree_deconv.backward_gemm(out, data, weights)
282
+
283
+ ctx.save_for_backward(data, weights)
284
+ ctx.octree_deconv = octree_deconv
285
+ return out
286
+
287
+ @staticmethod
288
+ def backward(ctx, grad):
289
+ data, weights = ctx.saved_tensors
290
+ octree_deconv = ctx.octree_deconv
291
+
292
+ grad_out = None
293
+ if ctx.needs_input_grad[0]:
294
+ grad_out = torch.zeros_like(data)
295
+ grad_out = octree_deconv.forward_gemm(grad_out, grad, weights)
296
+
297
+ grad_w = None
298
+ if ctx.needs_input_grad[1]:
299
+ grad_w = torch.zeros_like(weights)
300
+ grad_w = octree_deconv.weight_gemm(grad_w, grad, data)
301
+
302
+ return (grad_out, grad_w) + (None,) * 8
303
+
304
+
305
+ # alias
306
+ octree_conv = OctreeConvFunction.apply
307
+ octree_deconv = OctreeDeconvFunction.apply
308
+
309
+
310
+ class OctreeConv(OctreeConvBase, torch.nn.Module):
311
+ r''' Performs octree convolution.
312
+
313
+ Args:
314
+ in_channels (int): Number of input channels.
315
+ out_channels (int): Number of output channels.
316
+ kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
317
+ :obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
318
+ :obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
319
+ stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
320
+ nempty (bool): If True, only performs the convolution on non-empty
321
+ octree nodes.
322
+ direct_method (bool): If True, directly performs the convolution via using
323
+ gemm and octree2col/col2octree. The octree2col/col2octree needs to
324
+ construct a large matrix, which may consume a lot of memory. If False,
325
+ performs the convolution in a sub-matrix manner, which can save the
326
+ requied runtime memory.
327
+ use_bias (bool): If True, add a bias term to the convolution.
328
+ max_buffer (int): The maximum number of elements in the buffer, used when
329
+ :attr:`direct_method` is False.
330
+
331
+ .. note::
332
+ Each non-empty octree node has exactly 8 children nodes, among which some
333
+ children nodes are non-empty and some are empty. If :attr:`nempty` is true,
334
+ the convolution is performed on non-empty octree nodes only, which is exactly
335
+ the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
336
+ convolution is performed on all octree nodes, which is essential for shape
337
+ reconstruction tasks and can also be used in classification and segmentation
338
+ (with slightly better performance and larger memory cost).
339
+ '''
340
+
341
+ def __init__(self, in_channels: int, out_channels: int,
342
+ kernel_size: List[int] = [3], stride: int = 1,
343
+ nempty: bool = False, direct_method: bool = False,
344
+ use_bias: bool = False, max_buffer: int = int(2e8)):
345
+ super().__init__(
346
+ in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
347
+
348
+ self.direct_method = direct_method
349
+ self.use_bias = use_bias
350
+ self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
351
+ if self.use_bias:
352
+ self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
353
+ self.reset_parameters()
354
+
355
+ def reset_parameters(self):
356
+ xavier_uniform_(self.weights)
357
+ if self.use_bias:
358
+ torch.nn.init.zeros_(self.bias)
359
+
360
+ def is_conv_layer(self): return True
361
+
362
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
363
+ r''' Defines the octree convolution.
364
+
365
+ Args:
366
+ data (torch.Tensor): The input data.
367
+ octree (Octree): The corresponding octree.
368
+ depth (int): The depth of current octree.
369
+ '''
370
+
371
+ if self.direct_method:
372
+ col = octree2col(
373
+ data, octree, depth, self.kernel, self.stride, self.nempty)
374
+ out = torch.mm(col.flatten(1), self.weights.flatten(0, 1))
375
+ else:
376
+ out = octree_conv(
377
+ data, self.weights, octree, depth, self.in_channels,
378
+ self.out_channels, self.kernel_size, self.stride, self.nempty,
379
+ self.max_buffer)
380
+
381
+ if self.use_bias:
382
+ out += self.bias
383
+
384
+ if self.stride == 2 and not self.nempty:
385
+ out = octree_pad(out, octree, depth-1)
386
+ return out
387
+
388
+ def extra_repr(self) -> str:
389
+ r''' Sets the extra representation of the module.
390
+ '''
391
+
392
+ return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
393
+ 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
394
+ self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
395
+
396
+
397
+ class OctreeDeconv(OctreeConv):
398
+ r''' Performs octree deconvolution.
399
+
400
+ Please refer to :class:`OctreeConv` for the meaning of the arguments.
401
+ '''
402
+
403
+ def is_conv_layer(self): return False
404
+
405
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
406
+ r''' Defines the octree deconvolution.
407
+
408
+ Please refer to :meth:`OctreeConv.forward` for the meaning of the arguments.
409
+ '''
410
+
411
+ depth_col = depth
412
+ if self.stride == 2:
413
+ depth_col = depth + 1
414
+ if not self.nempty:
415
+ data = octree_depad(data, octree, depth)
416
+
417
+ if self.direct_method:
418
+ col = torch.mm(data, self.weights.flatten(0, 1).t())
419
+ col = col.view(col.shape[0], self.kdim, -1)
420
+ out = col2octree(
421
+ col, octree, depth_col, self.kernel, self.stride, self.nempty)
422
+ else:
423
+ out = octree_deconv(
424
+ data, self.weights, octree, depth, self.in_channels,
425
+ self.out_channels, self.kernel_size, self.stride, self.nempty,
426
+ self.max_buffer)
427
+
428
+ if self.use_bias:
429
+ out += self.bias
430
+ return out