ocnn 2.2.8__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ocnn/__init__.py +24 -24
- ocnn/dataset.py +160 -160
- ocnn/models/__init__.py +29 -29
- ocnn/models/autoencoder.py +155 -155
- ocnn/models/hrnet.py +192 -192
- ocnn/models/image2shape.py +128 -128
- ocnn/models/lenet.py +46 -46
- ocnn/models/ounet.py +94 -94
- ocnn/models/resnet.py +53 -53
- ocnn/models/segnet.py +72 -72
- ocnn/models/unet.py +105 -105
- ocnn/modules/__init__.py +26 -26
- ocnn/modules/modules.py +303 -303
- ocnn/modules/resblocks.py +158 -158
- ocnn/nn/__init__.py +45 -44
- ocnn/nn/kernels/__init__.py +14 -0
- ocnn/nn/kernels/autotuner.py +416 -0
- ocnn/nn/kernels/config.py +67 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm.py +229 -0
- ocnn/nn/kernels/conv_bwd_implicit_gemm_splitk.py +347 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm.py +109 -0
- ocnn/nn/kernels/conv_fwd_implicit_gemm_splitk.py +150 -0
- ocnn/nn/kernels/utils.py +44 -0
- ocnn/nn/octree2col.py +53 -53
- ocnn/nn/octree2vox.py +50 -50
- ocnn/nn/octree_align.py +46 -46
- ocnn/nn/octree_conv.py +430 -429
- ocnn/nn/octree_conv_t.py +148 -0
- ocnn/nn/octree_drop.py +55 -55
- ocnn/nn/octree_dwconv.py +222 -222
- ocnn/nn/octree_gconv.py +79 -79
- ocnn/nn/octree_interp.py +196 -196
- ocnn/nn/octree_norm.py +126 -126
- ocnn/nn/octree_pad.py +39 -39
- ocnn/nn/octree_pool.py +200 -200
- ocnn/octree/__init__.py +22 -22
- ocnn/octree/octree.py +770 -770
- ocnn/octree/points.py +384 -323
- ocnn/octree/shuffled_key.py +115 -115
- ocnn/utils.py +205 -205
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/METADATA +117 -111
- ocnn-2.3.0.dist-info/RECORD +45 -0
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/WHEEL +1 -1
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/licenses/LICENSE +21 -21
- ocnn-2.2.8.dist-info/RECORD +0 -36
- {ocnn-2.2.8.dist-info → ocnn-2.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
from typing import *
|
|
2
|
+
import math
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
from .autotuner import triton_autotune
|
|
7
|
+
from . import config
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@triton_autotune(
|
|
11
|
+
configs=config.autotune_config,
|
|
12
|
+
key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
|
|
13
|
+
)
|
|
14
|
+
@triton.jit
|
|
15
|
+
def conv_bwd_input_implicit_gemm_kernel(
|
|
16
|
+
grad_output,
|
|
17
|
+
weight,
|
|
18
|
+
neighbor,
|
|
19
|
+
grad_input,
|
|
20
|
+
# Tensor dimensions
|
|
21
|
+
N, LOGN, Ci, Co, V: tl.constexpr,
|
|
22
|
+
# Meta-parameters
|
|
23
|
+
B1: tl.constexpr, # Block size for N dimension
|
|
24
|
+
B2: tl.constexpr, # Block size for Ci dimension
|
|
25
|
+
BK: tl.constexpr, # Block size for K dimension (V * Co)
|
|
26
|
+
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Sparse submanifold convolution backward to input kernel using implicit GEMM.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
grad_output (pointer): A pointer to the gradient of the output tensor of shape (N, Co)
|
|
33
|
+
weight (pointer): A pointer to the weight tensor of shape (Co, V, Ci)
|
|
34
|
+
neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
|
|
35
|
+
grad_input (pointer): A pointer to the gradient of the input tensor of shape (N, Ci)
|
|
36
|
+
"""
|
|
37
|
+
block_id = tl.program_id(axis=0)
|
|
38
|
+
block_dim_ci = tl.cdiv(Ci, B2)
|
|
39
|
+
block_id_ci = block_id % block_dim_ci
|
|
40
|
+
block_id_n = block_id // block_dim_ci
|
|
41
|
+
|
|
42
|
+
# Create pointers for submatrices of A and B.
|
|
43
|
+
num_k = tl.cdiv(Co, BK) # Number of blocks in K dimension
|
|
44
|
+
offset_n = (block_id_n * B1 + tl.arange(0, B1)) % N # (B1,)
|
|
45
|
+
offset_ci = (block_id_ci * B2 + tl.arange(0, B2)) % Ci # (B2,)
|
|
46
|
+
offset_k = tl.arange(0, BK) # (BK,)
|
|
47
|
+
|
|
48
|
+
# Create a block of the output matrix C.
|
|
49
|
+
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
|
|
50
|
+
|
|
51
|
+
# Iterate along V*Co dimension.
|
|
52
|
+
for k in range(num_k * V):
|
|
53
|
+
v = k // num_k
|
|
54
|
+
bk = k % num_k
|
|
55
|
+
# Calculate pointers to grad_output matrix.
|
|
56
|
+
neighbor_offset_n = tl.load(neighbor + offset_n * V + V - 1 - v) # (B1,)
|
|
57
|
+
grad_output_ptr = grad_output + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Co + offset_k[None, :]) # (B1, BK)
|
|
58
|
+
# Calculate pointers to weight matrix.
|
|
59
|
+
weight_ptr = weight + (((offset_k[:, None] + bk * BK) * V + v) * Ci + offset_ci[None, :]) # (BK, B2)
|
|
60
|
+
# Load the next block of input and weight.
|
|
61
|
+
neigh_mask = neighbor_offset_n != -1
|
|
62
|
+
k_mask = offset_k < Co - bk * BK
|
|
63
|
+
grad_output_block = tl.load(grad_output_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
|
|
64
|
+
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
|
|
65
|
+
# Accumulate along the K dimension.
|
|
66
|
+
accumulator = tl.dot(grad_output_block, weight_block, accumulator,
|
|
67
|
+
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
|
68
|
+
c = accumulator.to(grad_output.type.element_ty)
|
|
69
|
+
|
|
70
|
+
# Write back the block of the output matrix with masks.
|
|
71
|
+
grad_input_offset_n = block_id_n * B1 + tl.arange(0, B1)
|
|
72
|
+
grad_input_offset_ci = block_id_ci * B2 + tl.arange(0, B2)
|
|
73
|
+
grad_input_ptr = grad_input + (grad_input_offset_n[:, None] * Ci + grad_input_offset_ci[None, :])
|
|
74
|
+
grad_input_mask = (grad_input_offset_n[:, None] < N) & (grad_input_offset_ci[None, :] < Ci)
|
|
75
|
+
tl.store(grad_input_ptr, c, mask=grad_input_mask)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
heuristics = {
|
|
79
|
+
# BCi must be a power of 2 for tl.dot, but should not exceed Ci or B2
|
|
80
|
+
'BCi': lambda meta: min(triton.next_power_of_2(meta['Ci']), meta['B2']),
|
|
81
|
+
# BV is calculated based on B2 and BCi
|
|
82
|
+
'BV': lambda meta: max(1, meta['B2'] // min(triton.next_power_of_2(meta['Ci']), meta['B2'])),
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@triton_autotune(
|
|
87
|
+
configs=config.autotune_config,
|
|
88
|
+
key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
|
|
89
|
+
)
|
|
90
|
+
@triton.heuristics(heuristics)
|
|
91
|
+
@triton.jit
|
|
92
|
+
def conv_bwd_weight_implicit_gemm_kernel(
|
|
93
|
+
grad_output,
|
|
94
|
+
input,
|
|
95
|
+
neighbor,
|
|
96
|
+
grad_weight,
|
|
97
|
+
# Tensor dimensions
|
|
98
|
+
N, LOGN, Ci, Co, V: tl.constexpr,
|
|
99
|
+
# Meta-parameters
|
|
100
|
+
B1: tl.constexpr, # Block size for Co dimension
|
|
101
|
+
B2: tl.constexpr, # Block size for V * Ci dimension
|
|
102
|
+
BK: tl.constexpr, # Block size for K dimension (N)
|
|
103
|
+
BV: tl.constexpr, # Block size for V dimension
|
|
104
|
+
BCi: tl.constexpr, # Block size for Ci dimension
|
|
105
|
+
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
Sparse submanifold convolution backward to weight kernel using implicit GEMM.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
grad_output (pointer): A pointer to the gradient of the output tensor of shape (N, Co)
|
|
112
|
+
input (pointer): A pointer to the input tensor of shape (N, Ci)
|
|
113
|
+
neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
|
|
114
|
+
grad_weight (pointer): A pointer to the gradient of the weight tensor of shape (Co, V, Ci)
|
|
115
|
+
"""
|
|
116
|
+
block_id_co = tl.program_id(axis=0)
|
|
117
|
+
block_id_vci = tl.program_id(axis=1)
|
|
118
|
+
|
|
119
|
+
# Create pointers for submatrices of A and B.
|
|
120
|
+
num_k = tl.cdiv(N, BK) # Number of blocks in K dimension
|
|
121
|
+
# Use cdiv to handle non-power-of-2 Ci correctly
|
|
122
|
+
num_ci_blocks = tl.cdiv(Ci, BCi)
|
|
123
|
+
offset_co = (block_id_co * B1 + tl.arange(0, B1)) % Co # (B1,)
|
|
124
|
+
offset_v = (tl.arange(0, BV) + (block_id_vci // num_ci_blocks) * BV) % V # (BV,)
|
|
125
|
+
offset_ci = (tl.arange(0, BCi) + (block_id_vci % num_ci_blocks) * BCi) % Ci # (BCi,)
|
|
126
|
+
offset_k = tl.arange(0, BK) # (BK,)
|
|
127
|
+
neighbor_ptr = neighbor + (offset_k[:, None] * V + offset_v[None, :]) # (BK, BV)
|
|
128
|
+
grad_output_ptr = grad_output + (offset_k[None, :] * Co + offset_co[:, None]) # (B1, BK)
|
|
129
|
+
|
|
130
|
+
# Create a block of the output matrix C.
|
|
131
|
+
accumulator = tl.zeros((B1, BV * BCi), dtype=tl.float32)
|
|
132
|
+
|
|
133
|
+
# Iterate along V*Ci dimension.
|
|
134
|
+
for k in range(num_k):
|
|
135
|
+
mask = offset_k < N - k * BK
|
|
136
|
+
# Calculate pointers to input matrix.
|
|
137
|
+
input_offset_n = tl.load(neighbor_ptr, mask=mask[:, None], other=-1) # (BK, BV)
|
|
138
|
+
input_ptr = input + (input_offset_n[:, :, None].to(tl.int64) * Ci + offset_ci[None, None, :]) # (BK, BV, BCi)
|
|
139
|
+
# Load the next block of input and weight.
|
|
140
|
+
grad_output_block = tl.load(grad_output_ptr, mask=mask[None, :], other=0.0)
|
|
141
|
+
input_block = tl.load(input_ptr, mask=input_offset_n[:, :, None] != -1, other=0.0).reshape(BK, BV * BCi)
|
|
142
|
+
# Accumulate along the K dimension.
|
|
143
|
+
accumulator = tl.dot(grad_output_block, input_block, accumulator,
|
|
144
|
+
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
|
145
|
+
# Advance pointers.
|
|
146
|
+
grad_output_ptr += BK * Co
|
|
147
|
+
neighbor_ptr += BK * V
|
|
148
|
+
c = accumulator.to(grad_output.type.element_ty)
|
|
149
|
+
|
|
150
|
+
# Write back the block of the output matrix with masks.
|
|
151
|
+
# Decompose block_id_vci into block_id_v and block_id_ci
|
|
152
|
+
block_id_v = block_id_vci // num_ci_blocks
|
|
153
|
+
block_id_ci = block_id_vci % num_ci_blocks
|
|
154
|
+
|
|
155
|
+
grad_weight_offset_co = block_id_co * B1 + tl.arange(0, B1)
|
|
156
|
+
|
|
157
|
+
# Compute V*Ci linear indices correctly accounting for (V, Ci) layout
|
|
158
|
+
local_v = tl.arange(0, BV)
|
|
159
|
+
local_ci = tl.arange(0, BCi)
|
|
160
|
+
global_v = block_id_v * BV + local_v # (BV,)
|
|
161
|
+
global_ci = block_id_ci * BCi + local_ci # (BCi,)
|
|
162
|
+
|
|
163
|
+
# Linear index in V*Ci space: v * Ci + ci
|
|
164
|
+
grad_weight_offset_vci = (global_v[:, None] * Ci + global_ci[None, :]).reshape(BV * BCi) # (BV*BCi,)
|
|
165
|
+
|
|
166
|
+
grad_weight_ptr = grad_weight + (grad_weight_offset_co[:, None] * V * Ci + grad_weight_offset_vci[None, :])
|
|
167
|
+
|
|
168
|
+
# Create proper mask for V and Ci boundaries
|
|
169
|
+
v_mask = (global_v < V)[:, None] # (BV, 1)
|
|
170
|
+
ci_mask = (global_ci < Ci)[None, :] # (1, BCi)
|
|
171
|
+
vci_mask = (v_mask & ci_mask).reshape(BV * BCi) # (BV*BCi,)
|
|
172
|
+
grad_weight_mask = (grad_weight_offset_co[:, None] < Co) & vci_mask[None, :]
|
|
173
|
+
tl.store(grad_weight_ptr, c, mask=grad_weight_mask)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def conv_bwd_implicit_gemm(
|
|
177
|
+
grad_output: torch.Tensor,
|
|
178
|
+
input: torch.Tensor,
|
|
179
|
+
weight: torch.Tensor,
|
|
180
|
+
bias: torch.Tensor,
|
|
181
|
+
neighbor: torch.Tensor,
|
|
182
|
+
needs_input_grad: List[bool],
|
|
183
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
184
|
+
assert grad_output.is_contiguous(), "Matrix grad_output must be contiguous"
|
|
185
|
+
assert input.shape[1] == weight.shape[2], "Incompatible dimensions"
|
|
186
|
+
assert input.is_contiguous(), "Matrix input must be contiguous"
|
|
187
|
+
assert weight.is_contiguous(), "Matrix weight must be contiguous"
|
|
188
|
+
assert neighbor.is_contiguous(), "Matrix neighbor must be contiguous"
|
|
189
|
+
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
|
190
|
+
LOGN = int(math.log2(N))
|
|
191
|
+
|
|
192
|
+
grad_input, grad_weight, grad_bias = None, None, None
|
|
193
|
+
|
|
194
|
+
# Grad for input
|
|
195
|
+
if needs_input_grad[0]:
|
|
196
|
+
# Allocate output matrix output.
|
|
197
|
+
grad_input = torch.empty((N, Ci), device=input.device, dtype=input.dtype)
|
|
198
|
+
# Launch the kernel.
|
|
199
|
+
grid = lambda META: (triton.cdiv(Ci, META['B2']) * triton.cdiv(N, META['B1']),)
|
|
200
|
+
conv_bwd_input_implicit_gemm_kernel[grid](
|
|
201
|
+
grad_output,
|
|
202
|
+
weight,
|
|
203
|
+
neighbor,
|
|
204
|
+
grad_input,
|
|
205
|
+
N, LOGN, Ci, Co, V,
|
|
206
|
+
allow_tf32=config.allow_tf32,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Grad for weight
|
|
210
|
+
if needs_input_grad[1]:
|
|
211
|
+
# Allocate output matrix output.
|
|
212
|
+
grad_weight = torch.empty((Co, V, Ci), device=weight.device, dtype=weight.dtype)
|
|
213
|
+
# Launch the kernel.
|
|
214
|
+
# Use cdiv separately for V and Ci to correctly handle non-power-of-2 channels
|
|
215
|
+
grid = lambda META: (triton.cdiv(Co, META['B1']), triton.cdiv(V, META['BV']) * triton.cdiv(Ci, META['BCi']))
|
|
216
|
+
conv_bwd_weight_implicit_gemm_kernel[grid](
|
|
217
|
+
grad_output,
|
|
218
|
+
input,
|
|
219
|
+
neighbor,
|
|
220
|
+
grad_weight,
|
|
221
|
+
N, LOGN, Ci, Co, V,
|
|
222
|
+
allow_tf32=config.allow_tf32,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Grad for bias
|
|
226
|
+
if needs_input_grad[2]:
|
|
227
|
+
grad_bias = grad_output.sum(0)
|
|
228
|
+
|
|
229
|
+
return grad_input, grad_weight, grad_bias
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
from typing import *
|
|
2
|
+
import math
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
from .utils import get_num_sm
|
|
7
|
+
from .autotuner import triton_autotune, autotune
|
|
8
|
+
from . import config
|
|
9
|
+
from .conv_bwd_implicit_gemm import (
|
|
10
|
+
conv_bwd_input_implicit_gemm_kernel,
|
|
11
|
+
conv_bwd_weight_implicit_gemm_kernel,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@triton_autotune(
|
|
16
|
+
configs=config.autotune_config,
|
|
17
|
+
key=['LOGN', 'Ci', 'Co', 'V', 'SPLITK', 'allow_tf32'],
|
|
18
|
+
)
|
|
19
|
+
@triton.jit
|
|
20
|
+
def conv_bwd_input_implicit_gemm_splitk_kernel(
|
|
21
|
+
grad_output,
|
|
22
|
+
weight,
|
|
23
|
+
neighbor,
|
|
24
|
+
grad_input,
|
|
25
|
+
# Tensor dimensions
|
|
26
|
+
N, LOGN, Ci, Co, V: tl.constexpr,
|
|
27
|
+
# Meta-parameters
|
|
28
|
+
B1: tl.constexpr, # Block size for N dimension
|
|
29
|
+
B2: tl.constexpr, # Block size for Ci dimension
|
|
30
|
+
BK: tl.constexpr, # Block size for K dimension (V * Co)
|
|
31
|
+
SPLITK: tl.constexpr, # Split K dimension
|
|
32
|
+
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Sparse submanifold convolution backward to input kernel using implicit GEMM.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
grad_output (pointer): A pointer to the gradient of the output tensor of shape (N, Co)
|
|
39
|
+
weight (pointer): A pointer to the weight tensor of shape (Co, V, Ci)
|
|
40
|
+
neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
|
|
41
|
+
grad_input (pointer): A pointer to the gradient of the input tensor of shape (N, Ci)
|
|
42
|
+
"""
|
|
43
|
+
block_id_k = tl.program_id(axis=1) # SplitK dimension
|
|
44
|
+
block_id = tl.program_id(axis=0)
|
|
45
|
+
block_dim_ci = tl.cdiv(Ci, B2)
|
|
46
|
+
block_id_ci = block_id % block_dim_ci
|
|
47
|
+
block_id_n = block_id // block_dim_ci
|
|
48
|
+
|
|
49
|
+
# Create pointers for submatrices of A and B.
|
|
50
|
+
num_k = tl.cdiv(Co, BK) # Number of blocks in K dimension
|
|
51
|
+
k_start = tl.cdiv(num_k * V * block_id_k, SPLITK)
|
|
52
|
+
k_end = tl.cdiv(num_k * V * (block_id_k + 1), SPLITK)
|
|
53
|
+
offset_n = (block_id_n * B1 + tl.arange(0, B1)) % N # (B1,)
|
|
54
|
+
offset_ci = (block_id_ci * B2 + tl.arange(0, B2)) % Ci # (B2,)
|
|
55
|
+
offset_k = tl.arange(0, BK) # (BK,)
|
|
56
|
+
|
|
57
|
+
# Create a block of the output matrix C.
|
|
58
|
+
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
|
|
59
|
+
|
|
60
|
+
# Iterate along V*Co dimension.
|
|
61
|
+
for k in range(k_start, k_end):
|
|
62
|
+
v = k // num_k
|
|
63
|
+
bk = k % num_k
|
|
64
|
+
# Calculate pointers to grad_output matrix.
|
|
65
|
+
neighbor_offset_n = tl.load(neighbor + offset_n * V + V - 1 - v) # (B1,)
|
|
66
|
+
grad_output_ptr = grad_output + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Co + offset_k[None, :]) # (B1, BK)
|
|
67
|
+
# Calculate pointers to weight matrix.
|
|
68
|
+
weight_ptr = weight + (((offset_k[:, None] + bk * BK) * V + v) * Ci + offset_ci[None, :]) # (BK, B2)
|
|
69
|
+
# Load the next block of input and weight.
|
|
70
|
+
neigh_mask = neighbor_offset_n != -1
|
|
71
|
+
k_mask = offset_k < Co - bk * BK
|
|
72
|
+
grad_output_block = tl.load(grad_output_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
|
|
73
|
+
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
|
|
74
|
+
# Accumulate along the K dimension.
|
|
75
|
+
accumulator = tl.dot(grad_output_block, weight_block, accumulator,
|
|
76
|
+
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
|
77
|
+
|
|
78
|
+
# Write back the block of the output matrix with masks.
|
|
79
|
+
grad_input_offset_n = block_id_n * B1 + tl.arange(0, B1)
|
|
80
|
+
grad_input_offset_ci = block_id_ci * B2 + tl.arange(0, B2)
|
|
81
|
+
grad_input_ptr = grad_input + block_id_k * N * Ci + (grad_input_offset_n[:, None] * Ci + grad_input_offset_ci[None, :])
|
|
82
|
+
grad_input_mask = (grad_input_offset_n[:, None] < N) & (grad_input_offset_ci[None, :] < Ci)
|
|
83
|
+
tl.store(grad_input_ptr, accumulator, mask=grad_input_mask)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
heuristics = {
|
|
87
|
+
# BCi should be power of 2 for efficient tl.dot, but not exceed Ci or B2
|
|
88
|
+
'BCi': lambda meta: min(triton.next_power_of_2(meta['Ci']), meta['B2']),
|
|
89
|
+
# BV is calculated based on B2 and BCi
|
|
90
|
+
'BV': lambda meta: max(1, meta['B2'] // min(triton.next_power_of_2(meta['Ci']), meta['B2'])),
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@triton_autotune(
|
|
95
|
+
configs=config.autotune_config,
|
|
96
|
+
key=['LOGN', 'Ci', 'Co', 'V', 'SPLITK', 'allow_tf32'],
|
|
97
|
+
)
|
|
98
|
+
@triton.heuristics(heuristics)
|
|
99
|
+
@triton.jit
|
|
100
|
+
def conv_bwd_weight_implicit_gemm_splitk_kernel(
|
|
101
|
+
grad_output,
|
|
102
|
+
input,
|
|
103
|
+
neighbor,
|
|
104
|
+
grad_weight,
|
|
105
|
+
# Tensor dimensions
|
|
106
|
+
N, LOGN, Ci, Co, V: tl.constexpr,
|
|
107
|
+
# Meta-parameters
|
|
108
|
+
B1: tl.constexpr, # Block size for Co dimension
|
|
109
|
+
B2: tl.constexpr, # Block size for V * Ci dimension
|
|
110
|
+
BK: tl.constexpr, # Block size for K dimension (N)
|
|
111
|
+
BV: tl.constexpr, # Block size for V dimension
|
|
112
|
+
BCi: tl.constexpr, # Block size for Ci dimension
|
|
113
|
+
SPLITK: tl.constexpr, # Split K dimension
|
|
114
|
+
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
Sparse submanifold convolution backward to weight kernel using implicit GEMM.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
grad_output (pointer): A pointer to the gradient of the output tensor of shape (N, Co)
|
|
121
|
+
input (pointer): A pointer to the input tensor of shape (N, Ci)
|
|
122
|
+
neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
|
|
123
|
+
grad_weight (pointer): A pointer to the gradient of the weight tensor of shape (Co, V, Ci)
|
|
124
|
+
"""
|
|
125
|
+
block_id_co = tl.program_id(axis=0)
|
|
126
|
+
block_id_vci = tl.program_id(axis=1)
|
|
127
|
+
block_id_k = tl.program_id(axis=2)
|
|
128
|
+
|
|
129
|
+
# Create pointers for submatrices of A and B.
|
|
130
|
+
num_k = tl.cdiv(N, BK) # Number of blocks in K dimension
|
|
131
|
+
k_start = tl.cdiv(num_k * block_id_k, SPLITK)
|
|
132
|
+
k_end = tl.cdiv(num_k * (block_id_k + 1), SPLITK)
|
|
133
|
+
# Use cdiv to handle non-power-of-2 Ci correctly
|
|
134
|
+
num_ci_blocks = tl.cdiv(Ci, BCi)
|
|
135
|
+
offset_co = (block_id_co * B1 + tl.arange(0, B1)) % Co # (B1,)
|
|
136
|
+
offset_v = (tl.arange(0, BV) + (block_id_vci // num_ci_blocks) * BV) % V # (BV,)
|
|
137
|
+
offset_ci = (tl.arange(0, BCi) + (block_id_vci % num_ci_blocks) * BCi) % Ci # (BCi,)
|
|
138
|
+
offset_k = tl.arange(0, BK) # (BK,)
|
|
139
|
+
neighbor_ptr = neighbor + k_start * BK * V + (offset_k[:, None] * V + offset_v[None, :]) # (BK, BV)
|
|
140
|
+
grad_output_ptr = grad_output + k_start * BK * Co + (offset_k[None, :] * Co + offset_co[:, None]) # (B1, BK)
|
|
141
|
+
|
|
142
|
+
# Create a block of the output matrix C.
|
|
143
|
+
accumulator = tl.zeros((B1, BV * BCi), dtype=tl.float32)
|
|
144
|
+
|
|
145
|
+
# Iterate along V*Ci dimension.
|
|
146
|
+
for k in range(k_start, k_end):
|
|
147
|
+
mask = offset_k < N - k * BK
|
|
148
|
+
# Calculate pointers to input matrix.
|
|
149
|
+
input_offset_n = tl.load(neighbor_ptr, mask=mask[:, None], other=-1) # (BK, BV)
|
|
150
|
+
input_ptr = input + (input_offset_n[:, :, None].to(tl.int64) * Ci + offset_ci[None, None, :]) # (BK, BV, BCi)
|
|
151
|
+
# Load the next block of input and weight.
|
|
152
|
+
grad_output_block = tl.load(grad_output_ptr, mask=mask[None, :], other=0.0)
|
|
153
|
+
input_block = tl.load(input_ptr, mask=input_offset_n[:, :, None] != -1, other=0.0).reshape(BK, BV * BCi)
|
|
154
|
+
# Accumulate along the K dimension.
|
|
155
|
+
accumulator = tl.dot(grad_output_block, input_block, accumulator,
|
|
156
|
+
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
|
157
|
+
# Advance pointers.
|
|
158
|
+
grad_output_ptr += BK * Co
|
|
159
|
+
neighbor_ptr += BK * V
|
|
160
|
+
|
|
161
|
+
# Write back the block of the output matrix with masks.
|
|
162
|
+
# Decompose block_id_vci into block_id_v and block_id_ci
|
|
163
|
+
block_id_v = block_id_vci // num_ci_blocks
|
|
164
|
+
block_id_ci = block_id_vci % num_ci_blocks
|
|
165
|
+
|
|
166
|
+
grad_weight_offset_co = block_id_co * B1 + tl.arange(0, B1)
|
|
167
|
+
|
|
168
|
+
# Compute V*Ci linear indices correctly accounting for (V, Ci) layout
|
|
169
|
+
local_v = tl.arange(0, BV)
|
|
170
|
+
local_ci = tl.arange(0, BCi)
|
|
171
|
+
global_v = block_id_v * BV + local_v # (BV,)
|
|
172
|
+
global_ci = block_id_ci * BCi + local_ci # (BCi,)
|
|
173
|
+
|
|
174
|
+
# Linear index in V*Ci space: v * Ci + ci
|
|
175
|
+
grad_weight_offset_vci = (global_v[:, None] * Ci + global_ci[None, :]).reshape(BV * BCi) # (BV*BCi,)
|
|
176
|
+
|
|
177
|
+
grad_weight_ptr = grad_weight + block_id_k * Co * V * Ci + (grad_weight_offset_co[:, None] * V * Ci + grad_weight_offset_vci[None, :])
|
|
178
|
+
|
|
179
|
+
# Create proper mask for V and Ci boundaries
|
|
180
|
+
v_mask = (global_v < V)[:, None] # (BV, 1)
|
|
181
|
+
ci_mask = (global_ci < Ci)[None, :] # (1, BCi)
|
|
182
|
+
vci_mask = (v_mask & ci_mask).reshape(BV * BCi) # (BV*BCi,)
|
|
183
|
+
grad_weight_mask = (grad_weight_offset_co[:, None] < Co) & vci_mask[None, :]
|
|
184
|
+
tl.store(grad_weight_ptr, accumulator, mask=grad_weight_mask)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def conv_bwd_input_implicit_gemm_splitk_configs(grad_output, weight, neighbor):
|
|
188
|
+
N, Ci = neighbor.shape[0], weight.shape[-1]
|
|
189
|
+
MAX_NB1 = (N + 128 - 1) // 128
|
|
190
|
+
MAX_NB2 = (Ci + 128 - 1) // 128
|
|
191
|
+
NUM_BLOCKS = MAX_NB1 * MAX_NB2
|
|
192
|
+
MIN_NUM_BLOCKS = get_num_sm()
|
|
193
|
+
MAX_NUM_BLOCKS = 32 * get_num_sm()
|
|
194
|
+
MIN_NUM_BLOCKS_LOG2 = max(0, int(math.log2(MIN_NUM_BLOCKS / NUM_BLOCKS)))
|
|
195
|
+
MAX_NUM_BLOCKS_LOG2 = max(1, int(math.log2(MAX_NUM_BLOCKS / NUM_BLOCKS) + 1))
|
|
196
|
+
configs = []
|
|
197
|
+
for i in range(MIN_NUM_BLOCKS_LOG2, MAX_NUM_BLOCKS_LOG2):
|
|
198
|
+
configs.append({'SPLITK': 2 ** i})
|
|
199
|
+
return configs
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def conv_bwd_input_implicit_gemm_splitk_keys(grad_output, weight, neighbor):
|
|
203
|
+
N, Ci, Co, V = neighbor.shape[0], weight.shape[-1], weight.shape[0], weight.shape[1]
|
|
204
|
+
return f'(2^{int(math.log2(N))}, {Ci}, {Co}, {V})'
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@autotune(
|
|
208
|
+
config_fn=conv_bwd_input_implicit_gemm_splitk_configs,
|
|
209
|
+
key_fn=conv_bwd_input_implicit_gemm_splitk_keys,
|
|
210
|
+
)
|
|
211
|
+
def conv_bwd_input_implicit_gemm_splitk(
|
|
212
|
+
grad_output: torch.Tensor,
|
|
213
|
+
weight: torch.Tensor,
|
|
214
|
+
neighbor: torch.Tensor,
|
|
215
|
+
SPLITK: int = 1,
|
|
216
|
+
) -> torch.Tensor:
|
|
217
|
+
N, Ci, Co, V = neighbor.shape[0], weight.shape[-1], weight.shape[0], weight.shape[1]
|
|
218
|
+
LOGN = int(math.log2(N))
|
|
219
|
+
# Launch the kernel.
|
|
220
|
+
if SPLITK == 1:
|
|
221
|
+
grad_input = torch.empty((N, Ci), device=weight.device, dtype=weight.dtype)
|
|
222
|
+
grid = lambda META: (triton.cdiv(Ci, META['B2']) * triton.cdiv(N, META['B1']),)
|
|
223
|
+
conv_bwd_input_implicit_gemm_kernel[grid](
|
|
224
|
+
grad_output,
|
|
225
|
+
weight,
|
|
226
|
+
neighbor,
|
|
227
|
+
grad_input,
|
|
228
|
+
N, LOGN, Ci, Co, V,
|
|
229
|
+
allow_tf32=config.allow_tf32,
|
|
230
|
+
)
|
|
231
|
+
return grad_input
|
|
232
|
+
else:
|
|
233
|
+
grad_input = torch.empty((SPLITK, N, Ci), device=weight.device, dtype=torch.float32)
|
|
234
|
+
grid = lambda META: (triton.cdiv(Ci, META['B2']) * triton.cdiv(N, META['B1']), SPLITK)
|
|
235
|
+
conv_bwd_input_implicit_gemm_splitk_kernel[grid](
|
|
236
|
+
grad_output,
|
|
237
|
+
weight,
|
|
238
|
+
neighbor,
|
|
239
|
+
grad_input,
|
|
240
|
+
N, LOGN, Ci, Co, V,
|
|
241
|
+
SPLITK=SPLITK,
|
|
242
|
+
allow_tf32=config.allow_tf32,
|
|
243
|
+
)
|
|
244
|
+
return grad_input.sum(0).to(weight.dtype)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def conv_bwd_weight_implicit_gemm_splitk_configs(grad_output, input, neighbor):
|
|
248
|
+
Co, V, Ci = grad_output.shape[1], neighbor.shape[1], input.shape[1]
|
|
249
|
+
MAX_NB1 = (Co + 128 - 1) // 128
|
|
250
|
+
MAX_NB2 = (V * Ci + 128 - 1) // 128
|
|
251
|
+
NUM_BLOCKS = MAX_NB1 * MAX_NB2
|
|
252
|
+
MIN_NUM_BLOCKS = get_num_sm()
|
|
253
|
+
MAX_NUM_BLOCKS = 32 * get_num_sm()
|
|
254
|
+
MIN_NUM_BLOCKS_LOG2 = max(0, int(math.log2(MIN_NUM_BLOCKS / NUM_BLOCKS)))
|
|
255
|
+
MAX_NUM_BLOCKS_LOG2 = max(1, int(math.log2(MAX_NUM_BLOCKS / NUM_BLOCKS) + 1))
|
|
256
|
+
configs = []
|
|
257
|
+
for i in range(MIN_NUM_BLOCKS_LOG2, MAX_NUM_BLOCKS_LOG2):
|
|
258
|
+
configs.append({'SPLITK': 2 ** i})
|
|
259
|
+
return configs
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def conv_bwd_weight_implicit_gemm_splitk_keys(grad_output, input, neighbor):
|
|
263
|
+
N, Ci, Co, V = neighbor.shape[0], input.shape[1], grad_output.shape[1], neighbor.shape[1]
|
|
264
|
+
return f'(2^{int(math.log2(N))}, {Ci}, {Co}, {V})'
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@autotune(
|
|
268
|
+
config_fn=conv_bwd_weight_implicit_gemm_splitk_configs,
|
|
269
|
+
key_fn=conv_bwd_weight_implicit_gemm_splitk_keys,
|
|
270
|
+
)
|
|
271
|
+
def conv_bwd_weight_implicit_gemm_splitk(
|
|
272
|
+
grad_output: torch.Tensor,
|
|
273
|
+
input: torch.Tensor,
|
|
274
|
+
neighbor: torch.Tensor,
|
|
275
|
+
SPLITK: int = 1,
|
|
276
|
+
) -> torch.Tensor:
|
|
277
|
+
N, Ci, Co, V = neighbor.shape[0], input.shape[1], grad_output.shape[1], neighbor.shape[1]
|
|
278
|
+
LOGN = int(math.log2(N))
|
|
279
|
+
# Launch the kernel.
|
|
280
|
+
if SPLITK == 1:
|
|
281
|
+
grad_weight = torch.empty((Co, V, Ci), device=grad_output.device, dtype=grad_output.dtype)
|
|
282
|
+
# Use cdiv separately for V and Ci to correctly handle non-power-of-2 channels
|
|
283
|
+
grid = lambda META: (triton.cdiv(Co, META['B1']), triton.cdiv(V, META['BV']) * triton.cdiv(Ci, META['BCi']))
|
|
284
|
+
conv_bwd_weight_implicit_gemm_kernel[grid](
|
|
285
|
+
grad_output,
|
|
286
|
+
input,
|
|
287
|
+
neighbor,
|
|
288
|
+
grad_weight,
|
|
289
|
+
N, LOGN, Ci, Co, V,
|
|
290
|
+
allow_tf32=config.allow_tf32,
|
|
291
|
+
)
|
|
292
|
+
return grad_weight
|
|
293
|
+
else:
|
|
294
|
+
grad_weight = torch.empty((SPLITK, Co, V, Ci), device=grad_output.device, dtype=torch.float32)
|
|
295
|
+
# Use cdiv separately for V and Ci to correctly handle non-power-of-2 channels
|
|
296
|
+
grid = lambda META: (triton.cdiv(Co, META['B1']), triton.cdiv(V, META['BV']) * triton.cdiv(Ci, META['BCi']), SPLITK)
|
|
297
|
+
conv_bwd_weight_implicit_gemm_splitk_kernel[grid](
|
|
298
|
+
grad_output,
|
|
299
|
+
input,
|
|
300
|
+
neighbor,
|
|
301
|
+
grad_weight,
|
|
302
|
+
N, LOGN, Ci, Co, V,
|
|
303
|
+
SPLITK=SPLITK,
|
|
304
|
+
allow_tf32=config.allow_tf32,
|
|
305
|
+
)
|
|
306
|
+
return grad_weight.sum(0).to(grad_output.dtype)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def conv_bwd_implicit_gemm_splitk(
|
|
310
|
+
grad_output: torch.Tensor,
|
|
311
|
+
input: torch.Tensor,
|
|
312
|
+
weight: torch.Tensor,
|
|
313
|
+
bias: torch.Tensor,
|
|
314
|
+
neighbor: torch.Tensor,
|
|
315
|
+
needs_input_grad,
|
|
316
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
317
|
+
assert grad_output.is_contiguous(), "Matrix grad_output must be contiguous"
|
|
318
|
+
assert input.shape[1] == weight.shape[2], "Incompatible dimensions"
|
|
319
|
+
assert input.is_contiguous(), "Matrix input must be contiguous"
|
|
320
|
+
assert weight.is_contiguous(), "Matrix weight must be contiguous"
|
|
321
|
+
assert neighbor.is_contiguous(), "Matrix neighbor must be contiguous"
|
|
322
|
+
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
|
323
|
+
LOGN = int(math.log2(N))
|
|
324
|
+
|
|
325
|
+
grad_input, grad_weight, grad_bias = None, None, None
|
|
326
|
+
|
|
327
|
+
# Grad for input
|
|
328
|
+
if needs_input_grad[0]:
|
|
329
|
+
grad_input = conv_bwd_input_implicit_gemm_splitk(
|
|
330
|
+
grad_output,
|
|
331
|
+
weight,
|
|
332
|
+
neighbor,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Grad for weight
|
|
336
|
+
if needs_input_grad[1]:
|
|
337
|
+
grad_weight = conv_bwd_weight_implicit_gemm_splitk(
|
|
338
|
+
grad_output,
|
|
339
|
+
input,
|
|
340
|
+
neighbor,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Grad for bias
|
|
344
|
+
if needs_input_grad[2]:
|
|
345
|
+
grad_bias = grad_output.sum(0)
|
|
346
|
+
|
|
347
|
+
return grad_input, grad_weight, grad_bias
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from typing import *
|
|
2
|
+
import math
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
from .autotuner import triton_autotune
|
|
7
|
+
from . import config
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@triton_autotune(
|
|
11
|
+
configs=config.autotune_config,
|
|
12
|
+
key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
|
|
13
|
+
)
|
|
14
|
+
@triton.jit
|
|
15
|
+
def conv_fwd_implicit_gemm_kernel(
|
|
16
|
+
input,
|
|
17
|
+
weight,
|
|
18
|
+
bias,
|
|
19
|
+
neighbor,
|
|
20
|
+
output,
|
|
21
|
+
# Tensor dimensions
|
|
22
|
+
N, LOGN, Ci, Co, V: tl.constexpr,
|
|
23
|
+
# Meta-parameters
|
|
24
|
+
B1: tl.constexpr, # Block size for N dimension
|
|
25
|
+
B2: tl.constexpr, # Block size for Co dimension
|
|
26
|
+
BK: tl.constexpr, # Block size for K dimension (V * Ci)
|
|
27
|
+
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Sparse submanifold convolution forward kernel using implicit GEMM.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
input (pointer): A pointer to the input tensor of shape (N, Ci)
|
|
34
|
+
weight (pointer): A pointer to the weight tensor of shape (Co, V, Ci)
|
|
35
|
+
bias (pointer): A pointer to the bias tensor of shape (Co)
|
|
36
|
+
neighbor (pointer): A pointer to the neighbor tensor of shape (N, V)
|
|
37
|
+
output (pointer): A pointer to the output tensor of shape (N, Co)
|
|
38
|
+
"""
|
|
39
|
+
block_id = tl.program_id(axis=0)
|
|
40
|
+
block_dim_co = tl.cdiv(Co, B2)
|
|
41
|
+
block_id_co = block_id % block_dim_co
|
|
42
|
+
block_id_n = block_id // block_dim_co
|
|
43
|
+
|
|
44
|
+
# Create pointers for submatrices of A and B.
|
|
45
|
+
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
|
|
46
|
+
offset_n = (block_id_n * B1 + tl.arange(0, B1)) % N # (B1,)
|
|
47
|
+
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
|
|
48
|
+
offset_k = tl.arange(0, BK) # (BK,)
|
|
49
|
+
|
|
50
|
+
# Create a block of the output matrix C.
|
|
51
|
+
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
|
|
52
|
+
|
|
53
|
+
# Calculate pointers to weight matrix.
|
|
54
|
+
weight_ptr = weight + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
|
|
55
|
+
|
|
56
|
+
# Iterate along V*Ci dimension.
|
|
57
|
+
for k in range(num_k * V):
|
|
58
|
+
v = k // num_k
|
|
59
|
+
bk = k % num_k
|
|
60
|
+
# Calculate pointers to input matrix.
|
|
61
|
+
neighbor_offset_n = tl.load(neighbor + offset_n * V + v) # (B1,)
|
|
62
|
+
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
|
|
63
|
+
# Load the next block of input and weight.
|
|
64
|
+
neigh_mask = neighbor_offset_n != -1
|
|
65
|
+
k_mask = offset_k < Ci - bk * BK
|
|
66
|
+
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
|
|
67
|
+
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
|
|
68
|
+
# Accumulate along the K dimension.
|
|
69
|
+
accumulator = tl.dot(input_block, weight_block, accumulator,
|
|
70
|
+
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
|
71
|
+
# Advance the pointers to the next Ci block.
|
|
72
|
+
weight_ptr += min(BK, Ci - bk * BK)
|
|
73
|
+
c = accumulator.to(input.type.element_ty)
|
|
74
|
+
|
|
75
|
+
# add bias
|
|
76
|
+
if bias is not None:
|
|
77
|
+
bias_block = tl.load(bias + offset_co)
|
|
78
|
+
c += bias_block[None, :]
|
|
79
|
+
|
|
80
|
+
# Write back the block of the output matrix with masks.
|
|
81
|
+
out_offset_n = block_id_n * B1 + tl.arange(0, B1)
|
|
82
|
+
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
|
|
83
|
+
out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :])
|
|
84
|
+
out_mask = (out_offset_n[:, None] < N) & (out_offset_co[None, :] < Co)
|
|
85
|
+
tl.store(out_ptr, c, mask=out_mask)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def conv_fwd_implicit_gemm(
|
|
89
|
+
input: torch.Tensor,
|
|
90
|
+
weight: torch.Tensor,
|
|
91
|
+
bias: torch.Tensor,
|
|
92
|
+
neighbor: torch.Tensor,
|
|
93
|
+
) -> torch.Tensor:
|
|
94
|
+
assert input.shape[1] == weight.shape[2], "Incompatible dimensions"
|
|
95
|
+
assert input.is_contiguous(), "Matrix input must be contiguous"
|
|
96
|
+
assert weight.is_contiguous(), "Matrix weight must be contiguous"
|
|
97
|
+
assert neighbor.is_contiguous(), "Matrix neighbor must be contiguous"
|
|
98
|
+
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
|
99
|
+
LOGN = int(math.log2(N))
|
|
100
|
+
# Allocate output matrix output.
|
|
101
|
+
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
|
|
102
|
+
# Launch the kernel.
|
|
103
|
+
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
|
|
104
|
+
conv_fwd_implicit_gemm_kernel[grid](
|
|
105
|
+
input, weight, bias, neighbor, output,
|
|
106
|
+
N, LOGN, Ci, Co, V,
|
|
107
|
+
allow_tf32=config.allow_tf32,
|
|
108
|
+
)
|
|
109
|
+
return output
|