triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl → 3.3.0a0.post11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +3 -3
  3. triton/_internal_testing.py +59 -4
  4. triton/_utils.py +35 -0
  5. triton/backends/amd/compiler.py +121 -74
  6. triton/backends/amd/driver.py +77 -43
  7. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
  8. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
  13. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
  15. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
  16. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
  17. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
  18. triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
  19. triton/backends/amd/include/hip/hip_ext.h +4 -2
  20. triton/backends/amd/include/hip/hip_fp8.h +33 -0
  21. triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
  22. triton/backends/amd/include/hip/hip_version.h +3 -3
  23. triton/backends/amd/include/hip/hiprtc.h +25 -25
  24. triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
  25. triton/backends/amd/include/hsa/hsa.h +11 -2
  26. triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
  27. triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
  28. triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
  29. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
  30. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
  31. triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
  32. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
  33. triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
  34. triton/backends/amd/lib/asanrtl.bc +0 -0
  35. triton/backends/compiler.py +25 -225
  36. triton/backends/driver.py +7 -2
  37. triton/backends/nvidia/bin/ptxas.exe +0 -0
  38. triton/backends/nvidia/compiler.py +135 -90
  39. triton/backends/nvidia/driver.c +0 -1
  40. triton/backends/nvidia/driver.py +135 -49
  41. triton/backends/nvidia/include/cuda.h +2162 -241
  42. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  43. triton/compiler/__init__.py +2 -2
  44. triton/compiler/code_generator.py +334 -231
  45. triton/compiler/compiler.py +77 -66
  46. triton/language/__init__.py +22 -5
  47. triton/language/core.py +448 -74
  48. triton/language/extra/cuda/_experimental_tma.py +3 -5
  49. triton/language/math.py +1 -1
  50. triton/language/random.py +2 -1
  51. triton/language/semantic.py +206 -52
  52. triton/language/standard.py +35 -18
  53. triton/runtime/_allocation.py +32 -0
  54. triton/runtime/autotuner.py +27 -32
  55. triton/runtime/build.py +1 -48
  56. triton/runtime/cache.py +6 -6
  57. triton/runtime/errors.py +10 -0
  58. triton/runtime/interpreter.py +179 -45
  59. triton/runtime/jit.py +149 -190
  60. triton/testing.py +39 -11
  61. triton/tools/compile.py +27 -20
  62. triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
  63. triton/tools/mxfp.py +301 -0
  64. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/METADATA +5 -2
  65. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/RECORD +68 -59
  66. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/top_level.txt +2 -0
  67. /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
  68. {triton_windows-3.2.0.post11.dist-info → triton_windows-3.3.0a0.post11.dist-info}/WHEEL +0 -0
@@ -1,11 +1,13 @@
1
1
  import functools
2
2
  import os
3
+ import sysconfig
3
4
  import hashlib
4
5
  import subprocess
5
6
  import tempfile
6
7
  from pathlib import Path
7
8
  from triton.runtime.build import _build
8
9
  from triton.runtime.cache import get_cache_manager
10
+ from triton.runtime import _allocation
9
11
  from triton.backends.compiler import GPUTarget
10
12
  from triton.backends.driver import GPUDriver
11
13
 
@@ -53,14 +55,17 @@ def library_dirs():
53
55
  return [libdevice_dir, *libcuda_dirs()]
54
56
 
55
57
 
58
+ @functools.lru_cache()
59
+ def platform_key():
60
+ from platform import machine, system, architecture
61
+ return ",".join([machine(), system(), *architecture()])
62
+
63
+
56
64
  def compile_module_from_src(src, name):
57
- key = hashlib.sha256(src.encode("utf-8")).hexdigest()
65
+ key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
58
66
  cache = get_cache_manager(key)
59
- if os.name == "nt":
60
- so_name = f"{name}.pyd"
61
- else:
62
- so_name = f"{name}.so"
63
- cache_path = cache.get_file(so_name)
67
+ ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
68
+ cache_path = cache.get_file(f"{name}.{ext}")
64
69
  if cache_path is None:
