ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ocnn/__init__.py +24 -24
  2. ocnn/dataset.py +160 -160
  3. ocnn/models/__init__.py +29 -29
  4. ocnn/models/autoencoder.py +155 -155
  5. ocnn/models/hrnet.py +192 -192
  6. ocnn/models/image2shape.py +128 -128
  7. ocnn/models/lenet.py +46 -46
  8. ocnn/models/ounet.py +94 -94
  9. ocnn/models/resnet.py +53 -53
  10. ocnn/models/segnet.py +72 -72
  11. ocnn/models/unet.py +105 -105
  12. ocnn/modules/__init__.py +26 -26
  13. ocnn/modules/modules.py +303 -303
  14. ocnn/modules/resblocks.py +158 -158
  15. ocnn/nn/__init__.py +45 -44
  16. ocnn/nn/kernels/__init__.py +14 -0
  17. ocnn/nn/kernels/autotuner.py +416 -0
  18. ocnn/nn/kernels/config.py +67 -0
  19. ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
  20. ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
  21. ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
  22. ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
  23. ocnn/nn/kernels/utils.py +44 -0
  24. ocnn/nn/octree2col.py +53 -53
  25. ocnn/nn/octree2vox.py +50 -50
  26. ocnn/nn/octree_align.py +46 -46
  27. ocnn/nn/octree_conv.py +430 -429
  28. ocnn/nn/octree_conv_t.py +148 -0
  29. ocnn/nn/octree_drop.py +55 -55
  30. ocnn/nn/octree_dwconv.py +222 -222
  31. ocnn/nn/octree_gconv.py +79 -79
  32. ocnn/nn/octree_interp.py +196 -196
  33. ocnn/nn/octree_norm.py +126 -126
  34. ocnn/nn/octree_pad.py +39 -39
  35. ocnn/nn/octree_pool.py +200 -200
  36. ocnn/octree/__init__.py +22 -22
  37. ocnn/octree/octree.py +770 -770
  38. ocnn/octree/points.py +384 -323
  39. ocnn/octree/shuffled_key.py +115 -115
  40. ocnn/utils.py +205 -205
  41. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
  42. ocnn-2.3.0.dist-info/RECORD +45 -0
  43. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
  44. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
  45. ocnn-2.2.8.dist-info/RECORD +0 -36
  46. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
