triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
@@ -0,0 +1,249 @@
1
+ from . import core
2
+ from functools import wraps
3
+ from typing import List
4
+
5
+ T = core.TypeVar('T')
6
+
7
+
8
+ def _check_dtype(dtypes: List[str]) -> T:
9
+ """
10
+ We're following libdevice's convention to check accepted data types for math functions.
11
+ It is not a good practice to support all data types as accelerators/GPUs don't support
12
+ many float16 and bfloat16 math operations.
13
+ We should let the users know that they are using and invoke explicit cast to convert
14
+ the data type to the supported one.
15
+ """
16
+
17
+ def wrapper(fn):
18
+
19
+ @wraps(fn)
20
+ def check(*args, **kwargs):
21
+ # concatenate args and kwargs
22
+ all_args = list(args) + list(kwargs.values())
23
+ for arg in [a for a in all_args if isinstance(a, core.tensor)]:
24
+ if arg.type.scalar.name not in dtypes:
25
+ raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}")
26
+ return fn(*args, **kwargs)
27
+
28
+ return check
29
+
30
+ return wrapper
31
+
32
+
33
+ def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]:
34
+
35
+ def _decorator(func: T) -> T:
36
+ docstr = """
37
+ Computes the element-wise {name} of :code:`x`.
38
+
39
+ :param x: the input values
40
+ :type x: Block
41
+ """
42
+ func.__doc__ = docstr.format(name=name)
43
+ return func
44
+
45
+ return _decorator
46
+
47
+
48
+ def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]:
49
+
50
+ def _decorator(func: T) -> T:
51
+ docstr = """
52
+ Computes the element-wise {name} of :code:`x` and :code:`y`.
53
+
54
+ :param x: the input values
55
+ :type x: Block
56
+ :param y: the input values
57
+ :type y: Block
58
+ """
59
+ func.__doc__ = docstr.format(name=name)
60
+ return func
61
+
62
+ return _decorator
63
+
64
+
65
+ def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]:
66
+
67
+ def _decorator(func: T) -> T:
68
+ docstr = """
69
+ Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`.
70
+
71
+ :param x: the input values
72
+ :type x: Block
73
+ :param y: the input values
74
+ :type y: Block
75
+ :param z: the input values
76
+ :type z: Block
77
+ """
78
+ func.__doc__ = docstr.format(name=name)
79
+ return func
80
+
81
+ return _decorator
82
+
83
+
84
+ @core.builtin
85
+ @_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"])
86
+ @_add_math_2arg_docstr("most significant N bits of the 2N-bit product")
87
+ def umulhi(x, y, _semantic=None):
88
+ x = _semantic.to_tensor(x)
89
+ y = _semantic.to_tensor(y)
90
+ x, y = core.binary_op_type_legalization(x, y, _semantic)
91
+ return core.tensor(_semantic.builder.create_umulhi(x.handle, y.handle), x.type)
92
+
93
+
94
+ @core.builtin
95
+ @_check_dtype(dtypes=["fp32", "fp64"])
96
+ @_add_math_1arg_docstr("exponential")
97
+ @core._tensor_member_fn
98
+ def exp(x, _semantic=None):
99
+ x = _semantic.to_tensor(x)
100
+ return core.tensor(_semantic.builder.create_exp(x.handle), x.type)
101
+
102
+
103
+ @core.builtin
104
+ @_check_dtype(dtypes=["fp32", "fp64"])
105
+ @_add_math_1arg_docstr("exponential (base 2)")
106
+ @core._tensor_member_fn
107
+ def exp2(x, _semantic=None):
108
+ x = _semantic.to_tensor(x)
109
+ return core.tensor(_semantic.builder.create_exp2(x.handle), x.type)
110
+
111
+
112
+ @core.builtin
113
+ @_check_dtype(dtypes=["fp32", "fp64"])
114
+ @_add_math_1arg_docstr("natural logarithm")
115
+ @core._tensor_member_fn
116
+ def log(x, _semantic=None):
117
+ x = _semantic.to_tensor(x)
118
+ return core.tensor(_semantic.builder.create_log(x.handle), x.type)
119
+
120
+
121
+ @core.builtin
122
+ @_check_dtype(dtypes=["fp32", "fp64"])
123
+ @_add_math_1arg_docstr("logarithm (base 2)")
124
+ @core._tensor_member_fn
125
+ def log2(x, _semantic=None):
126
+ x = _semantic.to_tensor(x)
127
+ return core.tensor(_semantic.builder.create_log2(x.handle), x.type)
128
+
129
+
130
+ @core.builtin
131
+ @_check_dtype(dtypes=["fp32", "fp64"])
132
+ @_add_math_1arg_docstr("cosine")
133
+ @core._tensor_member_fn
134
+ def cos(x, _semantic=None):
135
+ x = _semantic.to_tensor(x)
136
+ return core.tensor(_semantic.builder.create_cos(x.handle), x.type)
137
+
138
+
139
+ @core.builtin
140
+ @_check_dtype(dtypes=["fp32", "fp64"])
141
+ @_add_math_1arg_docstr("sine")
142
+ @core._tensor_member_fn
143
+ def sin(x, _semantic=None):
144
+ x = _semantic.to_tensor(x)
145
+ return core.tensor(_semantic.builder.create_sin(x.handle), x.type)
146
+
147
+
148
+ @core.builtin
149
+ @_check_dtype(dtypes=["fp32", "fp64"])
150
+ @_add_math_1arg_docstr("fast square root")
151
+ @core._tensor_member_fn
152
+ def sqrt(x, _semantic=None):
153
+ x = _semantic.to_tensor(x)
154
+ return core.tensor(_semantic.builder.create_sqrt(x.handle), x.type)
155
+
156
+
157
+ @core.builtin
158
+ @_check_dtype(dtypes=["fp32"])
159
+ @_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)")
160
+ @core._tensor_member_fn
161
+ def sqrt_rn(x, _semantic=None):
162
+ x = _semantic.to_tensor(x)
163
+ return core.tensor(_semantic.builder.create_precise_sqrt(x.handle), x.type)
164
+
165
+
166
+ @core.builtin
167
+ @_check_dtype(dtypes=["fp32", "fp64"])
168
+ @_add_math_1arg_docstr("inverse square root")
169
+ @core._tensor_member_fn
170
+ def rsqrt(x, _semantic=None):
171
+ x = _semantic.to_tensor(x)
172
+ return core.tensor(_semantic.builder.create_rsqrt(x.handle), x.type)
173
+
174
+
175
+ @core._tensor_member_fn
176
+ @core.builtin
177
+ @_add_math_1arg_docstr("absolute value")
178
+ def abs(x, _semantic=None):
179
+ x = _semantic.to_tensor(x)
180
+ dtype = x.dtype
181
+ if dtype.is_fp8e4b15():
182
+ mask = core.full(x.shape, 0x7F, core.int8, _semantic=_semantic)
183
+ return core.tensor(_semantic.builder.create_and(x.handle, mask.handle), x.type)
184
+ elif dtype.is_floating():
185
+ return core.tensor(_semantic.builder.create_fabs(x.handle), x.type)
186
+ elif dtype.is_int_signed():
187
+ return core.tensor(_semantic.builder.create_iabs(x.handle), x.type)
188
+ elif dtype.is_int_unsigned():
189
+ return x # no-op
190
+ else:
191
+ assert False, f"Unexpected dtype {dtype}"
192
+
193
+
194
+ @core.builtin
195
+ @_add_math_2arg_docstr("fast division")
196
+ def fdiv(x, y, ieee_rounding=False, _semantic=None):
197
+ ieee_rounding = core._unwrap_if_constexpr(ieee_rounding)
198
+ x = _semantic.to_tensor(x)
199
+ y = _semantic.to_tensor(y)
200
+ return _semantic.fdiv(x, y, ieee_rounding)
201
+
202
+
203
+ @core.builtin
204
+ @_check_dtype(dtypes=["fp32"])
205
+ @_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)")
206
+ def div_rn(x, y, _semantic=None):
207
+ x = _semantic.to_tensor(x)
208
+ y = _semantic.to_tensor(y)
209
+ x, y = core.binary_op_type_legalization(x, y, _semantic)
210
+ return core.tensor(_semantic.builder.create_precise_divf(x.handle, y.handle), x.type)
211
+
212
+
213
+ @core.builtin
214
+ @_check_dtype(dtypes=["fp32", "fp64"])
215
+ @_add_math_1arg_docstr("error function")
216
+ @core._tensor_member_fn
217
+ def erf(x, _semantic=None):
218
+ x = _semantic.to_tensor(x)
219
+ return core.tensor(_semantic.builder.create_erf(x.handle), x.type)
220
+
221
+
222
+ @core.builtin
223
+ @_check_dtype(dtypes=["fp32", "fp64"])
224
+ @_add_math_1arg_docstr("floor")
225
+ @core._tensor_member_fn
226
+ def floor(x, _semantic=None):
227
+ x = _semantic.to_tensor(x)
228
+ return core.tensor(_semantic.builder.create_floor(x.handle), x.type)
229
+
230
+
231
+ @core.builtin
232
+ @_check_dtype(dtypes=["fp32", "fp64"])
233
+ @_add_math_1arg_docstr("ceil")
234
+ @core._tensor_member_fn
235
+ def ceil(x, _semantic=None):
236
+ x = _semantic.to_tensor(x)
237
+ return core.tensor(_semantic.builder.create_ceil(x.handle), x.type)
238
+
239
+
240
+ @core.builtin
241
+ @_add_math_3arg_docstr("fused multiply-add")
242
+ def fma(x, y, z, _semantic=None):
243
+ x = _semantic.to_tensor(x)
244
+ y = _semantic.to_tensor(y)
245
+ z = _semantic.to_tensor(z)
246
+ x, y = core.binary_op_type_legalization(x, y, _semantic)
247
+ z, x = core.binary_op_type_legalization(z, x, _semantic)
248
+ z, y = core.binary_op_type_legalization(z, y, _semantic)
249
+ return core.tensor(_semantic.builder.create_fma(x.handle, y.handle, z.handle), x.type)
@@ -0,0 +1,218 @@
1
+ from ..runtime.jit import jit
2
+ from . import core as tl
3
+ from . import math
4
+
5
+ N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
6
+
7
+ # -------------------
8
+ # randint
9
+ # -------------------
10
+
11
+
12
+ @jit
13
+ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
14
+ """
15
+ Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
16
+ """
17
+ if c0.dtype == tl.uint32:
18
+ PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
19
+ PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
20
+ PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
21
+ PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
22
+ else:
23
+ tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl")
24
+ PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15
25
+ PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B
26
+ PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93
27
+ PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157
28
+
29
+ for _ in tl.static_range(n_rounds):
30
+ # for _ in range(n_rounds):
31
+ # update random state
32
+ A = PHILOX_ROUND_A
33
+ B = PHILOX_ROUND_B
34
+ _c0, _c2 = c0, c2
35
+ c0 = math.umulhi(B, _c2) ^ c1 ^ k0
36
+ c2 = math.umulhi(A, _c0) ^ c3 ^ k1
37
+ c1 = tl.mul(B, _c2, sanitize_overflow=False)
38
+ c3 = tl.mul(A, _c0, sanitize_overflow=False)
39
+ # raise key
40
+ k0 = tl.add(k0, PHILOX_KEY_A, sanitize_overflow=False)
41
+ k1 = tl.add(k1, PHILOX_KEY_B, sanitize_overflow=False)
42
+ return c0, c1, c2, c3
43
+
44
+
45
+ @jit
46
+ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
47
+ seed = tl.to_tensor(seed)
48
+ tl.static_assert(seed.dtype.is_int())
49
+ seed = seed.to(tl.uint64)
50
+ c0 = tl.to_tensor(c0)
51
+ c1 = tl.to_tensor(c1)
52
+ c2 = tl.to_tensor(c2)
53
+ c3 = tl.to_tensor(c3)
54
+
55
+ if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
56
+ int_dtype = tl.uint32
57
+ seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
58
+ seed_lo = (seed & 0xffffffff).to(tl.uint32)
59
+ else:
60
+ tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox")
61
+ int_dtype = tl.uint64
62
+ seed_hi = tl.full((1, ), 0, dtype=int_dtype)
63
+ seed_lo = seed
64
+
65
+ c0 = c0.to(int_dtype, bitcast=True)
66
+ c1 = c1.to(int_dtype, bitcast=True)
67
+ c2 = c2.to(int_dtype, bitcast=True)
68
+ c3 = c3.to(int_dtype, bitcast=True)
69
+ return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
70
+
71
+
72
+ @jit
73
+ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
74
+ """
75
+ Given a :code:`seed` scalar and an :code:`offset` block, returns a single
76
+ block of random :code:`int32`.
77
+
78
+ If you need multiple streams of random numbers,
79
+ using `randint4x` is likely to be faster than calling `randint` 4 times.
80
+
81
+ :param seed: The seed for generating random numbers.
82
+ :param offset: The offsets to generate random numbers for.
83
+ """
84
+ ret, _, _, _ = randint4x(seed, offset, n_rounds)
85
+ return ret
86
+
87
+
88
+ @jit
89
+ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
90
+ """
91
+ Given a :code:`seed` scalar and an :code:`offset` block, returns four
92
+ blocks of random :code:`int32`.
93
+
94
+ This is the maximally efficient entry point
95
+ to Triton's Philox pseudo-random number generator.
96
+
97
+ :param seed: The seed for generating random numbers.
98
+ :param offsets: The offsets to generate random numbers for.
99
+ """
100
+ # _0 = tl.zeros(offset.shape, offset.dtype)
101
+
102
+ offset_lo = offset.to(tl.uint32)
103
+ _0 = offset_lo * 0
104
+
105
+ if tl.constexpr(offset.dtype.primitive_bitwidth) > 32:
106
+ offset_hi = (offset >> 32).to(tl.uint32)
107
+ else:
108
+ offset_hi = _0
109
+
110
+ return philox(seed, offset_lo, offset_hi, _0, _0, n_rounds)
111
+
112
+
113
+ # -------------------
114
+ # rand
115
+ # -------------------
116
+
117
+ # @jit
118
+ # def uint32_to_uniform_float(x):
119
+ # """
120
+ # Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
121
+ # """
122
+ # two_to_the_minus_32: tl.constexpr = 2.328306e-10
123
+ # return x * two_to_the_minus_32
124
+
125
+
126
+ @jit
127
+ def uint_to_uniform_float(x):
128
+ """
129
+ Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
130
+ """
131
+ # TODO: fix frontend issues and cleanup
132
+ # conditions can be simplified
133
+ # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
134
+ if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32):
135
+ # maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
136
+ x = x.to(tl.int32, bitcast=True)
137
+ scale = 4.6566127342e-10
138
+ else:
139
+ tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64))
140
+ x = x.to(tl.int64, bitcast=True)
141
+ scale = 1.0842020432385337e-19
142
+ x = tl.where(x < 0, -x - 1, x)
143
+ return x * scale
144
+
145
+
146
+ @jit
147
+ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
148
+ """
149
+ Given a :code:`seed` scalar and an :code:`offset` block,
150
+ returns a block of random :code:`float32` in :math:`U(0, 1)`.
151
+
152
+ :param seed: The seed for generating random numbers.
153
+ :param offsets: The offsets to generate random numbers for.
154
+ """
155
+ source = randint(seed, offset, n_rounds)
156
+ return uint_to_uniform_float(source)
157
+
158
+
159
+ @jit
160
+ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
161
+ """
162
+ Given a :code:`seed` scalar and an :code:`offsets` block,
163
+ returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`.
164
+
165
+ :param seed: The seed for generating random numbers.
166
+ :param offsets: The offsets to generate random numbers for.
167
+ """
168
+ i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
169
+ u1 = uint_to_uniform_float(i1)
170
+ u2 = uint_to_uniform_float(i2)
171
+ u3 = uint_to_uniform_float(i3)
172
+ u4 = uint_to_uniform_float(i4)
173
+ return u1, u2, u3, u4
174
+
175
+
176
+ # -------------------
177
+ # randn
178
+ # -------------------
179
+
180
+
181
+ @jit
182
+ def pair_uniform_to_normal(u1, u2):
183
+ """Box-Muller transform"""
184
+ u1 = tl.maximum(1.0e-7, u1)
185
+ th = 6.283185307179586 * u2
186
+ r = math.sqrt(-2.0 * math.log(u1))
187
+ return r * math.cos(th), r * math.sin(th)
188
+
189
+
190
+ @jit
191
+ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
192
+ """
193
+ Given a :code:`seed` scalar and an :code:`offset` block,
194
+ returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
195
+
196
+ :param seed: The seed for generating random numbers.
197
+ :param offsets: The offsets to generate random numbers for.
198
+ """
199
+ i1, i2, _, _ = randint4x(seed, offset, n_rounds)
200
+ u1 = uint_to_uniform_float(i1)
201
+ u2 = uint_to_uniform_float(i2)
202
+ n1, _ = pair_uniform_to_normal(u1, u2)
203
+ return n1
204
+
205
+
206
+ @jit
207
+ def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
208
+ """
209
+ Given a :code:`seed` scalar and an :code:`offset` block,
210
+ returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
211
+
212
+ :param seed: The seed for generating random numbers.
213
+ :param offsets: The offsets to generate random numbers for.
214
+ """
215
+ u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
216
+ n1, n2 = pair_uniform_to_normal(u1, u2)
217
+ n3, n4 = pair_uniform_to_normal(u3, u4)
218
+ return n1, n2, n3, n4