triton-windows 3.2.0.post12__cp39-cp39-win_amd64.whl → 3.3.0a0.post12__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
triton/runtime/jit.py CHANGED
@@ -11,6 +11,7 @@ from functools import cached_property
11
11
  from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple
12
12
  from ..runtime.driver import driver
13
13
  from types import ModuleType
14
+ from .._utils import find_paths_if, get_iterable_path
14
15
 
15
16
  TRITON_MODULE = __name__[:-len(".runtime.jit")]
16
17
 
@@ -275,47 +276,63 @@ class KernelParam:
275
276
  return self._param.default != inspect.Parameter.empty
276
277
 
277
278
 
278
- def compute_spec_key(v, align):
279
-
280
- if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0):
281
- return "D"
282
- elif isinstance(v, int):
283
- # bool is a subclass of int, so we don't check explicitly above.
284
- if align and (v % 16 == 0):
285
- return "D"
286
- elif v == 1:
287
- return "1"
288
- return "N"
279
+ dtype2str = {}
280
+ specialize_impl_cache = []
289
281
 
290
282
 
291
- dtype2str = {}
283
+ def create_specialize_impl():
284
+ if specialize_impl_cache:
285
+ return specialize_impl_cache[-1]
292
286
 
287
+ from ..language import constexpr
293
288
 
294
- def mangle_type(arg, is_const=False):
289
+ def specialize_impl(arg, specialize_extra, is_const=False, specialize_value=True, align=True):
295
290
 
296
- if arg is None:
297
- return "none"
298
- elif isinstance(arg, bool):
299
- return "i1"
300
- elif isinstance(arg, int):
301
- if -(2**31) <= arg and arg <= 2**31 - 1:
302
- return "i32"
303
- elif 2**63 <= arg and arg <= 2**64 - 1:
304
- return "u64"
291
+ if arg is None:
292
+ return ("constexpr", None)
293
+ elif isinstance(arg, JITFunction):
294
+ return ("constexpr", arg.cache_key)
295
+ elif isinstance(arg, constexpr):
296
+ return ("constexpr", arg)
297
+ elif isinstance(arg, bool):
298
+ return ("i1", None)
299
+ elif isinstance(arg, int):
300
+ key = specialize_extra(arg, "int", align=align) if specialize_value else None
301
+ if arg == 1 and specialize_value:
302
+ return ("constexpr", 1)
303
+ elif -(2**31) <= arg and arg <= 2**31 - 1:
304
+ return ("i32", key)
305
+ elif 2**63 <= arg and arg <= 2**64 - 1:
306
+ return ("u64", key)
307
+ else:
308
+ return ("i64", key)
309
+ elif isinstance(arg, float):
310
+ return ("fp32", None)
311
+ elif hasattr(arg, "tma_desc_cpu_ptr"):
312
+ return ("nvTmaDesc", None)
313
+ elif isinstance(arg, tuple):
314
+ spec = [specialize_impl(x, specialize_extra) for x in arg]
315
+ make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
316
+ tys = make_tuple([x[0] for x in spec])
317
+ keys = make_tuple([x[1] for x in spec])
318
+ return (tys, keys)
305
319
  else:
306
- return "i64"
307
- elif isinstance(arg, float):
308
- return "fp32"
309
- elif hasattr(arg, "tma_desc_cpu_ptr"):
310
- return "nvTmaDesc"
311
- else:
312
- # dtypes are hashable so we can memoize this mapping:
313
- dsk = (arg.dtype, is_const)
314
- res = dtype2str.get(dsk, None)
315
- if res is None:
316
- res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]]
317
- dtype2str[dsk] = res
318
- return res
320
+ # dtypes are hashable so we can memoize this mapping:
321
+ dsk = (arg.dtype, is_const)
322
+ res = dtype2str.get(dsk, None)
323
+ if res is None:
324
+ res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]]
325
+ dtype2str[dsk] = res
326
+ key = specialize_extra(arg, "tensor", align=align) if specialize_value else None
327
+ return (res, key)
328
+
329
+ specialize_impl_cache.append(specialize_impl)
330
+ return specialize_impl
331
+
332
+
333
+ def mangle_type(arg, specialize=False):
334
+ specialize_impl = create_specialize_impl()
335
+ return specialize_impl(arg, lambda _, **kwargs: None, specialize_value=specialize)[0]
319
336
 
320
337
 
321
338
  class KernelInterface(Generic[T]):
