triton-windows 3.2.0.post12__cp39-cp39-win_amd64.whl → 3.3.0a0.post12__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
triton/runtime/jit.py
CHANGED
|
@@ -11,6 +11,7 @@ from functools import cached_property
|
|
|
11
11
|
from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple
|
|
12
12
|
from ..runtime.driver import driver
|
|
13
13
|
from types import ModuleType
|
|
14
|
+
from .._utils import find_paths_if, get_iterable_path
|
|
14
15
|
|
|
15
16
|
TRITON_MODULE = __name__[:-len(".runtime.jit")]
|
|
16
17
|
|
|
@@ -275,47 +276,63 @@ class KernelParam:
|
|
|
275
276
|
return self._param.default != inspect.Parameter.empty
|
|
276
277
|
|
|
277
278
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0):
|
|
281
|
-
return "D"
|
|
282
|
-
elif isinstance(v, int):
|
|
283
|
-
# bool is a subclass of int, so we don't check explicitly above.
|
|
284
|
-
if align and (v % 16 == 0):
|
|
285
|
-
return "D"
|
|
286
|
-
elif v == 1:
|
|
287
|
-
return "1"
|
|
288
|
-
return "N"
|
|
279
|
+
dtype2str = {}
|
|
280
|
+
specialize_impl_cache = []
|
|
289
281
|
|
|
290
282
|
|
|
291
|
-
|
|
283
|
+
def create_specialize_impl():
|
|
284
|
+
if specialize_impl_cache:
|
|
285
|
+
return specialize_impl_cache[-1]
|
|
292
286
|
|
|
287
|
+
from ..language import constexpr
|
|
293
288
|
|
|
294
|
-
def
|
|
289
|
+
def specialize_impl(arg, specialize_extra, is_const=False, specialize_value=True, align=True):
|
|
295
290
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
291
|
+
if arg is None:
|
|
292
|
+
return ("constexpr", None)
|
|
293
|
+
elif isinstance(arg, JITFunction):
|
|
294
|
+
return ("constexpr", arg.cache_key)
|
|
295
|
+
elif isinstance(arg, constexpr):
|
|
296
|
+
return ("constexpr", arg)
|
|
297
|
+
elif isinstance(arg, bool):
|
|
298
|
+
return ("i1", None)
|
|
299
|
+
elif isinstance(arg, int):
|
|
300
|
+
key = specialize_extra(arg, "int", align=align) if specialize_value else None
|
|
301
|
+
if arg == 1 and specialize_value:
|
|
302
|
+
return ("constexpr", 1)
|
|
303
|
+
elif -(2**31) <= arg and arg <= 2**31 - 1:
|
|
304
|
+
return ("i32", key)
|
|
305
|
+
elif 2**63 <= arg and arg <= 2**64 - 1:
|
|
306
|
+
return ("u64", key)
|
|
307
|
+
else:
|
|
308
|
+
return ("i64", key)
|
|
309
|
+
elif isinstance(arg, float):
|
|
310
|
+
return ("fp32", None)
|
|
311
|
+
elif hasattr(arg, "tma_desc_cpu_ptr"):
|
|
312
|
+
return ("nvTmaDesc", None)
|
|
313
|
+
elif isinstance(arg, tuple):
|
|
314
|
+
spec = [specialize_impl(x, specialize_extra) for x in arg]
|
|
315
|
+
make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
|
|
316
|
+
tys = make_tuple([x[0] for x in spec])
|
|
317
|
+
keys = make_tuple([x[1] for x in spec])
|
|
318
|
+
return (tys, keys)
|
|
305
319
|
else:
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
320
|
+
# dtypes are hashable so we can memoize this mapping:
|
|
321
|
+
dsk = (arg.dtype, is_const)
|
|
322
|
+
res = dtype2str.get(dsk, None)
|
|
323
|
+
if res is None:
|
|
324
|
+
res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]]
|
|
325
|
+
dtype2str[dsk] = res
|
|
326
|
+
key = specialize_extra(arg, "tensor", align=align) if specialize_value else None
|
|
327
|
+
return (res, key)
|
|
328
|
+
|
|
329
|
+
specialize_impl_cache.append(specialize_impl)
|
|
330
|
+
return specialize_impl
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def mangle_type(arg, specialize=False):
|
|
334
|
+
specialize_impl = create_specialize_impl()
|
|
335
|
+
return specialize_impl(arg, lambda _, **kwargs: None, specialize_value=specialize)[0]
|
|
319
336
|
|
|
320
337
|
|
|
321
338
|
class KernelInterface(Generic[T]):
|
|
@@ -335,8 +352,9 @@ def serialize_specialization_data(name, signature, constants, attrs, options, ke
|
|
|
335
352
|
constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()}
|
|
336
353
|
import json
|
|
337
354
|
obj = {
|
|
338
|
-
'name': name, 'signature': signature, '
|
|
339
|
-
|
|
355
|
+
'name': name, 'signature': signature, 'constant_keys': [list(x) for x in constants.keys()], 'constant_vals':
|
|
356
|
+
list(constants.values()), 'attrs_keys': [list(x) for x in attrs.keys()], 'attrs_vals': list(attrs.values()),
|
|
357
|
+
'options': options.__dict__, 'key': key
|
|
340
358
|
}
|
|
341
359
|
serialized_obj = json.dumps(obj)
|
|
342
360
|
return serialized_obj
|
|
@@ -349,50 +367,32 @@ def create_function_from_signature(sig, kparams, backend):
|
|
|
349
367
|
basis to avoid having to run these expensive functions -- which constitute
|
|
350
368
|
much of the kernel launch overhead -- every time we run the kernel.
|
|
351
369
|
"""
|
|
352
|
-
|
|
353
370
|
assert len(sig.parameters) == len(kparams)
|
|
354
|
-
|
|
355
371
|
# Create the function argument list and the dict entries for the return statement
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
non_constexpr_vals = []
|
|
360
|
-
signature_types = []
|
|
361
|
-
specialisations = []
|
|
362
|
-
|
|
363
|
-
for ((name, sp), kp) in zip(sig.parameters.items(), kparams):
|
|
364
|
-
if sp.default is inspect.Parameter.empty:
|
|
365
|
-
func_args.append(name)
|
|
366
|
-
dict_entries.append(f"'{name}': {name}")
|
|
367
|
-
else:
|
|
368
|
-
func_args.append(f"{name}=default_{name}")
|
|
369
|
-
dict_entries.append(f"'{name}': {name}")
|
|
372
|
+
specialization = []
|
|
373
|
+
# signature
|
|
374
|
+
for name, kp in zip(sig.parameters.keys(), kparams):
|
|
370
375
|
if kp.is_constexpr:
|
|
371
|
-
|
|
376
|
+
specialization.append(f'("constexpr", {name})')
|
|
372
377
|
else:
|
|
373
|
-
|
|
374
|
-
if
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
else:
|
|
378
|
-
specialisations.append('compute_spec_key(%s, align=False)' % name)
|
|
378
|
+
is_const = 'True' if kp.is_const else 'False'
|
|
379
|
+
specialize = 'False' if kp.do_not_specialize else 'True'
|
|
380
|
+
align = 'False' if kp.do_not_specialize_on_alignment else 'True'
|
|
381
|
+
ret = f"specialize_impl({name}, specialize_extra, {is_const}, {specialize}, {align})"
|
|
379
382
|
if kp.annotation_type:
|
|
380
|
-
|
|
383
|
+
specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
|
|
381
384
|
else:
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
cache_key = ''.join([x + ', ' for x in signature_types + specialisations])
|
|
385
|
-
constexpr_vals = ''.join([x + ', ' for x in constexpr_vals])
|
|
386
|
-
non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals])
|
|
387
|
-
|
|
388
|
-
func_args.append('**excess_kwargs')
|
|
385
|
+
specialization.append(f"{ret}")
|
|
389
386
|
|
|
387
|
+
# compute argument string for a given parameter
|
|
388
|
+
arg = lambda x: x[0] if x[1].default is inspect.Parameter.empty else f"{x[0]}=default_{x[0]}"
|
|
390
389
|
# Join all arguments into a function definition string
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
390
|
+
func_body = f"""
|
|
391
|
+
def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options"])}):
|
|
392
|
+
params = {{{', '.join([f"'{name}': {name}" for name in sig.parameters.keys()])}}}
|
|
393
|
+
specialization = [{','.join(specialization)}]
|
|
394
|
+
return params, specialization, options
|
|
395
|
+
"""
|
|
396
396
|
# Prepare defaults to be inserted into function namespace
|
|
397
397
|
func_namespace = {
|
|
398
398
|
f"default_{name}": param.default
|
|
@@ -400,8 +400,9 @@ def create_function_from_signature(sig, kparams, backend):
|
|
|
400
400
|
if param.default is not inspect.Parameter.empty
|
|
401
401
|
}
|
|
402
402
|
|
|
403
|
-
func_namespace[
|
|
404
|
-
func_namespace[
|
|
403
|
+
func_namespace["JITFunction"] = JITFunction
|
|
404
|
+
func_namespace["specialize_impl"] = create_specialize_impl()
|
|
405
|
+
func_namespace["specialize_extra"] = backend.get_arg_specialization
|
|
405
406
|
|
|
406
407
|
# Execute the function string in func_namespace to create the function
|
|
407
408
|
exec(func_body, func_namespace)
|
|
@@ -446,43 +447,6 @@ class JITFunction(KernelInterface[T]):
|
|
|
446
447
|
# cache_hook will always be called before compilation and compiled_hook after.
|
|
447
448
|
compiled_hook = None
|
|
448
449
|
|
|
449
|
-
@staticmethod
|
|
450
|
-
def _key_of(arg):
|
|
451
|
-
if hasattr(arg, "dtype"):
|
|
452
|
-
return arg.dtype
|
|
453
|
-
elif isinstance(arg, bool):
|
|
454
|
-
return "i1"
|
|
455
|
-
elif isinstance(arg, int):
|
|
456
|
-
if -(2**31) <= arg and arg <= 2**31 - 1:
|
|
457
|
-
return "i32"
|
|
458
|
-
elif 2**63 <= arg and arg <= 2**64 - 1:
|
|
459
|
-
return "u64"
|
|
460
|
-
else:
|
|
461
|
-
return "i64"
|
|
462
|
-
elif isinstance(arg, float):
|
|
463
|
-
return "fp32"
|
|
464
|
-
elif arg is None:
|
|
465
|
-
return None
|
|
466
|
-
else:
|
|
467
|
-
raise TypeError(f"Unsupported type {type(arg)} for {arg}")
|
|
468
|
-
|
|
469
|
-
@staticmethod
|
|
470
|
-
def _type_of(key, is_const=False):
|
|
471
|
-
# `None` is nullptr. Implicitly convert to *i8.
|
|
472
|
-
if key is None:
|
|
473
|
-
return "*i8"
|
|
474
|
-
elif isinstance(key, str):
|
|
475
|
-
return key
|
|
476
|
-
|
|
477
|
-
dtype_str = str(key).split(".")[-1]
|
|
478
|
-
dtype_str = type_canonicalisation_dict[dtype_str]
|
|
479
|
-
const_str = "*k" if is_const else "*"
|
|
480
|
-
return const_str + dtype_str
|
|
481
|
-
|
|
482
|
-
def _make_constants(self, constexpr_key):
|
|
483
|
-
constants = dict(zip(self.constexprs, constexpr_key))
|
|
484
|
-
return constants
|
|
485
|
-
|
|
486
450
|
def _call_hook(
|
|
487
451
|
self,
|
|
488
452
|
key,
|
|
@@ -501,7 +465,7 @@ class JITFunction(KernelInterface[T]):
|
|
|
501
465
|
name = self.fn.__name__
|
|
502
466
|
module = self.fn.__module__
|
|
503
467
|
arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
|
|
504
|
-
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})"
|
|
468
|
+
repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"
|
|
505
469
|
|
|
506
470
|
class JitFunctionInfo:
|
|
507
471
|
|
|
@@ -521,6 +485,7 @@ class JITFunction(KernelInterface[T]):
|
|
|
521
485
|
'num_ctas': options.num_ctas,
|
|
522
486
|
'num_stages': options.num_stages,
|
|
523
487
|
'enable_fp_fusion': options.enable_fp_fusion,
|
|
488
|
+
'launch_cooperative_grid': options.launch_cooperative_grid,
|
|
524
489
|
'extern_libs': options.extern_libs,
|
|
525
490
|
'configs': configs,
|
|
526
491
|
'specialization_data': specialization_data,
|
|
@@ -544,89 +509,66 @@ class JITFunction(KernelInterface[T]):
|
|
|
544
509
|
assert callable(hook)
|
|
545
510
|
self.pre_run_hooks.append(hook)
|
|
546
511
|
|
|
547
|
-
def create_binder(self
|
|
512
|
+
def create_binder(self):
|
|
548
513
|
"""
|
|
549
514
|
Precompute as much as possible.
|
|
550
515
|
"""
|
|
551
516
|
from ..compiler import CompiledKernel, compile, ASTSource, make_backend
|
|
517
|
+
target = driver.active.get_current_target()
|
|
518
|
+
backend = make_backend(target)
|
|
552
519
|
self.CompiledKernel = CompiledKernel
|
|
553
520
|
self.compile = compile
|
|
554
521
|
self.ASTSource = ASTSource
|
|
555
|
-
self.
|
|
556
|
-
|
|
557
|
-
self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr]
|
|
558
|
-
self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr]
|
|
559
|
-
self.specialised_indices = [
|
|
560
|
-
i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr)
|
|
561
|
-
]
|
|
522
|
+
binder = create_function_from_signature(self.signature, self.params, backend)
|
|
523
|
+
return {}, target, backend, binder
|
|
562
524
|
|
|
563
525
|
def run(self, *args, grid, warmup, **kwargs):
|
|
564
|
-
kwargs["debug"] = kwargs.get("debug",
|
|
526
|
+
kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1"
|
|
565
527
|
|
|
566
528
|
# parse options
|
|
567
|
-
from ..compiler import make_backend
|
|
568
529
|
device = driver.active.get_current_device()
|
|
569
530
|
stream = driver.active.get_current_stream(device)
|
|
570
|
-
target = driver.active.get_current_target()
|
|
571
|
-
backend = make_backend(target)
|
|
572
531
|
|
|
573
532
|
# Execute pre run hooks with args and kwargs
|
|
574
533
|
for hook in self.pre_run_hooks:
|
|
575
534
|
hook(*args, **kwargs)
|
|
576
535
|
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)
|
|
536
|
+
kernel_cache, target, backend, binder = self.device_caches[device]
|
|
537
|
+
bound_args, specialization, options = binder(*args, **kwargs)
|
|
581
538
|
|
|
582
539
|
# compute cache key
|
|
583
|
-
key =
|
|
584
|
-
kernel =
|
|
540
|
+
key = str(specialization) + str(options)
|
|
541
|
+
kernel = kernel_cache.get(key, None)
|
|
585
542
|
|
|
543
|
+
# Kernel is not cached; we have to compile.
|
|
586
544
|
if kernel is None:
|
|
587
|
-
#
|
|
545
|
+
# options
|
|
588
546
|
options = backend.parse_options(kwargs)
|
|
589
|
-
|
|
590
|
-
|
|
547
|
+
# signature
|
|
548
|
+
sigkeys = [x.name for x in self.params]
|
|
549
|
+
sigvals = [x[0] for x in specialization]
|
|
550
|
+
signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
|
|
551
|
+
# check arguments
|
|
591
552
|
assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
|
|
592
553
|
assert "device" not in kwargs, "device option is deprecated; current device will be used"
|
|
593
554
|
assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
|
|
594
|
-
for k in
|
|
595
|
-
if k not in options.__dict__:
|
|
555
|
+
for k in kwargs:
|
|
556
|
+
if k not in options.__dict__ and k not in sigkeys:
|
|
596
557
|
raise KeyError("Keyword argument %s was specified but unrecognised" % k)
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
#
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
sigvals = sig_and_spec[:len(sigkeys)]
|
|
606
|
-
signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)}
|
|
607
|
-
|
|
608
|
-
configs = (backend.get_attrs_descriptor(self.params, bound_vals), )
|
|
609
|
-
constant_params = configs[0].get_constants()
|
|
610
|
-
constants = {
|
|
611
|
-
p.name: v
|
|
612
|
-
for (v, p) in zip(bound_vals, self.params)
|
|
613
|
-
if p.is_constexpr or (p.num in constant_params) or v is None
|
|
614
|
-
}
|
|
615
|
-
for i, arg in constants.items():
|
|
616
|
-
if callable(arg):
|
|
617
|
-
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
|
618
|
-
|
|
619
|
-
if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True):
|
|
558
|
+
# constexprs
|
|
559
|
+
constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
|
|
560
|
+
constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
|
|
561
|
+
# attributes
|
|
562
|
+
attrvals = [x[1] for x in specialization]
|
|
563
|
+
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
|
|
564
|
+
attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
|
|
565
|
+
if self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=True):
|
|
620
566
|
return None
|
|
621
567
|
# compile the kernel
|
|
622
|
-
src = self.ASTSource(self, signature,
|
|
623
|
-
kernel = self.compile(
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
options=options.__dict__,
|
|
627
|
-
)
|
|
628
|
-
self.cache[device][key] = kernel
|
|
629
|
-
self._call_hook(key, signature, device, constants, options, configs, warmup, before=False)
|
|
568
|
+
src = self.ASTSource(self, signature, constexprs, attrs)
|
|
569
|
+
kernel = self.compile(src, target=target, options=options.__dict__)
|
|
570
|
+
kernel_cache[key] = kernel
|
|
571
|
+
self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False)
|
|
630
572
|
|
|
631
573
|
# Check that used global values have not changed.
|
|
632
574
|
not_present = object()
|
|
@@ -639,21 +581,21 @@ class JITFunction(KernelInterface[T]):
|
|
|
639
581
|
# canonicalize grid
|
|
640
582
|
assert grid is not None
|
|
641
583
|
if callable(grid):
|
|
642
|
-
# Arguments are passed as a dict to `grid`, by contract.
|
|
643
|
-
# TODO(jlebar): In the new launch API, pass the compiler flags as a
|
|
644
|
-
# second parameter to `grid`.
|
|
645
584
|
grid = grid(bound_args)
|
|
646
585
|
grid_size = len(grid)
|
|
647
586
|
grid_0 = grid[0]
|
|
648
587
|
grid_1 = grid[1] if grid_size > 1 else 1
|
|
649
588
|
grid_2 = grid[2] if grid_size > 2 else 1
|
|
650
|
-
|
|
651
589
|
# launch kernel
|
|
652
|
-
launch_metadata = kernel.launch_metadata(grid, stream, *
|
|
653
|
-
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
|
|
654
|
-
self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook,
|
|
590
|
+
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
|
|
591
|
+
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
|
|
592
|
+
launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook,
|
|
593
|
+
*bound_args.values())
|
|
655
594
|
return kernel
|
|
656
595
|
|
|
596
|
+
def repr(self, _):
|
|
597
|
+
return self._fn_name if self._repr is None else self._repr(_)
|
|
598
|
+
|
|
657
599
|
def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None,
|
|
658
600
|
noinline=None, repr=None, launch_metadata=None):
|
|
659
601
|
do_not_specialize = do_not_specialize if do_not_specialize else []
|
|
@@ -666,11 +608,10 @@ class JITFunction(KernelInterface[T]):
|
|
|
666
608
|
self.do_not_specialize = do_not_specialize
|
|
667
609
|
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
|
|
668
610
|
self.starting_line_number = inspect.getsourcelines(fn)[1]
|
|
669
|
-
self.
|
|
611
|
+
self._repr = repr
|
|
612
|
+
self._fn_name = fn.__name__
|
|
670
613
|
self.launch_metadata = launch_metadata
|
|
671
614
|
|
|
672
|
-
self.binder = None
|
|
673
|
-
|
|
674
615
|
self.params = []
|
|
675
616
|
for i, param in enumerate(self.signature.parameters.values()):
|
|
676
617
|
dns = i in do_not_specialize or param.name in do_not_specialize
|
|
@@ -678,10 +619,11 @@ class JITFunction(KernelInterface[T]):
|
|
|
678
619
|
self.params.append(KernelParam(i, param, dns, dns_oa))
|
|
679
620
|
|
|
680
621
|
# function source code (without decorators)
|
|
681
|
-
|
|
682
|
-
|
|
622
|
+
src = textwrap.dedent(inspect.getsource(fn))
|
|
623
|
+
src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
|
|
624
|
+
self._unsafe_update_src(src)
|
|
683
625
|
# cache of just-in-time compiled kernels
|
|
684
|
-
self.
|
|
626
|
+
self.device_caches = defaultdict(self.create_binder)
|
|
685
627
|
self.hash = None
|
|
686
628
|
|
|
687
629
|
# Map of global variables used by the function and any functions it
|
|
@@ -698,6 +640,7 @@ class JITFunction(KernelInterface[T]):
|
|
|
698
640
|
# JITFunction can be instantiated as kernel
|
|
699
641
|
# when called with a grid using __getitem__
|
|
700
642
|
self.kernel = None
|
|
643
|
+
self.debug = debug
|
|
701
644
|
self.noinline = noinline
|
|
702
645
|
|
|
703
646
|
# TODO(jlebar): Remove uses of these fields outside this file, then
|
|
@@ -729,7 +672,6 @@ class JITFunction(KernelInterface[T]):
|
|
|
729
672
|
|
|
730
673
|
def preload(self, specialization_data):
|
|
731
674
|
from ..compiler import compile, ASTSource
|
|
732
|
-
from triton.backends.compiler import AttrsDescriptor
|
|
733
675
|
import json
|
|
734
676
|
import triton.language as tl
|
|
735
677
|
device = driver.active.get_current_device()
|
|
@@ -737,19 +679,24 @@ class JITFunction(KernelInterface[T]):
|
|
|
737
679
|
if deserialized_obj['name'] != self.fn.__name__:
|
|
738
680
|
raise RuntimeError(
|
|
739
681
|
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}")
|
|
682
|
+
constant_keys = map(tuple, deserialized_obj['constant_keys'])
|
|
683
|
+
constant_vals = deserialized_obj['constant_vals']
|
|
740
684
|
constants = {
|
|
741
685
|
key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
|
|
742
|
-
for key, value in
|
|
686
|
+
for key, value in zip(constant_keys, constant_vals)
|
|
743
687
|
}
|
|
688
|
+
attrs_keys = map(tuple, deserialized_obj['attrs_keys'])
|
|
689
|
+
attrs_vals = deserialized_obj['attrs_vals']
|
|
690
|
+
attrs = dict(zip(attrs_keys, attrs_vals))
|
|
744
691
|
signature = dict(deserialized_obj['signature'].items())
|
|
745
|
-
src = ASTSource(self, signature, constants,
|
|
692
|
+
src = ASTSource(self, signature, constants, attrs)
|
|
746
693
|
options = {
|
|
747
694
|
key: tuple(value) if isinstance(value, list) else value
|
|
748
695
|
for key, value in deserialized_obj['options'].items()
|
|
749
696
|
}
|
|
750
697
|
key = deserialized_obj['key']
|
|
751
698
|
kernel = compile(src, None, options)
|
|
752
|
-
self.
|
|
699
|
+
self.device_caches[device][0][key] = kernel
|
|
753
700
|
return kernel
|
|
754
701
|
|
|
755
702
|
# we do not parse `src` in the constructor because
|
|
@@ -766,11 +713,20 @@ class JITFunction(KernelInterface[T]):
|
|
|
766
713
|
raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
|
|
767
714
|
|
|
768
715
|
def __setattr__(self, name, value):
|
|
769
|
-
|
|
770
|
-
# - when `.src` attribute is set, cache path needs
|
|
771
|
-
# to be reinitialized
|
|
716
|
+
# - when `.src` attribute is set, cache key of all callers need to be re-computed
|
|
772
717
|
if name == "src":
|
|
773
|
-
|
|
718
|
+
raise AttributeError(f"Cannot set attribute '{name}' directly. "
|
|
719
|
+
f"Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
|
|
720
|
+
f"instead.")
|
|
721
|
+
super(JITFunction, self).__setattr__(name, value)
|
|
722
|
+
|
|
723
|
+
def _unsafe_update_src(self, new_src):
|
|
724
|
+
"""
|
|
725
|
+
The only method allowed to modify src.
|
|
726
|
+
Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
|
|
727
|
+
"""
|
|
728
|
+
self.hash = None
|
|
729
|
+
super().__setattr__('src', new_src)
|
|
774
730
|
|
|
775
731
|
def __repr__(self):
|
|
776
732
|
return f"JITFunction({self.module}:{self.fn.__name__})"
|
|
@@ -896,8 +852,8 @@ class TensorWrapper:
|
|
|
896
852
|
def data_ptr(self):
|
|
897
853
|
return self.base.data_ptr()
|
|
898
854
|
|
|
899
|
-
def stride(self,
|
|
900
|
-
return self.base.stride(
|
|
855
|
+
def stride(self, *args):
|
|
856
|
+
return self.base.stride(*args)
|
|
901
857
|
|
|
902
858
|
def __str__(self) -> str:
|
|
903
859
|
return f"TensorWrapper[{self.dtype}]({self.base})"
|
|
@@ -917,6 +873,9 @@ class TensorWrapper:
|
|
|
917
873
|
def to(self, device):
|
|
918
874
|
return TensorWrapper(self.base.to(device), self.dtype)
|
|
919
875
|
|
|
876
|
+
def new_empty(self, sizes):
|
|
877
|
+
return TensorWrapper(self.base.new_empty(sizes), self.dtype)
|
|
878
|
+
|
|
920
879
|
|
|
921
880
|
def reinterpret(tensor, dtype):
|
|
922
881
|
if isinstance(tensor, TensorWrapper):
|
triton/testing.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import math
|
|
2
3
|
import os
|
|
4
|
+
import statistics
|
|
3
5
|
import subprocess
|
|
4
6
|
import sys
|
|
5
7
|
from contextlib import contextmanager
|
|
@@ -17,16 +19,42 @@ def nvsmi(attrs):
|
|
|
17
19
|
return ret
|
|
18
20
|
|
|
19
21
|
|
|
22
|
+
# pure Python implementation of np.quantile/torch.quantile
|
|
23
|
+
# to avoid unnecessary runtime dependency on numpy/torch
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _quantile(a, q):
|
|
27
|
+
n = len(a)
|
|
28
|
+
a = sorted(a)
|
|
29
|
+
|
|
30
|
+
def get_quantile(q):
|
|
31
|
+
if not (0 <= q <= 1):
|
|
32
|
+
raise ValueError("Quantiles must be in the range [0, 1]")
|
|
33
|
+
point = q * (n - 1)
|
|
34
|
+
lower = math.floor(point)
|
|
35
|
+
upper = math.ceil(point)
|
|
36
|
+
t = point - lower
|
|
37
|
+
return (1 - t) * a[lower] + t * a[upper]
|
|
38
|
+
|
|
39
|
+
return [get_quantile(q) for q in q]
|
|
40
|
+
|
|
41
|
+
|
|
20
42
|
def _summarize_statistics(times, quantiles, return_mode):
|
|
21
|
-
import torch
|
|
22
43
|
if quantiles is not None:
|
|
23
|
-
ret =
|
|
44
|
+
ret = _quantile(times, quantiles)
|
|
24
45
|
if len(ret) == 1:
|
|
25
46
|
ret = ret[0]
|
|
26
47
|
return ret
|
|
27
48
|
if return_mode == "all":
|
|
28
|
-
return times
|
|
29
|
-
|
|
49
|
+
return times
|
|
50
|
+
elif return_mode == "min":
|
|
51
|
+
return min(times)
|
|
52
|
+
elif return_mode == "max":
|
|
53
|
+
return max(times)
|
|
54
|
+
elif return_mode == "mean":
|
|
55
|
+
return statistics.mean(times)
|
|
56
|
+
elif return_mode == "median":
|
|
57
|
+
return statistics.median(times)
|
|
30
58
|
|
|
31
59
|
|
|
32
60
|
def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"):
|
|
@@ -39,7 +67,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
|
|
|
39
67
|
:type rep: int
|
|
40
68
|
:param grad_to_none: Reset the gradient of the provided tensor to None
|
|
41
69
|
:type grad_to_none: torch.tensor, optional
|
|
42
|
-
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean".
|
|
70
|
+
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
|
|
43
71
|
:type return_mode: str
|
|
44
72
|
"""
|
|
45
73
|
import torch
|
|
@@ -89,7 +117,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
|
|
|
89
117
|
end_event.record()
|
|
90
118
|
torch.cuda.synchronize()
|
|
91
119
|
ret += [start_event.elapsed_time(end_event) / n_repeat]
|
|
92
|
-
return _summarize_statistics(
|
|
120
|
+
return _summarize_statistics(ret, quantiles, return_mode)
|
|
93
121
|
|
|
94
122
|
|
|
95
123
|
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
|
|
@@ -107,10 +135,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
|
|
|
107
135
|
:type grad_to_none: torch.tensor, optional
|
|
108
136
|
:param quantiles: Performance percentile to return in addition to the median.
|
|
109
137
|
:type quantiles: list[float], optional
|
|
110
|
-
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean".
|
|
138
|
+
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
|
|
139
|
+
:type return_mode: str
|
|
111
140
|
"""
|
|
112
141
|
assert return_mode in ["min", "max", "mean", "median", "all"]
|
|
113
|
-
import torch
|
|
114
142
|
|
|
115
143
|
di = runtime.driver.active.get_device_interface()
|
|
116
144
|
|
|
@@ -124,7 +152,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
|
|
|
124
152
|
end_event = di.Event(enable_timing=True)
|
|
125
153
|
start_event.record()
|
|
126
154
|
for _ in range(5):
|
|
127
|
-
|
|
155
|
+
runtime.driver.active.clear_cache(cache)
|
|
128
156
|
fn()
|
|
129
157
|
end_event.record()
|
|
130
158
|
di.synchronize()
|
|
@@ -147,14 +175,14 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
|
|
|
147
175
|
for x in grad_to_none:
|
|
148
176
|
x.grad = None
|
|
149
177
|
# we clear the L2 cache before each run
|
|
150
|
-
|
|
178
|
+
runtime.driver.active.clear_cache(cache)
|
|
151
179
|
# record time of `fn`
|
|
152
180
|
start_event[i].record()
|
|
153
181
|
fn()
|
|
154
182
|
end_event[i].record()
|
|
155
183
|
# Record clocks
|
|
156
184
|
di.synchronize()
|
|
157
|
-
times =
|
|
185
|
+
times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
|
|
158
186
|
return _summarize_statistics(times, quantiles, return_mode)
|
|
159
187
|
|
|
160
188
|
|