triton-windows 3.2.0.post12__cp312-cp312-win_amd64.whl → 3.3.0a0.post12__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
triton/runtime/autotuner.py
CHANGED
|
@@ -4,10 +4,10 @@ import builtins
|
|
|
4
4
|
import os
|
|
5
5
|
import time
|
|
6
6
|
import inspect
|
|
7
|
-
from typing import Dict
|
|
7
|
+
from typing import Dict, Tuple, List, Optional
|
|
8
8
|
|
|
9
9
|
from .jit import KernelInterface
|
|
10
|
-
from .errors import OutOfResources
|
|
10
|
+
from .errors import OutOfResources, PTXASError
|
|
11
11
|
from .driver import driver
|
|
12
12
|
|
|
13
13
|
|
|
@@ -23,7 +23,7 @@ class Autotuner(KernelInterface):
|
|
|
23
23
|
restore_value,
|
|
24
24
|
pre_hook=None,
|
|
25
25
|
post_hook=None,
|
|
26
|
-
prune_configs_by: Dict = None,
|
|
26
|
+
prune_configs_by: Optional[Dict] = None,
|
|
27
27
|
warmup=None,
|
|
28
28
|
rep=None,
|
|
29
29
|
use_cuda_graph=False,
|
|
@@ -36,14 +36,11 @@ class Autotuner(KernelInterface):
|
|
|
36
36
|
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
|
37
37
|
"""
|
|
38
38
|
if not configs:
|
|
39
|
-
self.configs = [
|
|
40
|
-
Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0,
|
|
41
|
-
reg_dec_producer=0, reg_inc_consumer=0)
|
|
42
|
-
]
|
|
39
|
+
self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)]
|
|
43
40
|
else:
|
|
44
41
|
self.configs = configs
|
|
45
42
|
self.keys = key
|
|
46
|
-
self.cache = {}
|
|
43
|
+
self.cache: Dict[Tuple, Config] = {}
|
|
47
44
|
self.arg_names = arg_names
|
|
48
45
|
|
|
49
46
|
# Reset to zero or restore values
|
|
@@ -134,6 +131,10 @@ class Autotuner(KernelInterface):
|
|
|
134
131
|
def _bench(self, *args, config, **meta):
|
|
135
132
|
from ..compiler.errors import CompileTimeAssertionFailure
|
|
136
133
|
|
|
134
|
+
verbose = os.environ.get("TRITON_PRINT_AUTOTUNING", None) == "1"
|
|
135
|
+
if verbose:
|
|
136
|
+
print(f"Autotuning kernel {self.base_fn.__name__} with config {config}")
|
|
137
|
+
|
|
137
138
|
# check for conflicts, i.e. meta-parameters both provided
|
|
138
139
|
# as kwargs and by the autotuner
|
|
139
140
|
conflicts = meta.keys() & config.kwargs.keys()
|
|
@@ -164,7 +165,9 @@ class Autotuner(KernelInterface):
|
|
|
164
165
|
|
|
165
166
|
try:
|
|
166
167
|
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
|
|
167
|
-
except (OutOfResources, CompileTimeAssertionFailure):
|
|
168
|
+
except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e:
|
|
169
|
+
if verbose:
|
|
170
|
+
print(f"Autotuning failed with {e}")
|
|
168
171
|
return [float("inf"), float("inf"), float("inf")]
|
|
169
172
|
|
|
170
173
|
def run(self, *args, **kwargs):
|
|
@@ -208,7 +211,7 @@ class Autotuner(KernelInterface):
|
|
|
208
211
|
self.nargs = None
|
|
209
212
|
return ret
|
|
210
213
|
|
|
211
|
-
def prune_configs(self, kwargs):
|
|
214
|
+
def prune_configs(self, kwargs: Dict) -> List[Config]:
|
|
212
215
|
pruned_configs = self.configs
|
|
213
216
|
if self.early_config_prune:
|
|
214
217
|
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
|
|
@@ -216,6 +219,10 @@ class Autotuner(KernelInterface):
|
|
|
216
219
|
top_k = self.configs_top_k
|
|
217
220
|
if isinstance(top_k, float) and top_k <= 1.0:
|
|
218
221
|
top_k = int(len(self.configs) * top_k)
|
|
222
|
+
elif not isinstance(top_k, int):
|
|
223
|
+
# Slice index must be an integer
|
|
224
|
+
raise TypeError("Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int")
|
|
225
|
+
|
|
219
226
|
if len(pruned_configs) > top_k:
|
|
220
227
|
est_timing = {
|
|
221
228
|
config: self.perf_model(
|
|
@@ -262,16 +269,11 @@ class Config:
|
|
|
262
269
|
function are args.
|
|
263
270
|
"""
|
|
264
271
|
|
|
265
|
-
def __init__(self, kwargs, num_warps=4, num_stages=
|
|
266
|
-
reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None):
|
|
272
|
+
def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None):
|
|
267
273
|
self.kwargs = kwargs
|
|
268
274
|
self.num_warps = num_warps
|
|
269
275
|
self.num_ctas = num_ctas
|
|
270
276
|
self.num_stages = num_stages
|
|
271
|
-
self.num_buffers_warp_spec = num_buffers_warp_spec
|
|
272
|
-
self.num_consumer_groups = num_consumer_groups
|
|
273
|
-
self.reg_dec_producer = reg_dec_producer
|
|
274
|
-
self.reg_inc_consumer = reg_inc_consumer
|
|
275
277
|
self.maxnreg = maxnreg
|
|
276
278
|
self.pre_hook = pre_hook
|
|
277
279
|
|
|
@@ -283,10 +285,6 @@ class Config:
|
|
|
283
285
|
("num_warps", self.num_warps),
|
|
284
286
|
("num_ctas", self.num_ctas),
|
|
285
287
|
("num_stages", self.num_stages),
|
|
286
|
-
("num_buffers_warp_spec", self.num_buffers_warp_spec),
|
|
287
|
-
("num_consumer_groups", self.num_consumer_groups),
|
|
288
|
-
("reg_dec_producer", self.reg_dec_producer),
|
|
289
|
-
("reg_inc_consumer", self.reg_inc_consumer),
|
|
290
288
|
("maxnreg", self.maxnreg),
|
|
291
289
|
) if v is not None
|
|
292
290
|
}
|
|
@@ -299,10 +297,6 @@ class Config:
|
|
|
299
297
|
res.append(f"num_warps: {self.num_warps}")
|
|
300
298
|
res.append(f"num_ctas: {self.num_ctas}")
|
|
301
299
|
res.append(f"num_stages: {self.num_stages}")
|
|
302
|
-
res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}")
|
|
303
|
-
res.append(f"num_consumer_groups: {self.num_consumer_groups}")
|
|
304
|
-
res.append(f"reg_dec_producer: {self.reg_dec_producer}")
|
|
305
|
-
res.append(f"reg_inc_consumer: {self.reg_inc_consumer}")
|
|
306
300
|
res.append(f"maxnreg: {self.maxnreg}")
|
|
307
301
|
return ", ".join(res)
|
|
308
302
|
|
|
@@ -323,8 +317,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
|
|
|
323
317
|
# the value of x_size changes
|
|
324
318
|
)
|
|
325
319
|
@triton.jit
|
|
326
|
-
def kernel(x_ptr, x_size,
|
|
327
|
-
|
|
320
|
+
def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
|
|
321
|
+
...
|
|
328
322
|
:note: When all the configurations are evaluated, the kernel will run multiple times.
|
|
329
323
|
This means that whatever value the kernel updates will be updated multiple times.
|
|
330
324
|
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
|
@@ -367,7 +361,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
|
|
|
367
361
|
def decorator(fn):
|
|
368
362
|
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
|
|
369
363
|
post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
|
|
370
|
-
use_cuda_graph=use_cuda_graph)
|
|
364
|
+
use_cuda_graph=use_cuda_graph, do_bench=do_bench)
|
|
371
365
|
|
|
372
366
|
return decorator
|
|
373
367
|
|
|
@@ -388,18 +382,19 @@ class Heuristics(KernelInterface):
|
|
|
388
382
|
def heuristics(values):
|
|
389
383
|
"""
|
|
390
384
|
Decorator for specifying how the values of certain meta-parameters may be computed.
|
|
391
|
-
This is useful for cases where auto-tuning is
|
|
385
|
+
This is useful for cases where auto-tuning is prohibitively expensive, or just not applicable.
|
|
392
386
|
|
|
393
387
|
.. highlight:: python
|
|
394
388
|
.. code-block:: python
|
|
395
389
|
|
|
396
|
-
|
|
390
|
+
# smallest power-of-two >= x_size
|
|
391
|
+
@triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])})
|
|
397
392
|
@triton.jit
|
|
398
|
-
def kernel(x_ptr, x_size,
|
|
399
|
-
|
|
393
|
+
def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
|
|
394
|
+
...
|
|
400
395
|
:param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
|
|
401
396
|
each such function takes a list of positional arguments as input.
|
|
402
|
-
:type values: dict[str, Callable[[
|
|
397
|
+
:type values: dict[str, Callable[[dict[str, Any]], Any]]
|
|
403
398
|
"""
|
|
404
399
|
|
|
405
400
|
def decorator(fn):
|
triton/runtime/build.py
CHANGED
|
@@ -1,26 +1,12 @@
|
|
|
1
|
-
import contextlib
|
|
2
|
-
import sys
|
|
3
|
-
import io
|
|
4
1
|
import sysconfig
|
|
5
2
|
import os
|
|
6
3
|
import shutil
|
|
7
4
|
import subprocess
|
|
8
|
-
import setuptools
|
|
9
5
|
|
|
10
6
|
if os.name == "nt":
|
|
11
7
|
from triton.windows_utils import find_msvc_winsdk, find_python
|
|
12
8
|
|
|
13
9
|
|
|
14
|
-
@contextlib.contextmanager
|
|
15
|
-
def quiet():
|
|
16
|
-
old_stdout, old_stderr = sys.stdout, sys.stderr
|
|
17
|
-
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
|
|
18
|
-
try:
|
|
19
|
-
yield
|
|
20
|
-
finally:
|
|
21
|
-
sys.stdout, sys.stderr = old_stdout, old_stderr
|
|
22
|
-
|
|
23
|
-
|
|
24
10
|
def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
|
|
25
11
|
if cc.lower().endswith("cl") or cc.lower().endswith("cl.exe"):
|
|
26
12
|
out_base = os.path.splitext(out)[0]
|
|
@@ -74,38 +60,5 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
|
|
|
74
60
|
include_dirs += msvc_winsdk_inc_dirs
|
|
75
61
|
library_dirs += msvc_winsdk_lib_dirs
|
|
76
62
|
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
|
|
77
|
-
|
|
78
|
-
if ret == 0:
|
|
79
|
-
return so
|
|
80
|
-
# fallback on setuptools
|
|
81
|
-
extra_compile_args = []
|
|
82
|
-
if cc.lower().endswith("cl") or cc.lower().endswith("cl.exe"):
|
|
83
|
-
extra_compile_args += ["/O2"]
|
|
84
|
-
else:
|
|
85
|
-
extra_compile_args += ["-O3"]
|
|
86
|
-
# extra arguments
|
|
87
|
-
extra_link_args = []
|
|
88
|
-
# create extension module
|
|
89
|
-
ext = setuptools.Extension(
|
|
90
|
-
name=name,
|
|
91
|
-
language='c',
|
|
92
|
-
sources=[src],
|
|
93
|
-
include_dirs=include_dirs,
|
|
94
|
-
extra_compile_args=extra_compile_args,
|
|
95
|
-
extra_link_args=extra_link_args,
|
|
96
|
-
library_dirs=library_dirs,
|
|
97
|
-
libraries=libraries,
|
|
98
|
-
)
|
|
99
|
-
# build extension module
|
|
100
|
-
args = ['build_ext']
|
|
101
|
-
args.append('--build-temp=' + srcdir)
|
|
102
|
-
args.append('--build-lib=' + srcdir)
|
|
103
|
-
args.append('-q')
|
|
104
|
-
args = dict(
|
|
105
|
-
name=name,
|
|
106
|
-
ext_modules=[ext],
|
|
107
|
-
script_args=args,
|
|
108
|
-
)
|
|
109
|
-
with quiet():
|
|
110
|
-
setuptools.setup(**args)
|
|
63
|
+
subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
|
|
111
64
|
return so
|
triton/runtime/cache.py
CHANGED
|
@@ -256,9 +256,9 @@ __cache_cls = FileCacheManager
|
|
|
256
256
|
__cache_cls_nme = "DEFAULT"
|
|
257
257
|
|
|
258
258
|
|
|
259
|
-
def
|
|
259
|
+
def _base32(key):
|
|
260
260
|
# Assume key is a hex string.
|
|
261
|
-
return base64.
|
|
261
|
+
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
|
262
262
|
|
|
263
263
|
|
|
264
264
|
def get_cache_manager(key) -> CacheManager:
|
|
@@ -274,15 +274,15 @@ def get_cache_manager(key) -> CacheManager:
|
|
|
274
274
|
__cache_cls = getattr(module, clz_nme)
|
|
275
275
|
__cache_cls_nme = user_cache_manager
|
|
276
276
|
|
|
277
|
-
return __cache_cls(
|
|
277
|
+
return __cache_cls(_base32(key))
|
|
278
278
|
|
|
279
279
|
|
|
280
280
|
def get_override_manager(key) -> CacheManager:
|
|
281
|
-
return __cache_cls(
|
|
281
|
+
return __cache_cls(_base32(key), override=True)
|
|
282
282
|
|
|
283
283
|
|
|
284
284
|
def get_dump_manager(key) -> CacheManager:
|
|
285
|
-
return __cache_cls(
|
|
285
|
+
return __cache_cls(_base32(key), dump=True)
|
|
286
286
|
|
|
287
287
|
|
|
288
288
|
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
|
@@ -292,4 +292,4 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
|
|
292
292
|
for kw in kwargs:
|
|
293
293
|
key = f"{key}-{kwargs.get(kw)}"
|
|
294
294
|
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
295
|
-
return
|
|
295
|
+
return _base32(key)
|
triton/runtime/errors.py
CHANGED
|
@@ -24,3 +24,13 @@ class OutOfResources(TritonError):
|
|
|
24
24
|
def __reduce__(self):
|
|
25
25
|
# this is necessary to make CompilationError picklable
|
|
26
26
|
return (type(self), (self.required, self.limit, self.name))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PTXASError(TritonError):
|
|
30
|
+
|
|
31
|
+
def __init__(self, error_message: Optional[str] = None):
|
|
32
|
+
self.error_message = error_message
|
|
33
|
+
|
|
34
|
+
def __str__(self) -> str:
|
|
35
|
+
error_message = self.error_message or ""
|
|
36
|
+
return f"PTXAS error: {error_message}"
|
triton/runtime/interpreter.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import ast
|
|
2
2
|
import textwrap
|
|
3
3
|
import inspect
|
|
4
|
-
from typing import Tuple
|
|
4
|
+
from typing import Tuple, List
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
7
|
import numpy as np
|
|
@@ -21,7 +21,7 @@ class TensorHandle:
|
|
|
21
21
|
'''
|
|
22
22
|
data: numpy array
|
|
23
23
|
dtype: triton type, either pointer_type or scalar_type.
|
|
24
|
-
we don't store block_type here because the shape information is already
|
|
24
|
+
we don't store block_type here because the shape information is already available in the data field
|
|
25
25
|
attr: a dictionary of attributes
|
|
26
26
|
'''
|
|
27
27
|
self.data = data
|
|
@@ -46,27 +46,63 @@ class TensorHandle:
|
|
|
46
46
|
|
|
47
47
|
class BlockPointerHandle:
|
|
48
48
|
|
|
49
|
-
def __init__(self, base, shape, strides, offsets,
|
|
49
|
+
def __init__(self, base, shape, strides, offsets, block_shape, order):
|
|
50
50
|
self.base = base
|
|
51
51
|
self.shape = shape
|
|
52
52
|
self.strides = strides
|
|
53
53
|
self.offsets = offsets
|
|
54
|
-
self.
|
|
54
|
+
self.block_shape = block_shape
|
|
55
55
|
self.order = order
|
|
56
56
|
|
|
57
57
|
def materialize_pointers(self, boundary_check):
|
|
58
58
|
dtype_tt = self.base.get_element_ty()
|
|
59
59
|
n_bytes = dtype_tt.primitive_bitwidth // 8
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
bcast_dims = [
|
|
65
|
-
|
|
66
|
-
off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
|
|
60
|
+
ptrs = np.broadcast_to(self.base.data, self.block_shape)
|
|
61
|
+
masks = np.ones(self.block_shape, dtype=bool)
|
|
62
|
+
for dim in range(len(self.block_shape)):
|
|
63
|
+
bcast_dims = [1] * len(self.block_shape)
|
|
64
|
+
bcast_dims[dim] = self.block_shape[dim]
|
|
65
|
+
off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
|
|
67
66
|
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
|
|
68
67
|
if dim in boundary_check:
|
|
69
|
-
masks =
|
|
68
|
+
masks = masks & (off < self.shape[dim].data) & (off >= 0)
|
|
69
|
+
ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
|
|
70
|
+
return ptrs, masks
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class TensorDescHandle:
|
|
74
|
+
|
|
75
|
+
def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
|
|
76
|
+
block_shape: List[int]):
|
|
77
|
+
self.base = base
|
|
78
|
+
self.ndim = len(shape)
|
|
79
|
+
self.shape = shape
|
|
80
|
+
self.strides = strides
|
|
81
|
+
self.block_shape = block_shape
|
|
82
|
+
|
|
83
|
+
def validate(self):
|
|
84
|
+
assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
|
|
85
|
+
assert len(self.strides) == self.ndim
|
|
86
|
+
assert len(self.block_shape) == self.ndim
|
|
87
|
+
|
|
88
|
+
for stride in self.strides[:-1]:
|
|
89
|
+
assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
|
|
90
|
+
assert self.strides[-1].data.item() == 1, "last dim must be contiguous"
|
|
91
|
+
|
|
92
|
+
def materialize_pointers(self, offsets: List[TensorHandle]):
|
|
93
|
+
assert len(offsets) == self.ndim
|
|
94
|
+
scalar_ty = self.base.dtype.element_ty
|
|
95
|
+
itemsize = scalar_ty.primitive_bitwidth // 8
|
|
96
|
+
assert (offsets[-1].data * itemsize) % 16 == 0, "block offset start must be 16-byte aligned"
|
|
97
|
+
|
|
98
|
+
ptrs = np.broadcast_to(self.base.data, self.block_shape)
|
|
99
|
+
masks = np.ones(self.block_shape, dtype=bool)
|
|
100
|
+
for dim in range(len(self.block_shape)):
|
|
101
|
+
bcast_dims = [1] * len(self.block_shape)
|
|
102
|
+
bcast_dims[dim] = self.block_shape[dim]
|
|
103
|
+
off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
|
|
104
|
+
ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64)
|
|
105
|
+
masks = masks & (0 <= off) & (off < self.shape[dim].data)
|
|
70
106
|
ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
|
|
71
107
|
return ptrs, masks
|
|
72
108
|
|
|
@@ -242,7 +278,7 @@ class InterpreterBuilder:
|
|
|
242
278
|
self.options = InterpreterOptions()
|
|
243
279
|
self.codegen_fns = {}
|
|
244
280
|
self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types
|
|
245
|
-
self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (
|
|
281
|
+
self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1)
|
|
246
282
|
|
|
247
283
|
def set_grid_idx(self, x, y, z):
|
|
248
284
|
if not x < self.grid_dim[0]:
|
|
@@ -419,7 +455,7 @@ class InterpreterBuilder:
|
|
|
419
455
|
create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
|
|
420
456
|
create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
|
|
421
457
|
create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
|
|
422
|
-
create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.
|
|
458
|
+
create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
|
|
423
459
|
create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
|
|
424
460
|
create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
|
|
425
461
|
create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
|
|
@@ -557,6 +593,9 @@ class InterpreterBuilder:
|
|
|
557
593
|
def create_histogram(self, data, bins):
|
|
558
594
|
return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32)
|
|
559
595
|
|
|
596
|
+
def create_gather(self, src, indices, axis):
|
|
597
|
+
return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)
|
|
598
|
+
|
|
560
599
|
# pointer arithmetic
|
|
561
600
|
|
|
562
601
|
def create_addptr(self, ptr, offset):
|
|
@@ -655,21 +694,61 @@ class InterpreterBuilder:
|
|
|
655
694
|
# Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
|
|
656
695
|
pass
|
|
657
696
|
|
|
658
|
-
def create_make_block_ptr(self, base, shape, strides, offsets,
|
|
697
|
+
def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order):
|
|
659
698
|
# Create new offsets to avoid modifying the original
|
|
660
699
|
new_offsets = [offset.clone() for offset in offsets]
|
|
661
|
-
return BlockPointerHandle(base, shape, strides, new_offsets,
|
|
700
|
+
return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order)
|
|
662
701
|
|
|
663
702
|
def create_advance(self, ptr, offsets):
|
|
664
703
|
if len(ptr.offsets) != len(offsets):
|
|
665
704
|
raise ValueError("len(ptr.offsets) != len(offsets)")
|
|
666
705
|
# Create new offsets to avoid modifying the original
|
|
667
706
|
new_offsets = [offset.clone() for offset in ptr.offsets]
|
|
668
|
-
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.
|
|
707
|
+
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order)
|
|
669
708
|
for i in range(len(offsets)):
|
|
670
709
|
ret.offsets[i].data += offsets[i].data
|
|
671
710
|
return ret
|
|
672
711
|
|
|
712
|
+
def create_make_tensor_descriptor(
|
|
713
|
+
self,
|
|
714
|
+
base: TensorHandle,
|
|
715
|
+
shape: List[TensorHandle],
|
|
716
|
+
strides: List[TensorHandle],
|
|
717
|
+
tensor_shape: List[int],
|
|
718
|
+
):
|
|
719
|
+
desc = TensorDescHandle(base, shape, strides, tensor_shape)
|
|
720
|
+
desc.validate()
|
|
721
|
+
return desc
|
|
722
|
+
|
|
723
|
+
def create_descriptor_load(self, desc: TensorDescHandle, indices: List[TensorHandle], cache_modifier,
|
|
724
|
+
eviction_policy):
|
|
725
|
+
assert isinstance(desc, TensorDescHandle)
|
|
726
|
+
ptrs, mask = desc.materialize_pointers(indices)
|
|
727
|
+
return self.create_masked_load(ptrs, mask, other=None, cache_modifier=cache_modifier,
|
|
728
|
+
eviction_policy=eviction_policy, is_volatile=False)
|
|
729
|
+
|
|
730
|
+
def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]):
|
|
731
|
+
ptrs, mask = desc.materialize_pointers(indices)
|
|
732
|
+
return self.create_masked_store(ptrs, value, mask, None, None)
|
|
733
|
+
|
|
734
|
+
def create_descriptor_gather(self, desc: TensorDescHandle, x_offsets: TensorHandle, y_offset: TensorHandle, type):
|
|
735
|
+
dtype = desc.base.dtype.element_ty
|
|
736
|
+
np_dtype = _get_np_dtype(dtype)
|
|
737
|
+
result = np.zeros([x_offsets.data.shape[0], desc.block_shape[-1]], dtype=np_dtype)
|
|
738
|
+
cache_modifier = None
|
|
739
|
+
eviction_policy = None
|
|
740
|
+
for i, x_offset in enumerate(x_offsets.data):
|
|
741
|
+
indices = [TensorHandle(x_offset, tl.int32), y_offset]
|
|
742
|
+
result[i, :] = self.create_descriptor_load(desc, indices, cache_modifier, eviction_policy).data
|
|
743
|
+
return TensorHandle(result, dtype)
|
|
744
|
+
|
|
745
|
+
def create_descriptor_scatter(self, desc: TensorDescHandle, value: TensorHandle, x_offsets: TensorHandle,
|
|
746
|
+
y_offset: TensorHandle):
|
|
747
|
+
for i, x_offset in enumerate(x_offsets.data):
|
|
748
|
+
slice = TensorHandle(value.data[i], value.dtype)
|
|
749
|
+
indices = [TensorHandle(x_offset, tl.int32), y_offset]
|
|
750
|
+
self.create_descriptor_store(desc, slice, indices)
|
|
751
|
+
|
|
673
752
|
def get_all_ones_value(self, type):
|
|
674
753
|
np_type = _get_np_dtype(type)
|
|
675
754
|
if "int" in np_type.name:
|
|
@@ -701,7 +780,12 @@ def _patch_lang_tensor(tensor):
|
|
|
701
780
|
return bool(data) if data.size == 1 else True
|
|
702
781
|
|
|
703
782
|
def _get_transpose(self):
|
|
704
|
-
|
|
783
|
+
handle = TensorHandle(np.transpose(self.handle.data), self.handle.dtype)
|
|
784
|
+
assert self.type.is_block()
|
|
785
|
+
block_shape = list(self.type.shape)
|
|
786
|
+
block_shape[-1], block_shape[-2] = block_shape[-2], block_shape[-1]
|
|
787
|
+
res_ty = tl.core.block_type(self.dtype, block_shape)
|
|
788
|
+
return tl.core.tensor(handle, res_ty)
|
|
705
789
|
|
|
706
790
|
tensor.__index__ = lambda self: int(self.handle.data)
|
|
707
791
|
tensor.__bool__ = lambda self: _get_bool(self)
|
|
@@ -710,7 +794,7 @@ def _patch_lang_tensor(tensor):
|
|
|
710
794
|
tensor.T = property(_get_transpose)
|
|
711
795
|
|
|
712
796
|
|
|
713
|
-
class
|
|
797
|
+
class ReduceScanOpInterface:
|
|
714
798
|
|
|
715
799
|
def __init__(self, axis, combine_fn):
|
|
716
800
|
self.axis = axis
|
|
@@ -727,10 +811,12 @@ class ReduceScanOpIneterface:
|
|
|
727
811
|
self.check_axis(arg.shape, self.axis)
|
|
728
812
|
|
|
729
813
|
def to_tensor(self, ret, dtype):
|
|
814
|
+
np_dtype = _get_np_dtype(dtype)
|
|
730
815
|
if hasattr(ret, "shape") and ret.shape:
|
|
731
|
-
|
|
816
|
+
ret = ret.astype(np_dtype)
|
|
817
|
+
ret_type = tl.block_type(dtype, list(ret.shape))
|
|
732
818
|
else:
|
|
733
|
-
ret = np.array([ret]
|
|
819
|
+
ret = np.array([ret], dtype=np_dtype)
|
|
734
820
|
ret_type = dtype
|
|
735
821
|
return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type)
|
|
736
822
|
|
|
@@ -744,7 +830,7 @@ class ReduceScanOpIneterface:
|
|
|
744
830
|
raise NotImplementedError("apply_impl not implemented")
|
|
745
831
|
|
|
746
832
|
|
|
747
|
-
class ReduceOps(
|
|
833
|
+
class ReduceOps(ReduceScanOpInterface):
|
|
748
834
|
|
|
749
835
|
def __init__(self, axis, combine_fn, keep_dims):
|
|
750
836
|
super().__init__(axis, combine_fn)
|
|
@@ -840,7 +926,7 @@ class ReduceOps(ReduceScanOpIneterface):
|
|
|
840
926
|
return self.generic_reduce(input)
|
|
841
927
|
|
|
842
928
|
|
|
843
|
-
class ScanOps(
|
|
929
|
+
class ScanOps(ReduceScanOpInterface):
|
|
844
930
|
|
|
845
931
|
def __init__(self, axis, combine_fn, reverse):
|
|
846
932
|
super().__init__(axis, combine_fn)
|
|
@@ -989,7 +1075,7 @@ def _patch_lang_core(lang):
|
|
|
989
1075
|
lang.static_assert = _new_static_assert
|
|
990
1076
|
lang.static_print = print
|
|
991
1077
|
lang.dtype.to_ir = _new_to_ir
|
|
992
|
-
lang.multiple_of = partial(_set_attr, name="tt.
|
|
1078
|
+
lang.multiple_of = partial(_set_attr, name="tt.divisibility")
|
|
993
1079
|
lang.max_contiguous = partial(_set_attr, name="tt.contiguity")
|
|
994
1080
|
lang.max_constancy = partial(_set_attr, name="tt.constancy")
|
|
995
1081
|
|
|
@@ -997,7 +1083,7 @@ def _patch_lang_core(lang):
|
|
|
997
1083
|
|
|
998
1084
|
|
|
999
1085
|
def _patch_lang(fn):
|
|
1000
|
-
langs = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]]
|
|
1086
|
+
langs = [value for _, value in fn.__globals__.items() if inspect.ismodule(value) and value in [tl, tl.core]]
|
|
1001
1087
|
assert len(langs) >= 1, "triton.language must be visible from within jit'd function"
|
|
1002
1088
|
for lang in langs:
|
|
1003
1089
|
_patch_builtin(lang, interpreter_builder)
|
|
@@ -1006,12 +1092,22 @@ def _patch_lang(fn):
|
|
|
1006
1092
|
_patch_builtin(lang.math, interpreter_builder)
|
|
1007
1093
|
_patch_lang_tensor(lang.tensor)
|
|
1008
1094
|
_patch_lang_core(lang)
|
|
1095
|
+
_patch_builtin(tl.core._experimental_tensor_descriptor_base, interpreter_builder)
|
|
1096
|
+
|
|
1097
|
+
|
|
1098
|
+
def _tuple_create(arg, contents):
|
|
1099
|
+
# NamedTuples and tuples have different construction semantics. NamedTuple
|
|
1100
|
+
# has a constructor that takes individual arguments, while tuple takes an
|
|
1101
|
+
# iterable. Both have type "tuple" making it difficult to distinguish
|
|
1102
|
+
# between them, but only NamedTuple has "_fields" and apparently this is how
|
|
1103
|
+
# everyone does the check.
|
|
1104
|
+
return type(arg)(*contents) if hasattr(arg, "_fields") else type(arg)(contents)
|
|
1009
1105
|
|
|
1010
1106
|
|
|
1011
1107
|
# TODO: wrap everything in triton tensors
|
|
1012
1108
|
def _implicit_cvt(arg):
|
|
1013
1109
|
if isinstance(arg, int):
|
|
1014
|
-
ty = tl.str_to_ty(triton.runtime.jit.
|
|
1110
|
+
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
|
|
1015
1111
|
dtype = np.int32
|
|
1016
1112
|
if -2**31 <= arg < 2**31:
|
|
1017
1113
|
dtype = np.int32
|
|
@@ -1026,16 +1122,27 @@ def _implicit_cvt(arg):
|
|
|
1026
1122
|
handle = TensorHandle(np.array([arg], dtype=dtype), ty)
|
|
1027
1123
|
return tl.tensor(handle, ty)
|
|
1028
1124
|
if hasattr(arg, "data_ptr"):
|
|
1029
|
-
ty = tl.str_to_ty(triton.runtime.jit.
|
|
1125
|
+
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
|
|
1030
1126
|
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
|
|
1031
1127
|
return tl.tensor(handle, ty)
|
|
1128
|
+
elif isinstance(arg, tuple):
|
|
1129
|
+
return _tuple_create(arg, map(_implicit_cvt, arg))
|
|
1032
1130
|
return arg
|
|
1033
1131
|
|
|
1034
1132
|
|
|
1035
1133
|
interpreter_builder = InterpreterBuilder()
|
|
1036
1134
|
|
|
1037
|
-
|
|
1038
|
-
|
|
1135
|
+
|
|
1136
|
+
def _unwrap_tensor(t):
|
|
1137
|
+
if isinstance(t, triton.runtime.jit.TensorWrapper):
|
|
1138
|
+
return t.base
|
|
1139
|
+
return t
|
|
1140
|
+
|
|
1141
|
+
|
|
1142
|
+
def _rewrap_tensor(t, original_tensor):
|
|
1143
|
+
if isinstance(original_tensor, triton.runtime.jit.TensorWrapper):
|
|
1144
|
+
return triton.runtime.jit.TensorWrapper(t, original_tensor.dtype)
|
|
1145
|
+
return t
|
|
1039
1146
|
|
|
1040
1147
|
|
|
1041
1148
|
class GridExecutor:
|
|
@@ -1050,37 +1157,64 @@ class GridExecutor:
|
|
|
1050
1157
|
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
|
|
1051
1158
|
|
|
1052
1159
|
def _init_args_hst(self, args_dev, kwargs):
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1160
|
+
storages = {}
|
|
1161
|
+
|
|
1162
|
+
def _to_cpu(arg):
|
|
1163
|
+
if isinstance(arg, tuple):
|
|
1164
|
+
return _tuple_create(arg, map(_to_cpu, arg))
|
|
1165
|
+
elif not hasattr(arg, "data_ptr"):
|
|
1166
|
+
return arg
|
|
1167
|
+
|
|
1168
|
+
unwrapped_arg = _unwrap_tensor(arg)
|
|
1169
|
+
if unwrapped_arg.untyped_storage().data_ptr() not in storages:
|
|
1170
|
+
storage = unwrapped_arg.untyped_storage()
|
|
1171
|
+
storages[storage.data_ptr()] = storage.cpu()
|
|
1172
|
+
|
|
1173
|
+
storage = storages[unwrapped_arg.untyped_storage().data_ptr()]
|
|
1174
|
+
cpu_arg = unwrapped_arg.new_empty(0, device='cpu')
|
|
1175
|
+
cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride())
|
|
1176
|
+
cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
|
|
1177
|
+
return cpu_arg
|
|
1178
|
+
|
|
1179
|
+
args_hst = [_to_cpu(arg) for arg in args_dev]
|
|
1180
|
+
|
|
1059
1181
|
# Process keyword arguments
|
|
1060
1182
|
kwargs_hst = {}
|
|
1061
1183
|
for key, value in kwargs.items():
|
|
1062
|
-
|
|
1063
|
-
kwargs_hst[key] = value.cpu()
|
|
1064
|
-
else:
|
|
1065
|
-
kwargs_hst[key] = value
|
|
1184
|
+
kwargs_hst[key] = _to_cpu(value)
|
|
1066
1185
|
return args_hst, kwargs_hst
|
|
1067
1186
|
|
|
1068
1187
|
def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
|
|
1069
|
-
|
|
1188
|
+
storages = {}
|
|
1189
|
+
|
|
1190
|
+
def _from_cpu(arg_dev, arg_hst):
|
|
1070
1191
|
if hasattr(arg_dev, "data_ptr"):
|
|
1071
|
-
|
|
1192
|
+
# No need to rewrap because this just modifies internal
|
|
1193
|
+
arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst)
|
|
1194
|
+
storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage())
|
|
1195
|
+
elif isinstance(arg_dev, tuple):
|
|
1196
|
+
for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
|
|
1197
|
+
_from_cpu(arg_dev, arg_hst)
|
|
1198
|
+
|
|
1199
|
+
for arg_dev, arg_hst in zip(args_dev, args_hst):
|
|
1200
|
+
_from_cpu(arg_dev, arg_hst)
|
|
1072
1201
|
|
|
1073
1202
|
# Restore keyword arguments
|
|
1074
1203
|
for key, kwarg_dev in kwargs.items():
|
|
1075
1204
|
kwarg_hst = kwargs_hst[key]
|
|
1076
|
-
|
|
1077
|
-
|
|
1205
|
+
_from_cpu(kwarg_dev, kwarg_hst)
|
|
1206
|
+
|
|
1207
|
+
for (arg_dev, arg_hst) in storages.values():
|
|
1208
|
+
arg_dev.copy_(arg_hst)
|
|
1078
1209
|
|
|
1079
1210
|
def __call__(self, *args_dev, **kwargs):
|
|
1080
|
-
# removes reserved keywords from kwargs
|
|
1081
|
-
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
|
|
1082
1211
|
if kwargs.pop("warmup", False):
|
|
1083
1212
|
return
|
|
1213
|
+
# Removes not used reserved keywords from kwargs
|
|
1214
|
+
# Triton doesn't support keyword-only, variable positional or variable keyword arguments
|
|
1215
|
+
# It's safe to inspect only positional or keyword arguments (i.e., argspec.args)
|
|
1216
|
+
argspec = inspect.getfullargspec(self.fn)
|
|
1217
|
+
kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
|
|
1084
1218
|
# copy arguments to the host
|
|
1085
1219
|
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
|
|
1086
1220
|
# remaps core language functions to interpreted ones
|