triton-windows 3.3.1.post21__cp311-cp311-win_amd64.whl → 3.4.0.post21__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +143 -46
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +94 -94
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +296 -125
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +73 -9
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +47 -83
  59. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/METADATA +7 -2
  60. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/RECORD +64 -41
  61. triton_windows-3.4.0.post21.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post21.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post21.dist-info/top_level.txt +1 -0
  64. triton/language/_utils.py +0 -21
  65. triton/language/extra/cuda/_experimental_tma.py +0 -106
  66. triton/tools/experimental_descriptor.py +0 -32
  67. triton_windows-3.3.1.post21.dist-info/top_level.txt +0 -14
  68. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/WHEEL +0 -0
triton/_C/libtriton.pyd CHANGED
Binary file
triton/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """isort:skip_file"""
2
- __version__ = '3.3.1'
2
+ __version__ = '3.4.0'
3
3
 
4
4
  # ---------------------------------------
5
5
  # Note: import order is significant here.
@@ -26,6 +26,8 @@ from . import language
26
26
  from . import testing
27
27
  from . import tools
28
28
 
29
+ must_use_result = language.core.must_use_result
30
+
29
31
  __all__ = [
30
32
  "autotune",
31
33
  "cdiv",
@@ -39,6 +41,7 @@ __all__ = [
39
41
  "KernelInterface",
40
42
  "language",
41
43
  "MockTensor",
44
+ "must_use_result",
42
45
  "next_power_of_2",
43
46
  "OutOfResources",
44
47
  "reinterpret",
triton/_filecheck.py ADDED
@@ -0,0 +1,87 @@
1
+ import os
2
+ import inspect
3
+ import subprocess
4
+ import tempfile
5
+
6
+ import triton
7
+ from triton.compiler import ASTSource, make_backend
8
+ from triton.backends.compiler import GPUTarget
9
+ from triton.experimental.gluon._runtime import GluonASTSource
10
+ from triton._C.libtriton import ir
11
+
12
+ # ===-----------------------------------------------------------------------===#
13
+ # filecheck_test
14
+ # ===-----------------------------------------------------------------------===#
15
+
16
+ # Stub target for testing the frontend.
17
+ stub_target = GPUTarget("cuda", 100, 32)
18
+ stub_backend = make_backend(stub_target)
19
+
20
+ triton_dir = os.path.dirname(__file__)
21
+ filecheck_path = os.path.join(triton_dir, "FileCheck")
22
+
23
+
24
+ class MatchError(ValueError):
25
+
26
+ def __init__(self, message, module_str):
27
+ super().__init__(message)
28
+ self.module_str = module_str
29
+
30
+ def __str__(self):
31
+ return f"{super().__str__()}\n{self.module_str}"
32
+
33
+
34
+ def run_filecheck(name, module_str, check_template):
35
+ with tempfile.TemporaryDirectory() as tempdir:
36
+ temp_module = os.path.join(tempdir, "module")
37
+ with open(temp_module, "w") as temp:
38
+ temp.write(module_str)
39
+
40
+ temp_expected = os.path.join(tempdir, "expected")
41
+ with open(temp_expected, "w") as temp:
42
+ temp.write(check_template)
43
+
44
+ try:
45
+ subprocess.check_output([filecheck_path, temp_expected, "--input-file", temp_module],
46
+ stderr=subprocess.STDOUT)
47
+ except subprocess.CalledProcessError as error:
48
+ decoded = error.output.decode('unicode_escape')
49
+ raise ValueError(decoded)
50
+
51
+
52
+ def run_parser(kernel_fn):
53
+ sigkeys = [x.name for x in kernel_fn.params]
54
+ sigvals = [f"arg{i}" for i in range(len(sigkeys))]
55
+ signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
56
+ source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource
57
+ src = source_cls(fn=kernel_fn, signature=signature)
58
+
59
+ context = ir.context()
60
+ ir.load_dialects(context)
61
+ stub_backend.load_dialects(context)
62
+
63
+ extra_options = src.parse_options()
64
+ options = stub_backend.parse_options(dict(**extra_options))
65
+ codegen_fns = stub_backend.get_codegen_implementation(options)
66
+ module_map = stub_backend.get_module_map()
67
+ module = src.make_ir(options, codegen_fns, module_map, context)
68
+ assert module.verify()
69
+ return module
70
+
71
+
72
+ def run_filecheck_test(kernel_fn):
73
+ assert isinstance(kernel_fn, triton.runtime.JITFunction)
74
+ check_template = inspect.getsource(kernel_fn.fn)
75
+ if check_template is None:
76
+ raise ValueError("kernel function must have a docstring with FileCheck template")
77
+ mlir_module = run_parser(kernel_fn)
78
+
79
+ run_filecheck("placeholder", mlir_module.str_nodebug(), check_template)
80
+
81
+
82
+ def filecheck_test(fn):
83
+
84
+ def test_fn():
85
+ run_filecheck_test(fn)
86
+
87
+ return test_fn
@@ -4,7 +4,7 @@ import numpy as np
4
4
  import torch
5
5
  import triton
6
6
  import triton.language as tl
7
- from triton.backends.nvidia.compiler import _path_to_binary
7
+ from triton import knobs
8
8
  import pytest
9
9
 
10
10
  from numpy.random import RandomState
@@ -20,6 +20,7 @@ dtypes = integral_dtypes + float_dtypes
20
20
  dtypes_with_bfloat16 = dtypes + ['bfloat16']
21
21
  torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2']
22
22
  torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
23
+ tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"})
23
24
 
24
25
 
25
26
  def is_interpreter():
@@ -46,29 +47,29 @@ def is_hip():
46
47
  return False if target is None else target.backend == "hip"
47
48
 
48
49
 
49
- def is_hip_mi200():
50
+ def is_hip_cdna2():
50
51
  target = get_current_target()
51
- if target is None or target.backend != 'hip':
52
- return False
53
- return target.arch == 'gfx90a'
52
+ return target is not None and target.backend == 'hip' and target.arch == 'gfx90a'
54
53
 
55
54
 
56
- def is_hip_mi300():
55
+ def is_hip_cdna3():
57
56
  target = get_current_target()
58
- if target is None or target.backend != 'hip':
59
- return False
60
- return target.arch in ('gfx940', 'gfx941', 'gfx942')
57
+ return target is not None and target.backend == 'hip' and target.arch == 'gfx942'
61
58
 
62
59
 
63
- def is_hip_mi350():
60
+ def is_hip_cdna4():
64
61
  target = get_current_target()
65
- if target is None or target.backend != 'hip':
66
- return False
67
- return target.arch in ('gfx950')
62
+ return target is not None and target.backend == 'hip' and target.arch == 'gfx950'
63
+
64
+
65
+ def is_hip_gfx12():
66
+ target = get_current_target()
67
+ print(target.arch)
68
+ return target is not None and target.backend == 'hip' and 'gfx12' in target.arch
68
69
 
69
70
 
70
71
  def is_hip_cdna():
71
- return is_hip_mi200() or is_hip_mi300() or is_hip_mi350()
72
+ return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()
72
73
 
73
74
 
74
75
  def is_xpu():
@@ -161,7 +162,7 @@ def supports_tma(byval_only=False):
161
162
  return True
162
163
  if not is_cuda():
163
164
  return False
164
- _, cuda_version = _path_to_binary("ptxas")
165
+ cuda_version = knobs.nvidia.ptxas.version
165
166
  min_cuda_version = (12, 0) if byval_only else (12, 3)
166
167
  cuda_version_tuple = tuple(map(int, cuda_version.split(".")))
167
168
  assert len(cuda_version_tuple) == 2, cuda_version_tuple
@@ -176,3 +177,13 @@ def tma_skip_msg(byval_only=False):
176
177
 
177
178
 
178
179
  requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())
180
+
181
+
182
+ def default_alloc_fn(size: int, align: int, _):
183
+ return torch.empty(size, dtype=torch.int8, device="cuda")
184
+
185
+
186
+ def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> torch.Tensor:
187
+ if isinstance(t, triton.runtime.jit.TensorWrapper):
188
+ return t.base
189
+ return t
triton/_utils.py CHANGED
@@ -1,35 +1,124 @@
1
+ from __future__ import annotations
2
+
1
3
  from functools import reduce
4
+ from typing import Any, Callable, TYPE_CHECKING, Union, List, Dict
5
+
6
+ if TYPE_CHECKING:
7
+ from .language import core
8
+ IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type]
9
+ ObjPath = tuple[int, ...]
2
10
 
