ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ocnn/__init__.py +24 -24
  2. ocnn/dataset.py +160 -160
  3. ocnn/models/__init__.py +29 -29
  4. ocnn/models/autoencoder.py +155 -155
  5. ocnn/models/hrnet.py +192 -192
  6. ocnn/models/image2shape.py +128 -128
  7. ocnn/models/lenet.py +46 -46
  8. ocnn/models/ounet.py +94 -94
  9. ocnn/models/resnet.py +53 -53
  10. ocnn/models/segnet.py +72 -72
  11. ocnn/models/unet.py +105 -105
  12. ocnn/modules/__init__.py +26 -26
  13. ocnn/modules/modules.py +303 -303
  14. ocnn/modules/resblocks.py +158 -158
  15. ocnn/nn/__init__.py +45 -44
  16. ocnn/nn/kernels/__init__.py +14 -0
  17. ocnn/nn/kernels/autotuner.py +416 -0
  18. ocnn/nn/kernels/config.py +67 -0
  19. ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
  20. ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
  21. ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
  22. ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
  23. ocnn/nn/kernels/utils.py +44 -0
  24. ocnn/nn/octree2col.py +53 -53
  25. ocnn/nn/octree2vox.py +50 -50
  26. ocnn/nn/octree_align.py +46 -46
  27. ocnn/nn/octree_conv.py +430 -429
  28. ocnn/nn/octree_conv_t.py +148 -0
  29. ocnn/nn/octree_drop.py +55 -55
  30. ocnn/nn/octree_dwconv.py +222 -222
  31. ocnn/nn/octree_gconv.py +79 -79
  32. ocnn/nn/octree_interp.py +196 -196
  33. ocnn/nn/octree_norm.py +126 -126
  34. ocnn/nn/octree_pad.py +39 -39
  35. ocnn/nn/octree_pool.py +200 -200
  36. ocnn/octree/__init__.py +22 -22
  37. ocnn/octree/octree.py +770 -770
  38. ocnn/octree/points.py +384 -323
  39. ocnn/octree/shuffled_key.py +115 -115
  40. ocnn/utils.py +205 -205
  41. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
  42. ocnn-2.3.0.dist-info/RECORD +45 -0
  43. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
  44. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
  45. ocnn-2.2.8.dist-info/RECORD +0 -36
  46. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/nn/octree_dwconv.py CHANGED
