ocnn 2.2.7__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +1 -1
- ocnn/models/resnet.py +2 -2
- ocnn/nn/__init__.py +2 -1
- ocnn/nn/kernels/__init__.py +14 -0
- ocnn/nn/kernels/autotuner.py +416 -0
- ocnn/nn/kernels/config.py +67 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
- ocnn/nn/kernels/utils.py +44 -0
- ocnn/nn/octree_conv.py +2 -1
- ocnn/nn/octree_conv_t.py +148 -0
- ocnn/nn/octree_pad.py +4 -4
- ocnn/octree/octree.py +218 -109
- ocnn/octree/points.py +95 -34
- {ocnn-2.2.7.dist-info → ocnn-2.3.0.dist-info}/METADATA +11 -6
- {ocnn-2.2.7.dist-info → ocnn-2.3.0.dist-info}/RECORD +21 -12
- {ocnn-2.2.7.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
- {ocnn-2.2.7.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +0 -0
- {ocnn-2.2.7.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import triton
|
|
4
|
+
import triton.language as tl
|
|
5
|
+
from .utils import get_num_sm
|
|
6
|
+
from .autotuner import triton_autotune, autotune
|
|
7
|
+
from . import config
|
|
8
|
+
from .conv_fwd_implicit_gemm import conv_fwd_implicit_gemm_kernel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton_autotune(
|
|
12
|
+
configs=config.autotune_config,
|
|
13
|
+
key=['LOGN', 'Ci', 'Co', 'V', 'SPLITK', 'allow_tf32'],
|
|
14
|
+
)
|
|
15
|
+
@triton.jit
|
|
16
|
+
def conv_fwd_implicit_gemm_splitk_kernel(
|
|
17
|
+
input,
|
|
18
|
+
weight,
|
|
19
|
+
bias,
|
|
20
|
+
neighbor,
|
|
21
|
+
output,
|
|
22
|
+
# Tensor dimensions
|
|
23
|
+
N, LOGN, Ci, Co, V: tl.constexpr,
|
|
24
|
+
# Meta-parameters
|
|
25
|
+
B1: tl.constexpr, # Block size for N dimension
|
|
26
|
+
B2: tl.constexpr, # Block size for Co dimension
|
|
27
|
+
BK: tl.constexpr, # Block size for K dimension (V * Ci)
|
|
28
|
+
SPLITK: tl.constexpr, # Split K dimension
|
|
29
|
+
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Sparse submanifold convolution forward kernel using implicit GEMM with split K dimension.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
input (pointer): A pointer to the input tensor of shape (N, Ci)
|
|
36
|
+
weight (pointer): A pointer to the weight tensor of shape (Co, V, Ci)
|
|
37
|
+
bias (pointer): A pointer to the bias tensor of shape (Co)
|
|
38
|
+
neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
|
|
39
|
+
output (pointer): A pointer to the output tensor of shape (N, Co)
|
|
40
|
+
"""
|
|
41
|
+
block_id_k = tl.program_id(axis=1) # SplitK dimension
|
|
42
|
+
block_id = tl.program_id(axis=0)
|
|
43
|
+
block_dim_co = tl.cdiv(Co, B2)
|
|
44
|
+
block_id_co = block_id % block_dim_co
|
|
45
|
+
block_id_n = block_id // block_dim_co
|
|
46
|
+
|
|
47
|
+
# Create pointers for submatrices of A and B.
|
|
48
|
+
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
|
|
49
|
+
k_start = tl.cdiv(num_k * V * block_id_k, SPLITK)
|
|
50
|
+
k_end = tl.cdiv(num_k * V * (block_id_k + 1), SPLITK)
|
|
51
|
+
offset_n = (block_id_n * B1 + tl.arange(0, B1)) % N # (B1,)
|
|
52
|
+
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
|
|
53
|
+
offset_k = tl.arange(0, BK) # (BK,)
|
|
54
|
+
|
|
55
|
+
# Create a block of the output matrix C.
|
|
56
|
+
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
|
|
57
|
+
curr_v = k_start // num_k
|
|
58
|
+
curr_bk = k_start % num_k
|
|
59
|
+
weight_offset_base = curr_v * Ci + curr_bk * BK
|
|
60
|
+
|
|
61
|
+
weight_ptr = weight + weight_offset_base + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
|
|
62
|
+
|
|
63
|
+
# Iterate along V*Ci dimension.
|
|
64
|
+
for k in range(k_start, k_end):
|
|
65
|
+
v = k // num_k
|
|
66
|
+
bk = k % num_k
|
|
67
|
+
# Calculate pointers to input matrix.
|
|
68
|
+
neighbor_offset_n = tl.load(neighbor + offset_n * V + v).to(tl.int64) # (B1,)
|
|
69
|
+
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
|
|
70
|
+
# Load the next block of input and weight.
|
|
71
|
+
neigh_mask = neighbor_offset_n != -1
|
|
72
|
+
k_mask = offset_k < Ci - bk * BK
|
|
73
|
+
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
|
|
74
|
+
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
|
|
75
|
+
# Accumulate along the K dimension.
|
|
76
|
+
accumulator = tl.dot(input_block, weight_block, accumulator,
|
|
77
|
+
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
|
78
|
+
# Advance the pointers to the next Ci block.
|
|
79
|
+
weight_ptr += min(BK, Ci - bk * BK)
|
|
80
|
+
|
|
81
|
+
# add bias
|
|
82
|
+
if bias is not None and block_id_k == 0:
|
|
83
|
+
bias_block = tl.load(bias + offset_co)
|
|
84
|
+
accumulator += bias_block[None, :]
|
|
85
|
+
|
|
86
|
+
# Write back the block of the output matrix with masks.
|
|
87
|
+
out_offset_n = block_id_n * B1 + tl.arange(0, B1)
|
|
88
|
+
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
|
|
89
|
+
out_ptr = output + block_id_k * N * Co + (out_offset_n[:, None] * Co + out_offset_co[None, :])
|
|
90
|
+
out_mask = (out_offset_n[:, None] < N) & (out_offset_co[None, :] < Co)
|
|
91
|
+
tl.store(out_ptr, accumulator, mask=out_mask)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def conv_fwd_implicit_gemm_splitk_configs(input, weight, bias, neighbor):
|
|
95
|
+
N, Co = neighbor.shape[0], weight.shape[0]
|
|
96
|
+
MAX_NB1 = (N + 128 - 1) // 128
|
|
97
|
+
MAX_NB2 = (Co + 128 - 1) // 128
|
|
98
|
+
NUM_BLOCKS = MAX_NB1 * MAX_NB2
|
|
99
|
+
MIN_NUM_BLOCKS = get_num_sm()
|
|
100
|
+
MAX_NUM_BLOCKS = 32 * get_num_sm()
|
|
101
|
+
MIN_NUM_BLOCKS_LOG2 = max(0, int(math.log2(MIN_NUM_BLOCKS / NUM_BLOCKS)))
|
|
102
|
+
MAX_NUM_BLOCKS_LOG2 = max(1, int(math.log2(MAX_NUM_BLOCKS / NUM_BLOCKS) + 1))
|
|
103
|
+
configs = []
|
|
104
|
+
for i in range(MIN_NUM_BLOCKS_LOG2, MAX_NUM_BLOCKS_LOG2):
|
|
105
|
+
configs.append({'SPLITK': 2 ** i})
|
|
106
|
+
return configs
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def conv_fwd_implicit_gemm_splitk_keys(input, weight, bias, neighbor):
|
|
110
|
+
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
|
111
|
+
return f'(2^{int(math.log2(N))}, {Ci}, {Co}, {V})'
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@autotune(
|
|
115
|
+
config_fn=conv_fwd_implicit_gemm_splitk_configs,
|
|
116
|
+
key_fn=conv_fwd_implicit_gemm_splitk_keys,
|
|
117
|
+
)
|
|
118
|
+
def conv_fwd_implicit_gemm_splitk(
|
|
119
|
+
input: torch.Tensor,
|
|
120
|
+
weight: torch.Tensor,
|
|
121
|
+
bias: torch.Tensor,
|
|
122
|
+
neighbor: torch.Tensor,
|
|
123
|
+
SPLITK: int = 1,
|
|
124
|
+
) -> torch.Tensor:
|
|
125
|
+
assert input.shape[1] == weight.shape[2], "Incompatible dimensions"
|
|
126
|
+
assert input.is_contiguous(), "Matrix input must be contiguous"
|
|
127
|
+
assert weight.is_contiguous(), "Matrix weight must be contiguous"
|
|
128
|
+
assert neighbor.is_contiguous(), "Matrix neighbor must be contiguous"
|
|
129
|
+
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
|
130
|
+
LOGN = int(math.log2(N))
|
|
131
|
+
# Launch the kernel.
|
|
132
|
+
if SPLITK == 1:
|
|
133
|
+
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
|
|
134
|
+
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
|
|
135
|
+
conv_fwd_implicit_gemm_kernel[grid](
|
|
136
|
+
input, weight, bias, neighbor, output,
|
|
137
|
+
N, LOGN, Ci, Co, V,
|
|
138
|
+
allow_tf32=config.allow_tf32,
|
|
139
|
+
)
|
|
140
|
+
return output
|
|
141
|
+
else:
|
|
142
|
+
output = torch.empty((SPLITK, N, Co), device=input.device, dtype=torch.float32)
|
|
143
|
+
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']), SPLITK)
|
|
144
|
+
conv_fwd_implicit_gemm_splitk_kernel[grid](
|
|
145
|
+
input, weight, bias, neighbor, output,
|
|
146
|
+
N, LOGN, Ci, Co, V,
|
|
147
|
+
SPLITK=SPLITK,
|
|
148
|
+
allow_tf32=config.allow_tf32,
|
|
149
|
+
)
|
|
150
|
+
return output.sum(dim=0).to(input.dtype)
|
ocnn/nn/kernels/utils.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from typing import *
|
|
2
|
+
import torch
|
|
3
|
+
import triton
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_gpu_name():
|
|
7
|
+
return torch.cuda.get_device_name()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_platform_name():
|
|
11
|
+
if torch.cuda.is_available():
|
|
12
|
+
if getattr(torch.version, 'hip', None) is not None:
|
|
13
|
+
return 'hip'
|
|
14
|
+
return 'cuda'
|
|
15
|
+
return 'unknown'
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_num_sm():
|
|
19
|
+
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_autotune_config(
|
|
23
|
+
default: List[triton.Config] = None,
|
|
24
|
+
platform: Dict[str, List[triton.Config]] = None,
|
|
25
|
+
device: Dict[str, List[triton.Config]] = None,
|
|
26
|
+
) -> List[triton.Config]:
|
|
27
|
+
"""
|
|
28
|
+
Get the autotune configuration for the current platform and device.
|
|
29
|
+
"""
|
|
30
|
+
if device is not None:
|
|
31
|
+
gpu_name = get_gpu_name()
|
|
32
|
+
for key, value in device.items():
|
|
33
|
+
if key.lower() in gpu_name.lower():
|
|
34
|
+
return value
|
|
35
|
+
|
|
36
|
+
if platform is not None:
|
|
37
|
+
platform_name = get_platform_name()
|
|
38
|
+
for key, value in platform.items():
|
|
39
|
+
if key.lower() in platform_name.lower():
|
|
40
|
+
return value
|
|
41
|
+
|
|
42
|
+
if default is None:
|
|
43
|
+
raise ValueError("No autotune configuration found for the current platform and device.")
|
|
44
|
+
return default
|
ocnn/nn/octree_conv.py
CHANGED
|
@@ -98,7 +98,8 @@ class OctreeConvBase:
|
|
|
98
98
|
|
|
99
99
|
# Check the shape of input data
|
|
100
100
|
check = tuple(data.shape) == self.in_shape
|
|
101
|
-
assert check, 'The shape of input data is wrong
|
|
101
|
+
assert check, ('The shape of input data is wrong: ' +
|
|
102
|
+
'expected {}, got {}.'.format(self.in_shape, data.shape))
|
|
102
103
|
|
|
103
104
|
# Init the output data
|
|
104
105
|
out = data.new_zeros(self.out_shape)
|
ocnn/nn/octree_conv_t.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from torch.autograd import Function
|
|
11
|
+
from typing import List
|
|
12
|
+
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
from ocnn.nn import OctreeConv
|
|
15
|
+
from ocnn.utils import xavier_uniform_, resize_with_last_val, list2str
|
|
16
|
+
from ocnn.nn.kernels import conv_fwd_implicit_gemm_splitk, conv_bwd_implicit_gemm_splitk
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OctreeConvTritonFunction(Function):
|
|
20
|
+
r''' Wrap the octree convolution for auto-diff.
|
|
21
|
+
'''
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, bias: torch.Tensor,
|
|
25
|
+
neigh: torch.Tensor):
|
|
26
|
+
data = data.contiguous()
|
|
27
|
+
weights = weights.contiguous()
|
|
28
|
+
neigh = neigh.contiguous()
|
|
29
|
+
if bias is not None:
|
|
30
|
+
bias = bias.contiguous()
|
|
31
|
+
|
|
32
|
+
out = conv_fwd_implicit_gemm_splitk(data, weights, bias, neigh)
|
|
33
|
+
ctx.save_for_backward(data, weights, bias, neigh)
|
|
34
|
+
return out
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def backward(ctx, grad):
|
|
38
|
+
data, weights, bias, neigh = ctx.saved_tensors
|
|
39
|
+
grad = grad.contiguous()
|
|
40
|
+
grad_input, grad_weight, grad_bias = conv_bwd_implicit_gemm_splitk(
|
|
41
|
+
grad, data, weights, bias, neigh, ctx.needs_input_grad)
|
|
42
|
+
return grad_input, grad_weight, grad_bias, None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# alias
|
|
46
|
+
octree_conv_triton = OctreeConvTritonFunction.apply
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class OctreeConvTriton(torch.nn.Module):
|
|
50
|
+
r''' Performs octree convolution.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
in_channels (int): Number of input channels.
|
|
54
|
+
out_channels (int): Number of output channels.
|
|
55
|
+
kernel_size (List(int)): The kernel shape, only :obj:`[3]` and :obj:`[3,3,3]`
|
|
56
|
+
are supported now for the triton implementation.
|
|
57
|
+
stride (int): The stride of the convolution, only :obj:`1` is supported now.
|
|
58
|
+
nempty (bool): If True, only performs the convolution on non-empty octree
|
|
59
|
+
nodes; otherwise, performs the convolution on all octree nodes.
|
|
60
|
+
use_bias (bool): If True, add a bias term to the convolution.
|
|
61
|
+
|
|
62
|
+
.. note::
|
|
63
|
+
Each non-empty octree node has exactly 8 children nodes, among which some
|
|
64
|
+
children nodes are non-empty and some are empty. If :attr:`nempty` is true,
|
|
65
|
+
the convolution is performed on non-empty octree nodes only, which is exactly
|
|
66
|
+
the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
|
|
67
|
+
convolution is performed on all octree nodes, which is essential for shape
|
|
68
|
+
reconstruction tasks and can also be used in classification and segmentation
|
|
69
|
+
(with slightly better performance and larger memory cost).
|
|
70
|
+
'''
|
|
71
|
+
|
|
72
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
73
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
74
|
+
nempty: bool = False, direct_method: bool = False,
|
|
75
|
+
use_bias: bool = False, max_buffer: int = int(2e8)):
|
|
76
|
+
super().__init__()
|
|
77
|
+
self.in_channels = in_channels
|
|
78
|
+
self.out_channels = out_channels
|
|
79
|
+
self.kernel_size = resize_with_last_val(kernel_size)
|
|
80
|
+
self.kernel = list2str(self.kernel_size)
|
|
81
|
+
self.stride = stride
|
|
82
|
+
self.nempty = nempty
|
|
83
|
+
self.use_bias = use_bias
|
|
84
|
+
assert self.stride == 1, 'Only stride=1 is supported now.'
|
|
85
|
+
assert self.kernel == '333', 'Only kernel_size=[3,3,3] is supported now.'
|
|
86
|
+
|
|
87
|
+
self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
|
|
88
|
+
self.weights_shape = (self.kdim, self.in_channels, self.out_channels)
|
|
89
|
+
self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
|
|
90
|
+
self.bias = (torch.nn.Parameter(torch.Tensor(self.out_channels))
|
|
91
|
+
if use_bias else None)
|
|
92
|
+
self.reset_parameters()
|
|
93
|
+
|
|
94
|
+
def reset_parameters(self):
|
|
95
|
+
xavier_uniform_(self.weights)
|
|
96
|
+
if self.use_bias:
|
|
97
|
+
torch.nn.init.zeros_(self.bias)
|
|
98
|
+
|
|
99
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
100
|
+
r''' Defines the octree convolution.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
data (torch.Tensor): The input data.
|
|
104
|
+
octree (Octree): The corresponding octree.
|
|
105
|
+
depth (int): The depth of current octree.
|
|
106
|
+
'''
|
|
107
|
+
|
|
108
|
+
# TODO: remove the permute operation by changing the kernel implementation
|
|
109
|
+
weight = self.weights.permute(2, 0, 1) # (V,Ci,Co) -> (Co,V,Ci)
|
|
110
|
+
neigh = octree.get_neigh(depth, self.kernel, self.stride, self.nempty)
|
|
111
|
+
out = octree_conv_triton(data, weight, self.bias, neigh)
|
|
112
|
+
return out
|
|
113
|
+
|
|
114
|
+
def extra_repr(self) -> str:
|
|
115
|
+
r''' Sets the extra representation of the module.
|
|
116
|
+
'''
|
|
117
|
+
|
|
118
|
+
return ('triton, in_channels={}, out_channels={}, kernel_size={}, stride={}, '
|
|
119
|
+
'nempty={}, bias={}').format(self.in_channels, self.out_channels,
|
|
120
|
+
self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# alias
|
|
124
|
+
OctreeConvT = OctreeConvTriton
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def convert_conv_triton(module: torch.nn.Module) -> torch.nn.Module:
|
|
128
|
+
r''' Convert OctreeConv modules to OctreeConvTriton modules in a network.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
module (torch.nn.Module): The input module.
|
|
132
|
+
'''
|
|
133
|
+
|
|
134
|
+
module_out = module
|
|
135
|
+
if (isinstance(module, OctreeConv) and
|
|
136
|
+
module.stride == 1 and module.kernel_size == [3, 3, 3]):
|
|
137
|
+
module_out = OctreeConvTriton(
|
|
138
|
+
module.in_channels, module.out_channels, module.kernel_size,
|
|
139
|
+
module.stride, module.nempty, use_bias=module.use_bias,)
|
|
140
|
+
with torch.no_grad():
|
|
141
|
+
module_out.weights = module.weights
|
|
142
|
+
if module.use_bias:
|
|
143
|
+
module_out.bias = module.bias
|
|
144
|
+
|
|
145
|
+
for name, child in module.named_children():
|
|
146
|
+
module_out.add_module(name, convert_conv_triton(child))
|
|
147
|
+
del module
|
|
148
|
+
return module_out
|
ocnn/nn/octree_pad.py
CHANGED
|
@@ -22,10 +22,10 @@ def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0)
|
|
|
22
22
|
val (float): The padding value. (Default: :obj:`0.0`)
|
|
23
23
|
'''
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
idx = octree.nempty_index(depth)
|
|
26
26
|
size = (octree.nnum[depth], data.shape[1]) # (N, C)
|
|
27
27
|
out = torch.full(size, val, dtype=data.dtype, device=data.device)
|
|
28
|
-
out[
|
|
28
|
+
out[idx] = data
|
|
29
29
|
return out
|
|
30
30
|
|
|
31
31
|
|
|
@@ -35,5 +35,5 @@ def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
|
|
|
35
35
|
Please refer to :func:`octree_depad` for the meaning of the arguments.
|
|
36
36
|
'''
|
|
37
37
|
|
|
38
|
-
|
|
39
|
-
return data[
|
|
38
|
+
idx = octree.nempty_index(depth)
|
|
39
|
+
return data[idx]
|