@@ -335,8 +352,9 @@ def serialize_specialization_data(name, signature, constants, attrs, options, ke
335
352
  constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()}
336
353
  import json
337
354
  obj = {
338
- 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options':
339
- options.__dict__, 'key': key
355
+ 'name': name, 'signature': signature, 'constant_keys': [list(x) for x in constants.keys()], 'constant_vals':
356
+ list(constants.values()), 'attrs_keys': [list(x) for x in attrs.keys()], 'attrs_vals': list(attrs.values()),
357
+ 'options': options.__dict__, 'key': key
340
358
  }
341
359
  serialized_obj = json.dumps(obj)
342
360
  return serialized_obj
@@ -349,50 +367,32 @@ def create_function_from_signature(sig, kparams, backend):
349
367
  basis to avoid having to run these expensive functions -- which constitute
350
368
  much of the kernel launch overhead -- every time we run the kernel.
351
369
  """
352
-
353
370
  assert len(sig.parameters) == len(kparams)
354
-
355
371
  # Create the function argument list and the dict entries for the return statement
356
- func_args = []
357
- dict_entries = []
358
- constexpr_vals = []
359
- non_constexpr_vals = []
360
- signature_types = []
361
- specialisations = []
362
-
363
- for ((name, sp), kp) in zip(sig.parameters.items(), kparams):
364
- if sp.default is inspect.Parameter.empty:
365
- func_args.append(name)
366
- dict_entries.append(f"'{name}': {name}")
367
- else:
368
- func_args.append(f"{name}=default_{name}")
369
- dict_entries.append(f"'{name}': {name}")
372
+ specialization = []
373
+ # signature
374
+ for name, kp in zip(sig.parameters.keys(), kparams):
370
375
  if kp.is_constexpr:
371
- constexpr_vals.append(name)
376
+ specialization.append(f'("constexpr", {name})')
372
377
  else:
373
- non_constexpr_vals.append(name)
374
- if not kp.do_not_specialize:
375
- if not kp.do_not_specialize_on_alignment:
376
- specialisations.append('compute_spec_key(%s, align=True)' % name)
377
- else:
378
- specialisations.append('compute_spec_key(%s, align=False)' % name)
378
+ is_const = 'True' if kp.is_const else 'False'
379
+ specialize = 'False' if kp.do_not_specialize else 'True'
380
+ align = 'False' if kp.do_not_specialize_on_alignment else 'True'
381
+ ret = f"specialize_impl({name}, specialize_extra, {is_const}, {specialize}, {align})"
379
382
  if kp.annotation_type:
380
- signature_types.append('"%s"' % kp.annotation_type)
383
+ specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
381
384
  else:
382
- signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False'))
383
-
384
- cache_key = ''.join([x + ', ' for x in signature_types + specialisations])
385
- constexpr_vals = ''.join([x + ', ' for x in constexpr_vals])
386
- non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals])
387
-
388
- func_args.append('**excess_kwargs')
385
+ specialization.append(f"{ret}")
389
386
 
387
+ # compute argument string for a given parameter
388
+ arg = lambda x: x[0] if x[1].default is inspect.Parameter.empty else f"{x[0]}=default_{x[0]}"
390
389
  # Join all arguments into a function definition string
391
- args_str = ', '.join(func_args)
392
- dict_str = ', '.join(dict_entries)
393
- func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % (
394
- args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals)
395
-
390
+ func_body = f"""
391
+ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options"])}):
392
+ params = {{{', '.join([f"'{name}': {name}" for name in sig.parameters.keys()])}}}
393
+ specialization = [{','.join(specialization)}]
394
+ return params, specialization, options
395
+ """
396
396
  # Prepare defaults to be inserted into function namespace
397
397
  func_namespace = {
398
398
  f"default_{name}": param.default
@@ -400,8 +400,9 @@ def create_function_from_signature(sig, kparams, backend):
400
400
  if param.default is not inspect.Parameter.empty
401
401
  }
402
402
 
403
- func_namespace['mangle_type'] = mangle_type
404
- func_namespace['compute_spec_key'] = backend.compute_spec_key
403
+ func_namespace["JITFunction"] = JITFunction
404
+ func_namespace["specialize_impl"] = create_specialize_impl()
405
+ func_namespace["specialize_extra"] = backend.get_arg_specialization
405
406
 
406
407
  # Execute the function string in func_namespace to create the function
407
408
  exec(func_body, func_namespace)
@@ -446,43 +447,6 @@ class JITFunction(KernelInterface[T]):
446
447
  # cache_hook will always be called before compilation and compiled_hook after.
447
448
  compiled_hook = None
448
449
 
449
- @staticmethod
450
- def _key_of(arg):
451
- if hasattr(arg, "dtype"):
452
- return arg.dtype
453
- elif isinstance(arg, bool):
454
- return "i1"
455
- elif isinstance(arg, int):
456
- if -(2**31) <= arg and arg <= 2**31 - 1:
457
- return "i32"
458
- elif 2**63 <= arg and arg <= 2**64 - 1:
459
- return "u64"
460
- else:
461
- return "i64"
462
- elif isinstance(arg, float):
463
- return "fp32"
464
- elif arg is None:
465
- return None
466
- else:
467
- raise TypeError(f"Unsupported type {type(arg)} for {arg}")
468
-
469
- @staticmethod
470
- def _type_of(key, is_const=False):
471
- # `None` is nullptr. Implicitly convert to *i8.
472
- if key is None:
473
- return "*i8"
474
- elif isinstance(key, str):
475
- return key
476
-
477
- dtype_str = str(key).split(".")[-1]
478
- dtype_str = type_canonicalisation_dict[dtype_str]
479
- const_str = "*k" if is_const else "*"
480
- return const_str + dtype_str
481
-
482
- def _make_constants(self, constexpr_key):
483
- constants = dict(zip(self.constexprs, constexpr_key))
484
- return constants
485
-
486
450
  def _call_hook(
487
451
  self,
488
452
  key,
@@ -501,7 +465,7 @@ class JITFunction(KernelInterface[T]):
501
465
  name = self.fn.__name__
502
466
  module = self.fn.__module__
503
467
  arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
504
- repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})"
468
+ repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"
505
469
 
506
470
  class JitFunctionInfo:
507
471
 
@@ -521,6 +485,7 @@ class JITFunction(KernelInterface[T]):
521
485
  'num_ctas': options.num_ctas,
522
486
  'num_stages': options.num_stages,
523
487
  'enable_fp_fusion': options.enable_fp_fusion,
488
+ 'launch_cooperative_grid': options.launch_cooperative_grid,
524
489
  'extern_libs': options.extern_libs,
525
490
  'configs': configs,
526
491
  'specialization_data': specialization_data,
@@ -544,89 +509,66 @@ class JITFunction(KernelInterface[T]):
544
509
  assert callable(hook)
545
510
  self.pre_run_hooks.append(hook)
546
511
 
547
- def create_binder(self, backend):
512
+ def create_binder(self):
548
513
  """