@@ -1,222 +1,222 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from torch.autograd import Function
11
- from typing import List
12
-
13
- from ocnn.octree import Octree
14
- from ocnn.utils import scatter_add, xavier_uniform_
15
- from .octree_pad import octree_pad
16
- from .octree_conv import OctreeConvBase
17
-
18
-
19
- class OctreeDWConvBase(OctreeConvBase):
20
-
21
- def __init__(self, in_channels: int, kernel_size: List[int] = [3],
22
- stride: int = 1, nempty: bool = False,
23
- max_buffer: int = int(2e8)):
24
- super().__init__(
25
- in_channels, in_channels, kernel_size, stride, nempty, max_buffer)
26
- self.weights_shape = (self.kdim, 1, self.out_channels)
27
-
28
- def is_conv_layer(self): return True
29
-
30
- def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
31
- weights: torch.Tensor):
32
- r''' Peforms the forward pass of octree-based convolution.
33
- '''
34
-
35
- # Type check
36
- if data.dtype != out.dtype:
37
- data = data.to(out.dtype)
38
- if weights.dtype != out.dtype:
39
- weights = weights.to(out.dtype)
40
-
41
- # Initialize the buffer
42
- buffer = data.new_empty(self.buffer_shape)
43
-
44
- # Loop over each sub-matrix
45
- for i in range(self.buffer_n):
46
- start = i * self.buffer_h
47
- end = (i + 1) * self.buffer_h
48
-
49
- # The boundary case in the last iteration
50
- if end > self.neigh.shape[0]:
51
- dis = end - self.neigh.shape[0]
52
- end = self.neigh.shape[0]
53
- buffer, _ = buffer.split([self.buffer_h-dis, dis])
54
-
55
- # Perform octree2col
56
- neigh_i = self.neigh[start:end]
57
- valid = neigh_i >= 0
58
- buffer.fill_(0)
59
- buffer[valid] = data[neigh_i[valid]]
60
-
61
- # The sub-matrix gemm
62
- # out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
63
- out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1))
64
- return out
65
-
66
- def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
67
- weights: torch.Tensor):
68
- r''' Performs the backward pass of octree-based convolution.
69
- '''
70
-
71
- # Type check
72
- if grad.dtype != out.dtype:
73
- grad = grad.to(out.dtype)
74
- if weights.dtype != out.dtype:
75
- weights = weights.to(out.dtype)
76
-
77
- # Loop over each sub-matrix
78
- for i in range(self.buffer_n):
79
- start = i * self.buffer_h
80
- end = (i + 1) * self.buffer_h
81
-
82
- # The boundary case in the last iteration
83
- if end > self.neigh.shape[0]:
84
- end = self.neigh.shape[0]
85
-
86
- # The sub-matrix gemm
87
- # buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
88
- # buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
89
- buffer = torch.einsum(
90
- 'ic,kc->ikc', grad[start:end], weights.flatten(0, 1))
91
-
92
- # Performs col2octree
93
- neigh_i = self.neigh[start:end]
94
- valid = neigh_i >= 0
95
- out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
96
-
97
- return out
98
-
99
- def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
100
- r''' Computes the gradient of the weight matrix.
101
- '''
102
-
103
- # Type check
104
- if data.dtype != out.dtype:
105
- data = data.to(out.dtype)
106
- if grad.dtype != out.dtype:
107
- grad = grad.to(out.dtype)
108
-
109
- # Record the shape of out
110
- out_shape = out.shape
111
- out = out.flatten(0, 1)
112
-
113
- # Initialize the buffer
114
- buffer = data.new_empty(self.buffer_shape)
115
-
116
- # Loop over each sub-matrix
117
- for i in range(self.buffer_n):
118
- start = i * self.buffer_h
119
- end = (i + 1) * self.buffer_h
120
-
121
- # The boundary case in the last iteration
122
- if end > self.neigh.shape[0]:
123
- d = end - self.neigh.shape[0]
124
- end = self.neigh.shape[0]
125
- buffer, _ = buffer.split([self.buffer_h-d, d])
126
-
127
- # Perform octree2col
128
- neigh_i = self.neigh[start:end]
129
- valid = neigh_i >= 0
130
- buffer.fill_(0)
131
- buffer[valid] = data[neigh_i[valid]]
132
-
133
- # Accumulate the gradient via gemm
134
- # out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
135
- out += torch.einsum('ikc,ic->kc', buffer, grad[start:end])
136
- return out.view(out_shape)
137
-
138
-
139
- class OctreeDWConvFunction(Function):
140
- r''' Wrap the octree convolution for auto-diff.
141
- '''
142
-
143
- @staticmethod
144
- def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
145
- depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3],
146
- stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)):
147
- octree_conv = OctreeDWConvBase(
148
- in_channels, kernel_size, stride, nempty, max_buffer)
149
- octree_conv.setup(octree, depth)
150
- out = octree_conv.check_and_init(data)
151
- out = octree_conv.forward_gemm(out, data, weights)
152
-
153
- ctx.save_for_backward(data, weights)
154
- ctx.octree_conv = octree_conv
155
- return out
156
-
157
- @staticmethod
158
- def backward(ctx, grad):
159
- data, weights = ctx.saved_tensors
160
- octree_conv = ctx.octree_conv
161
-
162
- grad_out = None
163
- if ctx.needs_input_grad[0]:
164
- grad_out = torch.zeros_like(data)
165
- grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
166
-
167
- grad_w = None
168
- if ctx.needs_input_grad[1]:
169
- grad_w = torch.zeros_like(weights)
170
- grad_w = octree_conv.weight_gemm(grad_w, data, grad)
171
-
172
- return (grad_out, grad_w) + (None,) * 7
173
-
174
-
175
- # alias
176
- octree_dwconv = OctreeDWConvFunction.apply
177
-
178
-
179
- class OctreeDWConv(OctreeDWConvBase, torch.nn.Module):
180
- r''' Performs octree-based depth-wise convolution.
181
-
182
- Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments.
183
-
184
- .. note::
185
- This implementation uses the :func:`torch.einsum` and I find that the speed
186
- is relatively slow. Further optimization is needed to speed it up.
187
- '''
188
-
189
- def __init__(self, in_channels: int, kernel_size: List[int] = [3],
190
- stride: int = 1, nempty: bool = False, use_bias: bool = False,
191
- max_buffer: int = int(2e8)):
192
- super().__init__(in_channels, kernel_size, stride, nempty, max_buffer)
193
-
194
- self.use_bias = use_bias
195
- self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
196
- if self.use_bias:
197
- self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
198
- self.reset_parameters()
199
-
200
- def reset_parameters(self):
201
- xavier_uniform_(self.weights)
202
- if self.use_bias:
203
- torch.nn.init.zeros_(self.bias)
204
-
205
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
206
- r''''''
207
-
208
- out = octree_dwconv(
209
- data, self.weights, octree, depth, self.in_channels,
210
- self.kernel_size, self.stride, self.nempty, self.max_buffer)
211
-
212
- if self.use_bias:
213
- out += self.bias
214
-
215
- if self.stride == 2 and not self.nempty:
216
- out = octree_pad(out, octree, depth-1)
217
- return out
218
-
219
- def extra_repr(self) -> str:
220
- return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
221
- 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
222
- self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from torch.autograd import Function
11
+ from typing import List
12
+
13
+ from ocnn.octree import Octree
14
+ from ocnn.utils import scatter_add, xavier_uniform_
15
+ from .octree_pad import octree_pad
16
+ from .octree_conv import OctreeConvBase
17
+
18
+
19
+ class OctreeDWConvBase(OctreeConvBase):
20
+
21
+ def __init__(self, in_channels: int, kernel_size: List[int] = [3],
22
+ stride: int = 1, nempty: bool = False,
23
+ max_buffer: int = int(2e8)):
24
+ super().__init__(
25
+ in_channels, in_channels, kernel_size, stride, nempty, max_buffer)
26
+ self.weights_shape = (self.kdim, 1, self.out_channels)
27
+
28
+ def is_conv_layer(self): return True
29
+
30
+ def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
31
+ weights: torch.Tensor):
32
+ r''' Peforms the forward pass of octree-based convolution.
33
+ '''
34
+
35
+ # Type check
36
+ if data.dtype != out.dtype:
37
+ data = data.to(out.dtype)
38
+ if weights.dtype != out.dtype:
39
+ weights = weights.to(out.dtype)
40
+
41
+ # Initialize the buffer
42
+ buffer = data.new_empty(self.buffer_shape)
43
+
44
+ # Loop over each sub-matrix
45
+ for i in range(self.buffer_n):
46
+ start = i * self.buffer_h
47
+ end = (i + 1) * self.buffer_h
48
+
49
+ # The boundary case in the last iteration
50
+ if end > self.neigh.shape[0]:
51
+ dis = end - self.neigh.shape[0]
52
+ end = self.neigh.shape[0]
53
+ buffer, _ = buffer.split([self.buffer_h-dis, dis])
54
+
55
+ # Perform octree2col
56
+ neigh_i = self.neigh[start:end]
57
+ valid = neigh_i >= 0
58
+ buffer.fill_(0)
59
+ buffer[valid] = data[neigh_i[valid]]
60
+
61
+ # The sub-matrix gemm
62
+ # out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
63
+ out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1))
64
+ return out
65
+
66
+ def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
67
+ weights: torch.Tensor):
68
+ r''' Performs the backward pass of octree-based convolution.
69
+ '''
70
+
71
+ # Type check
72
+ if grad.dtype != out.dtype:
73
+ grad = grad.to(out.dtype)
74
+ if weights.dtype != out.dtype:
75
+ weights = weights.to(out.dtype)
76
+
77
+ # Loop over each sub-matrix
78
+ for i in range(self.buffer_n):
79
+ start = i * self.buffer_h
80
+ end = (i + 1) * self.buffer_h
81
+
82
+ # The boundary case in the last iteration
83
+ if end > self.neigh.shape[0]:
84
+ end = self.neigh.shape[0]
85
+
86
+ # The sub-matrix gemm
87
+ # buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
88
+ # buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
89
+ buffer = torch.einsum(
90
+ 'ic,kc->ikc', grad[start:end], weights.flatten(0, 1))
91
+
92
+ # Performs col2octree
93
+ neigh_i = self.neigh[start:end]
94
+ valid = neigh_i >= 0
95
+ out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
96
+
97
+ return out
98
+
99
+ def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
100
+ r''' Computes the gradient of the weight matrix.
101
+ '''
102
+
103
+ # Type check
104
+ if data.dtype != out.dtype:
105
+ data = data.to(out.dtype)
106
+ if grad.dtype != out.dtype:
107
+ grad = grad.to(out.dtype)
108
+
109
+ # Record the shape of out
110
+ out_shape = out.shape
111
+ out = out.flatten(0, 1)
112
+
113
+ # Initialize the buffer
114
+ buffer = data.new_empty(self.buffer_shape)
115
+
116
+ # Loop over each sub-matrix
117
+ for i in range(self.buffer_n):
118
+ start = i * self.buffer_h
119
+ end = (i + 1) * self.buffer_h
120
+
121
+ # The boundary case in the last iteration
122
+ if end > self.neigh.shape[0]:
123
+ d = end - self.neigh.shape[0]
124
+ end = self.neigh.shape[0]
125
+ buffer, _ = buffer.split([self.buffer_h-d, d])
126
+
127
+ # Perform octree2col
128
+ neigh_i = self.neigh[start:end]
129
+ valid = neigh_i >= 0
130
+ buffer.fill_(0)
131
+ buffer[valid] = data[neigh_i[valid]]
132
+
133
+ # Accumulate the gradient via gemm
134
+ # out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
135
+ out += torch.einsum('ikc,ic->kc', buffer, grad[start:end])
136
+ return out.view(out_shape)
137
+
138
+
139
+ class OctreeDWConvFunction(Function):
140
+ r''' Wrap the octree convolution for auto-diff.
141
+ '''
142
+
143
+ @staticmethod
144
+ def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
145
+ depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3],
146
+ stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)):
147
+ octree_conv = OctreeDWConvBase(
148
+ in_channels, kernel_size, stride, nempty, max_buffer)
149
+ octree_conv.setup(octree, depth)
150
+ out = octree_conv.check_and_init(data)
151
+ out = octree_conv.forward_gemm(out, data, weights)
152
+
153
+ ctx.save_for_backward(data, weights)
154
+ ctx.octree_conv = octree_conv
155
+ return out
156
+
157
+ @staticmethod
158
+ def backward(ctx, grad):
159
+ data, weights = ctx.saved_tensors
160
+ octree_conv = ctx.octree_conv
161
+
162
+ grad_out = None
163
+ if ctx.needs_input_grad[0]:
164
+ grad_out = torch.zeros_like(data)
165
+ grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
166
+
167
+ grad_w = None
168
+ if ctx.needs_input_grad[1]:
169
+ grad_w = torch.zeros_like(weights)
170
+ grad_w = octree_conv.weight_gemm(grad_w, data, grad)
171
+
172
+ return (grad_out, grad_w) + (None,) * 7
173
+
174
+
175
+ # alias
176
+ octree_dwconv = OctreeDWConvFunction.apply
177
+
178
+
179
+ class OctreeDWConv(OctreeDWConvBase, torch.nn.Module):
180
+ r''' Performs octree-based depth-wise convolution.
181
+
182
+ Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments.
183
+
184
+ .. note::
185
+ This implementation uses the :func:`torch.einsum` and I find that the speed
186
+ is relatively slow. Further optimization is needed to speed it up.
187
+ '''
188
+
189
+ def __init__(self, in_channels: int, kernel_size: List[int] = [3],
190
+ stride: int = 1, nempty: bool = False, use_bias: bool = False,
191
+ max_buffer: int = int(2e8)):
192
+ super().__init__(in_channels, kernel_size, stride, nempty, max_buffer)
193
+
194
+ self.use_bias = use_bias
195
+ self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
196
+ if self.use_bias:
197
+ self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
198
+ self.reset_parameters()
199
+
200
+ def reset_parameters(self):
201
+ xavier_uniform_(self.weights)
202
+ if self.use_bias:
203
+ torch.nn.init.zeros_(self.bias)
204
+
205
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
206
+ r''''''
207
+
208
+ out = octree_dwconv(
209
+ data, self.weights, octree, depth, self.in_channels,
210
+ self.kernel_size, self.stride, self.nempty, self.max_buffer)
211
+
212
+ if self.use_bias:
213
+ out += self.bias
214
+
215
+ if self.stride == 2 and not self.nempty:
216
+ out = octree_pad(out, octree, depth-1)
217
+ return out
218
+
219
+ def extra_repr(self) -> str:
220
+ return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
221
+ 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
222
+ self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
ocnn/nn/octree_gconv.py CHANGED
@@ -1,79 +1,79 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from typing import List
11
-
12
- import ocnn
13
- from ocnn.octree import Octree
14
-
15
-
16
- class OctreeGroupConv(torch.nn.Module):
17
- r''' Performs octree-based group convolution.
18
-
19
- Args:
20
- in_channels (int): Number of input channels.
21
- out_channels (int): Number of output channels.
22
- kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
23
- :obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
24
- :obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
25
- stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
26
- nempty (bool): If True, only performs the convolution on non-empty
27
- octree nodes.
28
- use_bias (bool): If True, add a bias term to the convolution.
29
- group (int): The number of groups.
30
-
31
- .. note::
32
- Perform octree-based group convolution with a for-loop. The performance is
33
- not optimal. Use this module only when the group number is small, otherwise
34
- it may be slow.
35
- '''
36
-
37
- def __init__(self, in_channels: int, out_channels: int,
38
- kernel_size: List[int] = [3], stride: int = 1,
39
- nempty: bool = False, use_bias: bool = False,
40
- group: int = 1):
41
- super().__init__()
42
-
43
- self.group = group
44
- self.in_channels = in_channels
45
- self.out_channels = out_channels
46
- self.in_channels_per_group = in_channels // group
47
- self.out_channels_per_group = out_channels // group
48
- assert in_channels % group == 0 and out_channels % group == 0
49
-
50
- self.convs = torch.nn.ModuleList([ocnn.nn.OctreeConv(
51
- self.in_channels_per_group, self.out_channels_per_group,
52
- kernel_size, stride, nempty, use_bias=use_bias)
53
- for _ in range(group)])
54
-
55
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
56
- r''' Defines the octree-based group convolution.
57
-
58
- Args:
59
- data (torch.Tensor): The input data.
60
- octree (Octree): The corresponding octree.
61
- depth (int): The depth of current octree.
62
- '''
63
-
64
- channels = data.shape[1]
65
- assert channels == self.in_channels
66
-
67
- outs = [None] * self.group
68
- slices = torch.split(data, self.in_channels_per_group, dim=1)
69
- for i in range(self.group):
70
- outs[i] = self.convs[i](slices[i], octree, depth)
71
- out = torch.cat(outs, dim=1)
72
- return out
73
-
74
- def extra_repr(self) -> str:
75
- r''' Sets the extra representation of the module.
76
- '''
77
-
78
- return ('in_channels={}, out_channels={}, group={}').format(
79
- self.in_channels, self.out_channels, self.group) # noqa
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from typing import List
11
+
12
+ import ocnn
13
+ from ocnn.octree import Octree
14
+
15
+
16
+ class OctreeGroupConv(torch.nn.Module):
17
+ r''' Performs octree-based group convolution.
18
+
19
+ Args:
20
+ in_channels (int): Number of input channels.
21
+ out_channels (int): Number of output channels.
22
+ kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
23
+ :obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
24
+ :obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
25
+ stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
26
+ nempty (bool): If True, only performs the convolution on non-empty
27
+ octree nodes.
28
+ use_bias (bool): If True, add a bias term to the convolution.
29
+ group (int): The number of groups.
30
+
31
+ .. note::
32
+ Perform octree-based group convolution with a for-loop. The performance is
33
+ not optimal. Use this module only when the group number is small, otherwise
34
+ it may be slow.
35
+ '''
36
+
37
+ def __init__(self, in_channels: int, out_channels: int,
38
+ kernel_size: List[int] = [3], stride: int = 1,
39
+ nempty: bool = False, use_bias: bool = False,
40
+ group: int = 1):
41
+ super().__init__()
42
+
43
+ self.group = group
44
+ self.in_channels = in_channels
45
+ self.out_channels = out_channels
46
+ self.in_channels_per_group = in_channels // group
47
+ self.out_channels_per_group = out_channels // group
48
+ assert in_channels % group == 0 and out_channels % group == 0
49
+
50
+ self.convs = torch.nn.ModuleList([ocnn.nn.OctreeConv(
51
+ self.in_channels_per_group, self.out_channels_per_group,
52
+ kernel_size, stride, nempty, use_bias=use_bias)
53
+ for _ in range(group)])
54
+
55
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
56
+ r''' Defines the octree-based group convolution.
57
+
58
+ Args:
59
+ data (torch.Tensor): The input data.
60
+ octree (Octree): The corresponding octree.
61
+ depth (int): The depth of current octree.
62
+ '''
63
+
64
+ channels = data.shape[1]
65
+ assert channels == self.in_channels
66
+
67
+ outs = [None] * self.group
68
+ slices = torch.split(data, self.in_channels_per_group, dim=1)
69
+ for i in range(self.group):
70
+ outs[i] = self.convs[i](slices[i], octree, depth)
71
+ out = torch.cat(outs, dim=1)
72
+ return out
73
+
74
+ def extra_repr(self) -> str:
75
+ r''' Sets the extra representation of the module.
76
+ '''
77
+
78
+ return ('in_channels={}, out_channels={}, group={}').format(
79
+ self.in_channels, self.out_channels, self.group) # noqa