ocnn 2.2.6__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +44 -44
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +429 -429
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +661 -661
- ocnn/octree/points.py +323 -322
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info}/METADATA +112 -91
- ocnn-2.2.7.dist-info/RECORD +36 -0
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info}/WHEEL +1 -1
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info/licenses}/LICENSE +21 -21
- ocnn-2.2.6.dist-info/RECORD +0 -36
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info}/top_level.txt +0 -0
ocnn/nn/octree_norm.py
CHANGED
|
@@ -1,126 +1,126 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn
|
|
10
|
-
from typing import Optional
|
|
11
|
-
|
|
12
|
-
from ocnn.octree import Octree
|
|
13
|
-
from ocnn.utils import scatter_add
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
OctreeBatchNorm = torch.nn.BatchNorm1d
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class OctreeGroupNorm(torch.nn.Module):
|
|
20
|
-
r''' An group normalization layer for the octree.
|
|
21
|
-
'''
|
|
22
|
-
|
|
23
|
-
def __init__(self, in_channels: int, group: int, nempty: bool = False,
|
|
24
|
-
min_group_channels: int = 4):
|
|
25
|
-
super().__init__()
|
|
26
|
-
self.eps = 1e-5
|
|
27
|
-
self.nempty = nempty
|
|
28
|
-
self.group = group
|
|
29
|
-
self.in_channels = in_channels
|
|
30
|
-
self.min_group_channels = min_group_channels
|
|
31
|
-
if self.min_group_channels * self.group > in_channels:
|
|
32
|
-
self.group = in_channels // self.min_group_channels
|
|
33
|
-
|
|
34
|
-
assert in_channels % self.group == 0
|
|
35
|
-
self.channels_per_group = in_channels // self.group
|
|
36
|
-
|
|
37
|
-
self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
38
|
-
self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
39
|
-
self.reset_parameters()
|
|
40
|
-
|
|
41
|
-
def reset_parameters(self):
|
|
42
|
-
torch.nn.init.ones_(self.weights)
|
|
43
|
-
torch.nn.init.zeros_(self.bias)
|
|
44
|
-
|
|
45
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
46
|
-
r''''''
|
|
47
|
-
|
|
48
|
-
batch_size = octree.batch_size
|
|
49
|
-
batch_id = octree.batch_id(depth, self.nempty)
|
|
50
|
-
ones = data.new_ones([data.shape[0], 1])
|
|
51
|
-
count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
|
|
52
|
-
count = count * self.channels_per_group # element number in each group
|
|
53
|
-
inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes
|
|
54
|
-
|
|
55
|
-
mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count
|
|
56
|
-
mean = self._adjust_for_group(mean)
|
|
57
|
-
out = data - mean.index_select(0, batch_id)
|
|
58
|
-
|
|
59
|
-
var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count
|
|
60
|
-
var = self._adjust_for_group(var)
|
|
61
|
-
inv_std = 1.0 / (var + self.eps).sqrt()
|
|
62
|
-
out = out * inv_std.index_select(0, batch_id)
|
|
63
|
-
|
|
64
|
-
out = out * self.weights + self.bias
|
|
65
|
-
return out
|
|
66
|
-
|
|
67
|
-
def _adjust_for_group(self, tensor: torch.Tensor):
|
|
68
|
-
r''' Adjust the tensor for the group.
|
|
69
|
-
'''
|
|
70
|
-
|
|
71
|
-
if self.channels_per_group > 1:
|
|
72
|
-
tensor = (tensor.reshape(-1, self.group, self.channels_per_group)
|
|
73
|
-
.sum(-1, keepdim=True)
|
|
74
|
-
.repeat(1, 1, self.channels_per_group)
|
|
75
|
-
.reshape(-1, self.in_channels))
|
|
76
|
-
return tensor
|
|
77
|
-
|
|
78
|
-
def extra_repr(self) -> str:
|
|
79
|
-
return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
|
|
80
|
-
self.in_channels, self.group, self.nempty, self.min_group_channels)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class OctreeInstanceNorm(OctreeGroupNorm):
|
|
84
|
-
r''' An instance normalization layer for the octree.
|
|
85
|
-
'''
|
|
86
|
-
|
|
87
|
-
def __init__(self, in_channels: int, nempty: bool = False):
|
|
88
|
-
super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
|
|
89
|
-
min_group_channels=1) # NOTE: group=in_channels
|
|
90
|
-
|
|
91
|
-
def extra_repr(self) -> str:
|
|
92
|
-
return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
class OctreeNorm(torch.nn.Module):
|
|
96
|
-
r''' A normalization layer for the octree. It encapsulates octree-based batch,
|
|
97
|
-
group and instance normalization.
|
|
98
|
-
'''
|
|
99
|
-
|
|
100
|
-
def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
|
|
101
|
-
group: int = 32, min_group_channels: int = 4):
|
|
102
|
-
super().__init__()
|
|
103
|
-
self.in_channels = in_channels
|
|
104
|
-
self.norm_type = norm_type
|
|
105
|
-
self.group = group
|
|
106
|
-
self.min_group_channels = min_group_channels
|
|
107
|
-
|
|
108
|
-
if self.norm_type == 'batch_norm':
|
|
109
|
-
self.norm = torch.nn.BatchNorm1d(in_channels)
|
|
110
|
-
elif self.norm_type == 'group_norm':
|
|
111
|
-
self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
|
|
112
|
-
elif self.norm_type == 'instance_norm':
|
|
113
|
-
self.norm = OctreeInstanceNorm(in_channels)
|
|
114
|
-
else:
|
|
115
|
-
raise ValueError
|
|
116
|
-
|
|
117
|
-
def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
|
|
118
|
-
depth: Optional[int] = None):
|
|
119
|
-
if self.norm_type == 'batch_norm':
|
|
120
|
-
output = self.norm(x)
|
|
121
|
-
elif (self.norm_type == 'group_norm' or
|
|
122
|
-
self.norm_type == 'instance_norm'):
|
|
123
|
-
output = self.norm(x, octree, depth)
|
|
124
|
-
else:
|
|
125
|
-
raise ValueError
|
|
126
|
-
return output
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
from ocnn.octree import Octree
|
|
13
|
+
from ocnn.utils import scatter_add
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
OctreeBatchNorm = torch.nn.BatchNorm1d
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OctreeGroupNorm(torch.nn.Module):
|
|
20
|
+
r''' An group normalization layer for the octree.
|
|
21
|
+
'''
|
|
22
|
+
|
|
23
|
+
def __init__(self, in_channels: int, group: int, nempty: bool = False,
|
|
24
|
+
min_group_channels: int = 4):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.eps = 1e-5
|
|
27
|
+
self.nempty = nempty
|
|
28
|
+
self.group = group
|
|
29
|
+
self.in_channels = in_channels
|
|
30
|
+
self.min_group_channels = min_group_channels
|
|
31
|
+
if self.min_group_channels * self.group > in_channels:
|
|
32
|
+
self.group = in_channels // self.min_group_channels
|
|
33
|
+
|
|
34
|
+
assert in_channels % self.group == 0
|
|
35
|
+
self.channels_per_group = in_channels // self.group
|
|
36
|
+
|
|
37
|
+
self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
38
|
+
self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
39
|
+
self.reset_parameters()
|
|
40
|
+
|
|
41
|
+
def reset_parameters(self):
|
|
42
|
+
torch.nn.init.ones_(self.weights)
|
|
43
|
+
torch.nn.init.zeros_(self.bias)
|
|
44
|
+
|
|
45
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
46
|
+
r''''''
|
|
47
|
+
|
|
48
|
+
batch_size = octree.batch_size
|
|
49
|
+
batch_id = octree.batch_id(depth, self.nempty)
|
|
50
|
+
ones = data.new_ones([data.shape[0], 1])
|
|
51
|
+
count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
|
|
52
|
+
count = count * self.channels_per_group # element number in each group
|
|
53
|
+
inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes
|
|
54
|
+
|
|
55
|
+
mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count
|
|
56
|
+
mean = self._adjust_for_group(mean)
|
|
57
|
+
out = data - mean.index_select(0, batch_id)
|
|
58
|
+
|
|
59
|
+
var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count
|
|
60
|
+
var = self._adjust_for_group(var)
|
|
61
|
+
inv_std = 1.0 / (var + self.eps).sqrt()
|
|
62
|
+
out = out * inv_std.index_select(0, batch_id)
|
|
63
|
+
|
|
64
|
+
out = out * self.weights + self.bias
|
|
65
|
+
return out
|
|
66
|
+
|
|
67
|
+
def _adjust_for_group(self, tensor: torch.Tensor):
|
|
68
|
+
r''' Adjust the tensor for the group.
|
|
69
|
+
'''
|
|
70
|
+
|
|
71
|
+
if self.channels_per_group > 1:
|
|
72
|
+
tensor = (tensor.reshape(-1, self.group, self.channels_per_group)
|
|
73
|
+
.sum(-1, keepdim=True)
|
|
74
|
+
.repeat(1, 1, self.channels_per_group)
|
|
75
|
+
.reshape(-1, self.in_channels))
|
|
76
|
+
return tensor
|
|
77
|
+
|
|
78
|
+
def extra_repr(self) -> str:
|
|
79
|
+
return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
|
|
80
|
+
self.in_channels, self.group, self.nempty, self.min_group_channels)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class OctreeInstanceNorm(OctreeGroupNorm):
|
|
84
|
+
r''' An instance normalization layer for the octree.
|
|
85
|
+
'''
|
|
86
|
+
|
|
87
|
+
def __init__(self, in_channels: int, nempty: bool = False):
|
|
88
|
+
super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
|
|
89
|
+
min_group_channels=1) # NOTE: group=in_channels
|
|
90
|
+
|
|
91
|
+
def extra_repr(self) -> str:
|
|
92
|
+
return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class OctreeNorm(torch.nn.Module):
|
|
96
|
+
r''' A normalization layer for the octree. It encapsulates octree-based batch,
|
|
97
|
+
group and instance normalization.
|
|
98
|
+
'''
|
|
99
|
+
|
|
100
|
+
def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
|
|
101
|
+
group: int = 32, min_group_channels: int = 4):
|
|
102
|
+
super().__init__()
|
|
103
|
+
self.in_channels = in_channels
|
|
104
|
+
self.norm_type = norm_type
|
|
105
|
+
self.group = group
|
|
106
|
+
self.min_group_channels = min_group_channels
|
|
107
|
+
|
|
108
|
+
if self.norm_type == 'batch_norm':
|
|
109
|
+
self.norm = torch.nn.BatchNorm1d(in_channels)
|
|
110
|
+
elif self.norm_type == 'group_norm':
|
|
111
|
+
self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
|
|
112
|
+
elif self.norm_type == 'instance_norm':
|
|
113
|
+
self.norm = OctreeInstanceNorm(in_channels)
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError
|
|
116
|
+
|
|
117
|
+
def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
|
|
118
|
+
depth: Optional[int] = None):
|
|
119
|
+
if self.norm_type == 'batch_norm':
|
|
120
|
+
output = self.norm(x)
|
|
121
|
+
elif (self.norm_type == 'group_norm' or
|
|
122
|
+
self.norm_type == 'instance_norm'):
|
|
123
|
+
output = self.norm(x, octree, depth)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError
|
|
126
|
+
return output
|
ocnn/nn/octree_pad.py
CHANGED
|
@@ -1,39 +1,39 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
|
|
10
|
-
from ocnn.octree import Octree
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
|
|
14
|
-
r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
|
|
15
|
-
the octree node number.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
data (torch.Tensor): The input tensor with its number of elements equal to the
|
|
19
|
-
non-empty octree node number.
|
|
20
|
-
octree (Octree): The corresponding octree.
|
|
21
|
-
depth (int): The depth of current octree.
|
|
22
|
-
val (float): The padding value. (Default: :obj:`0.0`)
|
|
23
|
-
'''
|
|
24
|
-
|
|
25
|
-
mask = octree.nempty_mask(depth)
|
|
26
|
-
size = (octree.nnum[depth], data.shape[1]) # (N, C)
|
|
27
|
-
out = torch.full(size, val, dtype=data.dtype, device=data.device)
|
|
28
|
-
out[mask] = data
|
|
29
|
-
return out
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
|
|
33
|
-
r''' Reverse operation of :func:`octree_depad`.
|
|
34
|
-
|
|
35
|
-
Please refer to :func:`octree_depad` for the meaning of the arguments.
|
|
36
|
-
'''
|
|
37
|
-
|
|
38
|
-
mask = octree.nempty_mask(depth)
|
|
39
|
-
return data[mask]
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ocnn.octree import Octree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
|
|
14
|
+
r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
|
|
15
|
+
the octree node number.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
data (torch.Tensor): The input tensor with its number of elements equal to the
|
|
19
|
+
non-empty octree node number.
|
|
20
|
+
octree (Octree): The corresponding octree.
|
|
21
|
+
depth (int): The depth of current octree.
|
|
22
|
+
val (float): The padding value. (Default: :obj:`0.0`)
|
|
23
|
+
'''
|
|
24
|
+
|
|
25
|
+
mask = octree.nempty_mask(depth)
|
|
26
|
+
size = (octree.nnum[depth], data.shape[1]) # (N, C)
|
|
27
|
+
out = torch.full(size, val, dtype=data.dtype, device=data.device)
|
|
28
|
+
out[mask] = data
|
|
29
|
+
return out
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
|
|
33
|
+
r''' Reverse operation of :func:`octree_depad`.
|
|
34
|
+
|
|
35
|
+
Please refer to :func:`octree_depad` for the meaning of the arguments.
|
|
36
|
+
'''
|
|
37
|
+
|
|
38
|
+
mask = octree.nempty_mask(depth)
|
|
39
|
+
return data[mask]
|