11
+ TRITON_MAX_TENSOR_NUMEL = 1048576
3
12
 
4
- def get_iterable_path(iterable, path):
5
- return reduce(lambda a, idx: a[idx], path, iterable)
6
13
 
14
+ def get_iterable_path(iterable: IterableType, path: ObjPath) -> Any:
15
+ return reduce(lambda a, idx: a[idx], path, iterable) # type: ignore[index]
7
16
 
8
- def set_iterable_path(iterable, path, val):
17
+
18
+ def set_iterable_path(iterable: IterableType, path: tuple[int, ...], val: Any):
19
+ assert len(path) != 0
9
20
  prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
10
- prev[path[-1]] = val
21
+ prev[path[-1]] = val # type: ignore[index]
11
22
 
12
23
 
13
- def find_paths_if(iterable, pred):
24
+ def find_paths_if(iterable: Union[IterableType, Any], pred: Callable[[ObjPath, Any], bool]) -> list[ObjPath]:
14
25
  from .language import core
15
- is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
16
- ret = dict()
26
+ is_iterable: Callable[[Any], bool] = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
27
+ # We need to use dict so that ordering is maintained, while set doesn't guarantee order
28
+ ret: dict[ObjPath, None] = {}
17
29
 
18
- def _impl(current, path):
19
- path = (path[0], ) if len(path) == 1 else tuple(path)
30
+ def _impl(path: tuple[int, ...], current: Any):
20
31
  if is_iterable(current):
