ocnn 2.2.6__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +44 -44
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +429 -429
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +661 -661
- ocnn/octree/points.py +323 -322
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info}/METADATA +112 -91
- ocnn-2.2.7.dist-info/RECORD +36 -0
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info}/WHEEL +1 -1
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info/licenses}/LICENSE +21 -21
- ocnn-2.2.6.dist-info/RECORD +0 -36
- {ocnn-2.2.6.dist-info → ocnn-2.2.7.dist-info}/top_level.txt +0 -0
ocnn/octree/shuffled_key.py
CHANGED
|
@@ -1,115 +1,115 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from typing import Optional, Union
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class KeyLUT:
|
|
13
|
-
|
|
14
|
-
def __init__(self):
|
|
15
|
-
r256 = torch.arange(256, dtype=torch.int64)
|
|
16
|
-
r512 = torch.arange(512, dtype=torch.int64)
|
|
17
|
-
zero = torch.zeros(256, dtype=torch.int64)
|
|
18
|
-
device = torch.device('cpu')
|
|
19
|
-
|
|
20
|
-
self._encode = {device: (self.xyz2key(r256, zero, zero, 8),
|
|
21
|
-
self.xyz2key(zero, r256, zero, 8),
|
|
22
|
-
self.xyz2key(zero, zero, r256, 8))}
|
|
23
|
-
self._decode = {device: self.key2xyz(r512, 9)}
|
|
24
|
-
|
|
25
|
-
def encode_lut(self, device=torch.device('cpu')):
|
|
26
|
-
if device not in self._encode:
|
|
27
|
-
cpu = torch.device('cpu')
|
|
28
|
-
self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
|
|
29
|
-
return self._encode[device]
|
|
30
|
-
|
|
31
|
-
def decode_lut(self, device=torch.device('cpu')):
|
|
32
|
-
if device not in self._decode:
|
|
33
|
-
cpu = torch.device('cpu')
|
|
34
|
-
self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
|
|
35
|
-
return self._decode[device]
|
|
36
|
-
|
|
37
|
-
def xyz2key(self, x, y, z, depth):
|
|
38
|
-
key = torch.zeros_like(x)
|
|
39
|
-
for i in range(depth):
|
|
40
|
-
mask = 1 << i
|
|
41
|
-
key = (key | ((x & mask) << (2 * i + 2)) |
|
|
42
|
-
((y & mask) << (2 * i + 1)) |
|
|
43
|
-
((z & mask) << (2 * i + 0)))
|
|
44
|
-
return key
|
|
45
|
-
|
|
46
|
-
def key2xyz(self, key, depth):
|
|
47
|
-
x = torch.zeros_like(key)
|
|
48
|
-
y = torch.zeros_like(key)
|
|
49
|
-
z = torch.zeros_like(key)
|
|
50
|
-
for i in range(depth):
|
|
51
|
-
x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
|
|
52
|
-
y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
|
|
53
|
-
z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
|
|
54
|
-
return x, y, z
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
_key_lut = KeyLUT()
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def xyz2key(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor,
|
|
61
|
-
b: Optional[Union[torch.Tensor, int]] = None, depth: int = 16):
|
|
62
|
-
r'''Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
|
|
63
|
-
based on pre-computed look up tables. The speed of this function is much
|
|
64
|
-
faster than the method based on for-loop.
|
|
65
|
-
|
|
66
|
-
Args:
|
|
67
|
-
x (torch.Tensor): The x coordinate.
|
|
68
|
-
y (torch.Tensor): The y coordinate.
|
|
69
|
-
z (torch.Tensor): The z coordinate.
|
|
70
|
-
b (torch.Tensor or int): The batch index of the coordinates, and should be
|
|
71
|
-
smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
|
|
72
|
-
:attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
|
|
73
|
-
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
|
74
|
-
'''
|
|
75
|
-
|
|
76
|
-
EX, EY, EZ = _key_lut.encode_lut(x.device)
|
|
77
|
-
x, y, z = x.long(), y.long(), z.long()
|
|
78
|
-
|
|
79
|
-
mask = 255 if depth > 8 else (1 << depth) - 1
|
|
80
|
-
key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
|
|
81
|
-
if depth > 8:
|
|
82
|
-
mask = (1 << (depth-8)) - 1
|
|
83
|
-
key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
|
|
84
|
-
key = key16 << 24 | key
|
|
85
|
-
|
|
86
|
-
if b is not None:
|
|
87
|
-
b = b.long()
|
|
88
|
-
key = b << 48 | key
|
|
89
|
-
|
|
90
|
-
return key
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def key2xyz(key: torch.Tensor, depth: int = 16):
|
|
94
|
-
r'''Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
|
|
95
|
-
and the batch index based on pre-computed look up tables.
|
|
96
|
-
|
|
97
|
-
Args:
|
|
98
|
-
key (torch.Tensor): The shuffled key.
|
|
99
|
-
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
|
100
|
-
'''
|
|
101
|
-
|
|
102
|
-
DX, DY, DZ = _key_lut.decode_lut(key.device)
|
|
103
|
-
x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
|
|
104
|
-
|
|
105
|
-
b = key >> 48
|
|
106
|
-
key = key & ((1 << 48) - 1)
|
|
107
|
-
|
|
108
|
-
n = (depth + 2) // 3
|
|
109
|
-
for i in range(n):
|
|
110
|
-
k = key >> (i * 9) & 511
|
|
111
|
-
x = x | (DX[k] << (i * 3))
|
|
112
|
-
y = y | (DY[k] << (i * 3))
|
|
113
|
-
z = z | (DZ[k] << (i * 3))
|
|
114
|
-
|
|
115
|
-
return x, y, z, b
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from typing import Optional, Union
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class KeyLUT:
|
|
13
|
+
|
|
14
|
+
def __init__(self):
|
|
15
|
+
r256 = torch.arange(256, dtype=torch.int64)
|
|
16
|
+
r512 = torch.arange(512, dtype=torch.int64)
|
|
17
|
+
zero = torch.zeros(256, dtype=torch.int64)
|
|
18
|
+
device = torch.device('cpu')
|
|
19
|
+
|
|
20
|
+
self._encode = {device: (self.xyz2key(r256, zero, zero, 8),
|
|
21
|
+
self.xyz2key(zero, r256, zero, 8),
|
|
22
|
+
self.xyz2key(zero, zero, r256, 8))}
|
|
23
|
+
self._decode = {device: self.key2xyz(r512, 9)}
|
|
24
|
+
|
|
25
|
+
def encode_lut(self, device=torch.device('cpu')):
|
|
26
|
+
if device not in self._encode:
|
|
27
|
+
cpu = torch.device('cpu')
|
|
28
|
+
self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
|
|
29
|
+
return self._encode[device]
|
|
30
|
+
|
|
31
|
+
def decode_lut(self, device=torch.device('cpu')):
|
|
32
|
+
if device not in self._decode:
|
|
33
|
+
cpu = torch.device('cpu')
|
|
34
|
+
self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
|
|
35
|
+
return self._decode[device]
|
|
36
|
+
|
|
37
|
+
def xyz2key(self, x, y, z, depth):
|
|
38
|
+
key = torch.zeros_like(x)
|
|
39
|
+
for i in range(depth):
|
|
40
|
+
mask = 1 << i
|
|
41
|
+
key = (key | ((x & mask) << (2 * i + 2)) |
|
|
42
|
+
((y & mask) << (2 * i + 1)) |
|
|
43
|
+
((z & mask) << (2 * i + 0)))
|
|
44
|
+
return key
|
|
45
|
+
|
|
46
|
+
def key2xyz(self, key, depth):
|
|
47
|
+
x = torch.zeros_like(key)
|
|
48
|
+
y = torch.zeros_like(key)
|
|
49
|
+
z = torch.zeros_like(key)
|
|
50
|
+
for i in range(depth):
|
|
51
|
+
x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
|
|
52
|
+
y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
|
|
53
|
+
z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
|
|
54
|
+
return x, y, z
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
_key_lut = KeyLUT()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def xyz2key(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor,
|
|
61
|
+
b: Optional[Union[torch.Tensor, int]] = None, depth: int = 16):
|
|
62
|
+
r'''Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
|
|
63
|
+
based on pre-computed look up tables. The speed of this function is much
|
|
64
|
+
faster than the method based on for-loop.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
x (torch.Tensor): The x coordinate.
|
|
68
|
+
y (torch.Tensor): The y coordinate.
|
|
69
|
+
z (torch.Tensor): The z coordinate.
|
|
70
|
+
b (torch.Tensor or int): The batch index of the coordinates, and should be
|
|
71
|
+
smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
|
|
72
|
+
:attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
|
|
73
|
+
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
|
74
|
+
'''
|
|
75
|
+
|
|
76
|
+
EX, EY, EZ = _key_lut.encode_lut(x.device)
|
|
77
|
+
x, y, z = x.long(), y.long(), z.long()
|
|
78
|
+
|
|
79
|
+
mask = 255 if depth > 8 else (1 << depth) - 1
|
|
80
|
+
key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
|
|
81
|
+
if depth > 8:
|
|
82
|
+
mask = (1 << (depth-8)) - 1
|
|
83
|
+
key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
|
|
84
|
+
key = key16 << 24 | key
|
|
85
|
+
|
|
86
|
+
if b is not None:
|
|
87
|
+
b = b.long()
|
|
88
|
+
key = b << 48 | key
|
|
89
|
+
|
|
90
|
+
return key
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def key2xyz(key: torch.Tensor, depth: int = 16):
|
|
94
|
+
r'''Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
|
|
95
|
+
and the batch index based on pre-computed look up tables.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
key (torch.Tensor): The shuffled key.
|
|
99
|
+
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
|
100
|
+
'''
|
|
101
|
+
|
|
102
|
+
DX, DY, DZ = _key_lut.decode_lut(key.device)
|
|
103
|
+
x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
|
|
104
|
+
|
|
105
|
+
b = key >> 48
|
|
106
|
+
key = key & ((1 << 48) - 1)
|
|
107
|
+
|
|
108
|
+
n = (depth + 2) // 3
|
|
109
|
+
for i in range(n):
|
|
110
|
+
k = key >> (i * 9) & 511
|
|
111
|
+
x = x | (DX[k] << (i * 3))
|
|
112
|
+
y = y | (DY[k] << (i * 3))
|
|
113
|
+
z = z | (DZ[k] << (i * 3))
|
|
114
|
+
|
|
115
|
+
return x, y, z, b
|
ocnn/utils.py
CHANGED
|
@@ -1,205 +1,205 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import math
|
|
9
|
-
import torch
|
|
10
|
-
from typing import Optional
|
|
11
|
-
from packaging import version
|
|
12
|
-
|
|
13
|
-
import ocnn
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
__all__ = ['trunc_div', 'meshgrid', 'cumsum', 'scatter_add', 'xavier_uniform_',
|
|
17
|
-
'resize_with_last_val', 'list2str', 'build_example_octree']
|
|
18
|
-
classes = __all__
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def trunc_div(input, other):
|
|
22
|
-
r''' Wraps :func:`torch.div` for compatibility. It rounds the results of the
|
|
23
|
-
division towards zero and is equivalent to C-style integer division.
|
|
24
|
-
'''
|
|
25
|
-
|
|
26
|
-
larger_than_171 = version.parse(torch.__version__) > version.parse('1.7.1')
|
|
27
|
-
|
|
28
|
-
if larger_than_171:
|
|
29
|
-
return torch.div(input, other, rounding_mode='trunc')
|
|
30
|
-
else:
|
|
31
|
-
return torch.floor_divide(input, other)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def meshgrid(*tensors, indexing: Optional[str] = None):
|
|
35
|
-
r''' Wraps :func:`torch.meshgrid` for compatibility.
|
|
36
|
-
'''
|
|
37
|
-
|
|
38
|
-
larger_than_191 = version.parse(torch.__version__) > version.parse('1.9.1')
|
|
39
|
-
|
|
40
|
-
if larger_than_191:
|
|
41
|
-
return torch.meshgrid(*tensors, indexing=indexing)
|
|
42
|
-
else:
|
|
43
|
-
return torch.meshgrid(*tensors)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def range_grid(min: int, max: int, device: torch.device = 'cpu'):
|
|
47
|
-
r''' Builds a 3D mesh grid in :obj:`[min, max]` (:attr:`max` included).
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
min (int): The minimum value of the grid.
|
|
51
|
-
max (int): The maximum value of the grid.
|
|
52
|
-
device (torch.device, optional): The device to place the grid on.
|
|
53
|
-
|
|
54
|
-
Returns:
|
|
55
|
-
torch.Tensor: A 3D mesh grid tensor of shape (N, 3), where N is the total
|
|
56
|
-
number of grid points.
|
|
57
|
-
|
|
58
|
-
Example:
|
|
59
|
-
>>> grid = range_grid(0, 1)
|
|
60
|
-
>>> print(grid)
|
|
61
|
-
tensor([[0, 0, 0],
|
|
62
|
-
[0, 0, 1],
|
|
63
|
-
[0, 1, 0],
|
|
64
|
-
[0, 1, 1],
|
|
65
|
-
[1, 0, 0],
|
|
66
|
-
[1, 0, 1],
|
|
67
|
-
[1, 1, 0],
|
|
68
|
-
[1, 1, 1]])
|
|
69
|
-
'''
|
|
70
|
-
|
|
71
|
-
rng = torch.arange(min, max+1, dtype=torch.long, device=device)
|
|
72
|
-
grid = meshgrid(rng, rng, rng, indexing='ij')
|
|
73
|
-
grid = torch.stack(grid, dim=-1).view(-1, 3)
|
|
74
|
-
return grid
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def cumsum(data: torch.Tensor, dim: int, exclusive: bool = False):
|
|
78
|
-
r''' Extends :func:`torch.cumsum` with the input argument :attr:`exclusive`.
|
|
79
|
-
|
|
80
|
-
Args:
|
|
81
|
-
data (torch.Tensor): The input data.
|
|
82
|
-
dim (int): The dimension to do the operation over.
|
|
83
|
-
exclusive (bool): If false, the behavior is the same as :func:`torch.cumsum`;
|
|
84
|
-
if true, returns the cumulative sum exclusively. Note that if ture,
|
|
85
|
-
the shape of output tensor is larger by 1 than :attr:`data` in the
|
|
86
|
-
dimension where the computation occurs.
|
|
87
|
-
'''
|
|
88
|
-
|
|
89
|
-
out = torch.cumsum(data, dim)
|
|
90
|
-
|
|
91
|
-
if exclusive:
|
|
92
|
-
size = list(data.size())
|
|
93
|
-
size[dim] = 1
|
|
94
|
-
zeros = out.new_zeros(size)
|
|
95
|
-
out = torch.cat([zeros, out], dim)
|
|
96
|
-
return out
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
|
|
100
|
-
r''' Broadcast :attr:`src` according to :attr:`other`, originally from the
|
|
101
|
-
library `pytorch_scatter`.
|
|
102
|
-
'''
|
|
103
|
-
|
|
104
|
-
if dim < 0:
|
|
105
|
-
dim = other.dim() + dim
|
|
106
|
-
|
|
107
|
-
if src.dim() == 1:
|
|
108
|
-
for _ in range(0, dim):
|
|
109
|
-
src = src.unsqueeze(0)
|
|
110
|
-
for _ in range(src.dim(), other.dim()):
|
|
111
|
-
src = src.unsqueeze(-1)
|
|
112
|
-
|
|
113
|
-
src = src.expand_as(other)
|
|
114
|
-
return src
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
|
|
118
|
-
out: Optional[torch.Tensor] = None,
|
|
119
|
-
dim_size: Optional[int] = None,) -> torch.Tensor:
|
|
120
|
-
r''' Reduces all values from the :attr:`src` tensor into :attr:`out` at the
|
|
121
|
-
indices specified in the :attr:`index` tensor along a given axis :attr:`dim`.
|
|
122
|
-
This is just a wrapper of :func:`torch.scatter` in a boardcasting fashion.
|
|
123
|
-
|
|
124
|
-
Args:
|
|
125
|
-
src (torch.Tensor): The source tensor.
|
|
126
|
-
index (torch.Tensor): The indices of elements to scatter.
|
|
127
|
-
dim (torch.Tensor): The axis along which to index, (default: :obj:`-1`).
|
|
128
|
-
out (torch.Tensor or None): The destination tensor.
|
|
129
|
-
dim_size (int or None): If :attr:`out` is not given, automatically create
|
|
130
|
-
output with size :attr:`dim_size` at dimension :attr:`dim`. If
|
|
131
|
-
:attr:`dim_size` is not given, a minimal sized output tensor according
|
|
132
|
-
to :obj:`index.max() + 1` is returned.
|
|
133
|
-
'''
|
|
134
|
-
|
|
135
|
-
index = broadcast(index, src, dim)
|
|
136
|
-
|
|
137
|
-
if out is None:
|
|
138
|
-
size = list(src.size())
|
|
139
|
-
if dim_size is not None:
|
|
140
|
-
size[dim] = dim_size
|
|
141
|
-
elif index.numel() == 0:
|
|
142
|
-
size[dim] = 0
|
|
143
|
-
else:
|
|
144
|
-
size[dim] = int(index.max()) + 1
|
|
145
|
-
out = torch.zeros(size, dtype=src.dtype, device=src.device)
|
|
146
|
-
|
|
147
|
-
return out.scatter_add_(dim, index, src)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
def xavier_uniform_(weights: torch.Tensor):
|
|
151
|
-
r''' Initialize convolution weights with the same method as
|
|
152
|
-
:obj:`torch.nn.init.xavier_uniform_`.
|
|
153
|
-
|
|
154
|
-
:obj:`torch.nn.init.xavier_uniform_` initialize a tensor with shape
|
|
155
|
-
:obj:`(out_c, in_c, kdim)`, which can not be used in :class:`ocnn.nn.OctreeConv`
|
|
156
|
-
since the the shape of :attr:`OctreeConv.weights` is :obj:`(kdim, in_c,
|
|
157
|
-
out_c)`.
|
|
158
|
-
'''
|
|
159
|
-
|
|
160
|
-
shape = weights.shape # (kernel_dim, in_conv, out_conv)
|
|
161
|
-
fan_in = shape[0] * shape[1]
|
|
162
|
-
fan_out = shape[0] * shape[2]
|
|
163
|
-
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
|
164
|
-
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
|
165
|
-
|
|
166
|
-
torch.nn.init.uniform_(weights, -a, a)
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def resize_with_last_val(list_in: list, num: int = 3):
|
|
170
|
-
r''' Resizes the number of elements of :attr:`list_in` to :attr:`num` with
|
|
171
|
-
the last element of :attr:`list_in` if its number of elements is smaller
|
|
172
|
-
than :attr:`num`.
|
|
173
|
-
'''
|
|
174
|
-
|
|
175
|
-
assert (type(list_in) is list and len(list_in) < num + 1)
|
|
176
|
-
for _ in range(len(list_in), num):
|
|
177
|
-
list_in.append(list_in[-1])
|
|
178
|
-
return list_in
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
def list2str(list_in: list):
|
|
182
|
-
r''' Returns a string representation of :attr:`list_in`.
|
|
183
|
-
'''
|
|
184
|
-
|
|
185
|
-
out = [str(x) for x in list_in]
|
|
186
|
-
return ''.join(out)
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
def build_example_octree(depth: int = 5, full_depth: int = 2, pt_num: int = 3):
|
|
190
|
-
r''' Builds an example octree on CPU from at most 3 points.
|
|
191
|
-
'''
|
|
192
|
-
# initialize the point cloud
|
|
193
|
-
points = torch.Tensor([[-1, -1, -1], [0, 0, -1], [0.0625, 0.0625, -1]])
|
|
194
|
-
normals = torch.Tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0]])
|
|
195
|
-
features = torch.Tensor([[1, -1], [2, -2], [3, -3]])
|
|
196
|
-
labels = torch.Tensor([[0], [2], [2]])
|
|
197
|
-
|
|
198
|
-
assert pt_num <= 3 and pt_num > 0
|
|
199
|
-
point_cloud = ocnn.octree.Points(
|
|
200
|
-
points[:pt_num], normals[:pt_num], features[:pt_num], labels[:pt_num])
|
|
201
|
-
|
|
202
|
-
# build octree
|
|
203
|
-
octree = ocnn.octree.Octree(depth, full_depth)
|
|
204
|
-
octree.build_octree(point_cloud)
|
|
205
|
-
return octree
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
import torch
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from packaging import version
|
|
12
|
+
|
|
13
|
+
import ocnn
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = ['trunc_div', 'meshgrid', 'cumsum', 'scatter_add', 'xavier_uniform_',
|
|
17
|
+
'resize_with_last_val', 'list2str', 'build_example_octree']
|
|
18
|
+
classes = __all__
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def trunc_div(input, other):
|
|
22
|
+
r''' Wraps :func:`torch.div` for compatibility. It rounds the results of the
|
|
23
|
+
division towards zero and is equivalent to C-style integer division.
|
|
24
|
+
'''
|
|
25
|
+
|
|
26
|
+
larger_than_171 = version.parse(torch.__version__) > version.parse('1.7.1')
|
|
27
|
+
|
|
28
|
+
if larger_than_171:
|
|
29
|
+
return torch.div(input, other, rounding_mode='trunc')
|
|
30
|
+
else:
|
|
31
|
+
return torch.floor_divide(input, other)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def meshgrid(*tensors, indexing: Optional[str] = None):
|
|
35
|
+
r''' Wraps :func:`torch.meshgrid` for compatibility.
|
|
36
|
+
'''
|
|
37
|
+
|
|
38
|
+
larger_than_191 = version.parse(torch.__version__) > version.parse('1.9.1')
|
|
39
|
+
|
|
40
|
+
if larger_than_191:
|
|
41
|
+
return torch.meshgrid(*tensors, indexing=indexing)
|
|
42
|
+
else:
|
|
43
|
+
return torch.meshgrid(*tensors)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def range_grid(min: int, max: int, device: torch.device = 'cpu'):
|
|
47
|
+
r''' Builds a 3D mesh grid in :obj:`[min, max]` (:attr:`max` included).
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
min (int): The minimum value of the grid.
|
|
51
|
+
max (int): The maximum value of the grid.
|
|
52
|
+
device (torch.device, optional): The device to place the grid on.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
torch.Tensor: A 3D mesh grid tensor of shape (N, 3), where N is the total
|
|
56
|
+
number of grid points.
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
>>> grid = range_grid(0, 1)
|
|
60
|
+
>>> print(grid)
|
|
61
|
+
tensor([[0, 0, 0],
|
|
62
|
+
[0, 0, 1],
|
|
63
|
+
[0, 1, 0],
|
|
64
|
+
[0, 1, 1],
|
|
65
|
+
[1, 0, 0],
|
|
66
|
+
[1, 0, 1],
|
|
67
|
+
[1, 1, 0],
|
|
68
|
+
[1, 1, 1]])
|
|
69
|
+
'''
|
|
70
|
+
|
|
71
|
+
rng = torch.arange(min, max+1, dtype=torch.long, device=device)
|
|
72
|
+
grid = meshgrid(rng, rng, rng, indexing='ij')
|
|
73
|
+
grid = torch.stack(grid, dim=-1).view(-1, 3)
|
|
74
|
+
return grid
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def cumsum(data: torch.Tensor, dim: int, exclusive: bool = False):
|
|
78
|
+
r''' Extends :func:`torch.cumsum` with the input argument :attr:`exclusive`.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
data (torch.Tensor): The input data.
|
|
82
|
+
dim (int): The dimension to do the operation over.
|
|
83
|
+
exclusive (bool): If false, the behavior is the same as :func:`torch.cumsum`;
|
|
84
|
+
if true, returns the cumulative sum exclusively. Note that if ture,
|
|
85
|
+
the shape of output tensor is larger by 1 than :attr:`data` in the
|
|
86
|
+
dimension where the computation occurs.
|
|
87
|
+
'''
|
|
88
|
+
|
|
89
|
+
out = torch.cumsum(data, dim)
|
|
90
|
+
|
|
91
|
+
if exclusive:
|
|
92
|
+
size = list(data.size())
|
|
93
|
+
size[dim] = 1
|
|
94
|
+
zeros = out.new_zeros(size)
|
|
95
|
+
out = torch.cat([zeros, out], dim)
|
|
96
|
+
return out
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
|
|
100
|
+
r''' Broadcast :attr:`src` according to :attr:`other`, originally from the
|
|
101
|
+
library `pytorch_scatter`.
|
|
102
|
+
'''
|
|
103
|
+
|
|
104
|
+
if dim < 0:
|
|
105
|
+
dim = other.dim() + dim
|
|
106
|
+
|
|
107
|
+
if src.dim() == 1:
|
|
108
|
+
for _ in range(0, dim):
|
|
109
|
+
src = src.unsqueeze(0)
|
|
110
|
+
for _ in range(src.dim(), other.dim()):
|
|
111
|
+
src = src.unsqueeze(-1)
|
|
112
|
+
|
|
113
|
+
src = src.expand_as(other)
|
|
114
|
+
return src
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
|
|
118
|
+
out: Optional[torch.Tensor] = None,
|
|
119
|
+
dim_size: Optional[int] = None,) -> torch.Tensor:
|
|
120
|
+
r''' Reduces all values from the :attr:`src` tensor into :attr:`out` at the
|
|
121
|
+
indices specified in the :attr:`index` tensor along a given axis :attr:`dim`.
|
|
122
|
+
This is just a wrapper of :func:`torch.scatter` in a boardcasting fashion.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
src (torch.Tensor): The source tensor.
|
|
126
|
+
index (torch.Tensor): The indices of elements to scatter.
|
|
127
|
+
dim (torch.Tensor): The axis along which to index, (default: :obj:`-1`).
|
|
128
|
+
out (torch.Tensor or None): The destination tensor.
|
|
129
|
+
dim_size (int or None): If :attr:`out` is not given, automatically create
|
|
130
|
+
output with size :attr:`dim_size` at dimension :attr:`dim`. If
|
|
131
|
+
:attr:`dim_size` is not given, a minimal sized output tensor according
|
|
132
|
+
to :obj:`index.max() + 1` is returned.
|
|
133
|
+
'''
|
|
134
|
+
|
|
135
|
+
index = broadcast(index, src, dim)
|
|
136
|
+
|
|
137
|
+
if out is None:
|
|
138
|
+
size = list(src.size())
|
|
139
|
+
if dim_size is not None:
|
|
140
|
+
size[dim] = dim_size
|
|
141
|
+
elif index.numel() == 0:
|
|
142
|
+
size[dim] = 0
|
|
143
|
+
else:
|
|
144
|
+
size[dim] = int(index.max()) + 1
|
|
145
|
+
out = torch.zeros(size, dtype=src.dtype, device=src.device)
|
|
146
|
+
|
|
147
|
+
return out.scatter_add_(dim, index, src)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def xavier_uniform_(weights: torch.Tensor):
|
|
151
|
+
r''' Initialize convolution weights with the same method as
|
|
152
|
+
:obj:`torch.nn.init.xavier_uniform_`.
|
|
153
|
+
|
|
154
|
+
:obj:`torch.nn.init.xavier_uniform_` initialize a tensor with shape
|
|
155
|
+
:obj:`(out_c, in_c, kdim)`, which can not be used in :class:`ocnn.nn.OctreeConv`
|
|
156
|
+
since the the shape of :attr:`OctreeConv.weights` is :obj:`(kdim, in_c,
|
|
157
|
+
out_c)`.
|
|
158
|
+
'''
|
|
159
|
+
|
|
160
|
+
shape = weights.shape # (kernel_dim, in_conv, out_conv)
|
|
161
|
+
fan_in = shape[0] * shape[1]
|
|
162
|
+
fan_out = shape[0] * shape[2]
|
|
163
|
+
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
|
164
|
+
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
|
165
|
+
|
|
166
|
+
torch.nn.init.uniform_(weights, -a, a)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def resize_with_last_val(list_in: list, num: int = 3):
|
|
170
|
+
r''' Resizes the number of elements of :attr:`list_in` to :attr:`num` with
|
|
171
|
+
the last element of :attr:`list_in` if its number of elements is smaller
|
|
172
|
+
than :attr:`num`.
|
|
173
|
+
'''
|
|
174
|
+
|
|
175
|
+
assert (type(list_in) is list and len(list_in) < num + 1)
|
|
176
|
+
for _ in range(len(list_in), num):
|
|
177
|
+
list_in.append(list_in[-1])
|
|
178
|
+
return list_in
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def list2str(list_in: list):
|
|
182
|
+
r''' Returns a string representation of :attr:`list_in`.
|
|
183
|
+
'''
|
|
184
|
+
|
|
185
|
+
out = [str(x) for x in list_in]
|
|
186
|
+
return ''.join(out)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def build_example_octree(depth: int = 5, full_depth: int = 2, pt_num: int = 3):
|
|
190
|
+
r''' Builds an example octree on CPU from at most 3 points.
|
|
191
|
+
'''
|
|
192
|
+
# initialize the point cloud
|
|
193
|
+
points = torch.Tensor([[-1, -1, -1], [0, 0, -1], [0.0625, 0.0625, -1]])
|
|
194
|
+
normals = torch.Tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0]])
|
|
195
|
+
features = torch.Tensor([[1, -1], [2, -2], [3, -3]])
|
|
196
|
+
labels = torch.Tensor([[0], [2], [2]])
|
|
197
|
+
|
|
198
|
+
assert pt_num <= 3 and pt_num > 0
|
|
199
|
+
point_cloud = ocnn.octree.Points(
|
|
200
|
+
points[:pt_num], normals[:pt_num], features[:pt_num], labels[:pt_num])
|
|
201
|
+
|
|
202
|
+
# build octree
|
|
203
|
+
octree = ocnn.octree.Octree(depth, full_depth)
|
|
204
|
+
octree.build_octree(point_cloud)
|
|
205
|
+
return octree
|