ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +45 -44
- ocnn/nn/kernels/__init__.py +14 -0
- ocnn/nn/kernels/autotuner.py +416 -0
- ocnn/nn/kernels/config.py +67 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
- ocnn/nn/kernels/utils.py +44 -0
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +430 -429
- ocnn/nn/octree_conv_t.py +148 -0
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +770 -770
- ocnn/octree/points.py +384 -323
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
- ocnn-2.3.0.dist-info/RECORD +45 -0
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
- ocnn-2.2.8.dist-info/RECORD +0 -36
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
ocnn/nn/octree_conv_t.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn
|
|
10
|
+
from torch.autograd import Function
|
|
11
|
+
from typing import List
|
|
12
|
+
|
|
13
|
+
from ocnn.octree import Octree
|
|
14
|
+
from ocnn.nn import OctreeConv
|
|
15
|
+
from ocnn.utils import xavier_uniform_, resize_with_last_val, list2str
|
|
16
|
+
from ocnn.nn.kernels import conv_fwd_implicit_gemm_splitk, conv_bwd_implicit_gemm_splitk
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OctreeConvTritonFunction(Function):
|
|
20
|
+
r''' Wrap the octree convolution for auto-diff.
|
|
21
|
+
'''
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, bias: torch.Tensor,
|
|
25
|
+
neigh: torch.Tensor):
|
|
26
|
+
data = data.contiguous()
|
|
27
|
+
weights = weights.contiguous()
|
|
28
|
+
neigh = neigh.contiguous()
|
|
29
|
+
if bias is not None:
|
|
30
|
+
bias = bias.contiguous()
|
|
31
|
+
|
|
32
|
+
out = conv_fwd_implicit_gemm_splitk(data, weights, bias, neigh)
|
|
33
|
+
ctx.save_for_backward(data, weights, bias, neigh)
|
|
34
|
+
return out
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def backward(ctx, grad):
|
|
38
|
+
data, weights, bias, neigh = ctx.saved_tensors
|
|
39
|
+
grad = grad.contiguous()
|
|
40
|
+
grad_input, grad_weight, grad_bias = conv_bwd_implicit_gemm_splitk(
|
|
41
|
+
grad, data, weights, bias, neigh, ctx.needs_input_grad)
|
|
42
|
+
return grad_input, grad_weight, grad_bias, None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# alias
|
|
46
|
+
octree_conv_triton = OctreeConvTritonFunction.apply
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class OctreeConvTriton(torch.nn.Module):
|
|
50
|
+
r''' Performs octree convolution.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
in_channels (int): Number of input channels.
|
|
54
|
+
out_channels (int): Number of output channels.
|
|
55
|
+
kernel_size (List(int)): The kernel shape, only :obj:`[3]` and :obj:`[3,3,3]`
|
|
56
|
+
are supported now for the triton implementation.
|
|
57
|
+
stride (int): The stride of the convolution, only :obj:`1` is supported now.
|
|
58
|
+
nempty (bool): If True, only performs the convolution on non-empty octree
|
|
59
|
+
nodes; otherwise, performs the convolution on all octree nodes.
|
|
60
|
+
use_bias (bool): If True, add a bias term to the convolution.
|
|
61
|
+
|
|
62
|
+
.. note::
|
|
63
|
+
Each non-empty octree node has exactly 8 children nodes, among which some
|
|
64
|
+
children nodes are non-empty and some are empty. If :attr:`nempty` is true,
|
|
65
|
+
the convolution is performed on non-empty octree nodes only, which is exactly
|
|
66
|
+
the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
|
|
67
|
+
convolution is performed on all octree nodes, which is essential for shape
|
|
68
|
+
reconstruction tasks and can also be used in classification and segmentation
|
|
69
|
+
(with slightly better performance and larger memory cost).
|
|
70
|
+
'''
|
|
71
|
+
|
|
72
|
+
def __init__(self, in_channels: int, out_channels: int,
|
|
73
|
+
kernel_size: List[int] = [3], stride: int = 1,
|
|
74
|
+
nempty: bool = False, direct_method: bool = False,
|
|
75
|
+
use_bias: bool = False, max_buffer: int = int(2e8)):
|
|
76
|
+
super().__init__()
|
|
77
|
+
self.in_channels = in_channels
|
|
78
|
+
self.out_channels = out_channels
|
|
79
|
+
self.kernel_size = resize_with_last_val(kernel_size)
|
|
80
|
+
self.kernel = list2str(self.kernel_size)
|
|
81
|
+
self.stride = stride
|
|
82
|
+
self.nempty = nempty
|
|
83
|
+
self.use_bias = use_bias
|
|
84
|
+
assert self.stride == 1, 'Only stride=1 is supported now.'
|
|
85
|
+
assert self.kernel == '333', 'Only kernel_size=[3,3,3] is supported now.'
|
|
86
|
+
|
|
87
|
+
self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
|
|
88
|
+
self.weights_shape = (self.kdim, self.in_channels, self.out_channels)
|
|
89
|
+
self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
|
|
90
|
+
self.bias = (torch.nn.Parameter(torch.Tensor(self.out_channels))
|
|
91
|
+
if use_bias else None)
|
|
92
|
+
self.reset_parameters()
|
|
93
|
+
|
|
94
|
+
def reset_parameters(self):
|
|
95
|
+
xavier_uniform_(self.weights)
|
|
96
|
+
if self.use_bias:
|
|
97
|
+
torch.nn.init.zeros_(self.bias)
|
|
98
|
+
|
|
99
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
|
|
100
|
+
r''' Defines the octree convolution.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
data (torch.Tensor): The input data.
|
|
104
|
+
octree (Octree): The corresponding octree.
|
|
105
|
+
depth (int): The depth of current octree.
|
|
106
|
+
'''
|
|
107
|
+
|
|
108
|
+
# TODO: remove the permute operation by changing the kernel implementation
|
|
109
|
+
weight = self.weights.permute(2, 0, 1) # (V,Ci,Co) -> (Co,V,Ci)
|
|
110
|
+
neigh = octree.get_neigh(depth, self.kernel, self.stride, self.nempty)
|
|
111
|
+
out = octree_conv_triton(data, weight, self.bias, neigh)
|
|
112
|
+
return out
|
|
113
|
+
|
|
114
|
+
def extra_repr(self) -> str:
|
|
115
|
+
r''' Sets the extra representation of the module.
|
|
116
|
+
'''
|
|
117
|
+
|
|
118
|
+
return ('triton, in_channels={}, out_channels={}, kernel_size={}, stride={}, '
|
|
119
|
+
'nempty={}, bias={}').format(self.in_channels, self.out_channels,
|
|
120
|
+
self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# alias
|
|
124
|
+
OctreeConvT = OctreeConvTriton
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def convert_conv_triton(module: torch.nn.Module) -> torch.nn.Module:
|
|
128
|
+
r''' Convert OctreeConv modules to OctreeConvTriton modules in a network.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
module (torch.nn.Module): The input module.
|
|
132
|
+
'''
|
|
133
|
+
|
|
134
|
+
module_out = module
|
|
135
|
+
if (isinstance(module, OctreeConv) and
|
|
136
|
+
module.stride == 1 and module.kernel_size == [3, 3, 3]):
|
|
137
|
+
module_out = OctreeConvTriton(
|
|
138
|
+
module.in_channels, module.out_channels, module.kernel_size,
|
|
139
|
+
module.stride, module.nempty, use_bias=module.use_bias,)
|
|
140
|
+
with torch.no_grad():
|
|
141
|
+
module_out.weights = module.weights
|
|
142
|
+
if module.use_bias:
|
|
143
|
+
module_out.bias = module.bias
|
|
144
|
+
|
|
145
|
+
for name, child in module.named_children():
|
|
146
|
+
module_out.add_module(name, convert_conv_triton(child))
|
|
147
|
+
del module
|
|
148
|
+
return module_out
|
ocnn/nn/octree_drop.py
CHANGED
|
@@ -1,55 +1,55 @@
|
|
|
1
|
-
# --------------------------------------------------------
|
|
2
|
-
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
-
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
-
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
-
# Written by Peng-Shuai Wang
|
|
6
|
-
# --------------------------------------------------------
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from typing import Optional
|
|
10
|
-
|
|
11
|
-
from ocnn.octree import Octree
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class OctreeDropPath(torch.nn.Module):
|
|
15
|
-
r'''Drop paths (Stochastic Depth) per sample when applied in main path of
|
|
16
|
-
residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
drop_prob (int): The probability of drop paths.
|
|
20
|
-
nempty (bool): Indicate whether the input data only contains features of the
|
|
21
|
-
non-empty octree nodes or not.
|
|
22
|
-
scale_by_keep (bool): Whether to scale the kept features proportionally.
|
|
23
|
-
'''
|
|
24
|
-
|
|
25
|
-
def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
|
|
26
|
-
scale_by_keep: bool = True):
|
|
27
|
-
super().__init__()
|
|
28
|
-
|
|
29
|
-
self.drop_prob = drop_prob
|
|
30
|
-
self.nempty = nempty
|
|
31
|
-
self.scale_by_keep = scale_by_keep
|
|
32
|
-
|
|
33
|
-
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
34
|
-
batch_id: Optional[torch.Tensor] = None):
|
|
35
|
-
r''''''
|
|
36
|
-
|
|
37
|
-
if self.drop_prob <= 0.0 or not self.training:
|
|
38
|
-
return data
|
|
39
|
-
|
|
40
|
-
batch_size = octree.batch_size
|
|
41
|
-
keep_prob = 1 - self.drop_prob
|
|
42
|
-
rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
|
|
43
|
-
rnd_tensor = torch.floor(rnd_tensor + keep_prob)
|
|
44
|
-
if keep_prob > 0.0 and self.scale_by_keep:
|
|
45
|
-
rnd_tensor.div_(keep_prob)
|
|
46
|
-
|
|
47
|
-
if batch_id is None:
|
|
48
|
-
batch_id = octree.batch_id(depth, self.nempty)
|
|
49
|
-
drop_mask = rnd_tensor[batch_id]
|
|
50
|
-
output = data * drop_mask
|
|
51
|
-
return output
|
|
52
|
-
|
|
53
|
-
def extra_repr(self) -> str:
|
|
54
|
-
return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
|
|
55
|
-
self.drop_prob, self.nempty, self.scale_by_keep) # noqa
|
|
1
|
+
# --------------------------------------------------------
|
|
2
|
+
# Octree-based Sparse Convolutional Neural Networks
|
|
3
|
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
|
4
|
+
# Licensed under The MIT License [see LICENSE for details]
|
|
5
|
+
# Written by Peng-Shuai Wang
|
|
6
|
+
# --------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from ocnn.octree import Octree
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OctreeDropPath(torch.nn.Module):
|
|
15
|
+
r'''Drop paths (Stochastic Depth) per sample when applied in main path of
|
|
16
|
+
residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
drop_prob (int): The probability of drop paths.
|
|
20
|
+
nempty (bool): Indicate whether the input data only contains features of the
|
|
21
|
+
non-empty octree nodes or not.
|
|
22
|
+
scale_by_keep (bool): Whether to scale the kept features proportionally.
|
|
23
|
+
'''
|
|
24
|
+
|
|
25
|
+
def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
|
|
26
|
+
scale_by_keep: bool = True):
|
|
27
|
+
super().__init__()
|
|
28
|
+
|
|
29
|
+
self.drop_prob = drop_prob
|
|
30
|
+
self.nempty = nempty
|
|
31
|
+
self.scale_by_keep = scale_by_keep
|
|
32
|
+
|
|
33
|
+
def forward(self, data: torch.Tensor, octree: Octree, depth: int,
|
|
34
|
+
batch_id: Optional[torch.Tensor] = None):
|
|
35
|
+
r''''''
|
|
36
|
+
|
|
37
|
+
if self.drop_prob <= 0.0 or not self.training:
|
|
38
|
+
return data
|
|
39
|
+
|
|
40
|
+
batch_size = octree.batch_size
|
|
41
|
+
keep_prob = 1 - self.drop_prob
|
|
42
|
+
rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
|
|
43
|
+
rnd_tensor = torch.floor(rnd_tensor + keep_prob)
|
|
44
|
+
if keep_prob > 0.0 and self.scale_by_keep:
|
|
45
|
+
rnd_tensor.div_(keep_prob)
|
|
46
|
+
|
|
47
|
+
if batch_id is None:
|
|
48
|
+
batch_id = octree.batch_id(depth, self.nempty)
|
|
49
|
+
drop_mask = rnd_tensor[batch_id]
|
|
50
|
+
output = data * drop_mask
|
|
51
|
+
return output
|
|
52
|
+
|
|
53
|
+
def extra_repr(self) -> str:
|
|
54
|
+
return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
|
|
55
|
+
self.drop_prob, self.nempty, self.scale_by_keep) # noqa
|