ocnn 2.2.1__py3-none-any.whl → 2.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -158
- ocnn/models/__init__.py +29 -27
- ocnn/models/autoencoder.py +155 -165
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -0
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +20 -20
- ocnn/modules/modules.py +193 -231
- ocnn/modules/resblocks.py +124 -124
- ocnn/nn/__init__.py +43 -42
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +429 -411
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -204
- ocnn/nn/octree_gconv.py +79 -0
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +86 -86
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -21
- ocnn/octree/octree.py +639 -601
- ocnn/octree/points.py +322 -298
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -153
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/LICENSE +21 -21
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/METADATA +79 -65
- ocnn-2.2.3.dist-info/RECORD +36 -0
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/WHEEL +1 -1
- ocnn-2.2.1.dist-info/RECORD +0 -34
- {ocnn-2.2.1.dist-info → ocnn-2.2.3.dist-info}/top_level.txt +0 -0
ocnn/nn/octree_interp.py
CHANGED
|
@@ -1,196 +1,196 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.sparse
|
|
10
|
-
from typing import List, Optional
|
|
11
|
-
|
|
12
|
-
import ocnn
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int,
|
|
17
|
-
pts: torch.Tensor, nempty: bool = False,
|
|
18
|
-
bound_check: bool = False):
|
|
19
|
-
''' The nearest-neighbor interpolatation with input points.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
data (torch.Tensor): The input data.
|
|
23
|
-
octree (Octree): The octree to interpolate.
|
|
24
|
-
depth (int): The depth of the data.
|
|
25
|
-
pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`,
|
|
26
|
-
i.e. :obj:`N x (x, y, z, batch)`.
|
|
27
|
-
nempty (bool): If true, the :attr:`data` only contains features of non-empty
|
|
28
|
-
octree nodes
|
|
29
|
-
bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`.
|
|
30
|
-
|
|
31
|
-
.. note::
|
|
32
|
-
The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`.
|
|
33
|
-
'''
|
|
34
|
-
|
|
35
|
-
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
36
|
-
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
37
|
-
|
|
38
|
-
idx = octree.search_xyzb(pts, depth, nempty)
|
|
39
|
-
valid = idx > -1 # valid indices
|
|
40
|
-
if bound_check:
|
|
41
|
-
bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1)
|
|
42
|
-
valid = torch.logical_and(valid, bound)
|
|
43
|
-
|
|
44
|
-
size = (pts.shape[0], data.shape[1])
|
|
45
|
-
out = torch.zeros(size, device=data.device, dtype=data.dtype)
|
|
46
|
-
out[valid] = data.index_select(0, idx[valid])
|
|
47
|
-
return out
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int,
|
|
51
|
-
pts: torch.Tensor, nempty: bool = False,
|
|
52
|
-
bound_check: bool = False):
|
|
53
|
-
''' Linear interpolatation with input points.
|
|
54
|
-
|
|
55
|
-
Refer to :func:`octree_nearest_pts` for the meaning of the arguments.
|
|
56
|
-
'''
|
|
57
|
-
|
|
58
|
-
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
59
|
-
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
60
|
-
|
|
61
|
-
device = data.device
|
|
62
|
-
grid = torch.tensor(
|
|
63
|
-
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
|
|
64
|
-
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device)
|
|
65
|
-
|
|
66
|
-
# 1. Neighborhood searching
|
|
67
|
-
xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel
|
|
68
|
-
xyzi = xyzf.floor() # the integer part (N, 3)
|
|
69
|
-
frac = xyzf - xyzi # the fraction part (N, 3)
|
|
70
|
-
|
|
71
|
-
xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3)
|
|
72
|
-
batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1)
|
|
73
|
-
idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty)
|
|
74
|
-
valid = idx > -1 # valid indices
|
|
75
|
-
if bound_check:
|
|
76
|
-
bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1)
|
|
77
|
-
valid = torch.logical_and(valid, bound)
|
|
78
|
-
idx = idx[valid]
|
|
79
|
-
|
|
80
|
-
# 2. Build the sparse matrix
|
|
81
|
-
npt = pts.shape[0]
|
|
82
|
-
ids = torch.arange(npt, device=idx.device)
|
|
83
|
-
ids = ids.unsqueeze(1).repeat(1, 8).view(-1)
|
|
84
|
-
ids = ids[valid]
|
|
85
|
-
indices = torch.stack([ids, idx], dim=0).long()
|
|
86
|
-
|
|
87
|
-
frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3)
|
|
88
|
-
weight = frac.prod(dim=2).abs().view(-1) # (8*N,)
|
|
89
|
-
weight = weight[valid]
|
|
90
|
-
|
|
91
|
-
h = data.shape[0]
|
|
92
|
-
mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device)
|
|
93
|
-
|
|
94
|
-
# 3. Interpolatation
|
|
95
|
-
output = torch.sparse.mm(mat, data)
|
|
96
|
-
ones = torch.ones(h, 1, dtype=data.dtype, device=device)
|
|
97
|
-
norm = torch.sparse.mm(mat, ones)
|
|
98
|
-
output = torch.div(output, norm + 1e-12)
|
|
99
|
-
return output
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class OctreeInterp(torch.nn.Module):
|
|
103
|
-
r''' Interpolates the points with an octree feature.
|
|
104
|
-
|
|
105
|
-
Refer to :func:`octree_nearest_pts` for a description of arguments.
|
|
106
|
-
'''
|
|
107
|
-
|
|
108
|
-
def __init__(self, method: str = 'linear', nempty: bool = False,
|
|
109
|
-
bound_check: bool = False, rescale_pts: bool = True):
|
|
110
|
-
super().__init__()
|
|
111
|
-
self.method = method
|
|
112
|
-
self.nempty = nempty
|
|
113
|
-
self.bound_check = bound_check
|
|
114
|
-
self.rescale_pts = rescale_pts
|
|
115
|
-
self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
|
|
116
|
-
|
|
117
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
118
|
-
pts: torch.Tensor):
|
|
119
|
-
r''''''
|
|
120
|
-
|
|
121
|
-
# rescale points from [-1, 1] to [0, 2^depth]
|
|
122
|
-
if self.rescale_pts:
|
|
123
|
-
scale = 2 ** (depth - 1)
|
|
124
|
-
pts[:, :3] = (pts[:, :3] + 1.0) * scale
|
|
125
|
-
|
|
126
|
-
return self.func(data, octree, depth, pts, self.nempty, self.bound_check)
|
|
127
|
-
|
|
128
|
-
def extra_repr(self) -> str:
|
|
129
|
-
r''' Sets the extra representation of the module.
|
|
130
|
-
'''
|
|
131
|
-
|
|
132
|
-
return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format(
|
|
133
|
-
self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int,
|
|
137
|
-
nempty: bool = False):
|
|
138
|
-
r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)`
|
|
139
|
-
with the nearest-neighbor interpolation.
|
|
140
|
-
|
|
141
|
-
Args:
|
|
142
|
-
data (torch.Tensor): The input data.
|
|
143
|
-
octree (Octree): The octree to interpolate.
|
|
144
|
-
depth (int): The depth of the data.
|
|
145
|
-
nempty (bool): If true, the :attr:`data` only contains features of non-empty
|
|
146
|
-
octree nodes.
|
|
147
|
-
'''
|
|
148
|
-
|
|
149
|
-
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
150
|
-
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
151
|
-
|
|
152
|
-
out = data
|
|
153
|
-
if not nempty:
|
|
154
|
-
out = ocnn.nn.octree_depad(out, octree, depth)
|
|
155
|
-
out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1)
|
|
156
|
-
if nempty:
|
|
157
|
-
out = ocnn.nn.octree_depad(out, octree, depth+1) # !!! depth+1
|
|
158
|
-
return out
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
class OctreeUpsample(torch.nn.Module):
|
|
162
|
-
r''' Upsamples the octree node features from :attr:`depth` to
|
|
163
|
-
:attr:`(target_depth)`.
|
|
164
|
-
|
|
165
|
-
Refer to :class:`octree_nearest_pts` for details.
|
|
166
|
-
'''
|
|
167
|
-
|
|
168
|
-
def __init__(self, method: str = 'linear', nempty: bool = False):
|
|
169
|
-
super().__init__()
|
|
170
|
-
self.method = method
|
|
171
|
-
self.nempty = nempty
|
|
172
|
-
self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
|
|
173
|
-
|
|
174
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
175
|
-
target_depth: Optional[int] = None):
|
|
176
|
-
r''''''
|
|
177
|
-
|
|
178
|
-
if target_depth is None:
|
|
179
|
-
target_depth = depth + 1
|
|
180
|
-
if target_depth == depth:
|
|
181
|
-
return data # return, do nothing
|
|
182
|
-
assert target_depth >= depth, 'target_depth must be larger than depth'
|
|
183
|
-
|
|
184
|
-
if target_depth == depth + 1 and self.method == 'nearest':
|
|
185
|
-
return octree_nearest_upsample(data, octree, depth, self.nempty)
|
|
186
|
-
|
|
187
|
-
xyzb = octree.xyzb(target_depth, self.nempty)
|
|
188
|
-
pts = torch.stack(xyzb, dim=1).float()
|
|
189
|
-
pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale
|
|
190
|
-
return self.func(data, octree, depth, pts, self.nempty)
|
|
191
|
-
|
|
192
|
-
def extra_repr(self) -> str:
|
|
193
|
-
r''' Sets the extra representation of the module.
|
|
194
|
-
'''
|
|
195
|
-
|
|
196
|
-
return ('method={}, nempty={}').format(self.method, self.nempty)
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.sparse
|
|
10
|
+
from typing import List, Optional
|
|
11
|
+
|
|
12
|
+
import ocnn
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int,
|
|
17
|
+
pts: torch.Tensor, nempty: bool = False,
|
|
18
|
+
bound_check: bool = False):
|
|
19
|
+
''' The nearest-neighbor interpolatation with input points.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
data (torch.Tensor): The input data.
|
|
23
|
+
octree (Octree): The octree to interpolate.
|
|
24
|
+
depth (int): The depth of the data.
|
|
25
|
+
pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`,
|
|
26
|
+
i.e. :obj:`N x (x, y, z, batch)`.
|
|
27
|
+
nempty (bool): If true, the :attr:`data` only contains features of non-empty
|
|
28
|
+
octree nodes
|
|
29
|
+
bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`.
|
|
30
|
+
|
|
31
|
+
.. note::
|
|
32
|
+
The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`.
|
|
33
|
+
'''
|
|
34
|
+
|
|
35
|
+
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
36
|
+
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
37
|
+
|
|
38
|
+
idx = octree.search_xyzb(pts, depth, nempty)
|
|
39
|
+
valid = idx > -1 # valid indices
|
|
40
|
+
if bound_check:
|
|
41
|
+
bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1)
|
|
42
|
+
valid = torch.logical_and(valid, bound)
|
|
43
|
+
|
|
44
|
+
size = (pts.shape[0], data.shape[1])
|
|
45
|
+
out = torch.zeros(size, device=data.device, dtype=data.dtype)
|
|
46
|
+
out[valid] = data.index_select(0, idx[valid])
|
|
47
|
+
return out
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int,
|
|
51
|
+
pts: torch.Tensor, nempty: bool = False,
|
|
52
|
+
bound_check: bool = False):
|
|
53
|
+
''' Linear interpolatation with input points.
|
|
54
|
+
|
|
55
|
+
Refer to :func:`octree_nearest_pts` for the meaning of the arguments.
|
|
56
|
+
'''
|
|
57
|
+
|
|
58
|
+
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
59
|
+
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
60
|
+
|
|
61
|
+
device = data.device
|
|
62
|
+
grid = torch.tensor(
|
|
63
|
+
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
|
|
64
|
+
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device)
|
|
65
|
+
|
|
66
|
+
# 1. Neighborhood searching
|
|
67
|
+
xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel
|
|
68
|
+
xyzi = xyzf.floor() # the integer part (N, 3)
|
|
69
|
+
frac = xyzf - xyzi # the fraction part (N, 3)
|
|
70
|
+
|
|
71
|
+
xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3)
|
|
72
|
+
batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1)
|
|
73
|
+
idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty)
|
|
74
|
+
valid = idx > -1 # valid indices
|
|
75
|
+
if bound_check:
|
|
76
|
+
bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1)
|
|
77
|
+
valid = torch.logical_and(valid, bound)
|
|
78
|
+
idx = idx[valid]
|
|
79
|
+
|
|
80
|
+
# 2. Build the sparse matrix
|
|
81
|
+
npt = pts.shape[0]
|
|
82
|
+
ids = torch.arange(npt, device=idx.device)
|
|
83
|
+
ids = ids.unsqueeze(1).repeat(1, 8).view(-1)
|
|
84
|
+
ids = ids[valid]
|
|
85
|
+
indices = torch.stack([ids, idx], dim=0).long()
|
|
86
|
+
|
|
87
|
+
frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3)
|
|
88
|
+
weight = frac.prod(dim=2).abs().view(-1) # (8*N,)
|
|
89
|
+
weight = weight[valid]
|
|
90
|
+
|
|
91
|
+
h = data.shape[0]
|
|
92
|
+
mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device)
|
|
93
|
+
|
|
94
|
+
# 3. Interpolatation
|
|
95
|
+
output = torch.sparse.mm(mat, data)
|
|
96
|
+
ones = torch.ones(h, 1, dtype=data.dtype, device=device)
|
|
97
|
+
norm = torch.sparse.mm(mat, ones)
|
|
98
|
+
output = torch.div(output, norm + 1e-12)
|
|
99
|
+
return output
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class OctreeInterp(torch.nn.Module):
|
|
103
|
+
r''' Interpolates the points with an octree feature.
|
|
104
|
+
|
|
105
|
+
Refer to :func:`octree_nearest_pts` for a description of arguments.
|
|
106
|
+
'''
|
|
107
|
+
|
|
108
|
+
def __init__(self, method: str = 'linear', nempty: bool = False,
|
|
109
|
+
bound_check: bool = False, rescale_pts: bool = True):
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.method = method
|
|
112
|
+
self.nempty = nempty
|
|
113
|
+
self.bound_check = bound_check
|
|
114
|
+
self.rescale_pts = rescale_pts
|
|
115
|
+
self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
|
|
116
|
+
|
|
117
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
118
|
+
pts: torch.Tensor):
|
|
119
|
+
r''''''
|
|
120
|
+
|
|
121
|
+
# rescale points from [-1, 1] to [0, 2^depth]
|
|
122
|
+
if self.rescale_pts:
|
|
123
|
+
scale = 2 ** (depth - 1)
|
|
124
|
+
pts[:, :3] = (pts[:, :3] + 1.0) * scale
|
|
125
|
+
|
|
126
|
+
return self.func(data, octree, depth, pts, self.nempty, self.bound_check)
|
|
127
|
+
|
|
128
|
+
def extra_repr(self) -> str:
|
|
129
|
+
r''' Sets the extra representation of the module.
|
|
130
|
+
'''
|
|
131
|
+
|
|
132
|
+
return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format(
|
|
133
|
+
self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int,
|
|
137
|
+
nempty: bool = False):
|
|
138
|
+
r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)`
|
|
139
|
+
with the nearest-neighbor interpolation.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
data (torch.Tensor): The input data.
|
|
143
|
+
octree (Octree): The octree to interpolate.
|
|
144
|
+
depth (int): The depth of the data.
|
|
145
|
+
nempty (bool): If true, the :attr:`data` only contains features of non-empty
|
|
146
|
+
octree nodes.
|
|
147
|
+
'''
|
|
148
|
+
|
|
149
|
+
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
150
|
+
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
151
|
+
|
|
152
|
+
out = data
|
|
153
|
+
if not nempty:
|
|
154
|
+
out = ocnn.nn.octree_depad(out, octree, depth)
|
|
155
|
+
out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1)
|
|
156
|
+
if nempty:
|
|
157
|
+
out = ocnn.nn.octree_depad(out, octree, depth+1) # !!! depth+1
|
|
158
|
+
return out
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class OctreeUpsample(torch.nn.Module):
|
|
162
|
+
r''' Upsamples the octree node features from :attr:`depth` to
|
|
163
|
+
:attr:`(target_depth)`.
|
|
164
|
+
|
|
165
|
+
Refer to :class:`octree_nearest_pts` for details.
|
|
166
|
+
'''
|
|
167
|
+
|
|
168
|
+
def __init__(self, method: str = 'linear', nempty: bool = False):
|
|
169
|
+
super().__init__()
|
|
170
|
+
self.method = method
|
|
171
|
+
self.nempty = nempty
|
|
172
|
+
self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
|
|
173
|
+
|
|
174
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
175
|
+
target_depth: Optional[int] = None):
|
|
176
|
+
r''''''
|
|
177
|
+
|
|
178
|
+
if target_depth is None:
|
|
179
|
+
target_depth = depth + 1
|
|
180
|
+
if target_depth == depth:
|
|
181
|
+
return data # return, do nothing
|
|
182
|
+
assert target_depth >= depth, 'target_depth must be larger than depth'
|
|
183
|
+
|
|
184
|
+
if target_depth == depth + 1 and self.method == 'nearest':
|
|
185
|
+
return octree_nearest_upsample(data, octree, depth, self.nempty)
|
|
186
|
+
|
|
187
|
+
xyzb = octree.xyzb(target_depth, self.nempty)
|
|
188
|
+
pts = torch.stack(xyzb, dim=1).float()
|
|
189
|
+
pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale
|
|
190
|
+
return self.func(data, octree, depth, pts, self.nempty)
|
|
191
|
+
|
|
192
|
+
def extra_repr(self) -> str:
|
|
193
|
+
r''' Sets the extra representation of the module.
|
|
194
|
+
'''
|
|
195
|
+
|
|
196
|
+
return ('method={}, nempty={}').format(self.method, self.nempty)
|
ocnn/nn/octree_norm.py
CHANGED
|
@@ -1,86 +1,86 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn
|
|
10
|
-
|
|
11
|
-
from ocnn.octree import Octree
|
|
12
|
-
from ocnn.utils import scatter_add
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
OctreeBatchNorm = torch.nn.BatchNorm1d
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class OctreeGroupNorm(torch.nn.Module):
|
|
19
|
-
r''' An group normalization layer for the octree.
|
|
20
|
-
'''
|
|
21
|
-
|
|
22
|
-
def __init__(self, in_channels: int, group: int, nempty: bool = False):
|
|
23
|
-
super().__init__()
|
|
24
|
-
self.eps = 1e-5
|
|
25
|
-
self.nempty = nempty
|
|
26
|
-
self.group = group
|
|
27
|
-
self.in_channels = in_channels
|
|
28
|
-
|
|
29
|
-
assert in_channels % group == 0
|
|
30
|
-
self.channels_per_group = in_channels // group
|
|
31
|
-
|
|
32
|
-
self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
33
|
-
self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
34
|
-
self.reset_parameters()
|
|
35
|
-
|
|
36
|
-
def reset_parameters(self):
|
|
37
|
-
torch.nn.init.ones_(self.weights)
|
|
38
|
-
torch.nn.init.zeros_(self.bias)
|
|
39
|
-
|
|
40
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
41
|
-
r''''''
|
|
42
|
-
|
|
43
|
-
batch_size = octree.batch_size
|
|
44
|
-
batch_id = octree.batch_id(depth, self.nempty)
|
|
45
|
-
ones = data.new_ones([data.shape[0], 1])
|
|
46
|
-
count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
|
|
47
|
-
count = count * self.channels_per_group # element number in each group
|
|
48
|
-
inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes
|
|
49
|
-
|
|
50
|
-
mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count
|
|
51
|
-
mean = self._adjust_for_group(mean)
|
|
52
|
-
out = data - mean.index_select(0, batch_id)
|
|
53
|
-
|
|
54
|
-
var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count
|
|
55
|
-
var = self._adjust_for_group(var)
|
|
56
|
-
inv_std = 1.0 / (var + self.eps).sqrt()
|
|
57
|
-
out = out * inv_std.index_select(0, batch_id)
|
|
58
|
-
|
|
59
|
-
out = out * self.weights + self.bias
|
|
60
|
-
return out
|
|
61
|
-
|
|
62
|
-
def _adjust_for_group(self, tensor: torch.Tensor):
|
|
63
|
-
r''' Adjust the tensor for the group.
|
|
64
|
-
'''
|
|
65
|
-
|
|
66
|
-
if self.channels_per_group > 1:
|
|
67
|
-
tensor = (tensor.reshape(-1, self.group, self.channels_per_group)
|
|
68
|
-
.sum(-1, keepdim=True)
|
|
69
|
-
.repeat(1, 1, self.channels_per_group)
|
|
70
|
-
.reshape(-1, self.in_channels))
|
|
71
|
-
return tensor
|
|
72
|
-
|
|
73
|
-
def extra_repr(self) -> str:
|
|
74
|
-
return ('in_channels={}, group={}, nempty={}').format(
|
|
75
|
-
self.in_channels, self.group, self.nempty) # noqa
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class OctreeInstanceNorm(OctreeGroupNorm):
|
|
79
|
-
r''' An instance normalization layer for the octree.
|
|
80
|
-
'''
|
|
81
|
-
|
|
82
|
-
def __init__(self, in_channels: int, nempty: bool = False):
|
|
83
|
-
super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty)
|
|
84
|
-
|
|
85
|
-
def extra_repr(self) -> str:
|
|
86
|
-
return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
|
|
11
|
+
from ocnn.octree import Octree
|
|
12
|
+
from ocnn.utils import scatter_add
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
OctreeBatchNorm = torch.nn.BatchNorm1d
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class OctreeGroupNorm(torch.nn.Module):
|
|
19
|
+
r''' An group normalization layer for the octree.
|
|
20
|
+
'''
|
|
21
|
+
|
|
22
|
+
def __init__(self, in_channels: int, group: int, nempty: bool = False):
|
|
23
|
+
super().__init__()
|
|
24
|
+
self.eps = 1e-5
|
|
25
|
+
self.nempty = nempty
|
|
26
|
+
self.group = group
|
|
27
|
+
self.in_channels = in_channels
|
|
28
|
+
|
|
29
|
+
assert in_channels % group == 0
|
|
30
|
+
self.channels_per_group = in_channels // group
|
|
31
|
+
|
|
32
|
+
self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
33
|
+
self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels))
|
|
34
|
+
self.reset_parameters()
|
|
35
|
+
|
|
36
|
+
def reset_parameters(self):
|
|
37
|
+
torch.nn.init.ones_(self.weights)
|
|
38
|
+
torch.nn.init.zeros_(self.bias)
|
|
39
|
+
|
|
40
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
41
|
+
r''''''
|
|
42
|
+
|
|
43
|
+
batch_size = octree.batch_size
|
|
44
|
+
batch_id = octree.batch_id(depth, self.nempty)
|
|
45
|
+
ones = data.new_ones([data.shape[0], 1])
|
|
46
|
+
count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size)
|
|
47
|
+
count = count * self.channels_per_group # element number in each group
|
|
48
|
+
inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes
|
|
49
|
+
|
|
50
|
+
mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count
|
|
51
|
+
mean = self._adjust_for_group(mean)
|
|
52
|
+
out = data - mean.index_select(0, batch_id)
|
|
53
|
+
|
|
54
|
+
var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count
|
|
55
|
+
var = self._adjust_for_group(var)
|
|
56
|
+
inv_std = 1.0 / (var + self.eps).sqrt()
|
|
57
|
+
out = out * inv_std.index_select(0, batch_id)
|
|
58
|
+
|
|
59
|
+
out = out * self.weights + self.bias
|
|
60
|
+
return out
|
|
61
|
+
|
|
62
|
+
def _adjust_for_group(self, tensor: torch.Tensor):
|
|
63
|
+
r''' Adjust the tensor for the group.
|
|
64
|
+
'''
|
|
65
|
+
|
|
66
|
+
if self.channels_per_group > 1:
|
|
67
|
+
tensor = (tensor.reshape(-1, self.group, self.channels_per_group)
|
|
68
|
+
.sum(-1, keepdim=True)
|
|
69
|
+
.repeat(1, 1, self.channels_per_group)
|
|
70
|
+
.reshape(-1, self.in_channels))
|
|
71
|
+
return tensor
|
|
72
|
+
|
|
73
|
+
def extra_repr(self) -> str:
|
|
74
|
+
return ('in_channels={}, group={}, nempty={}').format(
|
|
75
|
+
self.in_channels, self.group, self.nempty) # noqa
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class OctreeInstanceNorm(OctreeGroupNorm):
|
|
79
|
+
r''' An instance normalization layer for the octree.
|
|
80
|
+
'''
|
|
81
|
+
|
|
82
|
+
def __init__(self, in_channels: int, nempty: bool = False):
|
|
83
|
+
super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty)
|
|
84
|
+
|
|
85
|
+
def extra_repr(self) -> str:
|
|
86
|
+
return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
|
ocnn/nn/octree_pad.py
CHANGED
|
@@ -1,39 +1,39 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
|
|
10
|
-
from
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
|
|
14
|
-
r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
|
|
15
|
-
the octree node number.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
data (torch.Tensor): The input tensor with its number of elements equal to the
|
|
19
|
-
non-empty octree node number.
|
|
20
|
-
octree (Octree): The corresponding octree.
|
|
21
|
-
depth (int): The depth of current octree.
|
|
22
|
-
val (float): The padding value. (Default: :obj:`0.0`)
|
|
23
|
-
'''
|
|
24
|
-
|
|
25
|
-
mask = octree.nempty_mask(depth)
|
|
26
|
-
size = (octree.nnum[depth], data.shape[1]) # (N, C)
|
|
27
|
-
out = torch.full(size, val, dtype=data.dtype, device=data.device)
|
|
28
|
-
out[mask] = data
|
|
29
|
-
return out
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
|
|
33
|
-
r''' Reverse operation of :func:`octree_depad`.
|
|
34
|
-
|
|
35
|
-
Please refer to :func:`octree_depad` for the meaning of the arguments.
|
|
36
|
-
'''
|
|
37
|
-
|
|
38
|
-
mask = octree.nempty_mask(depth)
|
|
39
|
-
return data[mask]
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ocnn.octree import Octree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
|
|
14
|
+
r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
|
|
15
|
+
the octree node number.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
data (torch.Tensor): The input tensor with its number of elements equal to the
|
|
19
|
+
non-empty octree node number.
|
|
20
|
+
octree (Octree): The corresponding octree.
|
|
21
|
+
depth (int): The depth of current octree.
|
|
22
|
+
val (float): The padding value. (Default: :obj:`0.0`)
|
|
23
|
+
'''
|
|
24
|
+
|
|
25
|
+
mask = octree.nempty_mask(depth)
|
|
26
|
+
size = (octree.nnum[depth], data.shape[1]) # (N, C)
|
|
27
|
+
out = torch.full(size, val, dtype=data.dtype, device=data.device)
|
|
28
|
+
out[mask] = data
|
|
29
|
+
return out
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
|
|
33
|
+
r''' Reverse operation of :func:`octree_depad`.
|
|
34
|
+
|
|
35
|
+
Please refer to :func:`octree_depad` for the meaning of the arguments.
|
|
36
|
+
'''
|
|
37
|
+
|
|
38
|
+
mask = octree.nempty_mask(depth)
|
|
39
|
+
return data[mask]
|