549
514
  Precompute as much as possible.
550
515
  """
551
516
  from ..compiler import CompiledKernel, compile, ASTSource, make_backend
517
+ target = driver.active.get_current_target()
518
+ backend = make_backend(target)
552
519
  self.CompiledKernel = CompiledKernel
553
520
  self.compile = compile
554
521
  self.ASTSource = ASTSource
555
- self.make_backend = make_backend
556
- self.binder = create_function_from_signature(self.signature, self.params, backend)
557
- self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr]
558
- self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr]
559
- self.specialised_indices = [
560
- i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr)
561
- ]
522
+ binder = create_function_from_signature(self.signature, self.params, backend)
523
+ return {}, target, backend, binder
562
524
 
563
525
  def run(self, *args, grid, warmup, **kwargs):
564
- kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1"
526
+ kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1"
565
527
 
566
528
  # parse options
567
- from ..compiler import make_backend
568
529
  device = driver.active.get_current_device()
569
530
  stream = driver.active.get_current_stream(device)
570
- target = driver.active.get_current_target()
571
- backend = make_backend(target)
572
531
 
573
532
  # Execute pre run hooks with args and kwargs
574
533
  for hook in self.pre_run_hooks:
575
534
  hook(*args, **kwargs)
576
535
 
577
- if self.binder is None:
578
- self.create_binder(backend)
579
-
580
- bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)
536
+ kernel_cache, target, backend, binder = self.device_caches[device]
537
+ bound_args, specialization, options = binder(*args, **kwargs)
581
538
 
582
539
  # compute cache key
583
- key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs))
584
- kernel = self.cache[device].get(key, None)
540
+ key = str(specialization) + str(options)
541
+ kernel = kernel_cache.get(key, None)
585
542
 
543
+ # Kernel is not cached; we have to compile.
586
544
  if kernel is None:
587
- # Kernel is not cached; we have to compile.
545
+ # options
588
546
  options = backend.parse_options(kwargs)
589
-
590
- # deprecated arguments
547
+ # signature
548
+ sigkeys = [x.name for x in self.params]
549
+ sigvals = [x[0] for x in specialization]
550
+ signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
551
+ # check arguments
591
552
  assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
592
553
  assert "device" not in kwargs, "device option is deprecated; current device will be used"
593
554
  assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
594
- for k in excess_kwargs:
595
- if k not in options.__dict__:
555
+ for k in kwargs:
556
+ if k not in options.__dict__ and k not in sigkeys:
596
557
  raise KeyError("Keyword argument %s was specified but unrecognised" % k)
597
-
598
- bound_vals = tuple(bound_args.values())
599
-
600
- # `None` is nullptr. Implicitly convert to *i8. This needs to be
601
- # done here rather than when we build the signature as otherwise
602
- # the kernel cache key could not distinguish between byte pointers
603
- # and None arguments, resulting in a downstream mismatch:
604
- sigkeys = [self.params[i].name for i in self.non_constexpr_indices]
605
- sigvals = sig_and_spec[:len(sigkeys)]
606
- signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)}
607
-
608
- configs = (backend.get_attrs_descriptor(self.params, bound_vals), )
609
- constant_params = configs[0].get_constants()
610
- constants = {
611
- p.name: v
612
- for (v, p) in zip(bound_vals, self.params)
613
- if p.is_constexpr or (p.num in constant_params) or v is None
614
- }
615
- for i, arg in constants.items():
616
- if callable(arg):
617
- raise TypeError(f"Callable constexpr at index {i} is not supported")
618
-
619
- if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True):
558
+ # constexprs
559
+ constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
560
+ constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
561
+ # attributes
562
+ attrvals = [x[1] for x in specialization]
563
+ attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
564
+ attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
565
+ if self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=True):
620
566
  return None
621
567
  # compile the kernel
622
- src = self.ASTSource(self, signature, constants, configs[0])
623
- kernel = self.compile(
624
- src,
625
- target=target,
626
- options=options.__dict__,
627
- )
628
- self.cache[device][key] = kernel
629
- self._call_hook(key, signature, device, constants, options, configs, warmup, before=False)
568
+ src = self.ASTSource(self, signature, constexprs, attrs)
569
+ kernel = self.compile(src, target=target, options=options.__dict__)
570
+ kernel_cache[key] = kernel
571
+ self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False)
630
572
 
631
573
  # Check that used global values have not changed.
632
574
  not_present = object()
@@ -639,21 +581,21 @@ class JITFunction(KernelInterface[T]):
639
581
  # canonicalize grid
640
582
  assert grid is not None
641
583
  if callable(grid):
642
- # Arguments are passed as a dict to `grid`, by contract.
643
- # TODO(jlebar): In the new launch API, pass the compiler flags as a
644
- # second parameter to `grid`.
645
584
  grid = grid(bound_args)
646
585
  grid_size = len(grid)
647
586
  grid_0 = grid[0]
648
587
  grid_1 = grid[1] if grid_size > 1 else 1
649
588
  grid_2 = grid[2] if grid_size > 2 else 1
650
-
651
589
  # launch kernel
652
- launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals)
653
- kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
654
- self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals)
590
+ launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
591
+ kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
592
+ launch_metadata, self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook,
593
+ *bound_args.values())
655
594
  return kernel
656
595
 
596
+ def repr(self, _):
597
+ return self._fn_name if self._repr is None else self._repr(_)
598
+
657
599
  def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None,
658
600
  noinline=None, repr=None, launch_metadata=None):
659
601
  do_not_specialize = do_not_specialize if do_not_specialize else []
@@ -666,11 +608,10 @@ class JITFunction(KernelInterface[T]):
666
608
  self.do_not_specialize = do_not_specialize
667
609
  self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
668
610
  self.starting_line_number = inspect.getsourcelines(fn)[1]
669
- self.repr = lambda _: fn.__name__ if repr is None else repr(_)
611
+ self._repr = repr
612
+ self._fn_name = fn.__name__
670
613
  self.launch_metadata = launch_metadata
671
614
 
672
- self.binder = None
673
-
674
615
  self.params = []
675
616
  for i, param in enumerate(self.signature.parameters.values()):
676
617
  dns = i in do_not_specialize or param.name in do_not_specialize
@@ -678,10 +619,11 @@ class JITFunction(KernelInterface[T]):
678
619
  self.params.append(KernelParam(i, param, dns, dns_oa))
679
620
 
680
621
  # function source code (without decorators)
681
- self.src = textwrap.dedent(inspect.getsource(fn))
682
- self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]
622
+ src = textwrap.dedent(inspect.getsource(fn))
623
+ src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
624
+ self._unsafe_update_src(src)
683
625
  # cache of just-in-time compiled kernels
684
- self.cache = defaultdict(dict)
626
+ self.device_caches = defaultdict(self.create_binder)
685
627
  self.hash = None
686
628
 
687
629
  # Map of global variables used by the function and any functions it
@@ -698,6 +640,7 @@ class JITFunction(KernelInterface[T]):
698
640
  # JITFunction can be instantiated as kernel
699
641
  # when called with a grid using __getitem__
700
642
  self.kernel = None
643
+ self.debug = debug
701
644
  self.noinline = noinline
702
645
 
703
646
  # TODO(jlebar): Remove uses of these fields outside this file, then
@@ -729,7 +672,6 @@ class JITFunction(KernelInterface[T]):
729
672
 
730
673
  def preload(self, specialization_data):
731
674
  from ..compiler import compile, ASTSource
732
- from triton.backends.compiler import AttrsDescriptor
733
675
  import json
734
676
  import triton.language as tl
735
677
  device = driver.active.get_current_device()
@@ -737,19 +679,24 @@ class JITFunction(KernelInterface[T]):
737
679
  if deserialized_obj['name'] != self.fn.__name__:
738
680
  raise RuntimeError(
739
681
  f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}")
682
+ constant_keys = map(tuple, deserialized_obj['constant_keys'])
683
+ constant_vals = deserialized_obj['constant_vals']
740
684
  constants = {
741
685
  key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
742
- for key, value in deserialized_obj['constants'].items()
686
+ for key, value in zip(constant_keys, constant_vals)
743
687
  }
688
+ attrs_keys = map(tuple, deserialized_obj['attrs_keys'])
689
+ attrs_vals = deserialized_obj['attrs_vals']
690
+ attrs = dict(zip(attrs_keys, attrs_vals))
744
691
  signature = dict(deserialized_obj['signature'].items())
745
- src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs']))
692
+ src = ASTSource(self, signature, constants, attrs)
746
693
  options = {
747
694
  key: tuple(value) if isinstance(value, list) else value
748
695
  for key, value in deserialized_obj['options'].items()
749
696
  }
750
697
  key = deserialized_obj['key']
751
698
  kernel = compile(src, None, options)
752
- self.cache[device][key] = kernel
699
+ self.device_caches[device][0][key] = kernel
753
700
  return kernel
754
701
 
755
702
  # we do not parse `src` in the constructor because
@@ -766,11 +713,20 @@ class JITFunction(KernelInterface[T]):
766
713
  raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
767
714
 
768
715
  def __setattr__(self, name, value):
769
- super(JITFunction, self).__setattr__(name, value)
770
- # - when `.src` attribute is set, cache path needs
771
- # to be reinitialized
716
+ # - when `.src` attribute is set, cache key of all callers need to be re-computed
772
717
  if name == "src":
773
- self.hash = None
718
+ raise AttributeError(f"Cannot set attribute '{name}' directly. "
719
+ f"Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
720
+ f"instead.")
721
+ super(JITFunction, self).__setattr__(name, value)
722
+
723
+ def _unsafe_update_src(self, new_src):
724
+ """
725
+ The only method allowed to modify src.
726
+ Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
727
+ """
728
+ self.hash = None
729
+ super().__setattr__('src', new_src)
774
730
 
775
731
  def __repr__(self):
776
732
  return f"JITFunction({self.module}:{self.fn.__name__})"
@@ -896,8 +852,8 @@ class TensorWrapper:
896
852
  def data_ptr(self):
897
853
  return self.base.data_ptr()
898
854
 
899
- def stride(self, i):
900
- return self.base.stride(i)
855
+ def stride(self, *args):
856
+ return self.base.stride(*args)
901
857
 
902
858
  def __str__(self) -> str:
903
859
  return f"TensorWrapper[{self.dtype}]({self.base})"
@@ -917,6 +873,9 @@ class TensorWrapper:
917
873
  def to(self, device):
918
874
  return TensorWrapper(self.base.to(device), self.dtype)
919
875
 
876
+ def new_empty(self, sizes):
877
+ return TensorWrapper(self.base.new_empty(sizes), self.dtype)
878
+
920
879
 
921
880
  def reinterpret(tensor, dtype):
922
881
  if isinstance(tensor, TensorWrapper):
triton/testing.py CHANGED
@@ -1,5 +1,7 @@
1
1
  import functools
2
+ import math
2
3
  import os
4
+ import statistics
3
5
  import subprocess
4
6
  import sys
5
7
  from contextlib import contextmanager
@@ -17,16 +19,42 @@ def nvsmi(attrs):
17
19
  return ret
18
20
 
19
21
 
22
+ # pure Python implementation of np.quantile/torch.quantile
23
+ # to avoid unnecessary runtime dependency on numpy/torch
24
+
25
+
26
+ def _quantile(a, q):
27
+ n = len(a)
28
+ a = sorted(a)
29
+
30
+ def get_quantile(q):
31
+ if not (0 <= q <= 1):
32
+ raise ValueError("Quantiles must be in the range [0, 1]")
33
+ point = q * (n - 1)
34
+ lower = math.floor(point)
35
+ upper = math.ceil(point)
36
+ t = point - lower
37
+ return (1 - t) * a[lower] + t * a[upper]
38
+
39
+ return [get_quantile(q) for q in q]
40
+
41
+
20
42
  def _summarize_statistics(times, quantiles, return_mode):
21
- import torch
22
43
  if quantiles is not None:
23
- ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
44
+ ret = _quantile(times, quantiles)
24
45
  if len(ret) == 1:
25
46
  ret = ret[0]
26
47
  return ret
27
48
  if return_mode == "all":
28
- return times.tolist()
29
- return getattr(torch, return_mode)(times).item()
49
+ return times
50
+ elif return_mode == "min":
51
+ return min(times)
52
+ elif return_mode == "max":
53
+ return max(times)
54
+ elif return_mode == "mean":
55
+ return statistics.mean(times)
56
+ elif return_mode == "median":
57
+ return statistics.median(times)
30
58
 
31
59
 
32
60
  def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"):
@@ -39,7 +67,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
39
67
  :type rep: int
40
68
  :param grad_to_none: Reset the gradient of the provided tensor to None
41
69
  :type grad_to_none: torch.tensor, optional
42
- :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean".
70
+ :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
43
71
  :type return_mode: str
44
72
  """
