triton-windows 3.3.1.post19__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/language/standard.py
CHANGED
|
@@ -1,24 +1,25 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..runtime.jit import jit
|
|
3
|
+
from ..runtime.jit import jit, constexpr_function
|
|
4
4
|
from . import core
|
|
5
5
|
from . import math
|
|
6
6
|
|
|
7
7
|
# constexpr utilities
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
@constexpr_function
|
|
11
|
+
def _log2(i):
|
|
11
12
|
log2 = 0
|
|
12
|
-
n = i
|
|
13
|
+
n = i
|
|
13
14
|
while n > 1:
|
|
14
15
|
n >>= 1
|
|
15
16
|
log2 += 1
|
|
16
|
-
return
|
|
17
|
+
return log2
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
return
|
|
20
|
+
@constexpr_function
|
|
21
|
+
def _is_power_of_two(i):
|
|
22
|
+
return (i & (i - 1)) == 0 and i != 0
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
# -----------------------
|
|
@@ -50,10 +51,14 @@ def sigmoid(x):
|
|
|
50
51
|
@core._tensor_member_fn
|
|
51
52
|
@jit
|
|
52
53
|
@math._add_math_1arg_docstr("softmax")
|
|
53
|
-
def softmax(x, ieee_rounding=False):
|
|
54
|
-
|
|
54
|
+
def softmax(x, dim=None, keep_dims=False, ieee_rounding=False):
|
|
55
|
+
if dim is None:
|
|
56
|
+
_dim: core.constexpr = 0
|
|
57
|
+
else:
|
|
58
|
+
_dim: core.constexpr = dim
|
|
59
|
+
z = x - max(x, _dim, keep_dims=keep_dims)
|
|
55
60
|
num = math.exp(z)
|
|
56
|
-
den = sum(num,
|
|
61
|
+
den = sum(num, _dim, keep_dims=keep_dims)
|
|
57
62
|
return math.fdiv(num, den, ieee_rounding)
|
|
58
63
|
|
|
59
64
|
|
|
@@ -259,8 +264,8 @@ def _sum_combine(a, b):
|
|
|
259
264
|
# sum
|
|
260
265
|
|
|
261
266
|
|
|
262
|
-
|
|
263
|
-
|
|
267
|
+
@constexpr_function
|
|
268
|
+
def _pick_sum_dtype(in_dtype, dtype):
|
|
264
269
|
if dtype is not None:
|
|
265
270
|
return dtype
|
|
266
271
|
|
|
@@ -302,15 +307,37 @@ def xor_sum(input, axis=None, keep_dims=False):
|
|
|
302
307
|
return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)
|
|
303
308
|
|
|
304
309
|
|
|
310
|
+
# or reduction
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@jit
|
|
314
|
+
def _or_combine(x, y):
|
|
315
|
+
return x | y
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@core._tensor_member_fn
|
|
319
|
+
@jit
|
|
320
|
+
@core._add_reduction_docstr("reduce_or")
|
|
321
|
+
def reduce_or(input, axis, keep_dims=False):
|
|
322
|
+
core.static_assert(input.type.scalar.is_int(), "reduce_or only supported for integers")
|
|
323
|
+
return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
|
|
324
|
+
|
|
325
|
+
|
|
305
326
|
# cumsum
|
|
306
327
|
|
|
307
328
|
|
|
308
329
|
@core._tensor_member_fn
|
|
309
330
|
@jit
|
|
310
|
-
@core._add_scan_docstr("cumsum")
|
|
311
|
-
def cumsum(input, axis=0, reverse=False):
|
|
331
|
+
@core._add_scan_docstr("cumsum", dtype_arg="dtype")
|
|
332
|
+
def cumsum(input, axis=0, reverse=False, dtype: core.constexpr = None):
|
|
312
333
|
# todo rename this to a generic function name
|
|
334
|
+
|
|
313
335
|
input = core._promote_bfloat16_to_float32(input)
|
|
336
|
+
out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
|
|
337
|
+
|
|
338
|
+
if out_dtype is not None:
|
|
339
|
+
input = input.to(out_dtype)
|
|
340
|
+
|
|
314
341
|
return core.associative_scan(input, axis, _sum_combine, reverse)
|
|
315
342
|
|
|
316
343
|
|
|
@@ -335,53 +362,63 @@ def cumprod(input, axis=0, reverse=False):
|
|
|
335
362
|
|
|
336
363
|
|
|
337
364
|
@jit
|
|
338
|
-
def
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
365
|
+
def _indicator(n_dims: core.constexpr, j: core.constexpr):
|
|
366
|
+
ar = core.arange(0, 2)
|
|
367
|
+
ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j)
|
|
368
|
+
return ar
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
@jit
|
|
372
|
+
def _compare_and_swap(x, flip, i: core.constexpr):
|
|
373
|
+
# compare-and-swap on the ith *innermost* dimension
|
|
374
|
+
n_dims: core.constexpr = _log2(x.numel)
|
|
375
|
+
|
|
376
|
+
# flip along middle dimension (the bitwise XORs will be optimised away):
|
|
349
377
|
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
350
|
-
ileft = left.to(idtype, bitcast=True)
|
|
351
|
-
iright = right.to(idtype, bitcast=True)
|
|
352
378
|
ix = x.to(idtype, bitcast=True)
|
|
353
|
-
|
|
354
|
-
|
|
379
|
+
iy = ix ^ xor_sum(ix, n_dims - 1 - i, True)
|
|
380
|
+
y = iy.to(x.dtype, bitcast=True)
|
|
381
|
+
|
|
382
|
+
# determines whether we are in the right (rather than left) position along the axis:
|
|
383
|
+
is_right = _indicator(n_dims, i)
|
|
384
|
+
|
|
385
|
+
# conditional swap:
|
|
386
|
+
ret = core.where((x > y) != (flip ^ is_right), y, x)
|
|
387
|
+
return ret
|
|
355
388
|
|
|
356
389
|
|
|
357
390
|
@jit
|
|
358
|
-
def
|
|
391
|
+
def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr):
|
|
359
392
|
'''
|
|
360
393
|
order_type 0 == ascending
|
|
361
394
|
order_type 1 == descending
|
|
362
395
|
order_type 2 == alternating
|
|
363
396
|
'''
|
|
364
|
-
n_outer: core.constexpr = x.numel >> n_dims
|
|
365
|
-
core.static_assert(stage <= n_dims)
|
|
366
397
|
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
|
|
367
398
|
# descending order.
|
|
368
399
|
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
|
|
369
400
|
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
|
|
370
401
|
# a stride of 2) at this stage
|
|
371
402
|
if order == 2:
|
|
372
|
-
|
|
373
|
-
flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
|
|
403
|
+
flip = _indicator(_log2(x.numel), stage)
|
|
374
404
|
else:
|
|
375
405
|
flip = order
|
|
376
406
|
# perform `stage` rounds of `compare-and-swap`
|
|
377
407
|
for i in core.static_range(stage):
|
|
378
|
-
x = _compare_and_swap(x, flip,
|
|
408
|
+
x = _compare_and_swap(x, flip, stage - 1 - i)
|
|
379
409
|
return x
|
|
380
410
|
|
|
381
411
|
|
|
382
|
-
@core._tensor_member_fn
|
|
383
412
|
@jit
|
|
384
|
-
def
|
|
413
|
+
def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
|
|
414
|
+
h = core.reshape(x, [2] * _log2(x.numel))
|
|
415
|
+
h = _bitonic_merge_hypercube(h, stage, order)
|
|
416
|
+
x = core.reshape(h, x.shape)
|
|
417
|
+
return x
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
@jit
|
|
421
|
+
def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
385
422
|
"""
|
|
386
423
|
Sorts a tensor along a specified dimension.
|
|
387
424
|
|
|
@@ -389,29 +426,64 @@ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTE
|
|
|
389
426
|
:type x: Tensor
|
|
390
427
|
:param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
|
|
391
428
|
:type dim: int, optional
|
|
429
|
+
:param k: the number of top elements to select. If none, assume k = x.shape[dim]
|
|
430
|
+
:type k: int, optional
|
|
392
431
|
:param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
|
|
393
432
|
:type descending: bool, optional
|
|
394
433
|
"""
|
|
395
434
|
# handle default dimension or check that it is the most minor dim
|
|
396
435
|
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
|
|
397
436
|
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
437
|
+
|
|
438
|
+
log_n: core.constexpr = _log2(x.shape[_dim])
|
|
439
|
+
log_k: core.constexpr = log_n if k is None else _log2(k)
|
|
440
|
+
|
|
441
|
+
n_dims: core.constexpr = _log2(x.numel)
|
|
442
|
+
|
|
443
|
+
# reshape to hypercube:
|
|
444
|
+
h = core.reshape(x, [2] * n_dims)
|
|
445
|
+
|
|
446
|
+
# run first log_k bitonic sort iterations:
|
|
447
|
+
for i in core.static_range(1, log_k + 1):
|
|
448
|
+
h = _bitonic_merge_hypercube(h, i, 2 if i < log_n else descending)
|
|
449
|
+
|
|
450
|
+
# select top k elements using bitonic top-k
|
|
451
|
+
# https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf
|
|
452
|
+
for i in core.static_range(log_k + 1, log_n + 1):
|
|
453
|
+
h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k))
|
|
454
|
+
h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending)
|
|
455
|
+
|
|
456
|
+
# reshape back:
|
|
457
|
+
x = core.reshape(h, x.shape[:-1] + [2**log_k])
|
|
402
458
|
return x
|
|
403
459
|
|
|
404
460
|
|
|
405
|
-
|
|
461
|
+
@jit
|
|
462
|
+
def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
463
|
+
return sort_impl(x, dim=dim, descending=descending)
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
@jit
|
|
467
|
+
def topk(x, k: core.constexpr, dim: core.constexpr = None):
|
|
468
|
+
return sort_impl(x, k=k, dim=dim, descending=True)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@jit
|
|
472
|
+
def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
|
|
473
|
+
# handle default dimension or check that it is the most minor dim
|
|
474
|
+
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
|
|
475
|
+
core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
|
|
476
|
+
n_dims: core.constexpr = _log2(x.shape[-1])
|
|
477
|
+
return _bitonic_merge(x, n_dims, descending, n_dims)
|
|
406
478
|
|
|
407
479
|
|
|
480
|
+
@constexpr_function
|
|
408
481
|
def _get_flip_dim(dim, shape):
|
|
409
|
-
dim = core._unwrap_if_constexpr(dim)
|
|
410
|
-
shape = core._unwrap_if_constexpr(shape)
|
|
411
482
|
if dim is None:
|
|
412
483
|
dim = len(shape) - 1
|
|
413
|
-
|
|
414
|
-
|
|
484
|
+
if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
|
|
485
|
+
dim += len(shape)
|
|
486
|
+
return dim
|
|
415
487
|
|
|
416
488
|
|
|
417
489
|
@core._tensor_member_fn
|
|
@@ -422,26 +494,19 @@ def flip(x, dim=None):
|
|
|
422
494
|
|
|
423
495
|
:param x: the first input tensor
|
|
424
496
|
:type x: Block
|
|
425
|
-
:param dim: the dimension to flip along
|
|
497
|
+
:param dim: the dimension to flip along
|
|
426
498
|
:type dim: int
|
|
427
499
|
"""
|
|
428
|
-
core.static_assert(
|
|
429
|
-
core.
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
steps: core.constexpr = _log2(x.numel)
|
|
433
|
-
start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
|
|
500
|
+
core.static_assert(-len(x.shape) <= dim and dim < len(x.shape))
|
|
501
|
+
_dim: core.constexpr = _get_flip_dim(dim, x.shape)
|
|
502
|
+
core.static_assert(_is_power_of_two(x.shape[_dim]))
|
|
503
|
+
steps: core.constexpr = _log2(x.shape[_dim])
|
|
434
504
|
|
|
505
|
+
# reshape the swap dimension to (2, 2, ..., 2)
|
|
435
506
|
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
436
|
-
y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
for i in core.static_range(start, steps):
|
|
440
|
-
flip2 = flip
|
|
441
|
-
for j in core.static_range(0, steps + 1):
|
|
442
|
-
if j != i and j != i + 1:
|
|
443
|
-
flip2 = core.expand_dims(flip2, j)
|
|
444
|
-
y = sum(y * flip2, i + 1, keep_dims=True, dtype=y.dtype)
|
|
507
|
+
y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:])
|
|
508
|
+
for i in core.static_range(steps):
|
|
509
|
+
y = y ^ xor_sum(y, _dim + i, True)
|
|
445
510
|
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
|
|
446
511
|
return x
|
|
447
512
|
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from triton.runtime import driver
|
|
2
|
+
from triton.runtime.jit import constexpr_function
|
|
3
|
+
|
|
4
|
+
__all__ = ["current_target"]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def current_target():
|
|
8
|
+
try:
|
|
9
|
+
active_driver = driver.active
|
|
10
|
+
except RuntimeError:
|
|
11
|
+
# If there is no active driver, return None
|
|
12
|
+
return None
|
|
13
|
+
return active_driver.get_current_target()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
current_target.__triton_builtin__ = True
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@constexpr_function
|
|
20
|
+
def is_cuda():
|
|
21
|
+
target = current_target()
|
|
22
|
+
return target is not None and target.backend == "cuda"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@constexpr_function
|
|
26
|
+
def cuda_capability_geq(major, minor=0):
|
|
27
|
+
"""
|
|
28
|
+
Determines whether we have compute capability >= (major, minor) and
|
|
29
|
+
returns this as a constexpr boolean. This can be used for guarding
|
|
30
|
+
inline asm implementations that require a certain compute capability.
|
|
31
|
+
"""
|
|
32
|
+
target = current_target()
|
|
33
|
+
if target is None or target.backend != "cuda":
|
|
34
|
+
return False
|
|
35
|
+
assert isinstance(target.arch, int)
|
|
36
|
+
return target.arch >= major * 10 + minor
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@constexpr_function
|
|
40
|
+
def is_hip():
|
|
41
|
+
target = current_target()
|
|
42
|
+
return target is not None and target.backend == "hip"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@constexpr_function
|
|
46
|
+
def is_hip_cdna3():
|
|
47
|
+
target = current_target()
|
|
48
|
+
return target is not None and target.arch == "gfx942"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@constexpr_function
|
|
52
|
+
def is_hip_cdna4():
|
|
53
|
+
target = current_target()
|
|
54
|
+
return target is not None and target.arch == "gfx950"
|
triton/runtime/_allocation.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from typing import Optional, Protocol
|
|
2
|
+
from contextvars import ContextVar
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class Buffer(Protocol):
|
|
@@ -20,7 +21,7 @@ class NullAllocator:
|
|
|
20
21
|
"Use triton.set_allocator to specify an allocator.")
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
_allocator: Allocator = NullAllocator()
|
|
24
|
+
_allocator: ContextVar[Allocator] = ContextVar("_allocator", default=NullAllocator())
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
def set_allocator(allocator: Allocator):
|
|
@@ -28,5 +29,16 @@ def set_allocator(allocator: Allocator):
|
|
|
28
29
|
The allocator function is called during kernel launch for kernels that
|
|
29
30
|
require additional global memory workspace.
|
|
30
31
|
"""
|
|
31
|
-
|
|
32
|
-
|
|
32
|
+
_allocator.set(allocator)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_profile_allocator: Allocator = ContextVar("_allocator", default=NullAllocator())
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def set_profile_allocator(allocator: Optional[Allocator]):
|
|
39
|
+
"""
|
|
40
|
+
The profile allocator function is called before kernel launch for kernels
|
|
41
|
+
that require additional global memory workspace.
|
|
42
|
+
"""
|
|
43
|
+
global _profile_allocator
|
|
44
|
+
_profile_allocator.set(allocator)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable, Optional
|
|
3
|
+
from concurrent.futures import Executor, as_completed, Future
|
|
4
|
+
from contextvars import ContextVar
|
|
5
|
+
|
|
6
|
+
active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FutureKernel:
|
|
10
|
+
|
|
11
|
+
def __init__(self, finalize_compile: Callable, future: Future):
|
|
12
|
+
self.finalize_compile = finalize_compile
|
|
13
|
+
self.kernel = None
|
|
14
|
+
self.future = future
|
|
15
|
+
|
|
16
|
+
def result(self):
|
|
17
|
+
if self.kernel is not None:
|
|
18
|
+
return self.kernel
|
|
19
|
+
|
|
20
|
+
kernel = self.future.result()
|
|
21
|
+
self.finalize_compile(kernel)
|
|
22
|
+
self.kernel = kernel
|
|
23
|
+
return kernel
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AsyncCompileMode:
|
|
27
|
+
|
|
28
|
+
def __init__(self, executor: Executor):
|
|
29
|
+
self.executor = executor
|
|
30
|
+
self.raw_futures = []
|
|
31
|
+
self.future_kernels = {}
|
|
32
|
+
|
|
33
|
+
def submit(self, key, compile_fn, finalize_fn):
|
|
34
|
+
future = self.future_kernels.get(key)
|
|
35
|
+
if future is not None:
|
|
36
|
+
return future
|
|
37
|
+
|
|
38
|
+
future = self.executor.submit(compile_fn)
|
|
39
|
+
future._key = key
|
|
40
|
+
self.raw_futures.append(future)
|
|
41
|
+
future_kernel = FutureKernel(finalize_fn, future)
|
|
42
|
+
self.future_kernels[key] = future_kernel
|
|
43
|
+
return future_kernel
|
|
44
|
+
|
|
45
|
+
def __enter__(self):
|
|
46
|
+
if active_mode.get() is not None:
|
|
47
|
+
raise RuntimeError("Another AsyncCompileMode is already active")
|
|
48
|
+
active_mode.set(self)
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
52
|
+
# Finalize any outstanding compiles
|
|
53
|
+
for future in as_completed(self.raw_futures):
|
|
54
|
+
self.future_kernels[future._key].result()
|
|
55
|
+
active_mode.set(None)
|