21
32
  for idx, item in enumerate(current):
22
- _impl(item, path + (idx, ))
33
+ _impl((*path, idx), item)
23
34
  elif pred(path, current):
24
- if len(path) == 1:
25
- ret[(path[0], )] = None
26
- else:
27
- ret[tuple(path)] = None
28
-
29
- if is_iterable(iterable):
30
- _impl(iterable, [])
31
- elif pred(list(), iterable):
32
- ret = {tuple(): None}
33
- else:
34
- ret = dict()
35
+ ret[path] = None
36
+
37
+ _impl((), iterable)
38
+
35
39
  return list(ret.keys())
40
+
41
+
42
+ def is_power_of_two(x):
43
+ return (x & (x - 1)) == 0
44
+
45
+
46
+ def validate_block_shape(shape: List[int]):
47
+ numel = 1
48
+ for i, d in enumerate(shape):
49
+ if not isinstance(d, int):
50
+ raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
51
+ if not is_power_of_two(d):
52
+ raise ValueError(f"Shape element {i} must be a power of 2")
53
+ numel *= d
54
+
55
+ if numel > TRITON_MAX_TENSOR_NUMEL:
56
+ raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
57
+ return numel
58
+
59
+
60
+ type_canonicalisation_dict = {
61
+ # we canonicalise all bools to be unsigned:
62
+ "bool": "u1",
63
+ "int1": "u1",
64
+ "uint1": "u1",
65
+ "i1": "u1",
66
+ # floating-point dtypes:
67
+ "float8e4nv": "fp8e4nv",
68
+ "float8e5": "fp8e5",
69
+ "float8e4b15": "fp8e4b15",
70
+ "float8_e4m3fn": "fp8e4nv",
71
+ "float8e4b8": "fp8e4b8",
72
+ "float8_e4m3fnuz": "fp8e4b8",
73
+ "float8_e5m2": "fp8e5",
74
+ "float8e5b16": "fp8e5b16",
75
+ "float8_e5m2fnuz": "fp8e5b16",
76
+ "half": "fp16",
77
+ "float16": "fp16",
78
+ "bfloat16": "bf16",
79
+ "float": "fp32",
80
+ "float32": "fp32",
81
+ "double": "fp64",
82
+ "float64": "fp64",
83
+ # signed integers:
84
+ "int8": "i8",
85
+ "int16": "i16",
86
+ "int": "i32",
87
+ "int32": "i32",
88
+ "int64": "i64",
89
+ # unsigned integers:
90
+ "uint8": "u8",
91
+ "uint16": "u16",
92
+ "uint32": "u32",
93
+ "uint64": "u64",
94
+ "void": "void",
95
+ }
96
+
97
+ for v in list(type_canonicalisation_dict.values()):
98
+ type_canonicalisation_dict[v] = v
99
+
100
+
101
+ def canonicalize_dtype(dtype):
102
+ dtype_str = str(dtype).split(".")[-1]
103
+ return type_canonicalisation_dict[dtype_str]
104
+
105
+
106
+ BITWIDTH_DICT: Dict[str, int] = {
107
+ **{f"u{n}": n
108
+ for n in (1, 8, 16, 32, 64)},
109
+ **{f"i{n}": n
110
+ for n in (1, 8, 16, 32, 64)},
111
+ **{f"fp{n}": n
112
+ for n in (16, 32, 64)},
113
+ **{f"fp8{suffix}": 8
114
+ for suffix in ("e4nv", "e4b15", "e4b8", "e5", "e5b16")},
115
+ "bf16": 16,
116
+ "void": 0,
117
+ }
118
+
119
+ for k, v in type_canonicalisation_dict.items():
120
+ BITWIDTH_DICT[k] = BITWIDTH_DICT[v]
121
+
122
+
123
+ def get_primitive_bitwidth(dtype: str) -> int:
124
+ return BITWIDTH_DICT[dtype]
@@ -1,20 +1,22 @@
1
- import os
2
- import importlib.util
1
+ import importlib
3
2
  import inspect
