triton-windows 3.3.1.post21__cp312-cp312-win_amd64.whl → 3.4.0.post21__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +4 -1
- triton/_filecheck.py +87 -0
- triton/_internal_testing.py +26 -15
- triton/_utils.py +110 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +112 -78
- triton/backends/amd/driver.c +5 -2
- triton/backends/amd/driver.py +143 -46
- triton/backends/compiler.py +7 -21
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +94 -94
- triton/backends/nvidia/driver.c +90 -98
- triton/backends/nvidia/driver.py +296 -125
- triton/compiler/code_generator.py +212 -111
- triton/compiler/compiler.py +110 -25
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +4 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +99 -0
- triton/experimental/gluon/language/__init__.py +18 -0
- triton/experimental/gluon/language/_core.py +312 -0
- triton/experimental/gluon/language/_layouts.py +230 -0
- triton/experimental/gluon/language/_math.py +12 -0
- triton/experimental/gluon/language/_semantic.py +287 -0
- triton/experimental/gluon/language/_standard.py +47 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +40 -0
- triton/knobs.py +481 -0
- triton/language/__init__.py +39 -14
- triton/language/core.py +794 -537
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/libdevice.py +113 -104
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1706 -1770
- triton/language/standard.py +116 -51
- triton/runtime/autotuner.py +117 -59
- triton/runtime/build.py +73 -9
- triton/runtime/cache.py +18 -47
- triton/runtime/driver.py +32 -29
- triton/runtime/interpreter.py +72 -35
- triton/runtime/jit.py +146 -110
- triton/testing.py +16 -12
- triton/tools/disasm.py +3 -4
- triton/tools/tensor_descriptor.py +36 -0
- triton/windows_utils.py +47 -83
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/METADATA +7 -2
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/RECORD +64 -41
- triton_windows-3.4.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.4.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.4.0.post21.dist-info/top_level.txt +1 -0
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post21.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post21.dist-info → triton_windows-3.4.0.post21.dist-info}/WHEEL +0 -0
|
@@ -2,25 +2,25 @@ from triton.language import core
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
@core.extern
|
|
5
|
-
def globaltimer(
|
|
5
|
+
def globaltimer(_semantic=None):
|
|
6
6
|
return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
|
|
7
|
-
|
|
7
|
+
_semantic=_semantic)
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
@core.extern
|
|
11
|
-
def smid(
|
|
11
|
+
def smid(_semantic=None):
|
|
12
12
|
return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
|
|
13
|
-
|
|
13
|
+
_semantic=_semantic)
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
@core.builtin
|
|
17
|
-
def num_threads(
|
|
18
|
-
return core.constexpr(
|
|
17
|
+
def num_threads(_semantic=None):
|
|
18
|
+
return core.constexpr(_semantic.builder.options.num_warps * 32)
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
@core.builtin
|
|
22
|
-
def num_warps(
|
|
23
|
-
return core.constexpr(
|
|
22
|
+
def num_warps(_semantic=None):
|
|
23
|
+
return core.constexpr(_semantic.builder.options.num_warps)
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
# ----- FP8E4M3B15 ------
|
|
@@ -31,7 +31,7 @@ def num_warps(_builder=None):
|
|
|
31
31
|
# - the exponent bias is 15 instead of 7
|
|
32
32
|
# - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
|
|
33
33
|
@core.builtin
|
|
34
|
-
def convert_fp8e4b15_to_float16(arg,
|
|
34
|
+
def convert_fp8e4b15_to_float16(arg, _semantic=None):
|
|
35
35
|
return core.inline_asm_elementwise(
|
|
36
36
|
"{ \n"
|
|
37
37
|
".reg .b32 a<2>, b<2>; \n"
|
|
@@ -44,11 +44,11 @@ def convert_fp8e4b15_to_float16(arg, _builder=None):
|
|
|
44
44
|
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
|
45
45
|
"shl.b32 $1, b1, 7; \n"
|
|
46
46
|
"} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4,
|
|
47
|
-
|
|
47
|
+
_semantic=_semantic)
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
@core.builtin
|
|
51
|
-
def convert_float16_to_fp8e4b15(arg, has_minx2,
|
|
51
|
+
def convert_float16_to_fp8e4b15(arg, has_minx2, _semantic=None):
|
|
52
52
|
asm = """{
|
|
53
53
|
.reg .pred p<4>;
|
|
54
54
|
.reg .b32 a<2>, b<2>;
|
|
@@ -80,30 +80,30 @@ def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None):
|
|
|
80
80
|
prmt.b32 $0, b0, b1, 0x7531;
|
|
81
81
|
}"""
|
|
82
82
|
return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4,
|
|
83
|
-
|
|
83
|
+
_semantic=_semantic)
|
|
84
84
|
|
|
85
85
|
|
|
86
86
|
@core.builtin
|
|
87
|
-
def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2,
|
|
87
|
+
def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _semantic=None):
|
|
88
88
|
if arg.type.scalar.is_fp8e4b15():
|
|
89
|
-
upcast_val = convert_fp8e4b15_to_float16(arg,
|
|
89
|
+
upcast_val = convert_fp8e4b15_to_float16(arg, _semantic=_semantic)
|
|
90
90
|
if dst_ty.scalar.is_fp32():
|
|
91
|
-
upcast_val = upcast_val.to(core.float32,
|
|
91
|
+
upcast_val = upcast_val.to(core.float32, _semantic=_semantic)
|
|
92
92
|
return upcast_val
|
|
93
93
|
|
|
94
94
|
assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32()
|
|
95
95
|
downcast_val = arg
|
|
96
96
|
if arg.type.scalar.is_fp32():
|
|
97
|
-
downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz",
|
|
98
|
-
downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2,
|
|
97
|
+
downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _semantic=_semantic)
|
|
98
|
+
downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _semantic=_semantic)
|
|
99
99
|
return downcast_val
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
@core.builtin
|
|
103
|
-
def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None,
|
|
104
|
-
return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True,
|
|
103
|
+
def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
|
|
104
|
+
return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _semantic=_semantic)
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
@core.builtin
|
|
108
|
-
def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None,
|
|
109
|
-
return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False,
|
|
108
|
+
def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
|
|
109
|
+
return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _semantic=_semantic)
|