triton-windows 3.4.0.post20__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -22,10 +22,11 @@ def min_dot_size(target: GPUTarget):
22
22
  lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
23
23
  rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
24
24
  assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
25
+ # For small M/N the input we can still use tensorcores with padding.
25
26
  if lhs_bitwidth == 8:
26
- return (16, 16, 32)
27
+ return (1, 1, 32)
27
28
  else:
28
- return (16, 16, 16)
29
+ return (1, 1, 16)
29
30
 
30
31
  return check_dot_compatibility
31
32
 
@@ -59,6 +60,11 @@ def ptx_get_version(cuda_version) -> int:
59
60
  return 70 + minor
60
61
  if major == 10:
61
62
  return 63 + minor
63
+
64
+ if major >= 13:
65
+ base_ptx = 90
66
+ return base_ptx + (major - 13) * 10 + minor
67
+
62
68
  raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
63
69
 
64
70
 
@@ -111,6 +117,7 @@ class CUDAOptions:
111
117
  num_warps: int = 4
112
118
  num_ctas: int = 1
113
119
  num_stages: int = 3
120
+ warp_size: int = 32
114
121
  # maxnreg corresponds to the ptx parameter .maxnreg, which controls the
115
122
  # maximum number of 32-bit registers used by one thread.
116
123
  maxnreg: Optional[int] = None
@@ -121,7 +128,7 @@ class CUDAOptions:
121
128
  enable_fp_fusion: bool = True
122
129
  launch_cooperative_grid: bool = False
123
130
  launch_pdl: bool = False
124
- supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
131
+ supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e4b15")
125
132
  deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
126
133
  default_dot_input_precision: str = "tf32"
127
134
  allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
@@ -131,6 +138,7 @@ class CUDAOptions:
131
138
  backend_name: str = 'cuda'
132
139
  sanitize_overflow: bool = True
133
140
  arch: str = None
141
+ instrumentation_mode: str = ""
134
142
 
135
143
  def __post_init__(self):
136
144
  default_libdir = Path(__file__).parent / 'lib'
@@ -150,6 +158,7 @@ class CUDAOptions:
150
158
 
151
159
 
152
160
  class CUDABackend(BaseBackend):
161
+ instrumentation = None
153
162
 
154
163
  @staticmethod
155
164
  def supports_target(target: GPUTarget):
@@ -175,10 +184,13 @@ class CUDABackend(BaseBackend):
175
184
  args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
176
185
  capability = int(self._parse_arch(args["arch"]))
177
186
 
187
+ if args.get("num_ctas", 1) > 1 and capability < 90:
188
+ raise ValueError((f"num_ctas > 1 requires NVIDIA SM90+ (Hopper). "
189
+ f"Current target is sm_{capability}. This configuration will fail. "
190
+ f"Please set num_ctas=1 or target an SM90+ GPU."))
191
+
178
192
  if "supported_fp8_dtypes" not in args:
179
193
  supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
180
- if capability >= 89:
181
- supported_fp8_dtypes.add("fp8e4nv")
182
194
  args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
183
195
 
184
196
  if "deprecated_fp8_dot_operand_dtypes" not in args:
@@ -218,6 +230,8 @@ class CUDABackend(BaseBackend):
218
230
 
219
231
  def load_dialects(self, ctx):
220
232
  nvidia.load_dialects(ctx)
233
+ if CUDABackend.instrumentation:
234
+ CUDABackend.instrumentation.load_dialects(ctx)
221
235
 
222
236
  @staticmethod
223
237
  def make_ttir(mod, metadata, opt, capability):
@@ -278,13 +292,15 @@ class CUDABackend(BaseBackend):
278
292
  passes.common.add_canonicalizer(pm)
279
293
  passes.ttir.add_triton_licm(pm)
280
294
  passes.ttgpuir.add_optimize_accumulator_init(pm)
