triton-windows 3.1.0.post17__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +73 -0
- triton/backends/__init__.py +50 -0
- triton/backends/amd/compiler.py +262 -0
- triton/backends/amd/driver.c +211 -0
- triton/backends/amd/driver.py +497 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
- triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
- triton/backends/amd/include/hip/channel_descriptor.h +39 -0
- triton/backends/amd/include/hip/device_functions.h +38 -0
- triton/backends/amd/include/hip/driver_types.h +468 -0
- triton/backends/amd/include/hip/hip_bf16.h +36 -0
- triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
- triton/backends/amd/include/hip/hip_common.h +100 -0
- triton/backends/amd/include/hip/hip_complex.h +38 -0
- triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
- triton/backends/amd/include/hip/hip_deprecated.h +95 -0
- triton/backends/amd/include/hip/hip_ext.h +159 -0
- triton/backends/amd/include/hip/hip_fp16.h +36 -0
- triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
- triton/backends/amd/include/hip/hip_hcc.h +24 -0
- triton/backends/amd/include/hip/hip_math_constants.h +36 -0
- triton/backends/amd/include/hip/hip_profile.h +27 -0
- triton/backends/amd/include/hip/hip_runtime.h +75 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
- triton/backends/amd/include/hip/hip_texture_types.h +29 -0
- triton/backends/amd/include/hip/hip_vector_types.h +41 -0
- triton/backends/amd/include/hip/hip_version.h +17 -0
- triton/backends/amd/include/hip/hiprtc.h +421 -0
- triton/backends/amd/include/hip/library_types.h +78 -0
- triton/backends/amd/include/hip/math_functions.h +42 -0
- triton/backends/amd/include/hip/surface_types.h +63 -0
- triton/backends/amd/include/hip/texture_types.h +194 -0
- triton/backends/amd/include/hsa/Brig.h +1131 -0
- triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
- triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
- triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
- triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
- triton/backends/amd/include/hsa/hsa.h +5729 -0
- triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
- triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
- triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
- triton/backends/amd/include/roctracer/roctracer.h +779 -0
- triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
- triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
- triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
- triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
- triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
- triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
- triton/backends/amd/include/roctracer/roctx.h +229 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +76 -0
- triton/backends/driver.py +34 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +347 -0
- triton/backends/nvidia/driver.c +451 -0
- triton/backends/nvidia/driver.py +430 -0
- triton/backends/nvidia/include/cuda.h +24359 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +4 -0
- triton/compiler/code_generator.py +1302 -0
- triton/compiler/compiler.py +416 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/language/__init__.py +284 -0
- triton/language/core.py +2621 -0
- triton/language/extra/__init__.py +4 -0
- triton/language/extra/cuda/__init__.py +8 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +3 -0
- triton/language/extra/hip/libdevice.py +468 -0
- triton/language/extra/libdevice.py +1213 -0
- triton/language/math.py +250 -0
- triton/language/random.py +207 -0
- triton/language/semantic.py +1621 -0
- triton/language/standard.py +441 -0
- triton/ops/__init__.py +7 -0
- triton/ops/blocksparse/__init__.py +7 -0
- triton/ops/blocksparse/matmul.py +432 -0
- triton/ops/blocksparse/softmax.py +228 -0
- triton/ops/cross_entropy.py +96 -0
- triton/ops/flash_attention.py +466 -0
- triton/ops/matmul.py +219 -0
- triton/ops/matmul_perf_model.py +171 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/autotuner.py +361 -0
- triton/runtime/build.py +129 -0
- triton/runtime/cache.py +289 -0
- triton/runtime/driver.py +60 -0
- triton/runtime/errors.py +26 -0
- triton/runtime/interpreter.py +1127 -0
- triton/runtime/jit.py +956 -0
- triton/runtime/tcc/include/_mingw.h +170 -0
- triton/runtime/tcc/include/assert.h +57 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +57 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/limits.h +111 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +737 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdarg.h +79 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +54 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +580 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +118 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/winbase.h +2951 -0
- triton/runtime/tcc/include/winapi/wincon.h +301 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnt.h +5835 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +496 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.c +67 -0
- triton/tools/compile.h +14 -0
- triton/tools/compile.py +145 -0
- triton/tools/disasm.py +142 -0
- triton/tools/link.py +322 -0
- triton/windows_utils.py +373 -0
- triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
- triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
- triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
- triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
triton/ops/matmul.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from .. import Config, autotune, cdiv, heuristics, jit
|
|
4
|
+
from .. import language as tl
|
|
5
|
+
from .matmul_perf_model import early_config_prune, estimate_matmul_time
|
|
6
|
+
|
|
7
|
+
_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def upcast_if_fp8(a):
|
|
11
|
+
if "fp8" in str(a):
|
|
12
|
+
return torch.float16
|
|
13
|
+
return a
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_higher_dtype(a, b):
|
|
17
|
+
a = upcast_if_fp8(a)
|
|
18
|
+
b = upcast_if_fp8(b)
|
|
19
|
+
if a is b:
|
|
20
|
+
return a
|
|
21
|
+
|
|
22
|
+
assert a in _ordered_datatypes
|
|
23
|
+
assert b in _ordered_datatypes
|
|
24
|
+
|
|
25
|
+
for d in _ordered_datatypes:
|
|
26
|
+
if a is d:
|
|
27
|
+
return b
|
|
28
|
+
if b is d:
|
|
29
|
+
return a
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def init_to_zero(name):
|
|
33
|
+
return lambda nargs: nargs[name].zero_()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_configs_io_bound():
|
|
37
|
+
configs = []
|
|
38
|
+
for num_stages in [2, 3, 4, 5, 6]:
|
|
39
|
+
for block_m in [16, 32]:
|
|
40
|
+
for block_k in [32, 64]:
|
|
41
|
+
for block_n in [32, 64, 128, 256]:
|
|
42
|
+
num_warps = 2 if block_n <= 64 else 4
|
|
43
|
+
configs.append(
|
|
44
|
+
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
|
45
|
+
num_stages=num_stages, num_warps=num_warps))
|
|
46
|
+
# split_k
|
|
47
|
+
for split_k in [2, 4, 8, 16]:
|
|
48
|
+
configs.append(
|
|
49
|
+
Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
|
50
|
+
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
|
51
|
+
return configs
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@autotune(
|
|
55
|
+
configs=[
|
|
56
|
+
# basic configs for compute-bound matmuls
|
|
57
|
+
Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
|
58
|
+
Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
|
59
|
+
Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
60
|
+
Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
61
|
+
Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
62
|
+
Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
63
|
+
Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
64
|
+
Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
65
|
+
Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
|
66
|
+
# good for int8
|
|
67
|
+
Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
|
68
|
+
Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
|
69
|
+
Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
70
|
+
Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
71
|
+
Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
72
|
+
Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
73
|
+
Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
74
|
+
Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
|
75
|
+
Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
|
76
|
+
] + get_configs_io_bound(),
|
|
77
|
+
key=['M', 'N', 'K'],
|
|
78
|
+
prune_configs_by={
|
|
79
|
+
'early_config_prune': early_config_prune,
|
|
80
|
+
'perf_model': estimate_matmul_time,
|
|
81
|
+
'top_k': 10,
|
|
82
|
+
},
|
|
83
|
+
)
|
|
84
|
+
@heuristics({
|
|
85
|
+
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
|
86
|
+
})
|
|
87
|
+
@jit
|
|
88
|
+
def _kernel(A, B, C, M, N, K, #
|
|
89
|
+
stride_am, stride_ak, #
|
|
90
|
+
stride_bk, stride_bn, #
|
|
91
|
+
stride_cm, stride_cn, #
|
|
92
|
+
acc_dtype: tl.constexpr, #
|
|
93
|
+
input_precision: tl.constexpr, #
|
|
94
|
+
fp8_fast_accum: tl.constexpr, #
|
|
95
|
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
|
|
96
|
+
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr #
|
|
97
|
+
):
|
|
98
|
+
# matrix multiplication
|
|
99
|
+
pid = tl.program_id(0)
|
|
100
|
+
pid_z = tl.program_id(1)
|
|
101
|
+
grid_m = tl.cdiv(M, BLOCK_M)
|
|
102
|
+
grid_n = tl.cdiv(N, BLOCK_N)
|
|
103
|
+
# re-order program ID for better L2 performance
|
|
104
|
+
width = GROUP_M * grid_n
|
|
105
|
+
group_id = pid // width
|
|
106
|
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
|
107
|
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
|
108
|
+
pid_n = (pid % width) // (group_size)
|
|
109
|
+
# do matrix multiplication
|
|
110
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
111
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
112
|
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
|
113
|
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
|
114
|
+
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
115
|
+
# pointers
|
|
116
|
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
|
117
|
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
|
118
|
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
|
|
119
|
+
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
|
120
|
+
if EVEN_K:
|
|
121
|
+
a = tl.load(A)
|
|
122
|
+
b = tl.load(B)
|
|
123
|
+
else:
|
|
124
|
+
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
|
125
|
+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
|
|
126
|
+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
|
|
127
|
+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
|
|
128
|
+
if AB_DTYPE is not None:
|
|
129
|
+
a = a.to(AB_DTYPE)
|
|
130
|
+
b = b.to(AB_DTYPE)
|
|
131
|
+
if fp8_fast_accum:
|
|
132
|
+
acc = tl.dot(a, b, acc, out_dtype=acc_dtype, input_precision=input_precision)
|
|
133
|
+
else:
|
|
134
|
+
acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision)
|
|
135
|
+
A += BLOCK_K * SPLIT_K * stride_ak
|
|
136
|
+
B += BLOCK_K * SPLIT_K * stride_bk
|
|
137
|
+
acc = acc.to(C.dtype.element_ty)
|
|
138
|
+
# rematerialize rm and rn to save registers
|
|
139
|
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
140
|
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
141
|
+
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
|
142
|
+
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
|
143
|
+
# handles write-back with reduction-splitting
|
|
144
|
+
if SPLIT_K == 1:
|
|
145
|
+
tl.store(C, acc, mask=mask)
|
|
146
|
+
else:
|
|
147
|
+
tl.atomic_add(C, acc, mask=mask)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class _matmul(torch.autograd.Function):
|
|
151
|
+
kernel = _kernel
|
|
152
|
+
|
|
153
|
+
_locks = {}
|
|
154
|
+
|
|
155
|
+
@staticmethod
|
|
156
|
+
def _call(a, b, acc_dtype, input_precision, fp8_fast_accum, output_dtype):
|
|
157
|
+
device = a.device
|
|
158
|
+
# handle non-contiguous inputs if necessary
|
|
159
|
+
if a.stride(0) > 1 and a.stride(1) > 1:
|
|
160
|
+
a = a.contiguous()
|
|
161
|
+
if b.stride(0) > 1 and b.stride(1) > 1:
|
|
162
|
+
b = b.contiguous()
|
|
163
|
+
# checks constraints
|
|
164
|
+
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
|
165
|
+
M, K = a.shape
|
|
166
|
+
_, N = b.shape
|
|
167
|
+
|
|
168
|
+
# common type between a and b
|
|
169
|
+
ab_dtype = get_higher_dtype(a.dtype, b.dtype)
|
|
170
|
+
|
|
171
|
+
# allocates output
|
|
172
|
+
if (output_dtype is None):
|
|
173
|
+
output_dtype = ab_dtype
|
|
174
|
+
|
|
175
|
+
c = torch.empty((M, N), device=device, dtype=output_dtype)
|
|
176
|
+
|
|
177
|
+
# Allowed types for acc_type given the types of a and b.
|
|
178
|
+
supported_acc_dtypes = {
|
|
179
|
+
torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16),
|
|
180
|
+
torch.float32: (torch.float32, ), torch.int8: (torch.int32, )
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
if acc_dtype is None:
|
|
184
|
+
acc_dtype = supported_acc_dtypes[ab_dtype][0]
|
|
185
|
+
else:
|
|
186
|
+
assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype"
|
|
187
|
+
assert acc_dtype in supported_acc_dtypes[a.dtype], "acc_dtype not compatible with the type of a"
|
|
188
|
+
assert acc_dtype in supported_acc_dtypes[b.dtype], "acc_dtype not compatible with the type of b"
|
|
189
|
+
|
|
190
|
+
def to_tl_type(ty):
|
|
191
|
+
return getattr(tl, str(ty).split(".")[-1])
|
|
192
|
+
|
|
193
|
+
acc_dtype = to_tl_type(acc_dtype)
|
|
194
|
+
ab_dtype = to_tl_type(ab_dtype)
|
|
195
|
+
output_dtype = to_tl_type(output_dtype)
|
|
196
|
+
|
|
197
|
+
# Tensor cores support input with mixed float8 types.
|
|
198
|
+
if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]:
|
|
199
|
+
ab_dtype = None
|
|
200
|
+
# launch kernel
|
|
201
|
+
grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
|
202
|
+
_kernel[grid](
|
|
203
|
+
a, b, c, M, N, K, #
|
|
204
|
+
a.stride(0), a.stride(1), #
|
|
205
|
+
b.stride(0), b.stride(1), #
|
|
206
|
+
c.stride(0), c.stride(1), #
|
|
207
|
+
acc_dtype=acc_dtype, #
|
|
208
|
+
input_precision=input_precision, #
|
|
209
|
+
fp8_fast_accum=fp8_fast_accum, #
|
|
210
|
+
GROUP_M=8, AB_DTYPE=ab_dtype)
|
|
211
|
+
return c
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def forward(ctx, a, b, acc_dtype=None, input_precision=None, fp8_fast_accum=True, output_dtype=None):
|
|
215
|
+
return _matmul._call(a, b, acc_dtype=acc_dtype, input_precision=input_precision, fp8_fast_accum=fp8_fast_accum,
|
|
216
|
+
output_dtype=output_dtype)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
matmul = _matmul.apply
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import heapq
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .. import cdiv
|
|
7
|
+
from ..runtime import driver
|
|
8
|
+
from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@functools.lru_cache()
|
|
12
|
+
def get_clock_rate_in_khz():
|
|
13
|
+
try:
|
|
14
|
+
return nvsmi(['clocks.max.sm'])[0] * 1e3
|
|
15
|
+
except FileNotFoundError:
|
|
16
|
+
import pynvml
|
|
17
|
+
|
|
18
|
+
pynvml.nvmlInit()
|
|
19
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
|
20
|
+
return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_tensorcore_tflops(device, num_ctas, num_warps, dtype):
|
|
24
|
+
''' return compute throughput in TOPS '''
|
|
25
|
+
total_warps = num_ctas * min(num_warps, 4)
|
|
26
|
+
num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
|
27
|
+
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(
|
|
28
|
+
dtype, get_clock_rate_in_khz(), device)
|
|
29
|
+
return tflops
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_simd_tflops(device, num_ctas, num_warps, dtype):
|
|
33
|
+
''' return compute throughput in TOPS '''
|
|
34
|
+
total_warps = num_ctas * min(num_warps, 4)
|
|
35
|
+
num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
|
36
|
+
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device)
|
|
37
|
+
return tflops
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_tflops(device, num_ctas, num_warps, dtype):
|
|
41
|
+
capability = torch.cuda.get_device_capability(device)
|
|
42
|
+
if capability[0] < 8 and dtype == torch.float32:
|
|
43
|
+
return get_simd_tflops(device, num_ctas, num_warps, dtype)
|
|
44
|
+
return get_tensorcore_tflops(device, num_ctas, num_warps, dtype)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def estimate_matmul_time(
|
|
48
|
+
# backend, device,
|
|
49
|
+
num_warps, num_stages, #
|
|
50
|
+
A, B, C, #
|
|
51
|
+
M, N, K, #
|
|
52
|
+
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, #
|
|
53
|
+
debug=False, **kwargs #
|
|
54
|
+
):
|
|
55
|
+
''' return estimated running time in ms
|
|
56
|
+
= max(compute, loading) + store '''
|
|
57
|
+
device = torch.cuda.current_device()
|
|
58
|
+
dtype = A.dtype
|
|
59
|
+
dtsize = A.element_size()
|
|
60
|
+
|
|
61
|
+
num_cta_m = cdiv(M, BLOCK_M)
|
|
62
|
+
num_cta_n = cdiv(N, BLOCK_N)
|
|
63
|
+
num_cta_k = SPLIT_K
|
|
64
|
+
num_ctas = num_cta_m * num_cta_n * num_cta_k
|
|
65
|
+
|
|
66
|
+
# If the input is smaller than the block size
|
|
67
|
+
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
|
|
68
|
+
|
|
69
|
+
# time to compute
|
|
70
|
+
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
|
|
71
|
+
tput = get_tflops(device, num_ctas, num_warps, dtype)
|
|
72
|
+
compute_ms = total_ops / tput
|
|
73
|
+
|
|
74
|
+
# time to load data
|
|
75
|
+
num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"]
|
|
76
|
+
active_cta_ratio = min(1, num_ctas / num_sm)
|
|
77
|
+
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
|
|
78
|
+
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
|
|
79
|
+
dram_bw = get_dram_gbps(device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
|
|
80
|
+
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
|
81
|
+
# assume 80% of (following) loads are in L2 cache
|
|
82
|
+
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
|
|
83
|
+
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
|
|
84
|
+
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
|
|
85
|
+
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
|
|
86
|
+
# total
|
|
87
|
+
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
|
|
88
|
+
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
|
|
89
|
+
# loading time in ms
|
|
90
|
+
load_ms = total_dram / dram_bw + total_l2 / l2_bw
|
|
91
|
+
|
|
92
|
+
# estimate storing time
|
|
93
|
+
store_bw = dram_bw * 0.6 # :o
|
|
94
|
+
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
|
|
95
|
+
if SPLIT_K == 1:
|
|
96
|
+
store_ms = store_c_dram / store_bw
|
|
97
|
+
else:
|
|
98
|
+
reduce_bw = store_bw
|
|
99
|
+
store_ms = store_c_dram / reduce_bw
|
|
100
|
+
# c.zero_()
|
|
101
|
+
zero_ms = M * N * 2 / (1024 * 1024) / store_bw
|
|
102
|
+
store_ms += zero_ms
|
|
103
|
+
|
|
104
|
+
total_time_ms = max(compute_ms, load_ms) + store_ms
|
|
105
|
+
if debug:
|
|
106
|
+
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
|
|
107
|
+
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
|
|
108
|
+
f'Activate CTAs: {active_cta_ratio*100}%')
|
|
109
|
+
return total_time_ms
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def early_config_prune(configs, named_args, **kwargs):
|
|
113
|
+
device = torch.cuda.current_device()
|
|
114
|
+
capability = torch.cuda.get_device_capability()
|
|
115
|
+
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
|
116
|
+
dtsize = named_args['A'].element_size()
|
|
117
|
+
dtype = named_args['A'].dtype
|
|
118
|
+
|
|
119
|
+
# 1. make sure we have enough smem
|
|
120
|
+
pruned_configs = []
|
|
121
|
+
for config in configs:
|
|
122
|
+
kw = config.kwargs
|
|
123
|
+
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
|
|
124
|
+
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
|
|
125
|
+
|
|
126
|
+
max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"]
|
|
127
|
+
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
|
128
|
+
if required_shared_memory <= max_shared_memory:
|
|
129
|
+
pruned_configs.append(config)
|
|
130
|
+
configs = pruned_configs
|
|
131
|
+
|
|
132
|
+
# Some dtypes do not allow atomic_add
|
|
133
|
+
if dtype not in [torch.float16, torch.float32]:
|
|
134
|
+
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
|
|
135
|
+
|
|
136
|
+
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
|
|
137
|
+
configs_map = {}
|
|
138
|
+
for config in configs:
|
|
139
|
+
kw = config.kwargs
|
|
140
|
+
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
|
|
141
|
+
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
|
|
142
|
+
|
|
143
|
+
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
|
|
144
|
+
if key in configs_map:
|
|
145
|
+
configs_map[key].append((config, num_stages))
|
|
146
|
+
else:
|
|
147
|
+
configs_map[key] = [(config, num_stages)]
|
|
148
|
+
|
|
149
|
+
pruned_configs = []
|
|
150
|
+
for k, v in configs_map.items():
|
|
151
|
+
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
|
152
|
+
if capability[0] >= 8:
|
|
153
|
+
# compute cycles (only works for ampere GPUs)
|
|
154
|
+
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
|
|
155
|
+
mma_cycles = mmas / min(4, num_warps) * 8
|
|
156
|
+
|
|
157
|
+
ldgsts_latency = 300 # Does this matter?
|
|
158
|
+
optimal_num_stages = ldgsts_latency / mma_cycles
|
|
159
|
+
|
|
160
|
+
# nearest stages, prefer large #stages
|
|
161
|
+
nearest = heapq.nsmallest(
|
|
162
|
+
2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
|
163
|
+
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
|
164
|
+
|
|
165
|
+
for n in nearest:
|
|
166
|
+
pruned_configs.append(n[0])
|
|
167
|
+
else: # Volta & Turing only supports num_stages <= 2
|
|
168
|
+
random_config = v[0][0]
|
|
169
|
+
random_config.num_stages = 2
|
|
170
|
+
pruned_configs.append(random_config)
|
|
171
|
+
return pruned_configs
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics)
|
|
2
|
+
from .cache import RedisRemoteCacheBackend, RemoteCacheBackend
|
|
3
|
+
from .driver import driver
|
|
4
|
+
from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
|
|
5
|
+
from .errors import OutOfResources, InterpreterError
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"autotune",
|
|
9
|
+
"Autotuner",
|
|
10
|
+
"Config",
|
|
11
|
+
"driver",
|
|
12
|
+
"Heuristics",
|
|
13
|
+
"heuristics",
|
|
14
|
+
"InterpreterError",
|
|
15
|
+
"JITFunction",
|
|
16
|
+
"KernelInterface",
|
|
17
|
+
"MockTensor",
|
|
18
|
+
"OutOfResources",
|
|
19
|
+
"RedisRemoteCacheBackend",
|
|
20
|
+
"reinterpret",
|
|
21
|
+
"RemoteCacheBackend",
|
|
22
|
+
"TensorWrapper",
|
|
23
|
+
]
|