triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl → 3.3.0a0.post11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
@@ -3,7 +3,7 @@ import hashlib
3
3
  import json
4
4
  from .._C.libtriton import get_cache_invalidating_env_vars, ir
5
5
  from ..backends import backends
6
- from ..backends.compiler import GPUTarget, AttrsDescriptor
6
+ from ..backends.compiler import GPUTarget
7
7
  from .. import __version__
8
8
  from ..runtime.autotuner import OutOfResources
9
9
  from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
@@ -15,6 +15,7 @@ from pathlib import Path
15
15
  import re
16
16
  import functools
17
17
  import os
18
+ import sysconfig
18
19
 
19
20
  # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
20
21
  # and any following whitespace
@@ -24,19 +25,13 @@ import os
24
25
  # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
25
26
  # zero or more arguments separated by commas, and capture it as group 2 (the argument list)
26
27
  # - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
27
- mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
28
28
  ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
29
29
  prototype_pattern = {
30
- "ttir": mlir_prototype_pattern,
31
- "ttgir": mlir_prototype_pattern,
32
30
  "ptx": ptx_prototype_pattern,
33
31
  }
34
32
 
35
- mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
36
33
  ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
37
34
  arg_type_pattern = {
38
- "ttir": mlir_arg_type_pattern,
39
- "ttgir": mlir_arg_type_pattern,
40
35
  "ptx": ptx_arg_type_pattern,
41
36
  }
42
37
 
@@ -54,46 +49,32 @@ def convert_type_repr(x):
54
49
  return x
55
50
 
56
51
 
57
- def _get_num_warps_from_ir_str(src: str):
58
- ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
59
- # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
60
- # e.g. someone has an instruction (not module) attribute named "num-warps".
61
- num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
62
- assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
63
- num_warps = int(num_warps_matches[0])
64
- return num_warps
65
-
66
-
67
52
  class ASTSource:
68
53
 
69
- def __init__(self, fn, signature, constants=None, attrs=None) -> None:
54
+ def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
70
55
  self.fn = fn
71
56
  self.ext = "ttir"
72
57
  self.name = fn.__name__
73
58
  self.signature = signature
74
- self.constants = constants
75
- self.attrs = attrs
59
+ self.constants = dict()
60
+ if constexprs is not None:
61
+ for k, v in constexprs.items():
62
+ k = (fn.arg_names.index(k), ) if isinstance(k, str) else k
63
+ assert isinstance(k, tuple)
64
+ self.constants[k] = v
65
+ self.attrs = attrs or dict()
76
66
  if isinstance(self.signature, str):
77
67
  self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))}
78
- else:
79
- for k in self.signature.keys():
80
- if not isinstance(k, str):
81
- raise TypeError("Signature keys must be string")
82
- if self.constants is None:
83
- self.constants = {}
84
- else:
85
- for k in self.constants.keys():
86
- if not isinstance(k, str):
87
- raise TypeError("Constants keys must be string")
88
- if self.attrs is None:
89
- self.attrs = AttrsDescriptor()
68
+ # else:
69
+ # for k in self.signature.keys():
70
+ # if not isinstance(k, str):
71
+ # raise TypeError("Signature keys must be string")
90
72
 
91
73
  def hash(self):
92
74
  sorted_sig = [v for k, v in sorted(self.signature.items())]
93
- # Note - we stringify the keys here to allow sorting to work for cases
94
- # where constants have mixed int/str keys.
95
- sorted_constants = sorted((str(k), v) for k, v in self.constants.items())
96
- key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}"
75
+ get_key = lambda x: x.cache_key if hasattr(x, 'cache_key') else str(x)
76
+ constants_key = '-'.join([get_key(v) for k, v in sorted(self.constants.items())])
77
+ key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
97
78
  return hashlib.sha256(key.encode("utf-8")).hexdigest()
98
79
 
99
80
  def make_ir(self, options, codegen_fns, module_map, context):
@@ -106,28 +87,42 @@ class ASTSource:
106
87
 
