triton-windows 3.3.1.post21__cp312-cp312-win_amd64.whl → 3.4.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (68) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +143 -46
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +94 -94
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +296 -125
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +73 -9
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +47 -83
  59. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/METADATA +7 -2
  60. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/RECORD +64 -41
  61. triton_windows-3.4.0.post21.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post21.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post21.dist-info/top_level.txt +1 -0
  64. triton/language/_utils.py +0 -21
  65. triton/language/extra/cuda/_experimental_tma.py +0 -106
  66. triton/tools/experimental_descriptor.py +0 -32
  67. triton_windows-3.3.1.post21.dist-info/top_level.txt +0 -14
  68. {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/WHEEL +0 -0
@@ -2,25 +2,25 @@ from triton.language import core
2
2
 
3
3
 
4
4
  @core.extern
5
- def globaltimer(_builder=None):
5
+ def globaltimer(_semantic=None):
6
6
  return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
7
- _builder=_builder)
7
+ _semantic=_semantic)
8
8
 
9
9
 
10
10
  @core.extern
11
- def smid(_builder=None):
11
+ def smid(_semantic=None):
12
12
  return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
13
- _builder=_builder)
13
+ _semantic=_semantic)
14
14
 
15
15
 
16
16
  @core.builtin
17
- def num_threads(_builder=None):
18
- return core.constexpr(_builder.options.num_warps * 32)
17
+ def num_threads(_semantic=None):
18
+ return core.constexpr(_semantic.builder.options.num_warps * 32)
19
19
 
20
20
 
21
21
  @core.builtin
22
- def num_warps(_builder=None):
23
- return core.constexpr(_builder.options.num_warps)
22
+ def num_warps(_semantic=None):
23
+ return core.constexpr(_semantic.builder.options.num_warps)
24
24
 
25
25
 
26
26
  # ----- FP8E4M3B15 ------
@@ -31,7 +31,7 @@ def num_warps(_builder=None):
31
31
  # - the exponent bias is 15 instead of 7
32
32
  # - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
33
33
  @core.builtin
34
- def convert_fp8e4b15_to_float16(arg, _builder=None):
34
+ def convert_fp8e4b15_to_float16(arg, _semantic=None):
35
35
  return core.inline_asm_elementwise(
36
36
  "{ \n"
37
37
  ".reg .b32 a<2>, b<2>; \n"
@@ -44,11 +44,11 @@ def convert_fp8e4b15_to_float16(arg, _builder=None):
44
44
  "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
45
45
  "shl.b32 $1, b1, 7; \n"
46
46
  "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4,
47
- _builder=_builder)
47
+ _semantic=_semantic)
48
48
 
49
49
 
50
50
  @core.builtin
51
- def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None):
51
+ def convert_float16_to_fp8e4b15(arg, has_minx2, _semantic=None):
52
52
  asm = """{
53
53
  .reg .pred p<4>;
54
54
  .reg .b32 a<2>, b<2>;
@@ -80,30 +80,30 @@ def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None):
80
80
  prmt.b32 $0, b0, b1, 0x7531;
81
81
  }"""
82
82
  return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4,
83
- _builder=_builder)
83
+ _semantic=_semantic)
84
84
 
85
85
 
86
86
  @core.builtin
87
- def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _builder=None):
87
+ def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _semantic=None):
88
88
  if arg.type.scalar.is_fp8e4b15():
89
- upcast_val = convert_fp8e4b15_to_float16(arg, _builder=_builder)
89
+ upcast_val = convert_fp8e4b15_to_float16(arg, _semantic=_semantic)
90
90
  if dst_ty.scalar.is_fp32():
91
- upcast_val = upcast_val.to(core.float32, _builder=_builder)
91
+ upcast_val = upcast_val.to(core.float32, _semantic=_semantic)
92
92
  return upcast_val
93
93
 
94
94
  assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32()
95
95
  downcast_val = arg
96
96
  if arg.type.scalar.is_fp32():
97
- downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _builder=_builder)
98
- downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _builder=_builder)
97
+ downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _semantic=_semantic)
98
+ downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _semantic=_semantic)
99
99
  return downcast_val
100
100
 
101
101
 
102
102
  @core.builtin
103
- def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _builder=None):
104
- return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _builder=_builder)
103
+ def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
104
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _semantic=_semantic)
105
105
 
106
106
 
107
107
  @core.builtin
108
- def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _builder=None):
109
- return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _builder=_builder)
108
+ def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
109
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _semantic=_semantic)