ocnn 2.2.7__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,229 @@
1
+ from typing import *
2
+ import math
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+ from .autotuner import triton_autotune
7
+ from . import config
8
+
9
+
10
+ @triton_autotune(
11
+ configs=config.autotune_config,
12
+ key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
13
+ )
14
+ @triton.jit
15
+ def conv_bwd_input_implicit_gemm_kernel(
16
+ grad_output,
17
+ weight,
18
+ neighbor,
19
+ grad_input,
20
+ # Tensor dimensions
21
+ N, LOGN, Ci, Co, V: tl.constexpr,
22
+ # Meta-parameters
23
+ B1: tl.constexpr, # Block size for N dimension
24
+ B2: tl.constexpr, # Block size for Ci dimension
25
+ BK: tl.constexpr, # Block size for K dimension (V * Co)
26
+ allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
27
+ ):
28
+ """
29
+ Sparse submanifold convolution backward to input kernel using implicit GEMM.
30
+
31
+ Args:
32
+ grad_output (pointer): A pointer to the gradient of the output tensor of shape (N, Co)
33
+ weight (pointer): A pointer to the weight tensor of shape (Co, V, Ci)
34
+ neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
35
+ grad_input (pointer): A pointer to the gradient of the input tensor of shape (N, Ci)
36
+ """
37
+ block_id = tl.program_id(axis=0)
38
+ block_dim_ci = tl.cdiv(Ci, B2)
39
+ block_id_ci = block_id % block_dim_ci
40
+ block_id_n = block_id // block_dim_ci
41
+
42
+ # Create pointers for submatrices of A and B.
43
+ num_k = tl.cdiv(Co, BK) # Number of blocks in K dimension
44
+ offset_n = (block_id_n * B1 + tl.arange(0, B1)) % N # (B1,)
45
+ offset_ci = (block_id_ci * B2 + tl.arange(0, B2)) % Ci # (B2,)
46
+ offset_k = tl.arange(0, BK) # (BK,)
47
+
48
+ # Create a block of the output matrix C.
49
+ accumulator = tl.zeros((B1, B2), dtype=tl.float32)
50
+
51
+ # Iterate along V*Co dimension.
52
+ for k in range(num_k * V):
53
+ v = k // num_k
54
+ bk = k % num_k
55
+ # Calculate pointers to grad_output matrix.
56
+ neighbor_offset_n = tl.load(neighbor + offset_n * V + V - 1 - v) # (B1,)
57
+ grad_output_ptr = grad_output + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Co + offset_k[None, :]) # (B1, BK)
58
+ # Calculate pointers to weight matrix.
59
+ weight_ptr = weight + (((offset_k[:, None] + bk * BK) * V + v) * Ci + offset_ci[None, :]) # (BK, B2)
60
+ # Load the next block of input and weight.
61
+ neigh_mask = neighbor_offset_n != -1
62
+ k_mask = offset_k < Co - bk * BK
63
+ grad_output_block = tl.load(grad_output_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
64
+ weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
65
+ # Accumulate along the K dimension.
66
+ accumulator = tl.dot(grad_output_block, weight_block, accumulator,
67
+ input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
68
+ c = accumulator.to(grad_output.type.element_ty)
69
+
70
+ # Write back the block of the output matrix with masks.
71
+ grad_input_offset_n = block_id_n * B1 + tl.arange(0, B1)
72
+ grad_input_offset_ci = block_id_ci * B2 + tl.arange(0, B2)
73
+ grad_input_ptr = grad_input + (grad_input_offset_n[:, None] * Ci + grad_input_offset_ci[None, :])
74
+ grad_input_mask = (grad_input_offset_n[:, None] < N) & (grad_input_offset_ci[None, :] < Ci)
75
+ tl.store(grad_input_ptr, c, mask=grad_input_mask)
76
+
77
+
78
+ heuristics = {
79
+ # BCi must be a power of 2 for tl.dot, but should not exceed Ci or B2
80
+ 'BCi': lambda meta: min(triton.next_power_of_2(meta['Ci']), meta['B2']),
81
+ # BV is calculated based on B2 and BCi
82
+ 'BV': lambda meta: max(1, meta['B2'] // min(triton.next_power_of_2(meta['Ci']), meta['B2'])),
83
+ }
84
+
85
+
86
+ @triton_autotune(
87
+ configs=config.autotune_config,
88
+ key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
89
+ )
90
+ @triton.heuristics(heuristics)
91
+ @triton.jit
92
+ def conv_bwd_weight_implicit_gemm_kernel(
93
+ grad_output,
94
+ input,
95
+ neighbor,
96
+ grad_weight,
97
+ # Tensor dimensions
98
+ N, LOGN, Ci, Co, V: tl.constexpr,
99
+ # Meta-parameters
100
+ B1: tl.constexpr, # Block size for Co dimension
101
+ B2: tl.constexpr, # Block size for V * Ci dimension
102
+ BK: tl.constexpr, # Block size for K dimension (N)
103
+ BV: tl.constexpr, # Block size for V dimension
104
+ BCi: tl.constexpr, # Block size for Ci dimension
105
+ allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
106
+ ):
107
+ """
108
+ Sparse submanifold convolution backward to weight kernel using implicit GEMM.
109
+
110
+ Args:
111
+ grad_output (pointer): A pointer to the gradient of the output tensor of shape (N, Co)
112
+ input (pointer): A pointer to the input tensor of shape (N, Ci)
113
+ neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
114
+ grad_weight (pointer): A pointer to the gradient of the weight tensor of shape (Co, V, Ci)
115
+ """
116
+ block_id_co = tl.program_id(axis=0)
117
+ block_id_vci = tl.program_id(axis=1)
118
+
119
+ # Create pointers for submatrices of A and B.
120
+ num_k = tl.cdiv(N, BK) # Number of blocks in K dimension
121
+ # Use cdiv to handle non-power-of-2 Ci correctly
122
+ num_ci_blocks = tl.cdiv(Ci, BCi)
123
+ offset_co = (block_id_co * B1 + tl.arange(0, B1)) % Co # (B1,)
124
+ offset_v = (tl.arange(0, BV) + (block_id_vci // num_ci_blocks) * BV) % V # (BV,)
125
+ offset_ci = (tl.arange(0, BCi) + (block_id_vci % num_ci_blocks) * BCi) % Ci # (BCi,)
126
+ offset_k = tl.arange(0, BK) # (BK,)
127
+ neighbor_ptr = neighbor + (offset_k[:, None] * V + offset_v[None, :]) # (BK, BV)
128
+ grad_output_ptr = grad_output + (offset_k[None, :] * Co + offset_co[:, None]) # (B1, BK)
129
+
130
+ # Create a block of the output matrix C.
131
+ accumulator = tl.zeros((B1, BV * BCi), dtype=tl.float32)
132
+
133
+ # Iterate along V*Ci dimension.
134
+ for k in range(num_k):
135
+ mask = offset_k < N - k * BK
136
+ # Calculate pointers to input matrix.
137
+ input_offset_n = tl.load(neighbor_ptr, mask=mask[:, None], other=-1) # (BK, BV)
138
+ input_ptr = input + (input_offset_n[:, :, None].to(tl.int64) * Ci + offset_ci[None, None, :]) # (BK, BV, BCi)
139
+ # Load the next block of input and weight.
140
+ grad_output_block = tl.load(grad_output_ptr, mask=mask[None, :], other=0.0)
141
+ input_block = tl.load(input_ptr, mask=input_offset_n[:, :, None] != -1, other=0.0).reshape(BK, BV * BCi)
142
+ # Accumulate along the K dimension.
143
+ accumulator = tl.dot(grad_output_block, input_block, accumulator,
144
+ input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
145
+ # Advance pointers.
146
+ grad_output_ptr += BK * Co
147
+ neighbor_ptr += BK * V
148
+ c = accumulator.to(grad_output.type.element_ty)
149
+
150
+ # Write back the block of the output matrix with masks.
151
+ # Decompose block_id_vci into block_id_v and block_id_ci
152
+ block_id_v = block_id_vci // num_ci_blocks
153
+ block_id_ci = block_id_vci % num_ci_blocks
154
+
155
+ grad_weight_offset_co = block_id_co * B1 + tl.arange(0, B1)
156
+
157
+ # Compute V*Ci linear indices correctly accounting for (V, Ci) layout
158
+ local_v = tl.arange(0, BV)
159
+ local_ci = tl.arange(0, BCi)
160
+ global_v = block_id_v * BV + local_v # (BV,)
161
+ global_ci = block_id_ci * BCi + local_ci # (BCi,)
162
+
163
+ # Linear index in V*Ci space: v * Ci + ci
164
+ grad_weight_offset_vci = (global_v[:, None] * Ci + global_ci[None, :]).reshape(BV * BCi) # (BV*BCi,)
165
+
166
+ grad_weight_ptr = grad_weight + (grad_weight_offset_co[:, None] * V * Ci + grad_weight_offset_vci[None, :])
167
+
168
+ # Create proper mask for V and Ci boundaries
169
+ v_mask = (global_v < V)[:, None] # (BV, 1)
170
+ ci_mask = (global_ci < Ci)[None, :] # (1, BCi)
171
+ vci_mask = (v_mask & ci_mask).reshape(BV * BCi) # (BV*BCi,)
172
+ grad_weight_mask = (grad_weight_offset_co[:, None] < Co) & vci_mask[None, :]
173
+ tl.store(grad_weight_ptr, c, mask=grad_weight_mask)
174
+
175
+
176
+ def conv_bwd_implicit_gemm(
177
+ grad_output: torch.Tensor,
178
+ input: torch.Tensor,
179
+ weight: torch.Tensor,
180
+ bias: torch.Tensor,
181
+ neighbor: torch.Tensor,
182
+ needs_input_grad: List[bool],
183
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
184
+ assert grad_output.is_contiguous(), "Matrix grad_output must be contiguous"
185
+ assert input.shape[1] == weight.shape[2], "Incompatible dimensions"
186
+ assert input.is_contiguous(), "Matrix input must be contiguous"
187
+ assert weight.is_contiguous(), "Matrix weight must be contiguous"
188
+ assert neighbor.is_contiguous(), "Matrix neighbor must be contiguous"
189
+ N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
190
+ LOGN = int(math.log2(N))
191
+
192
+ grad_input, grad_weight, grad_bias = None, None, None
193
+
194
+ # Grad for input
195
+ if needs_input_grad[0]:
196
+ # Allocate output matrix output.
197
+ grad_input = torch.empty((N, Ci), device=input.device, dtype=input.dtype)
198
+ # Launch the kernel.
199
+ grid = lambda META: (triton.cdiv(Ci, META['B2']) * triton.cdiv(N, META['B1']),)
200
+ conv_bwd_input_implicit_gemm_kernel[grid](
201
+ grad_output,
202
+ weight,
203
+ neighbor,
204
+ grad_input,
205
+ N, LOGN, Ci, Co, V,
206
+ allow_tf32=config.allow_tf32,
207
+ )
208
+
209
+ # Grad for weight
210
+ if needs_input_grad[1]:
211
+ # Allocate output matrix output.
212
+ grad_weight = torch.empty((Co, V, Ci), device=weight.device, dtype=weight.dtype)
213
+ # Launch the kernel.
214
+ # Use cdiv separately for V and Ci to correctly handle non-power-of-2 channels
215
+ grid = lambda META: (triton.cdiv(Co, META['B1']), triton.cdiv(V, META['BV']) * triton.cdiv(Ci, META['BCi']))
216
+ conv_bwd_weight_implicit_gemm_kernel[grid](
217
+ grad_output,
218
+ input,
219
+ neighbor,
220
+ grad_weight,
221
+ N, LOGN, Ci, Co, V,
222
+ allow_tf32=config.allow_tf32,
223
+ )
224
+
225
+ # Grad for bias
226
+ if needs_input_grad[2]:
227
+ grad_bias = grad_output.sum(0)
228
+
229
+ return grad_input, grad_weight, grad_bias
@@ -0,0 +1,347 @@
1
+ from typing import *
2
+ import math
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+ from .utils import get_num_sm
7
+ from .autotuner import triton_autotune, autotune
8
+ from . import config
9
+ from .conv_bwd_implicit_gemm import (
10
+ conv_bwd_input_implicit_gemm_kernel,
11
+ conv_bwd_weight_implicit_gemm_kernel,
12
+ )
13
+
14
+
15
+ @triton_autotune(
16
+ configs=config.autotune_config,
17
+ key=['LOGN', 'Ci', 'Co', 'V', 'SPLITK', 'allow_tf32'],
18
+ )
19
+ @triton.jit
20
+ def conv_bwd_input_implicit_gemm_splitk_kernel(
21
+ grad_output,
22
+ weight,
23
+ neighbor,
24
+ grad_input,
25
+ # Tensor dimensions
26
+ N, LOGN, Ci, Co, V: tl.constexpr,
27
+ # Meta-parameters
28
+ B1: tl.constexpr, # Block size for N dimension
29
+ B2: tl.constexpr, # Block size for Ci dimension
30
+ BK: tl.constexpr, # Block size for K dimension (V * Co)
31
+ SPLITK: tl.constexpr, # Split K dimension
32
+ allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
33
+ ):
34
+ """
35
+ Sparse submanifold convolution backward to input kernel using implicit GEMM.
36
+
37
+ Args:
38
+ grad_output (pointer): A pointer to the gradient of the output tensor of shape (N, Co)
39
+ weight (pointer): A pointer to the weight tensor of shape (Co, V, Ci)
40
+ neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
41
+ grad_input (pointer): A pointer to the gradient of the input tensor of shape (N, Ci)
42
+ """
43
+ block_id_k = tl.program_id(axis=1) # SplitK dimension
44
+ block_id = tl.program_id(axis=0)
45
+ block_dim_ci = tl.cdiv(Ci, B2)
46
+ block_id_ci = block_id % block_dim_ci
47
+ block_id_n = block_id // block_dim_ci
48
+
49
+ # Create pointers for submatrices of A and B.
50
+ num_k = tl.cdiv(Co, BK) # Number of blocks in K dimension
51
+ k_start = tl.cdiv(num_k * V * block_id_k, SPLITK)
52
+ k_end = tl.cdiv(num_k * V * (block_id_k + 1), SPLITK)
53
+ offset_n = (block_id_n * B1 + tl.arange(0, B1)) % N # (B1,)
54
+ offset_ci = (block_id_ci * B2 + tl.arange(0, B2)) % Ci # (B2,)
55
+ offset_k = tl.arange(0, BK) # (BK,)
56
+
57
+ # Create a block of the output matrix C.
58
+ accumulator = tl.zeros((B1, B2), dtype=tl.float32)
59
+
60
+ # Iterate along V*Co dimension.
61
+ for k in range(k_start, k_end):
62
+ v = k // num_k
63
+ bk = k % num_k
64
+ # Calculate pointers to grad_output matrix.
65
+ neighbor_offset_n = tl.load(neighbor + offset_n * V + V - 1 - v) # (B1,)
66
+ grad_output_ptr = grad_output + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Co + offset_k[None, :]) # (B1, BK)
67
+ # Calculate pointers to weight matrix.
68
+ weight_ptr = weight + (((offset_k[:, None] + bk * BK) * V + v) * Ci + offset_ci[None, :]) # (BK, B2)
69
+ # Load the next block of input and weight.
70
+ neigh_mask = neighbor_offset_n != -1
71
+ k_mask = offset_k < Co - bk * BK
72
+ grad_output_block = tl.load(grad_output_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
73
+ weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
74
+ # Accumulate along the K dimension.
75
+ accumulator = tl.dot(grad_output_block, weight_block, accumulator,
76
+ input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
77
+
78
+ # Write back the block of the output matrix with masks.
79
+ grad_input_offset_n = block_id_n * B1 + tl.arange(0, B1)
80
+ grad_input_offset_ci = block_id_ci * B2 + tl.arange(0, B2)
81
+ grad_input_ptr = grad_input + block_id_k * N * Ci + (grad_input_offset_n[:, None] * Ci + grad_input_offset_ci[None, :])
82
+ grad_input_mask = (grad_input_offset_n[:, None] < N) & (grad_input_offset_ci[None, :] < Ci)
83
+ tl.store(grad_input_ptr, accumulator, mask=grad_input_mask)
84
+
85
+
86
+ heuristics = {
87
+ # BCi should be power of 2 for efficient tl.dot, but not exceed Ci or B2
88
+ 'BCi': lambda meta: min(triton.next_power_of_2(meta['Ci']), meta['B2']),
89
+ # BV is calculated based on B2 and BCi
90
+ 'BV': lambda meta: max(1, meta['B2'] // min(triton.next_power_of_2(meta['Ci']), meta['B2'])),
91
+ }
92
+
93
+
94
+ @triton_autotune(
95
+ configs=config.autotune_config,
96
+ key=['LOGN', 'Ci', 'Co', 'V', 'SPLITK', 'allow_tf32'],
97
+ )
98
+ @triton.heuristics(heuristics)
99
+ @triton.jit
100
+ def conv_bwd_weight_implicit_gemm_splitk_kernel(
101
+ grad_output,
102
+ input,
103
+ neighbor,
104
+ grad_weight,
105
+ # Tensor dimensions
106
+ N, LOGN, Ci, Co, V: tl.constexpr,
107
+ # Meta-parameters
108
+ B1: tl.constexpr, # Block size for Co dimension
109
+ B2: tl.constexpr, # Block size for V * Ci dimension
110
+ BK: tl.constexpr, # Block size for K dimension (N)
111
+ BV: tl.constexpr, # Block size for V dimension
112
+ BCi: tl.constexpr, # Block size for Ci dimension
113
+ SPLITK: tl.constexpr, # Split K dimension
114
+ allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
115
+ ):
116
+ """
117
+ Sparse submanifold convolution backward to weight kernel using implicit GEMM.
118
+
119
+ Args:
120
+ grad_output (pointer): A pointer to the gradient of the output tensor of shape (N, Co)
121
+ input (pointer): A pointer to the input tensor of shape (N, Ci)
122
+ neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
123
+ grad_weight (pointer): A pointer to the gradient of the weight tensor of shape (Co, V, Ci)
124
+ """
125
+ block_id_co = tl.program_id(axis=0)
126
+ block_id_vci = tl.program_id(axis=1)
127
+ block_id_k = tl.program_id(axis=2)
128
+
129
+ # Create pointers for submatrices of A and B.
130
+ num_k = tl.cdiv(N, BK) # Number of blocks in K dimension
131
+ k_start = tl.cdiv(num_k * block_id_k, SPLITK)
132
+ k_end = tl.cdiv(num_k * (block_id_k + 1), SPLITK)
133
+ # Use cdiv to handle non-power-of-2 Ci correctly
134
+ num_ci_blocks = tl.cdiv(Ci, BCi)
135
+ offset_co = (block_id_co * B1 + tl.arange(0, B1)) % Co # (B1,)
136
+ offset_v = (tl.arange(0, BV) + (block_id_vci // num_ci_blocks) * BV) % V # (BV,)
137
+ offset_ci = (tl.arange(0, BCi) + (block_id_vci % num_ci_blocks) * BCi) % Ci # (BCi,)
138
+ offset_k = tl.arange(0, BK) # (BK,)
139
+ neighbor_ptr = neighbor + k_start * BK * V + (offset_k[:, None] * V + offset_v[None, :]) # (BK, BV)
140
+ grad_output_ptr = grad_output + k_start * BK * Co + (offset_k[None, :] * Co + offset_co[:, None]) # (B1, BK)
141
+
142
+ # Create a block of the output matrix C.
143
+ accumulator = tl.zeros((B1, BV * BCi), dtype=tl.float32)
144
+
145
+ # Iterate along V*Ci dimension.
146
+ for k in range(k_start, k_end):
147
+ mask = offset_k < N - k * BK
148
+ # Calculate pointers to input matrix.
149
+ input_offset_n = tl.load(neighbor_ptr, mask=mask[:, None], other=-1) # (BK, BV)
150
+ input_ptr = input + (input_offset_n[:, :, None].to(tl.int64) * Ci + offset_ci[None, None, :]) # (BK, BV, BCi)
151
+ # Load the next block of input and weight.
152
+ grad_output_block = tl.load(grad_output_ptr, mask=mask[None, :], other=0.0)
153
+ input_block = tl.load(input_ptr, mask=input_offset_n[:, :, None] != -1, other=0.0).reshape(BK, BV * BCi)
154
+ # Accumulate along the K dimension.
155
+ accumulator = tl.dot(grad_output_block, input_block, accumulator,
156
+ input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
157
+ # Advance pointers.
158
+ grad_output_ptr += BK * Co
159
+ neighbor_ptr += BK * V
160
+
161
+ # Write back the block of the output matrix with masks.
162
+ # Decompose block_id_vci into block_id_v and block_id_ci
163
+ block_id_v = block_id_vci // num_ci_blocks
164
+ block_id_ci = block_id_vci % num_ci_blocks
165
+
166
+ grad_weight_offset_co = block_id_co * B1 + tl.arange(0, B1)
167
+
168
+ # Compute V*Ci linear indices correctly accounting for (V, Ci) layout
169
+ local_v = tl.arange(0, BV)
170
+ local_ci = tl.arange(0, BCi)
171
+ global_v = block_id_v * BV + local_v # (BV,)
172
+ global_ci = block_id_ci * BCi + local_ci # (BCi,)
173
+
174
+ # Linear index in V*Ci space: v * Ci + ci
175
+ grad_weight_offset_vci = (global_v[:, None] * Ci + global_ci[None, :]).reshape(BV * BCi) # (BV*BCi,)
176
+
177
+ grad_weight_ptr = grad_weight + block_id_k * Co * V * Ci + (grad_weight_offset_co[:, None] * V * Ci + grad_weight_offset_vci[None, :])
178
+
179
+ # Create proper mask for V and Ci boundaries
180
+ v_mask = (global_v < V)[:, None] # (BV, 1)
181
+ ci_mask = (global_ci < Ci)[None, :] # (1, BCi)
182
+ vci_mask = (v_mask & ci_mask).reshape(BV * BCi) # (BV*BCi,)
183
+ grad_weight_mask = (grad_weight_offset_co[:, None] < Co) & vci_mask[None, :]
184
+ tl.store(grad_weight_ptr, accumulator, mask=grad_weight_mask)
185
+
186
+
187
+ def conv_bwd_input_implicit_gemm_splitk_configs(grad_output, weight, neighbor):
188
+ N, Ci = neighbor.shape[0], weight.shape[-1]
189
+ MAX_NB1 = (N + 128 - 1) // 128
190
+ MAX_NB2 = (Ci + 128 - 1) // 128
191
+ NUM_BLOCKS = MAX_NB1 * MAX_NB2
192
+ MIN_NUM_BLOCKS = get_num_sm()
193
+ MAX_NUM_BLOCKS = 32 * get_num_sm()
194
+ MIN_NUM_BLOCKS_LOG2 = max(0, int(math.log2(MIN_NUM_BLOCKS / NUM_BLOCKS)))
195
+ MAX_NUM_BLOCKS_LOG2 = max(1, int(math.log2(MAX_NUM_BLOCKS / NUM_BLOCKS) + 1))
196
+ configs = []
197
+ for i in range(MIN_NUM_BLOCKS_LOG2, MAX_NUM_BLOCKS_LOG2):
198
+ configs.append({'SPLITK': 2 ** i})
199
+ return configs
200
+
201
+
202
+ def conv_bwd_input_implicit_gemm_splitk_keys(grad_output, weight, neighbor):
203
+ N, Ci, Co, V = neighbor.shape[0], weight.shape[-1], weight.shape[0], weight.shape[1]
204
+ return f'(2^{int(math.log2(N))}, {Ci}, {Co}, {V})'
205
+
206
+
207
+ @autotune(
208
+ config_fn=conv_bwd_input_implicit_gemm_splitk_configs,
209
+ key_fn=conv_bwd_input_implicit_gemm_splitk_keys,
210
+ )
211
+ def conv_bwd_input_implicit_gemm_splitk(
212
+ grad_output: torch.Tensor,
213
+ weight: torch.Tensor,
214
+ neighbor: torch.Tensor,
215
+ SPLITK: int = 1,
216
+ ) -> torch.Tensor:
217
+ N, Ci, Co, V = neighbor.shape[0], weight.shape[-1], weight.shape[0], weight.shape[1]
218
+ LOGN = int(math.log2(N))
219
+ # Launch the kernel.
220
+ if SPLITK == 1:
221
+ grad_input = torch.empty((N, Ci), device=weight.device, dtype=weight.dtype)
222
+ grid = lambda META: (triton.cdiv(Ci, META['B2']) * triton.cdiv(N, META['B1']),)
223
+ conv_bwd_input_implicit_gemm_kernel[grid](
224
+ grad_output,
225
+ weight,
226
+ neighbor,
227
+ grad_input,
228
+ N, LOGN, Ci, Co, V,
229
+ allow_tf32=config.allow_tf32,
230
+ )
231
+ return grad_input
232
+ else:
233
+ grad_input = torch.empty((SPLITK, N, Ci), device=weight.device, dtype=torch.float32)
234
+ grid = lambda META: (triton.cdiv(Ci, META['B2']) * triton.cdiv(N, META['B1']), SPLITK)
235
+ conv_bwd_input_implicit_gemm_splitk_kernel[grid](
236
+ grad_output,
237
+ weight,
238
+ neighbor,
239
+ grad_input,
240
+ N, LOGN, Ci, Co, V,
241
+ SPLITK=SPLITK,
242
+ allow_tf32=config.allow_tf32,
243
+ )
244
+ return grad_input.sum(0).to(weight.dtype)
245
+
246
+
247
+ def conv_bwd_weight_implicit_gemm_splitk_configs(grad_output, input, neighbor):
248
+ Co, V, Ci = grad_output.shape[1], neighbor.shape[1], input.shape[1]
249
+ MAX_NB1 = (Co + 128 - 1) // 128
250
+ MAX_NB2 = (V * Ci + 128 - 1) // 128
251
+ NUM_BLOCKS = MAX_NB1 * MAX_NB2
252
+ MIN_NUM_BLOCKS = get_num_sm()
253
+ MAX_NUM_BLOCKS = 32 * get_num_sm()
254
+ MIN_NUM_BLOCKS_LOG2 = max(0, int(math.log2(MIN_NUM_BLOCKS / NUM_BLOCKS)))
255
+ MAX_NUM_BLOCKS_LOG2 = max(1, int(math.log2(MAX_NUM_BLOCKS / NUM_BLOCKS) + 1))
256
+ configs = []
257
+ for i in range(MIN_NUM_BLOCKS_LOG2, MAX_NUM_BLOCKS_LOG2):
258
+ configs.append({'SPLITK': 2 ** i})
259
+ return configs
260
+
261
+
262
+ def conv_bwd_weight_implicit_gemm_splitk_keys(grad_output, input, neighbor):
263
+ N, Ci, Co, V = neighbor.shape[0], input.shape[1], grad_output.shape[1], neighbor.shape[1]
264
+ return f'(2^{int(math.log2(N))}, {Ci}, {Co}, {V})'
265
+
266
+
267
+ @autotune(
268
+ config_fn=conv_bwd_weight_implicit_gemm_splitk_configs,
269
+ key_fn=conv_bwd_weight_implicit_gemm_splitk_keys,
270
+ )
271
+ def conv_bwd_weight_implicit_gemm_splitk(
272
+ grad_output: torch.Tensor,
273
+ input: torch.Tensor,
274
+ neighbor: torch.Tensor,
275
+ SPLITK: int = 1,
276
+ ) -> torch.Tensor:
277
+ N, Ci, Co, V = neighbor.shape[0], input.shape[1], grad_output.shape[1], neighbor.shape[1]
278
+ LOGN = int(math.log2(N))
279
+ # Launch the kernel.
280
+ if SPLITK == 1:
281
+ grad_weight = torch.empty((Co, V, Ci), device=grad_output.device, dtype=grad_output.dtype)
282
+ # Use cdiv separately for V and Ci to correctly handle non-power-of-2 channels
283
+ grid = lambda META: (triton.cdiv(Co, META['B1']), triton.cdiv(V, META['BV']) * triton.cdiv(Ci, META['BCi']))
284
+ conv_bwd_weight_implicit_gemm_kernel[grid](
285
+ grad_output,
286
+ input,
287
+ neighbor,
288
+ grad_weight,
289
+ N, LOGN, Ci, Co, V,
290
+ allow_tf32=config.allow_tf32,
291
+ )
292
+ return grad_weight
293
+ else:
294
+ grad_weight = torch.empty((SPLITK, Co, V, Ci), device=grad_output.device, dtype=torch.float32)
295
+ # Use cdiv separately for V and Ci to correctly handle non-power-of-2 channels
296
+ grid = lambda META: (triton.cdiv(Co, META['B1']), triton.cdiv(V, META['BV']) * triton.cdiv(Ci, META['BCi']), SPLITK)
297
+ conv_bwd_weight_implicit_gemm_splitk_kernel[grid](
298
+ grad_output,
299
+ input,
300
+ neighbor,
301
+ grad_weight,
302
+ N, LOGN, Ci, Co, V,
303
+ SPLITK=SPLITK,
304
+ allow_tf32=config.allow_tf32,
305
+ )
306
+ return grad_weight.sum(0).to(grad_output.dtype)
307
+
308
+
309
+ def conv_bwd_implicit_gemm_splitk(
310
+ grad_output: torch.Tensor,
311
+ input: torch.Tensor,
312
+ weight: torch.Tensor,
313
+ bias: torch.Tensor,
314
+ neighbor: torch.Tensor,
315
+ needs_input_grad,
316
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
317
+ assert grad_output.is_contiguous(), "Matrix grad_output must be contiguous"
318
+ assert input.shape[1] == weight.shape[2], "Incompatible dimensions"
319
+ assert input.is_contiguous(), "Matrix input must be contiguous"
320
+ assert weight.is_contiguous(), "Matrix weight must be contiguous"
321
+ assert neighbor.is_contiguous(), "Matrix neighbor must be contiguous"
322
+ N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
323
+ LOGN = int(math.log2(N))
324
+
325
+ grad_input, grad_weight, grad_bias = None, None, None
326
+
327
+ # Grad for input
328
+ if needs_input_grad[0]:
329
+ grad_input = conv_bwd_input_implicit_gemm_splitk(
330
+ grad_output,
331
+ weight,
332
+ neighbor,
333
+ )
334
+
335
+ # Grad for weight
336
+ if needs_input_grad[1]:
337
+ grad_weight = conv_bwd_weight_implicit_gemm_splitk(
338
+ grad_output,
339
+ input,
340
+ neighbor,
341
+ )
342
+
343
+ # Grad for bias
344
+ if needs_input_grad[2]:
345
+ grad_bias = grad_output.sum(0)
346
+
347
+ return grad_input, grad_weight, grad_bias
@@ -0,0 +1,109 @@
1
+ from typing import *
2
+ import math
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+ from .autotuner import triton_autotune
7
+ from . import config
8
+
9
+
10
+ @triton_autotune(
11
+ configs=config.autotune_config,
12
+ key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
13
+ )
14
+ @triton.jit
15
+ def conv_fwd_implicit_gemm_kernel(
16
+ input,
17
+ weight,
18
+ bias,
19
+ neighbor,
20
+ output,
21
+ # Tensor dimensions
22
+ N, LOGN, Ci, Co, V: tl.constexpr,
23
+ # Meta-parameters
24
+ B1: tl.constexpr, # Block size for N dimension
25
+ B2: tl.constexpr, # Block size for Co dimension
26
+ BK: tl.constexpr, # Block size for K dimension (V * Ci)
27
+ allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
28
+ ):
29
+ """
30
+ Sparse submanifold convolution forward kernel using implicit GEMM.
31
+
32
+ Args:
33
+ input (pointer): A pointer to the input tensor of shape (N, Ci)
34
+ weight (pointer): A pointer to the weight tensor of shape (Co, V, Ci)
35
+ bias (pointer): A pointer to the bias tensor of shape (Co)
36
+ neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
37
+ output (pointer): A pointer to the output tensor of shape (N, Co)
38
+ """
39
+ block_id = tl.program_id(axis=0)
40
+ block_dim_co = tl.cdiv(Co, B2)
41
+ block_id_co = block_id % block_dim_co
42
+ block_id_n = block_id // block_dim_co
43
+
44
+ # Create pointers for submatrices of A and B.
45
+ num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
46
+ offset_n = (block_id_n * B1 + tl.arange(0, B1)) % N # (B1,)
47
+ offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
48
+ offset_k = tl.arange(0, BK) # (BK,)
49
+
50
+ # Create a block of the output matrix C.
51
+ accumulator = tl.zeros((B1, B2), dtype=tl.float32)
52
+
53
+ # Calculate pointers to weight matrix.
54
+ weight_ptr = weight + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
55
+
56
+ # Iterate along V*Ci dimension.
57
+ for k in range(num_k * V):
58
+ v = k // num_k
59
+ bk = k % num_k
60
+ # Calculate pointers to input matrix.
61
+ neighbor_offset_n = tl.load(neighbor + offset_n * V + v) # (B1,)
62
+ input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
63
+ # Load the next block of input and weight.
64
+ neigh_mask = neighbor_offset_n != -1
65
+ k_mask = offset_k < Ci - bk * BK
66
+ input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
67
+ weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
68
+ # Accumulate along the K dimension.
69
+ accumulator = tl.dot(input_block, weight_block, accumulator,
70
+ input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
71
+ # Advance the pointers to the next Ci block.
72
+ weight_ptr += min(BK, Ci - bk * BK)
73
+ c = accumulator.to(input.type.element_ty)
74
+
75
+ # add bias
76
+ if bias is not None:
77
+ bias_block = tl.load(bias + offset_co)
78
+ c += bias_block[None, :]
79
+
80
+ # Write back the block of the output matrix with masks.
81
+ out_offset_n = block_id_n * B1 + tl.arange(0, B1)
82
+ out_offset_co = block_id_co * B2 + tl.arange(0, B2)
83
+ out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :])
84
+ out_mask = (out_offset_n[:, None] < N) & (out_offset_co[None, :] < Co)
85
+ tl.store(out_ptr, c, mask=out_mask)
86
+
87
+
88
+ def conv_fwd_implicit_gemm(
89
+ input: torch.Tensor,
90
+ weight: torch.Tensor,
91
+ bias: torch.Tensor,
92
+ neighbor: torch.Tensor,
93
+ ) -> torch.Tensor:
94
+ assert input.shape[1] == weight.shape[2], "Incompatible dimensions"
95
+ assert input.is_contiguous(), "Matrix input must be contiguous"
96
+ assert weight.is_contiguous(), "Matrix weight must be contiguous"
97
+ assert neighbor.is_contiguous(), "Matrix neighbor must be contiguous"
98
+ N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
99
+ LOGN = int(math.log2(N))
100
+ # Allocate output matrix output.
101
+ output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
102
+ # Launch the kernel.
103
+ grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
104
+ conv_fwd_implicit_gemm_kernel[grid](
105
+ input, weight, bias, neighbor, output,
106
+ N, LOGN, Ci, Co, V,
107
+ allow_tf32=config.allow_tf32,
108
+ )
109
+ return output