281
- passes.ttgpuir.add_hoist_tmem_alloc(pm)
295
+ passes.ttgpuir.add_hoist_tmem_alloc(pm, False)
282
296
  nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
283
297
  passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
284
298
  passes.ttgpuir.add_schedule_loops(pm)
285
299
  passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
286
300
  passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
287
301
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
302
+ # hoist again and allow hoisting out of if statements
303
+ passes.ttgpuir.add_hoist_tmem_alloc(pm, True)
288
304
  nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
289
305
  else:
290
306
  passes.ttir.add_triton_licm(pm)
@@ -302,24 +318,28 @@ class CUDABackend(BaseBackend):
302
318
  passes.common.add_symbol_dce(pm)
303
319
  if capability // 10 >= 9:
304
320
  nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
305
- nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
321
+ nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
322
+ nvidia.passes.ttnvgpuir.add_lower_mma(pm)
306
323
  passes.common.add_sccp(pm)
324
+ passes.common.add_cse(pm)
307
325
  passes.common.add_canonicalizer(pm)
326
+
308
327
  pm.run(mod)
309
328
  metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
310
329
  tensordesc_meta = mod.get_tensordesc_metadata()
311
330
  metadata["tensordesc_meta"] = tensordesc_meta
312
331
  return mod
313
332
 
314
- def ttgir_opt(self, src, metadata, options, capability):
333
+ def gluon_to_ttgir(self, src, metadata, options, capability):
315
334
  mod = src
316
335
  pm = ir.pass_manager(mod.context)
317
336
  pm.enable_debug()
318
337
 
319
- passes.ttgpuir.add_inliner(pm)
338
+ passes.gluon.add_inliner(pm)
339
+ passes.gluon.add_resolve_auto_encodings(pm)
320
340
  passes.common.add_sccp(pm)
321
341
  passes.ttir.add_loop_aware_cse(pm)
322
- passes.ttgpuir.add_canonicalizer(pm)
342
+ passes.gluon.add_canonicalizer(pm)
323
343
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
324
344
 
325
345
  pm.run(mod)
@@ -334,13 +354,19 @@ class CUDABackend(BaseBackend):
334
354
  pm = ir.pass_manager(mod.context)
335
355
  pm.enable_debug()
336
356
 
337
- nvidia.passes.ttnvgpuir.add_lower_mma(pm)
338
357
  passes.ttgpuir.add_combine_tensor_select_and_if(pm)
339
358
  passes.ttgpuir.add_allocate_warp_groups(pm)
340
359
  passes.convert.add_scf_to_cf(pm)
341
- passes.ttgpuir.add_allocate_shared_memory(pm)
360
+ nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version)
342
361
  nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
362
+ if knobs.compilation.enable_experimental_consan:
363
+ # Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
364
+ passes.ttgpuir.add_concurrency_sanitizer(pm)
343
365
  passes.ttgpuir.add_allocate_global_scratch_memory(pm)
366
+ nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
367
+ # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
368
+ if CUDABackend.instrumentation:
369
+ CUDABackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
344
370
  nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
345
371
  passes.common.add_canonicalizer(pm)
346
372
  passes.common.add_cse(pm)
@@ -349,8 +375,12 @@ class CUDABackend(BaseBackend):
349
375
  passes.common.add_canonicalizer(pm)
350
376
  passes.common.add_cse(pm)
351
377
  passes.common.add_symbol_dce(pm)
378
+ passes.convert.add_nvvm_to_llvm(pm)
352
379
  if not knobs.compilation.disable_line_info:
353
380
  passes.llvmir.add_di_scope(pm)
381
+ if CUDABackend.instrumentation:
382
+ CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
383
+
354
384
  pm.run(mod)
355
385
  # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
356
386
  llvm.init_targets()
@@ -366,7 +396,7 @@ class CUDABackend(BaseBackend):
366
396
  llvm.attach_datalayout(llvm_mod, triple, proc, features)