3
+ import sys
4
4
  from dataclasses import dataclass
5
+ from typing import Type, TypeVar, Union
6
+ from types import ModuleType
5
7
  from .driver import DriverBase
6
8
  from .compiler import BaseBackend
7
9
 
10
+ if sys.version_info >= (3, 10):
11
+ from importlib.metadata import entry_points
12
+ else:
13
+ from importlib_metadata import entry_points
8
14
 
9
- def _load_module(name, path):
10
- spec = importlib.util.spec_from_file_location(name, path)
11
- module = importlib.util.module_from_spec(spec)
12
- spec.loader.exec_module(module)
13
- return module
15
+ T = TypeVar("T", bound=Union[BaseBackend, DriverBase])
14
16
 
15
17
 
16
- def _find_concrete_subclasses(module, base_class):
17
- ret = []
18
+ def _find_concrete_subclasses(module: ModuleType, base_class: Type[T]) -> Type[T]:
19
+ ret: list[Type[T]] = []
18
20
  for attr_name in dir(module):
19
21
  attr = getattr(module, attr_name)
20
22
  if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr):
@@ -28,23 +30,18 @@ def _find_concrete_subclasses(module, base_class):
28
30
 
29
31
  @dataclass(frozen=True)
30
32
  class Backend:
31
- compiler: BaseBackend = None
32
- driver: DriverBase = None
33
+ compiler: Type[BaseBackend]
34
+ driver: Type[DriverBase]
33
35
 
34
36
 
35
- def _discover_backends():
37
+ def _discover_backends() -> dict[str, Backend]:
36
38
  backends = dict()
37
- root = os.path.dirname(__file__)
38
- for name in os.listdir(root):
39
- if not os.path.isdir(os.path.join(root, name)):
40
- continue
41
- if name.startswith('__'):
42
- continue
43
- compiler = _load_module(name, os.path.join(root, name, 'compiler.py'))
44
- driver = _load_module(name, os.path.join(root, name, 'driver.py'))
45
- backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend),
46
- _find_concrete_subclasses(driver, DriverBase))
39
+ for ep in entry_points().select(group="triton.backends"):
40
+ compiler = importlib.import_module(f"{ep.value}.compiler")
41
+ driver = importlib.import_module(f"{ep.value}.driver")
42
+ backends[ep.name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), # type: ignore
43
+ _find_concrete_subclasses(driver, DriverBase)) # type: ignore
47
44
  return backends
48
45
 
49
46
 
50
- backends = _discover_backends()
47
+ backends: dict[str, Backend] = _discover_backends()
File without changes