triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl → 3.3.0a0.post11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
@@ -4,10 +4,10 @@ import builtins
4
4
  import os
5
5
  import time
6
6
  import inspect
7
- from typing import Dict
7
+ from typing import Dict, Tuple, List, Optional
8
8
 
9
9
  from .jit import KernelInterface
10
- from .errors import OutOfResources
10
+ from .errors import OutOfResources, PTXASError
11
11
  from .driver import driver
12
12
 
13
13
 
@@ -23,7 +23,7 @@ class Autotuner(KernelInterface):
23
23
  restore_value,
24
24
  pre_hook=None,
25
25
  post_hook=None,
26
- prune_configs_by: Dict = None,
26
+ prune_configs_by: Optional[Dict] = None,
27
27
  warmup=None,
28
28
  rep=None,
29
29
  use_cuda_graph=False,
@@ -36,14 +36,11 @@ class Autotuner(KernelInterface):
36
36
  'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
37
37
  """
38
38
  if not configs:
39
- self.configs = [
40
- Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0,
41
- reg_dec_producer=0, reg_inc_consumer=0)
42
- ]
39
+ self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)]
43
40
  else:
44
41
  self.configs = configs
45
42
  self.keys = key
46
- self.cache = {}
43
+ self.cache: Dict[Tuple, Config] = {}
47
44
  self.arg_names = arg_names
48
45
 
49
46
  # Reset to zero or restore values
@@ -134,6 +131,10 @@ class Autotuner(KernelInterface):
134
131
  def _bench(self, *args, config, **meta):
135
132
  from ..compiler.errors import CompileTimeAssertionFailure
136
133
 
134
+ verbose = os.environ.get("TRITON_PRINT_AUTOTUNING", None) == "1"
135
+ if verbose:
136
+ print(f"Autotuning kernel {self.base_fn.__name__} with config {config}")
137
+
137
138
  # check for conflicts, i.e. meta-parameters both provided
138
139
  # as kwargs and by the autotuner
139
140
  conflicts = meta.keys() & config.kwargs.keys()
@@ -164,7 +165,9 @@ class Autotuner(KernelInterface):
164
165
 
165
166
  try:
166
167
  return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
167
- except (OutOfResources, CompileTimeAssertionFailure):
168
+ except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e:
169
+ if verbose:
170
+ print(f"Autotuning failed with {e}")
168
171
  return [float("inf"), float("inf"), float("inf")]
169
172
 
170
173
  def run(self, *args, **kwargs):
@@ -208,7 +211,7 @@ class Autotuner(KernelInterface):
208
211
  self.nargs = None
209
212
  return ret
210
213
 
211
- def prune_configs(self, kwargs):
214
+ def prune_configs(self, kwargs: Dict) -> List[Config]:
212
215
  pruned_configs = self.configs
213
216
  if self.early_config_prune:
214
217
  pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
@@ -216,6 +219,10 @@ class Autotuner(KernelInterface):
216
219
  top_k = self.configs_top_k
217
220
  if isinstance(top_k, float) and top_k <= 1.0:
218
221
  top_k = int(len(self.configs) * top_k)
222
+ elif not isinstance(top_k, int):
223
+ # Slice index must be an integer
224
+ raise TypeError("Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int")
225
+
219
226
  if len(pruned_configs) > top_k:
220
227
  est_timing = {
221
228
  config: self.perf_model(
@@ -262,16 +269,11 @@ class Config:
262
269
  function are args.
263
270
  """
264
271
 
265
- def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0,
266
- reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None):
272
+ def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None):
267
273
  self.kwargs = kwargs
268
274
  self.num_warps = num_warps
269
275
  self.num_ctas = num_ctas
270
276
  self.num_stages = num_stages
271
- self.num_buffers_warp_spec = num_buffers_warp_spec
272
- self.num_consumer_groups = num_consumer_groups
273
- self.reg_dec_producer = reg_dec_producer
274
- self.reg_inc_consumer = reg_inc_consumer
275
277
  self.maxnreg = maxnreg
276
278
  self.pre_hook = pre_hook
277
279
 
@@ -283,10 +285,6 @@ class Config:
283
285
  ("num_warps", self.num_warps),
284
286
  ("num_ctas", self.num_ctas),
285
287
  ("num_stages", self.num_stages),
