ocnn 2.2.1__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/nn/octree_interp.py CHANGED
@@ -1,196 +1,196 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.sparse
10
- from typing import List, Optional
11
-
12
- import ocnn
13
- from ocnn.octree import Octree
14
-
15
-
16
- def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int,
17
- pts: torch.Tensor, nempty: bool = False,
18
- bound_check: bool = False):
19
- ''' The nearest-neighbor interpolatation with input points.
20
-
21
- Args:
22
- data (torch.Tensor): The input data.
23
- octree (Octree): The octree to interpolate.
24
- depth (int): The depth of the data.
25
- pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`,
26
- i.e. :obj:`N x (x, y, z, batch)`.
27
- nempty (bool): If true, the :attr:`data` only contains features of non-empty
28
- octree nodes
29
- bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`.
30
-
31
- .. note::
32
- The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`.
33
- '''
34
-
35
- nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
36
- assert data.shape[0] == nnum, 'The shape of input data is wrong.'
37
-
38
- idx = octree.search_xyzb(pts, depth, nempty)
39
- valid = idx > -1 # valid indices
40
- if bound_check:
41
- bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1)
42
- valid = torch.logical_and(valid, bound)
43
-
44
- size = (pts.shape[0], data.shape[1])
45
- out = torch.zeros(size, device=data.device, dtype=data.dtype)
46
- out[valid] = data.index_select(0, idx[valid])
47
- return out
48
-
49
-
50
- def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int,
51
- pts: torch.Tensor, nempty: bool = False,
52
- bound_check: bool = False):
53
- ''' Linear interpolatation with input points.
54
-
55
- Refer to :func:`octree_nearest_pts` for the meaning of the arguments.
56
- '''
57
-
58
- nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
59
- assert data.shape[0] == nnum, 'The shape of input data is wrong.'
60
-
61
- device = data.device
62
- grid = torch.tensor(
63
- [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
64
- [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device)
65
-
66
- # 1. Neighborhood searching
67
- xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel
68
- xyzi = xyzf.floor() # the integer part (N, 3)
69
- frac = xyzf - xyzi # the fraction part (N, 3)
70
-
71
- xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3)
72
- batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1)
73
- idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty)
74
- valid = idx > -1 # valid indices
75
- if bound_check:
76
- bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1)
77
- valid = torch.logical_and(valid, bound)
78
- idx = idx[valid]
79
-
80
- # 2. Build the sparse matrix
81
- npt = pts.shape[0]
82
- ids = torch.arange(npt, device=idx.device)
83
- ids = ids.unsqueeze(1).repeat(1, 8).view(-1)
84
- ids = ids[valid]
85
- indices = torch.stack([ids, idx], dim=0).long()
86
-
87
- frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3)
88
- weight = frac.prod(dim=2).abs().view(-1) # (8*N,)
89
- weight = weight[valid]
90
-
91
- h = data.shape[0]
92
- mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device)
93
-
94
- # 3. Interpolatation
95
- output = torch.sparse.mm(mat, data)
96
- ones = torch.ones(h, 1, dtype=data.dtype, device=device)
97
- norm = torch.sparse.mm(mat, ones)
98
- output = torch.div(output, norm + 1e-12)
99
- return output
100
-
101
-
102
- class OctreeInterp(torch.nn.Module):
103
- r''' Interpolates the points with an octree feature.
104
-
105
- Refer to :func:`octree_nearest_pts` for a description of arguments.
106
- '''
107
-
108
- def __init__(self, method: str = 'linear', nempty: bool = False,
109
- bound_check: bool = False, rescale_pts: bool = True):
110
- super().__init__()
111
- self.method = method
112
- self.nempty = nempty
113
- self.bound_check = bound_check
114
- self.rescale_pts = rescale_pts
115
- self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
116
-
117
- def forward(self, data: torch.Tensor, octree: Octree, depth: int,
118
- pts: torch.Tensor):
119
- r''''''
120
-
121
- # rescale points from [-1, 1] to [0, 2^depth]
122
- if self.rescale_pts:
123
- scale = 2 ** (depth - 1)
124
- pts[:, :3] = (pts[:, :3] + 1.0) * scale
125
-
126
- return self.func(data, octree, depth, pts, self.nempty, self.bound_check)
127
-
128
- def extra_repr(self) -> str:
129
- r''' Sets the extra representation of the module.
130
- '''
131
-
132
- return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format(
133
- self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa
134
-
135
-
136
- def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int,
137
- nempty: bool = False):
138
- r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)`
139
- with the nearest-neighbor interpolation.
140
-
141
- Args:
142
- data (torch.Tensor): The input data.
143
- octree (Octree): The octree to interpolate.
144
- depth (int): The depth of the data.
145
- nempty (bool): If true, the :attr:`data` only contains features of non-empty
146
- octree nodes.
147
- '''
148
-
149
- nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
150
- assert data.shape[0] == nnum, 'The shape of input data is wrong.'
151
-
152
- out = data
153
- if not nempty:
154
- out = ocnn.nn.octree_depad(out, octree, depth)
155
- out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1)
156
- if nempty:
157
- out = ocnn.nn.octree_depad(out, octree, depth+1) # !!! depth+1
158
- return out
159
-
160
-
161
- class OctreeUpsample(torch.nn.Module):
162
- r''' Upsamples the octree node features from :attr:`depth` to
163
- :attr:`(target_depth)`.
164
-
165
- Refer to :class:`octree_nearest_pts` for details.
166
- '''
167
-
168
- def __init__(self, method: str = 'linear', nempty: bool = False):
169
- super().__init__()
170
- self.method = method
171
- self.nempty = nempty
172
- self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
173
-
174
- def forward(self, data: torch.Tensor, octree: Octree, depth: int,
175
- target_depth: Optional[int] = None):
176
- r''''''
177
-
178
- if target_depth is None:
179
- target_depth = depth + 1
180
- if target_depth == depth:
181
- return data # return, do nothing
182
- assert target_depth >= depth, 'target_depth must be larger than depth'
183
-
184
- if target_depth == depth + 1 and self.method == 'nearest':
185
- return octree_nearest_upsample(data, octree, depth, self.nempty)
186
-
187
- xyzb = octree.xyzb(target_depth, self.nempty)
188
- pts = torch.stack(xyzb, dim=1).float()
189
- pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale
190
- return self.func(data, octree, depth, pts, self.nempty)
191
-
192
- def extra_repr(self) -> str:
193
- r''' Sets the extra representation of the module.
194
- '''
195
-
196
- return ('method={}, nempty={}').format(self.method, self.nempty)
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.sparse
10
+ from typing import List, Optional
11
+
12
+ import ocnn
13
+ from ocnn.octree import Octree
14
+
15
+
16
+ def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int,
17
+ pts: torch.Tensor, nempty: bool = False,
18
+ bound_check: bool = False):
19
+ ''' The nearest-neighbor interpolatation with input points.
20
+
21
+ Args:
22
+ data (torch.Tensor): The input data.
23
+ octree (Octree): The octree to interpolate.
24
+ depth (int): The depth of the data.
25
+ pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`,
26
+ i.e. :obj:`N x (x, y, z, batch)`.
27
+ nempty (bool): If true, the :attr:`data` only contains features of non-empty
28
+ octree nodes
29
+ bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`.
30
+
31
+ .. note::
32
+ The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`.
33
+ '''
34
+
35
+ nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
36
+ assert data.shape[0] == nnum, 'The shape of input data is wrong.'
37
+
38
+ idx = octree.search_xyzb(pts, depth, nempty)
39
+ valid = idx > -1 # valid indices
40
+ if bound_check:
41
+ bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1)
42
+ valid = torch.logical_and(valid, bound)
43
+
44
+ size = (pts.shape[0], data.shape[1])
45
+ out = torch.zeros(size, device=data.device, dtype=data.dtype)
46
+ out[valid] = data.index_select(0, idx[valid])
47
+ return out
48
+
49
+
50
+ def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int,
51
+ pts: torch.Tensor, nempty: bool = False,
52
+ bound_check: bool = False):
53
+ ''' Linear interpolatation with input points.
54
+
55
+ Refer to :func:`octree_nearest_pts` for the meaning of the arguments.
56
+ '''
57
+
58
+ nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
59
+ assert data.shape[0] == nnum, 'The shape of input data is wrong.'
60
+
61
+ device = data.device
62
+ grid = torch.tensor(
63
+ [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
64
+ [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device)
65
+
66
+ # 1. Neighborhood searching
67
+ xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel
68
+ xyzi = xyzf.floor() # the integer part (N, 3)
69
+ frac = xyzf - xyzi # the fraction part (N, 3)
70
+
71
+ xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3)
72
+ batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1)
73
+ idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty)
74
+ valid = idx > -1 # valid indices
75
+ if bound_check:
76
+ bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1)
77
+ valid = torch.logical_and(valid, bound)
78
+ idx = idx[valid]
79
+
80
+ # 2. Build the sparse matrix
81
+ npt = pts.shape[0]
82
+ ids = torch.arange(npt, device=idx.device)
83
+ ids = ids.unsqueeze(1).repeat(1, 8).view(-1)
84
+ ids = ids[valid]
85
+ indices = torch.stack([ids, idx], dim=0).long()
86
+
87
+ frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3)
88
+ weight = frac.prod(dim=2).abs().view(-1) # (8*N,)
89
+ weight = weight[valid]
90
+
91
+ h = data.shape[0]
92
+ mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device)
93
+
94
+ # 3. Interpolatation
95
+ output = torch.sparse.mm(mat, data)
96
+ ones = torch.ones(h, 1, dtype=data.dtype, device=device)
97
+ norm = torch.sparse.mm(mat, ones)
98
+ output = torch.div(output, norm + 1e-12)
99
+ return output
100
+
101
+
102
+ class OctreeInterp(torch.nn.Module):
103
+ r''' Interpolates the points with an octree feature.
104
+
105
+ Refer to :func:`octree_nearest_pts` for a description of arguments.
106
+ '''
107
+
108
+ def __init__(self, method: str = 'linear', nempty: bool = False,
109
+ bound_check: bool = False, rescale_pts: bool = True):
110
+ super().__init__()
111
+ self.method = method
112
+ self.nempty = nempty
113
+ self.bound_check = bound_check
114
+ self.rescale_pts = rescale_pts
115
+ self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
116
+
117
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int,
118
+ pts: torch.Tensor):
119
+ r''''''
120
+
121
+ # rescale points from [-1, 1] to [0, 2^depth]
122
+ if self.rescale_pts:
123
+ scale = 2 ** (depth - 1)
124
+ pts[:, :3] = (pts[:, :3] + 1.0) * scale
125
+
126
+ return self.func(data, octree, depth, pts, self.nempty, self.bound_check)
127
+
128
+ def extra_repr(self) -> str:
129
+ r''' Sets the extra representation of the module.
130
+ '''
131
+
132
+ return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format(
133
+ self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa
134
+
135
+
136
+ def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int,
137
+ nempty: bool = False):
138
+ r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)`
139
+ with the nearest-neighbor interpolation.
140
+
141
+ Args:
142
+ data (torch.Tensor): The input data.
143
+ octree (Octree): The octree to interpolate.
144
+ depth (int): The depth of the data.
145
+ nempty (bool): If true, the :attr:`data` only contains features of non-empty
146
+ octree nodes.
147
+ '''
148
+
149
+ nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
150
+ assert data.shape[0] == nnum, 'The shape of input data is wrong.'
151
+
152
+ out = data
153
+ if not nempty:
154
+ out = ocnn.nn.octree_depad(out, octree, depth)
155
+ out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1)
156
+ if nempty:
157
+ out = ocnn.nn.octree_depad(out, octree, depth+1) # !!! depth+1
158
+ return out
159
+
160
+
161
+ class OctreeUpsample(torch.nn.Module):
162
+ r''' Upsamples the octree node features from :attr:`depth` to
163
+ :attr:`(target_depth)`.
164
+
165
+ Refer to :class:`octree_nearest_pts` for details.
166
+ '''
167
+
168
+ def __init__(self, method: str = 'linear', nempty: bool = False):
169
+ super().__init__()
170
+ self.method = method
171
+ self.nempty = nempty
172
+ self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
173
+
174
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int,
175
+ target_depth: Optional[int] = None):
176
+ r''''''
177
+
178
+ if target_depth is None:
179
+ target_depth = depth + 1
180
+ if target_depth == depth:
181
+ return data # return, do nothing
182
+ assert target_depth >= depth, 'target_depth must be larger than depth'
183
+
184
+ if target_depth == depth + 1 and self.method == 'nearest':
185
+ return octree_nearest_upsample(data, octree, depth, self.nempty)
186
+
187
+ xyzb = octree.xyzb(target_depth, self.nempty)
188
+ pts = torch.stack(xyzb, dim=1).float()
189
+ pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale
190
+ return self.func(data, octree, depth, pts, self.nempty)
191
+
192
+ def extra_repr(self) -> str:
193
+ r''' Sets the extra representation of the module.
194
+ '''
195
+
196
+ return ('method={}, nempty={}').format(self.method, self.nempty)
ocnn/nn/octree_norm.py CHANGED
@@ -1,86 +1,86 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
-
11
- from ocnn.octree import Octree
12
- from ocnn.utils import scatter_add
13
-
14
-
15
- OctreeBatchNorm = torch.nn.BatchNorm1d
16
-
17
-
18
- class OctreeGroupNorm(torch.nn.Module):
19
- r''' An group normalization layer for the octree.
20
- '''
21
-
22
- def __init__(self, in_channels: int, group: int, nempty: bool = False):
23
- super().__init__()
24
- self.eps = 1e-5
25
- self.nempty = nempty
26
- self.group = group
27
- self.in_channels = in_channels
28
-
29
- assert in_channels % group == 0
30
- self.channels_per_group = in_channels // group
31
-
32
- self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
33
- self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
34
- self.reset_parameters()
35
-
36
- def reset_parameters(self):
37
- torch.nn.init.ones_(self.weights)
38
- torch.nn.init.zeros_(self.bias)
39
-
40
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
41
- r''''''
42
-
43
- batch_size = octree.batch_size
44
- batch_id = octree.batch_id(depth, self.nempty)
45
- ones = data.new_ones([data.shape[0], 1])
46
- count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
47
- count = count * self.channels_per_group # element number in each group
48
- inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes
49
-
50
- mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count
51
- mean = self._adjust_for_group(mean)
52
- out = data - mean.index_select(0, batch_id)
53
-
54
- var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count
55
- var = self._adjust_for_group(var)
56
- inv_std = 1.0 / (var + self.eps).sqrt()
57
- out = out * inv_std.index_select(0, batch_id)
58
-
59
- out = out * self.weights + self.bias
60
- return out
61
-
62
- def _adjust_for_group(self, tensor: torch.Tensor):
63
- r''' Adjust the tensor for the group.
64
- '''
65
-
66
- if self.channels_per_group > 1:
67
- tensor = (tensor.reshape(-1, self.group, self.channels_per_group)
68
- .sum(-1, keepdim=True)
69
- .repeat(1, 1, self.channels_per_group)
70
- .reshape(-1, self.in_channels))
71
- return tensor
72
-
73
- def extra_repr(self) -> str:
74
- return ('in_channels={}, group={}, nempty={}').format(
75
- self.in_channels, self.group, self.nempty) # noqa
76
-
77
-
78
- class OctreeInstanceNorm(OctreeGroupNorm):
79
- r''' An instance normalization layer for the octree.
80
- '''
81
-
82
- def __init__(self, in_channels: int, nempty: bool = False):
83
- super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty)
84
-
85
- def extra_repr(self) -> str:
86
- return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+
11
+ from ocnn.octree import Octree
12
+ from ocnn.utils import scatter_add
13
+
14
+
15
+ OctreeBatchNorm = torch.nn.BatchNorm1d
16
+
17
+
18
+ class OctreeGroupNorm(torch.nn.Module):
19
+ r''' An group normalization layer for the octree.
20
+ '''
21
+
22
+ def __init__(self, in_channels: int, group: int, nempty: bool = False):
23
+ super().__init__()
24
+ self.eps = 1e-5
25
+ self.nempty = nempty
26
+ self.group = group
27
+ self.in_channels = in_channels
28
+
29
+ assert in_channels % group == 0
30
+ self.channels_per_group = in_channels // group
31
+
32
+ self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
33
+ self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
34
+ self.reset_parameters()
35
+
36
+ def reset_parameters(self):
37
+ torch.nn.init.ones_(self.weights)
38
+ torch.nn.init.zeros_(self.bias)
39
+
40
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
41
+ r''''''
42
+
43
+ batch_size = octree.batch_size
44
+ batch_id = octree.batch_id(depth, self.nempty)
45
+ ones = data.new_ones([data.shape[0], 1])
46
+ count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
47
+ count = count * self.channels_per_group # element number in each group
48
+ inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes
49
+
50
+ mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count
51
+ mean = self._adjust_for_group(mean)
52
+ out = data - mean.index_select(0, batch_id)
53
+
54
+ var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count
55
+ var = self._adjust_for_group(var)
56
+ inv_std = 1.0 / (var + self.eps).sqrt()
57
+ out = out * inv_std.index_select(0, batch_id)
58
+
59
+ out = out * self.weights + self.bias
60
+ return out
61
+
62
+ def _adjust_for_group(self, tensor: torch.Tensor):
63
+ r''' Adjust the tensor for the group.
64
+ '''
65
+
66
+ if self.channels_per_group > 1:
67
+ tensor = (tensor.reshape(-1, self.group, self.channels_per_group)
68
+ .sum(-1, keepdim=True)
69
+ .repeat(1, 1, self.channels_per_group)
70
+ .reshape(-1, self.in_channels))
71
+ return tensor
72
+
73
+ def extra_repr(self) -> str:
74
+ return ('in_channels={}, group={}, nempty={}').format(
75
+ self.in_channels, self.group, self.nempty) # noqa
76
+
77
+
78
+ class OctreeInstanceNorm(OctreeGroupNorm):
79
+ r''' An instance normalization layer for the octree.
80
+ '''
81
+
82
+ def __init__(self, in_channels: int, nempty: bool = False):
83
+ super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty)
84
+
85
+ def extra_repr(self) -> str:
86
+ return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
ocnn/nn/octree_pad.py CHANGED
@@ -1,39 +1,39 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
-
10
- from ..octree import Octree
11
-
12
-
13
- def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
14
- r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
15
- the octree node number.
16
-
17
- Args:
18
- data (torch.Tensor): The input tensor with its number of elements equal to the
19
- non-empty octree node number.
20
- octree (Octree): The corresponding octree.
21
- depth (int): The depth of current octree.
22
- val (float): The padding value. (Default: :obj:`0.0`)
23
- '''
24
-
25
- mask = octree.nempty_mask(depth)
26
- size = (octree.nnum[depth], data.shape[1]) # (N, C)
27
- out = torch.full(size, val, dtype=data.dtype, device=data.device)
28
- out[mask] = data
29
- return out
30
-
31
-
32
- def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
33
- r''' Reverse operation of :func:`octree_depad`.
34
-
35
- Please refer to :func:`octree_depad` for the meaning of the arguments.
36
- '''
37
-
38
- mask = octree.nempty_mask(depth)
39
- return data[mask]
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ from ocnn.octree import Octree
11
+
12
+
13
+ def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
14
+ r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
15
+ the octree node number.
16
+
17
+ Args:
18
+ data (torch.Tensor): The input tensor with its number of elements equal to the
19
+ non-empty octree node number.
20
+ octree (Octree): The corresponding octree.
21
+ depth (int): The depth of current octree.
22
+ val (float): The padding value. (Default: :obj:`0.0`)
23
+ '''
24
+
25
+ mask = octree.nempty_mask(depth)
26
+ size = (octree.nnum[depth], data.shape[1]) # (N, C)
27
+ out = torch.full(size, val, dtype=data.dtype, device=data.device)
28
+ out[mask] = data
29
+ return out
30
+
31
+
32
+ def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
33
+ r''' Reverse operation of :func:`octree_depad`.
34
+
35
+ Please refer to :func:`octree_depad` for the meaning of the arguments.
36
+ '''
37
+
38
+ mask = octree.nempty_mask(depth)
39
+ return data[mask]