65
70
  with tempfile.TemporaryDirectory() as tmpdir:
66
71
  src_path = os.path.join(tmpdir, f"{name}.c")
@@ -68,7 +73,7 @@ def compile_module_from_src(src, name):
68
73
  f.write(src)
69
74
  so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
70
75
  with open(so, "rb") as f:
71
- cache_path = cache.put(f.read(), so_name, binary=True)
76
+ cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
72
77
  import importlib.util
73
78
  spec = importlib.util.spec_from_file_location(name, cache_path)
74
79
  mod = importlib.util.module_from_spec(spec)
@@ -126,22 +131,32 @@ def ty_to_cpp(ty):
126
131
  }[ty]
127
132
 
128
133
 
129
- def make_launcher(constants, signature, ids):
130
- # Record the end of regular arguments;
131
- # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
132
- arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
134
+ def make_launcher(constants, signature):
135
+
136
+ def _serialize_signature(sig):
137
+ if isinstance(sig, tuple):
138
+ return ','.join(map(_serialize_signature, sig))
139
+ return sig
133
140
 
134
141
  def _extracted_type(ty):
142
+ if isinstance(ty, tuple):
143
+ val = ','.join(map(_extracted_type, ty))
144
+ return f"[{val}]"
135
145
  if ty[0] == '*':
136
146
  return "PyObject*"
137
- if ty == "nvTmaDesc":
147
+ if ty in ("constexpr", "nvTmaDesc"):
138
148
  return "PyObject*"
139
-
140
149
  return ty_to_cpp(ty)
141
150
 
142
151
  def format_of(ty):
152
+ if isinstance(ty, tuple):
153
+ val = ''.join(map(format_of, ty))
154
+ return f"({val})"
155
+ if ty[0] == '*':
156
+ return "O"
157
+ if ty in ("constexpr", "nvTmaDesc"):
158
+ return "O"
143
159
  return {
144
- "PyObject*": "O",
145
160
  "float": "f",
146
161
  "double": "d",
147
162
  "long": "l",
@@ -153,12 +168,17 @@ def make_launcher(constants, signature, ids):
153
168
  "uint16_t": "H",
154
169
  "uint32_t": "I",
155
170
  "uint64_t": "K",
156
- }[ty]
171
+ }[ty_to_cpp(ty)]
157
172
 
158
- args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
159
- format = "iiiKKOOOO" + args_format
173
+ args_format = ''.join([format_of(ty) for ty in signature.values()])
174
+ format = "iiiKKpOOOOO" + args_format
175
+ signature = ','.join(map(_serialize_signature, signature.values()))
176
+ signature = list(filter(bool, signature.split(',')))
177
+ signature = {i: s for i, s in enumerate(signature)}
160
178
  args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
161
-
179
+ # Record the end of regular arguments;
180
+ # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
181
+ arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
162
182
  internal_args_list = []
163
183
  for i, ty in signature.items():
164
184
  if ty[0] == "*":
@@ -166,16 +186,23 @@ def make_launcher(constants, signature, ids):
166
186
  elif ty == "nvTmaDesc":
167
187
  # Note: we have to dereference the pointer
168
188
  internal_args_list.append(f"*tma_ptr{i}")
169
- else:
189
+ elif ty != "constexpr":
170
190
  internal_args_list.append(f"_arg{i}")
191
+ params = range(len(signature))
171
192
 
172
193
  # generate glue code