286
- ("num_buffers_warp_spec", self.num_buffers_warp_spec),
287
- ("num_consumer_groups", self.num_consumer_groups),
288
- ("reg_dec_producer", self.reg_dec_producer),
289
- ("reg_inc_consumer", self.reg_inc_consumer),
290
288
  ("maxnreg", self.maxnreg),
291
289
  ) if v is not None
292
290
  }
@@ -299,10 +297,6 @@ class Config:
299
297
  res.append(f"num_warps: {self.num_warps}")
300
298
  res.append(f"num_ctas: {self.num_ctas}")
301
299
  res.append(f"num_stages: {self.num_stages}")
302
- res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}")
303
- res.append(f"num_consumer_groups: {self.num_consumer_groups}")
304
- res.append(f"reg_dec_producer: {self.reg_dec_producer}")
305
- res.append(f"reg_inc_consumer: {self.reg_inc_consumer}")
306
300
  res.append(f"maxnreg: {self.maxnreg}")
307
301
  return ", ".join(res)
308
302
 
@@ -323,8 +317,8 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
323
317
  # the value of x_size changes
324
318
  )
325
319
  @triton.jit
326
- def kernel(x_ptr, x_size, **META):
327
- BLOCK_SIZE = META['BLOCK_SIZE']
320
+ def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
321
+ ...
328
322
  :note: When all the configurations are evaluated, the kernel will run multiple times.
329
323
  This means that whatever value the kernel updates will be updated multiple times.
330
324
  To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
@@ -367,7 +361,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
367
361
  def decorator(fn):
368
362
  return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
369
363
  post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
370
- use_cuda_graph=use_cuda_graph)
364
+ use_cuda_graph=use_cuda_graph, do_bench=do_bench)
371
365
 
372
366
  return decorator
373
367
 
@@ -388,18 +382,19 @@ class Heuristics(KernelInterface):
388
382
  def heuristics(values):
389
383
  """
390
384
  Decorator for specifying how the values of certain meta-parameters may be computed.
391
- This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
385
+ This is useful for cases where auto-tuning is prohibitively expensive, or just not applicable.
392
386
 
393
387
  .. highlight:: python
394
388
  .. code-block:: python
395
389
 
396
- @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
390
+ # smallest power-of-two >= x_size
391
+ @triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])})
397
392
  @triton.jit
398
- def kernel(x_ptr, x_size, **META):
399
- BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size
393
+ def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
394
+ ...
400
395
  :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
401
396
  each such function takes a list of positional arguments as input.
402
- :type values: dict[str, Callable[[list[Any]], Any]]
397
+ :type values: dict[str, Callable[[dict[str, Any]], Any]]
403
398
  """
404
399
 
405
400
  def decorator(fn):
triton/runtime/build.py CHANGED
@@ -1,26 +1,12 @@
1
- import contextlib
2
- import sys
3
- import io
4
1
  import sysconfig
5
2
  import os
6
3
  import shutil
7
4
  import subprocess
8
- import setuptools
9
5
 
10
6
  if os.name == "nt":
11
7
  from triton.windows_utils import find_msvc_winsdk, find_python
12
8
 
13
9
 
14
- @contextlib.contextmanager
15
- def quiet():
16
- old_stdout, old_stderr = sys.stdout, sys.stderr
17
- sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
18
- try:
19
- yield
20
- finally:
21
- sys.stdout, sys.stderr = old_stdout, old_stderr
22
-
23
-
24
10
  def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
25
11
  if cc.lower().endswith("cl") or cc.lower().endswith("cl.exe"):
26
12
  out_base = os.path.splitext(out)[0]
@@ -74,38 +60,5 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
74
60
  include_dirs += msvc_winsdk_inc_dirs
75
61
  library_dirs += msvc_winsdk_lib_dirs
76
62
  cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
