ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ocnn/__init__.py +24 -24
  2. ocnn/dataset.py +160 -160
  3. ocnn/models/__init__.py +29 -29
  4. ocnn/models/autoencoder.py +155 -155
  5. ocnn/models/hrnet.py +192 -192
  6. ocnn/models/image2shape.py +128 -128
  7. ocnn/models/lenet.py +46 -46
  8. ocnn/models/ounet.py +94 -94
  9. ocnn/models/resnet.py +53 -53
  10. ocnn/models/segnet.py +72 -72
  11. ocnn/models/unet.py +105 -105
  12. ocnn/modules/__init__.py +26 -26
  13. ocnn/modules/modules.py +303 -303
  14. ocnn/modules/resblocks.py +158 -158
  15. ocnn/nn/__init__.py +45 -44
  16. ocnn/nn/kernels/__init__.py +14 -0
  17. ocnn/nn/kernels/autotuner.py +416 -0
  18. ocnn/nn/kernels/config.py +67 -0
  19. ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
  20. ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
  21. ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
  22. ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
  23. ocnn/nn/kernels/utils.py +44 -0
  24. ocnn/nn/octree2col.py +53 -53
  25. ocnn/nn/octree2vox.py +50 -50
  26. ocnn/nn/octree_align.py +46 -46
  27. ocnn/nn/octree_conv.py +430 -429
  28. ocnn/nn/octree_conv_t.py +148 -0
  29. ocnn/nn/octree_drop.py +55 -55
  30. ocnn/nn/octree_dwconv.py +222 -222
  31. ocnn/nn/octree_gconv.py +79 -79
  32. ocnn/nn/octree_interp.py +196 -196
  33. ocnn/nn/octree_norm.py +126 -126
  34. ocnn/nn/octree_pad.py +39 -39
  35. ocnn/nn/octree_pool.py +200 -200
  36. ocnn/octree/__init__.py +22 -22
  37. ocnn/octree/octree.py +770 -770
  38. ocnn/octree/points.py +384 -323
  39. ocnn/octree/shuffled_key.py +115 -115
  40. ocnn/utils.py +205 -205
  41. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
  42. ocnn-2.3.0.dist-info/RECORD +45 -0
  43. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
  44. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
  45. ocnn-2.2.8.dist-info/RECORD +0 -36
  46. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/nn/octree_interp.py CHANGED
