ocnn 2.2.6__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/nn/octree_norm.py CHANGED
@@ -1,126 +1,126 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from typing import Optional
11
-
12
- from ocnn.octree import Octree
13
- from ocnn.utils import scatter_add
14
-
15
-
16
- OctreeBatchNorm = torch.nn.BatchNorm1d
17
-
18
-
19
- class OctreeGroupNorm(torch.nn.Module):
20
- r''' An group normalization layer for the octree.
21
- '''
22
-
23
- def __init__(self, in_channels: int, group: int, nempty: bool = False,
24
- min_group_channels: int = 4):
25
- super().__init__()
26
- self.eps = 1e-5
27
- self.nempty = nempty
28
- self.group = group
29
- self.in_channels = in_channels
30
- self.min_group_channels = min_group_channels
31
- if self.min_group_channels * self.group > in_channels:
32
- self.group = in_channels // self.min_group_channels
33
-
34
- assert in_channels % self.group == 0
35
- self.channels_per_group = in_channels // self.group
36
-
37
- self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
38
- self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
39
- self.reset_parameters()
40
-
41
- def reset_parameters(self):
42
- torch.nn.init.ones_(self.weights)
43
- torch.nn.init.zeros_(self.bias)
44
-
45
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
46
- r''''''
47
-
48
- batch_size = octree.batch_size
49
- batch_id = octree.batch_id(depth, self.nempty)
50
- ones = data.new_ones([data.shape[0], 1])
51
- count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
52
- count = count * self.channels_per_group # element number in each group
53
- inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes
54
-
55
- mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count
56
- mean = self._adjust_for_group(mean)
57
- out = data - mean.index_select(0, batch_id)
58
-
59
- var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count
60
- var = self._adjust_for_group(var)
61
- inv_std = 1.0 / (var + self.eps).sqrt()
62
- out = out * inv_std.index_select(0, batch_id)
63
-
64
- out = out * self.weights + self.bias
65
- return out
66
-
67
- def _adjust_for_group(self, tensor: torch.Tensor):
68
- r''' Adjust the tensor for the group.
69
- '''
70
-
71
- if self.channels_per_group > 1:
72
- tensor = (tensor.reshape(-1, self.group, self.channels_per_group)
73
- .sum(-1, keepdim=True)
74
- .repeat(1, 1, self.channels_per_group)
75
- .reshape(-1, self.in_channels))
76
- return tensor
77
-
78
- def extra_repr(self) -> str:
79
- return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
80
- self.in_channels, self.group, self.nempty, self.min_group_channels)
81
-
82
-
83
- class OctreeInstanceNorm(OctreeGroupNorm):
84
- r''' An instance normalization layer for the octree.
85
- '''
86
-
87
- def __init__(self, in_channels: int, nempty: bool = False):
88
- super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
89
- min_group_channels=1) # NOTE: group=in_channels
90
-
91
- def extra_repr(self) -> str:
92
- return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
93
-
94
-
95
- class OctreeNorm(torch.nn.Module):
96
- r''' A normalization layer for the octree. It encapsulates octree-based batch,
97
- group and instance normalization.
98
- '''
99
-
100
- def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
101
- group: int = 32, min_group_channels: int = 4):
102
- super().__init__()
103
- self.in_channels = in_channels
104
- self.norm_type = norm_type
105
- self.group = group
106
- self.min_group_channels = min_group_channels
107
-
108
- if self.norm_type == 'batch_norm':
109
- self.norm = torch.nn.BatchNorm1d(in_channels)
110
- elif self.norm_type == 'group_norm':
111
- self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
112
- elif self.norm_type == 'instance_norm':
113
- self.norm = OctreeInstanceNorm(in_channels)
114
- else:
115
- raise ValueError
116
-
117
- def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
118
- depth: Optional[int] = None):
119
- if self.norm_type == 'batch_norm':
120
- output = self.norm(x)
121
- elif (self.norm_type == 'group_norm' or
122
- self.norm_type == 'instance_norm'):
123
- output = self.norm(x, octree, depth)
124
- else:
125
- raise ValueError
126
- return output
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from typing import Optional
11
+
12
+ from ocnn.octree import Octree
13
+ from ocnn.utils import scatter_add
14
+
15
+
16
+ OctreeBatchNorm = torch.nn.BatchNorm1d
17
+
18
+
19
+ class OctreeGroupNorm(torch.nn.Module):
20
+ r''' An group normalization layer for the octree.
21
+ '''
22
+
23
+ def __init__(self, in_channels: int, group: int, nempty: bool = False,
24
+ min_group_channels: int = 4):
25
+ super().__init__()
26
+ self.eps = 1e-5
27
+ self.nempty = nempty
28
+ self.group = group
29
+ self.in_channels = in_channels
30
+ self.min_group_channels = min_group_channels
31
+ if self.min_group_channels * self.group > in_channels:
32
+ self.group = in_channels // self.min_group_channels
33
+
34
+ assert in_channels % self.group == 0
35
+ self.channels_per_group = in_channels // self.group
36
+
37
+ self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
38
+ self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
39
+ self.reset_parameters()
40
+
41
+ def reset_parameters(self):
42
+ torch.nn.init.ones_(self.weights)
43
+ torch.nn.init.zeros_(self.bias)
44
+
45
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
46
+ r''''''
47
+
48
+ batch_size = octree.batch_size
49
+ batch_id = octree.batch_id(depth, self.nempty)
50
+ ones = data.new_ones([data.shape[0], 1])
51
+ count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
52
+ count = count * self.channels_per_group # element number in each group
53
+ inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes
54
+
55
+ mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count
56
+ mean = self._adjust_for_group(mean)
57
+ out = data - mean.index_select(0, batch_id)
58
+
59
+ var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count
60
+ var = self._adjust_for_group(var)
61
+ inv_std = 1.0 / (var + self.eps).sqrt()
62
+ out = out * inv_std.index_select(0, batch_id)
63
+
64
+ out = out * self.weights + self.bias
65
+ return out
66
+
67
+ def _adjust_for_group(self, tensor: torch.Tensor):
68
+ r''' Adjust the tensor for the group.
69
+ '''
70
+
71
+ if self.channels_per_group > 1:
72
+ tensor = (tensor.reshape(-1, self.group, self.channels_per_group)
73
+ .sum(-1, keepdim=True)
74
+ .repeat(1, 1, self.channels_per_group)
75
+ .reshape(-1, self.in_channels))
76
+ return tensor
77
+
78
+ def extra_repr(self) -> str:
79
+ return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
80
+ self.in_channels, self.group, self.nempty, self.min_group_channels)
81
+
82
+
83
+ class OctreeInstanceNorm(OctreeGroupNorm):
84
+ r''' An instance normalization layer for the octree.
85
+ '''
86
+
87
+ def __init__(self, in_channels: int, nempty: bool = False):
88
+ super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
89
+ min_group_channels=1) # NOTE: group=in_channels
90
+
91
+ def extra_repr(self) -> str:
92
+ return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
93
+
94
+
95
+ class OctreeNorm(torch.nn.Module):
96
+ r''' A normalization layer for the octree. It encapsulates octree-based batch,
97
+ group and instance normalization.
98
+ '''
99
+
100
+ def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
101
+ group: int = 32, min_group_channels: int = 4):
102
+ super().__init__()
103
+ self.in_channels = in_channels
104
+ self.norm_type = norm_type
105
+ self.group = group
106
+ self.min_group_channels = min_group_channels
107
+
108
+ if self.norm_type == 'batch_norm':
109
+ self.norm = torch.nn.BatchNorm1d(in_channels)
110
+ elif self.norm_type == 'group_norm':
111
+ self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
112
+ elif self.norm_type == 'instance_norm':
113
+ self.norm = OctreeInstanceNorm(in_channels)
114
+ else:
115
+ raise ValueError
116
+
117
+ def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
118
+ depth: Optional[int] = None):
119
+ if self.norm_type == 'batch_norm':
120
+ output = self.norm(x)
121
+ elif (self.norm_type == 'group_norm' or
122
+ self.norm_type == 'instance_norm'):
123
+ output = self.norm(x, octree, depth)
124
+ else:
125
+ raise ValueError
126
+ return output
ocnn/nn/octree_pad.py CHANGED
@@ -1,39 +1,39 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
-
10
- from ocnn.octree import Octree
11
-
12
-
13
- def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
14
- r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
15
- the octree node number.
16
-
17
- Args:
18
- data (torch.Tensor): The input tensor with its number of elements equal to the
19
- non-empty octree node number.
20
- octree (Octree): The corresponding octree.
21
- depth (int): The depth of current octree.
22
- val (float): The padding value. (Default: :obj:`0.0`)
23
- '''
24
-
25
- mask = octree.nempty_mask(depth)
26
- size = (octree.nnum[depth], data.shape[1]) # (N, C)
27
- out = torch.full(size, val, dtype=data.dtype, device=data.device)
28
- out[mask] = data
29
- return out
30
-
31
-
32
- def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
33
- r''' Reverse operation of :func:`octree_depad`.
34
-
35
- Please refer to :func:`octree_depad` for the meaning of the arguments.
36
- '''
37
-
38
- mask = octree.nempty_mask(depth)
39
- return data[mask]
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ from ocnn.octree import Octree
11
+
12
+
13
+ def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
14
+ r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
15
+ the octree node number.
16
+
17
+ Args:
18
+ data (torch.Tensor): The input tensor with its number of elements equal to the
19
+ non-empty octree node number.
20
+ octree (Octree): The corresponding octree.
21
+ depth (int): The depth of current octree.
22
+ val (float): The padding value. (Default: :obj:`0.0`)
23
+ '''
24
+
25
+ mask = octree.nempty_mask(depth)
26
+ size = (octree.nnum[depth], data.shape[1]) # (N, C)
27
+ out = torch.full(size, val, dtype=data.dtype, device=data.device)
28
+ out[mask] = data
29
+ return out
30
+
31
+
32
+ def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
33
+ r''' Reverse operation of :func:`octree_depad`.
34
+
35
+ Please refer to :func:`octree_depad` for the meaning of the arguments.
36
+ '''
37
+
38
+ mask = octree.nempty_mask(depth)
39
+ return data[mask]