ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. ocnn/__init__.py +24 -24
  2. ocnn/dataset.py +160 -160
  3. ocnn/models/__init__.py +29 -29
  4. ocnn/models/autoencoder.py +155 -155
  5. ocnn/models/hrnet.py +192 -192
  6. ocnn/models/image2shape.py +128 -128
  7. ocnn/models/lenet.py +46 -46
  8. ocnn/models/ounet.py +94 -94
  9. ocnn/models/resnet.py +53 -53
  10. ocnn/models/segnet.py +72 -72
  11. ocnn/models/unet.py +105 -105
  12. ocnn/modules/__init__.py +26 -26
  13. ocnn/modules/modules.py +303 -303
  14. ocnn/modules/resblocks.py +158 -158
  15. ocnn/nn/__init__.py +45 -44
  16. ocnn/nn/kernels/__init__.py +14 -0
  17. ocnn/nn/kernels/autotuner.py +416 -0
  18. ocnn/nn/kernels/config.py +67 -0
  19. ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
  20. ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
  21. ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
  22. ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
  23. ocnn/nn/kernels/utils.py +44 -0
  24. ocnn/nn/octree2col.py +53 -53
  25. ocnn/nn/octree2vox.py +50 -50
  26. ocnn/nn/octree_align.py +46 -46
  27. ocnn/nn/octree_conv.py +430 -429
  28. ocnn/nn/octree_conv_t.py +148 -0
  29. ocnn/nn/octree_drop.py +55 -55
  30. ocnn/nn/octree_dwconv.py +222 -222
  31. ocnn/nn/octree_gconv.py +79 -79
  32. ocnn/nn/octree_interp.py +196 -196
  33. ocnn/nn/octree_norm.py +126 -126
  34. ocnn/nn/octree_pad.py +39 -39
  35. ocnn/nn/octree_pool.py +200 -200
  36. ocnn/octree/__init__.py +22 -22
  37. ocnn/octree/octree.py +770 -770
  38. ocnn/octree/points.py +384 -323
  39. ocnn/octree/shuffled_key.py +115 -115
  40. ocnn/utils.py +205 -205
  41. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
  42. ocnn-2.3.0.dist-info/RECORD +45 -0
  43. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
  44. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
  45. ocnn-2.2.8.dist-info/RECORD +0 -36
  46. {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,148 @@
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from torch.autograd import Function
11
+ from typing import List
12
+
13
+ from ocnn.octree import Octree
14
+ from ocnn.nn import OctreeConv
15
+ from ocnn.utils import xavier_uniform_, resize_with_last_val, list2str
16
+ from ocnn.nn.kernels import conv_fwd_implicit_gemm_splitk, conv_bwd_implicit_gemm_splitk
17
+
18
+
19
+ class OctreeConvTritonFunction(Function):
20
+ r''' Wrap the octree convolution for auto-diff.
21
+ '''
22
+
23
+ @staticmethod
24
+ def forward(ctx, data: torch.Tensor, weights: torch.Tensor, bias: torch.Tensor,
25
+ neigh: torch.Tensor):
26
+ data = data.contiguous()
27
+ weights = weights.contiguous()
28
+ neigh = neigh.contiguous()
29
+ if bias is not None:
30
+ bias = bias.contiguous()
31
+
32
+ out = conv_fwd_implicit_gemm_splitk(data, weights, bias, neigh)
33
+ ctx.save_for_backward(data, weights, bias, neigh)
34
+ return out
35
+
36
+ @staticmethod
37
+ def backward(ctx, grad):
38
+ data, weights, bias, neigh = ctx.saved_tensors
39
+ grad = grad.contiguous()
40
+ grad_input, grad_weight, grad_bias = conv_bwd_implicit_gemm_splitk(
41
+ grad, data, weights, bias, neigh, ctx.needs_input_grad)
42
+ return grad_input, grad_weight, grad_bias, None
43
+
44
+
45
+ # alias
46
+ octree_conv_triton = OctreeConvTritonFunction.apply
47
+
48
+
49
+ class OctreeConvTriton(torch.nn.Module):
50
+ r''' Performs octree convolution.
51
+
52
+ Args:
53
+ in_channels (int): Number of input channels.
54
+ out_channels (int): Number of output channels.
55
+ kernel_size (List(int)): The kernel shape, only :obj:`[3]` and :obj:`[3,3,3]`
56
+ are supported now for the triton implementation.
57
+ stride (int): The stride of the convolution, only :obj:`1` is supported now.
58
+ nempty (bool): If True, only performs the convolution on non-empty octree
59
+ nodes; otherwise, performs the convolution on all octree nodes.
60
+ use_bias (bool): If True, add a bias term to the convolution.
61
+
62
+ .. note::
63
+ Each non-empty octree node has exactly 8 children nodes, among which some
64
+ children nodes are non-empty and some are empty. If :attr:`nempty` is true,
65
+ the convolution is performed on non-empty octree nodes only, which is exactly
66
+ the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
67
+ convolution is performed on all octree nodes, which is essential for shape
68
+ reconstruction tasks and can also be used in classification and segmentation
69
+ (with slightly better performance and larger memory cost).
70
+ '''
71
+
72
+ def __init__(self, in_channels: int, out_channels: int,
73
+ kernel_size: List[int] = [3], stride: int = 1,
74
+ nempty: bool = False, direct_method: bool = False,
75
+ use_bias: bool = False, max_buffer: int = int(2e8)):
76
+ super().__init__()
77
+ self.in_channels = in_channels
78
+ self.out_channels = out_channels
79
+ self.kernel_size = resize_with_last_val(kernel_size)
80
+ self.kernel = list2str(self.kernel_size)
81
+ self.stride = stride
82
+ self.nempty = nempty
83
+ self.use_bias = use_bias
84
+ assert self.stride == 1, 'Only stride=1 is supported now.'
85
+ assert self.kernel == '333', 'Only kernel_size=[3,3,3] is supported now.'
86
+
87
+ self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
88
+ self.weights_shape = (self.kdim, self.in_channels, self.out_channels)
89
+ self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
90
+ self.bias = (torch.nn.Parameter(torch.Tensor(self.out_channels))
91
+ if use_bias else None)
92
+ self.reset_parameters()
93
+
94
+ def reset_parameters(self):
95
+ xavier_uniform_(self.weights)
96
+ if self.use_bias:
97
+ torch.nn.init.zeros_(self.bias)
98
+
99
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
100
+ r''' Defines the octree convolution.
101
+
102
+ Args:
103
+ data (torch.Tensor): The input data.
104
+ octree (Octree): The corresponding octree.
105
+ depth (int): The depth of current octree.
106
+ '''
107
+
108
+ # TODO: remove the permute operation by changing the kernel implementation
109
+ weight = self.weights.permute(2, 0, 1) # (V,Ci,Co) -> (Co,V,Ci)
110
+ neigh = octree.get_neigh(depth, self.kernel, self.stride, self.nempty)
111
+ out = octree_conv_triton(data, weight, self.bias, neigh)
112
+ return out
113
+
114
+ def extra_repr(self) -> str:
115
+ r''' Sets the extra representation of the module.
116
+ '''
117
+
118
+ return ('triton, in_channels={}, out_channels={}, kernel_size={}, stride={}, '
119
+ 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
120
+ self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
121
+
122
+
123
+ # alias
124
+ OctreeConvT = OctreeConvTriton
125
+
126
+
127
+ def convert_conv_triton(module: torch.nn.Module) -> torch.nn.Module:
128
+ r''' Convert OctreeConv modules to OctreeConvTriton modules in a network.
129
+
130
+ Args:
131
+ module (torch.nn.Module): The input module.
132
+ '''
133
+
134
+ module_out = module
135
+ if (isinstance(module, OctreeConv) and
136
+ module.stride == 1 and module.kernel_size == [3, 3, 3]):
137
+ module_out = OctreeConvTriton(
138
+ module.in_channels, module.out_channels, module.kernel_size,
139
+ module.stride, module.nempty, use_bias=module.use_bias,)
140
+ with torch.no_grad():
141
+ module_out.weights = module.weights
142
+ if module.use_bias:
143
+ module_out.bias = module.bias
144
+
145
+ for name, child in module.named_children():
146
+ module_out.add_module(name, convert_conv_triton(child))
147
+ del module
148
+ return module_out
ocnn/nn/octree_drop.py CHANGED
@@ -1,55 +1,55 @@
1
- # --------------------------------------------------------
2
- # Octree-based Sparse Convolutional Neural Networks
3
- # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Peng-Shuai Wang
6
- # --------------------------------------------------------
7
-
8
- import torch
9
- from typing import Optional
10
-
11
- from ocnn.octree import Octree
12
-
13
-
14
- class OctreeDropPath(torch.nn.Module):
15
- r'''Drop paths (Stochastic Depth) per sample when applied in main path of
16
- residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
17
-
18
- Args:
19
- drop_prob (int): The probability of drop paths.
20
- nempty (bool): Indicate whether the input data only contains features of the
21
- non-empty octree nodes or not.
22
- scale_by_keep (bool): Whether to scale the kept features proportionally.
23
- '''
24
-
25
- def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
26
- scale_by_keep: bool = True):
27
- super().__init__()
28
-
29
- self.drop_prob = drop_prob
30
- self.nempty = nempty
31
- self.scale_by_keep = scale_by_keep
32
-
33
- def forward(self, data: torch.Tensor, octree: Octree, depth: int,
34
- batch_id: Optional[torch.Tensor] = None):
35
- r''''''
36
-
37
- if self.drop_prob <= 0.0 or not self.training:
38
- return data
39
-
40
- batch_size = octree.batch_size
41
- keep_prob = 1 - self.drop_prob
42
- rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
43
- rnd_tensor = torch.floor(rnd_tensor + keep_prob)
44
- if keep_prob > 0.0 and self.scale_by_keep:
45
- rnd_tensor.div_(keep_prob)
46
-
47
- if batch_id is None:
48
- batch_id = octree.batch_id(depth, self.nempty)
49
- drop_mask = rnd_tensor[batch_id]
50
- output = data * drop_mask
51
- return output
52
-
53
- def extra_repr(self) -> str:
54
- return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
55
- self.drop_prob, self.nempty, self.scale_by_keep) # noqa
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from typing import Optional
10
+
11
+ from ocnn.octree import Octree
12
+
13
+
14
+ class OctreeDropPath(torch.nn.Module):
15
+ r'''Drop paths (Stochastic Depth) per sample when applied in main path of
16
+ residual blocks, following the logic of :func:`timm.models.layers.DropPath`.
17
+
18
+ Args:
19
+ drop_prob (int): The probability of drop paths.
20
+ nempty (bool): Indicate whether the input data only contains features of the
21
+ non-empty octree nodes or not.
22
+ scale_by_keep (bool): Whether to scale the kept features proportionally.
23
+ '''
24
+
25
+ def __init__(self, drop_prob: float = 0.0, nempty: bool = False,
26
+ scale_by_keep: bool = True):
27
+ super().__init__()
28
+
29
+ self.drop_prob = drop_prob
30
+ self.nempty = nempty
31
+ self.scale_by_keep = scale_by_keep
32
+
33
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int,
34
+ batch_id: Optional[torch.Tensor] = None):
35
+ r''''''
36
+
37
+ if self.drop_prob <= 0.0 or not self.training:
38
+ return data
39
+
40
+ batch_size = octree.batch_size
41
+ keep_prob = 1 - self.drop_prob
42
+ rnd_tensor = torch.rand(batch_size, 1, dtype=data.dtype, device=data.device)
43
+ rnd_tensor = torch.floor(rnd_tensor + keep_prob)
44
+ if keep_prob > 0.0 and self.scale_by_keep:
45
+ rnd_tensor.div_(keep_prob)
46
+
47
+ if batch_id is None:
48
+ batch_id = octree.batch_id(depth, self.nempty)
49
+ drop_mask = rnd_tensor[batch_id]
50
+ output = data * drop_mask
51
+ return output
52
+
53
+ def extra_repr(self) -> str:
54
+ return ('drop_prob={:.4f}, nempty={}, scale_by_keep={}').format(
55
+ self.drop_prob, self.nempty, self.scale_by_keep) # noqa