triton-windows 3.1.0.post17__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,250 @@
1
+ from . import core
2
+ from . import semantic
3
+ from functools import wraps
4
+ from typing import List
5
+
6
+ T = core.TypeVar('T')
7
+
8
+
9
+ def _check_dtype(dtypes: List[str]) -> T:
10
+ """
11
+ We're following libdevice's convention to check accepted data types for math functions.
12
+ It is not a good practice to support all data types as accelerators/GPUs don't support
13
+ many float16 and bfloat16 math operations.
14
+ We should let the users know that they are using and invoke explicit cast to convert
15
+ the data type to the supported one.
16
+ """
17
+
18
+ def wrapper(fn):
19
+
20
+ @wraps(fn)
21
+ def check(*args, **kwargs):
22
+ # concatenate args and kwargs
23
+ all_args = list(args) + list(kwargs.values())
24
+ for arg in [a for a in all_args if isinstance(a, core.tensor)]:
25
+ if arg.type.scalar.name not in dtypes:
26
+ raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}")
27
+ return fn(*args, **kwargs)
28
+
29
+ return check
30
+
31
+ return wrapper
32
+
33
+
34
+ def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]:
35
+
36
+ def _decorator(func: T) -> T:
37
+ docstr = """
38
+ Computes the element-wise {name} of :code:`x`.
39
+
40
+ :param x: the input values
41
+ :type x: Block
42
+ """
43
+ func.__doc__ = docstr.format(name=name)
44
+ return func
45
+
46
+ return _decorator
47
+
48
+
49
+ def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]:
50
+
51
+ def _decorator(func: T) -> T:
52
+ docstr = """
53
+ Computes the element-wise {name} of :code:`x` and :code:`y`.
54
+
55
+ :param x: the input values
56
+ :type x: Block
57
+ :param y: the input values
58
+ :type y: Block
59
+ """
60
+ func.__doc__ = docstr.format(name=name)
61
+ return func
62
+
63
+ return _decorator
64
+
65
+
66
+ def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]:
67
+
68
+ def _decorator(func: T) -> T:
69
+ docstr = """
70
+ Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`.
71
+
72
+ :param x: the input values
73
+ :type x: Block
74
+ :param y: the input values
75
+ :type y: Block
76
+ :param z: the input values
77
+ :type z: Block
78
+ """
79
+ func.__doc__ = docstr.format(name=name)
80
+ return func
81
+
82
+ return _decorator
83
+
84
+
85
+ @core.builtin
86
+ @_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"])
87
+ @_add_math_2arg_docstr("most significant N bits of the 2N-bit product")
88
+ def umulhi(x, y, _builder=None):
89
+ x = core._to_tensor(x, _builder)
90
+ y = core._to_tensor(y, _builder)
91
+ x, y = core.binary_op_type_legalization(x, y, _builder)
92
+ return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type)
93
+
94
+
95
+ @core.builtin
96
+ @_check_dtype(dtypes=["fp32", "fp64"])
97
+ @_add_math_1arg_docstr("exponential")
98
+ @core._tensor_member_fn
99
+ def exp(x, _builder=None):
100
+ x = core._to_tensor(x, _builder)
101
+ return core.tensor(_builder.create_exp(x.handle), x.type)
102
+
103
+
104
+ @core.builtin
105
+ @_check_dtype(dtypes=["fp32", "fp64"])
106
+ @_add_math_1arg_docstr("exponential (base 2)")
107
+ @core._tensor_member_fn
108
+ def exp2(x, _builder=None):
109
+ x = core._to_tensor(x, _builder)
110
+ return core.tensor(_builder.create_exp2(x.handle), x.type)
111
+
112
+
113
+ @core.builtin
114
+ @_check_dtype(dtypes=["fp32", "fp64"])
115
+ @_add_math_1arg_docstr("natural logarithm")
116
+ @core._tensor_member_fn
117
+ def log(x, _builder=None):
118
+ x = core._to_tensor(x, _builder)
119
+ return core.tensor(_builder.create_log(x.handle), x.type)
120
+
121
+
122
+ @core.builtin
123
+ @_check_dtype(dtypes=["fp32", "fp64"])
124
+ @_add_math_1arg_docstr("logarithm (base 2)")
125
+ @core._tensor_member_fn
126
+ def log2(x, _builder=None):
127
+ x = core._to_tensor(x, _builder)
128
+ return core.tensor(_builder.create_log2(x.handle), x.type)
129
+
130
+
131
+ @core.builtin
132
+ @_check_dtype(dtypes=["fp32", "fp64"])
133
+ @_add_math_1arg_docstr("cosine")
134
+ @core._tensor_member_fn
135
+ def cos(x, _builder=None):
136
+ x = core._to_tensor(x, _builder)
137
+ return core.tensor(_builder.create_cos(x.handle), x.type)
138
+
139
+
140
+ @core.builtin
141
+ @_check_dtype(dtypes=["fp32", "fp64"])
142
+ @_add_math_1arg_docstr("sine")
143
+ @core._tensor_member_fn
144
+ def sin(x, _builder=None):
145
+ x = core._to_tensor(x, _builder)
146
+ return core.tensor(_builder.create_sin(x.handle), x.type)
147
+
148
+
149
+ @core.builtin
150
+ @_check_dtype(dtypes=["fp32", "fp64"])
151
+ @_add_math_1arg_docstr("fast square root")
152
+ @core._tensor_member_fn
153
+ def sqrt(x, _builder=None):
154
+ x = core._to_tensor(x, _builder)
155
+ return core.tensor(_builder.create_sqrt(x.handle), x.type)
156
+
157
+
158
+ @core.builtin
159
+ @_check_dtype(dtypes=["fp32"])
160
+ @_add_math_1arg_docstr("precise square root (rounding to nearest)")
161
+ @core._tensor_member_fn
162
+ def sqrt_rn(x, _builder=None):
163
+ x = core._to_tensor(x, _builder)
164
+ return core.tensor(_builder.create_precise_sqrt(x.handle), x.type)
165
+
166
+
167
+ @core.builtin
168
+ @_check_dtype(dtypes=["fp32", "fp64"])
169
+ @_add_math_1arg_docstr("inverse square root")
170
+ @core._tensor_member_fn
171
+ def rsqrt(x, _builder=None):
172
+ x = core._to_tensor(x, _builder)
173
+ return core.tensor(_builder.create_rsqrt(x.handle), x.type)
174
+
175
+
176
+ @core.builtin
177
+ @_add_math_1arg_docstr("absolute value")
178
+ @core._tensor_member_fn
179
+ def abs(x, _builder=None):
180
+ x = core._to_tensor(x, _builder)
181
+ dtype = x.dtype
182
+ if dtype.is_fp8e4b15():
183
+ mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder)
184
+ return core.tensor(_builder.create_and(x.handle, mask.handle), x.type)
185
+ elif dtype.is_floating():
186
+ return core.tensor(_builder.create_fabs(x.handle), x.type)
187
+ elif dtype.is_int_signed():
188
+ return core.tensor(_builder.create_iabs(x.handle), x.type)
189
+ elif dtype.is_int_unsigned():
190
+ return x # no-op
191
+ else:
192
+ assert False, f"Unexpected dtype {dtype}"
193
+
194
+
195
+ @core.builtin
196
+ @_add_math_2arg_docstr("fast division")
197
+ def fdiv(x, y, ieee_rounding=False, _builder=None):
198
+ ieee_rounding = core._constexpr_to_value(ieee_rounding)
199
+ x = core._to_tensor(x, _builder)
200
+ y = core._to_tensor(y, _builder)
201
+ return semantic.fdiv(x, y, ieee_rounding, _builder)
202
+
203
+
204
+ @core.builtin
205
+ @_check_dtype(dtypes=["fp32"])
206
+ @_add_math_2arg_docstr("precise division (rounding to nearest)")
207
+ def div_rn(x, y, _builder=None):
208
+ x = core._to_tensor(x, _builder)
209
+ y = core._to_tensor(y, _builder)
210
+ x, y = core.binary_op_type_legalization(x, y, _builder)
211
+ return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type)
212
+
213
+
214
+ @core.builtin
215
+ @_check_dtype(dtypes=["fp32", "fp64"])
216
+ @_add_math_1arg_docstr("error function")
217
+ @core._tensor_member_fn
218
+ def erf(x, _builder=None):
219
+ x = core._to_tensor(x, _builder)
220
+ return core.tensor(_builder.create_erf(x.handle), x.type)
221
+
222
+
223
+ @core.builtin
224
+ @_check_dtype(dtypes=["fp32", "fp64"])
225
+ @_add_math_1arg_docstr("floor")
226
+ @core._tensor_member_fn
227
+ def floor(x, _builder=None):
228
+ x = core._to_tensor(x, _builder)
229
+ return core.tensor(_builder.create_floor(x.handle), x.type)
230
+
231
+
232
+ @core.builtin
233
+ @_check_dtype(dtypes=["fp32", "fp64"])
234
+ @_add_math_1arg_docstr("ceil")
235
+ @core._tensor_member_fn
236
+ def ceil(x, _builder=None):
237
+ x = core._to_tensor(x, _builder)
238
+ return core.tensor(_builder.create_ceil(x.handle), x.type)
239
+
240
+
241
+ @core.builtin
242
+ @_add_math_3arg_docstr("fused multiply-add")
243
+ def fma(x, y, z, _builder=None):
244
+ x = core._to_tensor(x, _builder)
245
+ y = core._to_tensor(y, _builder)
246
+ z = core._to_tensor(z, _builder)
247
+ x, y = core.binary_op_type_legalization(x, y, _builder)
248
+ z, x = core.binary_op_type_legalization(z, x, _builder)
249
+ z, y = core.binary_op_type_legalization(z, y, _builder)
250
+ return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type)
@@ -0,0 +1,207 @@
1
+ from ..runtime.jit import jit
2
+ from . import core as tl
3
+ from . import math
4
+
5
+ N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox
6
+
7
+ # -------------------
8
+ # randint
9
+ # -------------------
10
+
11
+
12
+ @jit
13
+ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
14
+ """
15
+ Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
16
+ """
17
+ if c0.dtype == tl.uint32:
18
+ PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
19
+ PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
20
+ PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
21
+ PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
22
+ else:
23
+ tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl")
24
+ PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15
25
+ PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B
26
+ PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93
27
+ PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157
28
+
29
+ for _ in tl.static_range(n_rounds):
30
+ # for _ in range(n_rounds):
31
+ # update random state
32
+ A = PHILOX_ROUND_A
33
+ B = PHILOX_ROUND_B
34
+ _c0, _c2 = c0, c2
35
+ c0 = math.umulhi(B, _c2) ^ c1 ^ k0
36
+ c2 = math.umulhi(A, _c0) ^ c3 ^ k1
37
+ c1 = B * _c2
38
+ c3 = A * _c0
39
+ # raise key
40
+ k0 = k0 + PHILOX_KEY_A
41
+ k1 = k1 + PHILOX_KEY_B
42
+ return c0, c1, c2, c3
43
+
44
+
45
+ @jit
46
+ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
47
+ seed = tl.to_tensor(seed)
48
+ c0 = tl.to_tensor(c0)
49
+ c1 = tl.to_tensor(c1)
50
+ c2 = tl.to_tensor(c2)
51
+ c3 = tl.to_tensor(c3)
52
+ seed = seed.to(tl.uint64)
53
+ if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
54
+ int_dtype = tl.uint32
55
+ seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
56
+ seed_lo = (seed & 0xffffffff).to(tl.uint32)
57
+ else:
58
+ tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox")
59
+ int_dtype = tl.uint64
60
+ seed_hi = tl.full((1, ), 0, dtype=int_dtype)
61
+ seed_lo = seed
62
+ c0 = c0.to(int_dtype, bitcast=True)
63
+ c1 = c1.to(int_dtype, bitcast=True)
64
+ c2 = c2.to(int_dtype, bitcast=True)
65
+ c3 = c3.to(int_dtype, bitcast=True)
66
+ return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
67
+
68
+
69
+ @jit
70
+ def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
71
+ """
72
+ Given a :code:`seed` scalar and an :code:`offset` block, returns a single
73
+ block of random :code:`int32`.
74
+
75
+ If you need multiple streams of random numbers,
76
+ using `randint4x` is likely to be faster than calling `randint` 4 times.
77
+
78
+ :param seed: The seed for generating random numbers.
79
+ :param offset: The offsets to generate random numbers for.
80
+ """
81
+ ret, _, _, _ = randint4x(seed, offset, n_rounds)
82
+ return ret
83
+
84
+
85
+ @jit
86
+ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
87
+ """
88
+ Given a :code:`seed` scalar and an :code:`offset` block, returns four
89
+ blocks of random :code:`int32`.
90
+
91
+ This is the maximally efficient entry point
92
+ to Triton's Philox pseudo-random number generator.
93
+
94
+ :param seed: The seed for generating random numbers.
95
+ :param offsets: The offsets to generate random numbers for.
96
+ """
97
+ # _0 = tl.zeros(offset.shape, offset.dtype)
98
+ _0 = offset * 0
99
+ return philox(seed, offset, _0, _0, _0, n_rounds)
100
+
101
+
102
+ # -------------------
103
+ # rand
104
+ # -------------------
105
+
106
+ # @jit
107
+ # def uint32_to_uniform_float(x):
108
+ # """
109
+ # Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
110
+ # """
111
+ # two_to_the_minus_32: tl.constexpr = 2.328306e-10
112
+ # return x * two_to_the_minus_32
113
+
114
+
115
+ @jit
116
+ def uint_to_uniform_float(x):
117
+ """
118
+ Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
119
+ """
120
+ # TODO: fix frontend issues and cleanup
121
+ # conditions can be simplified
122
+ # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
123
+ if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32):
124
+ # maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
125
+ x = x.to(tl.int32, bitcast=True)
126
+ scale = 4.6566127342e-10
127
+ else:
128
+ tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64))
129
+ x = x.to(tl.int64, bitcast=True)
130
+ scale = 1.0842020432385337e-19
131
+ x = tl.where(x < 0, -x - 1, x)
132
+ return x * scale
133
+
134
+
135
+ @jit
136
+ def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
137
+ """
138
+ Given a :code:`seed` scalar and an :code:`offset` block,
139
+ returns a block of random :code:`float32` in :math:`U(0, 1)`.
140
+
141
+ :param seed: The seed for generating random numbers.
142
+ :param offsets: The offsets to generate random numbers for.
143
+ """
144
+ source = randint(seed, offset, n_rounds)
145
+ return uint_to_uniform_float(source)
146
+
147
+
148
+ @jit
149
+ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
150
+ """
151
+ Given a :code:`seed` scalar and an :code:`offsets` block,
152
+ returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`.
153
+
154
+ :param seed: The seed for generating random numbers.
155
+ :param offsets: The offsets to generate random numbers for.
156
+ """
157
+ i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
158
+ u1 = uint_to_uniform_float(i1)
159
+ u2 = uint_to_uniform_float(i2)
160
+ u3 = uint_to_uniform_float(i3)
161
+ u4 = uint_to_uniform_float(i4)
162
+ return u1, u2, u3, u4
163
+
164
+
165
+ # -------------------
166
+ # randn
167
+ # -------------------
168
+
169
+
170
+ @jit
171
+ def pair_uniform_to_normal(u1, u2):
172
+ """Box-Muller transform"""
173
+ u1 = tl.maximum(1.0e-7, u1)
174
+ th = 6.283185307179586 * u2
175
+ r = math.sqrt(-2.0 * math.log(u1))
176
+ return r * math.cos(th), r * math.sin(th)
177
+
178
+
179
+ @jit
180
+ def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
181
+ """
182
+ Given a :code:`seed` scalar and an :code:`offset` block,
183
+ returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
184
+
185
+ :param seed: The seed for generating random numbers.
186
+ :param offsets: The offsets to generate random numbers for.
187
+ """
188
+ i1, i2, _, _ = randint4x(seed, offset, n_rounds)
189
+ u1 = uint_to_uniform_float(i1)
190
+ u2 = uint_to_uniform_float(i2)
191
+ n1, _ = pair_uniform_to_normal(u1, u2)
192
+ return n1
193
+
194
+
195
+ @jit
196
+ def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
197
+ """
198
+ Given a :code:`seed` scalar and an :code:`offset` block,
199
+ returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
200
+
201
+ :param seed: The seed for generating random numbers.
202
+ :param offsets: The offsets to generate random numbers for.
203
+ """
204
+ u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
205
+ n1, n2 = pair_uniform_to_normal(u1, u2)
206
+ n3, n4 = pair_uniform_to_normal(u3, u4)
207
+ return n1, n2, n3, n4