ocnn 2.2.5__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/nn/octree_drop.py CHANGED
@@ -1,55 +1,55 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- from typing import Optional
10
-
11
- from ocnn.octree import Octree
12
-
13
-
14
- class OctreeDropPath(torch.nn.Module):
15
- r'''Drop paths (Stochastic Depth) per sample when applied in main path of
16
- residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
17
-
18
- Args:
19
- drop_prob (int): The probability of drop paths.
20
- nempty (bool): Indicate whether the input data only contains features of the
21
- non-empty octree nodes or not.
22
- scale_by_keep (bool): Whether to scale the kept features proportionally.
23
- '''
24
-
25
- def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
26
- scale_by_keep: bool = True):
27
- super().__init__()
28
-
29
- self.drop_prob = drop_prob
30
- self.nempty = nempty
31
- self.scale_by_keep = scale_by_keep
32
-
33
- def forward(self, data: torch.Tensor, octree: Octree, depth: int,
34
- batch_id: Optional[torch.Tensor] = None):
35
- r''''''
36
-
37
- if self.drop_prob <= 0.0 or not self.training:
38
- return data
39
-
40
- batch_size = octree.batch_size
41
- keep_prob = 1 - self.drop_prob
42
- rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
43
- rnd_tensor = torch.floor(rnd_tensor + keep_prob)
44
- if keep_prob > 0.0 and self.scale_by_keep:
45
- rnd_tensor.div_(keep_prob)
46
-
47
- if batch_id is None:
48
- batch_id = octree.batch_id(depth, self.nempty)
49
- drop_mask = rnd_tensor[batch_id]
50
- output = data * drop_mask
51
- return output
52
-
53
- def extra_repr(self) -> str:
54
- return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
55
- self.drop_prob, self.nempty, self.scale_by_keep) # noqa
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from typing import Optional
10
+
11
+ from ocnn.octree import Octree
12
+
13
+
14
+ class OctreeDropPath(torch.nn.Module):
15
+ r'''Drop paths (Stochastic Depth) per sample when applied in main path of
16
+ residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
17
+
18
+ Args:
19
+ drop_prob (int): The probability of drop paths.
20
+ nempty (bool): Indicate whether the input data only contains features of the
21
+ non-empty octree nodes or not.
22
+ scale_by_keep (bool): Whether to scale the kept features proportionally.
23
+ '''
24
+
25
+ def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
26
+ scale_by_keep: bool = True):
27
+ super().__init__()
28
+
29
+ self.drop_prob = drop_prob
30
+ self.nempty = nempty
31
+ self.scale_by_keep = scale_by_keep
32
+
33
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int,
34
+ batch_id: Optional[torch.Tensor] = None):
35
+ r''''''
36
+
37
+ if self.drop_prob <= 0.0 or not self.training:
38
+ return data
39
+
40
+ batch_size = octree.batch_size
41
+ keep_prob = 1 - self.drop_prob
42
+ rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
43
+ rnd_tensor = torch.floor(rnd_tensor + keep_prob)
44
+ if keep_prob > 0.0 and self.scale_by_keep:
45
+ rnd_tensor.div_(keep_prob)
46
+
47
+ if batch_id is None:
48
+ batch_id = octree.batch_id(depth, self.nempty)
49
+ drop_mask = rnd_tensor[batch_id]
50
+ output = data * drop_mask
51
+ return output
52
+
53
+ def extra_repr(self) -> str:
54
+ return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
55
+ self.drop_prob, self.nempty, self.scale_by_keep) # noqa
ocnn/nn/octree_dwconv.py CHANGED
@@ -1,222 +1,222 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from torch.autograd import Function
11
- from typing import List
12
-
13
- from ocnn.octree import Octree
14
- from ocnn.utils import scatter_add, xavier_uniform_
15
- from .octree_pad import octree_pad
16
- from .octree_conv import OctreeConvBase
17
-
18
-
19
- class OctreeDWConvBase(OctreeConvBase):
20
-
21
- def __init__(self, in_channels: int, kernel_size: List[int] = [3],
22
- stride: int = 1, nempty: bool = False,
23
- max_buffer: int = int(2e8)):
24
- super().__init__(
25
- in_channels, in_channels, kernel_size, stride, nempty, max_buffer)
26
- self.weights_shape = (self.kdim, 1, self.out_channels)
27
-
28
- def is_conv_layer(self): return True
29
-
30
- def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
31
- weights: torch.Tensor):
32
- r''' Peforms the forward pass of octree-based convolution.
33
- '''
34
-
35
- # Type check
36
- if data.dtype != out.dtype:
37
- data = data.to(out.dtype)
38
- if weights.dtype != out.dtype:
39
- weights = weights.to(out.dtype)
40
-
41
- # Initialize the buffer
42
- buffer = data.new_empty(self.buffer_shape)
43
-
44
- # Loop over each sub-matrix
45
- for i in range(self.buffer_n):
46
- start = i * self.buffer_h
47
- end = (i + 1) * self.buffer_h
48
-
49
- # The boundary case in the last iteration
50
- if end > self.neigh.shape[0]:
51
- dis = end - self.neigh.shape[0]
52
- end = self.neigh.shape[0]
53
- buffer, _ = buffer.split([self.buffer_h-dis, dis])
54
-
55
- # Perform octree2col
56
- neigh_i = self.neigh[start:end]
57
- valid = neigh_i >= 0
58
- buffer.fill_(0)
59
- buffer[valid] = data[neigh_i[valid]]
60
-
61
- # The sub-matrix gemm
62
- # out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
63
- out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1))
64
- return out
65
-
66
- def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
67
- weights: torch.Tensor):
68
- r''' Performs the backward pass of octree-based convolution.
69
- '''
70
-
71
- # Type check
72
- if grad.dtype != out.dtype:
73
- grad = grad.to(out.dtype)
74
- if weights.dtype != out.dtype:
75
- weights = weights.to(out.dtype)
76
-
77
- # Loop over each sub-matrix
78
- for i in range(self.buffer_n):
79
- start = i * self.buffer_h
80
- end = (i + 1) * self.buffer_h
81
-
82
- # The boundary case in the last iteration
83
- if end > self.neigh.shape[0]:
84
- end = self.neigh.shape[0]
85
-
86
- # The sub-matrix gemm
87
- # buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
88
- # buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
89
- buffer = torch.einsum(
90
- 'ic,kc->ikc', grad[start:end], weights.flatten(0, 1))
91
-
92
- # Performs col2octree
93
- neigh_i = self.neigh[start:end]
94
- valid = neigh_i >= 0
95
- out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
96
-
97
- return out
98
-
99
- def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
100
- r''' Computes the gradient of the weight matrix.
101
- '''
102
-
103
- # Type check
104
- if data.dtype != out.dtype:
105
- data = data.to(out.dtype)
106
- if grad.dtype != out.dtype:
107
- grad = grad.to(out.dtype)
108
-
109
- # Record the shape of out
110
- out_shape = out.shape
111
- out = out.flatten(0, 1)
112
-
113
- # Initialize the buffer
114
- buffer = data.new_empty(self.buffer_shape)
115
-
116
- # Loop over each sub-matrix
117
- for i in range(self.buffer_n):
118
- start = i * self.buffer_h
119
- end = (i + 1) * self.buffer_h
120
-
121
- # The boundary case in the last iteration
122
- if end > self.neigh.shape[0]:
123
- d = end - self.neigh.shape[0]
124
- end = self.neigh.shape[0]
125
- buffer, _ = buffer.split([self.buffer_h-d, d])
126
-
127
- # Perform octree2col
128
- neigh_i = self.neigh[start:end]
129
- valid = neigh_i >= 0
130
- buffer.fill_(0)
131
- buffer[valid] = data[neigh_i[valid]]
132
-
133
- # Accumulate the gradient via gemm
134
- # out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
135
- out += torch.einsum('ikc,ic->kc', buffer, grad[start:end])
136
- return out.view(out_shape)
137
-
138
-
139
- class OctreeDWConvFunction(Function):
140
- r''' Wrap the octree convolution for auto-diff.
141
- '''
142
-
143
- @staticmethod
144
- def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
145
- depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3],
146
- stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)):
147
- octree_conv = OctreeDWConvBase(
148
- in_channels, kernel_size, stride, nempty, max_buffer)
149
- octree_conv.setup(octree, depth)
150
- out = octree_conv.check_and_init(data)
151
- out = octree_conv.forward_gemm(out, data, weights)
152
-
153
- ctx.save_for_backward(data, weights)
154
- ctx.octree_conv = octree_conv
155
- return out
156
-
157
- @staticmethod
158
- def backward(ctx, grad):
159
- data, weights = ctx.saved_tensors
160
- octree_conv = ctx.octree_conv
161
-
162
- grad_out = None
163
- if ctx.needs_input_grad[0]:
164
- grad_out = torch.zeros_like(data)
165
- grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
166
-
167
- grad_w = None
168
- if ctx.needs_input_grad[1]:
169
- grad_w = torch.zeros_like(weights)
170
- grad_w = octree_conv.weight_gemm(grad_w, data, grad)
171
-
172
- return (grad_out, grad_w) + (None,) * 7
173
-
174
-
175
- # alias
176
- octree_dwconv = OctreeDWConvFunction.apply
177
-
178
-
179
- class OctreeDWConv(OctreeDWConvBase, torch.nn.Module):
180
- r''' Performs octree-based depth-wise convolution.
181
-
182
- Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments.
183
-
184
- .. note::
185
- This implementation uses the :func:`torch.einsum` and I find that the speed
186
- is relatively slow. Further optimization is needed to speed it up.
187
- '''
188
-
189
- def __init__(self, in_channels: int, kernel_size: List[int] = [3],
190
- stride: int = 1, nempty: bool = False, use_bias: bool = False,
191
- max_buffer: int = int(2e8)):
192
- super().__init__(in_channels, kernel_size, stride, nempty, max_buffer)
193
-
194
- self.use_bias = use_bias
195
- self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
196
- if self.use_bias:
197
- self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
198
- self.reset_parameters()
199
-
200
- def reset_parameters(self):
201
- xavier_uniform_(self.weights)
202
- if self.use_bias:
203
- torch.nn.init.zeros_(self.bias)
204
-
205
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
206
- r''''''
207
-
208
- out = octree_dwconv(
209
- data, self.weights, octree, depth, self.in_channels,
210
- self.kernel_size, self.stride, self.nempty, self.max_buffer)
211
-
212
- if self.use_bias:
213
- out += self.bias
214
-
215
- if self.stride == 2 and not self.nempty:
216
- out = octree_pad(out, octree, depth-1)
217
- return out
218
-
219
- def extra_repr(self) -> str:
220
- return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
221
- 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
222
- self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from torch.autograd import Function
11
+ from typing import List
12
+
13
+ from ocnn.octree import Octree
14
+ from ocnn.utils import scatter_add, xavier_uniform_
15
+ from .octree_pad import octree_pad
16
+ from .octree_conv import OctreeConvBase
17
+
18
+
19
+ class OctreeDWConvBase(OctreeConvBase):
20
+
21
+ def __init__(self, in_channels: int, kernel_size: List[int] = [3],
22
+ stride: int = 1, nempty: bool = False,
23
+ max_buffer: int = int(2e8)):
24
+ super().__init__(
25
+ in_channels, in_channels, kernel_size, stride, nempty, max_buffer)
26
+ self.weights_shape = (self.kdim, 1, self.out_channels)
27
+
28
+ def is_conv_layer(self): return True
29
+
30
+ def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
31
+ weights: torch.Tensor):
32
+ r''' Peforms the forward pass of octree-based convolution.
33
+ '''
34
+
35
+ # Type check
36
+ if data.dtype != out.dtype:
37
+ data = data.to(out.dtype)
38
+ if weights.dtype != out.dtype:
39
+ weights = weights.to(out.dtype)
40
+
41
+ # Initialize the buffer
42
+ buffer = data.new_empty(self.buffer_shape)
43
+
44
+ # Loop over each sub-matrix
45
+ for i in range(self.buffer_n):
46
+ start = i * self.buffer_h
47
+ end = (i + 1) * self.buffer_h
48
+
49
+ # The boundary case in the last iteration
50
+ if end > self.neigh.shape[0]:
51
+ dis = end - self.neigh.shape[0]
52
+ end = self.neigh.shape[0]
53
+ buffer, _ = buffer.split([self.buffer_h-dis, dis])
54
+
55
+ # Perform octree2col
56
+ neigh_i = self.neigh[start:end]
57
+ valid = neigh_i >= 0
58
+ buffer.fill_(0)
59
+ buffer[valid] = data[neigh_i[valid]]
60
+
61
+ # The sub-matrix gemm
62
+ # out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
63
+ out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1))
64
+ return out
65
+
66
+ def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
67
+ weights: torch.Tensor):
68
+ r''' Performs the backward pass of octree-based convolution.
69
+ '''
70
+
71
+ # Type check
72
+ if grad.dtype != out.dtype:
73
+ grad = grad.to(out.dtype)
74
+ if weights.dtype != out.dtype:
75
+ weights = weights.to(out.dtype)
76
+
77
+ # Loop over each sub-matrix
78
+ for i in range(self.buffer_n):
79
+ start = i * self.buffer_h
80
+ end = (i + 1) * self.buffer_h
81
+
82
+ # The boundary case in the last iteration
83
+ if end > self.neigh.shape[0]:
84
+ end = self.neigh.shape[0]
85
+
86
+ # The sub-matrix gemm
87
+ # buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
88
+ # buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
89
+ buffer = torch.einsum(
90
+ 'ic,kc->ikc', grad[start:end], weights.flatten(0, 1))
91
+
92
+ # Performs col2octree
93
+ neigh_i = self.neigh[start:end]
94
+ valid = neigh_i >= 0
95
+ out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
96
+
97
+ return out
98
+
99
+ def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
100
+ r''' Computes the gradient of the weight matrix.
101
+ '''
102
+
103
+ # Type check
104
+ if data.dtype != out.dtype:
105
+ data = data.to(out.dtype)
106
+ if grad.dtype != out.dtype:
107
+ grad = grad.to(out.dtype)
108
+
109
+ # Record the shape of out
110
+ out_shape = out.shape
111
+ out = out.flatten(0, 1)
112
+
113
+ # Initialize the buffer
114
+ buffer = data.new_empty(self.buffer_shape)
115
+
116
+ # Loop over each sub-matrix
117
+ for i in range(self.buffer_n):
118
+ start = i * self.buffer_h
119
+ end = (i + 1) * self.buffer_h
120
+
121
+ # The boundary case in the last iteration
122
+ if end > self.neigh.shape[0]:
123
+ d = end - self.neigh.shape[0]
124
+ end = self.neigh.shape[0]
125
+ buffer, _ = buffer.split([self.buffer_h-d, d])
126
+
127
+ # Perform octree2col
128
+ neigh_i = self.neigh[start:end]
129
+ valid = neigh_i >= 0
130
+ buffer.fill_(0)
131
+ buffer[valid] = data[neigh_i[valid]]
132
+
133
+ # Accumulate the gradient via gemm
134
+ # out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
135
+ out += torch.einsum('ikc,ic->kc', buffer, grad[start:end])
136
+ return out.view(out_shape)
137
+
138
+
139
+ class OctreeDWConvFunction(Function):
140
+ r''' Wrap the octree convolution for auto-diff.
141
+ '''
142
+
143
+ @staticmethod
144
+ def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
145
+ depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3],
146
+ stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)):
147
+ octree_conv = OctreeDWConvBase(
148
+ in_channels, kernel_size, stride, nempty, max_buffer)
149
+ octree_conv.setup(octree, depth)
150
+ out = octree_conv.check_and_init(data)
151
+ out = octree_conv.forward_gemm(out, data, weights)
152
+
153
+ ctx.save_for_backward(data, weights)
154
+ ctx.octree_conv = octree_conv
155
+ return out
156
+
157
+ @staticmethod
158
+ def backward(ctx, grad):
159
+ data, weights = ctx.saved_tensors
160
+ octree_conv = ctx.octree_conv
161
+
162
+ grad_out = None
163
+ if ctx.needs_input_grad[0]:
164
+ grad_out = torch.zeros_like(data)
165
+ grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
166
+
167
+ grad_w = None
168
+ if ctx.needs_input_grad[1]:
169
+ grad_w = torch.zeros_like(weights)
170
+ grad_w = octree_conv.weight_gemm(grad_w, data, grad)
171
+
172
+ return (grad_out, grad_w) + (None,) * 7
173
+
174
+
175
+ # alias
176
+ octree_dwconv = OctreeDWConvFunction.apply
177
+
178
+
179
+ class OctreeDWConv(OctreeDWConvBase, torch.nn.Module):
180
+ r''' Performs octree-based depth-wise convolution.
181
+
182
+ Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments.
183
+
184
+ .. note::
185
+ This implementation uses the :func:`torch.einsum` and I find that the speed
186
+ is relatively slow. Further optimization is needed to speed it up.
187
+ '''
188
+
189
+ def __init__(self, in_channels: int, kernel_size: List[int] = [3],
190
+ stride: int = 1, nempty: bool = False, use_bias: bool = False,
191
+ max_buffer: int = int(2e8)):
192
+ super().__init__(in_channels, kernel_size, stride, nempty, max_buffer)
193
+
194
+ self.use_bias = use_bias
195
+ self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
196
+ if self.use_bias:
197
+ self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
198
+ self.reset_parameters()
199
+
200
+ def reset_parameters(self):
201
+ xavier_uniform_(self.weights)
202
+ if self.use_bias:
203
+ torch.nn.init.zeros_(self.bias)
204
+
205
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
206
+ r''''''
207
+
208
+ out = octree_dwconv(
209
+ data, self.weights, octree, depth, self.in_channels,
210
+ self.kernel_size, self.stride, self.nempty, self.max_buffer)
211
+
212
+ if self.use_bias:
213
+ out += self.bias
214
+
215
+ if self.stride == 2 and not self.nempty:
216
+ out = octree_pad(out, octree, depth-1)
217
+ return out
218
+
219
+ def extra_repr(self) -> str:
220
+ return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
221
+ 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
222
+ self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa