ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +45 -44
- ocnn/nn/kernels/__init__.py +14 -0
- ocnn/nn/kernels/autotuner.py +416 -0
- ocnn/nn/kernels/config.py +67 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
- ocnn/nn/kernels/utils.py +44 -0
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +430 -429
- ocnn/nn/octree_conv_t.py +148 -0
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +770 -770
- ocnn/octree/points.py +384 -323
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
- ocnn-2.3.0.dist-info/RECORD +45 -0
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
- ocnn-2.2.8.dist-info/RECORD +0 -36
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import triton
|
|
4
|
+
import triton.language as tl
|
|
5
|
+
from .utils import get_num_sm
|
|
6
|
+
from .autotuner import triton_autotune, autotune
|
|
7
|
+
from . import config
|
|
8
|
+
from .conv_fwd_implicit_gemm import conv_fwd_implicit_gemm_kernel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton_autotune(
|
|
12
|
+
configs=config.autotune_config,
|
|
13
|
+
key=['LOGN', 'Ci', 'Co', 'V', 'SPLITK', 'allow_tf32'],
|
|
14
|
+
)
|
|
15
|
+
@triton.jit
|
|
16
|
+
def conv_fwd_implicit_gemm_splitk_kernel(
|
|
17
|
+
input,
|
|
18
|
+
weight,
|
|
19
|
+
bias,
|
|
20
|
+
neighbor,
|
|
21
|
+
output,
|
|
22
|
+
# Tensor dimensions
|
|
23
|
+
N, LOGN, Ci, Co, V: tl.constexpr,
|
|
24
|
+
# Meta-parameters
|
|
25
|
+
B1: tl.constexpr, # Block size for N dimension
|
|
26
|
+
B2: tl.constexpr, # Block size for Co dimension
|
|
27
|
+
BK: tl.constexpr, # Block size for K dimension (V * Ci)
|
|
28
|
+
SPLITK: tl.constexpr, # Split K dimension
|
|
29
|
+
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Sparse submanifold convolution forward kernel using implicit GEMM with split K dimension.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
input (pointer): A pointer to the input tensor of shape (N, Ci)
|
|
36
|
+
weight (pointer): A pointer to the weight tensor of shape (Co, V, Ci)
|
|
37
|
+
bias (pointer): A pointer to the bias tensor of shape (Co)
|
|
38
|
+
neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
|
|
39
|
+
output (pointer): A pointer to the output tensor of shape (N, Co)
|
|
40
|
+
"""
|
|
41
|
+
block_id_k = tl.program_id(axis=1) # SplitK dimension
|
|
42
|
+
block_id = tl.program_id(axis=0)
|
|
43
|
+
block_dim_co = tl.cdiv(Co, B2)
|
|
44
|
+
block_id_co = block_id % block_dim_co
|
|
45
|
+
block_id_n = block_id // block_dim_co
|
|
46
|
+
|
|
47
|
+
# Create pointers for submatrices of A and B.
|
|
48
|
+
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
|
|
49
|
+
k_start = tl.cdiv(num_k * V * block_id_k, SPLITK)
|
|
50
|
+
k_end = tl.cdiv(num_k * V * (block_id_k + 1), SPLITK)
|
|
51
|
+
offset_n = (block_id_n * B1 + tl.arange(0, B1)) % N # (B1,)
|
|
52
|
+
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
|
|
53
|
+
offset_k = tl.arange(0, BK) # (BK,)
|
|
54
|
+
|
|
55
|
+
# Create a block of the output matrix C.
|
|
56
|
+
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
|
|
57
|
+
curr_v = k_start // num_k
|
|
58
|
+
curr_bk = k_start % num_k
|
|
59
|
+
weight_offset_base = curr_v * Ci + curr_bk * BK
|
|
60
|
+
|
|
61
|
+
weight_ptr = weight + weight_offset_base + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
|
|
62
|
+
|
|
63
|
+
# Iterate along V*Ci dimension.
|
|
64
|
+
for k in range(k_start, k_end):
|
|
65
|
+
v = k // num_k
|
|
66
|
+
bk = k % num_k
|
|
67
|
+
# Calculate pointers to input matrix.
|
|
68
|
+
neighbor_offset_n = tl.load(neighbor + offset_n * V + v).to(tl.int64) # (B1,)
|
|
69
|
+
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
|
|
70
|
+
# Load the next block of input and weight.
|
|
71
|
+
neigh_mask = neighbor_offset_n != -1
|
|
72
|
+
k_mask = offset_k < Ci - bk * BK
|
|
73
|
+
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
|
|
74
|
+
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
|
|
75
|
+
# Accumulate along the K dimension.
|
|
76
|
+
accumulator = tl.dot(input_block, weight_block, accumulator,
|
|
77
|
+
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
|
78
|
+
# Advance the pointers to the next Ci block.
|
|
79
|
+
weight_ptr += min(BK, Ci - bk * BK)
|
|
80
|
+
|
|
81
|
+
# add bias
|
|
82
|
+
if bias is not None and block_id_k == 0:
|
|
83
|
+
bias_block = tl.load(bias + offset_co)
|
|
84
|
+
accumulator += bias_block[None, :]
|
|
85
|
+
|
|
86
|
+
# Write back the block of the output matrix with masks.
|
|
87
|
+
out_offset_n = block_id_n * B1 + tl.arange(0, B1)
|
|
88
|
+
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
|
|
89
|
+
out_ptr = output + block_id_k * N * Co + (out_offset_n[:, None] * Co + out_offset_co[None, :])
|
|
90
|
+
out_mask = (out_offset_n[:, None] < N) & (out_offset_co[None, :] < Co)
|
|
91
|
+
tl.store(out_ptr, accumulator, mask=out_mask)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def conv_fwd_implicit_gemm_splitk_configs(input, weight, bias, neighbor):
|
|
95
|
+
N, Co = neighbor.shape[0], weight.shape[0]
|
|
96
|
+
MAX_NB1 = (N + 128 - 1) // 128
|
|
97
|
+
MAX_NB2 = (Co + 128 - 1) // 128
|
|
98
|
+
NUM_BLOCKS = MAX_NB1 * MAX_NB2
|
|
99
|
+
MIN_NUM_BLOCKS = get_num_sm()
|
|
100
|
+
MAX_NUM_BLOCKS = 32 * get_num_sm()
|
|
101
|
+
MIN_NUM_BLOCKS_LOG2 = max(0, int(math.log2(MIN_NUM_BLOCKS / NUM_BLOCKS)))
|
|
102
|
+
MAX_NUM_BLOCKS_LOG2 = max(1, int(math.log2(MAX_NUM_BLOCKS / NUM_BLOCKS) + 1))
|
|
103
|
+
configs = []
|
|
104
|
+
for i in range(MIN_NUM_BLOCKS_LOG2, MAX_NUM_BLOCKS_LOG2):
|
|
105
|
+
configs.append({'SPLITK': 2 ** i})
|
|
106
|
+
return configs
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def conv_fwd_implicit_gemm_splitk_keys(input, weight, bias, neighbor):
|
|
110
|
+
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
|
111
|
+
return f'(2^{int(math.log2(N))}, {Ci}, {Co}, {V})'
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@autotune(
|
|
115
|
+
config_fn=conv_fwd_implicit_gemm_splitk_configs,
|
|
116
|
+
key_fn=conv_fwd_implicit_gemm_splitk_keys,
|
|
117
|
+
)
|
|
118
|
+
def conv_fwd_implicit_gemm_splitk(
|
|
119
|
+
input: torch.Tensor,
|
|
120
|
+
weight: torch.Tensor,
|
|
121
|
+
bias: torch.Tensor,
|
|
122
|
+
neighbor: torch.Tensor,
|
|
123
|
+
SPLITK: int = 1,
|
|
124
|
+
) -> torch.Tensor:
|
|
125
|
+
assert input.shape[1] == weight.shape[2], "Incompatible dimensions"
|
|
126
|
+
assert input.is_contiguous(), "Matrix input must be contiguous"
|
|
127
|
+
assert weight.is_contiguous(), "Matrix weight must be contiguous"
|
|
128
|
+
assert neighbor.is_contiguous(), "Matrix neighbor must be contiguous"
|
|
129
|
+
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
|
130
|
+
LOGN = int(math.log2(N))
|
|
131
|
+
# Launch the kernel.
|
|
132
|
+
if SPLITK == 1:
|
|
133
|
+
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
|
|
134
|
+
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
|
|
135
|
+
conv_fwd_implicit_gemm_kernel[grid](
|
|
136
|
+
input, weight, bias, neighbor, output,
|
|
137
|
+
N, LOGN, Ci, Co, V,
|
|
138
|
+
allow_tf32=config.allow_tf32,
|
|
139
|
+
)
|
|
140
|
+
return output
|
|
141
|
+
else:
|
|
142
|
+
output = torch.empty((SPLITK, N, Co), device=input.device, dtype=torch.float32)
|
|
143
|
+
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']), SPLITK)
|
|
144
|
+
conv_fwd_implicit_gemm_splitk_kernel[grid](
|
|
145
|
+
input, weight, bias, neighbor, output,
|
|
146
|
+
N, LOGN, Ci, Co, V,
|
|
147
|
+
SPLITK=SPLITK,
|
|
148
|
+
allow_tf32=config.allow_tf32,
|
|
149
|
+
)
|
|
150
|
+
return output.sum(dim=0).to(input.dtype)
|
ocnn/nn/kernels/utils.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from typing import *
|
|
2
|
+
import torch
|
|
3
|
+
import triton
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_gpu_name():
|
|
7
|
+
return torch.cuda.get_device_name()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_platform_name():
|
|
11
|
+
if torch.cuda.is_available():
|
|
12
|
+
if getattr(torch.version, 'hip', None) is not None:
|
|
13
|
+
return 'hip'
|
|
14
|
+
return 'cuda'
|
|
15
|
+
return 'unknown'
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_num_sm():
|
|
19
|
+
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_autotune_config(
|
|
23
|
+
default: List[triton.Config] = None,
|
|
24
|
+
platform: Dict[str, List[triton.Config]] = None,
|
|
25
|
+
device: Dict[str, List[triton.Config]] = None,
|
|
26
|
+
) -> List[triton.Config]:
|
|
27
|
+
"""
|
|
28
|
+
Get the autotune configuration for the current platform and device.
|
|
29
|
+
"""
|
|
30
|
+
if device is not None:
|
|
31
|
+
gpu_name = get_gpu_name()
|
|
32
|
+
for key, value in device.items():
|
|
33
|
+
if key.lower() in gpu_name.lower():
|
|
34
|
+
return value
|
|
35
|
+
|
|
36
|
+
if platform is not None:
|
|
37
|
+
platform_name = get_platform_name()
|
|
38
|
+
for key, value in platform.items():
|
|
39
|
+
if key.lower() in platform_name.lower():
|
|
40
|
+
return value
|
|
41
|
+
|
|
42
|
+
if default is None:
|
|
43
|
+
raise ValueError("No autotune configuration found for the current platform and device.")
|
|
44
|
+
return default
|
ocnn/nn/octree2col.py
CHANGED
|
@@ -1,53 +1,53 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import torch.nn
|
|
10
|
-
|
|
11
|
-
from ocnn.octree import Octree
|
|
12
|
-
from ocnn.utils import scatter_add
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def octree2col(data: torch.Tensor, octree: Octree, depth: int,
|
|
16
|
-
kernel_size: str = '333', stride: int = 1, nempty: bool = False):
|
|
17
|
-
r''' Gathers the neighboring features for convolutions.
|
|
18
|
-
|
|
19
|
-
Args:
|
|
20
|
-
data (torch.Tensor): The input data.
|
|
21
|
-
octree (Octree): The corresponding octree.
|
|
22
|
-
depth (int): The depth of current octree.
|
|
23
|
-
kernel_size (str): The kernel shape, choose from :obj:`333`, :obj:`311`,
|
|
24
|
-
:obj:`131`, :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and
|
|
25
|
-
:obj:`313`.
|
|
26
|
-
stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
|
|
27
|
-
stride is :obj:`2`, it always returns the neighborhood of the first
|
|
28
|
-
siblings, and the number of elements of output tensor is
|
|
29
|
-
:obj:`octree.nnum[depth] / 8`.
|
|
30
|
-
nempty (bool): If True, only returns the neighborhoods of the non-empty
|
|
31
|
-
octree nodes.
|
|
32
|
-
'''
|
|
33
|
-
|
|
34
|
-
neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
|
|
35
|
-
size = (neigh.shape[0], neigh.shape[1], data.shape[1])
|
|
36
|
-
out = torch.zeros(size, dtype=data.dtype, device=data.device)
|
|
37
|
-
valid = neigh >= 0
|
|
38
|
-
out[valid] = data[neigh[valid]] # (N, K, C)
|
|
39
|
-
return out
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def col2octree(data: torch.Tensor, octree: Octree, depth: int,
|
|
43
|
-
kernel_size: str = '333', stride: int = 1, nempty: bool = False):
|
|
44
|
-
r''' Scatters the convolution features to an octree.
|
|
45
|
-
|
|
46
|
-
Please refer to :func:`octree2col` for the usage of function parameters.
|
|
47
|
-
'''
|
|
48
|
-
|
|
49
|
-
neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
|
|
50
|
-
valid = neigh >= 0
|
|
51
|
-
dim_size = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
52
|
-
out = scatter_add(data[valid], neigh[valid], dim=0, dim_size=dim_size)
|
|
53
|
-
return out
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
|
|
11
|
+
from ocnn.octree import Octree
|
|
12
|
+
from ocnn.utils import scatter_add
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def octree2col(data: torch.Tensor, octree: Octree, depth: int,
|
|
16
|
+
kernel_size: str = '333', stride: int = 1, nempty: bool = False):
|
|
17
|
+
r''' Gathers the neighboring features for convolutions.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
data (torch.Tensor): The input data.
|
|
21
|
+
octree (Octree): The corresponding octree.
|
|
22
|
+
depth (int): The depth of current octree.
|
|
23
|
+
kernel_size (str): The kernel shape, choose from :obj:`333`, :obj:`311`,
|
|
24
|
+
:obj:`131`, :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and
|
|
25
|
+
:obj:`313`.
|
|
26
|
+
stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the
|
|
27
|
+
stride is :obj:`2`, it always returns the neighborhood of the first
|
|
28
|
+
siblings, and the number of elements of output tensor is
|
|
29
|
+
:obj:`octree.nnum[depth] / 8`.
|
|
30
|
+
nempty (bool): If True, only returns the neighborhoods of the non-empty
|
|
31
|
+
octree nodes.
|
|
32
|
+
'''
|
|
33
|
+
|
|
34
|
+
neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
|
|
35
|
+
size = (neigh.shape[0], neigh.shape[1], data.shape[1])
|
|
36
|
+
out = torch.zeros(size, dtype=data.dtype, device=data.device)
|
|
37
|
+
valid = neigh >= 0
|
|
38
|
+
out[valid] = data[neigh[valid]] # (N, K, C)
|
|
39
|
+
return out
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def col2octree(data: torch.Tensor, octree: Octree, depth: int,
|
|
43
|
+
kernel_size: str = '333', stride: int = 1, nempty: bool = False):
|
|
44
|
+
r''' Scatters the convolution features to an octree.
|
|
45
|
+
|
|
46
|
+
Please refer to :func:`octree2col` for the usage of function parameters.
|
|
47
|
+
'''
|
|
48
|
+
|
|
49
|
+
neigh = octree.get_neigh(depth, kernel_size, stride, nempty)
|
|
50
|
+
valid = neigh >= 0
|
|
51
|
+
dim_size = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
|
|
52
|
+
out = scatter_add(data[valid], neigh[valid], dim=0, dim_size=dim_size)
|
|
53
|
+
return out
|
ocnn/nn/octree2vox.py
CHANGED
|
@@ -1,50 +1,50 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
|
|
10
|
-
from ocnn.octree import Octree
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def octree2voxel(data: torch.Tensor, octree: Octree, depth: int,
|
|
14
|
-
nempty: bool = False):
|
|
15
|
-
r''' Converts the input feature to the full-voxel-based representation.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
data (torch.Tensor): The input feature.
|
|
19
|
-
octree (Octree): The corresponding octree.
|
|
20
|
-
depth (int): The depth of current octree.
|
|
21
|
-
nempty (bool): If True, :attr:`data` only contains the features of non-empty
|
|
22
|
-
octree nodes.
|
|
23
|
-
'''
|
|
24
|
-
|
|
25
|
-
x, y, z, b = octree.xyzb(depth, nempty)
|
|
26
|
-
|
|
27
|
-
num = 1 << depth
|
|
28
|
-
channel = data.shape[1]
|
|
29
|
-
vox = data.new_zeros([octree.batch_size, num, num, num, channel])
|
|
30
|
-
vox[b, x, y, z] = data
|
|
31
|
-
return vox
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
class Octree2Voxel(torch.nn.Module):
|
|
35
|
-
r''' Converts the input feature to the full-voxel-based representation.
|
|
36
|
-
|
|
37
|
-
Please refer to :func:`octree2voxel` for details.
|
|
38
|
-
'''
|
|
39
|
-
|
|
40
|
-
def __init__(self, nempty: bool = False):
|
|
41
|
-
super().__init__()
|
|
42
|
-
self.nempty = nempty
|
|
43
|
-
|
|
44
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
45
|
-
r''''''
|
|
46
|
-
|
|
47
|
-
return octree2voxel(data, octree, depth, self.nempty)
|
|
48
|
-
|
|
49
|
-
def extra_repr(self) -> str:
|
|
50
|
-
return 'nempty={}'.format(self.nempty)
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ocnn.octree import Octree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def octree2voxel(data: torch.Tensor, octree: Octree, depth: int,
|
|
14
|
+
nempty: bool = False):
|
|
15
|
+
r''' Converts the input feature to the full-voxel-based representation.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
data (torch.Tensor): The input feature.
|
|
19
|
+
octree (Octree): The corresponding octree.
|
|
20
|
+
depth (int): The depth of current octree.
|
|
21
|
+
nempty (bool): If True, :attr:`data` only contains the features of non-empty
|
|
22
|
+
octree nodes.
|
|
23
|
+
'''
|
|
24
|
+
|
|
25
|
+
x, y, z, b = octree.xyzb(depth, nempty)
|
|
26
|
+
|
|
27
|
+
num = 1 << depth
|
|
28
|
+
channel = data.shape[1]
|
|
29
|
+
vox = data.new_zeros([octree.batch_size, num, num, num, channel])
|
|
30
|
+
vox[b, x, y, z] = data
|
|
31
|
+
return vox
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Octree2Voxel(torch.nn.Module):
|
|
35
|
+
r''' Converts the input feature to the full-voxel-based representation.
|
|
36
|
+
|
|
37
|
+
Please refer to :func:`octree2voxel` for details.
|
|
38
|
+
'''
|
|
39
|
+
|
|
40
|
+
def __init__(self, nempty: bool = False):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.nempty = nempty
|
|
43
|
+
|
|
44
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
45
|
+
r''''''
|
|
46
|
+
|
|
47
|
+
return octree2voxel(data, octree, depth, self.nempty)
|
|
48
|
+
|
|
49
|
+
def extra_repr(self) -> str:
|
|
50
|
+
return 'nempty={}'.format(self.nempty)
|
ocnn/nn/octree_align.py
CHANGED
|
@@ -1,46 +1,46 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
|
|
10
|
-
from ocnn.octree import Octree
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def search_value(value: torch.Tensor, key: torch.Tensor, query: torch.Tensor):
|
|
14
|
-
r''' Searches values according to sorted shuffled keys.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
value (torch.Tensor): The input tensor with shape (N, C).
|
|
18
|
-
key (torch.Tensor): The key tensor corresponds to :attr:`value` with shape
|
|
19
|
-
(N,), which contains sorted shuffled keys of an octree.
|
|
20
|
-
query (torch.Tensor): The query tensor, which also contains shuffled keys.
|
|
21
|
-
'''
|
|
22
|
-
|
|
23
|
-
# deal with out-of-bound queries, the indices of these queries
|
|
24
|
-
# returned by torch.searchsorted equal to `key.shape[0]`
|
|
25
|
-
out_of_bound = query > key[-1]
|
|
26
|
-
|
|
27
|
-
# search
|
|
28
|
-
idx = torch.searchsorted(key, query)
|
|
29
|
-
idx[out_of_bound] = -1 # to avoid overflow when executing the following line
|
|
30
|
-
found = key[idx] == query
|
|
31
|
-
|
|
32
|
-
# assign the found value to the output
|
|
33
|
-
out = torch.zeros(query.shape[0], value.shape[1], device=value.device)
|
|
34
|
-
out[found] = value[idx[found]]
|
|
35
|
-
return out
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def octree_align(value: torch.Tensor, octree: Octree, octree_query: Octree,
|
|
39
|
-
depth: int, nempty: bool = False):
|
|
40
|
-
r''' Wraps :func:`octree_align` to take octrees as input for convenience.
|
|
41
|
-
'''
|
|
42
|
-
|
|
43
|
-
key = octree.key(depth, nempty)
|
|
44
|
-
query = octree_query.key(depth, nempty)
|
|
45
|
-
assert key.shape[0] == value.shape[0]
|
|
46
|
-
return search_value(value, key, query)
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ocnn.octree import Octree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def search_value(value: torch.Tensor, key: torch.Tensor, query: torch.Tensor):
|
|
14
|
+
r''' Searches values according to sorted shuffled keys.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
value (torch.Tensor): The input tensor with shape (N, C).
|
|
18
|
+
key (torch.Tensor): The key tensor corresponds to :attr:`value` with shape
|
|
19
|
+
(N,), which contains sorted shuffled keys of an octree.
|
|
20
|
+
query (torch.Tensor): The query tensor, which also contains shuffled keys.
|
|
21
|
+
'''
|
|
22
|
+
|
|
23
|
+
# deal with out-of-bound queries, the indices of these queries
|
|
24
|
+
# returned by torch.searchsorted equal to `key.shape[0]`
|
|
25
|
+
out_of_bound = query > key[-1]
|
|
26
|
+
|
|
27
|
+
# search
|
|
28
|
+
idx = torch.searchsorted(key, query)
|
|
29
|
+
idx[out_of_bound] = -1 # to avoid overflow when executing the following line
|
|
30
|
+
found = key[idx] == query
|
|
31
|
+
|
|
32
|
+
# assign the found value to the output
|
|
33
|
+
out = torch.zeros(query.shape[0], value.shape[1], device=value.device)
|
|
34
|
+
out[found] = value[idx[found]]
|
|
35
|
+
return out
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def octree_align(value: torch.Tensor, octree: Octree, octree_query: Octree,
|
|
39
|
+
depth: int, nempty: bool = False):
|
|
40
|
+
r''' Wraps :func:`octree_align` to take octrees as input for convenience.
|
|
41
|
+
'''
|
|
42
|
+
|
|
43
|
+
key = octree.key(depth, nempty)
|
|
44
|
+
query = octree_query.key(depth, nempty)
|
|
45
|
+
assert key.shape[0] == value.shape[0]
|
|
46
|
+
return search_value(value, key, query)
|