107
88
  class IRSource:
108
89
 
109
- def __init__(self, path):
90
+ def __init__(self, path, context, backend):
110
91
  self.path = path
111
92
  path = Path(path)
112
93
  self.ext = path.suffix[1:]
113
94
  self.src = path.read_text()
114
- match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
115
- self.name = match.group(1)
116
- signature = match.group(2)
117
- types = re.findall(arg_type_pattern[self.ext], signature)
118
- self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
95
+ ir.load_dialects(context)
96
+ backend.load_dialects(context)
97
+
98
+ # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
99
+ # TODO - replace with a proper parser
100
+ if self.ext == "ptx":
101
+ match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
102
+ self.name = match.group(1)
103
+ signature = match.group(2)
104
+ types = re.findall(arg_type_pattern[self.ext], signature)
105
+ self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
106
+ else:
107
+ self.module = ir.parse_mlir_module(self.path, context)
108
+ fn_name = self.module.get_entry_func_name()
109
+ self.name = "@" + fn_name
110
+ funcOp = self.module.get_function(fn_name)
111
+ func_ty = self.module.get_function_signature(funcOp)
112
+ self.signature = {k: ty for k, ty in enumerate(func_ty)}
119
113
 
120
114
  def hash(self):
121
115
  return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
122
116
 
123
117
  def make_ir(self, options, codegen_fns, module_map, context):
124
- module = ir.parse_mlir_module(self.path, context)
125
- module.context = context
126
- return module
118
+ self.module.context = context
119
+ return self.module
127
120
 
128
121
  def parse_options(self):
129
122
  if self.ext == "ttgir":
130
- return {'num_warps': _get_num_warps_from_ir_str(self.src)}
123
+ num_warps = self.module.get_int_attr("ttg.num-warps")
124
+ assert num_warps is not None, "Unable to parse ttg.num-warps attribute"
125
+ return {'num_warps': num_warps}
131
126
  return dict()
132
127
 
133
128
 
@@ -151,11 +146,8 @@ def triton_key():
151
146
 
152
147
  # backend
153
148
  libtriton_hash = hashlib.sha256()
154
- if os.name == "nt":
155
- so_name = "libtriton.pyd"
156
- else:
157
- so_name = "libtriton.so"
158
- with open(os.path.join(TRITON_PATH, f"_C/{so_name}"), "rb") as f:
149
+ ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
150
+ with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
159
151
  while True:
160
152
  chunk = f.read(1024**2)
161
153
  if not chunk:
@@ -175,9 +167,9 @@ def parse(full_name, ext, context):
175
167
  module = ir.parse_mlir_module(full_name, context)
176
168
  module.context = context
177
169
  return module
178
- if ext == "llir" or ext == "ptx":
170
+ if ext == "llir" or ext == "ptx" or ext == "amdgcn":
179
171
  return Path(full_name).read_text()
180
- if ext == "cubin":
172
+ if ext == "cubin" or ext == "hsaco":
181
173
  return Path(full_name).read_bytes()
182
174
 
183
175
 
@@ -200,6 +192,7 @@ def filter_traceback(e: BaseException):
200
192
  "/triton/compiler/code_generator.py",
201
193
  "/ast.py",
202
194
  ]
195
+ BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES]
203
196
 
204
197
  tb = e.__traceback__
205
198
  frames = []
@@ -227,7 +220,9 @@ def compile(src, target=None, options=None):
227
220
  # create backend
228
221
  if ir_source:
229
222
  assert isinstance(src, str), "source must be either AST or a filepath"
230
- src = IRSource(src)
223
+ context = ir.context()
224
+ src = IRSource(src, context, backend)
225
+
231
226
  extra_options = src.parse_options()
232
227
  options = backend.parse_options(dict(options or dict(), **extra_options))
233
228
  # create cache manager
@@ -239,6 +234,7 @@ def compile(src, target=None, options=None):
239
234
  # core changes to make it easier to track kernels by hash.