367
397
  nvidia.set_nvvm_reflect_ftz(llvm_mod)
368
398
 
369
- if options.extern_libs:
399
+ if options.extern_libs and nvidia.has_extern_deps(llvm_mod):
370
400
  paths = [path for (name, path) in options.extern_libs]
371
401
  llvm.link_extern_libs(llvm_mod, paths)
372
402
 
@@ -381,6 +411,8 @@ class CUDABackend(BaseBackend):
381
411
  metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
382
412
  metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
383
413
  metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
414
+ metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
415
+ metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
384
416
  ret = str(llvm_mod)
385
417
  del llvm_mod
386
418
  del context
@@ -416,8 +448,18 @@ class CUDABackend(BaseBackend):
416
448
  fsrc.flush()
417
449
  fbin = fsrc.name + '.o'
418
450
 
419
- line_info = ["-lineinfo", "-suppress-debug-info"] if knobs.compilation.disable_line_info else ["-lineinfo"]
420
- fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
451
+ debug_info = []
452
+ if knobs.compilation.disable_line_info:
453
+ # This option is ignored if used without -lineinfo
454
+ debug_info += ["-lineinfo", "-suppress-debug-info"]
455
+ elif knobs.nvidia.disable_ptxas_opt:
456
+ # Synthesize complete debug info
457
+ debug_info += ["-g"]
458
+ else:
459
+ # Only emit line info
460
+ debug_info += ["-lineinfo"]
461
+
462
+ fmad = [] if opt.enable_fp_fusion else ["--fmad=false"]
421
463
  arch = sm_arch_from_capability(capability)
422
464
 
423
465
  # Disable ptxas optimizations if requested
@@ -427,13 +469,18 @@ class CUDABackend(BaseBackend):
427
469
  ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
428
470
 
429
471
  ptxas_cmd = [
430
- ptxas, *line_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name, '-o',
431
- fbin
472
+ ptxas, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name,
473
+ '-o', fbin
432
474
  ]
433
475
  try:
434
476
  # close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
435
477
  # On Windows, both stdout and stderr need to be redirected to flog
436
- subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog, stderr=flog)
478
+ subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog,
479
+ stderr=flog)
480
+ if knobs.nvidia.dump_ptxas_log:
481
+ with open(flog.name) as log_file:
482
+ print(log_file.read())
483
+
437
484
  except subprocess.CalledProcessError as e:
438
485
  with open(flog.name) as log_file:
439
486
  log = log_file.read()
@@ -445,9 +492,20 @@ class CUDABackend(BaseBackend):
445
492
  else:
446
493
  error = f'`ptxas` failed with error code {e.returncode}'
447
494
 
448
- raise PTXASError(f"{error}\n"
449
- f"`ptxas` stderr:\n{log}\n"
450
- f'Repro command: {" ".join(ptxas_cmd)}\n')
495
+ error = (f"{error}\n"
496
+ f"`ptxas` stderr:\n{log}\n"
497
+ f'Repro command: {" ".join(ptxas_cmd)}\n')
498
+
499
+ print(f"""
500
+
501
+ ================================================================
502
+ {error}
503
+
504
+ {src}
505
+ ================================================================
506
+ please share the reproducer above with Triton project.
507
+ """)
508
+ raise PTXASError(error)
451
509
 
452
510
  with open(fbin, 'rb') as f:
453
511
  cubin = f.read()
@@ -464,7 +522,7 @@ class CUDABackend(BaseBackend):
464
522
  stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
465
523
  stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
466
524
  elif language == Language.GLUON:
467
- stages["ttgir"] = lambda src, metadata: self.ttgir_opt(src, metadata, options, capability)
525
+ stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability)
468
526
  stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
469
527
  stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
470
528
  stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
@@ -9,9 +9,15 @@
9
9
  #endif
10
10
 
11
11
  #include <stdbool.h>
