ocnn 2.2.1__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/modules/__init__.py CHANGED
@@ -1,20 +1,20 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- from .modules import (InputFeature,
9
- OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
10
- Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,)
11
- from .resblocks import OctreeResBlock, OctreeResBlock2, OctreeResBlocks
12
-
13
- __all__ = [
14
- 'InputFeature',
15
- 'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
16
- 'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
17
- 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlocks',
18
- ]
19
-
20
- classes = __all__
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ from .modules import (InputFeature,
9
+ OctreeConvBn, OctreeConvBnRelu, OctreeDeconvBnRelu,
10
+ Conv1x1, Conv1x1Bn, Conv1x1BnRelu, FcBnRelu,)
11
+ from .resblocks import OctreeResBlock, OctreeResBlock2, OctreeResBlocks
12
+
13
+ __all__ = [
14
+ 'InputFeature',
15
+ 'OctreeConvBn', 'OctreeConvBnRelu', 'OctreeDeconvBnRelu',
16
+ 'Conv1x1', 'Conv1x1Bn', 'Conv1x1BnRelu', 'FcBnRelu',
17
+ 'OctreeResBlock', 'OctreeResBlock2', 'OctreeResBlocks',
18
+ ]
19
+
20
+ classes = __all__
ocnn/modules/modules.py CHANGED
@@ -1,231 +1,193 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.utils.checkpoint
10
- from typing import List
11
-
12
- import ocnn
13
- from ocnn.nn import OctreeConv, OctreeDeconv
14
- from ocnn.octree import Octree
15
-
16
-
17
- bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x
18
- # bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch
19
-
20
-
21
- def ckpt_conv_wrapper(conv_op, data, octree):
22
- # The dummy tensor is a workaround when the checkpoint is used for the first conv layer:
23
- # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11
24
- dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
25
-
26
- def conv_wrapper(data, octree, dummy_tensor):
27
- return conv_op(data, octree)
28
-
29
- return torch.utils.checkpoint.checkpoint(conv_wrapper, data, octree, dummy)
30
-
31
-
32
- class OctreeConvBn(torch.nn.Module):
33
- r''' A sequence of :class:`OctreeConv` and :obj:`BatchNorm`.
34
-
35
- Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
36
- '''
37
-
38
- def __init__(self, in_channels: int, out_channels: int,
39
- kernel_size: List[int] = [3], stride: int = 1,
40
- nempty: bool = False):
41
- super().__init__()
42
- self.conv = OctreeConv(
43
- in_channels, out_channels, kernel_size, stride, nempty)
44
- self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
45
-
46
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
47
- r''''''
48
-
49
- out = self.conv(data, octree, depth)
50
- out = self.bn(out)
51
- return out
52
-
53
-
54
- class OctreeConvBnRelu(torch.nn.Module):
55
- r''' A sequence of :class:`OctreeConv`, :obj:`BatchNorm`, and :obj:`Relu`.
56
-
57
- Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
58
- '''
59
-
60
- def __init__(self, in_channels: int, out_channels: int,
61
- kernel_size: List[int] = [3], stride: int = 1,
62
- nempty: bool = False):
63
- super().__init__()
64
- self.conv = OctreeConv(
65
- in_channels, out_channels, kernel_size, stride, nempty)
66
- self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
67
- self.relu = torch.nn.ReLU(inplace=True)
68
-
69
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
70
- r''''''
71
-
72
- out = self.conv(data, octree, depth)
73
- out = self.bn(out)
74
- out = self.relu(out)
75
- return out
76
-
77
-
78
- class OctreeDeconvBnRelu(torch.nn.Module):
79
- r''' A sequence of :class:`OctreeDeconv`, :obj:`BatchNorm`, and :obj:`Relu`.
80
-
81
- Please refer to :class:`ocnn.nn.OctreeDeconv` for details on the parameters.
82
- '''
83
-
84
- def __init__(self, in_channels: int, out_channels: int,
85
- kernel_size: List[int] = [3], stride: int = 1,
86
- nempty: bool = False):
87
- super().__init__()
88
- self.deconv = OctreeDeconv(
89
- in_channels, out_channels, kernel_size, stride, nempty)
90
- self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
91
- self.relu = torch.nn.ReLU(inplace=True)
92
-
93
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
94
- r''''''
95
-
96
- out = self.deconv(data, octree, depth)
97
- out = self.bn(out)
98
- out = self.relu(out)
99
- return out
100
-
101
-
102
- class Conv1x1(torch.nn.Module):
103
- r''' Performs a convolution with kernel :obj:`(1,1,1)`.
104
-
105
- The shape of octree features is :obj:`(N, C)`, where :obj:`N` is the node
106
- number and :obj:`C` is the feature channel. Therefore, :class:`Conv1x1` can be
107
- implemented with :class:`torch.nn.Linear`.
108
- '''
109
-
110
- def __init__(self, in_channels: int, out_channels: int, use_bias: bool = False):
111
- super().__init__()
112
- self.linear = torch.nn.Linear(in_channels, out_channels, use_bias)
113
-
114
- def forward(self, data: torch.Tensor):
115
- r''''''
116
-
117
- return self.linear(data)
118
-
119
-
120
- class Conv1x1Bn(torch.nn.Module):
121
- r''' A sequence of :class:`Conv1x1` and :class:`BatchNorm`.
122
- '''
123
-
124
- def __init__(self, in_channels: int, out_channels: int):
125
- super().__init__()
126
- self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
127
- self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
128
-
129
- def forward(self, data: torch.Tensor):
130
- r''''''
131
-
132
- out = self.conv(data)
133
- out = self.bn(out)
134
- return out
135
-
136
-
137
- class Conv1x1BnRelu(torch.nn.Module):
138
- r''' A sequence of :class:`Conv1x1`, :class:`BatchNorm` and :class:`Relu`.
139
- '''
140
-
141
- def __init__(self, in_channels: int, out_channels: int):
142
- super().__init__()
143
- self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
144
- self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
145
- self.relu = torch.nn.ReLU(inplace=True)
146
-
147
- def forward(self, data: torch.Tensor):
148
- r''''''
149
-
150
- out = self.conv(data)
151
- out = self.bn(out)
152
- out = self.relu(out)
153
- return out
154
-
155
-
156
- class FcBnRelu(torch.nn.Module):
157
- r''' A sequence of :class:`FC`, :class:`BatchNorm` and :class:`Relu`.
158
- '''
159
-
160
- def __init__(self, in_channels: int, out_channels: int):
161
- super().__init__()
162
- self.flatten = torch.nn.Flatten(start_dim=1)
163
- self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
164
- self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
165
- self.relu = torch.nn.ReLU(inplace=True)
166
-
167
- def forward(self, data):
168
- r''''''
169
-
170
- out = self.flatten(data)
171
- out = self.fc(out)
172
- out = self.bn(out)
173
- out = self.relu(out)
174
- return out
175
-
176
-
177
- class InputFeature(torch.nn.Module):
178
- r''' Returns the initial input feature stored in octree.
179
-
180
- Args:
181
- feature (str): A string used to indicate which features to extract from the
182
- input octree. If the character :obj:`N` is in :attr:`feature`, the
183
- normal signal is extracted (3 channels). Similarly, if :obj:`D` is in
184
- :attr:`feature`, the local displacement is extracted (1 channels). If
185
- :obj:`L` is in :attr:`feature`, the local coordinates of the averaged
186
- points in each octree node is extracted (3 channels). If :attr:`P` is in
187
- :attr:`feature`, the global coordinates are extracted (3 channels). If
188
- :attr:`F` is in :attr:`feature`, other features (like colors) are
189
- extracted (k channels).
190
- nempty (bool): If false, gets the features of all octree nodes.
191
- '''
192
-
193
- def __init__(self, feature: str = 'NDF', nempty: bool = False):
194
- super().__init__()
195
- self.nempty = nempty
196
- self.feature = feature.upper()
197
-
198
- def forward(self, octree: Octree):
199
- r''''''
200
-
201
- features = list()
202
- depth = octree.depth
203
- if 'N' in self.feature:
204
- features.append(octree.normals[depth])
205
-
206
- if 'L' in self.feature or 'D' in self.feature:
207
- local_points = octree.points[depth].frac() - 0.5
208
-
209
- if 'D' in self.feature:
210
- dis = torch.sum(local_points * octree.normals[depth], dim=1, keepdim=True)
211
- features.append(dis)
212
-
213
- if 'L' in self.feature:
214
- features.append(local_points)
215
-
216
- if 'P' in self.feature:
217
- scale = 2 ** (1 - depth) # normalize [0, 2^depth] -> [-1, 1]
218
- global_points = octree.points[depth] * scale - 1.0
219
- features.append(global_points)
220
-
221
- if 'F' in self.feature:
222
- features.append(octree.features[depth])
223
-
224
- out = torch.cat(features, dim=1)
225
- if not self.nempty:
226
- out = ocnn.nn.octree_pad(out, octree, depth)
227
- return out
228
-
229
- def extra_repr(self) -> str:
230
- r''''''
231
- return 'feature={}, nempty={}'.format(self.feature, self.nempty)
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from typing import List
11
+
12
+ from ocnn.nn import OctreeConv, OctreeDeconv
13
+ from ocnn.octree import Octree
14
+
15
+
16
+ # bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x
17
+ # bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch
18
+
19
+
20
+ def ckpt_conv_wrapper(conv_op, data, octree):
21
+ # The dummy tensor is a workaround when the checkpoint is used for the first conv layer:
22
+ # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11
23
+ dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
24
+
25
+ def conv_wrapper(data, octree, dummy_tensor):
26
+ return conv_op(data, octree)
27
+
28
+ return torch.utils.checkpoint.checkpoint(conv_wrapper, data, octree, dummy)
29
+
30
+
31
+ class OctreeConvBn(torch.nn.Module):
32
+ r''' A sequence of :class:`OctreeConv` and :obj:`BatchNorm`.
33
+
34
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
35
+ '''
36
+
37
+ def __init__(self, in_channels: int, out_channels: int,
38
+ kernel_size: List[int] = [3], stride: int = 1,
39
+ nempty: bool = False):
40
+ super().__init__()
41
+ self.conv = OctreeConv(
42
+ in_channels, out_channels, kernel_size, stride, nempty)
43
+ self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
44
+
45
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
46
+ r''''''
47
+
48
+ out = self.conv(data, octree, depth)
49
+ out = self.bn(out)
50
+ return out
51
+
52
+
53
+ class OctreeConvBnRelu(torch.nn.Module):
54
+ r''' A sequence of :class:`OctreeConv`, :obj:`BatchNorm`, and :obj:`Relu`.
55
+
56
+ Please refer to :class:`ocnn.nn.OctreeConv` for details on the parameters.
57
+ '''
58
+
59
+ def __init__(self, in_channels: int, out_channels: int,
60
+ kernel_size: List[int] = [3], stride: int = 1,
61
+ nempty: bool = False):
62
+ super().__init__()
63
+ self.conv = OctreeConv(
64
+ in_channels, out_channels, kernel_size, stride, nempty)
65
+ self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
66
+ self.relu = torch.nn.ReLU(inplace=True)
67
+
68
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
69
+ r''''''
70
+
71
+ out = self.conv(data, octree, depth)
72
+ out = self.bn(out)
73
+ out = self.relu(out)
74
+ return out
75
+
76
+
77
+ class OctreeDeconvBnRelu(torch.nn.Module):
78
+ r''' A sequence of :class:`OctreeDeconv`, :obj:`BatchNorm`, and :obj:`Relu`.
79
+
80
+ Please refer to :class:`ocnn.nn.OctreeDeconv` for details on the parameters.
81
+ '''
82
+
83
+ def __init__(self, in_channels: int, out_channels: int,
84
+ kernel_size: List[int] = [3], stride: int = 1,
85
+ nempty: bool = False):
86
+ super().__init__()
87
+ self.deconv = OctreeDeconv(
88
+ in_channels, out_channels, kernel_size, stride, nempty)
89
+ self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
90
+ self.relu = torch.nn.ReLU(inplace=True)
91
+
92
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
93
+ r''''''
94
+
95
+ out = self.deconv(data, octree, depth)
96
+ out = self.bn(out)
97
+ out = self.relu(out)
98
+ return out
99
+
100
+
101
+ class Conv1x1(torch.nn.Module):
102
+ r''' Performs a convolution with kernel :obj:`(1,1,1)`.
103
+
104
+ The shape of octree features is :obj:`(N, C)`, where :obj:`N` is the node
105
+ number and :obj:`C` is the feature channel. Therefore, :class:`Conv1x1` can be
106
+ implemented with :class:`torch.nn.Linear`.
107
+ '''
108
+
109
+ def __init__(self, in_channels: int, out_channels: int, use_bias: bool = False):
110
+ super().__init__()
111
+ self.linear = torch.nn.Linear(in_channels, out_channels, use_bias)
112
+
113
+ def forward(self, data: torch.Tensor):
114
+ r''''''
115
+
116
+ return self.linear(data)
117
+
118
+
119
+ class Conv1x1Bn(torch.nn.Module):
120
+ r''' A sequence of :class:`Conv1x1` and :class:`BatchNorm`.
121
+ '''
122
+
123
+ def __init__(self, in_channels: int, out_channels: int):
124
+ super().__init__()
125
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
126
+ self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
127
+
128
+ def forward(self, data: torch.Tensor):
129
+ r''''''
130
+
131
+ out = self.conv(data)
132
+ out = self.bn(out)
133
+ return out
134
+
135
+
136
+ class Conv1x1BnRelu(torch.nn.Module):
137
+ r''' A sequence of :class:`Conv1x1`, :class:`BatchNorm` and :class:`Relu`.
138
+ '''
139
+
140
+ def __init__(self, in_channels: int, out_channels: int):
141
+ super().__init__()
142
+ self.conv = Conv1x1(in_channels, out_channels, use_bias=False)
143
+ self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
144
+ self.relu = torch.nn.ReLU(inplace=True)
145
+
146
+ def forward(self, data: torch.Tensor):
147
+ r''''''
148
+
149
+ out = self.conv(data)
150
+ out = self.bn(out)
151
+ out = self.relu(out)
152
+ return out
153
+
154
+
155
+ class FcBnRelu(torch.nn.Module):
156
+ r''' A sequence of :class:`FC`, :class:`BatchNorm` and :class:`Relu`.
157
+ '''
158
+
159
+ def __init__(self, in_channels: int, out_channels: int):
160
+ super().__init__()
161
+ self.flatten = torch.nn.Flatten(start_dim=1)
162
+ self.fc = torch.nn.Linear(in_channels, out_channels, bias=False)
163
+ self.bn = torch.nn.BatchNorm1d(out_channels) #, bn_eps, bn_momentum)
164
+ self.relu = torch.nn.ReLU(inplace=True)
165
+
166
+ def forward(self, data):
167
+ r''''''
168
+
169
+ out = self.flatten(data)
170
+ out = self.fc(out)
171
+ out = self.bn(out)
172
+ out = self.relu(out)
173
+ return out
174
+
175
+
176
+ class InputFeature(torch.nn.Module):
177
+ r''' Returns the initial input feature stored in octree.
178
+
179
+ Refer to :func:`ocnn.octree.Octree.get_input_feature` for details.
180
+ '''
181
+
182
+ def __init__(self, feature: str = 'NDF', nempty: bool = False):
183
+ super().__init__()
184
+ self.nempty = nempty
185
+ self.feature = feature.upper()
186
+
187
+ def forward(self, octree: Octree):
188
+ r''''''
189
+ return octree.get_input_feature(self.feature, self.nempty)
190
+
191
+ def extra_repr(self) -> str:
192
+ r''''''
193
+ return 'feature={}, nempty={}'.format(self.feature, self.nempty)