45
73
  import torch
@@ -89,7 +117,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
89
117
  end_event.record()
90
118
  torch.cuda.synchronize()
91
119
  ret += [start_event.elapsed_time(end_event) / n_repeat]
92
- return _summarize_statistics(torch.tensor(ret), quantiles, return_mode)
120
+ return _summarize_statistics(ret, quantiles, return_mode)
93
121
 
94
122
 
95
123
  def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
@@ -107,10 +135,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
107
135
  :type grad_to_none: torch.tensor, optional
108
136
  :param quantiles: Performance percentile to return in addition to the median.
109
137
  :type quantiles: list[float], optional
110
- :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str
138
+ :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
139
+ :type return_mode: str
111
140
  """
112
141
  assert return_mode in ["min", "max", "mean", "median", "all"]
113
- import torch
114
142
 
115
143
  di = runtime.driver.active.get_device_interface()
116
144
 
@@ -124,7 +152,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
124
152
  end_event = di.Event(enable_timing=True)
125
153
  start_event.record()
126
154
  for _ in range(5):
127
- cache.zero_()
155
+ runtime.driver.active.clear_cache(cache)
128
156
  fn()
129
157
  end_event.record()
130
158
  di.synchronize()
@@ -147,14 +175,14 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m
147
175
  for x in grad_to_none:
148
176
  x.grad = None
149
177
  # we clear the L2 cache before each run
150
- cache.zero_()
178
+ runtime.driver.active.clear_cache(cache)
151
179
  # record time of `fn`
152
180
  start_event[i].record()
153
181
  fn()
154
182
  end_event[i].record()
155
183
  # Record clocks
156
184
  di.synchronize()
157
- times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
185
+ times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
158
186
  return _summarize_statistics(times, quantiles, return_mode)
159
187
 
160
188