77
- ret = subprocess.check_call(cc_cmd)
78
- if ret == 0:
79
- return so
80
- # fallback on setuptools
81
- extra_compile_args = []
82
- if cc.lower().endswith("cl") or cc.lower().endswith("cl.exe"):
83
- extra_compile_args += ["/O2"]
84
- else:
85
- extra_compile_args += ["-O3"]
86
- # extra arguments
87
- extra_link_args = []
88
- # create extension module
89
- ext = setuptools.Extension(
90
- name=name,
91
- language='c',
92
- sources=[src],
93
- include_dirs=include_dirs,
94
- extra_compile_args=extra_compile_args,
95
- extra_link_args=extra_link_args,
96
- library_dirs=library_dirs,
97
- libraries=libraries,
98
- )
99
- # build extension module
100
- args = ['build_ext']
101
- args.append('--build-temp=' + srcdir)
102
- args.append('--build-lib=' + srcdir)
103
- args.append('-q')
104
- args = dict(
105
- name=name,
106
- ext_modules=[ext],
107
- script_args=args,
108
- )
109
- with quiet():
110
- setuptools.setup(**args)
63
+ subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
111
64
  return so
triton/runtime/cache.py CHANGED
@@ -256,9 +256,9 @@ __cache_cls = FileCacheManager
256
256
  __cache_cls_nme = "DEFAULT"
257
257
 
258
258
 
259
- def _base64(key):
259
+ def _base32(key):
260
260
  # Assume key is a hex string.
261
- return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
261
+ return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
262
262
 
263
263
 
264
264
  def get_cache_manager(key) -> CacheManager:
@@ -274,15 +274,15 @@ def get_cache_manager(key) -> CacheManager:
274
274
  __cache_cls = getattr(module, clz_nme)
275
275
  __cache_cls_nme = user_cache_manager
276
276
 
277
- return __cache_cls(_base64(key))
277
+ return __cache_cls(_base32(key))
278
278
 
279
279
 
280
280
  def get_override_manager(key) -> CacheManager:
281
- return __cache_cls(_base64(key), override=True)
281
+ return __cache_cls(_base32(key), override=True)
282
282
 
283
283
 
284
284
  def get_dump_manager(key) -> CacheManager:
285
- return __cache_cls(_base64(key), dump=True)
285
+ return __cache_cls(_base32(key), dump=True)
286
286
 
287
287
 
288
288
  def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
@@ -292,4 +292,4 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
292
292
  for kw in kwargs:
293
293
  key = f"{key}-{kwargs.get(kw)}"
294
294
  key = hashlib.sha256(key.encode("utf-8")).hexdigest()
295
- return _base64(key)
295
+ return _base32(key)
triton/runtime/errors.py CHANGED
@@ -24,3 +24,13 @@ class OutOfResources(TritonError):
24
24
  def __reduce__(self):
25
25
  # this is necessary to make CompilationError picklable
26
26
  return (type(self), (self.required, self.limit, self.name))
27
+
28
+
29
+ class PTXASError(TritonError):
30
+
31
+ def __init__(self, error_message: Optional[str] = None):
32
+ self.error_message = error_message
33
+
34
+ def __str__(self) -> str:
35
+ error_message = self.error_message or ""
36
+ return f"PTXAS error: {error_message}"
@@ -1,7 +1,7 @@
1
1
  import ast
2
2
  import textwrap
3
3
  import inspect
4
- from typing import Tuple
4
+ from typing import Tuple, List
5
5
 
6
6
  import math
7
7
  import numpy as np
@@ -21,7 +21,7 @@ class TensorHandle:
21
21
  '''
22
22
  data: numpy array
23
23
  dtype: triton type, either pointer_type or scalar_type.
24
- we don't store block_type here because the shape information is already availale in the data field
24
+ we don't store block_type here because the shape information is already available in the data field
25
25
  attr: a dictionary of attributes
26
26
  '''
27
27
  self.data = data
@@ -46,27 +46,63 @@ class TensorHandle:
46
46
 
47
47
  class BlockPointerHandle:
48
48
 
49
- def __init__(self, base, shape, strides, offsets, tensor_shape, order):
49
+ def __init__(self, base, shape, strides, offsets, block_shape, order):
50
50
  self.base = base
51
51
  self.shape = shape
52
52
  self.strides = strides
53
53
  self.offsets = offsets
54
- self.tensor_shape = tensor_shape
54
+ self.block_shape = block_shape
55
55
  self.order = order
56
56
 
57
57
  def materialize_pointers(self, boundary_check):
58
58
  dtype_tt = self.base.get_element_ty()
59
59
  n_bytes = dtype_tt.primitive_bitwidth // 8
