ocnn 2.2.7__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,150 @@
1
+ import math
2
+ import torch
3
+ import triton
4
+ import triton.language as tl
5
+ from .utils import get_num_sm
6
+ from .autotuner import triton_autotune, autotune
7
+ from . import config
8
+ from .conv_fwd_implicit_gemm import conv_fwd_implicit_gemm_kernel
9
+
10
+
11
+ @triton_autotune(
12
+ configs=config.autotune_config,
13
+ key=['LOGN', 'Ci', 'Co', 'V', 'SPLITK', 'allow_tf32'],
14
+ )
15
+ @triton.jit
16
+ def conv_fwd_implicit_gemm_splitk_kernel(
17
+ input,
18
+ weight,
19
+ bias,
20
+ neighbor,
21
+ output,
22
+ # Tensor dimensions
23
+ N, LOGN, Ci, Co, V: tl.constexpr,
24
+ # Meta-parameters
25
+ B1: tl.constexpr, # Block size for N dimension
26
+ B2: tl.constexpr, # Block size for Co dimension
27
+ BK: tl.constexpr, # Block size for K dimension (V * Ci)
28
+ SPLITK: tl.constexpr, # Split K dimension
29
+ allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
30
+ ):
31
+ """
32
+ Sparse submanifold convolution forward kernel using implicit GEMM with split K dimension.
33
+
34
+ Args:
35
+ input (pointer): A pointer to the input tensor of shape (N, Ci)
36
+ weight (pointer): A pointer to the weight tensor of shape (Co, V, Ci)
37
+ bias (pointer): A pointer to the bias tensor of shape (Co)
38
+ neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
39
+ output (pointer): A pointer to the output tensor of shape (N, Co)
40
+ """
41
+ block_id_k = tl.program_id(axis=1) # SplitK dimension
42
+ block_id = tl.program_id(axis=0)
43
+ block_dim_co = tl.cdiv(Co, B2)
44
+ block_id_co = block_id % block_dim_co
45
+ block_id_n = block_id // block_dim_co
46
+
47
+ # Create pointers for submatrices of A and B.
48
+ num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
49
+ k_start = tl.cdiv(num_k * V * block_id_k, SPLITK)
50
+ k_end = tl.cdiv(num_k * V * (block_id_k + 1), SPLITK)
51
+ offset_n = (block_id_n * B1 + tl.arange(0, B1)) % N # (B1,)
52
+ offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
53
+ offset_k = tl.arange(0, BK) # (BK,)
54
+
55
+ # Create a block of the output matrix C.
56
+ accumulator = tl.zeros((B1, B2), dtype=tl.float32)
57
+ curr_v = k_start // num_k
58
+ curr_bk = k_start % num_k
59
+ weight_offset_base = curr_v * Ci + curr_bk * BK
60
+
61
+ weight_ptr = weight + weight_offset_base + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
62
+
63
+ # Iterate along V*Ci dimension.
64
+ for k in range(k_start, k_end):
65
+ v = k // num_k
66
+ bk = k % num_k
67
+ # Calculate pointers to input matrix.
68
+ neighbor_offset_n = tl.load(neighbor + offset_n * V + v).to(tl.int64) # (B1,)
69
+ input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
70
+ # Load the next block of input and weight.
71
+ neigh_mask = neighbor_offset_n != -1
72
+ k_mask = offset_k < Ci - bk * BK
73
+ input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
74
+ weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
75
+ # Accumulate along the K dimension.
76
+ accumulator = tl.dot(input_block, weight_block, accumulator,
77
+ input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
78
+ # Advance the pointers to the next Ci block.
79
+ weight_ptr += min(BK, Ci - bk * BK)
80
+
81
+ # add bias
82
+ if bias is not None and block_id_k == 0:
83
+ bias_block = tl.load(bias + offset_co)
84
+ accumulator += bias_block[None, :]
85
+
86
+ # Write back the block of the output matrix with masks.
87
+ out_offset_n = block_id_n * B1 + tl.arange(0, B1)
88
+ out_offset_co = block_id_co * B2 + tl.arange(0, B2)
89
+ out_ptr = output + block_id_k * N * Co + (out_offset_n[:, None] * Co + out_offset_co[None, :])
90
+ out_mask = (out_offset_n[:, None] < N) & (out_offset_co[None, :] < Co)
91
+ tl.store(out_ptr, accumulator, mask=out_mask)
92
+
93
+
94
+ def conv_fwd_implicit_gemm_splitk_configs(input, weight, bias, neighbor):
95
+ N, Co = neighbor.shape[0], weight.shape[0]
96
+ MAX_NB1 = (N + 128 - 1) // 128
97
+ MAX_NB2 = (Co + 128 - 1) // 128
98
+ NUM_BLOCKS = MAX_NB1 * MAX_NB2
99
+ MIN_NUM_BLOCKS = get_num_sm()
100
+ MAX_NUM_BLOCKS = 32 * get_num_sm()
101
+ MIN_NUM_BLOCKS_LOG2 = max(0, int(math.log2(MIN_NUM_BLOCKS / NUM_BLOCKS)))
102
+ MAX_NUM_BLOCKS_LOG2 = max(1, int(math.log2(MAX_NUM_BLOCKS / NUM_BLOCKS) + 1))
103
+ configs = []
104
+ for i in range(MIN_NUM_BLOCKS_LOG2, MAX_NUM_BLOCKS_LOG2):
105
+ configs.append({'SPLITK': 2 ** i})
106
+ return configs
107
+
108
+
109
+ def conv_fwd_implicit_gemm_splitk_keys(input, weight, bias, neighbor):
110
+ N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
111
+ return f'(2^{int(math.log2(N))}, {Ci}, {Co}, {V})'
112
+
113
+
114
+ @autotune(
115
+ config_fn=conv_fwd_implicit_gemm_splitk_configs,
116
+ key_fn=conv_fwd_implicit_gemm_splitk_keys,
117
+ )
118
+ def conv_fwd_implicit_gemm_splitk(
119
+ input: torch.Tensor,
120
+ weight: torch.Tensor,
121
+ bias: torch.Tensor,
122
+ neighbor: torch.Tensor,
123
+ SPLITK: int = 1,
124
+ ) -> torch.Tensor:
125
+ assert input.shape[1] == weight.shape[2], "Incompatible dimensions"
126
+ assert input.is_contiguous(), "Matrix input must be contiguous"
127
+ assert weight.is_contiguous(), "Matrix weight must be contiguous"
128
+ assert neighbor.is_contiguous(), "Matrix neighbor must be contiguous"
129
+ N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
130
+ LOGN = int(math.log2(N))
131
+ # Launch the kernel.
132
+ if SPLITK == 1:
133
+ output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
134
+ grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
135
+ conv_fwd_implicit_gemm_kernel[grid](
136
+ input, weight, bias, neighbor, output,
137
+ N, LOGN, Ci, Co, V,
138
+ allow_tf32=config.allow_tf32,
139
+ )
140
+ return output
141
+ else:
142
+ output = torch.empty((SPLITK, N, Co), device=input.device, dtype=torch.float32)
143
+ grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']), SPLITK)
144
+ conv_fwd_implicit_gemm_splitk_kernel[grid](
145
+ input, weight, bias, neighbor, output,
146
+ N, LOGN, Ci, Co, V,
147
+ SPLITK=SPLITK,
148
+ allow_tf32=config.allow_tf32,
149
+ )
150
+ return output.sum(dim=0).to(input.dtype)
@@ -0,0 +1,44 @@
1
+ from typing import *
2
+ import torch
3
+ import triton
4
+
5
+
6
+ def get_gpu_name():
7
+ return torch.cuda.get_device_name()
8
+
9
+
10
+ def get_platform_name():
11
+ if torch.cuda.is_available():
12
+ if getattr(torch.version, 'hip', None) is not None:
13
+ return 'hip'
14
+ return 'cuda'
15
+ return 'unknown'
16
+
17
+
18
+ def get_num_sm():
19
+ return torch.cuda.get_device_properties("cuda").multi_processor_count
20
+
21
+
22
+ def get_autotune_config(
23
+ default: List[triton.Config] = None,
24
+ platform: Dict[str, List[triton.Config]] = None,
25
+ device: Dict[str, List[triton.Config]] = None,
26
+ ) -> List[triton.Config]:
27
+ """
28
+ Get the autotune configuration for the current platform and device.
29
+ """
30
+ if device is not None:
31
+ gpu_name = get_gpu_name()
32
+ for key, value in device.items():
33
+ if key.lower() in gpu_name.lower():
34
+ return value
35
+
36
+ if platform is not None:
37
+ platform_name = get_platform_name()
38
+ for key, value in platform.items():
39
+ if key.lower() in platform_name.lower():
40
+ return value
41
+
42
+ if default is None:
43
+ raise ValueError("No autotune configuration found for the current platform and device.")
44
+ return default
ocnn/nn/octree_conv.py CHANGED
@@ -98,7 +98,8 @@ class OctreeConvBase:
98
98
 
99
99
  # Check the shape of input data
100
100
  check = tuple(data.shape) == self.in_shape
101
- assert check, 'The shape of input data is wrong.'
101
+ assert check, ('The shape of input data is wrong: ' +
102
+ 'expected {}, got {}.'.format(self.in_shape, data.shape))
102
103
 
103
104
  # Init the output data
104
105
  out = data.new_zeros(self.out_shape)
@@ -0,0 +1,148 @@
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn
10
+ from torch.autograd import Function
11
+ from typing import List
12
+
13
+ from ocnn.octree import Octree
14
+ from ocnn.nn import OctreeConv
15
+ from ocnn.utils import xavier_uniform_, resize_with_last_val, list2str
16
+ from ocnn.nn.kernels import conv_fwd_implicit_gemm_splitk, conv_bwd_implicit_gemm_splitk
17
+
18
+
19
+ class OctreeConvTritonFunction(Function):
20
+ r''' Wrap the octree convolution for auto-diff.
21
+ '''
22
+
23
+ @staticmethod
24
+ def forward(ctx, data: torch.Tensor, weights: torch.Tensor, bias: torch.Tensor,
25
+ neigh: torch.Tensor):
26
+ data = data.contiguous()
27
+ weights = weights.contiguous()
28
+ neigh = neigh.contiguous()
29
+ if bias is not None:
30
+ bias = bias.contiguous()
31
+
32
+ out = conv_fwd_implicit_gemm_splitk(data, weights, bias, neigh)
33
+ ctx.save_for_backward(data, weights, bias, neigh)
34
+ return out
35
+
36
+ @staticmethod
37
+ def backward(ctx, grad):
38
+ data, weights, bias, neigh = ctx.saved_tensors
39
+ grad = grad.contiguous()
40
+ grad_input, grad_weight, grad_bias = conv_bwd_implicit_gemm_splitk(
41
+ grad, data, weights, bias, neigh, ctx.needs_input_grad)
42
+ return grad_input, grad_weight, grad_bias, None
43
+
44
+
45
+ # alias
46
+ octree_conv_triton = OctreeConvTritonFunction.apply
47
+
48
+
49
+ class OctreeConvTriton(torch.nn.Module):
50
+ r''' Performs octree convolution.
51
+
52
+ Args:
53
+ in_channels (int): Number of input channels.
54
+ out_channels (int): Number of output channels.
55
+ kernel_size (List(int)): The kernel shape, only :obj:`[3]` and :obj:`[3,3,3]`
56
+ are supported now for the triton implementation.
57
+ stride (int): The stride of the convolution, only :obj:`1` is supported now.
58
+ nempty (bool): If True, only performs the convolution on non-empty octree
59
+ nodes; otherwise, performs the convolution on all octree nodes.
60
+ use_bias (bool): If True, add a bias term to the convolution.
61
+
62
+ .. note::
63
+ Each non-empty octree node has exactly 8 children nodes, among which some
64
+ children nodes are non-empty and some are empty. If :attr:`nempty` is true,
65
+ the convolution is performed on non-empty octree nodes only, which is exactly
66
+ the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
67
+ convolution is performed on all octree nodes, which is essential for shape
68
+ reconstruction tasks and can also be used in classification and segmentation
69
+ (with slightly better performance and larger memory cost).
70
+ '''
71
+
72
+ def __init__(self, in_channels: int, out_channels: int,
73
+ kernel_size: List[int] = [3], stride: int = 1,
74
+ nempty: bool = False, direct_method: bool = False,
75
+ use_bias: bool = False, max_buffer: int = int(2e8)):
76
+ super().__init__()
77
+ self.in_channels = in_channels
78
+ self.out_channels = out_channels
79
+ self.kernel_size = resize_with_last_val(kernel_size)
80
+ self.kernel = list2str(self.kernel_size)
81
+ self.stride = stride
82
+ self.nempty = nempty
83
+ self.use_bias = use_bias
84
+ assert self.stride == 1, 'Only stride=1 is supported now.'
85
+ assert self.kernel == '333', 'Only kernel_size=[3,3,3] is supported now.'
86
+
87
+ self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
88
+ self.weights_shape = (self.kdim, self.in_channels, self.out_channels)
89
+ self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
90
+ self.bias = (torch.nn.Parameter(torch.Tensor(self.out_channels))
91
+ if use_bias else None)
92
+ self.reset_parameters()
93
+
94
+ def reset_parameters(self):
95
+ xavier_uniform_(self.weights)
96
+ if self.use_bias:
97
+ torch.nn.init.zeros_(self.bias)
98
+
99
+ def forward(self, data: torch.Tensor, octree: Octree, depth: int):
100
+ r''' Defines the octree convolution.
101
+
102
+ Args:
103
+ data (torch.Tensor): The input data.
104
+ octree (Octree): The corresponding octree.
105
+ depth (int): The depth of current octree.
106
+ '''
107
+
108
+ # TODO: remove the permute operation by changing the kernel implementation
109
+ weight = self.weights.permute(2, 0, 1) # (V,Ci,Co) -> (Co,V,Ci)
110
+ neigh = octree.get_neigh(depth, self.kernel, self.stride, self.nempty)
111
+ out = octree_conv_triton(data, weight, self.bias, neigh)
112
+ return out
113
+
114
+ def extra_repr(self) -> str:
115
+ r''' Sets the extra representation of the module.
116
+ '''
117
+
118
+ return ('triton, in_channels={}, out_channels={}, kernel_size={}, stride={}, '
119
+ 'nempty={}, bias={}').format(self.in_channels, self.out_channels,
120
+ self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa
121
+
122
+
123
+ # alias
124
+ OctreeConvT = OctreeConvTriton
125
+
126
+
127
+ def convert_conv_triton(module: torch.nn.Module) -> torch.nn.Module:
128
+ r''' Convert OctreeConv modules to OctreeConvTriton modules in a network.
129
+
130
+ Args:
131
+ module (torch.nn.Module): The input module.
132
+ '''
133
+
134
+ module_out = module
135
+ if (isinstance(module, OctreeConv) and
136
+ module.stride == 1 and module.kernel_size == [3, 3, 3]):
137
+ module_out = OctreeConvTriton(
138
+ module.in_channels, module.out_channels, module.kernel_size,
139
+ module.stride, module.nempty, use_bias=module.use_bias,)
140
+ with torch.no_grad():
141
+ module_out.weights = module.weights
142
+ if module.use_bias:
143
+ module_out.bias = module.bias
144
+
145
+ for name, child in module.named_children():
146
+ module_out.add_module(name, convert_conv_triton(child))
147
+ del module
148
+ return module_out
ocnn/nn/octree_pad.py CHANGED
@@ -22,10 +22,10 @@ def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0)
22
22
  val (float): The padding value. (Default: :obj:`0.0`)
23
23
  '''
24
24
 
25
- mask = octree.nempty_mask(depth)
25
+ idx = octree.nempty_index(depth)
26
26
  size = (octree.nnum[depth], data.shape[1]) # (N, C)
27
27
  out = torch.full(size, val, dtype=data.dtype, device=data.device)
28
- out[mask] = data
28
+ out[idx] = data
29
29
  return out
30
30
 
31
31
 
@@ -35,5 +35,5 @@ def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
35
35
  Please refer to :func:`octree_depad` for the meaning of the arguments.
36
36
  '''
37
37
 
38
- mask = octree.nempty_mask(depth)
39
- return data[mask]
38
+ idx = octree.nempty_index(depth)
39
+ return data[idx]