triton-windows 3.2.0.post11__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
@@ -0,0 +1,109 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def globaltimer(_builder=None):
6
+ return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
7
+ _builder=_builder)
8
+
9
+
10
+ @core.extern
11
+ def smid(_builder=None):
12
+ return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
13
+ _builder=_builder)
14
+
15
+
16
+ @core.builtin
17
+ def num_threads(_builder=None):
18
+ return core.constexpr(_builder.options.num_warps * 32)
19
+
20
+
21
+ @core.builtin
22
+ def num_warps(_builder=None):
23
+ return core.constexpr(_builder.options.num_warps)
24
+
25
+
26
+ # ----- FP8E4M3B15 ------
27
+ # This data-type is a variant of the standard FP8E4M3 format.
28
+ # It was designed for fast software conversion to FP16 on
29
+ # nvidia GPUs that do not support it natively.
30
+ # This is the same format as FP8E4M3Nv, but:
31
+ # - the exponent bias is 15 instead of 7
32
+ # - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
33
+ @core.builtin
34
+ def convert_fp8e4b15_to_float16(arg, _builder=None):
35
+ return core.inline_asm_elementwise(
36
+ "{ \n"
37
+ ".reg .b32 a<2>, b<2>; \n"
38
+ "prmt.b32 a0, 0, $2, 0x5746; \n"
39
+ "and.b32 b0, a0, 0x7f007f00; \n"
40
+ "and.b32 b1, a0, 0x00ff00ff; \n"
41
+ "and.b32 a1, a0, 0x00800080; \n"
42
+ "shr.b32 b0, b0, 1; \n"
43
+ "add.u32 b1, b1, a1; \n"
44
+ "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
45
+ "shl.b32 $1, b1, 7; \n"
46
+ "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4,
47
+ _builder=_builder)
48
+
49
+
50
+ @core.builtin
51
+ def convert_float16_to_fp8e4b15(arg, has_minx2, _builder=None):
52
+ asm = """{
53
+ .reg .pred p<4>;
54
+ .reg .b32 a<2>, b<2>;
55
+ .reg .b16 c<4>;
56
+ .reg .b16 max_val_f16;
57
+ .reg .b32 max_val_f16x2;
58
+ mov.b16 max_val_f16, 0x3F00;
59
+ mov.b32 max_val_f16x2, 0x3F003F00;
60
+ and.b32 a0, $1, 0x7fff7fff;
61
+ and.b32 a1, $2, 0x7fff7fff;"""
62
+ if has_minx2:
63
+ asm += """min.f16x2 a0, a0, max_val_f16x2;
64
+ min.f16x2 a1, a1, max_val_f16x2;"""
65
+ else:
66
+ asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2;
67
+ setp.lt.f16x2 p2|p3, a1, max_val_f16x2;
68
+ mov.b32 {c0, c1}, a0;
69
+ mov.b32 {c2, c3}, a1;
70
+ selp.b16 c0, c0, max_val_f16, p0;
71
+ selp.b16 c1, c1, max_val_f16, p1;
72
+ selp.b16 c2, c2, max_val_f16, p2;
73
+ selp.b16 c3, c3, max_val_f16, p3;
74
+ mov.b32 a0, {c0, c1};
75
+ mov.b32 a1, {c2, c3};"""
76
+ asm += """mad.lo.u32 a0, a0, 2, 0x00800080;
77
+ mad.lo.u32 a1, a1, 2, 0x00800080;
78
+ lop3.b32 b0, $1, 0x80008000, a0, 0xea;
79
+ lop3.b32 b1, $2, 0x80008000, a1, 0xea;
80
+ prmt.b32 $0, b0, b1, 0x7531;
81
+ }"""
82
+ return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4,
83
+ _builder=_builder)
84
+
85
+
86
+ @core.builtin
87
+ def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _builder=None):
88
+ if arg.type.scalar.is_fp8e4b15():
89
+ upcast_val = convert_fp8e4b15_to_float16(arg, _builder=_builder)
90
+ if dst_ty.scalar.is_fp32():
91
+ upcast_val = upcast_val.to(core.float32, _builder=_builder)
92
+ return upcast_val
93
+
94
+ assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32()
95
+ downcast_val = arg
96
+ if arg.type.scalar.is_fp32():
97
+ downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _builder=_builder)
98
+ downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _builder=_builder)
99
+ return downcast_val
100
+
101
+
102
+ @core.builtin
103
+ def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _builder=None):
104
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _builder=_builder)
105
+
106
+
107
+ @core.builtin
108
+ def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _builder=None):
109
+ return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _builder=_builder)
@@ -0,0 +1,3 @@
1
+ from . import libdevice
2
+
3
+ __all__ = ["libdevice"]
@@ -0,0 +1,475 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def abs(arg0, _builder=None):
6
+ return core.extern_elementwise(
7
+ "", "", [arg0], {
8
+ (core.dtype("int32"), ): ("__triton_hip_iabs", core.dtype("int32")),
9
+ (core.dtype("int64"), ): ("__triton_hip_iabs", core.dtype("int64")),
10
+ (core.dtype("fp32"), ): ("__triton_hip_fabs", core.dtype("fp32")),
11
+ (core.dtype("fp64"), ): ("__triton_hip_fabs", core.dtype("fp64")),
12
+ }, is_pure=True, _builder=_builder)
13
+
14
+
15
+ @core.extern
16
+ def floor(arg0, _builder=None):
17
+ return core.extern_elementwise(
18
+ "", "", [arg0], {
19
+ (core.dtype("fp32"), ): ("__ocml_floor_f32", core.dtype("fp32")),
20
+ (core.dtype("fp64"), ): ("__ocml_floor_f64", core.dtype("fp64")),
21
+ }, is_pure=True, _builder=_builder)
22
+
23
+
24
+ @core.extern
25
+ def rsqrt(arg0, _builder=None):
26
+ return core.extern_elementwise(
27
+ "", "", [arg0], {
28
+ (core.dtype("fp32"), ): ("__ocml_rsqrt_f32", core.dtype("fp32")),
29
+ (core.dtype("fp64"), ): ("__ocml_rsqrt_f64", core.dtype("fp64")),
30
+ }, is_pure=True, _builder=_builder)
31
+
32
+
33
+ @core.extern
34
+ def ceil(arg0, _builder=None):
35
+ return core.extern_elementwise(
36
+ "", "", [arg0], {
37
+ (core.dtype("fp32"), ): ("__ocml_ceil_f32", core.dtype("fp32")),
38
+ (core.dtype("fp64"), ): ("__ocml_ceil_f64", core.dtype("fp64")),
39
+ }, is_pure=True, _builder=_builder)
40
+
41
+
42
+ @core.extern
43
+ def trunc(arg0, _builder=None):
44
+ return core.extern_elementwise(
45
+ "", "", [arg0], {
46
+ (core.dtype("fp32"), ): ("__ocml_trunc_f32", core.dtype("fp32")),
47
+ (core.dtype("fp64"), ): ("__ocml_trunc_f64", core.dtype("fp64")),
48
+ }, is_pure=True, _builder=_builder)
49
+
50
+
51
+ @core.extern
52
+ def exp2(arg0, _builder=None):
53
+ return core.extern_elementwise(
54
+ "", "", [arg0], {
55
+ (core.dtype("fp32"), ): ("__ocml_exp2_f32", core.dtype("fp32")),
56
+ (core.dtype("fp64"), ): ("__ocml_exp2_f64", core.dtype("fp64")),
57
+ }, is_pure=True, _builder=_builder)
58
+
59
+
60
+ @core.extern
61
+ def exp(arg0, _builder=None):
62
+ return core.extern_elementwise(
63
+ "", "", [arg0], {
64
+ (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")),
65
+ (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")),
66
+ }, is_pure=True, _builder=_builder)
67
+
68
+
69
+ @core.extern
70
+ def fast_expf(arg0, _builder=None):
71
+ return core.extern_elementwise("", "", [arg0], {
72
+ (core.dtype("fp32"), ): ("__triton_hip_fast_expf", core.dtype("fp32")),
73
+ }, is_pure=True, _builder=_builder)
74
+
75
+
76
+ @core.extern
77
+ def fast_dividef(arg0, arg1, _builder=None):
78
+ return core.extern_elementwise("", "", [arg0, arg1], {
79
+ (core.dtype("fp32"), core.dtype("fp32")): ("__triton_hip_fast_fdividef", core.dtype("fp32")),
80
+ }, is_pure=True, _builder=_builder)
81
+
82
+
83
+ @core.extern
84
+ def sqrt(arg0, _builder=None):
85
+ return core.extern_elementwise(
86
+ "", "", [arg0], {
87
+ (core.dtype("fp32"), ): ("__ocml_sqrt_f32", core.dtype("fp32")),
88
+ (core.dtype("fp64"), ): ("__ocml_sqrt_f64", core.dtype("fp64")),
89
+ }, is_pure=True, _builder=_builder)
90
+
91
+
92
+ @core.extern
93
+ def llrint(arg0, _builder=None):
94
+ return core.extern_elementwise(
95
+ "", "", [arg0], {
96
+ (core.dtype("fp32"), ): ("__triton_hip_llrint", core.dtype("int64")),
97
+ (core.dtype("fp64"), ): ("__triton_hip_llrint", core.dtype("int64")),
98
+ }, is_pure=True, _builder=_builder)
99
+
100
+
101
+ @core.extern
102
+ def nearbyint(arg0, _builder=None):
103
+ return core.extern_elementwise(
104
+ "", "", [
105
+ arg0,
106
+ ], {
107
+ (core.dtype("fp32"), ): ("__ocml_nearbyint_f32", core.dtype("fp32")),
108
+ (core.dtype("fp64"), ): ("__ocml_nearbyint_f64", core.dtype("fp64")),
109
+ }, is_pure=True, _builder=_builder)
110
+
111
+
112
+ @core.extern
113
+ def isnan(arg0, _builder=None):
114
+ return core.extern_elementwise(
115
+ "", "", [
116
+ arg0,
117
+ ], {
118
+ (core.dtype("fp32"), ): ("__ocml_isnan_f32", core.dtype("int32")),
119
+ (core.dtype("fp64"), ): ("__ocml_isnan_f64", core.dtype("int32")),
120
+ }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder)
121
+
122
+
123
+ @core.extern
124
+ def signbit(arg0, _builder=None):
125
+ return core.extern_elementwise(
126
+ "", "", [
127
+ arg0,
128
+ ], {
129
+ (core.dtype("fp32"), ): ("__ocml_signbit_f32", core.dtype("int32")),
130
+ (core.dtype("fp64"), ): ("__ocml_signbit_f64", core.dtype("int32")),
131
+ }, is_pure=True, _builder=_builder)
132
+
133
+
134
+ @core.extern
135
+ def copysign(arg0, arg1, _builder=None):
136
+ return core.extern_elementwise(
137
+ "", "", [arg0, arg1], {
138
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_copysign_f32", core.dtype("fp32")),
139
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_copysign_f64", core.dtype("fp64")),
140
+ }, is_pure=True, _builder=_builder)
141
+
142
+
143
+ @core.extern
144
+ def isinf(arg0, _builder=None):
145
+ return core.extern_elementwise(
146
+ "", "", [arg0], {
147
+ (core.dtype("fp32"), ): ("__ocml_isinf_f32", core.dtype("int32")),
148
+ (core.dtype("fp64"), ): ("__ocml_isinf_f64", core.dtype("int32")),
149
+ }, is_pure=True, _builder=_builder).to(core.int1, _builder=_builder)
150
+
151
+
152
+ @core.extern
153
+ def nextafter(arg0, arg1, _builder=None):
154
+ return core.extern_elementwise(
155
+ "", "", [arg0, arg1], {
156
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_nextafter_f32", core.dtype("fp32")),
157
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_nextafter_f64", core.dtype("fp64")),
158
+ }, is_pure=True, _builder=_builder)
159
+
160
+
161
+ @core.extern
162
+ def sin(arg0, _builder=None):
163
+ return core.extern_elementwise(
164
+ "", "", [arg0], {
165
+ (core.dtype("fp32"), ): ("__ocml_sin_f32", core.dtype("fp32")),
166
+ (core.dtype("fp64"), ): ("__ocml_sin_f64", core.dtype("fp64")),
167
+ }, is_pure=True, _builder=_builder)
168
+
169
+
170
+ @core.extern
171
+ def cos(arg0, _builder=None):
172
+ return core.extern_elementwise(
173
+ "", "", [arg0], {
174
+ (core.dtype("fp32"), ): ("__ocml_cos_f32", core.dtype("fp32")),
175
+ (core.dtype("fp64"), ): ("__ocml_cos_f64", core.dtype("fp64")),
176
+ }, is_pure=True, _builder=_builder)
177
+
178
+
179
+ @core.extern
180
+ def tan(arg0, _builder=None):
181
+ return core.extern_elementwise(
182
+ "", "", [arg0], {
183
+ (core.dtype("fp32"), ): ("__ocml_tan_f32", core.dtype("fp32")),
184
+ (core.dtype("fp64"), ): ("__ocml_tan_f64", core.dtype("fp64")),
185
+ }, is_pure=True, _builder=_builder)
186
+
187
+
188
+ @core.extern
189
+ def log2(arg0, _builder=None):
190
+ return core.extern_elementwise(
191
+ "", "", [arg0], {
192
+ (core.dtype("fp32"), ): ("__ocml_log2_f32", core.dtype("fp32")),
193
+ (core.dtype("fp64"), ): ("__ocml_log2_f64", core.dtype("fp64")),
194
+ }, is_pure=True, _builder=_builder)
195
+
196
+
197
+ @core.extern
198
+ def cosh(arg0, _builder=None):
199
+ return core.extern_elementwise(
200
+ "", "", [arg0], {
201
+ (core.dtype("fp32"), ): ("__ocml_cosh_f32", core.dtype("fp32")),
202
+ (core.dtype("fp64"), ): ("__ocml_cosh_f64", core.dtype("fp64")),
203
+ }, is_pure=True, _builder=_builder)
204
+
205
+
206
+ @core.extern
207
+ def sinh(arg0, _builder=None):
208
+ return core.extern_elementwise(
209
+ "", "", [arg0], {
210
+ (core.dtype("fp32"), ): ("__ocml_sinh_f32", core.dtype("fp32")),
211
+ (core.dtype("fp64"), ): ("__ocml_sinh_f64", core.dtype("fp64")),
212
+ }, is_pure=True, _builder=_builder)
213
+
214
+
215
+ @core.extern
216
+ def tanh(arg0, _builder=None):
217
+ return core.extern_elementwise(
218
+ "", "", [arg0], {
219
+ (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")),
220
+ (core.dtype("fp64"), ): ("__ocml_tanh_f64", core.dtype("fp64")),
221
+ }, is_pure=True, _builder=_builder)
222
+
223
+
224
+ @core.extern
225
+ def atan2(arg0, arg1, _builder=None):
226
+ return core.extern_elementwise(
227
+ "", "", [arg0, arg1], {
228
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")),
229
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f64", core.dtype("fp64")),
230
+ }, is_pure=True, _builder=_builder)
231
+
232
+
233
+ @core.extern
234
+ def atan(arg0, _builder=None):
235
+ return core.extern_elementwise(
236
+ "", "", [arg0], {
237
+ (core.dtype("fp32"), ): ("__ocml_atan_f32", core.dtype("fp32")),
238
+ (core.dtype("fp64"), ): ("__ocml_atan_f64", core.dtype("fp64")),
239
+ }, is_pure=True, _builder=_builder)
240
+
241
+
242
+ @core.extern
243
+ def asin(arg0, _builder=None):
244
+ return core.extern_elementwise(
245
+ "", "", [arg0], {
246
+ (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")),
247
+ (core.dtype("fp64"), ): ("__ocml_asin_f64", core.dtype("fp64")),
248
+ }, is_pure=True, _builder=_builder)
249
+
250
+
251
+ @core.extern
252
+ def acos(arg0, _builder=None):
253
+ return core.extern_elementwise(
254
+ "", "", [arg0], {
255
+ (core.dtype("fp32"), ): ("__ocml_acos_f32", core.dtype("fp32")),
256
+ (core.dtype("fp64"), ): ("__ocml_acos_f64", core.dtype("fp64")),
257
+ }, is_pure=True, _builder=_builder)
258
+
259
+
260
+ @core.extern
261
+ def log(arg0, _builder=None):
262
+ return core.extern_elementwise(
263
+ "", "", [arg0], {
264
+ (core.dtype("fp32"), ): ("__ocml_log_f32", core.dtype("fp32")),
265
+ (core.dtype("fp64"), ): ("__ocml_log_f64", core.dtype("fp64")),
266
+ }, is_pure=True, _builder=_builder)
267
+
268
+
269
+ @core.extern
270
+ def log10(arg0, _builder=None):
271
+ return core.extern_elementwise(
272
+ "", "", [arg0], {
273
+ (core.dtype("fp32"), ): ("__ocml_log10_f32", core.dtype("fp32")),
274
+ (core.dtype("fp64"), ): ("__ocml_log10_f64", core.dtype("fp64")),
275
+ }, is_pure=True, _builder=_builder)
276
+
277
+
278
+ @core.extern
279
+ def log1p(arg0, _builder=None):
280
+ return core.extern_elementwise(
281
+ "", "", [arg0], {
282
+ (core.dtype("fp32"), ): ("__ocml_log1p_f32", core.dtype("fp32")),
283
+ (core.dtype("fp64"), ): ("__ocml_log1p_f64", core.dtype("fp64")),
284
+ }, is_pure=True, _builder=_builder)
285
+
286
+
287
+ @core.extern
288
+ def acosh(arg0, _builder=None):
289
+ return core.extern_elementwise(
290
+ "", "", [arg0], {
291
+ (core.dtype("fp32"), ): ("__ocml_acosh_f32", core.dtype("fp32")),
292
+ (core.dtype("fp64"), ): ("__ocml_acosh_f64", core.dtype("fp64")),
293
+ }, is_pure=True, _builder=_builder)
294
+
295
+
296
+ @core.extern
297
+ def asinh(arg0, _builder=None):
298
+ return core.extern_elementwise(
299
+ "", "", [arg0], {
300
+ (core.dtype("fp32"), ): ("__ocml_asinh_f32", core.dtype("fp32")),
301
+ (core.dtype("fp64"), ): ("__ocml_asinh_f64", core.dtype("fp64")),
302
+ }, is_pure=True, _builder=_builder)
303
+
304
+
305
+ @core.extern
306
+ def atanh(arg0, _builder=None):
307
+ return core.extern_elementwise(
308
+ "", "", [arg0], {
309
+ (core.dtype("fp32"), ): ("__ocml_atanh_f32", core.dtype("fp32")),
310
+ (core.dtype("fp64"), ): ("__ocml_atanh_f64", core.dtype("fp64")),
311
+ }, is_pure=True, _builder=_builder)
312
+
313
+
314
+ @core.extern
315
+ def expm1(arg0, _builder=None):
316
+ return core.extern_elementwise(
317
+ "", "", [arg0], {
318
+ (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")),
319
+ (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")),
320
+ }, is_pure=True, _builder=_builder)
321
+
322
+
323
+ @core.extern
324
+ def hypot(arg0, arg1, _builder=None):
325
+ return core.extern_elementwise(
326
+ "", "", [arg0, arg1], {
327
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_hypot_f32", core.dtype("fp32")),
328
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_hypot_f64", core.dtype("fp64")),
329
+ }, is_pure=True, _builder=_builder)
330
+
331
+
332
+ @core.extern
333
+ def j0(arg0, _builder=None):
334
+ return core.extern_elementwise(
335
+ "", "", [arg0], {
336
+ (core.dtype("fp32"), ): ("__ocml_j0_f32", core.dtype("fp32")),
337
+ (core.dtype("fp64"), ): ("__ocml_j0_f64", core.dtype("fp64")),
338
+ }, is_pure=True, _builder=_builder)
339
+
340
+
341
+ @core.extern
342
+ def j1(arg0, _builder=None):
343
+ return core.extern_elementwise(
344
+ "", "", [arg0], {
345
+ (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")),
346
+ (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")),
347
+ }, is_pure=True, _builder=_builder)
348
+
349
+
350
+ @core.extern
351
+ def y0(arg0, _builder=None):
352
+ return core.extern_elementwise(
353
+ "", "", [arg0], {
354
+ (core.dtype("fp32"), ): ("__ocml_y0_f32", core.dtype("fp32")),
355
+ (core.dtype("fp64"), ): ("__ocml_y0_f64", core.dtype("fp64")),
356
+ }, is_pure=True, _builder=_builder)
357
+
358
+
359
+ @core.extern
360
+ def y1(arg0, _builder=None):
361
+ return core.extern_elementwise(
362
+ "", "", [arg0], {
363
+ (core.dtype("fp32"), ): ("__ocml_y1_f32", core.dtype("fp32")),
364
+ (core.dtype("fp64"), ): ("__ocml_y1_f64", core.dtype("fp64")),
365
+ }, is_pure=True, _builder=_builder)
366
+
367
+
368
+ @core.extern
369
+ def cyl_bessel_i0(arg0, _builder=None):
370
+ return core.extern_elementwise(
371
+ "", "", [arg0], {
372
+ (core.dtype("fp32"), ): ("__ocml_i0_f32", core.dtype("fp32")),
373
+ (core.dtype("fp64"), ): ("__ocml_i0_f64", core.dtype("fp64")),
374
+ }, is_pure=True, _builder=_builder)
375
+
376
+
377
+ @core.extern
378
+ def cyl_bessel_i1(arg0, _builder=None):
379
+ return core.extern_elementwise(
380
+ "", "", [arg0], {
381
+ (core.dtype("fp32"), ): ("__ocml_i1_f32", core.dtype("fp32")),
382
+ (core.dtype("fp64"), ): ("__ocml_i1_f64", core.dtype("fp64")),
383
+ }, is_pure=True, _builder=_builder)
384
+
385
+
386
+ @core.extern
387
+ def erf(arg0, _builder=None):
388
+ return core.extern_elementwise(
389
+ "", "", [arg0], {
390
+ (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")),
391
+ (core.dtype("fp64"), ): ("__ocml_erf_f64", core.dtype("fp64")),
392
+ }, is_pure=True, _builder=_builder)
393
+
394
+
395
+ @core.extern
396
+ def erfinv(arg0, _builder=None):
397
+ return core.extern_elementwise(
398
+ "", "", [arg0], {
399
+ (core.dtype("fp32"), ): ("__ocml_erfinv_f32", core.dtype("fp32")),
400
+ (core.dtype("fp64"), ): ("__ocml_erfinv_f64", core.dtype("fp64")),
401
+ }, is_pure=True, _builder=_builder)
402
+
403
+
404
+ @core.extern
405
+ def erfc(arg0, _builder=None):
406
+ return core.extern_elementwise(
407
+ "", "", [arg0], {
408
+ (core.dtype("fp32"), ): ("__ocml_erfc_f32", core.dtype("fp32")),
409
+ (core.dtype("fp64"), ): ("__ocml_erfc_f64", core.dtype("fp64")),
410
+ }, is_pure=True, _builder=_builder)
411
+
412
+
413
+ @core.extern
414
+ def erfcx(arg0, _builder=None):
415
+ return core.extern_elementwise(
416
+ "", "", [arg0], {
417
+ (core.dtype("fp32"), ): ("__ocml_erfcx_f32", core.dtype("fp32")),
418
+ (core.dtype("fp64"), ): ("__ocml_erfcx_f64", core.dtype("fp64")),
419
+ }, is_pure=True, _builder=_builder)
420
+
421
+
422
+ @core.extern
423
+ def lgamma(arg0, _builder=None):
424
+ return core.extern_elementwise(
425
+ "", "", [arg0], {
426
+ (core.dtype("fp32"), ): ("__ocml_lgamma_f32", core.dtype("fp32")),
427
+ (core.dtype("fp64"), ): ("__ocml_lgamma_f64", core.dtype("fp64")),
428
+ }, is_pure=True, _builder=_builder)
429
+
430
+
431
+ @core.extern
432
+ def ldexp(arg0, arg1, _builder=None):
433
+ return core.extern_elementwise(
434
+ "", "", [arg0, arg1], {
435
+ (core.dtype("fp32"), core.dtype("int32")): ("__ocml_ldexp_f32", core.dtype("fp32")),
436
+ (core.dtype("fp64"), core.dtype("int32")): ("__ocml_ldexp_f64", core.dtype("fp64")),
437
+ }, is_pure=True, _builder=_builder)
438
+
439
+
440
+ @core.extern
441
+ def fmod(arg0, arg1, _builder=None):
442
+ return core.extern_elementwise(
443
+ "", "", [arg0, arg1], {
444
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")),
445
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f64", core.dtype("fp64")),
446
+ }, is_pure=True, _builder=_builder)
447
+
448
+
449
+ @core.extern
450
+ def fma(arg0, arg1, arg2, _builder=None):
451
+ return core.extern_elementwise(
452
+ "", "", [arg0, arg1, arg2], {
453
+ (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fma_f32", core.dtype("fp32")),
454
+ (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fma_f64", core.dtype("fp64")),
455
+ }, is_pure=True, _builder=_builder)
456
+
457
+
458
+ @core.extern
459
+ def pow(arg0, arg1, _builder=None):
460
+ return core.extern_elementwise(
461
+ "", "", [arg0, arg1], {
462
+ (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")),
463
+ (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f64", core.dtype("fp64")),
464
+ (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")),
465
+ (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f64", core.dtype("fp64")),
466
+ }, is_pure=True, _builder=_builder)
467
+
468
+
469
+ @core.extern
470
+ def ilogb(arg0, _builder=None):
471
+ return core.extern_elementwise(
472
+ "", "", [arg0], {
473
+ (core.dtype("fp32"), ): ("__ocml_ilogb_f32", core.dtype("int32")),
474
+ (core.dtype("fp64"), ): ("__ocml_ilogb_f64", core.dtype("int32")),
475
+ }, is_pure=True, _builder=_builder)