60
- tensor_shape = self.tensor_shape
61
- ptrs = np.broadcast_to(self.base.data, self.tensor_shape)
62
- masks = np.ones(self.tensor_shape, dtype=bool)
63
- for dim in range(len(tensor_shape)):
64
- bcast_dims = [1] * len(tensor_shape)
65
- bcast_dims[dim] = tensor_shape[dim]
66
- off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
60
+ ptrs = np.broadcast_to(self.base.data, self.block_shape)
61
+ masks = np.ones(self.block_shape, dtype=bool)
62
+ for dim in range(len(self.block_shape)):
63
+ bcast_dims = [1] * len(self.block_shape)
64
+ bcast_dims[dim] = self.block_shape[dim]
65
+ off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
67
66
  ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
68
67
  if dim in boundary_check:
69
- masks = np.logical_and(masks, off < self.shape[dim].data)
68
+ masks = masks & (off < self.shape[dim].data) & (off >= 0)
69
+ ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
70
+ return ptrs, masks
71
+
72
+
73
+ class TensorDescHandle:
74
+
75
+ def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
76
+ block_shape: List[int]):
77
+ self.base = base
78
+ self.ndim = len(shape)
79
+ self.shape = shape
80
+ self.strides = strides
81
+ self.block_shape = block_shape
82
+
83
+ def validate(self):
84
+ assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
85
+ assert len(self.strides) == self.ndim
86
+ assert len(self.block_shape) == self.ndim
87
+
88
+ for stride in self.strides[:-1]:
89
+ assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
90
+ assert self.strides[-1].data.item() == 1, "last dim must be contiguous"
91
+
92
+ def materialize_pointers(self, offsets: List[TensorHandle]):
93
+ assert len(offsets) == self.ndim
94
+ scalar_ty = self.base.dtype.element_ty
95
+ itemsize = scalar_ty.primitive_bitwidth // 8
96
+ assert (offsets[-1].data * itemsize) % 16 == 0, "block offset start must be 16-byte aligned"
97
+
98
+ ptrs = np.broadcast_to(self.base.data, self.block_shape)
99
+ masks = np.ones(self.block_shape, dtype=bool)
100
+ for dim in range(len(self.block_shape)):
101
+ bcast_dims = [1] * len(self.block_shape)
102
+ bcast_dims[dim] = self.block_shape[dim]
103
+ off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
104
+ ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64)
105
+ masks = masks & (0 <= off) & (off < self.shape[dim].data)
70
106
  ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
71
107
  return ptrs, masks
72
108
 
@@ -242,7 +278,7 @@ class InterpreterBuilder:
242
278
  self.options = InterpreterOptions()
243
279
  self.codegen_fns = {}
244
280
  self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types
245
- self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (16, 16, 16)
281
+ self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1)
246
282
 
247
283
  def set_grid_idx(self, x, y, z):
248
284
  if not x < self.grid_dim[0]:
@@ -419,7 +455,7 @@ class InterpreterBuilder:
419
455
  create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
420
456
  create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
421
457
  create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
422
- create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
458
+ create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
423
459
  create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
424
460
  create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
425
461
  create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
@@ -557,6 +593,9 @@ class InterpreterBuilder:
557
593
  def create_histogram(self, data, bins):
558
594
  return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32)
559
595
 
596
+ def create_gather(self, src, indices, axis):
597
+ return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)
598
+
560
599
  # pointer arithmetic
561
600
 
562
601
  def create_addptr(self, ptr, offset):
@@ -655,21 +694,61 @@ class InterpreterBuilder:
655
694
  # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
656
695
  pass
657
696
 
658
- def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order):
697
+ def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order):
659
698
  # Create new offsets to avoid modifying the original
660
699
  new_offsets = [offset.clone() for offset in offsets]
661
- return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order)
700
+ return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order)
662
701
 
663
702
  def create_advance(self, ptr, offsets):
664
703
  if len(ptr.offsets) != len(offsets):
665
704
  raise ValueError("len(ptr.offsets) != len(offsets)")
666
705
  # Create new offsets to avoid modifying the original
667
706
  new_offsets = [offset.clone() for offset in ptr.offsets]
668
- ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order)
707
+ ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order)
669
708
  for i in range(len(offsets)):
670
709
  ret.offsets[i].data += offsets[i].data