12
+ #include <stdlib.h>
12
13
  #define PY_SSIZE_T_CLEAN
13
14
  #include <Python.h>
14
15
 
16
+ typedef struct {
17
+ PyObject_HEAD
18
+ _Alignas(128) CUtensorMap tensorMap;
19
+ } PyCUtensorMapObject;
20
+
15
21
  // Raises a Python exception and returns false if code is not CUDA_SUCCESS.
16
22
  static bool gpuAssert(CUresult code, const char *file, int line) {
17
23
  if (code == CUDA_SUCCESS)
@@ -34,7 +40,7 @@ static bool gpuAssert(CUresult code, const char *file, int line) {
34
40
  #define CUDA_CHECK_AND_RETURN_NULL(ans) \
35
41
  do { \
36
42
  if (!gpuAssert((ans), __FILE__, __LINE__)) \
37
- return NULL; \
43
+ goto cleanup; \
38
44
  } while (0)
39
45
 
40
46
  // To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
@@ -52,7 +58,7 @@ static bool gpuAssert(CUresult code, const char *file, int line) {
52
58
  if ((funcPointer) == NULL) { \
53
59
  (funcPointer) = (initializerFunction)(); \
54
60
  if ((funcPointer) == NULL) { \
55
- return NULL; \
61
+ goto cleanup; \
56
62
  } \
57
63
  } \
58
64
  } while (0)
@@ -95,6 +101,9 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
95
101
  warp_size, "sm_clock_rate", sm_clock_rate,
96
102
  "mem_clock_rate", mem_clock_rate, "mem_bus_width",
97
103
  mem_bus_width);
104
+
105
+ cleanup:
106
+ return NULL;
98
107
  }
99
108
 
100
109
  static PyObject *loadBinary(PyObject *self, PyObject *args) {
@@ -268,6 +277,9 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
268
277
  cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config));
269
278
  Py_END_ALLOW_THREADS;
270
279
  return PyLong_FromLong(maxActiveClusters);
280
+
281
+ cleanup:
282
+ return NULL;
271
283
  }
272
284
 
273
285
  static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
@@ -306,12 +318,57 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
306
318
  }
307
319
 
308
320
  Py_END_ALLOW_THREADS;
