triton-windows 3.3.0.post19__cp312-cp312-win_amd64.whl → 3.4.0.post20__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +149 -47
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +92 -93
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +303 -128
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +76 -12
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +14 -6
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
- triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
- {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
- triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
- triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
|
@@ -1,5 +1,6 @@
|
|
|
1
|
-
from triton.backends.compiler import BaseBackend, GPUTarget
|
|
1
|
+
from triton.backends.compiler import BaseBackend, GPUTarget, Language
|
|
2
2
|
from triton._C.libtriton import ir, passes, llvm, nvidia
|
|
3
|
+
from triton import knobs
|
|
3
4
|
from triton.runtime.errors import PTXASError
|
|
4
5
|
|
|
5
6
|
from dataclasses import dataclass
|
|
@@ -13,7 +14,6 @@ import signal
|
|
|
13
14
|
import os
|
|
14
15
|
import subprocess
|
|
15
16
|
from pathlib import Path
|
|
16
|
-
import sysconfig
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def min_dot_size(target: GPUTarget):
|
|
@@ -30,46 +30,16 @@ def min_dot_size(target: GPUTarget):
|
|
|
30
30
|
return check_dot_compatibility
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
paths = [
|
|
36
|
-
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
|
|
37
|
-
]
|
|
38
|
-
binary += sysconfig.get_config_var("EXE")
|
|
39
|
-
paths += [
|
|
40
|
-
os.path.join(os.path.dirname(__file__), "bin", binary),
|
|
41
|
-
]
|
|
42
|
-
if os.name == "nt":
|
|
43
|
-
from triton.windows_utils import find_cuda
|
|
44
|
-
cuda_bin_path, _, _ = find_cuda()
|
|
45
|
-
if cuda_bin_path:
|
|
46
|
-
paths += [os.path.join(cuda_bin_path, binary)]
|
|
47
|
-
|
|
48
|
-
for path in paths:
|
|
49
|
-
if os.path.exists(path) and os.path.isfile(path):
|
|
50
|
-
result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
|
|
51
|
-
if result is not None:
|
|
52
|
-
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
|
53
|
-
if version is not None:
|
|
54
|
-
return path, version.group(1)
|
|
55
|
-
raise RuntimeError(f"Cannot find {binary}")
|
|
33
|
+
def get_ptxas() -> knobs.NvidiaTool:
|
|
34
|
+
return knobs.nvidia.ptxas
|
|
56
35
|
|
|
57
36
|
|
|
58
37
|
@functools.lru_cache()
|
|
59
|
-
def
|
|
60
|
-
|
|
61
|
-
name = "ptxas"
|
|
62
|
-
else:
|
|
63
|
-
name = "ptxas-blackwell" if arch >= 100 else "ptxas"
|
|
64
|
-
return _path_to_binary(name)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
@functools.lru_cache()
|
|
68
|
-
def get_ptxas_version(arch: int):
|
|
69
|
-
mock_ver = os.environ.get('TRITON_MOCK_PTX_VERSION')
|
|
38
|
+
def get_ptxas_version():
|
|
39
|
+
mock_ver = knobs.nvidia.mock_ptx_version
|
|
70
40
|
if mock_ver is not None:
|
|
71
41
|
return mock_ver # This is not really a version of ptxas, but it is good enough for testing
|
|
72
|
-
version = subprocess.check_output([get_ptxas(
|
|
42
|
+
version = subprocess.check_output([get_ptxas().path, "--version"]).decode("utf-8")
|
|
73
43
|
return version
|
|
74
44
|
|
|
75
45
|
|
|
@@ -95,7 +65,7 @@ def ptx_get_version(cuda_version) -> int:
|
|
|
95
65
|
def get_ptx_version_from_options(options, arch: int):
|
|
96
66
|
ptx_version = options.ptx_version
|
|
97
67
|
if ptx_version is None:
|
|
98
|
-
|
|
68
|
+
cuda_version = get_ptxas().version
|
|
99
69
|
ptx_version = ptx_get_version(cuda_version)
|
|
100
70
|
return ptx_version
|
|
101
71
|
|
|
@@ -141,19 +111,18 @@ class CUDAOptions:
|
|
|
141
111
|
num_warps: int = 4
|
|
142
112
|
num_ctas: int = 1
|
|
143
113
|
num_stages: int = 3
|
|
144
|
-
num_buffers_warp_spec: int = 0
|
|
145
|
-
num_consumer_groups: int = 0
|
|
146
|
-
reg_dec_producer: int = 0
|
|
147
|
-
reg_inc_consumer: int = 0
|
|
148
114
|
# maxnreg corresponds to the ptx parameter .maxnreg, which controls the
|
|
149
115
|
# maximum number of 32-bit registers used by one thread.
|
|
150
116
|
maxnreg: Optional[int] = None
|
|
151
117
|
cluster_dims: tuple = (1, 1, 1)
|
|
152
118
|
ptx_version: int = None
|
|
119
|
+
ptx_options: str = None
|
|
120
|
+
ir_override: Optional[str] = None # filename of a user-defined IR (*.{ttir|ttgir|llir|ptx})
|
|
153
121
|
enable_fp_fusion: bool = True
|
|
154
122
|
launch_cooperative_grid: bool = False
|
|
123
|
+
launch_pdl: bool = False
|
|
155
124
|
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
|
|
156
|
-
|
|
125
|
+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
|
|
157
126
|
default_dot_input_precision: str = "tf32"
|
|
158
127
|
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
|
|
159
128
|
max_num_imprecise_acc_default: bool = None
|
|
@@ -167,7 +136,8 @@ class CUDAOptions:
|
|
|
167
136
|
default_libdir = Path(__file__).parent / 'lib'
|
|
168
137
|
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
|
|
169
138
|
if not extern_libs.get('libdevice', None):
|
|
170
|
-
extern_libs['libdevice'] =
|
|
139
|
+
extern_libs['libdevice'] = knobs.nvidia.libdevice_path or str(default_libdir / 'libdevice.10.bc')
|
|
140
|
+
|
|
171
141
|
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
|
|
172
142
|
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
|
|
173
143
|
"num_warps must be a power of 2"
|
|
@@ -192,12 +162,16 @@ class CUDABackend(BaseBackend):
|
|
|
192
162
|
raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
|
|
193
163
|
return int(match.group(1))
|
|
194
164
|
|
|
165
|
+
def get_target_name(self, options) -> str:
|
|
166
|
+
capability = self._parse_arch(options.arch)
|
|
167
|
+
return f"cuda:{capability}"
|
|
168
|
+
|
|
195
169
|
def __init__(self, target: GPUTarget) -> None:
|
|
196
170
|
super().__init__(target)
|
|
197
171
|
self.binary_ext = "cubin"
|
|
198
172
|
|
|
199
173
|
def parse_options(self, opts) -> Any:
|
|
200
|
-
args = {'arch':
|
|
174
|
+
args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"}
|
|
201
175
|
args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
|
|
202
176
|
capability = int(self._parse_arch(args["arch"]))
|
|
203
177
|
|
|
@@ -207,12 +181,12 @@ class CUDABackend(BaseBackend):
|
|
|
207
181
|
supported_fp8_dtypes.add("fp8e4nv")
|
|
208
182
|
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
|
|
209
183
|
|
|
210
|
-
if "
|
|
184
|
+
if "deprecated_fp8_dot_operand_dtypes" not in args:
|
|
211
185
|
if capability >= 90:
|
|
212
|
-
args["
|
|
186
|
+
args["deprecated_fp8_dot_operand_dtypes"] = ("fp8e4b15", )
|
|
213
187
|
|
|
214
188
|
if "enable_fp_fusion" not in args:
|
|
215
|
-
args["enable_fp_fusion"] =
|
|
189
|
+
args["enable_fp_fusion"] = knobs.language.default_fp_fusion
|
|
216
190
|
|
|
217
191
|
args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
|
|
218
192
|
|
|
@@ -246,11 +220,13 @@ class CUDABackend(BaseBackend):
|
|
|
246
220
|
nvidia.load_dialects(ctx)
|
|
247
221
|
|
|
248
222
|
@staticmethod
|
|
249
|
-
def make_ttir(mod, metadata, opt):
|
|
223
|
+
def make_ttir(mod, metadata, opt, capability):
|
|
250
224
|
pm = ir.pass_manager(mod.context)
|
|
251
225
|
pm.enable_debug()
|
|
252
226
|
passes.common.add_inliner(pm)
|
|
253
227
|
passes.ttir.add_rewrite_tensor_pointer(pm)
|
|
228
|
+
if capability // 10 < 9:
|
|
229
|
+
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
|
|
254
230
|
passes.common.add_canonicalizer(pm)
|
|
255
231
|
passes.ttir.add_combine(pm)
|
|
256
232
|
passes.ttir.add_reorder_broadcast(pm)
|
|
@@ -262,6 +238,10 @@ class CUDABackend(BaseBackend):
|
|
|
262
238
|
|
|
263
239
|
@staticmethod
|
|
264
240
|
def make_ttgir(mod, metadata, opt, capability):
|
|
241
|
+
# Set maxnreg on all kernels, if it was provided.
|
|
242
|
+
if opt.maxnreg is not None:
|
|
243
|
+
mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
|
|
244
|
+
|
|
265
245
|
cluster_info = nvidia.ClusterInfo()
|
|
266
246
|
if opt.cluster_dims is not None:
|
|
267
247
|
cluster_info.clusterDimX = opt.cluster_dims[0]
|
|
@@ -281,56 +261,69 @@ class CUDABackend(BaseBackend):
|
|
|
281
261
|
passes.ttgpuir.add_accelerate_matmul(pm)
|
|
282
262
|
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
283
263
|
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
|
|
284
|
-
passes.
|
|
264
|
+
nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
|
|
265
|
+
passes.ttir.add_loop_aware_cse(pm)
|
|
285
266
|
if capability // 10 in [8, 9]:
|
|
286
267
|
passes.ttgpuir.add_fuse_nested_loops(pm)
|
|
287
268
|
passes.common.add_canonicalizer(pm)
|
|
288
|
-
passes.
|
|
289
|
-
passes.ttgpuir.add_optimize_accumulator_init(pm)
|
|
269
|
+
passes.ttir.add_triton_licm(pm)
|
|
290
270
|
passes.common.add_canonicalizer(pm)
|
|
291
271
|
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
292
|
-
passes.
|
|
293
|
-
passes.ttgpuir.
|
|
294
|
-
passes.ttgpuir.
|
|
295
|
-
passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups,
|
|
296
|
-
opt.reg_dec_producer, opt.reg_inc_consumer)
|
|
272
|
+
nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
|
|
273
|
+
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
|
|
274
|
+
passes.ttgpuir.add_schedule_loops(pm)
|
|
297
275
|
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
|
|
298
|
-
passes.ttgpuir.add_ping_pong_sync(pm, opt.num_consumer_groups)
|
|
299
|
-
passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)
|
|
300
276
|
elif capability // 10 >= 10:
|
|
301
277
|
passes.ttgpuir.add_fuse_nested_loops(pm)
|
|
302
278
|
passes.common.add_canonicalizer(pm)
|
|
303
|
-
passes.
|
|
279
|
+
passes.ttir.add_triton_licm(pm)
|
|
304
280
|
passes.ttgpuir.add_optimize_accumulator_init(pm)
|
|
305
|
-
passes.ttgpuir.
|
|
306
|
-
passes.
|
|
307
|
-
passes.ttgpuir.
|
|
308
|
-
passes.ttgpuir.
|
|
309
|
-
|
|
281
|
+
passes.ttgpuir.add_hoist_tmem_alloc(pm)
|
|
282
|
+
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
|
|
283
|
+
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
|
|
284
|
+
passes.ttgpuir.add_schedule_loops(pm)
|
|
285
|
+
passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
|
|
310
286
|
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
|
|
311
287
|
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
312
|
-
nvidia.passes.ttnvgpuir.
|
|
313
|
-
nvidia.passes.ttnvgpuir.add_keep_acc_in_tmem(pm)
|
|
314
|
-
passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups)
|
|
315
|
-
passes.common.add_canonicalizer(pm)
|
|
288
|
+
nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
|
|
316
289
|
else:
|
|
317
|
-
passes.
|
|
290
|
+
passes.ttir.add_triton_licm(pm)
|
|
291
|
+
passes.common.add_canonicalizer(pm)
|
|
292
|
+
passes.ttir.add_loop_aware_cse(pm)
|
|
318
293
|
passes.ttgpuir.add_prefetch(pm)
|
|
319
294
|
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
|
|
320
295
|
passes.ttgpuir.add_coalesce_async_copy(pm)
|
|
296
|
+
nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
|
|
321
297
|
passes.ttgpuir.add_remove_layout_conversions(pm)
|
|
298
|
+
nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
|
|
322
299
|
passes.ttgpuir.add_reduce_data_duplication(pm)
|
|
323
300
|
passes.ttgpuir.add_reorder_instructions(pm)
|
|
324
|
-
passes.
|
|
301
|
+
passes.ttir.add_loop_aware_cse(pm)
|
|
325
302
|
passes.common.add_symbol_dce(pm)
|
|
326
303
|
if capability // 10 >= 9:
|
|
327
|
-
nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
|
|
328
304
|
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
|
|
305
|
+
nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
|
|
306
|
+
passes.common.add_sccp(pm)
|
|
329
307
|
passes.common.add_canonicalizer(pm)
|
|
330
|
-
if capability // 10 >= 9:
|
|
331
|
-
passes.ttgpuir.add_ws_canonicalization(pm, opt.num_consumer_groups)
|
|
332
308
|
pm.run(mod)
|
|
333
309
|
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
|
|
310
|
+
tensordesc_meta = mod.get_tensordesc_metadata()
|
|
311
|
+
metadata["tensordesc_meta"] = tensordesc_meta
|
|
312
|
+
return mod
|
|
313
|
+
|
|
314
|
+
def ttgir_opt(self, src, metadata, options, capability):
|
|
315
|
+
mod = src
|
|
316
|
+
pm = ir.pass_manager(mod.context)
|
|
317
|
+
pm.enable_debug()
|
|
318
|
+
|
|
319
|
+
passes.ttgpuir.add_inliner(pm)
|
|
320
|
+
passes.common.add_sccp(pm)
|
|
321
|
+
passes.ttir.add_loop_aware_cse(pm)
|
|
322
|
+
passes.ttgpuir.add_canonicalizer(pm)
|
|
323
|
+
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
|
|
324
|
+
|
|
325
|
+
pm.run(mod)
|
|
326
|
+
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
|
|
334
327
|
return mod
|
|
335
328
|
|
|
336
329
|
def make_llir(self, src, metadata, options, capability):
|
|
@@ -356,28 +349,23 @@ class CUDABackend(BaseBackend):
|
|
|
356
349
|
passes.common.add_canonicalizer(pm)
|
|
357
350
|
passes.common.add_cse(pm)
|
|
358
351
|
passes.common.add_symbol_dce(pm)
|
|
359
|
-
if
|
|
352
|
+
if not knobs.compilation.disable_line_info:
|
|
360
353
|
passes.llvmir.add_di_scope(pm)
|
|
361
354
|
pm.run(mod)
|
|
362
355
|
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
|
|
363
356
|
llvm.init_targets()
|
|
364
357
|
context = llvm.context()
|
|
365
|
-
if
|
|
358
|
+
if knobs.compilation.enable_asan:
|
|
366
359
|
raise RuntimeError(
|
|
367
360
|
"Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
|
|
368
361
|
llvm_mod = llvm.to_module(mod, context)
|
|
369
362
|
proc = sm_arch_from_capability(capability)
|
|
370
363
|
features = get_features(options, self.target.arch)
|
|
371
364
|
triple = 'nvptx64-nvidia-cuda'
|
|
365
|
+
nvidia.set_short_ptr()
|
|
372
366
|
llvm.attach_datalayout(llvm_mod, triple, proc, features)
|
|
373
367
|
nvidia.set_nvvm_reflect_ftz(llvm_mod)
|
|
374
368
|
|
|
375
|
-
# Set maxnreg on all kernels, if it was provided.
|
|
376
|
-
if options.maxnreg is not None:
|
|
377
|
-
for k in llvm_mod.get_functions():
|
|
378
|
-
if not k.is_declaration() and k.is_external_linkage():
|
|
379
|
-
k.set_nvvm_maxnreg(options.maxnreg)
|
|
380
|
-
|
|
381
369
|
if options.extern_libs:
|
|
382
370
|
paths = [path for (name, path) in options.extern_libs]
|
|
383
371
|
llvm.link_extern_libs(llvm_mod, paths)
|
|
@@ -404,7 +392,7 @@ class CUDABackend(BaseBackend):
|
|
|
404
392
|
triple = 'nvptx64-nvidia-cuda'
|
|
405
393
|
proc = sm_arch_from_capability(capability)
|
|
406
394
|
features = get_features(opt, self.target.arch)
|
|
407
|
-
ret = llvm.translate_to_asm(src, triple, proc, features, [
|
|
395
|
+
ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
|
|
408
396
|
# Find kernel names (there should only be one)
|
|
409
397
|
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
|
|
410
398
|
assert len(names) == 1
|
|
@@ -415,25 +403,33 @@ class CUDABackend(BaseBackend):
|
|
|
415
403
|
ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
|
|
416
404
|
# Remove the debug flag that prevents ptxas from optimizing the code
|
|
417
405
|
ret = re.sub(r",\s*debug|debug,\s*", "", ret)
|
|
418
|
-
if
|
|
406
|
+
if knobs.nvidia.dump_nvptx:
|
|
419
407
|
print("// -----// NVPTX Dump //----- //")
|
|
420
408
|
print(ret)
|
|
421
409
|
return ret
|
|
422
410
|
|
|
423
411
|
def make_cubin(self, src, metadata, opt, capability):
|
|
424
|
-
ptxas
|
|
412
|
+
ptxas = get_ptxas().path
|
|
425
413
|
with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
|
|
426
414
|
tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
|
|
427
415
|
fsrc.write(src)
|
|
428
416
|
fsrc.flush()
|
|
429
417
|
fbin = fsrc.name + '.o'
|
|
430
418
|
|
|
431
|
-
line_info = ["-lineinfo", "-suppress-debug-info"] if
|
|
432
|
-
"0") == "1" else ["-lineinfo"]
|
|
419
|
+
line_info = ["-lineinfo", "-suppress-debug-info"] if knobs.compilation.disable_line_info else ["-lineinfo"]
|
|
433
420
|
fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
|
|
434
421
|
arch = sm_arch_from_capability(capability)
|
|
435
|
-
|
|
436
|
-
|
|
422
|
+
|
|
423
|
+
# Disable ptxas optimizations if requested
|
|
424
|
+
disable_opt = ['--opt-level', '0'] if knobs.nvidia.disable_ptxas_opt else []
|
|
425
|
+
|
|
426
|
+
# Accept more ptxas options if provided
|
|
427
|
+
ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
|
|
428
|
+
|
|
429
|
+
ptxas_cmd = [
|
|
430
|
+
ptxas, *line_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name, '-o',
|
|
431
|
+
fbin
|
|
432
|
+
]
|
|
437
433
|
try:
|
|
438
434
|
# close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
|
|
439
435
|
# On Windows, both stdout and stderr need to be redirected to flog
|
|
@@ -462,15 +458,18 @@ class CUDABackend(BaseBackend):
|
|
|
462
458
|
try_remove(flog.name)
|
|
463
459
|
return cubin
|
|
464
460
|
|
|
465
|
-
def add_stages(self, stages, options):
|
|
461
|
+
def add_stages(self, stages, options, language):
|
|
466
462
|
capability = self._parse_arch(options.arch)
|
|
467
|
-
|
|
468
|
-
|
|
463
|
+
if language == Language.TRITON:
|
|
464
|
+
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
|
|
465
|
+
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
|
|
466
|
+
elif language == Language.GLUON:
|
|
467
|
+
stages["ttgir"] = lambda src, metadata: self.ttgir_opt(src, metadata, options, capability)
|
|
469
468
|
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
|
|
470
469
|
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
|
|
471
470
|
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
|
|
472
471
|
|
|
473
472
|
@functools.lru_cache()
|
|
474
473
|
def hash(self):
|
|
475
|
-
version = get_ptxas_version(
|
|
474
|
+
version = get_ptxas_version()
|
|
476
475
|
return f'{version}-{self.target.arch}'
|
triton/backends/nvidia/driver.c
CHANGED
|
@@ -10,7 +10,6 @@
|
|
|
10
10
|
|
|
11
11
|
#include <stdbool.h>
|
|
12
12
|
#define PY_SSIZE_T_CLEAN
|
|
13
|
-
#define Py_LIMITED_API 0x03090000
|
|
14
13
|
#include <Python.h>
|
|
15
14
|
|
|
16
15
|
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
|
|
@@ -112,6 +111,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
|
|
112
111
|
CUmodule mod;
|
|
113
112
|
int32_t n_regs = 0;
|
|
114
113
|
int32_t n_spills = 0;
|
|
114
|
+
int32_t n_max_threads = 0;
|
|
115
115
|
// create driver handles
|
|
116
116
|
CUcontext pctx = 0;
|
|
117
117
|
|
|
@@ -132,6 +132,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
|
|
132
132
|
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
|
133
133
|
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
|
134
134
|
n_spills /= 4;
|
|
135
|
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
|
|
136
|
+
&n_max_threads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
|
|
135
137
|
// set dynamic shared memory if necessary
|
|
136
138
|
int shared_optin;
|
|
137
139
|
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
|
|
@@ -155,8 +157,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
|
|
155
157
|
if (PyErr_Occurred()) {
|
|
156
158
|
return NULL;
|
|
157
159
|
}
|
|
158
|
-
return Py_BuildValue("(
|
|
159
|
-
n_spills);
|
|
160
|
+
return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
|
161
|
+
n_spills, n_max_threads);
|
|
160
162
|
}
|
|
161
163
|
|
|
162
164
|
typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
|
|
@@ -308,112 +310,103 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
|
|
|
308
310
|
return Py_None;
|
|
309
311
|
}
|
|
310
312
|
|
|
311
|
-
|
|
312
|
-
// This is a useful to test TMA operations independently.
|
|
313
|
-
static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
|
|
314
|
-
unsigned long long global_address;
|
|
315
|
-
uint64_t dim;
|
|
316
|
-
uint32_t tensorDim;
|
|
317
|
-
int elementSize;
|
|
313
|
+
static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
|
|
318
314
|
unsigned long long desc_address;
|
|
319
|
-
|
|
320
|
-
|
|
315
|
+
unsigned long long global_address;
|
|
316
|
+
int swizzle;
|
|
317
|
+
int elemSize;
|
|
318
|
+
int elemType;
|
|
319
|
+
PyObject *blockSize;
|
|
320
|
+
PyObject *shape;
|
|
321
|
+
PyObject *strides;
|
|
322
|
+
|
|
323
|
+
if (!PyArg_ParseTuple(args, "KKiiiOOO", &desc_address, &global_address,
|
|
324
|
+
&swizzle, &elemSize, &elemType, &blockSize, &shape,
|
|
325
|
+
&strides)) {
|
|
321
326
|
return NULL;
|
|
322
327
|
}
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
328
|
+
|
|
329
|
+
PyObject *blockSizeFast = NULL;
|
|
330
|
+
PyObject *shapeFast = NULL;
|
|
331
|
+
PyObject *stridesFast = NULL;
|
|
332
|
+
PyObject *result = NULL;
|
|
333
|
+
|
|
334
|
+
uint32_t blockSizeInt[5];
|
|
335
|
+
uint64_t shapeInt[5];
|
|
336
|
+
uint64_t stridesLL[5];
|
|
337
|
+
|
|
338
|
+
blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
|
|
339
|
+
if (!blockSizeFast)
|
|
340
|
+
goto cleanup;
|
|
341
|
+
int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
|
|
342
|
+
|
|
343
|
+
for (int i = 0; i < rank; ++i) {
|
|
344
|
+
PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
|
|
345
|
+
if (!PyLong_Check(item)) {
|
|
346
|
+
PyErr_SetString(PyExc_TypeError, "block size must be an int");
|
|
347
|
+
goto cleanup;
|
|
348
|
+
}
|
|
349
|
+
blockSizeInt[rank - i - 1] = PyLong_AsLongLong(item);
|
|
341
350
|
}
|
|
342
|
-
assert((elementSize * tensorDim) >= 32 && "block size too small.");
|
|
343
|
-
int rank = 1;
|
|
344
|
-
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
|
|
345
|
-
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
|
|
346
|
-
getCuTensorMapEncodeTiledHandle);
|
|
347
|
-
CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
|
|
348
|
-
(CUtensorMap *)desc_address, type, rank, (void *)global_address, dims,
|
|
349
|
-
globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
|
|
350
|
-
CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
|
351
|
-
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
|
352
|
-
Py_INCREF(Py_None);
|
|
353
|
-
return Py_None;
|
|
354
|
-
}
|
|
355
351
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
unsigned long long desc_address;
|
|
364
|
-
if (!PyArg_ParseTuple(args, "KKKiiiK", &global_address, &dims[1], &dims[0],
|
|
365
|
-
&tensorDims[1], &tensorDims[0], &elementSize,
|
|
366
|
-
&desc_address)) {
|
|
367
|
-
return NULL;
|
|
352
|
+
shapeFast = PySequence_Fast(shape, "shape must be a sequence");
|
|
353
|
+
if (!shapeFast)
|
|
354
|
+
goto cleanup;
|
|
355
|
+
|
|
356
|
+
if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
|
|
357
|
+
PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
|
|
358
|
+
goto cleanup;
|
|
368
359
|
}
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
break;
|
|
377
|
-
case 2:
|
|
378
|
-
type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
|
379
|
-
break;
|
|
380
|
-
case 4:
|
|
381
|
-
type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
|
382
|
-
break;
|
|
383
|
-
default:
|
|
384
|
-
PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4");
|
|
360
|
+
for (int i = 0; i < rank; ++i) {
|
|
361
|
+
PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
|
|
362
|
+
if (!PyLong_Check(item)) {
|
|
363
|
+
PyErr_SetString(PyExc_TypeError, "shape must be an int");
|
|
364
|
+
goto cleanup;
|
|
365
|
+
}
|
|
366
|
+
shapeInt[rank - i - 1] = PyLong_AsLong(item);
|
|
385
367
|
}
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
if (
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
|
|
395
|
-
} else if (contigDimSizeInByte >= 32) {
|
|
396
|
-
swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
|
|
397
|
-
} else {
|
|
398
|
-
assert(false && "block size too small.");
|
|
368
|
+
|
|
369
|
+
stridesFast = PySequence_Fast(strides, "strides must be a sequence");
|
|
370
|
+
if (!stridesFast)
|
|
371
|
+
goto cleanup;
|
|
372
|
+
|
|
373
|
+
if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
|
|
374
|
+
PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
|
|
375
|
+
goto cleanup;
|
|
399
376
|
}
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
377
|
+
for (int i = 0; i + 1 < rank; ++i) {
|
|
378
|
+
PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
|
|
379
|
+
if (!PyLong_Check(item)) {
|
|
380
|
+
PyErr_SetString(PyExc_TypeError, "shape must be an int");
|
|
381
|
+
goto cleanup;
|
|
382
|
+
}
|
|
383
|
+
stridesLL[rank - i - 2] = elemSize * PyLong_AsLongLong(item);
|
|
406
384
|
}
|
|
385
|
+
stridesLL[rank - 1] =
|
|
386
|
+
shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]);
|
|
387
|
+
Py_DECREF(blockSizeFast);
|
|
388
|
+
blockSizeFast = NULL;
|
|
389
|
+
Py_DECREF(shapeFast);
|
|
390
|
+
shapeFast = NULL;
|
|
391
|
+
Py_DECREF(stridesFast);
|
|
392
|
+
stridesFast = NULL;
|
|
393
|
+
|
|
394
|
+
uint32_t elementStrides[5] = {1, 1, 1, 1, 1};
|
|
407
395
|
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
|
|
408
396
|
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
|
|
409
397
|
getCuTensorMapEncodeTiledHandle);
|
|
410
398
|
CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
|
|
411
|
-
(CUtensorMap *)desc_address,
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
|
415
|
-
|
|
416
|
-
|
|
399
|
+
(CUtensorMap *)desc_address, elemType, rank, (void *)global_address,
|
|
400
|
+
shapeInt, stridesLL, blockSizeInt, elementStrides,
|
|
401
|
+
CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
|
|
402
|
+
CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
|
403
|
+
Py_RETURN_NONE;
|
|
404
|
+
|
|
405
|
+
cleanup:
|
|
406
|
+
Py_XDECREF(blockSizeFast);
|
|
407
|
+
Py_XDECREF(shapeFast);
|
|
408
|
+
Py_XDECREF(stridesFast);
|
|
409
|
+
return result;
|
|
417
410
|
}
|
|
418
411
|
|
|
419
412
|
static PyMethodDef ModuleMethods[] = {
|
|
@@ -429,8 +422,7 @@ static PyMethodDef ModuleMethods[] = {
|
|
|
429
422
|
"being dropped. This inherits all the limitations of this call; in "
|
|
430
423
|
"particular it's an error to change this value after launching any kernel "
|
|
431
424
|
"that calls printf()."},
|
|
432
|
-
{"
|
|
433
|
-
{"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"},
|
|
425
|
+
{"fill_tma_descriptor", fillTMADescriptor, METH_VARARGS, "doc"},
|
|
434
426
|
|
|
435
427
|
{NULL, NULL, 0, NULL} // sentinel
|
|
436
428
|
};
|