1
+ import math
2
+ import torch
3
+ import triton
4
+ import triton.language as tl
5
+ from .utils import get_num_sm
6
+ from .autotuner import triton_autotune, autotune
7
+ from . import config
8
+ from .conv_fwd_implicit_gemm import conv_fwd_implicit_gemm_kernel
9
+
10
+
11
+ @triton_autotune(
12
+ configs=config.autotune_config,
13
+ key=['LOGN', 'Ci', 'Co', 'V', 'SPLITK', 'allow_tf32'],
14
+ )
15
+ @triton.jit
16
+ def conv_fwd_implicit_gemm_splitk_kernel(
17
+ input,
18
+ weight,
19
+ bias,
20
+ neighbor,
21
+ output,
22
+ # Tensor dimensions
23
+ N, LOGN, Ci, Co, V: tl.constexpr,
24
+ # Meta-parameters
25
+ B1: tl.constexpr, # Block size for N dimension
26
+ B2: tl.constexpr, # Block size for Co dimension
27
+ BK: tl.constexpr, # Block size for K dimension (V * Ci)
28
+ SPLITK: tl.constexpr, # Split K dimension
29
+ allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
30
+ ):
31
+ """
32
+ Sparse submanifold convolution forward kernel using implicit GEMM with split K dimension.
33
+
34
+ Args:
35
+ input (pointer): A pointer to the input tensor of shape (N, Ci)
36
+ weight (pointer): A pointer to the weight tensor of shape (Co, V, Ci)
37
+ bias (pointer): A pointer to the bias tensor of shape (Co)
38
+ neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
39
+ output (pointer): A pointer to the output tensor of shape (N, Co)
40
+ """
41
+ block_id_k = tl.program_id(axis=1) # SplitK dimension
42
+ block_id = tl.program_id(axis=0)
43
+ block_dim_co = tl.cdiv(Co, B2)
44
+ block_id_co = block_id % block_dim_co
45
+ block_id_n = block_id // block_dim_co
46
+
47
+ # Create pointers for submatrices of A and B.
48
+ num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
49
+ k_start = tl.cdiv(num_k * V * block_id_k, SPLITK)
50
+ k_end = tl.cdiv(num_k * V * (block_id_k + 1), SPLITK)
51
+ offset_n = (block_id_n * B1 + tl.arange(0, B1)) % N # (B1,)
52
+ offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
53
+ offset_k = tl.arange(0, BK) # (BK,)
54
+
55
+ # Create a block of the output matrix C.
56
+ accumulator = tl.zeros((B1, B2), dtype=tl.float32)
57
+ curr_v = k_start // num_k
58
+ curr_bk = k_start % num_k
59
+ weight_offset_base = curr_v * Ci + curr_bk * BK
60
+
61
+ weight_ptr = weight + weight_offset_base + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
62
+
63
+ # Iterate along V*Ci dimension.
64
+ for k in range(k_start, k_end):
65
+ v = k // num_k
66
+ bk = k % num_k
67
+ # Calculate pointers to input matrix.
68
+ neighbor_offset_n = tl.load(neighbor + offset_n * V + v).to(tl.int64) # (B1,)
69
+ input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
70
+ # Load the next block of input and weight.
71
+ neigh_mask = neighbor_offset_n != -1
72
+ k_mask = offset_k < Ci - bk * BK
73
+ input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
74
+ weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
75
+ # Accumulate along the K dimension.
76
+ accumulator = tl.dot(input_block, weight_block, accumulator,
77
+ input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
78
+ # Advance the pointers to the next Ci block.
79
+ weight_ptr += min(BK, Ci - bk * BK)
80
+
81
+ # add bias
82
+ if bias is not None and block_id_k == 0:
83
+ bias_block = tl.load(bias + offset_co)
84
+ accumulator += bias_block[None, :]
85
+
86
+ # Write back the block of the output matrix with masks.
87
+ out_offset_n = block_id_n * B1 + tl.arange(0, B1)
88
+ out_offset_co = block_id_co * B2 + tl.arange(0, B2)
89
+ out_ptr = output + block_id_k * N * Co + (out_offset_n[:, None] * Co + out_offset_co[None, :])
90
+ out_mask = (out_offset_n[:, None] < N) & (out_offset_co[None, :] < Co)
91
+ tl.store(out_ptr, accumulator, mask=out_mask)
92
+
93
+
94
+ def conv_fwd_implicit_gemm_splitk_configs(input, weight, bias, neighbor):
95
+ N, Co = neighbor.shape[0], weight.shape[0]
96
+ MAX_NB1 = (N + 128 - 1) // 128
97
+ MAX_NB2 = (Co + 128 - 1) // 128
98
+ NUM_BLOCKS = MAX_NB1 * MAX_NB2
99
+ MIN_NUM_BLOCKS = get_num_sm()
100
+ MAX_NUM_BLOCKS = 32 * get_num_sm()
101
+ MIN_NUM_BLOCKS_LOG2 = max(0, int(math.log2(MIN_NUM_BLOCKS / NUM_BLOCKS)))
102
+ MAX_NUM_BLOCKS_LOG2 = max(1, int(math.log2(MAX_NUM_BLOCKS / NUM_BLOCKS) + 1))
103
+ configs = []
104
+ for i in range(MIN_NUM_BLOCKS_LOG2, MAX_NUM_BLOCKS_LOG2):
105
+ configs.append({'SPLITK': 2 ** i})
106
+ return configs
107
+
108
+
109
+ def conv_fwd_implicit_gemm_splitk_keys(input, weight, bias, neighbor):
110
+ N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
111
+ return f'(2^{int(math.log2(N))}, {Ci}, {Co}, {V})'
112
+
113
+
114
+ @autotune(
115
+ config_fn=conv_fwd_implicit_gemm_splitk_configs,
116
+ key_fn=conv_fwd_implicit_gemm_splitk_keys,
117
+ )
118
+ def conv_fwd_implicit_gemm_splitk(
119
+ input: torch.Tensor,
120
+ weight: torch.Tensor,
121
+ bias: torch.Tensor,
122
+ neighbor: torch.Tensor,
123
+ SPLITK: int = 1,
124
+ ) -> torch.Tensor:
125
+ assert input.shape[1] == weight.shape[2], "Incompatible dimensions"
126
+ assert input.is_contiguous(), "Matrix input must be contiguous"
127
+ assert weight.is_contiguous(), "Matrix weight must be contiguous"
128
+ assert neighbor.is_contiguous(), "Matrix neighbor must be contiguous"
129
+ N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
130
+ LOGN = int(math.log2(N))
131
+ # Launch the kernel.
132
+ if SPLITK == 1:
133
+ output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
134
+ grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
135
+ conv_fwd_implicit_gemm_kernel[grid](
136
+ input, weight, bias, neighbor, output,
137
+ N, LOGN, Ci, Co, V,
138
+ allow_tf32=config.allow_tf32,
139
+ )
140
+ return output
141
+ else:
142
+ output = torch.empty((SPLITK, N, Co), device=input.device, dtype=torch.float32)
143
+ grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']), SPLITK)
144
+ conv_fwd_implicit_gemm_splitk_kernel[grid](
145
+ input, weight, bias, neighbor, output,
146
+ N, LOGN, Ci, Co, V,
147
+ SPLITK=SPLITK,
148
+ allow_tf32=config.allow_tf32,
149
+ )
150
+ return output.sum(dim=0).to(input.dtype)
@@ -0,0 +1,44 @@
1
+ from typing import *
2
+ import torch
3
+ import triton
4
+
5
+
6
+ def get_gpu_name():
7
+ return torch.cuda.get_device_name()
8
+
9
+
10
+ def get_platform_name():
11
+ if torch.cuda.is_available():
12
+ if getattr(torch.version, 'hip', None) is not None:
13
+ return 'hip'
14
+ return 'cuda'
15
+ return 'unknown'
16
+
17
+
18
+ def get_num_sm():
19
+ return torch.cuda.get_device_properties("cuda").multi_processor_count
20
+
21
+
22
+ def get_autotune_config(
23
+ default: List[triton.Config] = None,
24
+ platform: Dict[str, List[triton.Config]] = None,
25
+ device: Dict[str, List[triton.Config]] = None,
26
+ ) -> List[triton.Config]:
27
+ """
28
+ Get the autotune configuration for the current platform and device.
29
+ """
30
+ if device is not None:
31
+ gpu_name = get_gpu_name()
32
+ for key, value in device.items():
33
+ if key.lower() in gpu_name.lower():
34
+ return value
35
+
36
+ if platform is not None:
37
+ platform_name = get_platform_name()
38
+ for key, value in platform.items():
39
+ if key.lower() in platform_name.lower():
40
+ return value
41
+
42
+ if default is None:
43
+ raise ValueError("No autotune configuration found for the current platform and device.")
44
+ return default
ocnn/nn/octree2col.py CHANGED
@@ -1,53 +1,53 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- import torch.nn
10
-
11
- from ocnn.octree import Octree
12
- from ocnn.utils import scatter_add
13
-
14
-
15
- def octree2col(data: torch.Tensor, octree: Octree, depth: int,
16
- kernel_size: str = '333', stride: int = 1, nempty: bool = False):
17
- r''' Gathers the neighboring features for convolutions.
18
-
19
- Args:
20
- data (torch.Tensor): The input data.
21
- octree (Octree): The corresponding octree.
22
- depth (int): The depth of current octree.
23
- kernel_size (str): The kernel shape, choose from :obj:`333`, :obj:`311`,
24
- :obj:`131`, :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and
25
- :obj:`313`.
26
- stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
27
- stride is :obj:`2`, it always returns the neighborhood of the first
28
- siblings, and the number of elements of output tensor is
29
- :obj:`octree.nnum[depth] / 8`.
30
- nempty (bool): If True, only returns the neighborhoods of the non-empty
31
- octree nodes.
32
- '''
33
-
34
- neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
35
- size = (neigh.shape[0], neigh.shape[1], data.shape[1])
36
- out = torch.zeros(size, dtype=data.dtype, device=data.device)
37
- valid = neigh >= 0
38
- out[valid] = data[neigh[valid]] # (N, K, C)
39
- return out
40
-
41
-
42
- def col2octree(data: torch.Tensor, octree: Octree, depth: int,
43
- kernel_size: str = '333', stride: int = 1, nempty: bool = False):
44
- r''' Scatters the convolution features to an octree.
45
-
46
- Please refer to :func:`octree2col` for the usage of function parameters.
47
- '''
48
-
49
- neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
50
- valid = neigh >= 0
51
- dim_size = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
52
- out = scatter_add(data[valid], neigh[valid], dim=0, dim_size=dim_size)
53
- return out
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+
11
+ from ocnn.octree import Octree
12
+ from ocnn.utils import scatter_add
13
+
14
+
15
+ def octree2col(data: torch.Tensor, octree: Octree, depth: int,
16
+ kernel_size: str = '333', stride: int = 1, nempty: bool = False):
17
+ r''' Gathers the neighboring features for convolutions.
18
+
19
+ Args:
20
+ data (torch.Tensor): The input data.
21
+ octree (Octree): The corresponding octree.
22
+ depth (int): The depth of current octree.
23
+ kernel_size (str): The kernel shape, choose from :obj:`333`, :obj:`311`,
24
+ :obj:`131`, :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and
25
+ :obj:`313`.
26
+ stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
27
+ stride is :obj:`2`, it always returns the neighborhood of the first
28
+ siblings, and the number of elements of output tensor is
29
+ :obj:`octree.nnum[depth] / 8`.
30
+ nempty (bool): If True, only returns the neighborhoods of the non-empty
31
+ octree nodes.
32
+ '''
33
+
34
+ neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
35
+ size = (neigh.shape[0], neigh.shape[1], data.shape[1])
36
+ out = torch.zeros(size, dtype=data.dtype, device=data.device)
37
+ valid = neigh >= 0
38
+ out[valid] = data[neigh[valid]] # (N, K, C)
39
+ return out
40
+
41
+
42
+ def col2octree(data: torch.Tensor, octree: Octree, depth: int,
43
+ kernel_size: str = '333', stride: int = 1, nempty: bool = False):
44
+ r''' Scatters the convolution features to an octree.
45
+
46
+ Please refer to :func:`octree2col` for the usage of function parameters.
47
+ '''
48
+
49
+ neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
50
+ valid = neigh >= 0
51
+ dim_size = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
52
+ out = scatter_add(data[valid], neigh[valid], dim=0, dim_size=dim_size)
53
+ return out
ocnn/nn/octree2vox.py CHANGED
@@ -1,50 +1,50 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
-
10
- from ocnn.octree import Octree
11
-
12
-
13
- def octree2voxel(data: torch.Tensor, octree: Octree, depth: int,
14
- nempty: bool = False):
15
- r''' Converts the input feature to the full-voxel-based representation.
16
-
17
- Args:
18
- data (torch.Tensor): The input feature.
19
- octree (Octree): The corresponding octree.
20
- depth (int): The depth of current octree.
21
- nempty (bool): If True, :attr:`data` only contains the features of non-empty
22
- octree nodes.
23
- '''
24
-
25
- x, y, z, b = octree.xyzb(depth, nempty)
26
-
27
- num = 1 << depth
28
- channel = data.shape[1]
29
- vox = data.new_zeros([octree.batch_size, num, num, num, channel])
30
- vox[b, x, y, z] = data
31
- return vox
32
-
33
-
34
- class Octree2Voxel(torch.nn.Module):
35
- r''' Converts the input feature to the full-voxel-based representation.
36
-
37
- Please refer to :func:`octree2voxel` for details.
38
- '''
39
-
40
- def __init__(self, nempty: bool = False):
41
- super().__init__()
42
- self.nempty = nempty
43
-
44
- def forward(self, data: torch.Tensor, octree: Octree, depth: int):
45
- r''''''
46
-
47
- return octree2voxel(data, octree, depth, self.nempty)
48
-
49
- def extra_repr(self) -> str:
50
- return 'nempty={}'.format(self.nempty)
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ from ocnn.octree import Octree
11
+
12
+
13
+ def octree2voxel(data: torch.Tensor, octree: Octree, depth: int,
14
+ nempty: bool = False):
15
+ r''' Converts the input feature to the full-voxel-based representation.
16
+
17
+ Args:
18
+ data (torch.Tensor): The input feature.
19
+ octree (Octree): The corresponding octree.
20
+ depth (int): The depth of current octree.
21
+ nempty (bool): If True, :attr:`data` only contains the features of non-empty
22
+ octree nodes.
23
+ '''
24
+
25
+ x, y, z, b = octree.xyzb(depth, nempty)
26
+
27
+ num = 1 << depth
28
+ channel = data.shape[1]
29
+ vox = data.new_zeros([octree.batch_size, num, num, num, channel])
30
+ vox[b, x, y, z] = data
31
+ return vox
32
+
33
+
34
+ class Octree2Voxel(torch.nn.Module):
35
+ r''' Converts the input feature to the full-voxel-based representation.
36
+
37
+ Please refer to :func:`octree2voxel` for details.
38
+ '''
39
+
40
+ def __init__(self, nempty: bool = False):
41
+ super().__init__()
42
+ self.nempty = nempty
43
+
44
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
45
+ r''''''
46
+
47
+ return octree2voxel(data, octree, depth, self.nempty)
48
+
49
+ def extra_repr(self) -> str:
50
+ return 'nempty={}'.format(self.nempty)
ocnn/nn/octree_align.py CHANGED
@@ -1,46 +1,46 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
-
10
- from ocnn.octree import Octree
11
-
12
-
13
- def search_value(value: torch.Tensor, key: torch.Tensor, query: torch.Tensor):
14
- r''' Searches values according to sorted shuffled keys.
15
-
16
- Args:
17
- value (torch.Tensor): The input tensor with shape (N, C).
18
- key (torch.Tensor): The key tensor corresponds to :attr:`value` with shape
19
- (N,), which contains sorted shuffled keys of an octree.
20
- query (torch.Tensor): The query tensor, which also contains shuffled keys.
21
- '''
22
-
23
- # deal with out-of-bound queries, the indices of these queries
24
- # returned by torch.searchsorted equal to `key.shape[0]`
25
- out_of_bound = query > key[-1]
26
-
27
- # search
28
- idx = torch.searchsorted(key, query)
29
- idx[out_of_bound] = -1 # to avoid overflow when executing the following line
30
- found = key[idx] == query
31
-
32
- # assign the found value to the output
33
- out = torch.zeros(query.shape[0], value.shape[1], device=value.device)
34
- out[found] = value[idx[found]]
35
- return out
36
-
37
-
38
- def octree_align(value: torch.Tensor, octree: Octree, octree_query: Octree,
39
- depth: int, nempty: bool = False):
40
- r''' Wraps :func:`octree_align` to take octrees as input for convenience.
41
- '''
42
-
43
- key = octree.key(depth, nempty)
44
- query = octree_query.key(depth, nempty)
45
- assert key.shape[0] == value.shape[0]
46
- return search_value(value, key, query)
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ from ocnn.octree import Octree
11
+
12
+
13
+ def search_value(value: torch.Tensor, key: torch.Tensor, query: torch.Tensor):
14
+ r''' Searches values according to sorted shuffled keys.
15
+
16
+ Args:
17
+ value (torch.Tensor): The input tensor with shape (N, C).
18
+ key (torch.Tensor): The key tensor corresponds to :attr:`value` with shape
19
+ (N,), which contains sorted shuffled keys of an octree.
20
+ query (torch.Tensor): The query tensor, which also contains shuffled keys.
21
+ '''
22
+
23
+ # deal with out-of-bound queries, the indices of these queries
24
+ # returned by torch.searchsorted equal to `key.shape[0]`
25
+ out_of_bound = query > key[-1]
26
+
27
+ # search
28
+ idx = torch.searchsorted(key, query)
29
+ idx[out_of_bound] = -1 # to avoid overflow when executing the following line
30
+ found = key[idx] == query
31
+
32
+ # assign the found value to the output
33
+ out = torch.zeros(query.shape[0], value.shape[1], device=value.device)
34
+ out[found] = value[idx[found]]
35
+ return out
36
+
37
+
38
+ def octree_align(value: torch.Tensor, octree: Octree, octree_query: Octree,
39
+ depth: int, nempty: bool = False):
40
+ r''' Wraps :func:`octree_align` to take octrees as input for convenience.
41
+ '''
42
+
43
+ key = octree.key(depth, nempty)
44
+ query = octree_query.key(depth, nempty)
45
+ assert key.shape[0] == value.shape[0]
46
+ return search_value(value, key, query)