triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/language/math.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from . import core
|
|
2
|
-
from . import semantic
|
|
3
2
|
from functools import wraps
|
|
4
3
|
from typing import List
|
|
5
4
|
|
|
@@ -85,107 +84,107 @@ def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]:
|
|
|
85
84
|
@core.builtin
|
|
86
85
|
@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"])
|
|
87
86
|
@_add_math_2arg_docstr("most significant N bits of the 2N-bit product")
|
|
88
|
-
def umulhi(x, y,
|
|
89
|
-
x =
|
|
90
|
-
y =
|
|
91
|
-
x, y = core.binary_op_type_legalization(x, y,
|
|
92
|
-
return core.tensor(
|
|
87
|
+
def umulhi(x, y, _semantic=None):
|
|
88
|
+
x = _semantic.to_tensor(x)
|
|
89
|
+
y = _semantic.to_tensor(y)
|
|
90
|
+
x, y = core.binary_op_type_legalization(x, y, _semantic)
|
|
91
|
+
return core.tensor(_semantic.builder.create_umulhi(x.handle, y.handle), x.type)
|
|
93
92
|
|
|
94
93
|
|
|
95
94
|
@core.builtin
|
|
96
95
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
97
96
|
@_add_math_1arg_docstr("exponential")
|
|
98
97
|
@core._tensor_member_fn
|
|
99
|
-
def exp(x,
|
|
100
|
-
x =
|
|
101
|
-
return core.tensor(
|
|
98
|
+
def exp(x, _semantic=None):
|
|
99
|
+
x = _semantic.to_tensor(x)
|
|
100
|
+
return core.tensor(_semantic.builder.create_exp(x.handle), x.type)
|
|
102
101
|
|
|
103
102
|
|
|
104
103
|
@core.builtin
|
|
105
104
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
106
105
|
@_add_math_1arg_docstr("exponential (base 2)")
|
|
107
106
|
@core._tensor_member_fn
|
|
108
|
-
def exp2(x,
|
|
109
|
-
x =
|
|
110
|
-
return core.tensor(
|
|
107
|
+
def exp2(x, _semantic=None):
|
|
108
|
+
x = _semantic.to_tensor(x)
|
|
109
|
+
return core.tensor(_semantic.builder.create_exp2(x.handle), x.type)
|
|
111
110
|
|
|
112
111
|
|
|
113
112
|
@core.builtin
|
|
114
113
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
115
114
|
@_add_math_1arg_docstr("natural logarithm")
|
|
116
115
|
@core._tensor_member_fn
|
|
117
|
-
def log(x,
|
|
118
|
-
x =
|
|
119
|
-
return core.tensor(
|
|
116
|
+
def log(x, _semantic=None):
|
|
117
|
+
x = _semantic.to_tensor(x)
|
|
118
|
+
return core.tensor(_semantic.builder.create_log(x.handle), x.type)
|
|
120
119
|
|
|
121
120
|
|
|
122
121
|
@core.builtin
|
|
123
122
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
124
123
|
@_add_math_1arg_docstr("logarithm (base 2)")
|
|
125
124
|
@core._tensor_member_fn
|
|
126
|
-
def log2(x,
|
|
127
|
-
x =
|
|
128
|
-
return core.tensor(
|
|
125
|
+
def log2(x, _semantic=None):
|
|
126
|
+
x = _semantic.to_tensor(x)
|
|
127
|
+
return core.tensor(_semantic.builder.create_log2(x.handle), x.type)
|
|
129
128
|
|
|
130
129
|
|
|
131
130
|
@core.builtin
|
|
132
131
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
133
132
|
@_add_math_1arg_docstr("cosine")
|
|
134
133
|
@core._tensor_member_fn
|
|
135
|
-
def cos(x,
|
|
136
|
-
x =
|
|
137
|
-
return core.tensor(
|
|
134
|
+
def cos(x, _semantic=None):
|
|
135
|
+
x = _semantic.to_tensor(x)
|
|
136
|
+
return core.tensor(_semantic.builder.create_cos(x.handle), x.type)
|
|
138
137
|
|
|
139
138
|
|
|
140
139
|
@core.builtin
|
|
141
140
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
142
141
|
@_add_math_1arg_docstr("sine")
|
|
143
142
|
@core._tensor_member_fn
|
|
144
|
-
def sin(x,
|
|
145
|
-
x =
|
|
146
|
-
return core.tensor(
|
|
143
|
+
def sin(x, _semantic=None):
|
|
144
|
+
x = _semantic.to_tensor(x)
|
|
145
|
+
return core.tensor(_semantic.builder.create_sin(x.handle), x.type)
|
|
147
146
|
|
|
148
147
|
|
|
149
148
|
@core.builtin
|
|
150
149
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
151
150
|
@_add_math_1arg_docstr("fast square root")
|
|
152
151
|
@core._tensor_member_fn
|
|
153
|
-
def sqrt(x,
|
|
154
|
-
x =
|
|
155
|
-
return core.tensor(
|
|
152
|
+
def sqrt(x, _semantic=None):
|
|
153
|
+
x = _semantic.to_tensor(x)
|
|
154
|
+
return core.tensor(_semantic.builder.create_sqrt(x.handle), x.type)
|
|
156
155
|
|
|
157
156
|
|
|
158
157
|
@core.builtin
|
|
159
158
|
@_check_dtype(dtypes=["fp32"])
|
|
160
159
|
@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)")
|
|
161
160
|
@core._tensor_member_fn
|
|
162
|
-
def sqrt_rn(x,
|
|
163
|
-
x =
|
|
164
|
-
return core.tensor(
|
|
161
|
+
def sqrt_rn(x, _semantic=None):
|
|
162
|
+
x = _semantic.to_tensor(x)
|
|
163
|
+
return core.tensor(_semantic.builder.create_precise_sqrt(x.handle), x.type)
|
|
165
164
|
|
|
166
165
|
|
|
167
166
|
@core.builtin
|
|
168
167
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
169
168
|
@_add_math_1arg_docstr("inverse square root")
|
|
170
169
|
@core._tensor_member_fn
|
|
171
|
-
def rsqrt(x,
|
|
172
|
-
x =
|
|
173
|
-
return core.tensor(
|
|
170
|
+
def rsqrt(x, _semantic=None):
|
|
171
|
+
x = _semantic.to_tensor(x)
|
|
172
|
+
return core.tensor(_semantic.builder.create_rsqrt(x.handle), x.type)
|
|
174
173
|
|
|
175
174
|
|
|
176
175
|
@core._tensor_member_fn
|
|
177
176
|
@core.builtin
|
|
178
177
|
@_add_math_1arg_docstr("absolute value")
|
|
179
|
-
def abs(x,
|
|
180
|
-
x =
|
|
178
|
+
def abs(x, _semantic=None):
|
|
179
|
+
x = _semantic.to_tensor(x)
|
|
181
180
|
dtype = x.dtype
|
|
182
181
|
if dtype.is_fp8e4b15():
|
|
183
|
-
mask = core.full(x.shape, 0x7F, core.int8,
|
|
184
|
-
return core.tensor(
|
|
182
|
+
mask = core.full(x.shape, 0x7F, core.int8, _semantic=_semantic)
|
|
183
|
+
return core.tensor(_semantic.builder.create_and(x.handle, mask.handle), x.type)
|
|
185
184
|
elif dtype.is_floating():
|
|
186
|
-
return core.tensor(
|
|
185
|
+
return core.tensor(_semantic.builder.create_fabs(x.handle), x.type)
|
|
187
186
|
elif dtype.is_int_signed():
|
|
188
|
-
return core.tensor(
|
|
187
|
+
return core.tensor(_semantic.builder.create_iabs(x.handle), x.type)
|
|
189
188
|
elif dtype.is_int_unsigned():
|
|
190
189
|
return x # no-op
|
|
191
190
|
else:
|
|
@@ -194,57 +193,57 @@ def abs(x, _builder=None):
|
|
|
194
193
|
|
|
195
194
|
@core.builtin
|
|
196
195
|
@_add_math_2arg_docstr("fast division")
|
|
197
|
-
def fdiv(x, y, ieee_rounding=False,
|
|
198
|
-
ieee_rounding = core.
|
|
199
|
-
x =
|
|
200
|
-
y =
|
|
201
|
-
return
|
|
196
|
+
def fdiv(x, y, ieee_rounding=False, _semantic=None):
|
|
197
|
+
ieee_rounding = core._unwrap_if_constexpr(ieee_rounding)
|
|
198
|
+
x = _semantic.to_tensor(x)
|
|
199
|
+
y = _semantic.to_tensor(y)
|
|
200
|
+
return _semantic.fdiv(x, y, ieee_rounding)
|
|
202
201
|
|
|
203
202
|
|
|
204
203
|
@core.builtin
|
|
205
204
|
@_check_dtype(dtypes=["fp32"])
|
|
206
205
|
@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)")
|
|
207
|
-
def div_rn(x, y,
|
|
208
|
-
x =
|
|
209
|
-
y =
|
|
210
|
-
x, y = core.binary_op_type_legalization(x, y,
|
|
211
|
-
return core.tensor(
|
|
206
|
+
def div_rn(x, y, _semantic=None):
|
|
207
|
+
x = _semantic.to_tensor(x)
|
|
208
|
+
y = _semantic.to_tensor(y)
|
|
209
|
+
x, y = core.binary_op_type_legalization(x, y, _semantic)
|
|
210
|
+
return core.tensor(_semantic.builder.create_precise_divf(x.handle, y.handle), x.type)
|
|
212
211
|
|
|
213
212
|
|
|
214
213
|
@core.builtin
|
|
215
214
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
216
215
|
@_add_math_1arg_docstr("error function")
|
|
217
216
|
@core._tensor_member_fn
|
|
218
|
-
def erf(x,
|
|
219
|
-
x =
|
|
220
|
-
return core.tensor(
|
|
217
|
+
def erf(x, _semantic=None):
|
|
218
|
+
x = _semantic.to_tensor(x)
|
|
219
|
+
return core.tensor(_semantic.builder.create_erf(x.handle), x.type)
|
|
221
220
|
|
|
222
221
|
|
|
223
222
|
@core.builtin
|
|
224
223
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
225
224
|
@_add_math_1arg_docstr("floor")
|
|
226
225
|
@core._tensor_member_fn
|
|
227
|
-
def floor(x,
|
|
228
|
-
x =
|
|
229
|
-
return core.tensor(
|
|
226
|
+
def floor(x, _semantic=None):
|
|
227
|
+
x = _semantic.to_tensor(x)
|
|
228
|
+
return core.tensor(_semantic.builder.create_floor(x.handle), x.type)
|
|
230
229
|
|
|
231
230
|
|
|
232
231
|
@core.builtin
|
|
233
232
|
@_check_dtype(dtypes=["fp32", "fp64"])
|
|
234
233
|
@_add_math_1arg_docstr("ceil")
|
|
235
234
|
@core._tensor_member_fn
|
|
236
|
-
def ceil(x,
|
|
237
|
-
x =
|
|
238
|
-
return core.tensor(
|
|
235
|
+
def ceil(x, _semantic=None):
|
|
236
|
+
x = _semantic.to_tensor(x)
|
|
237
|
+
return core.tensor(_semantic.builder.create_ceil(x.handle), x.type)
|
|
239
238
|
|
|
240
239
|
|
|
241
240
|
@core.builtin
|
|
242
241
|
@_add_math_3arg_docstr("fused multiply-add")
|
|
243
|
-
def fma(x, y, z,
|
|
244
|
-
x =
|
|
245
|
-
y =
|
|
246
|
-
z =
|
|
247
|
-
x, y = core.binary_op_type_legalization(x, y,
|
|
248
|
-
z, x = core.binary_op_type_legalization(z, x,
|
|
249
|
-
z, y = core.binary_op_type_legalization(z, y,
|
|
250
|
-
return core.tensor(
|
|
242
|
+
def fma(x, y, z, _semantic=None):
|
|
243
|
+
x = _semantic.to_tensor(x)
|
|
244
|
+
y = _semantic.to_tensor(y)
|
|
245
|
+
z = _semantic.to_tensor(z)
|
|
246
|
+
x, y = core.binary_op_type_legalization(x, y, _semantic)
|
|
247
|
+
z, x = core.binary_op_type_legalization(z, x, _semantic)
|
|
248
|
+
z, y = core.binary_op_type_legalization(z, y, _semantic)
|
|
249
|
+
return core.tensor(_semantic.builder.create_fma(x.handle, y.handle, z.handle), x.type)
|
triton/language/random.py
CHANGED
|
@@ -51,6 +51,7 @@ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
|
51
51
|
c1 = tl.to_tensor(c1)
|
|
52
52
|
c2 = tl.to_tensor(c2)
|
|
53
53
|
c3 = tl.to_tensor(c3)
|
|
54
|
+
|
|
54
55
|
if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
|
|
55
56
|
int_dtype = tl.uint32
|
|
56
57
|
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
|
|
@@ -60,6 +61,7 @@ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
|
60
61
|
int_dtype = tl.uint64
|
|
61
62
|
seed_hi = tl.full((1, ), 0, dtype=int_dtype)
|
|
62
63
|
seed_lo = seed
|
|
64
|
+
|
|
63
65
|
c0 = c0.to(int_dtype, bitcast=True)
|
|
64
66
|
c1 = c1.to(int_dtype, bitcast=True)
|
|
65
67
|
c2 = c2.to(int_dtype, bitcast=True)
|
|
@@ -96,8 +98,16 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
|
|
|
96
98
|
:param offsets: The offsets to generate random numbers for.
|
|
97
99
|
"""
|
|
98
100
|
# _0 = tl.zeros(offset.shape, offset.dtype)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
+
|
|
102
|
+
offset_lo = offset.to(tl.uint32)
|
|
103
|
+
_0 = offset_lo * 0
|
|
104
|
+
|
|
105
|
+
if tl.constexpr(offset.dtype.primitive_bitwidth) > 32:
|
|
106
|
+
offset_hi = (offset >> 32).to(tl.uint32)
|
|
107
|
+
else:
|
|
108
|
+
offset_hi = _0
|
|
109
|
+
|
|
110
|
+
return philox(seed, offset_lo, offset_hi, _0, _0, n_rounds)
|
|
101
111
|
|
|
102
112
|
|
|
103
113
|
# -------------------
|