671
710
  return ret
672
711
 
712
+ def create_make_tensor_descriptor(
713
+ self,
714
+ base: TensorHandle,
715
+ shape: List[TensorHandle],
716
+ strides: List[TensorHandle],
717
+ tensor_shape: List[int],
718
+ ):
719
+ desc = TensorDescHandle(base, shape, strides, tensor_shape)
720
+ desc.validate()
721
+ return desc
722
+
723
+ def create_descriptor_load(self, desc: TensorDescHandle, indices: List[TensorHandle], cache_modifier,
724
+ eviction_policy):
725
+ assert isinstance(desc, TensorDescHandle)
726
+ ptrs, mask = desc.materialize_pointers(indices)
727
+ return self.create_masked_load(ptrs, mask, other=None, cache_modifier=cache_modifier,
728
+ eviction_policy=eviction_policy, is_volatile=False)
729
+
730
+ def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]):
731
+ ptrs, mask = desc.materialize_pointers(indices)
732
+ return self.create_masked_store(ptrs, value, mask, None, None)
733
+
734
+ def create_descriptor_gather(self, desc: TensorDescHandle, x_offsets: TensorHandle, y_offset: TensorHandle, type):
735
+ dtype = desc.base.dtype.element_ty
736
+ np_dtype = _get_np_dtype(dtype)
737
+ result = np.zeros([x_offsets.data.shape[0], desc.block_shape[-1]], dtype=np_dtype)
738
+ cache_modifier = None
739
+ eviction_policy = None
740
+ for i, x_offset in enumerate(x_offsets.data):
741
+ indices = [TensorHandle(x_offset, tl.int32), y_offset]
742
+ result[i, :] = self.create_descriptor_load(desc, indices, cache_modifier, eviction_policy).data
743
+ return TensorHandle(result, dtype)
744
+
745
+ def create_descriptor_scatter(self, desc: TensorDescHandle, value: TensorHandle, x_offsets: TensorHandle,
746
+ y_offset: TensorHandle):
747
+ for i, x_offset in enumerate(x_offsets.data):
748
+ slice = TensorHandle(value.data[i], value.dtype)
749
+ indices = [TensorHandle(x_offset, tl.int32), y_offset]
750
+ self.create_descriptor_store(desc, slice, indices)
751
+
673
752
  def get_all_ones_value(self, type):
674
753
  np_type = _get_np_dtype(type)
675
754
  if "int" in np_type.name:
@@ -701,7 +780,12 @@ def _patch_lang_tensor(tensor):
701
780
  return bool(data) if data.size == 1 else True
702
781
 
703
782
  def _get_transpose(self):
704
- return tl.core.tensor(TensorHandle(np.transpose(self.handle.data), self.handle.dtype), self.dtype.scalar)
783
+ handle = TensorHandle(np.transpose(self.handle.data), self.handle.dtype)
784
+ assert self.type.is_block()
785
+ block_shape = list(self.type.shape)
786
+ block_shape[-1], block_shape[-2] = block_shape[-2], block_shape[-1]
787
+ res_ty = tl.core.block_type(self.dtype, block_shape)
788
+ return tl.core.tensor(handle, res_ty)
705
789
 
706
790
  tensor.__index__ = lambda self: int(self.handle.data)
707
791
  tensor.__bool__ = lambda self: _get_bool(self)
@@ -710,7 +794,7 @@ def _patch_lang_tensor(tensor):
710
794
  tensor.T = property(_get_transpose)
711
795
 
712
796
 
713
- class ReduceScanOpIneterface:
797
+ class ReduceScanOpInterface:
714
798
 
715
799
  def __init__(self, axis, combine_fn):
716
800
  self.axis = axis
@@ -727,10 +811,12 @@ class ReduceScanOpIneterface:
727
811
  self.check_axis(arg.shape, self.axis)
728
812
 
729
813
  def to_tensor(self, ret, dtype):
814
+ np_dtype = _get_np_dtype(dtype)
730
815
  if hasattr(ret, "shape") and ret.shape:
731
- ret_type = tl.block_type(dtype, ret.shape)
816
+ ret = ret.astype(np_dtype)
817
+ ret_type = tl.block_type(dtype, list(ret.shape))
732
818
  else:
