triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +82 -0
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +255 -0
- triton/_utils.py +126 -0
- triton/backends/__init__.py +47 -0
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +461 -0
- triton/backends/amd/driver.c +283 -0
- triton/backends/amd/driver.py +724 -0
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +90 -0
- triton/backends/driver.py +66 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +533 -0
- triton/backends/nvidia/driver.c +517 -0
- triton/backends/nvidia/driver.py +799 -0
- triton/backends/nvidia/include/cuda.h +26280 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +7 -0
- triton/compiler/code_generator.py +1614 -0
- triton/compiler/compiler.py +509 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +342 -0
- triton/language/core.py +3405 -0
- triton/language/extra/__init__.py +26 -0
- triton/language/extra/cuda/__init__.py +16 -0
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +5 -0
- triton/language/extra/hip/libdevice.py +491 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +790 -0
- triton/language/math.py +249 -0
- triton/language/random.py +218 -0
- triton/language/semantic.py +1939 -0
- triton/language/standard.py +534 -0
- triton/language/target_info.py +54 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/_allocation.py +44 -0
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +476 -0
- triton/runtime/build.py +168 -0
- triton/runtime/cache.py +317 -0
- triton/runtime/driver.py +38 -0
- triton/runtime/errors.py +36 -0
- triton/runtime/interpreter.py +1414 -0
- triton/runtime/jit.py +1107 -0
- triton/runtime/tcc/include/_mingw.h +168 -0
- triton/runtime/tcc/include/assert.h +62 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +75 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +116 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +497 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +14 -0
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +42 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +591 -0
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +123 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +2958 -0
- triton/runtime/tcc/include/winapi/wincon.h +309 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +5837 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/python310.def +1610 -0
- triton/runtime/tcc/lib/python311.def +1633 -0
- triton/runtime/tcc/lib/python312.def +1703 -0
- triton/runtime/tcc/lib/python313.def +1651 -0
- triton/runtime/tcc/lib/python313t.def +1656 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/lib/python39.def +1644 -0
- triton/runtime/tcc/lib/python3t.def +905 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +543 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.py +210 -0
- triton/tools/disasm.py +143 -0
- triton/tools/extra/cuda/compile.c +70 -0
- triton/tools/extra/cuda/compile.h +14 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/link.py +322 -0
- triton/tools/mxfp.py +301 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +405 -0
- triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
triton/language/math.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from . import core
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
T = core.TypeVar('T')
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _check_dtype(dtypes: List[str]) -> T:
|
|
9
|
+
"""
|
|
10
|
+
We're following libdevice's convention to check accepted data types for math functions.
|
|
11
|
+
It is not a good practice to support all data types as accelerators/GPUs don't support
|
|
12
|
+
many float16 and bfloat16 math operations.
|
|
13
|
+
We should let the users know that they are using and invoke explicit cast to convert
|
|
14
|
+
the data type to the supported one.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def wrapper(fn):
|
|
18
|
+
|
|
19
|
+
@wraps(fn)
|
|
20
|
+
def check(*args, **kwargs):
|
|
21
|
+
# concatenate args and kwargs
|
|
22
|
+
all_args = list(args) + list(kwargs.values())
|
|
23
|
+
for arg in [a for a in all_args if isinstance(a, core.tensor)]:
|
|
24
|
+
if arg.type.scalar.name not in dtypes:
|
|
25
|
+
raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}")
|
|
26
|
+
return fn(*args, **kwargs)
|
|
27
|
+
|
|
28
|
+
return check
|
|
29
|
+
|
|
30
|
+
return wrapper
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]:
|
|
34
|
+
|
|
35
|
+
def _decorator(func: T) -> T:
|
|
36
|
+
docstr = """
|
|
37
|
+
Computes the element-wise {name} of :code:`x`.
|
|
38
|
+
|
|
39
|
+
:param x: the input values
|
|
40
|
+
:type x: Block
|
|
41
|
+
"""
|
|
42
|
+
func.__doc__ = docstr.format(name=name)
|
|
43
|
+
return func
|
|
44
|
+
|
|
45
|
+
return _decorator
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]:
|
|
49
|
+
|
|
50
|
+
def _decorator(func: T) -> T:
|
|
51
|
+
docstr = """
|
|
52
|
+
Computes the element-wise {name} of :code:`x` and :code:`y`.
|
|
53
|
+
|
|
54
|
+
:param x: the input values
|
|
55
|
+
:type x: Block
|
|
56
|
+
:param y: the input values
|
|
57
|
+
:type y: Block
|
|
58
|
+
"""
|
|
59
|
+
func.__doc__ = docstr.format(name=name)
|
|
60
|
+
return func
|
|
61
|
+
|
|
62
|
+
return _decorator
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]:
|
|
66
|
+
|
|
67
|
+
def _decorator(func: T) -> T:
|
|
68
|
+
docstr = """
|
|
69
|
+
Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`.
|
|
70
|
+
|
|
71
|
+
:param x: the input values
|
|
72
|
+
:type x: Block
|
|
73
|
+
:param y: the input values
|
|
74
|
+
:type y: Block
|
|
75
|
+
:param z: the input values
|
|
76
|
+
:type z: Block
|
|
77
|
+
"""
|
|
78
|
+
func.__doc__ = docstr.format(name=name)
|
|
79
|
+
return func
|
|
80
|
+
|
|
81
|
+
return _decorator
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@core.builtin
|
|
85
|
+
@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"])
|
|
86
|
+
@_add_math_2arg_docstr("most significant N bits of the 2N-bit product")
|
|
87
|
+
def umulhi(x, y, _semantic=None):
|
|
88
|
+
x = _semantic.to_tensor(x)
|
|
89
|
+
y = _semantic.to_tensor(y)
|
|
90
|
+
x, y = core.binary_op_type_legalization(x, y, _semantic)
|
|
91
|
+
return core.tensor(_semantic.builder.create_umulhi(x.handle, y.handle), x.type)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@core.builtin
|
|
95
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
96
|
+
@_add_math_1arg_docstr("exponential")
|
|
97
|
+
@core._tensor_member_fn
|
|
98
|
+
def exp(x, _semantic=None):
|
|
99
|
+
x = _semantic.to_tensor(x)
|
|
100
|
+
return core.tensor(_semantic.builder.create_exp(x.handle), x.type)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@core.builtin
|
|
104
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
105
|
+
@_add_math_1arg_docstr("exponential (base 2)")
|
|
106
|
+
@core._tensor_member_fn
|
|
107
|
+
def exp2(x, _semantic=None):
|
|
108
|
+
x = _semantic.to_tensor(x)
|
|
109
|
+
return core.tensor(_semantic.builder.create_exp2(x.handle), x.type)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@core.builtin
|
|
113
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
114
|
+
@_add_math_1arg_docstr("natural logarithm")
|
|
115
|
+
@core._tensor_member_fn
|
|
116
|
+
def log(x, _semantic=None):
|
|
117
|
+
x = _semantic.to_tensor(x)
|
|
118
|
+
return core.tensor(_semantic.builder.create_log(x.handle), x.type)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@core.builtin
|
|
122
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
123
|
+
@_add_math_1arg_docstr("logarithm (base 2)")
|
|
124
|
+
@core._tensor_member_fn
|
|
125
|
+
def log2(x, _semantic=None):
|
|
126
|
+
x = _semantic.to_tensor(x)
|
|
127
|
+
return core.tensor(_semantic.builder.create_log2(x.handle), x.type)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@core.builtin
|
|
131
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
132
|
+
@_add_math_1arg_docstr("cosine")
|
|
133
|
+
@core._tensor_member_fn
|
|
134
|
+
def cos(x, _semantic=None):
|
|
135
|
+
x = _semantic.to_tensor(x)
|
|
136
|
+
return core.tensor(_semantic.builder.create_cos(x.handle), x.type)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@core.builtin
|
|
140
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
141
|
+
@_add_math_1arg_docstr("sine")
|
|
142
|
+
@core._tensor_member_fn
|
|
143
|
+
def sin(x, _semantic=None):
|
|
144
|
+
x = _semantic.to_tensor(x)
|
|
145
|
+
return core.tensor(_semantic.builder.create_sin(x.handle), x.type)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@core.builtin
|
|
149
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
150
|
+
@_add_math_1arg_docstr("fast square root")
|
|
151
|
+
@core._tensor_member_fn
|
|
152
|
+
def sqrt(x, _semantic=None):
|
|
153
|
+
x = _semantic.to_tensor(x)
|
|
154
|
+
return core.tensor(_semantic.builder.create_sqrt(x.handle), x.type)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@core.builtin
|
|
158
|
+
@_check_dtype(dtypes=["fp32"])
|
|
159
|
+
@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)")
|
|
160
|
+
@core._tensor_member_fn
|
|
161
|
+
def sqrt_rn(x, _semantic=None):
|
|
162
|
+
x = _semantic.to_tensor(x)
|
|
163
|
+
return core.tensor(_semantic.builder.create_precise_sqrt(x.handle), x.type)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@core.builtin
|
|
167
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
168
|
+
@_add_math_1arg_docstr("inverse square root")
|
|
169
|
+
@core._tensor_member_fn
|
|
170
|
+
def rsqrt(x, _semantic=None):
|
|
171
|
+
x = _semantic.to_tensor(x)
|
|
172
|
+
return core.tensor(_semantic.builder.create_rsqrt(x.handle), x.type)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@core._tensor_member_fn
|
|
176
|
+
@core.builtin
|
|
177
|
+
@_add_math_1arg_docstr("absolute value")
|
|
178
|
+
def abs(x, _semantic=None):
|
|
179
|
+
x = _semantic.to_tensor(x)
|
|
180
|
+
dtype = x.dtype
|
|
181
|
+
if dtype.is_fp8e4b15():
|
|
182
|
+
mask = core.full(x.shape, 0x7F, core.int8, _semantic=_semantic)
|
|
183
|
+
return core.tensor(_semantic.builder.create_and(x.handle, mask.handle), x.type)
|
|
184
|
+
elif dtype.is_floating():
|
|
185
|
+
return core.tensor(_semantic.builder.create_fabs(x.handle), x.type)
|
|
186
|
+
elif dtype.is_int_signed():
|
|
187
|
+
return core.tensor(_semantic.builder.create_iabs(x.handle), x.type)
|
|
188
|
+
elif dtype.is_int_unsigned():
|
|
189
|
+
return x # no-op
|
|
190
|
+
else:
|
|
191
|
+
assert False, f"Unexpected dtype {dtype}"
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@core.builtin
|
|
195
|
+
@_add_math_2arg_docstr("fast division")
|
|
196
|
+
def fdiv(x, y, ieee_rounding=False, _semantic=None):
|
|
197
|
+
ieee_rounding = core._unwrap_if_constexpr(ieee_rounding)
|
|
198
|
+
x = _semantic.to_tensor(x)
|
|
199
|
+
y = _semantic.to_tensor(y)
|
|
200
|
+
return _semantic.fdiv(x, y, ieee_rounding)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@core.builtin
|
|
204
|
+
@_check_dtype(dtypes=["fp32"])
|
|
205
|
+
@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)")
|
|
206
|
+
def div_rn(x, y, _semantic=None):
|
|
207
|
+
x = _semantic.to_tensor(x)
|
|
208
|
+
y = _semantic.to_tensor(y)
|
|
209
|
+
x, y = core.binary_op_type_legalization(x, y, _semantic)
|
|
210
|
+
return core.tensor(_semantic.builder.create_precise_divf(x.handle, y.handle), x.type)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@core.builtin
|
|
214
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
215
|
+
@_add_math_1arg_docstr("error function")
|
|
216
|
+
@core._tensor_member_fn
|
|
217
|
+
def erf(x, _semantic=None):
|
|
218
|
+
x = _semantic.to_tensor(x)
|
|
219
|
+
return core.tensor(_semantic.builder.create_erf(x.handle), x.type)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@core.builtin
|
|
223
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
224
|
+
@_add_math_1arg_docstr("floor")
|
|
225
|
+
@core._tensor_member_fn
|
|
226
|
+
def floor(x, _semantic=None):
|
|
227
|
+
x = _semantic.to_tensor(x)
|
|
228
|
+
return core.tensor(_semantic.builder.create_floor(x.handle), x.type)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@core.builtin
|
|
232
|
+
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
233
|
+
@_add_math_1arg_docstr("ceil")
|
|
234
|
+
@core._tensor_member_fn
|
|
235
|
+
def ceil(x, _semantic=None):
|
|
236
|
+
x = _semantic.to_tensor(x)
|
|
237
|
+
return core.tensor(_semantic.builder.create_ceil(x.handle), x.type)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@core.builtin
|
|
241
|
+
@_add_math_3arg_docstr("fused multiply-add")
|
|
242
|
+
def fma(x, y, z, _semantic=None):
|
|
243
|
+
x = _semantic.to_tensor(x)
|
|
244
|
+
y = _semantic.to_tensor(y)
|
|
245
|
+
z = _semantic.to_tensor(z)
|
|
246
|
+
x, y = core.binary_op_type_legalization(x, y, _semantic)
|
|
247
|
+
z, x = core.binary_op_type_legalization(z, x, _semantic)
|
|
248
|
+
z, y = core.binary_op_type_legalization(z, y, _semantic)
|
|
249
|
+
return core.tensor(_semantic.builder.create_fma(x.handle, y.handle, z.handle), x.type)
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from ..runtime.jit import jit
|
|
2
|
+
from . import core as tl
|
|
3
|
+
from . import math
|
|
4
|
+
|
|
5
|
+
N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
|
|
6
|
+
|
|
7
|
+
# -------------------
|
|
8
|
+
# randint
|
|
9
|
+
# -------------------
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@jit
|
|
13
|
+
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
14
|
+
"""
|
|
15
|
+
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
|
|
16
|
+
"""
|
|
17
|
+
if c0.dtype == tl.uint32:
|
|
18
|
+
PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
|
|
19
|
+
PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
|
|
20
|
+
PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
|
|
21
|
+
PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
|
|
22
|
+
else:
|
|
23
|
+
tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl")
|
|
24
|
+
PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15
|
|
25
|
+
PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B
|
|
26
|
+
PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93
|
|
27
|
+
PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157
|
|
28
|
+
|
|
29
|
+
for _ in tl.static_range(n_rounds):
|
|
30
|
+
# for _ in range(n_rounds):
|
|
31
|
+
# update random state
|
|
32
|
+
A = PHILOX_ROUND_A
|
|
33
|
+
B = PHILOX_ROUND_B
|
|
34
|
+
_c0, _c2 = c0, c2
|
|
35
|
+
c0 = math.umulhi(B, _c2) ^ c1 ^ k0
|
|
36
|
+
c2 = math.umulhi(A, _c0) ^ c3 ^ k1
|
|
37
|
+
c1 = tl.mul(B, _c2, sanitize_overflow=False)
|
|
38
|
+
c3 = tl.mul(A, _c0, sanitize_overflow=False)
|
|
39
|
+
# raise key
|
|
40
|
+
k0 = tl.add(k0, PHILOX_KEY_A, sanitize_overflow=False)
|
|
41
|
+
k1 = tl.add(k1, PHILOX_KEY_B, sanitize_overflow=False)
|
|
42
|
+
return c0, c1, c2, c3
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@jit
|
|
46
|
+
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
47
|
+
seed = tl.to_tensor(seed)
|
|
48
|
+
tl.static_assert(seed.dtype.is_int())
|
|
49
|
+
seed = seed.to(tl.uint64)
|
|
50
|
+
c0 = tl.to_tensor(c0)
|
|
51
|
+
c1 = tl.to_tensor(c1)
|
|
52
|
+
c2 = tl.to_tensor(c2)
|
|
53
|
+
c3 = tl.to_tensor(c3)
|
|
54
|
+
|
|
55
|
+
if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
|
|
56
|
+
int_dtype = tl.uint32
|
|
57
|
+
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
|
|
58
|
+
seed_lo = (seed & 0xffffffff).to(tl.uint32)
|
|
59
|
+
else:
|
|
60
|
+
tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox")
|
|
61
|
+
int_dtype = tl.uint64
|
|
62
|
+
seed_hi = tl.full((1, ), 0, dtype=int_dtype)
|
|
63
|
+
seed_lo = seed
|
|
64
|
+
|
|
65
|
+
c0 = c0.to(int_dtype, bitcast=True)
|
|
66
|
+
c1 = c1.to(int_dtype, bitcast=True)
|
|
67
|
+
c2 = c2.to(int_dtype, bitcast=True)
|
|
68
|
+
c3 = c3.to(int_dtype, bitcast=True)
|
|
69
|
+
return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@jit
|
|
73
|
+
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
74
|
+
"""
|
|
75
|
+
Given a :code:`seed` scalar and an :code:`offset` block, returns a single
|
|
76
|
+
block of random :code:`int32`.
|
|
77
|
+
|
|
78
|
+
If you need multiple streams of random numbers,
|
|
79
|
+
using `randint4x` is likely to be faster than calling `randint` 4 times.
|
|
80
|
+
|
|
81
|
+
:param seed: The seed for generating random numbers.
|
|
82
|
+
:param offset: The offsets to generate random numbers for.
|
|
83
|
+
"""
|
|
84
|
+
ret, _, _, _ = randint4x(seed, offset, n_rounds)
|
|
85
|
+
return ret
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@jit
|
|
89
|
+
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
90
|
+
"""
|
|
91
|
+
Given a :code:`seed` scalar and an :code:`offset` block, returns four
|
|
92
|
+
blocks of random :code:`int32`.
|
|
93
|
+
|
|
94
|
+
This is the maximally efficient entry point
|
|
95
|
+
to Triton's Philox pseudo-random number generator.
|
|
96
|
+
|
|
97
|
+
:param seed: The seed for generating random numbers.
|
|
98
|
+
:param offsets: The offsets to generate random numbers for.
|
|
99
|
+
"""
|
|
100
|
+
# _0 = tl.zeros(offset.shape, offset.dtype)
|
|
101
|
+
|
|
102
|
+
offset_lo = offset.to(tl.uint32)
|
|
103
|
+
_0 = offset_lo * 0
|
|
104
|
+
|
|
105
|
+
if tl.constexpr(offset.dtype.primitive_bitwidth) > 32:
|
|
106
|
+
offset_hi = (offset >> 32).to(tl.uint32)
|
|
107
|
+
else:
|
|
108
|
+
offset_hi = _0
|
|
109
|
+
|
|
110
|
+
return philox(seed, offset_lo, offset_hi, _0, _0, n_rounds)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# -------------------
|
|
114
|
+
# rand
|
|
115
|
+
# -------------------
|
|
116
|
+
|
|
117
|
+
# @jit
|
|
118
|
+
# def uint32_to_uniform_float(x):
|
|
119
|
+
# """
|
|
120
|
+
# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
|
|
121
|
+
# """
|
|
122
|
+
# two_to_the_minus_32: tl.constexpr = 2.328306e-10
|
|
123
|
+
# return x * two_to_the_minus_32
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@jit
|
|
127
|
+
def uint_to_uniform_float(x):
|
|
128
|
+
"""
|
|
129
|
+
Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
|
|
130
|
+
"""
|
|
131
|
+
# TODO: fix frontend issues and cleanup
|
|
132
|
+
# conditions can be simplified
|
|
133
|
+
# scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
|
|
134
|
+
if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32):
|
|
135
|
+
# maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
|
|
136
|
+
x = x.to(tl.int32, bitcast=True)
|
|
137
|
+
scale = 4.6566127342e-10
|
|
138
|
+
else:
|
|
139
|
+
tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64))
|
|
140
|
+
x = x.to(tl.int64, bitcast=True)
|
|
141
|
+
scale = 1.0842020432385337e-19
|
|
142
|
+
x = tl.where(x < 0, -x - 1, x)
|
|
143
|
+
return x * scale
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@jit
|
|
147
|
+
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
148
|
+
"""
|
|
149
|
+
Given a :code:`seed` scalar and an :code:`offset` block,
|
|
150
|
+
returns a block of random :code:`float32` in :math:`U(0, 1)`.
|
|
151
|
+
|
|
152
|
+
:param seed: The seed for generating random numbers.
|
|
153
|
+
:param offsets: The offsets to generate random numbers for.
|
|
154
|
+
"""
|
|
155
|
+
source = randint(seed, offset, n_rounds)
|
|
156
|
+
return uint_to_uniform_float(source)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@jit
|
|
160
|
+
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
161
|
+
"""
|
|
162
|
+
Given a :code:`seed` scalar and an :code:`offsets` block,
|
|
163
|
+
returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`.
|
|
164
|
+
|
|
165
|
+
:param seed: The seed for generating random numbers.
|
|
166
|
+
:param offsets: The offsets to generate random numbers for.
|
|
167
|
+
"""
|
|
168
|
+
i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
|
|
169
|
+
u1 = uint_to_uniform_float(i1)
|
|
170
|
+
u2 = uint_to_uniform_float(i2)
|
|
171
|
+
u3 = uint_to_uniform_float(i3)
|
|
172
|
+
u4 = uint_to_uniform_float(i4)
|
|
173
|
+
return u1, u2, u3, u4
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# -------------------
|
|
177
|
+
# randn
|
|
178
|
+
# -------------------
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@jit
|
|
182
|
+
def pair_uniform_to_normal(u1, u2):
|
|
183
|
+
"""Box-Muller transform"""
|
|
184
|
+
u1 = tl.maximum(1.0e-7, u1)
|
|
185
|
+
th = 6.283185307179586 * u2
|
|
186
|
+
r = math.sqrt(-2.0 * math.log(u1))
|
|
187
|
+
return r * math.cos(th), r * math.sin(th)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@jit
|
|
191
|
+
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
192
|
+
"""
|
|
193
|
+
Given a :code:`seed` scalar and an :code:`offset` block,
|
|
194
|
+
returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
|
|
195
|
+
|
|
196
|
+
:param seed: The seed for generating random numbers.
|
|
197
|
+
:param offsets: The offsets to generate random numbers for.
|
|
198
|
+
"""
|
|
199
|
+
i1, i2, _, _ = randint4x(seed, offset, n_rounds)
|
|
200
|
+
u1 = uint_to_uniform_float(i1)
|
|
201
|
+
u2 = uint_to_uniform_float(i2)
|
|
202
|
+
n1, _ = pair_uniform_to_normal(u1, u2)
|
|
203
|
+
return n1
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@jit
|
|
207
|
+
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
208
|
+
"""
|
|
209
|
+
Given a :code:`seed` scalar and an :code:`offset` block,
|
|
210
|
+
returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
|
|
211
|
+
|
|
212
|
+
:param seed: The seed for generating random numbers.
|
|
213
|
+
:param offsets: The offsets to generate random numbers for.
|
|
214
|
+
"""
|
|
215
|
+
u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
|
|
216
|
+
n1, n2 = pair_uniform_to_normal(u1, u2)
|
|
217
|
+
n3, n4 = pair_uniform_to_normal(u3, u4)
|
|
218
|
+
return n1, n2, n3, n4
|