ocnn 2.2.5__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +44 -44
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +429 -429
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +661 -659
- ocnn/octree/points.py +323 -322
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- ocnn-2.2.7.dist-info/METADATA +112 -0
- ocnn-2.2.7.dist-info/RECORD +36 -0
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info}/WHEEL +1 -1
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info/licenses}/LICENSE +21 -21
- ocnn-2.2.5.dist-info/METADATA +0 -80
- ocnn-2.2.5.dist-info/RECORD +0 -36
- {ocnn-2.2.5.dist-info → ocnn-2.2.7.dist-info}/top_level.txt +0 -0
ocnn/nn/octree_gconv.py
CHANGED
|
@@ -1,79 +1,79 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn
|
|
10
|
-
from typing import List
|
|
11
|
-
|
|
12
|
-
import ocnn
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class OctreeGroupConv(torch.nn.Module):
|
|
17
|
-
r''' Performs octree-based group convolution.
|
|
18
|
-
|
|
19
|
-
Args:
|
|
20
|
-
in_channels (int): Number of input channels.
|
|
21
|
-
out_channels (int): Number of output channels.
|
|
22
|
-
kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
|
|
23
|
-
:obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
|
|
24
|
-
:obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
|
|
25
|
-
stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
|
|
26
|
-
nempty (bool): If True, only performs the convolution on non-empty
|
|
27
|
-
octree nodes.
|
|
28
|
-
use_bias (bool): If True, add a bias term to the convolution.
|
|
29
|
-
group (int): The number of groups.
|
|
30
|
-
|
|
31
|
-
.. note::
|
|
32
|
-
Perform octree-based group convolution with a for-loop. The performance is
|
|
33
|
-
not optimal. Use this module only when the group number is small, otherwise
|
|
34
|
-
it may be slow.
|
|
35
|
-
'''
|
|
36
|
-
|
|
37
|
-
def __init__(self, in_channels: int, out_channels: int,
|
|
38
|
-
kernel_size: List[int] = [3], stride: int = 1,
|
|
39
|
-
nempty: bool = False, use_bias: bool = False,
|
|
40
|
-
group: int = 1):
|
|
41
|
-
super().__init__()
|
|
42
|
-
|
|
43
|
-
self.group = group
|
|
44
|
-
self.in_channels = in_channels
|
|
45
|
-
self.out_channels = out_channels
|
|
46
|
-
self.in_channels_per_group = in_channels // group
|
|
47
|
-
self.out_channels_per_group = out_channels // group
|
|
48
|
-
assert in_channels % group == 0 and out_channels % group == 0
|
|
49
|
-
|
|
50
|
-
self.convs = torch.nn.ModuleList([ocnn.nn.OctreeConv(
|
|
51
|
-
self.in_channels_per_group, self.out_channels_per_group,
|
|
52
|
-
kernel_size, stride, nempty, use_bias=use_bias)
|
|
53
|
-
for _ in range(group)])
|
|
54
|
-
|
|
55
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
56
|
-
r''' Defines the octree-based group convolution.
|
|
57
|
-
|
|
58
|
-
Args:
|
|
59
|
-
data (torch.Tensor): The input data.
|
|
60
|
-
octree (Octree): The corresponding octree.
|
|
61
|
-
depth (int): The depth of current octree.
|
|
62
|
-
'''
|
|
63
|
-
|
|
64
|
-
channels = data.shape[1]
|
|
65
|
-
assert channels == self.in_channels
|
|
66
|
-
|
|
67
|
-
outs = [None] * self.group
|
|
68
|
-
slices = torch.split(data, self.in_channels_per_group, dim=1)
|
|
69
|
-
for i in range(self.group):
|
|
70
|
-
outs[i] = self.convs[i](slices[i], octree, depth)
|
|
71
|
-
out = torch.cat(outs, dim=1)
|
|
72
|
-
return out
|
|
73
|
-
|
|
74
|
-
def extra_repr(self) -> str:
|
|
75
|
-
r''' Sets the extra representation of the module.
|
|
76
|
-
'''
|
|
77
|
-
|
|
78
|
-
return ('in_channels={}, out_channels={}, group={}').format(
|
|
79
|
-
self.in_channels, self.out_channels, self.group) # noqa
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from typing import List
|
|
11
|
+
|
|
12
|
+
import ocnn
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OctreeGroupConv(torch.nn.Module):
|
|
17
|
+
r''' Performs octree-based group convolution.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
in_channels (int): Number of input channels.
|
|
21
|
+
out_channels (int): Number of output channels.
|
|
22
|
+
kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
|
|
23
|
+
:obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
|
|
24
|
+
:obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
|
|
25
|
+
stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
|
|
26
|
+
nempty (bool): If True, only performs the convolution on non-empty
|
|
27
|
+
octree nodes.
|
|
28
|
+
use_bias (bool): If True, add a bias term to the convolution.
|
|
29
|
+
group (int): The number of groups.
|
|
30
|
+
|
|
31
|
+
.. note::
|
|
32
|
+
Perform octree-based group convolution with a for-loop. The performance is
|
|
33
|
+
not optimal. Use this module only when the group number is small, otherwise
|
|
34
|
+
it may be slow.
|
|
35
|
+
'''
|
|
36
|
+
|
|
37
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
38
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
39
|
+
nempty: bool = False, use_bias: bool = False,
|
|
40
|
+
group: int = 1):
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
self.group = group
|
|
44
|
+
self.in_channels = in_channels
|
|
45
|
+
self.out_channels = out_channels
|
|
46
|
+
self.in_channels_per_group = in_channels // group
|
|
47
|
+
self.out_channels_per_group = out_channels // group
|
|
48
|
+
assert in_channels % group == 0 and out_channels % group == 0
|
|
49
|
+
|
|
50
|
+
self.convs = torch.nn.ModuleList([ocnn.nn.OctreeConv(
|
|
51
|
+
self.in_channels_per_group, self.out_channels_per_group,
|
|
52
|
+
kernel_size, stride, nempty, use_bias=use_bias)
|
|
53
|
+
for _ in range(group)])
|
|
54
|
+
|
|
55
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
56
|
+
r''' Defines the octree-based group convolution.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
data (torch.Tensor): The input data.
|
|
60
|
+
octree (Octree): The corresponding octree.
|
|
61
|
+
depth (int): The depth of current octree.
|
|
62
|
+
'''
|
|
63
|
+
|
|
64
|
+
channels = data.shape[1]
|
|
65
|
+
assert channels == self.in_channels
|
|
66
|
+
|
|
67
|
+
outs = [None] * self.group
|
|
68
|
+
slices = torch.split(data, self.in_channels_per_group, dim=1)
|
|
69
|
+
for i in range(self.group):
|
|
70
|
+
outs[i] = self.convs[i](slices[i], octree, depth)
|
|
71
|
+
out = torch.cat(outs, dim=1)
|
|
72
|
+
return out
|
|
73
|
+
|
|
74
|
+
def extra_repr(self) -> str:
|
|
75
|
+
r''' Sets the extra representation of the module.
|
|
76
|
+
'''
|
|
77
|
+
|
|
78
|
+
return ('in_channels={}, out_channels={}, group={}').format(
|
|
79
|
+
self.in_channels, self.out_channels, self.group) # noqa
|
ocnn/nn/octree_interp.py
CHANGED
|
@@ -1,196 +1,196 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.sparse
|
|
10
|
-
from typing import List, Optional
|
|
11
|
-
|
|
12
|
-
import ocnn
|
|
13
|
-
from ocnn.octree import Octree
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int,
|
|
17
|
-
pts: torch.Tensor, nempty: bool = False,
|
|
18
|
-
bound_check: bool = False):
|
|
19
|
-
''' The nearest-neighbor interpolatation with input points.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
data (torch.Tensor): The input data.
|
|
23
|
-
octree (Octree): The octree to interpolate.
|
|
24
|
-
depth (int): The depth of the data.
|
|
25
|
-
pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`,
|
|
26
|
-
i.e. :obj:`N x (x, y, z, batch)`.
|
|
27
|
-
nempty (bool): If true, the :attr:`data` only contains features of non-empty
|
|
28
|
-
octree nodes
|
|
29
|
-
bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`.
|
|
30
|
-
|
|
31
|
-
.. note::
|
|
32
|
-
The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`.
|
|
33
|
-
'''
|
|
34
|
-
|
|
35
|
-
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
36
|
-
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
37
|
-
|
|
38
|
-
idx = octree.search_xyzb(pts, depth, nempty)
|
|
39
|
-
valid = idx > -1 # valid indices
|
|
40
|
-
if bound_check:
|
|
41
|
-
bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1)
|
|
42
|
-
valid = torch.logical_and(valid, bound)
|
|
43
|
-
|
|
44
|
-
size = (pts.shape[0], data.shape[1])
|
|
45
|
-
out = torch.zeros(size, device=data.device, dtype=data.dtype)
|
|
46
|
-
out[valid] = data.index_select(0, idx[valid])
|
|
47
|
-
return out
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int,
|
|
51
|
-
pts: torch.Tensor, nempty: bool = False,
|
|
52
|
-
bound_check: bool = False):
|
|
53
|
-
''' Linear interpolatation with input points.
|
|
54
|
-
|
|
55
|
-
Refer to :func:`octree_nearest_pts` for the meaning of the arguments.
|
|
56
|
-
'''
|
|
57
|
-
|
|
58
|
-
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
59
|
-
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
60
|
-
|
|
61
|
-
device = data.device
|
|
62
|
-
grid = torch.tensor(
|
|
63
|
-
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
|
|
64
|
-
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device)
|
|
65
|
-
|
|
66
|
-
# 1. Neighborhood searching
|
|
67
|
-
xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel
|
|
68
|
-
xyzi = xyzf.floor() # the integer part (N, 3)
|
|
69
|
-
frac = xyzf - xyzi # the fraction part (N, 3)
|
|
70
|
-
|
|
71
|
-
xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3)
|
|
72
|
-
batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1)
|
|
73
|
-
idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty)
|
|
74
|
-
valid = idx > -1 # valid indices
|
|
75
|
-
if bound_check:
|
|
76
|
-
bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1)
|
|
77
|
-
valid = torch.logical_and(valid, bound)
|
|
78
|
-
idx = idx[valid]
|
|
79
|
-
|
|
80
|
-
# 2. Build the sparse matrix
|
|
81
|
-
npt = pts.shape[0]
|
|
82
|
-
ids = torch.arange(npt, device=idx.device)
|
|
83
|
-
ids = ids.unsqueeze(1).repeat(1, 8).view(-1)
|
|
84
|
-
ids = ids[valid]
|
|
85
|
-
indices = torch.stack([ids, idx], dim=0).long()
|
|
86
|
-
|
|
87
|
-
frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3)
|
|
88
|
-
weight = frac.prod(dim=2).abs().view(-1) # (8*N,)
|
|
89
|
-
weight = weight[valid]
|
|
90
|
-
|
|
91
|
-
h = data.shape[0]
|
|
92
|
-
mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device)
|
|
93
|
-
|
|
94
|
-
# 3. Interpolatation
|
|
95
|
-
output = torch.sparse.mm(mat, data)
|
|
96
|
-
ones = torch.ones(h, 1, dtype=data.dtype, device=device)
|
|
97
|
-
norm = torch.sparse.mm(mat, ones)
|
|
98
|
-
output = torch.div(output, norm + 1e-12)
|
|
99
|
-
return output
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class OctreeInterp(torch.nn.Module):
|
|
103
|
-
r''' Interpolates the points with an octree feature.
|
|
104
|
-
|
|
105
|
-
Refer to :func:`octree_nearest_pts` for a description of arguments.
|
|
106
|
-
'''
|
|
107
|
-
|
|
108
|
-
def __init__(self, method: str = 'linear', nempty: bool = False,
|
|
109
|
-
bound_check: bool = False, rescale_pts: bool = True):
|
|
110
|
-
super().__init__()
|
|
111
|
-
self.method = method
|
|
112
|
-
self.nempty = nempty
|
|
113
|
-
self.bound_check = bound_check
|
|
114
|
-
self.rescale_pts = rescale_pts
|
|
115
|
-
self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
|
|
116
|
-
|
|
117
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
118
|
-
pts: torch.Tensor):
|
|
119
|
-
r''''''
|
|
120
|
-
|
|
121
|
-
# rescale points from [-1, 1] to [0, 2^depth]
|
|
122
|
-
if self.rescale_pts:
|
|
123
|
-
scale = 2 ** (depth - 1)
|
|
124
|
-
pts[:, :3] = (pts[:, :3] + 1.0) * scale
|
|
125
|
-
|
|
126
|
-
return self.func(data, octree, depth, pts, self.nempty, self.bound_check)
|
|
127
|
-
|
|
128
|
-
def extra_repr(self) -> str:
|
|
129
|
-
r''' Sets the extra representation of the module.
|
|
130
|
-
'''
|
|
131
|
-
|
|
132
|
-
return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format(
|
|
133
|
-
self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int,
|
|
137
|
-
nempty: bool = False):
|
|
138
|
-
r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)`
|
|
139
|
-
with the nearest-neighbor interpolation.
|
|
140
|
-
|
|
141
|
-
Args:
|
|
142
|
-
data (torch.Tensor): The input data.
|
|
143
|
-
octree (Octree): The octree to interpolate.
|
|
144
|
-
depth (int): The depth of the data.
|
|
145
|
-
nempty (bool): If true, the :attr:`data` only contains features of non-empty
|
|
146
|
-
octree nodes.
|
|
147
|
-
'''
|
|
148
|
-
|
|
149
|
-
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
150
|
-
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
151
|
-
|
|
152
|
-
out = data
|
|
153
|
-
if not nempty:
|
|
154
|
-
out = ocnn.nn.octree_depad(out, octree, depth)
|
|
155
|
-
out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1)
|
|
156
|
-
if nempty:
|
|
157
|
-
out = ocnn.nn.octree_depad(out, octree, depth+1) # !!! depth+1
|
|
158
|
-
return out
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
class OctreeUpsample(torch.nn.Module):
|
|
162
|
-
r''' Upsamples the octree node features from :attr:`depth` to
|
|
163
|
-
:attr:`(target_depth)`.
|
|
164
|
-
|
|
165
|
-
Refer to :class:`octree_nearest_pts` for details.
|
|
166
|
-
'''
|
|
167
|
-
|
|
168
|
-
def __init__(self, method: str = 'linear', nempty: bool = False):
|
|
169
|
-
super().__init__()
|
|
170
|
-
self.method = method
|
|
171
|
-
self.nempty = nempty
|
|
172
|
-
self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
|
|
173
|
-
|
|
174
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
175
|
-
target_depth: Optional[int] = None):
|
|
176
|
-
r''''''
|
|
177
|
-
|
|
178
|
-
if target_depth is None:
|
|
179
|
-
target_depth = depth + 1
|
|
180
|
-
if target_depth == depth:
|
|
181
|
-
return data # return, do nothing
|
|
182
|
-
assert target_depth >= depth, 'target_depth must be larger than depth'
|
|
183
|
-
|
|
184
|
-
if target_depth == depth + 1 and self.method == 'nearest':
|
|
185
|
-
return octree_nearest_upsample(data, octree, depth, self.nempty)
|
|
186
|
-
|
|
187
|
-
xyzb = octree.xyzb(target_depth, self.nempty)
|
|
188
|
-
pts = torch.stack(xyzb, dim=1).float()
|
|
189
|
-
pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale
|
|
190
|
-
return self.func(data, octree, depth, pts, self.nempty)
|
|
191
|
-
|
|
192
|
-
def extra_repr(self) -> str:
|
|
193
|
-
r''' Sets the extra representation of the module.
|
|
194
|
-
'''
|
|
195
|
-
|
|
196
|
-
return ('method={}, nempty={}').format(self.method, self.nempty)
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.sparse
|
|
10
|
+
from typing import List, Optional
|
|
11
|
+
|
|
12
|
+
import ocnn
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int,
|
|
17
|
+
pts: torch.Tensor, nempty: bool = False,
|
|
18
|
+
bound_check: bool = False):
|
|
19
|
+
''' The nearest-neighbor interpolatation with input points.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
data (torch.Tensor): The input data.
|
|
23
|
+
octree (Octree): The octree to interpolate.
|
|
24
|
+
depth (int): The depth of the data.
|
|
25
|
+
pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`,
|
|
26
|
+
i.e. :obj:`N x (x, y, z, batch)`.
|
|
27
|
+
nempty (bool): If true, the :attr:`data` only contains features of non-empty
|
|
28
|
+
octree nodes
|
|
29
|
+
bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`.
|
|
30
|
+
|
|
31
|
+
.. note::
|
|
32
|
+
The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`.
|
|
33
|
+
'''
|
|
34
|
+
|
|
35
|
+
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
36
|
+
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
37
|
+
|
|
38
|
+
idx = octree.search_xyzb(pts, depth, nempty)
|
|
39
|
+
valid = idx > -1 # valid indices
|
|
40
|
+
if bound_check:
|
|
41
|
+
bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1)
|
|
42
|
+
valid = torch.logical_and(valid, bound)
|
|
43
|
+
|
|
44
|
+
size = (pts.shape[0], data.shape[1])
|
|
45
|
+
out = torch.zeros(size, device=data.device, dtype=data.dtype)
|
|
46
|
+
out[valid] = data.index_select(0, idx[valid])
|
|
47
|
+
return out
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int,
|
|
51
|
+
pts: torch.Tensor, nempty: bool = False,
|
|
52
|
+
bound_check: bool = False):
|
|
53
|
+
''' Linear interpolatation with input points.
|
|
54
|
+
|
|
55
|
+
Refer to :func:`octree_nearest_pts` for the meaning of the arguments.
|
|
56
|
+
'''
|
|
57
|
+
|
|
58
|
+
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
59
|
+
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
60
|
+
|
|
61
|
+
device = data.device
|
|
62
|
+
grid = torch.tensor(
|
|
63
|
+
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
|
|
64
|
+
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device)
|
|
65
|
+
|
|
66
|
+
# 1. Neighborhood searching
|
|
67
|
+
xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel
|
|
68
|
+
xyzi = xyzf.floor() # the integer part (N, 3)
|
|
69
|
+
frac = xyzf - xyzi # the fraction part (N, 3)
|
|
70
|
+
|
|
71
|
+
xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3)
|
|
72
|
+
batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1)
|
|
73
|
+
idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty)
|
|
74
|
+
valid = idx > -1 # valid indices
|
|
75
|
+
if bound_check:
|
|
76
|
+
bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1)
|
|
77
|
+
valid = torch.logical_and(valid, bound)
|
|
78
|
+
idx = idx[valid]
|
|
79
|
+
|
|
80
|
+
# 2. Build the sparse matrix
|
|
81
|
+
npt = pts.shape[0]
|
|
82
|
+
ids = torch.arange(npt, device=idx.device)
|
|
83
|
+
ids = ids.unsqueeze(1).repeat(1, 8).view(-1)
|
|
84
|
+
ids = ids[valid]
|
|
85
|
+
indices = torch.stack([ids, idx], dim=0).long()
|
|
86
|
+
|
|
87
|
+
frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3)
|
|
88
|
+
weight = frac.prod(dim=2).abs().view(-1) # (8*N,)
|
|
89
|
+
weight = weight[valid]
|
|
90
|
+
|
|
91
|
+
h = data.shape[0]
|
|
92
|
+
mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device)
|
|
93
|
+
|
|
94
|
+
# 3. Interpolatation
|
|
95
|
+
output = torch.sparse.mm(mat, data)
|
|
96
|
+
ones = torch.ones(h, 1, dtype=data.dtype, device=device)
|
|
97
|
+
norm = torch.sparse.mm(mat, ones)
|
|
98
|
+
output = torch.div(output, norm + 1e-12)
|
|
99
|
+
return output
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class OctreeInterp(torch.nn.Module):
|
|
103
|
+
r''' Interpolates the points with an octree feature.
|
|
104
|
+
|
|
105
|
+
Refer to :func:`octree_nearest_pts` for a description of arguments.
|
|
106
|
+
'''
|
|
107
|
+
|
|
108
|
+
def __init__(self, method: str = 'linear', nempty: bool = False,
|
|
109
|
+
bound_check: bool = False, rescale_pts: bool = True):
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.method = method
|
|
112
|
+
self.nempty = nempty
|
|
113
|
+
self.bound_check = bound_check
|
|
114
|
+
self.rescale_pts = rescale_pts
|
|
115
|
+
self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
|
|
116
|
+
|
|
117
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
118
|
+
pts: torch.Tensor):
|
|
119
|
+
r''''''
|
|
120
|
+
|
|
121
|
+
# rescale points from [-1, 1] to [0, 2^depth]
|
|
122
|
+
if self.rescale_pts:
|
|
123
|
+
scale = 2 ** (depth - 1)
|
|
124
|
+
pts[:, :3] = (pts[:, :3] + 1.0) * scale
|
|
125
|
+
|
|
126
|
+
return self.func(data, octree, depth, pts, self.nempty, self.bound_check)
|
|
127
|
+
|
|
128
|
+
def extra_repr(self) -> str:
|
|
129
|
+
r''' Sets the extra representation of the module.
|
|
130
|
+
'''
|
|
131
|
+
|
|
132
|
+
return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format(
|
|
133
|
+
self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int,
|
|
137
|
+
nempty: bool = False):
|
|
138
|
+
r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)`
|
|
139
|
+
with the nearest-neighbor interpolation.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
data (torch.Tensor): The input data.
|
|
143
|
+
octree (Octree): The octree to interpolate.
|
|
144
|
+
depth (int): The depth of the data.
|
|
145
|
+
nempty (bool): If true, the :attr:`data` only contains features of non-empty
|
|
146
|
+
octree nodes.
|
|
147
|
+
'''
|
|
148
|
+
|
|
149
|
+
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
150
|
+
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
|
|
151
|
+
|
|
152
|
+
out = data
|
|
153
|
+
if not nempty:
|
|
154
|
+
out = ocnn.nn.octree_depad(out, octree, depth)
|
|
155
|
+
out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1)
|
|
156
|
+
if nempty:
|
|
157
|
+
out = ocnn.nn.octree_depad(out, octree, depth+1) # !!! depth+1
|
|
158
|
+
return out
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class OctreeUpsample(torch.nn.Module):
|
|
162
|
+
r''' Upsamples the octree node features from :attr:`depth` to
|
|
163
|
+
:attr:`(target_depth)`.
|
|
164
|
+
|
|
165
|
+
Refer to :class:`octree_nearest_pts` for details.
|
|
166
|
+
'''
|
|
167
|
+
|
|
168
|
+
def __init__(self, method: str = 'linear', nempty: bool = False):
|
|
169
|
+
super().__init__()
|
|
170
|
+
self.method = method
|
|
171
|
+
self.nempty = nempty
|
|
172
|
+
self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
|
|
173
|
+
|
|
174
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
175
|
+
target_depth: Optional[int] = None):
|
|
176
|
+
r''''''
|
|
177
|
+
|
|
178
|
+
if target_depth is None:
|
|
179
|
+
target_depth = depth + 1
|
|
180
|
+
if target_depth == depth:
|
|
181
|
+
return data # return, do nothing
|
|
182
|
+
assert target_depth >= depth, 'target_depth must be larger than depth'
|
|
183
|
+
|
|
184
|
+
if target_depth == depth + 1 and self.method == 'nearest':
|
|
185
|
+
return octree_nearest_upsample(data, octree, depth, self.nempty)
|
|
186
|
+
|
|
187
|
+
xyzb = octree.xyzb(target_depth, self.nempty)
|
|
188
|
+
pts = torch.stack(xyzb, dim=1).float()
|
|
189
|
+
pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale
|
|
190
|
+
return self.func(data, octree, depth, pts, self.nempty)
|
|
191
|
+
|
|
192
|
+
def extra_repr(self) -> str:
|
|
193
|
+
r''' Sets the extra representation of the module.
|
|
194
|
+
'''
|
|
195
|
+
|
|
196
|
+
return ('method={}, nempty={}').format(self.method, self.nempty)
|