240
235
  enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1"
241
236
  enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1"
237
+ store_only_binary = os.environ.get("TRITON_STORE_BINARY_ONLY", "0") == "1"
242
238
  fn_override_manager = get_override_manager(src.hash()) if enable_override else None
243
239
  fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
244
240
  # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
@@ -252,7 +248,6 @@ def compile(src, target=None, options=None):
252
248
  always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1"
253
249
  if not always_compile and metadata_path is not None:
254
250
  # cache hit!
255
- metadata = json.loads(Path(metadata_path).read_text())
256
251
  return CompiledKernel(src, metadata_group, hash)
257
252
  # initialize metadata
258
253
  metadata = {
@@ -261,6 +256,7 @@ def compile(src, target=None, options=None):
261
256
  **options.__dict__,
262
257
  **env_vars,
263
258
  }
259
+ metadata["triton_version"] = __version__
264
260
  # run compilation pipeline and populate metadata
265
261
  stages = dict()
266
262
  backend.add_stages(stages, options)
@@ -268,10 +264,15 @@ def compile(src, target=None, options=None):
268
264
  # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
269
265
  if ir_source:
270
266
  first_stage += 1
271
- context = ir.context()
272
- ir.load_dialects(context)
273
- backend.load_dialects(context)
274
- codegen_fns = backend.get_codegen_implementation()
267
+
268
+ # For IRSource, we have already grabbed the context + called both
269
+ # ir.load_dialects and backend.load_dialects.
270
+ if not isinstance(src, IRSource):
271
+ context = ir.context()
272
+ ir.load_dialects(context)
273
+ backend.load_dialects(context)
274
+
275
+ codegen_fns = backend.get_codegen_implementation(options)
275
276
  module_map = backend.get_module_map()
276
277
  try:
277
278
  module = src.make_ir(options, codegen_fns, module_map, context)
@@ -285,7 +286,9 @@ def compile(src, target=None, options=None):
285
286
  if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None):
286
287
  print(f"\nOverriding kernel with file {full_name}")
287
288
  next_module = parse(full_name, ext, context)
288
- metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
289
+ # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
290
+ if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")):
291
+ metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
289
292
  if fn_dump_manager is not None:
290
293
  fn_dump_manager.put(next_module, ir_filename)
291
294
  # use an env variable to parse ir from file
@@ -302,7 +305,13 @@ def compile(src, target=None, options=None):
302
305
  # This is needed to safely finalize threads pool inside context: if current process forks before
303
306
  # python GC deletes context object, thread pool in child process will be invalid, which could
304
307
  # lead to child crash or hang.
305
- context.disable_multithreading()
308
+ #
309
+ # However disabling multithreading causes the code to hang if the ASAN pass is enabled
310
+ # this is likely due to the llvm-symbolizer forking a process
311
+ # TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
312
+ # multithreading in the MLIR context
313
+ if not os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
314
+ context.disable_multithreading()
306
315
  # return handle to compiled kernel
307
316
  return CompiledKernel(src, metadata_group, hash)
308
317
 
@@ -390,6 +399,11 @@ class CompiledKernel:
390
399
  max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
391
400
  if self.metadata.shared > max_shared:
392
401
  raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
402
+ if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
403
+ # Use blackwell max tmem size for now, this should be moved in device properties
404
+ max_tmem_size = 512 # tmem size in number of columns
405
+ if self.metadata.tmem_size > max_tmem_size:
406
+ raise OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory")
393
407
  # TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
394
408
  self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
395
409
  self.name, self.kernel, self.metadata.shared, device)
@@ -408,11 +422,8 @@ class CompiledKernel:
408
422
  arg_dict = {}
409
423
  arg_idx = 0
410
424
  for i, arg_name in enumerate(self.src.fn.arg_names):
411
- if i in self.src.fn.constexprs:
412
- arg_dict[arg_name] = self.src.constants[arg_name]
413
- else:
414
- arg_dict[arg_name] = args[arg_idx]
415
- arg_idx += 1
425
+ arg_dict[arg_name] = args[arg_idx]
426
+ arg_idx += 1
416
427
  ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