@@ -1,196 +1,196 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.sparse
10
- from typing import List, Optional
11
-
12
- import ocnn
13
- from ocnn.octree import Octree
14
-
15
-
16
- def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int,
17
- pts: torch.Tensor, nempty: bool = False,
18
- bound_check: bool = False):
19
- ''' The nearest-neighbor interpolatation with input points.
20
-
21
- Args:
22
- data (torch.Tensor): The input data.
23
- octree (Octree): The octree to interpolate.
24
- depth (int): The depth of the data.
25
- pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`,
26
- i.e. :obj:`N x (x, y, z, batch)`.
27
- nempty (bool): If true, the :attr:`data` only contains features of non-empty
28
- octree nodes
29
- bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`.
30
-
31
- .. note::
32
- The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`.
33
- '''
34
-
35
- nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
36
- assert data.shape[0] == nnum, 'The shape of input data is wrong.'
37
-
38
- idx = octree.search_xyzb(pts, depth, nempty)
39
- valid = idx > -1 # valid indices
40
- if bound_check:
41
- bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1)
42
- valid = torch.logical_and(valid, bound)
43
-
44
- size = (pts.shape[0], data.shape[1])
45
- out = torch.zeros(size, device=data.device, dtype=data.dtype)
46
- out[valid] = data.index_select(0, idx[valid])
47
- return out
48
-
49
-
50
- def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int,
51
- pts: torch.Tensor, nempty: bool = False,
52
- bound_check: bool = False):
53
- ''' Linear interpolatation with input points.
54
-
55
- Refer to :func:`octree_nearest_pts` for the meaning of the arguments.
56
- '''
57
-
58
- nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
59
- assert data.shape[0] == nnum, 'The shape of input data is wrong.'
60
-
61
- device = data.device
62
- grid = torch.tensor(
63
- [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
64
- [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device)
65
-
66
- # 1. Neighborhood searching
67
- xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel
68
- xyzi = xyzf.floor() # the integer part (N, 3)
69
- frac = xyzf - xyzi # the fraction part (N, 3)
70
-
71
- xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3)
72
- batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1)
73
- idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty)
74
- valid = idx > -1 # valid indices
75
- if bound_check:
76
- bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1)
77
- valid = torch.logical_and(valid, bound)
78
- idx = idx[valid]
79
-
80
- # 2. Build the sparse matrix
81
- npt = pts.shape[0]
82
- ids = torch.arange(npt, device=idx.device)
83
- ids = ids.unsqueeze(1).repeat(1, 8).view(-1)
84
- ids = ids[valid]
85
- indices = torch.stack([ids, idx], dim=0).long()
86
-
87
- frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3)
88
- weight = frac.prod(dim=2).abs().view(-1) # (8*N,)
89
- weight = weight[valid]
90
-
91
- h = data.shape[0]
92
- mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device)
93
-
94
- # 3. Interpolatation
95
- output = torch.sparse.mm(mat, data)
96
- ones = torch.ones(h, 1, dtype=data.dtype, device=device)
97
- norm = torch.sparse.mm(mat, ones)
98
- output = torch.div(output, norm + 1e-12)
99
- return output
100
-
101
-
102
- class OctreeInterp(torch.nn.Module):
103
- r''' Interpolates the points with an octree feature.
104
-
105
- Refer to :func:`octree_nearest_pts` for a description of arguments.
106
- '''
107
-
108
- def __init__(self, method: str = 'linear', nempty: bool = False,
109
- bound_check: bool = False, rescale_pts: bool = True):
110
- super().__init__()
111
- self.method = method
112
- self.nempty = nempty
113
- self.bound_check = bound_check
114
- self.rescale_pts = rescale_pts
115
- self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
116
-
117
- def forward(self, data: torch.Tensor, octree: Octree, depth: int,
118
- pts: torch.Tensor):
119
- r''''''
120
-
121
- # rescale points from [-1, 1] to [0, 2^depth]
122
- if self.rescale_pts:
123
- scale = 2 ** (depth - 1)
124
- pts[:, :3] = (pts[:, :3] + 1.0) * scale
125
-
126
- return self.func(data, octree, depth, pts, self.nempty, self.bound_check)
127
-
128
- def extra_repr(self) -> str:
129
- r''' Sets the extra representation of the module.
130
- '''
131
-
132
- return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format(
133
- self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa
134
-
135
-
136
- def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int,
137
- nempty: bool = False):
138
- r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)`
139
- with the nearest-neighbor interpolation.
140
-
141
- Args:
142
- data (torch.Tensor): The input data.
143
- octree (Octree): The octree to interpolate.
144
- depth (int): The depth of the data.
145
- nempty (bool): If true, the :attr:`data` only contains features of non-empty
146
- octree nodes.
147
- '''
148
-
149
- nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
150
- assert data.shape[0] == nnum, 'The shape of input data is wrong.'
151
-
152
- out = data
153
- if not nempty:
154
- out = ocnn.nn.octree_depad(out, octree, depth)
155
- out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1)
156
- if nempty:
157
- out = ocnn.nn.octree_depad(out, octree, depth+1) # !!! depth+1
158
- return out
159
-
160
-
161
- class OctreeUpsample(torch.nn.Module):
162
- r''' Upsamples the octree node features from :attr:`depth` to
163
- :attr:`(target_depth)`.
164
-
165
- Refer to :class:`octree_nearest_pts` for details.
166
- '''
167
-
168
- def __init__(self, method: str = 'linear', nempty: bool = False):
169
- super().__init__()
170
- self.method = method
171
- self.nempty = nempty
172
- self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
173
-
174
- def forward(self, data: torch.Tensor, octree: Octree, depth: int,
175
- target_depth: Optional[int] = None):
176
- r''''''
177
-
178
- if target_depth is None:
179
- target_depth = depth + 1
180
- if target_depth == depth:
181
- return data # return, do nothing
182
- assert target_depth >= depth, 'target_depth must be larger than depth'
183
-
184
- if target_depth == depth + 1 and self.method == 'nearest':
185
- return octree_nearest_upsample(data, octree, depth, self.nempty)
186
-
187
- xyzb = octree.xyzb(target_depth, self.nempty)
188
- pts = torch.stack(xyzb, dim=1).float()
189
- pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale
190
- return self.func(data, octree, depth, pts, self.nempty)
191
-
192
- def extra_repr(self) -> str:
193
- r''' Sets the extra representation of the module.
194
- '''
195
-
196
- return ('method={}, nempty={}').format(self.method, self.nempty)
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.sparse
10
+ from typing import List, Optional
11
+
12
+ import ocnn
13
+ from ocnn.octree import Octree
14
+
15
+
16
+ def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int,
17
+ pts: torch.Tensor, nempty: bool = False,
18
+ bound_check: bool = False):
19
+ ''' The nearest-neighbor interpolatation with input points.
20
+
21
+ Args:
22
+ data (torch.Tensor): The input data.
23
+ octree (Octree): The octree to interpolate.
24
+ depth (int): The depth of the data.
25
+ pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`,
26
+ i.e. :obj:`N x (x, y, z, batch)`.
27
+ nempty (bool): If true, the :attr:`data` only contains features of non-empty
28
+ octree nodes
29
+ bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`.
30
+
31
+ .. note::
32
+ The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`.
33
+ '''
34
+
35
+ nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
36
+ assert data.shape[0] == nnum, 'The shape of input data is wrong.'
37
+
38
+ idx = octree.search_xyzb(pts, depth, nempty)
39
+ valid = idx > -1 # valid indices
40
+ if bound_check:
41
+ bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1)
42
+ valid = torch.logical_and(valid, bound)
43
+
44
+ size = (pts.shape[0], data.shape[1])
45
+ out = torch.zeros(size, device=data.device, dtype=data.dtype)
46
+ out[valid] = data.index_select(0, idx[valid])
47
+ return out
48
+
49
+
50
+ def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int,
51
+ pts: torch.Tensor, nempty: bool = False,
52
+ bound_check: bool = False):
53
+ ''' Linear interpolatation with input points.
54
+
55
+ Refer to :func:`octree_nearest_pts` for the meaning of the arguments.
56
+ '''
57
+
58
+ nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
59
+ assert data.shape[0] == nnum, 'The shape of input data is wrong.'
60
+
61
+ device = data.device
62
+ grid = torch.tensor(
63
+ [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
64
+ [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device)
65
+
66
+ # 1. Neighborhood searching
67
+ xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel
68
+ xyzi = xyzf.floor() # the integer part (N, 3)
69
+ frac = xyzf - xyzi # the fraction part (N, 3)
70
+
71
+ xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3)
72
+ batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1)
73
+ idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty)
74
+ valid = idx > -1 # valid indices
75
+ if bound_check:
76
+ bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1)
77
+ valid = torch.logical_and(valid, bound)
78
+ idx = idx[valid]
79
+
80
+ # 2. Build the sparse matrix
81
+ npt = pts.shape[0]
82
+ ids = torch.arange(npt, device=idx.device)
83
+ ids = ids.unsqueeze(1).repeat(1, 8).view(-1)
84
+ ids = ids[valid]
85
+ indices = torch.stack([ids, idx], dim=0).long()
86
+
87
+ frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3)
88
+ weight = frac.prod(dim=2).abs().view(-1) # (8*N,)
89
+ weight = weight[valid]
90
+
91
+ h = data.shape[0]
92
+ mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device)
93
+
94
+ # 3. Interpolatation
95
+ output = torch.sparse.mm(mat, data)
96
+ ones = torch.ones(h, 1, dtype=data.dtype, device=device)
97
+ norm = torch.sparse.mm(mat, ones)
98
+ output = torch.div(output, norm + 1e-12)
99
+ return output
100
+
101
+
102
+ class OctreeInterp(torch.nn.Module):
103
+ r''' Interpolates the points with an octree feature.
104
+
105
+ Refer to :func:`octree_nearest_pts` for a description of arguments.
106
+ '''
107
+
108
+ def __init__(self, method: str = 'linear', nempty: bool = False,
109
+ bound_check: bool = False, rescale_pts: bool = True):
110
+ super().__init__()
111
+ self.method = method
112
+ self.nempty = nempty
113
+ self.bound_check = bound_check
114
+ self.rescale_pts = rescale_pts
115
+ self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
116
+
117
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int,
118
+ pts: torch.Tensor):
119
+ r''''''
120
+
121
+ # rescale points from [-1, 1] to [0, 2^depth]
122
+ if self.rescale_pts:
123
+ scale = 2 ** (depth - 1)
124
+ pts[:, :3] = (pts[:, :3] + 1.0) * scale
125
+
126
+ return self.func(data, octree, depth, pts, self.nempty, self.bound_check)
127
+
128
+ def extra_repr(self) -> str:
129
+ r''' Sets the extra representation of the module.
130
+ '''
131
+
132
+ return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format(
133
+ self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa
134
+
135
+
136
+ def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int,
137
+ nempty: bool = False):
138
+ r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)`
139
+ with the nearest-neighbor interpolation.
140
+
141
+ Args:
142
+ data (torch.Tensor): The input data.
143
+ octree (Octree): The octree to interpolate.
144
+ depth (int): The depth of the data.
145
+ nempty (bool): If true, the :attr:`data` only contains features of non-empty
146
+ octree nodes.
147
+ '''
148
+
149
+ nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
150
+ assert data.shape[0] == nnum, 'The shape of input data is wrong.'
151
+
152
+ out = data
153
+ if not nempty:
154
+ out = ocnn.nn.octree_depad(out, octree, depth)
155
+ out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1)
156
+ if nempty:
157
+ out = ocnn.nn.octree_depad(out, octree, depth+1) # !!! depth+1
158
+ return out
159
+
160
+
161
+ class OctreeUpsample(torch.nn.Module):
162
+ r''' Upsamples the octree node features from :attr:`depth` to
163
+ :attr:`(target_depth)`.
164
+
165
+ Refer to :class:`octree_nearest_pts` for details.
166
+ '''
167
+
168
+ def __init__(self, method: str = 'linear', nempty: bool = False):
169
+ super().__init__()
170
+ self.method = method
171
+ self.nempty = nempty
172
+ self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
173
+
174
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int,
175
+ target_depth: Optional[int] = None):
176
+ r''''''
177
+
178
+ if target_depth is None:
179
+ target_depth = depth + 1
180
+ if target_depth == depth:
181
+ return data # return, do nothing
182
+ assert target_depth >= depth, 'target_depth must be larger than depth'
183
+
184
+ if target_depth == depth + 1 and self.method == 'nearest':
185
+ return octree_nearest_upsample(data, octree, depth, self.nempty)
186
+
187
+ xyzb = octree.xyzb(target_depth, self.nempty)
188
+ pts = torch.stack(xyzb, dim=1).float()
189
+ pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale
190
+ return self.func(data, octree, depth, pts, self.nempty)
191
+
192
+ def extra_repr(self) -> str:
193
+ r''' Sets the extra representation of the module.
194
+ '''
195
+
196
+ return ('method={}, nempty={}').format(self.method, self.nempty)
ocnn/nn/octree_norm.py CHANGED
@@ -1,126 +1,126 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
- from typing import Optional
11
-
12
- from ocnn.octree import Octree
13
- from ocnn.utils import scatter_add
14
-
15
-
16
- OctreeBatchNorm = torch.nn.BatchNorm1d
17
-
18
-
19
- class OctreeGroupNorm(torch.nn.Module):
20
- r''' An group normalization layer for the octree.
21
- '''
22
-
23
- def __init__(self, in_channels: int, group: int, nempty: bool = False,
24
- min_group_channels: int = 4):
25
- super().__init__()
26
- self.eps = 1e-5
27
- self.nempty = nempty
28
- self.group = group
29
- self.in_channels = in_channels
30
- self.min_group_channels = min_group_channels
31
- if self.min_group_channels * self.group > in_channels:
32
- self.group = in_channels // self.min_group_channels
33
-
34
- assert in_channels % self.group == 0
35
- self.channels_per_group = in_channels // self.group
36
-
37
- self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
38
- self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
39
- self.reset_parameters()
40
-
41
- def reset_parameters(self):
42
- torch.nn.init.ones_(self.weights)
43
- torch.nn.init.zeros_(self.bias)
44
-
45
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
46
- r''''''
47
-
48
- batch_size = octree.batch_size
49
- batch_id = octree.batch_id(depth, self.nempty)
50
- ones = data.new_ones([data.shape[0], 1])
51
- count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
52
- count = count * self.channels_per_group # element number in each group
53
- inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes
54
-
55
- mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count
56
- mean = self._adjust_for_group(mean)
57
- out = data - mean.index_select(0, batch_id)
58
-
59
- var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count
60
- var = self._adjust_for_group(var)
61
- inv_std = 1.0 / (var + self.eps).sqrt()
62
- out = out * inv_std.index_select(0, batch_id)
63
-
64
- out = out * self.weights + self.bias
65
- return out
66
-
67
- def _adjust_for_group(self, tensor: torch.Tensor):
68
- r''' Adjust the tensor for the group.
69
- '''
70
-
71
- if self.channels_per_group > 1:
72
- tensor = (tensor.reshape(-1, self.group, self.channels_per_group)
73
- .sum(-1, keepdim=True)
74
- .repeat(1, 1, self.channels_per_group)
75
- .reshape(-1, self.in_channels))
76
- return tensor
77
-
78
- def extra_repr(self) -> str:
79
- return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
80
- self.in_channels, self.group, self.nempty, self.min_group_channels)
81
-
82
-
83
- class OctreeInstanceNorm(OctreeGroupNorm):
84
- r''' An instance normalization layer for the octree.
85
- '''
86
-
87
- def __init__(self, in_channels: int, nempty: bool = False):
88
- super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
89
- min_group_channels=1) # NOTE: group=in_channels
90
-
91
- def extra_repr(self) -> str:
92
- return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
93
-
94
-
95
- class OctreeNorm(torch.nn.Module):
96
- r''' A normalization layer for the octree. It encapsulates octree-based batch,
97
- group and instance normalization.
98
- '''
99
-
100
- def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
101
- group: int = 32, min_group_channels: int = 4):
102
- super().__init__()
103
- self.in_channels = in_channels
104
- self.norm_type = norm_type
105
- self.group = group
106
- self.min_group_channels = min_group_channels
107
-
108
- if self.norm_type == 'batch_norm':
109
- self.norm = torch.nn.BatchNorm1d(in_channels)
110
- elif self.norm_type == 'group_norm':
111
- self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
112
- elif self.norm_type == 'instance_norm':
113
- self.norm = OctreeInstanceNorm(in_channels)
114
- else:
115
- raise ValueError
116
-
117
- def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
118
- depth: Optional[int] = None):
119
- if self.norm_type == 'batch_norm':
120
- output = self.norm(x)
121
- elif (self.norm_type == 'group_norm' or
122
- self.norm_type == 'instance_norm'):
123
- output = self.norm(x, octree, depth)
124
- else:
125
- raise ValueError
126
- return output
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from typing import Optional
11
+
12
+ from ocnn.octree import Octree
13
+ from ocnn.utils import scatter_add
14
+
15
+
16
+ OctreeBatchNorm = torch.nn.BatchNorm1d
17
+
18
+
19
+ class OctreeGroupNorm(torch.nn.Module):
20
+ r''' An group normalization layer for the octree.
21
+ '''
22
+
23
+ def __init__(self, in_channels: int, group: int, nempty: bool = False,
24
+ min_group_channels: int = 4):
25
+ super().__init__()
26
+ self.eps = 1e-5
27
+ self.nempty = nempty
28
+ self.group = group
29
+ self.in_channels = in_channels
30
+ self.min_group_channels = min_group_channels
31
+ if self.min_group_channels * self.group > in_channels:
32
+ self.group = in_channels // self.min_group_channels
33
+
34
+ assert in_channels % self.group == 0
35
+ self.channels_per_group = in_channels // self.group
36
+
37
+ self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
38
+ self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
39
+ self.reset_parameters()
40
+
41
+ def reset_parameters(self):
42
+ torch.nn.init.ones_(self.weights)
43
+ torch.nn.init.zeros_(self.bias)
44
+
45
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
46
+ r''''''
47
+
48
+ batch_size = octree.batch_size
49
+ batch_id = octree.batch_id(depth, self.nempty)
50
+ ones = data.new_ones([data.shape[0], 1])
51
+ count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
52
+ count = count * self.channels_per_group # element number in each group
53
+ inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes
54
+
55
+ mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count
56
+ mean = self._adjust_for_group(mean)
57
+ out = data - mean.index_select(0, batch_id)
58
+
59
+ var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count
60
+ var = self._adjust_for_group(var)
61
+ inv_std = 1.0 / (var + self.eps).sqrt()
62
+ out = out * inv_std.index_select(0, batch_id)
63
+
64
+ out = out * self.weights + self.bias
65
+ return out
66
+
67
+ def _adjust_for_group(self, tensor: torch.Tensor):
68
+ r''' Adjust the tensor for the group.
69
+ '''
70
+
71
+ if self.channels_per_group > 1:
72
+ tensor = (tensor.reshape(-1, self.group, self.channels_per_group)
73
+ .sum(-1, keepdim=True)
74
+ .repeat(1, 1, self.channels_per_group)
75
+ .reshape(-1, self.in_channels))
76
+ return tensor
77
+
78
+ def extra_repr(self) -> str:
79
+ return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format(
80
+ self.in_channels, self.group, self.nempty, self.min_group_channels)
81
+
82
+
83
+ class OctreeInstanceNorm(OctreeGroupNorm):
84
+ r''' An instance normalization layer for the octree.
85
+ '''
86
+
87
+ def __init__(self, in_channels: int, nempty: bool = False):
88
+ super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty,
89
+ min_group_channels=1) # NOTE: group=in_channels
90
+
91
+ def extra_repr(self) -> str:
92
+ return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
93
+
94
+
95
+ class OctreeNorm(torch.nn.Module):
96
+ r''' A normalization layer for the octree. It encapsulates octree-based batch,
97
+ group and instance normalization.
98
+ '''
99
+
100
+ def __init__(self, in_channels: int, norm_type: str = 'batch_norm',
101
+ group: int = 32, min_group_channels: int = 4):
102
+ super().__init__()
103
+ self.in_channels = in_channels
104
+ self.norm_type = norm_type
105
+ self.group = group
106
+ self.min_group_channels = min_group_channels
107
+
108
+ if self.norm_type == 'batch_norm':
109
+ self.norm = torch.nn.BatchNorm1d(in_channels)
110
+ elif self.norm_type == 'group_norm':
111
+ self.norm = OctreeGroupNorm(in_channels, group, min_group_channels)
112
+ elif self.norm_type == 'instance_norm':
113
+ self.norm = OctreeInstanceNorm(in_channels)
114
+ else:
115
+ raise ValueError
116
+
117
+ def forward(self, x: torch.Tensor, octree: Optional[Octree] = None,
118
+ depth: Optional[int] = None):
119
+ if self.norm_type == 'batch_norm':
120
+ output = self.norm(x)
121
+ elif (self.norm_type == 'group_norm' or
122
+ self.norm_type == 'instance_norm'):
123
+ output = self.norm(x, octree, depth)
124
+ else:
125
+ raise ValueError
126
+ return output