ocnn 2.2.5__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +44 -44
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +429 -429
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +661 -659
- ocnn/octree/points.py +323 -322
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- ocnn-2.2.7.dist-info/METADATA +112 -0
- ocnn-2.2.7.dist-info/RECORD +36 -0
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info}/WHEEL +1 -1
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info/licenses}/LICENSE +21 -21
- ocnn-2.2.5.dist-info/METADATA +0 -80
- ocnn-2.2.5.dist-info/RECORD +0 -36
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info}/top_level.txt +0 -0
ocnn/modules/modules.py
CHANGED
|
@@ -1,303 +1,303 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.utils.checkpoint
|
|
10
|
-
from typing import List
|
|
11
|
-
|
|
12
|
-
from ocnn.nn import OctreeConv, OctreeDeconv, OctreeGroupNorm
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
# bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x
|
|
17
|
-
# bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def ckpt_conv_wrapper(conv_op, data, octree):
|
|
21
|
-
# The dummy tensor is a workaround when the checkpoint is used for the first conv layer:
|
|
22
|
-
# https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11
|
|
23
|
-
dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
|
|
24
|
-
|
|
25
|
-
def conv_wrapper(data, octree, dummy_tensor):
|
|
26
|
-
return conv_op(data, octree)
|
|
27
|
-
|
|
28
|
-
return torch.utils.checkpoint.checkpoint(conv_wrapper, data, octree, dummy)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class OctreeConvBn(torch.nn.Module):
|
|
32
|
-
r''' A sequence of :class:`OctreeConv` and :obj:`BatchNorm`.
|
|
33
|
-
|
|
34
|
-
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
35
|
-
'''
|
|
36
|
-
|
|
37
|
-
def __init__(self, in_channels: int, out_channels: int,
|
|
38
|
-
kernel_size: List[int] = [3], stride: int = 1,
|
|
39
|
-
nempty: bool = False):
|
|
40
|
-
super().__init__()
|
|
41
|
-
self.conv = OctreeConv(
|
|
42
|
-
in_channels, out_channels, kernel_size, stride, nempty)
|
|
43
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
44
|
-
|
|
45
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
46
|
-
r''''''
|
|
47
|
-
|
|
48
|
-
out = self.conv(data, octree, depth)
|
|
49
|
-
out = self.bn(out)
|
|
50
|
-
return out
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class OctreeConvBnRelu(torch.nn.Module):
|
|
54
|
-
r''' A sequence of :class:`OctreeConv`, :obj:`BatchNorm`, and :obj:`Relu`.
|
|
55
|
-
|
|
56
|
-
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
57
|
-
'''
|
|
58
|
-
|
|
59
|
-
def __init__(self, in_channels: int, out_channels: int,
|
|
60
|
-
kernel_size: List[int] = [3], stride: int = 1,
|
|
61
|
-
nempty: bool = False):
|
|
62
|
-
super().__init__()
|
|
63
|
-
self.conv = OctreeConv(
|
|
64
|
-
in_channels, out_channels, kernel_size, stride, nempty)
|
|
65
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
66
|
-
self.relu = torch.nn.ReLU(inplace=True)
|
|
67
|
-
|
|
68
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
69
|
-
r''''''
|
|
70
|
-
|
|
71
|
-
out = self.conv(data, octree, depth)
|
|
72
|
-
out = self.bn(out)
|
|
73
|
-
out = self.relu(out)
|
|
74
|
-
return out
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
class OctreeDeconvBnRelu(torch.nn.Module):
|
|
78
|
-
r''' A sequence of :class:`OctreeDeconv`, :obj:`BatchNorm`, and :obj:`Relu`.
|
|
79
|
-
|
|
80
|
-
Please refer to :class:`ocnn.nn.OctreeDeconv` for details on the parameters.
|
|
81
|
-
'''
|
|
82
|
-
|
|
83
|
-
def __init__(self, in_channels: int, out_channels: int,
|
|
84
|
-
kernel_size: List[int] = [3], stride: int = 1,
|
|
85
|
-
nempty: bool = False):
|
|
86
|
-
super().__init__()
|
|
87
|
-
self.deconv = OctreeDeconv(
|
|
88
|
-
in_channels, out_channels, kernel_size, stride, nempty)
|
|
89
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
90
|
-
self.relu = torch.nn.ReLU(inplace=True)
|
|
91
|
-
|
|
92
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
93
|
-
r''''''
|
|
94
|
-
|
|
95
|
-
out = self.deconv(data, octree, depth)
|
|
96
|
-
out = self.bn(out)
|
|
97
|
-
out = self.relu(out)
|
|
98
|
-
return out
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
class Conv1x1(torch.nn.Module):
|
|
102
|
-
r''' Performs a convolution with kernel :obj:`(1,1,1)`.
|
|
103
|
-
|
|
104
|
-
The shape of octree features is :obj:`(N, C)`, where :obj:`N` is the node
|
|
105
|
-
number and :obj:`C` is the feature channel. Therefore, :class:`Conv1x1` can be
|
|
106
|
-
implemented with :class:`torch.nn.Linear`.
|
|
107
|
-
'''
|
|
108
|
-
|
|
109
|
-
def __init__(self, in_channels: int, out_channels: int, use_bias: bool = False):
|
|
110
|
-
super().__init__()
|
|
111
|
-
self.linear = torch.nn.Linear(in_channels, out_channels, use_bias)
|
|
112
|
-
|
|
113
|
-
def forward(self, data: torch.Tensor):
|
|
114
|
-
r''''''
|
|
115
|
-
|
|
116
|
-
return self.linear(data)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
class Conv1x1Bn(torch.nn.Module):
|
|
120
|
-
r''' A sequence of :class:`Conv1x1` and :class:`BatchNorm`.
|
|
121
|
-
'''
|
|
122
|
-
|
|
123
|
-
def __init__(self, in_channels: int, out_channels: int):
|
|
124
|
-
super().__init__()
|
|
125
|
-
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
126
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
127
|
-
|
|
128
|
-
def forward(self, data: torch.Tensor):
|
|
129
|
-
r''''''
|
|
130
|
-
|
|
131
|
-
out = self.conv(data)
|
|
132
|
-
out = self.bn(out)
|
|
133
|
-
return out
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
class Conv1x1BnRelu(torch.nn.Module):
|
|
137
|
-
r''' A sequence of :class:`Conv1x1`, :class:`BatchNorm` and :class:`Relu`.
|
|
138
|
-
'''
|
|
139
|
-
|
|
140
|
-
def __init__(self, in_channels: int, out_channels: int):
|
|
141
|
-
super().__init__()
|
|
142
|
-
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
143
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
144
|
-
self.relu = torch.nn.ReLU(inplace=True)
|
|
145
|
-
|
|
146
|
-
def forward(self, data: torch.Tensor):
|
|
147
|
-
r''''''
|
|
148
|
-
|
|
149
|
-
out = self.conv(data)
|
|
150
|
-
out = self.bn(out)
|
|
151
|
-
out = self.relu(out)
|
|
152
|
-
return out
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
class FcBnRelu(torch.nn.Module):
|
|
156
|
-
r''' A sequence of :class:`FC`, :class:`BatchNorm` and :class:`Relu`.
|
|
157
|
-
'''
|
|
158
|
-
|
|
159
|
-
def __init__(self, in_channels: int, out_channels: int):
|
|
160
|
-
super().__init__()
|
|
161
|
-
self.flatten = torch.nn.Flatten(start_dim=1)
|
|
162
|
-
self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
|
|
163
|
-
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
164
|
-
self.relu = torch.nn.ReLU(inplace=True)
|
|
165
|
-
|
|
166
|
-
def forward(self, data):
|
|
167
|
-
r''''''
|
|
168
|
-
|
|
169
|
-
out = self.flatten(data)
|
|
170
|
-
out = self.fc(out)
|
|
171
|
-
out = self.bn(out)
|
|
172
|
-
out = self.relu(out)
|
|
173
|
-
return out
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
class OctreeConvGn(torch.nn.Module):
|
|
177
|
-
r''' A sequence of :class:`OctreeConv` and :obj:`OctreeGroupNorm`.
|
|
178
|
-
|
|
179
|
-
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
180
|
-
'''
|
|
181
|
-
|
|
182
|
-
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
183
|
-
kernel_size: List[int] = [3], stride: int = 1,
|
|
184
|
-
nempty: bool = False):
|
|
185
|
-
super().__init__()
|
|
186
|
-
self.conv = OctreeConv(
|
|
187
|
-
in_channels, out_channels, kernel_size, stride, nempty)
|
|
188
|
-
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
189
|
-
|
|
190
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
191
|
-
r''''''
|
|
192
|
-
|
|
193
|
-
out = self.conv(data, octree, depth)
|
|
194
|
-
out = self.gn(out, octree, depth)
|
|
195
|
-
return out
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
class OctreeConvGnRelu(torch.nn.Module):
|
|
199
|
-
r''' A sequence of :class:`OctreeConv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
|
|
200
|
-
|
|
201
|
-
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
202
|
-
'''
|
|
203
|
-
|
|
204
|
-
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
205
|
-
kernel_size: List[int] = [3], stride: int = 1,
|
|
206
|
-
nempty: bool = False):
|
|
207
|
-
super().__init__()
|
|
208
|
-
self.stride = stride
|
|
209
|
-
self.conv = OctreeConv(
|
|
210
|
-
in_channels, out_channels, kernel_size, stride, nempty)
|
|
211
|
-
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
212
|
-
self.relu = torch.nn.ReLU(inplace=True)
|
|
213
|
-
|
|
214
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
215
|
-
r''''''
|
|
216
|
-
|
|
217
|
-
out = self.conv(data, octree, depth)
|
|
218
|
-
out = self.gn(out, octree, depth if self.stride == 1 else depth - 1)
|
|
219
|
-
out = self.relu(out)
|
|
220
|
-
return out
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
class OctreeDeconvGnRelu(torch.nn.Module):
|
|
224
|
-
r''' A sequence of :class:`OctreeDeconv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
|
|
225
|
-
|
|
226
|
-
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
227
|
-
'''
|
|
228
|
-
|
|
229
|
-
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
230
|
-
kernel_size: List[int] = [3], stride: int = 1,
|
|
231
|
-
nempty: bool = False):
|
|
232
|
-
super().__init__()
|
|
233
|
-
self.stride = stride
|
|
234
|
-
self.deconv = OctreeDeconv(
|
|
235
|
-
in_channels, out_channels, kernel_size, stride, nempty)
|
|
236
|
-
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
237
|
-
self.relu = torch.nn.ReLU(inplace=True)
|
|
238
|
-
|
|
239
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
240
|
-
r''''''
|
|
241
|
-
|
|
242
|
-
out = self.deconv(data, octree, depth)
|
|
243
|
-
out = self.gn(out, octree, depth if self.stride == 1 else depth + 1)
|
|
244
|
-
out = self.relu(out)
|
|
245
|
-
return out
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
class Conv1x1Gn(torch.nn.Module):
|
|
249
|
-
r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm`.
|
|
250
|
-
'''
|
|
251
|
-
|
|
252
|
-
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
253
|
-
nempty: bool = False):
|
|
254
|
-
super().__init__()
|
|
255
|
-
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
256
|
-
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
257
|
-
|
|
258
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
259
|
-
r''''''
|
|
260
|
-
|
|
261
|
-
out = self.conv(data)
|
|
262
|
-
out = self.gn(out, octree, depth)
|
|
263
|
-
return out
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
class Conv1x1GnRelu(torch.nn.Module):
|
|
267
|
-
r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm` and :class:`Relu`.
|
|
268
|
-
'''
|
|
269
|
-
|
|
270
|
-
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
271
|
-
nempty: bool = False):
|
|
272
|
-
super().__init__()
|
|
273
|
-
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
274
|
-
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
275
|
-
self.relu = torch.nn.ReLU(inplace=True)
|
|
276
|
-
|
|
277
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
278
|
-
r''''''
|
|
279
|
-
|
|
280
|
-
out = self.conv(data)
|
|
281
|
-
out = self.gn(out, octree, depth)
|
|
282
|
-
out = self.relu(out)
|
|
283
|
-
return out
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
class InputFeature(torch.nn.Module):
|
|
287
|
-
r''' Returns the initial input feature stored in octree.
|
|
288
|
-
|
|
289
|
-
Refer to :func:`ocnn.octree.Octree.get_input_feature` for details.
|
|
290
|
-
'''
|
|
291
|
-
|
|
292
|
-
def __init__(self, feature: str = 'NDF', nempty: bool = False):
|
|
293
|
-
super().__init__()
|
|
294
|
-
self.nempty = nempty
|
|
295
|
-
self.feature = feature.upper()
|
|
296
|
-
|
|
297
|
-
def forward(self, octree: Octree):
|
|
298
|
-
r''''''
|
|
299
|
-
return octree.get_input_feature(self.feature, self.nempty)
|
|
300
|
-
|
|
301
|
-
def extra_repr(self) -> str:
|
|
302
|
-
r''''''
|
|
303
|
-
return 'feature={}, nempty={}'.format(self.feature, self.nempty)
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.utils.checkpoint
|
|
10
|
+
from typing import List
|
|
11
|
+
|
|
12
|
+
from ocnn.nn import OctreeConv, OctreeDeconv, OctreeGroupNorm
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x
|
|
17
|
+
# bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def ckpt_conv_wrapper(conv_op, data, octree):
|
|
21
|
+
# The dummy tensor is a workaround when the checkpoint is used for the first conv layer:
|
|
22
|
+
# https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11
|
|
23
|
+
dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
|
|
24
|
+
|
|
25
|
+
def conv_wrapper(data, octree, dummy_tensor):
|
|
26
|
+
return conv_op(data, octree)
|
|
27
|
+
|
|
28
|
+
return torch.utils.checkpoint.checkpoint(conv_wrapper, data, octree, dummy)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OctreeConvBn(torch.nn.Module):
|
|
32
|
+
r''' A sequence of :class:`OctreeConv` and :obj:`BatchNorm`.
|
|
33
|
+
|
|
34
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
35
|
+
'''
|
|
36
|
+
|
|
37
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
38
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
39
|
+
nempty: bool = False):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.conv = OctreeConv(
|
|
42
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
43
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
44
|
+
|
|
45
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
46
|
+
r''''''
|
|
47
|
+
|
|
48
|
+
out = self.conv(data, octree, depth)
|
|
49
|
+
out = self.bn(out)
|
|
50
|
+
return out
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class OctreeConvBnRelu(torch.nn.Module):
|
|
54
|
+
r''' A sequence of :class:`OctreeConv`, :obj:`BatchNorm`, and :obj:`Relu`.
|
|
55
|
+
|
|
56
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
57
|
+
'''
|
|
58
|
+
|
|
59
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
60
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
61
|
+
nempty: bool = False):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.conv = OctreeConv(
|
|
64
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
65
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
66
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
67
|
+
|
|
68
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
69
|
+
r''''''
|
|
70
|
+
|
|
71
|
+
out = self.conv(data, octree, depth)
|
|
72
|
+
out = self.bn(out)
|
|
73
|
+
out = self.relu(out)
|
|
74
|
+
return out
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class OctreeDeconvBnRelu(torch.nn.Module):
|
|
78
|
+
r''' A sequence of :class:`OctreeDeconv`, :obj:`BatchNorm`, and :obj:`Relu`.
|
|
79
|
+
|
|
80
|
+
Please refer to :class:`ocnn.nn.OctreeDeconv` for details on the parameters.
|
|
81
|
+
'''
|
|
82
|
+
|
|
83
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
84
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
85
|
+
nempty: bool = False):
|
|
86
|
+
super().__init__()
|
|
87
|
+
self.deconv = OctreeDeconv(
|
|
88
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
89
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
90
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
91
|
+
|
|
92
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
93
|
+
r''''''
|
|
94
|
+
|
|
95
|
+
out = self.deconv(data, octree, depth)
|
|
96
|
+
out = self.bn(out)
|
|
97
|
+
out = self.relu(out)
|
|
98
|
+
return out
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class Conv1x1(torch.nn.Module):
|
|
102
|
+
r''' Performs a convolution with kernel :obj:`(1,1,1)`.
|
|
103
|
+
|
|
104
|
+
The shape of octree features is :obj:`(N, C)`, where :obj:`N` is the node
|
|
105
|
+
number and :obj:`C` is the feature channel. Therefore, :class:`Conv1x1` can be
|
|
106
|
+
implemented with :class:`torch.nn.Linear`.
|
|
107
|
+
'''
|
|
108
|
+
|
|
109
|
+
def __init__(self, in_channels: int, out_channels: int, use_bias: bool = False):
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.linear = torch.nn.Linear(in_channels, out_channels, use_bias)
|
|
112
|
+
|
|
113
|
+
def forward(self, data: torch.Tensor):
|
|
114
|
+
r''''''
|
|
115
|
+
|
|
116
|
+
return self.linear(data)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class Conv1x1Bn(torch.nn.Module):
|
|
120
|
+
r''' A sequence of :class:`Conv1x1` and :class:`BatchNorm`.
|
|
121
|
+
'''
|
|
122
|
+
|
|
123
|
+
def __init__(self, in_channels: int, out_channels: int):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
126
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
127
|
+
|
|
128
|
+
def forward(self, data: torch.Tensor):
|
|
129
|
+
r''''''
|
|
130
|
+
|
|
131
|
+
out = self.conv(data)
|
|
132
|
+
out = self.bn(out)
|
|
133
|
+
return out
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class Conv1x1BnRelu(torch.nn.Module):
|
|
137
|
+
r''' A sequence of :class:`Conv1x1`, :class:`BatchNorm` and :class:`Relu`.
|
|
138
|
+
'''
|
|
139
|
+
|
|
140
|
+
def __init__(self, in_channels: int, out_channels: int):
|
|
141
|
+
super().__init__()
|
|
142
|
+
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
143
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
144
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
145
|
+
|
|
146
|
+
def forward(self, data: torch.Tensor):
|
|
147
|
+
r''''''
|
|
148
|
+
|
|
149
|
+
out = self.conv(data)
|
|
150
|
+
out = self.bn(out)
|
|
151
|
+
out = self.relu(out)
|
|
152
|
+
return out
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class FcBnRelu(torch.nn.Module):
|
|
156
|
+
r''' A sequence of :class:`FC`, :class:`BatchNorm` and :class:`Relu`.
|
|
157
|
+
'''
|
|
158
|
+
|
|
159
|
+
def __init__(self, in_channels: int, out_channels: int):
|
|
160
|
+
super().__init__()
|
|
161
|
+
self.flatten = torch.nn.Flatten(start_dim=1)
|
|
162
|
+
self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
|
|
163
|
+
self.bn = torch.nn.BatchNorm1d(out_channels)
|
|
164
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
165
|
+
|
|
166
|
+
def forward(self, data):
|
|
167
|
+
r''''''
|
|
168
|
+
|
|
169
|
+
out = self.flatten(data)
|
|
170
|
+
out = self.fc(out)
|
|
171
|
+
out = self.bn(out)
|
|
172
|
+
out = self.relu(out)
|
|
173
|
+
return out
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class OctreeConvGn(torch.nn.Module):
|
|
177
|
+
r''' A sequence of :class:`OctreeConv` and :obj:`OctreeGroupNorm`.
|
|
178
|
+
|
|
179
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
180
|
+
'''
|
|
181
|
+
|
|
182
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
183
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
184
|
+
nempty: bool = False):
|
|
185
|
+
super().__init__()
|
|
186
|
+
self.conv = OctreeConv(
|
|
187
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
188
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
189
|
+
|
|
190
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
191
|
+
r''''''
|
|
192
|
+
|
|
193
|
+
out = self.conv(data, octree, depth)
|
|
194
|
+
out = self.gn(out, octree, depth)
|
|
195
|
+
return out
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class OctreeConvGnRelu(torch.nn.Module):
|
|
199
|
+
r''' A sequence of :class:`OctreeConv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
|
|
200
|
+
|
|
201
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
202
|
+
'''
|
|
203
|
+
|
|
204
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
205
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
206
|
+
nempty: bool = False):
|
|
207
|
+
super().__init__()
|
|
208
|
+
self.stride = stride
|
|
209
|
+
self.conv = OctreeConv(
|
|
210
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
211
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
212
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
213
|
+
|
|
214
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
215
|
+
r''''''
|
|
216
|
+
|
|
217
|
+
out = self.conv(data, octree, depth)
|
|
218
|
+
out = self.gn(out, octree, depth if self.stride == 1 else depth - 1)
|
|
219
|
+
out = self.relu(out)
|
|
220
|
+
return out
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class OctreeDeconvGnRelu(torch.nn.Module):
|
|
224
|
+
r''' A sequence of :class:`OctreeDeconv`, :obj:`OctreeGroupNorm`, and :obj:`Relu`.
|
|
225
|
+
|
|
226
|
+
Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
|
|
227
|
+
'''
|
|
228
|
+
|
|
229
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
230
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
231
|
+
nempty: bool = False):
|
|
232
|
+
super().__init__()
|
|
233
|
+
self.stride = stride
|
|
234
|
+
self.deconv = OctreeDeconv(
|
|
235
|
+
in_channels, out_channels, kernel_size, stride, nempty)
|
|
236
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
237
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
238
|
+
|
|
239
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
240
|
+
r''''''
|
|
241
|
+
|
|
242
|
+
out = self.deconv(data, octree, depth)
|
|
243
|
+
out = self.gn(out, octree, depth if self.stride == 1 else depth + 1)
|
|
244
|
+
out = self.relu(out)
|
|
245
|
+
return out
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class Conv1x1Gn(torch.nn.Module):
|
|
249
|
+
r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm`.
|
|
250
|
+
'''
|
|
251
|
+
|
|
252
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
253
|
+
nempty: bool = False):
|
|
254
|
+
super().__init__()
|
|
255
|
+
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
256
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
257
|
+
|
|
258
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
259
|
+
r''''''
|
|
260
|
+
|
|
261
|
+
out = self.conv(data)
|
|
262
|
+
out = self.gn(out, octree, depth)
|
|
263
|
+
return out
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class Conv1x1GnRelu(torch.nn.Module):
|
|
267
|
+
r''' A sequence of :class:`Conv1x1`, :class:`OctreeGroupNorm` and :class:`Relu`.
|
|
268
|
+
'''
|
|
269
|
+
|
|
270
|
+
def __init__(self, in_channels: int, out_channels: int, group: int,
|
|
271
|
+
nempty: bool = False):
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
|
|
274
|
+
self.gn = OctreeGroupNorm(out_channels, group=group, nempty=nempty)
|
|
275
|
+
self.relu = torch.nn.ReLU(inplace=True)
|
|
276
|
+
|
|
277
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
278
|
+
r''''''
|
|
279
|
+
|
|
280
|
+
out = self.conv(data)
|
|
281
|
+
out = self.gn(out, octree, depth)
|
|
282
|
+
out = self.relu(out)
|
|
283
|
+
return out
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class InputFeature(torch.nn.Module):
|
|
287
|
+
r''' Returns the initial input feature stored in octree.
|
|
288
|
+
|
|
289
|
+
Refer to :func:`ocnn.octree.Octree.get_input_feature` for details.
|
|
290
|
+
'''
|
|
291
|
+
|
|
292
|
+
def __init__(self, feature: str = 'NDF', nempty: bool = False):
|
|
293
|
+
super().__init__()
|
|
294
|
+
self.nempty = nempty
|
|
295
|
+
self.feature = feature.upper()
|
|
296
|
+
|
|
297
|
+
def forward(self, octree: Octree):
|
|
298
|
+
r''''''
|
|
299
|
+
return octree.get_input_feature(self.feature, self.nempty)
|
|
300
|
+
|
|
301
|
+
def extra_repr(self) -> str:
|
|
302
|
+
r''''''
|
|
303
|
+
return 'feature={}, nempty={}'.format(self.feature, self.nempty)
|