ocnn 2.2.0__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ocnn/modules/resblocks.py CHANGED
@@ -1,124 +1,124 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.utils.checkpoint
10
-
11
- from ocnn.octree import Octree
12
- from ocnn.nn import OctreeMaxPool
13
- from ocnn.modules import Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn, OctreeConvBn
14
-
15
-
16
- class OctreeResBlock(torch.nn.Module):
17
- r''' Octree-based ResNet block in a bottleneck style. The block is composed of
18
- a series of :obj:`Conv1x1`, :obj:`Conv3x3`, and :obj:`Conv1x1`.
19
-
20
- Args:
21
- in_channels (int): Number of input channels.
22
- out_channels (int): Number of output channels.
23
- stride (int): The stride of the block (:obj:`1` or :obj:`2`).
24
- bottleneck (int): The input and output channels of the :obj:`Conv3x3` is
25
- equal to the input channel divided by :attr:`bottleneck`.
26
- nempty (bool): If True, only performs the convolution on non-empty
27
- octree nodes.
28
- '''
29
-
30
- def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
31
- bottleneck: int = 4, nempty: bool = False):
32
- super().__init__()
33
- self.in_channels = in_channels
34
- self.out_channels = out_channels
35
- self.bottleneck = bottleneck
36
- self.stride = stride
37
- channelb = int(out_channels / bottleneck)
38
-
39
- if self.stride == 2:
40
- self.max_pool = OctreeMaxPool(nempty)
41
- self.conv1x1a = Conv1x1BnRelu(in_channels, channelb)
42
- self.conv3x3 = OctreeConvBnRelu(channelb, channelb, nempty=nempty)
43
- self.conv1x1b = Conv1x1Bn(channelb, out_channels)
44
- if self.in_channels != self.out_channels:
45
- self.conv1x1c = Conv1x1Bn(in_channels, out_channels)
46
- self.relu = torch.nn.ReLU(inplace=True)
47
-
48
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
49
- r''''''
50
-
51
- if self.stride == 2:
52
- data = self.max_pool(data, octree, depth)
53
- depth = depth - 1
54
- conv1 = self.conv1x1a(data)
55
- conv2 = self.conv3x3(conv1, octree, depth)
56
- conv3 = self.conv1x1b(conv2)
57
- if self.in_channels != self.out_channels:
58
- data = self.conv1x1c(data)
59
- out = self.relu(conv3 + data)
60
- return out
61
-
62
-
63
- class OctreeResBlock2(torch.nn.Module):
64
- r''' Basic Octree-based ResNet block. The block is composed of
65
- a series of :obj:`Conv3x3` and :obj:`Conv3x3`.
66
-
67
- Refer to :class:`OctreeResBlock` for the details of arguments.
68
- '''
69
-
70
- def __init__(self, in_channels, out_channels, stride=1, bottleneck=1,
71
- nempty=False):
72
- super().__init__()
73
- self.in_channels = in_channels
74
- self.out_channels = out_channels
75
- self.stride = stride
76
- channelb = int(out_channels / bottleneck)
77
-
78
- if self.stride == 2:
79
- self.maxpool = OctreeMaxPool(self.depth)
80
- self.conv3x3a = OctreeConvBnRelu(in_channels, channelb, nempty=nempty)
81
- self.conv3x3b = OctreeConvBn(channelb, out_channels, nempty=nempty)
82
- if self.in_channels != self.out_channels:
83
- self.conv1x1 = Conv1x1Bn(in_channels, out_channels)
84
- self.relu = torch.nn.ReLU(inplace=True)
85
-
86
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
87
- r''''''
88
-
89
- if self.stride == 2:
90
- data = self.maxpool(data, octree, depth)
91
- depth = depth - 1
92
- conv1 = self.conv3x3a(data, octree, depth)
93
- conv2 = self.conv3x3b(conv1, octree, depth)
94
- if self.in_channels != self.out_channels:
95
- data = self.conv1x1(data)
96
- out = self.relu(conv2 + data)
97
- return out
98
-
99
-
100
- class OctreeResBlocks(torch.nn.Module):
101
- r''' A sequence of :attr:`resblk_num` ResNet blocks.
102
- '''
103
-
104
- def __init__(self, in_channels, out_channels, resblk_num, bottleneck=4,
105
- nempty=False, resblk=OctreeResBlock, use_checkpoint=False):
106
- super().__init__()
107
- self.resblk_num = resblk_num
108
- self.use_checkpoint = use_checkpoint
109
- channels = [in_channels] + [out_channels] * resblk_num
110
-
111
- self.resblks = torch.nn.ModuleList(
112
- [resblk(channels[i], channels[i+1], 1, bottleneck, nempty)
113
- for i in range(self.resblk_num)])
114
-
115
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
116
- r''''''
117
-
118
- for i in range(self.resblk_num):
119
- if self.use_checkpoint:
120
- data = torch.utils.checkpoint.checkpoint(
121
- self.resblks[i], data, octree, depth)
122
- else:
123
- data = self.resblks[i](data, octree, depth)
124
- return data
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.utils.checkpoint
10
+
11
+ from ocnn.octree import Octree
12
+ from ocnn.nn import OctreeMaxPool
13
+ from ocnn.modules import Conv1x1BnRelu, OctreeConvBnRelu, Conv1x1Bn, OctreeConvBn
14
+
15
+
16
+ class OctreeResBlock(torch.nn.Module):
17
+ r''' Octree-based ResNet block in a bottleneck style. The block is composed of
18
+ a series of :obj:`Conv1x1`, :obj:`Conv3x3`, and :obj:`Conv1x1`.
19
+
20
+ Args:
21
+ in_channels (int): Number of input channels.
22
+ out_channels (int): Number of output channels.
23
+ stride (int): The stride of the block (:obj:`1` or :obj:`2`).
24
+ bottleneck (int): The input and output channels of the :obj:`Conv3x3` is
25
+ equal to the input channel divided by :attr:`bottleneck`.
26
+ nempty (bool): If True, only performs the convolution on non-empty
27
+ octree nodes.
28
+ '''
29
+
30
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 1,
31
+ bottleneck: int = 4, nempty: bool = False):
32
+ super().__init__()
33
+ self.in_channels = in_channels
34
+ self.out_channels = out_channels
35
+ self.bottleneck = bottleneck
36
+ self.stride = stride
37
+ channelb = int(out_channels / bottleneck)
38
+
39
+ if self.stride == 2:
40
+ self.max_pool = OctreeMaxPool(nempty)
41
+ self.conv1x1a = Conv1x1BnRelu(in_channels, channelb)
42
+ self.conv3x3 = OctreeConvBnRelu(channelb, channelb, nempty=nempty)
43
+ self.conv1x1b = Conv1x1Bn(channelb, out_channels)
44
+ if self.in_channels != self.out_channels:
45
+ self.conv1x1c = Conv1x1Bn(in_channels, out_channels)
46
+ self.relu = torch.nn.ReLU(inplace=True)
47
+
48
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
49
+ r''''''
50
+
51
+ if self.stride == 2:
52
+ data = self.max_pool(data, octree, depth)
53
+ depth = depth - 1
54
+ conv1 = self.conv1x1a(data)
55
+ conv2 = self.conv3x3(conv1, octree, depth)
56
+ conv3 = self.conv1x1b(conv2)
57
+ if self.in_channels != self.out_channels:
58
+ data = self.conv1x1c(data)
59
+ out = self.relu(conv3 + data)
60
+ return out
61
+
62
+
63
+ class OctreeResBlock2(torch.nn.Module):
64
+ r''' Basic Octree-based ResNet block. The block is composed of
65
+ a series of :obj:`Conv3x3` and :obj:`Conv3x3`.
66
+
67
+ Refer to :class:`OctreeResBlock` for the details of arguments.
68
+ '''
69
+
70
+ def __init__(self, in_channels, out_channels, stride=1, bottleneck=1,
71
+ nempty=False):
72
+ super().__init__()
73
+ self.in_channels = in_channels
74
+ self.out_channels = out_channels
75
+ self.stride = stride
76
+ channelb = int(out_channels / bottleneck)
77
+
78
+ if self.stride == 2:
79
+ self.maxpool = OctreeMaxPool(self.depth)
80
+ self.conv3x3a = OctreeConvBnRelu(in_channels, channelb, nempty=nempty)
81
+ self.conv3x3b = OctreeConvBn(channelb, out_channels, nempty=nempty)
82
+ if self.in_channels != self.out_channels:
83
+ self.conv1x1 = Conv1x1Bn(in_channels, out_channels)
84
+ self.relu = torch.nn.ReLU(inplace=True)
85
+
86
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
87
+ r''''''
88
+
89
+ if self.stride == 2:
90
+ data = self.maxpool(data, octree, depth)
91
+ depth = depth - 1
92
+ conv1 = self.conv3x3a(data, octree, depth)
93
+ conv2 = self.conv3x3b(conv1, octree, depth)
94
+ if self.in_channels != self.out_channels:
95
+ data = self.conv1x1(data)
96
+ out = self.relu(conv2 + data)
97
+ return out
98
+
99
+
100
+ class OctreeResBlocks(torch.nn.Module):
101
+ r''' A sequence of :attr:`resblk_num` ResNet blocks.
102
+ '''
103
+
104
+ def __init__(self, in_channels, out_channels, resblk_num, bottleneck=4,
105
+ nempty=False, resblk=OctreeResBlock, use_checkpoint=False):
106
+ super().__init__()
107
+ self.resblk_num = resblk_num
108
+ self.use_checkpoint = use_checkpoint
109
+ channels = [in_channels] + [out_channels] * resblk_num
110
+
111
+ self.resblks = torch.nn.ModuleList(
112
+ [resblk(channels[i], channels[i+1], 1, bottleneck, nempty)
113
+ for i in range(self.resblk_num)])
114
+
115
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
116
+ r''''''
117
+
118
+ for i in range(self.resblk_num):
119
+ if self.use_checkpoint:
120
+ data = torch.utils.checkpoint.checkpoint(
121
+ self.resblks[i], data, octree, depth)
122
+ else:
123
+ data = self.resblks[i](data, octree, depth)
124
+ return data
ocnn/nn/__init__.py CHANGED
@@ -1,40 +1,42 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- from .octree2vox import octree2voxel, Octree2Voxel
9
- from .octree2col import octree2col, col2octree
10
- from .octree_pad import octree_pad, octree_depad
11
- from .octree_interp import (octree_nearest_pts, octree_linear_pts,
12
- OctreeInterp, OctreeUpsample)
13
- from .octree_pool import (octree_max_pool, OctreeMaxPool,
14
- octree_max_unpool, OctreeMaxUnpool,
15
- octree_global_pool, OctreeGlobalPool,
16
- octree_avg_pool, OctreeAvgPool,)
17
- from .octree_conv import OctreeConv, OctreeDeconv
18
- from .octree_dwconv import OctreeDWConv
19
- from .octree_norm import OctreeInstanceNorm, OctreeBatchNorm
20
- from .octree_drop import OctreeDropPath
21
-
22
-
23
- __all__ = [
24
- 'octree2voxel',
25
- 'octree2col', 'col2octree',
26
- 'octree_pad', 'octree_depad',
27
- 'octree_nearest_pts', 'octree_linear_pts',
28
- 'octree_max_pool', 'octree_max_unpool',
29
- 'octree_global_pool', 'octree_avg_pool',
30
- 'Octree2Voxel',
31
- 'OctreeMaxPool', 'OctreeMaxUnpool',
32
- 'OctreeGlobalPool', 'OctreeAvgPool',
33
- 'OctreeConv', 'OctreeDeconv',
34
- 'OctreeDWConv',
35
- 'OctreeInterp', 'OctreeUpsample',
36
- 'OctreeInstanceNorm', 'OctreeBatchNorm',
37
- 'OctreeDropPath',
38
- ]
39
-
40
- classes = __all__
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ from .octree2vox import octree2voxel, Octree2Voxel
9
+ from .octree2col import octree2col, col2octree
10
+ from .octree_pad import octree_pad, octree_depad
11
+ from .octree_interp import (octree_nearest_pts, octree_linear_pts,
12
+ OctreeInterp, OctreeUpsample)
13
+ from .octree_pool import (octree_max_pool, OctreeMaxPool,
14
+ octree_max_unpool, OctreeMaxUnpool,
15
+ octree_global_pool, OctreeGlobalPool,
16
+ octree_avg_pool, OctreeAvgPool,)
17
+ from .octree_conv import OctreeConv, OctreeDeconv
18
+ from .octree_dwconv import OctreeDWConv
19
+ from .octree_norm import OctreeBatchNorm, OctreeGroupNorm, OctreeInstanceNorm
20
+ from .octree_drop import OctreeDropPath
21
+ from .octree_align import search_value, octree_align
22
+
23
+
24
+ __all__ = [
25
+ 'octree2voxel',
26
+ 'octree2col', 'col2octree',
27
+ 'octree_pad', 'octree_depad',
28
+ 'octree_nearest_pts', 'octree_linear_pts',
29
+ 'octree_max_pool', 'octree_max_unpool',
30
+ 'octree_global_pool', 'octree_avg_pool',
31
+ 'Octree2Voxel',
32
+ 'OctreeMaxPool', 'OctreeMaxUnpool',
33
+ 'OctreeGlobalPool', 'OctreeAvgPool',
34
+ 'OctreeConv', 'OctreeDeconv',
35
+ 'OctreeDWConv',
36
+ 'OctreeInterp', 'OctreeUpsample',
37
+ 'OctreeInstanceNorm', 'OctreeBatchNorm', 'OctreeGroupNorm',
38
+ 'OctreeDropPath',
39
+ 'search_value', 'octree_align',
40
+ ]
41
+
42
+ classes = __all__
ocnn/nn/octree2col.py CHANGED
@@ -1,53 +1,53 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
-
11
- from ocnn.octree import Octree
12
- from ocnn.utils import scatter_add
13
-
14
-
15
- def octree2col(data: torch.Tensor, octree: Octree, depth: int,
16
- kernel_size: str = '333', stride: int = 1, nempty: bool = False):
17
- r''' Gathers the neighboring features for convolutions.
18
-
19
- Args:
20
- data (torch.Tensor): The input data.
21
- octree (Octree): The corresponding octree.
22
- depth (int): The depth of current octree.
23
- kernel_size (str): The kernel shape, choose from :obj:`333`, :obj:`311`,
24
- :obj:`131`, :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and
25
- :obj:`313`.
26
- stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
27
- stride is :obj:`2`, it always returns the neighborhood of the first
28
- siblings, and the number of elements of output tensor is
29
- :obj:`octree.nnum[depth] / 8`.
30
- nempty (bool): If True, only returns the neighborhoods of the non-empty
31
- octree nodes.
32
- '''
33
-
34
- neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
35
- size = (neigh.shape[0], neigh.shape[1], data.shape[1])
36
- out = torch.zeros(size, dtype=data.dtype, device=data.device)
37
- valid = neigh >= 0
38
- out[valid] = data[neigh[valid]] # (N, K, C)
39
- return out
40
-
41
-
42
- def col2octree(data: torch.Tensor, octree: Octree, depth: int,
43
- kernel_size: str = '333', stride: int = 1, nempty: bool = False):
44
- r''' Scatters the convolution features to an octree.
45
-
46
- Please refer to :func:`octree2col` for the usage of function parameters.
47
- '''
48
-
49
- neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
50
- valid = neigh >= 0
51
- dim_size = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
52
- out = scatter_add(data[valid], neigh[valid], dim=0, dim_size=dim_size)
53
- return out
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+
11
+ from ocnn.octree import Octree
12
+ from ocnn.utils import scatter_add
13
+
14
+
15
+ def octree2col(data: torch.Tensor, octree: Octree, depth: int,
16
+ kernel_size: str = '333', stride: int = 1, nempty: bool = False):
17
+ r''' Gathers the neighboring features for convolutions.
18
+
19
+ Args:
20
+ data (torch.Tensor): The input data.
21
+ octree (Octree): The corresponding octree.
22
+ depth (int): The depth of current octree.
23
+ kernel_size (str): The kernel shape, choose from :obj:`333`, :obj:`311`,
24
+ :obj:`131`, :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and
25
+ :obj:`313`.
26
+ stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
27
+ stride is :obj:`2`, it always returns the neighborhood of the first
28
+ siblings, and the number of elements of output tensor is
29
+ :obj:`octree.nnum[depth] / 8`.
30
+ nempty (bool): If True, only returns the neighborhoods of the non-empty
31
+ octree nodes.
32
+ '''
33
+
34
+ neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
35
+ size = (neigh.shape[0], neigh.shape[1], data.shape[1])
36
+ out = torch.zeros(size, dtype=data.dtype, device=data.device)
37
+ valid = neigh >= 0
38
+ out[valid] = data[neigh[valid]] # (N, K, C)
39
+ return out
40
+
41
+
42
+ def col2octree(data: torch.Tensor, octree: Octree, depth: int,
43
+ kernel_size: str = '333', stride: int = 1, nempty: bool = False):
44
+ r''' Scatters the convolution features to an octree.
45
+
46
+ Please refer to :func:`octree2col` for the usage of function parameters.
47
+ '''
48
+
49
+ neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
50
+ valid = neigh >= 0
51
+ dim_size = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
52
+ out = scatter_add(data[valid], neigh[valid], dim=0, dim_size=dim_size)
53
+ return out
ocnn/nn/octree2vox.py CHANGED
@@ -1,50 +1,50 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
-
10
- from ocnn.octree import Octree
11
-
12
-
13
- def octree2voxel(data: torch.Tensor, octree: Octree, depth: int,
14
- nempty: bool = False):
15
- r''' Converts the input feature to the full-voxel-based representation.
16
-
17
- Args:
18
- data (torch.Tensor): The input feature.
19
- octree (Octree): The corresponding octree.
20
- depth (int): The depth of current octree.
21
- nempty (bool): If True, :attr:`data` only contains the features of non-empty
22
- octree nodes.
23
- '''
24
-
25
- x, y, z, b = octree.xyzb(depth, nempty)
26
-
27
- num = 1 << depth
28
- channel = data.shape[1]
29
- vox = data.new_zeros([octree.batch_size, num, num, num, channel])
30
- vox[b, x, y, z] = data
31
- return vox
32
-
33
-
34
- class Octree2Voxel(torch.nn.Module):
35
- r''' Converts the input feature to the full-voxel-based representation.
36
-
37
- Please refer to :func:`octree2voxel` for details.
38
- '''
39
-
40
- def __init__(self, nempty: bool = False):
41
- super().__init__()
42
- self.nempty = nempty
43
-
44
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
45
- r''''''
46
-
47
- return octree2voxel(data, octree, depth, self.nempty)
48
-
49
- def extra_repr(self) -> str:
50
- return 'nempty={}'.format(self.nempty)
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ from ocnn.octree import Octree
11
+
12
+
13
+ def octree2voxel(data: torch.Tensor, octree: Octree, depth: int,
14
+ nempty: bool = False):
15
+ r''' Converts the input feature to the full-voxel-based representation.
16
+
17
+ Args:
18
+ data (torch.Tensor): The input feature.
19
+ octree (Octree): The corresponding octree.
20
+ depth (int): The depth of current octree.
21
+ nempty (bool): If True, :attr:`data` only contains the features of non-empty
22
+ octree nodes.
23
+ '''
24
+
25
+ x, y, z, b = octree.xyzb(depth, nempty)
26
+
27
+ num = 1 << depth
28
+ channel = data.shape[1]
29
+ vox = data.new_zeros([octree.batch_size, num, num, num, channel])
30
+ vox[b, x, y, z] = data
31
+ return vox
32
+
33
+
34
+ class Octree2Voxel(torch.nn.Module):
35
+ r''' Converts the input feature to the full-voxel-based representation.
36
+
37
+ Please refer to :func:`octree2voxel` for details.
38
+ '''
39
+
40
+ def __init__(self, nempty: bool = False):
41
+ super().__init__()
42
+ self.nempty = nempty
43
+
44
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
45
+ r''''''
46
+
47
+ return octree2voxel(data, octree, depth, self.nempty)
48
+
49
+ def extra_repr(self) -> str:
50
+ return 'nempty={}'.format(self.nempty)
@@ -0,0 +1,46 @@
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ from ocnn.octree import Octree
11
+
12
+
13
+ def search_value(value: torch.Tensor, key: torch.Tensor, query: torch.Tensor):
14
+ r''' Searches values according to sorted shuffled keys.
15
+
16
+ Args:
17
+ value (torch.Tensor): The input tensor with shape (N, C).
18
+ key (torch.Tensor): The key tensor corresponds to :attr:`value` with shape
19
+ (N,), which contains sorted shuffled keys of an octree.
20
+ query (torch.Tensor): The query tensor, which also contains shuffled keys.
21
+ '''
22
+
23
+ # deal with out-of-bound queries, the indices of these queries
24
+ # returned by torch.searchsorted equal to `key.shape[0]`
25
+ out_of_bound = query > key[-1]
26
+
27
+ # search
28
+ idx = torch.searchsorted(key, query)
29
+ idx[out_of_bound] = -1 # to avoid overflow when executing the following line
30
+ found = key[idx] == query
31
+
32
+ # assign the found value to the output
33
+ out = torch.zeros(query.shape[0], value.shape[1], device=value.device)
34
+ out[found] = value[idx[found]]
35
+ return out
36
+
37
+
38
+ def octree_align(value: torch.Tensor, octree: Octree, octree_query: Octree,
39
+ depth: int, nempty: bool = False):
40
+ r''' Wraps :func:`octree_align` to take octrees as input for convenience.
41
+ '''
42
+
43
+ key = octree.key(depth, nempty)
44
+ query = octree_query.key(depth, nempty)
45
+ assert key.shape[0] == value.shape[0]
46
+ return search_value(value, key, query)