733
- ret = np.array([ret]).astype(_get_np_dtype(dtype))
819
+ ret = np.array([ret], dtype=np_dtype)
734
820
  ret_type = dtype
735
821
  return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type)
736
822
 
@@ -744,7 +830,7 @@ class ReduceScanOpIneterface:
744
830
  raise NotImplementedError("apply_impl not implemented")
745
831
 
746
832
 
747
- class ReduceOps(ReduceScanOpIneterface):
833
+ class ReduceOps(ReduceScanOpInterface):
748
834
 
749
835
  def __init__(self, axis, combine_fn, keep_dims):
750
836
  super().__init__(axis, combine_fn)
@@ -840,7 +926,7 @@ class ReduceOps(ReduceScanOpIneterface):
840
926
  return self.generic_reduce(input)
841
927
 
842
928
 
843
- class ScanOps(ReduceScanOpIneterface):
929
+ class ScanOps(ReduceScanOpInterface):
844
930
 
845
931
  def __init__(self, axis, combine_fn, reverse):
846
932
  super().__init__(axis, combine_fn)
@@ -989,7 +1075,7 @@ def _patch_lang_core(lang):
989
1075
  lang.static_assert = _new_static_assert
990
1076
  lang.static_print = print
991
1077
  lang.dtype.to_ir = _new_to_ir
992
- lang.multiple_of = partial(_set_attr, name="tt.divisiblity")
1078
+ lang.multiple_of = partial(_set_attr, name="tt.divisibility")
993
1079
  lang.max_contiguous = partial(_set_attr, name="tt.contiguity")
994
1080
  lang.max_constancy = partial(_set_attr, name="tt.constancy")
995
1081
 
@@ -997,7 +1083,7 @@ def _patch_lang_core(lang):
997
1083
 
998
1084
 
999
1085
  def _patch_lang(fn):
1000
- langs = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]]
1086
+ langs = [value for _, value in fn.__globals__.items() if inspect.ismodule(value) and value in [tl, tl.core]]
1001
1087
  assert len(langs) >= 1, "triton.language must be visible from within jit'd function"
1002
1088
  for lang in langs:
1003
1089
  _patch_builtin(lang, interpreter_builder)
@@ -1006,12 +1092,22 @@ def _patch_lang(fn):
1006
1092
  _patch_builtin(lang.math, interpreter_builder)
1007
1093
  _patch_lang_tensor(lang.tensor)
1008
1094
  _patch_lang_core(lang)
1095
+ _patch_builtin(tl.core._experimental_tensor_descriptor_base, interpreter_builder)
1096
+
1097
+
1098
+ def _tuple_create(arg, contents):
1099
+ # NamedTuples and tuples have different construction semantics. NamedTuple
1100
+ # has a constructor that takes individual arguments, while tuple takes an
1101
+ # iterable. Both have type "tuple" making it difficult to distinguish
1102
+ # between them, but only NamedTuple has "_fields" and apparently this is how
1103
+ # everyone does the check.
1104
+ return type(arg)(*contents) if hasattr(arg, "_fields") else type(arg)(contents)
1009
1105
 
1010
1106
 
1011
1107
  # TODO: wrap everything in triton tensors
1012
1108
  def _implicit_cvt(arg):
1013
1109
  if isinstance(arg, int):
1014
- ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
1110
+ ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
1015
1111
  dtype = np.int32
1016
1112
  if -2**31 <= arg < 2**31:
1017
1113
  dtype = np.int32
@@ -1026,16 +1122,27 @@ def _implicit_cvt(arg):
1026
1122
  handle = TensorHandle(np.array([arg], dtype=dtype), ty)
1027
1123
  return tl.tensor(handle, ty)
1028
1124
  if hasattr(arg, "data_ptr"):
1029
- ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
1125
+ ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
1030
1126
  handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
1031
1127
  return tl.tensor(handle, ty)
1128
+ elif isinstance(arg, tuple):
1129
+ return _tuple_create(arg, map(_implicit_cvt, arg))
1032
1130
  return arg
1033
1131
 
1034
1132
 
1035
1133
  interpreter_builder = InterpreterBuilder()
1036
1134
 
