ocnn 2.2.1__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/nn/octree_drop.py CHANGED
@@ -1,55 +1,55 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- from typing import Optional
10
-
11
- from ocnn.octree import Octree
12
-
13
-
14
- class OctreeDropPath(torch.nn.Module):
15
- r'''Drop paths (Stochastic Depth) per sample when applied in main path of
16
- residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
17
-
18
- Args:
19
- drop_prob (int): The probability of drop paths.
20
- nempty (bool): Indicate whether the input data only contains features of the
21
- non-empty octree nodes or not.
22
- scale_by_keep (bool): Whether to scale the kept features proportionally.
23
- '''
24
-
25
- def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
26
- scale_by_keep: bool = True):
27
- super().__init__()
28
-
29
- self.drop_prob = drop_prob
30
- self.nempty = nempty
31
- self.scale_by_keep = scale_by_keep
32
-
33
- def forward(self, data: torch.Tensor, octree: Octree, depth: int,
34
- batch_id: Optional[torch.Tensor] = None):
35
- r''''''
36
-
37
- if self.drop_prob <= 0.0 or not self.training:
38
- return data
39
-
40
- batch_size = octree.batch_size
41
- keep_prob = 1 - self.drop_prob
42
- rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
43
- rnd_tensor = torch.floor(rnd_tensor + keep_prob)
44
- if keep_prob > 0.0 and self.scale_by_keep:
45
- rnd_tensor.div_(keep_prob)
46
-
47
- if batch_id is None:
48
- batch_id = octree.batch_id(depth, self.nempty)
49
- drop_mask = rnd_tensor[batch_id]
50
- output = data * drop_mask
51
- return output
52
-
53
- def extra_repr(self) -> str:
54
- return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
55
- self.drop_prob, self.nempty, self.scale_by_keep) # noqa
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from typing import Optional
10
+
11
+ from ocnn.octree import Octree
12
+
13
+
14
+ class OctreeDropPath(torch.nn.Module):
15
+ r'''Drop paths (Stochastic Depth) per sample when applied in main path of
16
+ residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
17
+
18
+ Args:
19
+ drop_prob (int): The probability of drop paths.
20
+ nempty (bool): Indicate whether the input data only contains features of the
21
+ non-empty octree nodes or not.
22
+ scale_by_keep (bool): Whether to scale the kept features proportionally.
23
+ '''
24
+
25
+ def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
26
+ scale_by_keep: bool = True):
27
+ super().__init__()
28
+
29
+ self.drop_prob = drop_prob
30
+ self.nempty = nempty
31
+ self.scale_by_keep = scale_by_keep
32
+
33
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int,
34
+ batch_id: Optional[torch.Tensor] = None):
35
+ r''''''
36
+
37
+ if self.drop_prob <= 0.0 or not self.training:
38
+ return data
39
+
40
+ batch_size = octree.batch_size
41
+ keep_prob = 1 - self.drop_prob
42
+ rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
43
+ rnd_tensor = torch.floor(rnd_tensor + keep_prob)
44
+ if keep_prob > 0.0 and self.scale_by_keep:
45
+ rnd_tensor.div_(keep_prob)
46
+
47
+ if batch_id is None:
48
+ batch_id = octree.batch_id(depth, self.nempty)
49
+ drop_mask = rnd_tensor[batch_id]
50
+ output = data * drop_mask
51
+ return output
52
+
53
+ def extra_repr(self) -> str:
54
+ return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
55
+ self.drop_prob, self.nempty, self.scale_by_keep) # noqa
ocnn/nn/octree_dwconv.py CHANGED
@@ -1,204 +1,204 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from torch.autograd import Function
11
- from typing import List
12
-
13
- from ocnn.octree import Octree
14
- from ocnn.utils import scatter_add, xavier_uniform_
15
- from .octree_pad import octree_pad
16
- from .octree_conv import OctreeConvBase
17
-
18
-
19
- class OctreeDWConvBase(OctreeConvBase):
20
-
21
- def __init__(self, in_channels: int, kernel_size: List[int] = [3],
22
- stride: int = 1, nempty: bool = False,
23
- max_buffer: int = int(2e8)):
24
- super().__init__(
25
- in_channels, in_channels, kernel_size, stride, nempty, max_buffer)
26
- self.weights_shape = (self.kdim, 1, self.out_channels)
27
-
28
- def is_conv_layer(self): return True
29
-
30
- def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
31
- weights: torch.Tensor):
32
- r''' Peforms the forward pass of octree-based convolution.
33
- '''
34
-
35
- # Initialize the buffer
36
- buffer = data.new_empty(self.buffer_shape)
37
-
38
- # Loop over each sub-matrix
39
- for i in range(self.buffer_n):
40
- start = i * self.buffer_h
41
- end = (i + 1) * self.buffer_h
42
-
43
- # The boundary case in the last iteration
44
- if end > self.neigh.shape[0]:
45
- dis = end - self.neigh.shape[0]
46
- end = self.neigh.shape[0]
47
- buffer, _ = buffer.split([self.buffer_h-dis, dis])
48
-
49
- # Perform octree2col
50
- neigh_i = self.neigh[start:end]
51
- valid = neigh_i >= 0
52
- buffer.fill_(0)
53
- buffer[valid] = data[neigh_i[valid]]
54
-
55
- # The sub-matrix gemm
56
- # out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
57
- out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1))
58
- return out
59
-
60
- def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
61
- weights: torch.Tensor):
62
- r''' Performs the backward pass of octree-based convolution.
63
- '''
64
-
65
- # Loop over each sub-matrix
66
- for i in range(self.buffer_n):
67
- start = i * self.buffer_h
68
- end = (i + 1) * self.buffer_h
69
-
70
- # The boundary case in the last iteration
71
- if end > self.neigh.shape[0]:
72
- end = self.neigh.shape[0]
73
-
74
- # The sub-matrix gemm
75
- # buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
76
- # buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
77
- buffer = torch.einsum(
78
- 'ic,kc->ikc', grad[start:end], weights.flatten(0, 1))
79
-
80
- # Performs col2octree
81
- neigh_i = self.neigh[start:end]
82
- valid = neigh_i >= 0
83
- out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
84
-
85
- return out
86
-
87
- def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
88
- r''' Computes the gradient of the weight matrix.
89
- '''
90
-
91
- # Record the shape of out
92
- out_shape = out.shape
93
- out = out.flatten(0, 1)
94
-
95
- # Initialize the buffer
96
- buffer = data.new_empty(self.buffer_shape)
97
-
98
- # Loop over each sub-matrix
99
- for i in range(self.buffer_n):
100
- start = i * self.buffer_h
101
- end = (i + 1) * self.buffer_h
102
-
103
- # The boundary case in the last iteration
104
- if end > self.neigh.shape[0]:
105
- d = end - self.neigh.shape[0]
106
- end = self.neigh.shape[0]
107
- buffer, _ = buffer.split([self.buffer_h-d, d])
108
-
109
- # Perform octree2col
110
- neigh_i = self.neigh[start:end]
111
- valid = neigh_i >= 0
112
- buffer.fill_(0)
113
- buffer[valid] = data[neigh_i[valid]]
114
-
115
- # Accumulate the gradient via gemm
116
- # out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
117
- out += torch.einsum('ikc,ic->kc', buffer, grad[start:end])
118
- return out.view(out_shape)
119
-
120
-
121
- class OctreeDWConvFunction(Function):
122
- r''' Wrap the octree convolution for auto-diff.
123
- '''
124
-
125
- @staticmethod
126
- def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
127
- depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3],
128
- stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)):
129
- octree_conv = OctreeDWConvBase(
130
- in_channels, kernel_size, stride, nempty, max_buffer)
131
- octree_conv.setup(octree, depth)
132
- out = octree_conv.check_and_init(data)
133
- out = octree_conv.forward_gemm(out, data, weights)
134
-
135
- ctx.save_for_backward(data, weights)
136
- ctx.octree_conv = octree_conv
137
- return out
138
-
139
- @staticmethod
140
- def backward(ctx, grad):
141
- data, weights = ctx.saved_tensors
142
- octree_conv = ctx.octree_conv
143
-
144
- grad_out = None
145
- if ctx.needs_input_grad[0]:
146
- grad_out = torch.zeros_like(data)
147
- grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
148
-
149
- grad_w = None
150
- if ctx.needs_input_grad[1]:
151
- grad_w = torch.zeros_like(weights)
152
- grad_w = octree_conv.weight_gemm(grad_w, data, grad)
153
-
154
- return (grad_out, grad_w) + (None,) * 7
155
-
156
-
157
- # alias
158
- octree_dwconv = OctreeDWConvFunction.apply
159
-
160
-
161
- class OctreeDWConv(OctreeDWConvBase, torch.nn.Module):
162
- r''' Performs octree-based depth-wise convolution.
163
-
164
- Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments.
165
-
166
- .. note::
167
- This implementation uses the :func:`torch.einsum` and I find that the speed
168
- is relatively slow. Further optimization is needed to speed it up.
169
- '''
170
-
171
- def __init__(self, in_channels: int, kernel_size: List[int] = [3],
172
- stride: int = 1, nempty: bool = False, use_bias: bool = False,
173
- max_buffer: int = int(2e8)):
174
- super().__init__(in_channels, kernel_size, stride, nempty, max_buffer)
175
-
176
- self.use_bias = use_bias
177
- self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
178
- if self.use_bias:
179
- self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
180
- self.reset_parameters()
181
-
182
- def reset_parameters(self):
183
- xavier_uniform_(self.weights)
184
- if self.use_bias:
185
- torch.nn.init.zeros_(self.bias)
186
-
187
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
188
- r''''''
189
-
190
- out = octree_dwconv(
191
- data, self.weights, octree, depth, self.in_channels,
192
- self.kernel_size, self.stride, self.nempty, self.max_buffer)
193
-
194
- if self.use_bias:
195
- out += self.bias
196
-
197
- if self.stride == 2 and not self.nempty:
198
- out = octree_pad(out, octree, depth-1)
199
- return out
200
-
201
- def extra_repr(self) -> str:
202
- return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
203
- 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
204
- self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from torch.autograd import Function
11
+ from typing import List
12
+
13
+ from ocnn.octree import Octree
14
+ from ocnn.utils import scatter_add, xavier_uniform_
15
+ from .octree_pad import octree_pad
16
+ from .octree_conv import OctreeConvBase
17
+
18
+
19
+ class OctreeDWConvBase(OctreeConvBase):
20
+
21
+ def __init__(self, in_channels: int, kernel_size: List[int] = [3],
22
+ stride: int = 1, nempty: bool = False,
23
+ max_buffer: int = int(2e8)):
24
+ super().__init__(
25
+ in_channels, in_channels, kernel_size, stride, nempty, max_buffer)
26
+ self.weights_shape = (self.kdim, 1, self.out_channels)
27
+
28
+ def is_conv_layer(self): return True
29
+
30
+ def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
31
+ weights: torch.Tensor):
32
+ r''' Peforms the forward pass of octree-based convolution.
33
+ '''
34
+
35
+ # Initialize the buffer
36
+ buffer = data.new_empty(self.buffer_shape)
37
+
38
+ # Loop over each sub-matrix
39
+ for i in range(self.buffer_n):
40
+ start = i * self.buffer_h
41
+ end = (i + 1) * self.buffer_h
42
+
43
+ # The boundary case in the last iteration
44
+ if end > self.neigh.shape[0]:
45
+ dis = end - self.neigh.shape[0]
46
+ end = self.neigh.shape[0]
47
+ buffer, _ = buffer.split([self.buffer_h-dis, dis])
48
+
49
+ # Perform octree2col
50
+ neigh_i = self.neigh[start:end]
51
+ valid = neigh_i >= 0
52
+ buffer.fill_(0)
53
+ buffer[valid] = data[neigh_i[valid]]
54
+
55
+ # The sub-matrix gemm
56
+ # out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
57
+ out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1))
58
+ return out
59
+
60
+ def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
61
+ weights: torch.Tensor):
62
+ r''' Performs the backward pass of octree-based convolution.
63
+ '''
64
+
65
+ # Loop over each sub-matrix
66
+ for i in range(self.buffer_n):
67
+ start = i * self.buffer_h
68
+ end = (i + 1) * self.buffer_h
69
+
70
+ # The boundary case in the last iteration
71
+ if end > self.neigh.shape[0]:
72
+ end = self.neigh.shape[0]
73
+
74
+ # The sub-matrix gemm
75
+ # buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
76
+ # buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
77
+ buffer = torch.einsum(
78
+ 'ic,kc->ikc', grad[start:end], weights.flatten(0, 1))
79
+
80
+ # Performs col2octree
81
+ neigh_i = self.neigh[start:end]
82
+ valid = neigh_i >= 0
83
+ out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
84
+
85
+ return out
86
+
87
+ def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
88
+ r''' Computes the gradient of the weight matrix.
89
+ '''
90
+
91
+ # Record the shape of out
92
+ out_shape = out.shape
93
+ out = out.flatten(0, 1)
94
+
95
+ # Initialize the buffer
96
+ buffer = data.new_empty(self.buffer_shape)
97
+
98
+ # Loop over each sub-matrix
99
+ for i in range(self.buffer_n):
100
+ start = i * self.buffer_h
101
+ end = (i + 1) * self.buffer_h
102
+
103
+ # The boundary case in the last iteration
104
+ if end > self.neigh.shape[0]:
105
+ d = end - self.neigh.shape[0]
106
+ end = self.neigh.shape[0]
107
+ buffer, _ = buffer.split([self.buffer_h-d, d])
108
+
109
+ # Perform octree2col
110
+ neigh_i = self.neigh[start:end]
111
+ valid = neigh_i >= 0
112
+ buffer.fill_(0)
113
+ buffer[valid] = data[neigh_i[valid]]
114
+
115
+ # Accumulate the gradient via gemm
116
+ # out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
117
+ out += torch.einsum('ikc,ic->kc', buffer, grad[start:end])
118
+ return out.view(out_shape)
119
+
120
+
121
+ class OctreeDWConvFunction(Function):
122
+ r''' Wrap the octree convolution for auto-diff.
123
+ '''
124
+
125
+ @staticmethod
126
+ def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
127
+ depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3],
128
+ stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)):
129
+ octree_conv = OctreeDWConvBase(
130
+ in_channels, kernel_size, stride, nempty, max_buffer)
131
+ octree_conv.setup(octree, depth)
132
+ out = octree_conv.check_and_init(data)
133
+ out = octree_conv.forward_gemm(out, data, weights)
134
+
135
+ ctx.save_for_backward(data, weights)
136
+ ctx.octree_conv = octree_conv
137
+ return out
138
+
139
+ @staticmethod
140
+ def backward(ctx, grad):
141
+ data, weights = ctx.saved_tensors
142
+ octree_conv = ctx.octree_conv
143
+
144
+ grad_out = None
145
+ if ctx.needs_input_grad[0]:
146
+ grad_out = torch.zeros_like(data)
147
+ grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
148
+
149
+ grad_w = None
150
+ if ctx.needs_input_grad[1]:
151
+ grad_w = torch.zeros_like(weights)
152
+ grad_w = octree_conv.weight_gemm(grad_w, data, grad)
153
+
154
+ return (grad_out, grad_w) + (None,) * 7
155
+
156
+
157
+ # alias
158
+ octree_dwconv = OctreeDWConvFunction.apply
159
+
160
+
161
+ class OctreeDWConv(OctreeDWConvBase, torch.nn.Module):
162
+ r''' Performs octree-based depth-wise convolution.
163
+
164
+ Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments.
165
+
166
+ .. note::
167
+ This implementation uses the :func:`torch.einsum` and I find that the speed
168
+ is relatively slow. Further optimization is needed to speed it up.
169
+ '''
170
+
171
+ def __init__(self, in_channels: int, kernel_size: List[int] = [3],
172
+ stride: int = 1, nempty: bool = False, use_bias: bool = False,
173
+ max_buffer: int = int(2e8)):
174
+ super().__init__(in_channels, kernel_size, stride, nempty, max_buffer)
175
+
176
+ self.use_bias = use_bias
177
+ self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
178
+ if self.use_bias:
179
+ self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
180
+ self.reset_parameters()
181
+
182
+ def reset_parameters(self):
183
+ xavier_uniform_(self.weights)
184
+ if self.use_bias:
185
+ torch.nn.init.zeros_(self.bias)
186
+
187
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
188
+ r''''''
189
+
190
+ out = octree_dwconv(
191
+ data, self.weights, octree, depth, self.in_channels,
192
+ self.kernel_size, self.stride, self.nempty, self.max_buffer)
193
+
194
+ if self.use_bias:
195
+ out += self.bias
196
+
197
+ if self.stride == 2 and not self.nempty:
198
+ out = octree_pad(out, octree, depth-1)
199
+ return out
200
+
201
+ def extra_repr(self) -> str:
202
+ return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
203
+ 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
204
+ self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
@@ -0,0 +1,79 @@
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from typing import List
11
+
12
+ import ocnn
13
+ from ocnn.octree import Octree
14
+
15
+
16
+ class OctreeGroupConv(torch.nn.Module):
17
+ r''' Performs octree-based group convolution.
18
+
19
+ Args:
20
+ in_channels (int): Number of input channels.
21
+ out_channels (int): Number of output channels.
22
+ kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
23
+ :obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
24
+ :obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
25
+ stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
26
+ nempty (bool): If True, only performs the convolution on non-empty
27
+ octree nodes.
28
+ use_bias (bool): If True, add a bias term to the convolution.
29
+ group (int): The number of groups.
30
+
31
+ .. note::
32
+ Perform octree-based group convolution with a for-loop. The performance is
33
+ not optimal. Use this module only when the group number is small, otherwise
34
+ it may be slow.
35
+ '''
36
+
37
+ def __init__(self, in_channels: int, out_channels: int,
38
+ kernel_size: List[int] = [3], stride: int = 1,
39
+ nempty: bool = False, use_bias: bool = False,
40
+ group: int = 1):
41
+ super().__init__()
42
+
43
+ self.group = group
44
+ self.in_channels = in_channels
45
+ self.out_channels = out_channels
46
+ self.in_channels_per_group = in_channels // group
47
+ self.out_channels_per_group = out_channels // group
48
+ assert in_channels % group == 0 and out_channels % group == 0
49
+
50
+ self.convs = torch.nn.ModuleList([ocnn.nn.OctreeConv(
51
+ self.in_channels_per_group, self.out_channels_per_group,
52
+ kernel_size, stride, nempty, use_bias=use_bias)
53
+ for _ in range(group)])
54
+
55
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
56
+ r''' Defines the octree-based group convolution.
57
+
58
+ Args:
59
+ data (torch.Tensor): The input data.
60
+ octree (Octree): The corresponding octree.
61
+ depth (int): The depth of current octree.
62
+ '''
63
+
64
+ channels = data.shape[1]
65
+ assert channels == self.in_channels
66
+
67
+ outs = [None] * self.group
68
+ slices = torch.split(data, self.in_channels_per_group, dim=1)
69
+ for i in range(self.group):
70
+ outs[i] = self.convs[i](slices[i], octree, depth)
71
+ out = torch.cat(outs, dim=1)
72
+ return out
73
+
74
+ def extra_repr(self) -> str:
75
+ r''' Sets the extra representation of the module.
76
+ '''
77
+
78
+ return ('in_channels={}, out_channels={}, group={}').format(
79
+ self.in_channels, self.out_channels, self.group) # noqa