417
428
  return ret
418
429
 
@@ -28,6 +28,9 @@ from .core import (
28
28
  TRITON_MAX_TENSOR_NUMEL,
29
29
  _experimental_descriptor_load,
30
30
  _experimental_descriptor_store,
31
+ _experimental_make_tensor_descriptor,
32
+ _experimental_reinterpret_tensor_descriptor,
33
+ _experimental_tensor_descriptor,
31
34
  add,
32
35
  advance,
33
36
  arange,
@@ -66,7 +69,7 @@ from .core import (
66
69
  float8e5,
67
70
  float8e5b16,
68
71
  full,
69
- function_type,
72
+ gather,
70
73
  histogram,
71
74
  inline_asm_elementwise,
72
75
  int1,
@@ -91,6 +94,7 @@ from .core import (
91
94
  range,
92
95
  reduce,
93
96
  reshape,
97
+ slice,
94
98
  split,
95
99
  static_assert,
96
100
  static_print,
@@ -98,6 +102,8 @@ from .core import (
98
102
  store,
99
103
  tensor,
100
104
  trans,
105
+ tuple,
106
+ tuple_type,
101
107
  uint16,
102
108
  uint32,
103
109
  uint64,
@@ -126,6 +132,9 @@ __all__ = [
126
132
  "TRITON_MAX_TENSOR_NUMEL",
127
133
  "_experimental_descriptor_load",
128
134
  "_experimental_descriptor_store",
135
+ "_experimental_make_tensor_descriptor",
136
+ "_experimental_reinterpret_tensor_descriptor",
137
+ "_experimental_tensor_descriptor",
129
138
  "abs",
130
139
  "add",
131
140
  "advance",
@@ -146,7 +155,6 @@ __all__ = [
146
155
  "block_type",
147
156
  "broadcast",
148
157
  "broadcast_to",
149
- "builtin",
150
158
  "cat",
151
159
  "cast",
152
160
  "cdiv",
@@ -182,7 +190,7 @@ __all__ = [
182
190
  "floor",
183
191
  "fma",
184
192
  "full",
185
- "function_type",
193
+ "gather",
186
194
  "histogram",
187
195
  "inline_asm_elementwise",
188
196
  "interleave",
@@ -191,7 +199,6 @@ __all__ = [
191
199
  "int32",
192
200
  "int64",
193
201
  "int8",
194
- "ir",
195
202
  "join",
196
203
  "load",
197
204
  "log",
@@ -225,6 +232,7 @@ __all__ = [
225
232
  "reduce",
226
233
  "reshape",
227
234
  "rsqrt",
235
+ "slice",
228
236
  "sigmoid",
229
237
  "sin",
230
238
  "softmax",
@@ -240,7 +248,7 @@ __all__ = [
240
248
  "swizzle2d",
241
249
  "tensor",
242
250
  "trans",
243
- "triton",
251
+ "tuple",
244
252
  "uint16",
245
253
  "uint32",
246
254
  "uint64",
@@ -257,6 +265,12 @@ __all__ = [
257
265
 
258
266
 
259
267
  def str_to_ty(name):
268
+ from builtins import tuple
269
+
270
+ if isinstance(name, tuple):
271
+ fields = type(name).__dict__.get("_fields", None)
272
+ return tuple_type([str_to_ty(x) for x in name], fields)
273
+
260
274
  if name[0] == "*":
261
275
  name = name[1:]
262
276
  const = False
@@ -269,6 +283,9 @@ def str_to_ty(name):
269
283
  if name == "nvTmaDesc":
270
284
  return nv_tma_desc_type()
271
285
 
286
+ if name == "constexpr":
287
+ return constexpr
288
+
272
289
  tys = {
273
290
  "fp8e4nv": float8e4nv,
274
291
  "fp8e4b8": float8e4b8,