1037
- # These keywords are not supported by the interpreter
1038
- RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"]
1135
+
1136
+ def _unwrap_tensor(t):
1137
+ if isinstance(t, triton.runtime.jit.TensorWrapper):
1138
+ return t.base
1139
+ return t
1140
+
1141
+
1142
+ def _rewrap_tensor(t, original_tensor):
1143
+ if isinstance(original_tensor, triton.runtime.jit.TensorWrapper):
1144
+ return triton.runtime.jit.TensorWrapper(t, original_tensor.dtype)
1145
+ return t
1039
1146
 
1040
1147
 
1041
1148
  class GridExecutor:
@@ -1050,37 +1157,64 @@ class GridExecutor:
1050
1157
  self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
1051
1158
 
1052
1159
  def _init_args_hst(self, args_dev, kwargs):
1053
- args_hst = []
1054
- for arg in args_dev:
1055
- if hasattr(arg, "data_ptr"):
1056
- args_hst.append(arg.cpu())
1057
- else:
1058
- args_hst.append(arg)
1160
+ storages = {}
1161
+
1162
+ def _to_cpu(arg):
1163
+ if isinstance(arg, tuple):
1164
+ return _tuple_create(arg, map(_to_cpu, arg))
1165
+ elif not hasattr(arg, "data_ptr"):
1166
+ return arg
1167
+
1168
+ unwrapped_arg = _unwrap_tensor(arg)
1169
+ if unwrapped_arg.untyped_storage().data_ptr() not in storages:
1170
+ storage = unwrapped_arg.untyped_storage()
1171
+ storages[storage.data_ptr()] = storage.cpu()
1172
+
1173
+ storage = storages[unwrapped_arg.untyped_storage().data_ptr()]
1174
+ cpu_arg = unwrapped_arg.new_empty(0, device='cpu')
1175
+ cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride())
1176
+ cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
1177
+ return cpu_arg
1178
+
1179
+ args_hst = [_to_cpu(arg) for arg in args_dev]
1180
+
1059
1181
  # Process keyword arguments
1060
1182
  kwargs_hst = {}
1061
1183
  for key, value in kwargs.items():
1062
- if hasattr(value, "data_ptr"):
1063
- kwargs_hst[key] = value.cpu()
1064
- else:
1065
- kwargs_hst[key] = value
1184
+ kwargs_hst[key] = _to_cpu(value)
1066
1185
  return args_hst, kwargs_hst
1067
1186
 
1068
1187
  def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
1069
- for arg_dev, arg_hst in zip(args_dev, args_hst):
1188
+ storages = {}
1189
+
1190
+ def _from_cpu(arg_dev, arg_hst):
1070
1191
  if hasattr(arg_dev, "data_ptr"):
1071
- arg_dev.data.copy_(arg_hst.to(arg_dev.device).data)
1192
+ # No need to rewrap because this just modifies internal
1193
+ arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst)
1194
+ storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage())
1195
+ elif isinstance(arg_dev, tuple):
1196
+ for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
1197
+ _from_cpu(arg_dev, arg_hst)
1198
+
1199
+ for arg_dev, arg_hst in zip(args_dev, args_hst):
1200
+ _from_cpu(arg_dev, arg_hst)
1072
1201
 
1073
1202
  # Restore keyword arguments
1074
1203
  for key, kwarg_dev in kwargs.items():
1075
1204
  kwarg_hst = kwargs_hst[key]
1076
- if hasattr(kwarg_dev, "data_ptr"):
1077
- kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data)
1205
+ _from_cpu(kwarg_dev, kwarg_hst)
1206
+
1207
+ for (arg_dev, arg_hst) in storages.values():
1208
+ arg_dev.copy_(arg_hst)
1078
1209
 
1079
1210
  def __call__(self, *args_dev, **kwargs):
1080
- # removes reserved keywords from kwargs
1081
- kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
1082
1211
  if kwargs.pop("warmup", False):
1083
1212
  return
1213
+ # Removes not used reserved keywords from kwargs
1214
+ # Triton doesn't support keyword-only, variable positional or variable keyword arguments
1215
+ # It's safe to inspect only positional or keyword arguments (i.e., argspec.args)
1216
+ argspec = inspect.getfullargspec(self.fn)
1217
+ kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
1084
1218
  # copy arguments to the host
1085
1219
  args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
1086
1220
  # remaps core language functions to interpreted ones