173
- params = [i for i in signature.keys() if i not in constants]
174
- if params:
175
- params_decl = ", ".join(f"&arg{i}" for i in params)
176
- params_decl = f"void *params[] = {{ {params_decl} }};"
177
- else:
178
- params_decl = "void **params = NULL;"
194
+ newline = '\n '
195
+ ptr_decls = [
196
+ f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;"
197
+ for i, ty in signature.items()
198
+ if ty[0] == "*"
199
+ ]
200
+ tma_decls = [
201
+ f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
202
+ if ty == "nvTmaDesc"
203
+ ]
204
+ params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
205
+ params.append("&global_scratch")
179
206
  src = f"""
180
207
  #include \"cuda.h\"
181
208
  #include <stdbool.h>
@@ -248,19 +275,50 @@ static cuLaunchKernelEx_t getLaunchKernelExHandle() {{
248
275
  }}
249
276
  #endif
250
277
 
251
- static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
252
- {params_decl}
278
+ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
279
+ void *params[] = {{ {', '.join(params)} }};
253
280
  if (gridX*gridY*gridZ > 0) {{
254
- if (num_ctas == 1) {{
281
+ if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{
255
282
  CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
283
+ }} else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {{
284
+ CUlaunchAttribute launchAttr[1];
285
+ CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
286
+ launchAttr[0] = coopAttr;
287
+
288
+ CUlaunchConfig config;
289
+ config.gridDimX = gridX;
290
+ config.gridDimY = gridY;
291
+ config.gridDimZ = gridZ;
292
+ config.blockDimX = 32 * num_warps;
293
+ config.blockDimY = 1;
294
+ config.blockDimZ = 1;
295
+ config.sharedMemBytes = shared_memory;
296
+ config.hStream = stream;
297
+ config.attrs = launchAttr;
298
+ config.numAttrs = 1;
299
+
300
+ static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
301
+ if (cuLaunchKernelExHandle == NULL) {{
302
+ cuLaunchKernelExHandle = getLaunchKernelExHandle();
303
+ }}
304
+ CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
305
+
256
306
  }} else {{
257
- CUlaunchAttribute launchAttr[2];
307
+ CUlaunchAttribute launchAttr[3];
258
308
  launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
259
309
  launchAttr[0].value.clusterDim.x = clusterDimX;
260
310
  launchAttr[0].value.clusterDim.y = clusterDimY;
261
311
  launchAttr[0].value.clusterDim.z = clusterDimZ;
262
312
  launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
263
313
  launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
314
+
315
+ unsigned numAttrs = 2;
316
+ if (0 != launch_cooperative_grid) {{
317
+ CUlaunchAttribute coopAttr = {{ .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1}};
318
+ launchAttr[2] = coopAttr;
319
+ numAttrs = 3;
320
+ }}
321
+
264
322
  CUlaunchConfig config;
265
323
  config.gridDimX = gridX * clusterDimX;
266
324
  config.gridDimY = gridY * clusterDimY;
@@ -271,7 +329,7 @@ static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas
271
329
  config.sharedMemBytes = shared_memory;
272
330
  config.hStream = stream;
273
331
  config.attrs = launchAttr;
274
- config.numAttrs = 2;
332
+ config.numAttrs = numAttrs;
275
333
  static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
276
334
  if (cuLaunchKernelExHandle == NULL) {{
277
335
  cuLaunchKernelExHandle = getLaunchKernelExHandle();
@@ -396,14 +454,17 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
396
454
  int gridX, gridY, gridZ;
397
455
  uint64_t _stream;
398
456
  uint64_t _function;
457
+ int launch_cooperative_grid;
399
458
  PyObject *launch_enter_hook = NULL;
400
459
  PyObject *launch_exit_hook = NULL;
401
460
  PyObject *kernel_metadata = NULL;
402
461
  PyObject *launch_metadata = NULL;
403
- {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
404
- if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function,
462
+ PyObject *global_scratch_obj = NULL;
463
+ {newline.join([f"{_extracted_type(ty)} _arg{i};" for i, ty in signature.items()])}
464
+ if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ,
465
+ &_stream, &_function, &launch_cooperative_grid, &global_scratch_obj,
405
466
  &kernel_metadata, &launch_metadata,
406
- &launch_enter_hook, &launch_exit_hook {args_list})) {{
467
+ &launch_enter_hook, &launch_exit_hook{args_list})) {{
407
468
  return NULL;
408
469
  }}
409
470
 
@@ -422,11 +483,20 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
422
483
  return NULL;
423
484
  }}
424
485
 
486
+ CUdeviceptr global_scratch = 0;
487
+ if (global_scratch_obj != Py_None) {{
488
+ DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1);
489
+ if (!global_scratch_info.valid) {{
490
+ return NULL;
491
+ }}
492
+ global_scratch = global_scratch_info.dev_ptr;
493
+ }}
494
+
425
495
  // raise exception asap
426
- {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
427
- {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])};
496
+ {newline.join(ptr_decls)}
497
+ {newline.join(tma_decls)}
428
498
  Py_BEGIN_ALLOW_THREADS;
429
- _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
499
+ _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
430
500
  Py_END_ALLOW_THREADS;
431
501
  if (PyErr_Occurred()) {{
432
502
  return NULL;
@@ -441,9 +511,7 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
441
511
 
442
512
  }}
443
513
 
444
- // return None
445
- Py_INCREF(Py_None);
446
- return Py_None;
514
+ Py_RETURN_NONE;
447
515
  }}
448
516
 
449
517
  static PyMethodDef ModuleMethods[] = {{
@@ -474,17 +542,25 @@ PyMODINIT_FUNC PyInit___triton_launcher(void) {{
474
542
  class CudaLauncher(object):
475
543
 
476
544
  def __init__(self, src, metadata):
477
- ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
478
545
  constants = src.constants if hasattr(src, "constants") else dict()
479
- cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
480
- constants = {cst_key(key): value for key, value in constants.items()}
481
- signature = {cst_key(key): value for key, value in src.signature.items()}
482
- src = make_launcher(constants, signature, ids)
546
+ arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
547
+ constants = {arg_idx(idx): value for idx, value in constants.items()}
548
+ signature = {idx: value for idx, value in src.signature.items()}
549
+ src = make_launcher(constants, signature)
483
550
  mod = compile_module_from_src(src, "__triton_launcher")
484
551
  self.launch = mod.launch
485
-
486
- def __call__(self, *args, **kwargs):
487
- self.launch(*args, **kwargs)
552
+ self.global_scratch_size = metadata.global_scratch_size
553
+ self.global_scratch_align = metadata.global_scratch_align
554
+ self.launch_cooperative_grid = metadata.launch_cooperative_grid
555
+
556
+ def __call__(self, gridX, gridY, gridZ, stream, function, *args):
557
+ if self.global_scratch_size > 0:
558
+ grid_size = gridX * gridY * gridZ
559
+ alloc_size = grid_size * self.global_scratch_size
560
+ global_scratch = _allocation._allocator(alloc_size, self.global_scratch_align, stream)
561
+ else:
562
+ global_scratch = None
563
+ self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args)
488
564
 
489
565
 
490
566
  class CudaDriver(GPUDriver):
@@ -501,14 +577,21 @@ class CudaDriver(GPUDriver):
501
577
  warp_size = 32
502
578
  return GPUTarget("cuda", capability, warp_size)
503
579
 
580
+ def get_active_torch_device(self):
581
+ import torch
582
+ return torch.device("cuda", self.get_current_device())
583
+
504
584
  def get_device_interface(self):
505
585
  import torch
506
586
  return torch.cuda
507
587
 
508
588
  @staticmethod
509
589
  def is_active():
510
- import torch
511
- return torch.cuda.is_available() and (torch.version.hip is None)
590
+ try:
591
+ import torch
592
+ return torch.cuda.is_available() and (torch.version.hip is None)
593
+ except ImportError:
594
+ return False
512
595
 
513
596
  def get_benchmarker(self):
514
597
  from triton.testing import do_bench
@@ -522,3 +605,6 @@ class CudaDriver(GPUDriver):
522
605
  # doesn't contain any input data before the run
523
606
  cache_size = 256 * 1024 * 1024
524
607
  return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
608
+
609
+ def clear_cache(self, cache):
610
+ cache.zero_()