309
- Py_INCREF(Py_None);
310
- return Py_None;
321
+ Py_RETURN_NONE;
322
+ }
323
+
324
+ static PyObject *PyCUtensorMap_alloc(PyTypeObject *type, Py_ssize_t n_items) {
325
+ PyCUtensorMapObject *self = NULL;
326
+ void *mem = NULL;
327
+ size_t size = type->tp_basicsize;
328
+
329
+ #ifdef _WIN32
330
+ mem = _aligned_malloc(size, 128);
331
+ if (mem == NULL) {
332
+ #else
333
+ if (posix_memalign(&mem, 128, size) != 0) {
334
+ #endif
335
+ PyErr_NoMemory();
336
+ return NULL;
337
+ }
338
+
339
+ self = (PyCUtensorMapObject *)mem;
340
+ PyObject_INIT(self, type);
341
+ return (PyObject *)self;
342
+ }
343
+
344
+ static void PyCUtensorMap_dealloc(PyObject *self) {
345
+ Py_TYPE(self)->tp_free(self);
346
+ }
347
+
348
+ static void PyCUtensorMap_free(void *ptr) {
349
+ #ifdef _WIN32
350
+ _aligned_free(ptr);
351
+ #else
352
+ free(ptr);
353
+ #endif
311
354
  }
312
355
 
356
+ // clang-format off
357
+ static PyTypeObject PyCUtensorMapType = {
358
+ PyVarObject_HEAD_INIT(NULL, 0)
359
+ .tp_name = "triton.backends.nvidia.PyCUtensorMap",
360
+ .tp_basicsize = sizeof(PyCUtensorMapObject),
361
+ .tp_itemsize = 0,
362
+ .tp_flags = Py_TPFLAGS_DEFAULT,
363
+ .tp_doc = "<PyCUtensorMap object>",
364
+ .tp_new = PyType_GenericNew,
365
+ .tp_alloc = PyCUtensorMap_alloc,
366
+ .tp_dealloc = (destructor)PyCUtensorMap_dealloc,
367
+ .tp_free = PyCUtensorMap_free,
368
+ };
369
+ // clang-format on
370
+
313
371
  static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
314
- unsigned long long desc_address;
315
372
  unsigned long long global_address;
316
373
  int swizzle;
317
374
  int elemSize;
@@ -319,17 +376,22 @@ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
319
376
  PyObject *blockSize;
320
377
  PyObject *shape;
321
378
  PyObject *strides;
379
+ int padding;
380
+
381
+ if (!PyArg_ParseTuple(args, "KiiiOOOi", &global_address, &swizzle, &elemSize,
382
+ &elemType, &blockSize, &shape, &strides, &padding)) {
383
+ return NULL;
384
+ }
322
385
 
323
- if (!PyArg_ParseTuple(args, "KKiiiOOO", &desc_address, &global_address,
324
- &swizzle, &elemSize, &elemType, &blockSize, &shape,
325
- &strides)) {
386
+ PyCUtensorMapObject *desc = (PyCUtensorMapObject *)PyObject_CallObject(
387
+ (PyObject *)&PyCUtensorMapType, NULL);
388
+ if (!desc) {
326
389
  return NULL;
327
390
  }
328
391
 
329
392
  PyObject *blockSizeFast = NULL;
330
393
  PyObject *shapeFast = NULL;
331
394
  PyObject *stridesFast = NULL;
332
- PyObject *result = NULL;
333
395
 
334
396
  uint32_t blockSizeInt[5];
335
397
  uint64_t shapeInt[5];
@@ -391,22 +453,27 @@ static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
391
453
  Py_DECREF(stridesFast);
392
454
  stridesFast = NULL;
393
455
 
456
+ CUtensorMapFloatOOBfill fill =
457
+ (padding == 1) ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
458
+ : CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
459
+
394
460
  uint32_t elementStrides[5] = {1, 1, 1, 1, 1};
395
461
  static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
396
462
  INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
397
463
  getCuTensorMapEncodeTiledHandle);
398
464
  CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
399
- (CUtensorMap *)desc_address, elemType, rank, (void *)global_address,
400
- shapeInt, stridesLL, blockSizeInt, elementStrides,
401
- CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
402
- CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
403
- Py_RETURN_NONE;
465
+ &desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
466
+ stridesLL, blockSizeInt, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
467
+ swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill));
468
+
469
+ return (PyObject *)desc;
404
470
 
405
471
  cleanup:
406
472
  Py_XDECREF(blockSizeFast);
407
473
  Py_XDECREF(shapeFast);
408
474
  Py_XDECREF(stridesFast);
409
- return result;
475
+ Py_XDECREF(desc);
476
+ return NULL;
410
477
  }
411
478
 
412
479
  static PyMethodDef ModuleMethods[] = {
@@ -433,12 +500,18 @@ static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
433
500
  ModuleMethods};
434
501
 
435
502
  PyMODINIT_FUNC PyInit_cuda_utils(void) {
503
+ if (PyType_Ready(&PyCUtensorMapType) < 0) {
504
+ return NULL;
505
+ }
506
+
436
507
  PyObject *m = PyModule_Create(&ModuleDef);
437
508
  if (m == NULL) {
438
509
  return NULL;
439
510
  }
440
511
 
441
512
  PyModule_AddFunctions(m, ModuleMethods);
513
+ Py_INCREF(&PyCUtensorMapType);
514
+ PyModule_AddObject(m, "PyCUtensorMap", (PyObject *)&PyCUtensorMapType);
442
515
 
443
516
  return m;
444
517
  }