triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/backends/nvidia/driver.c
CHANGED
|
@@ -9,10 +9,15 @@
|
|
|
9
9
|
#endif
|
|
10
10
|
|
|
11
11
|
#include <stdbool.h>
|
|
12
|
+
#include <stdlib.h>
|
|
12
13
|
#define PY_SSIZE_T_CLEAN
|
|
13
|
-
#define Py_LIMITED_API 0x03090000
|
|
14
14
|
#include <Python.h>
|
|
15
15
|
|
|
16
|
+
typedef struct {
|
|
17
|
+
PyObject_HEAD
|
|
18
|
+
_Alignas(128) CUtensorMap tensorMap;
|
|
19
|
+
} PyCUtensorMapObject;
|
|
20
|
+
|
|
16
21
|
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
|
|
17
22
|
static bool gpuAssert(CUresult code, const char *file, int line) {
|
|
18
23
|
if (code == CUDA_SUCCESS)
|
|
@@ -35,7 +40,7 @@ static bool gpuAssert(CUresult code, const char *file, int line) {
|
|
|
35
40
|
#define CUDA_CHECK_AND_RETURN_NULL(ans) \
|
|
36
41
|
do { \
|
|
37
42
|
if (!gpuAssert((ans), __FILE__, __LINE__)) \
|
|
38
|
-
|
|
43
|
+
goto cleanup; \
|
|
39
44
|
} while (0)
|
|
40
45
|
|
|
41
46
|
// To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
|
|
@@ -53,7 +58,7 @@ static bool gpuAssert(CUresult code, const char *file, int line) {
|
|
|
53
58
|
if ((funcPointer) == NULL) { \
|
|
54
59
|
(funcPointer) = (initializerFunction)(); \
|
|
55
60
|
if ((funcPointer) == NULL) { \
|
|
56
|
-
|
|
61
|
+
goto cleanup; \
|
|
57
62
|
} \
|
|
58
63
|
} \
|
|
59
64
|
} while (0)
|
|
@@ -96,6 +101,9 @@ static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
|
|
|
96
101
|
warp_size, "sm_clock_rate", sm_clock_rate,
|
|
97
102
|
"mem_clock_rate", mem_clock_rate, "mem_bus_width",
|
|
98
103
|
mem_bus_width);
|
|
104
|
+
|
|
105
|
+
cleanup:
|
|
106
|
+
return NULL;
|
|
99
107
|
}
|
|
100
108
|
|
|
101
109
|
static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
|
@@ -112,6 +120,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
|
|
112
120
|
CUmodule mod;
|
|
113
121
|
int32_t n_regs = 0;
|
|
114
122
|
int32_t n_spills = 0;
|
|
123
|
+
int32_t n_max_threads = 0;
|
|
115
124
|
// create driver handles
|
|
116
125
|
CUcontext pctx = 0;
|
|
117
126
|
|
|
@@ -132,6 +141,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
|
|
132
141
|
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
|
|
133
142
|
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
|
134
143
|
n_spills /= 4;
|
|
144
|
+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
|
|
145
|
+
&n_max_threads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
|
|
135
146
|
// set dynamic shared memory if necessary
|
|
136
147
|
int shared_optin;
|
|
137
148
|
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
|
|
@@ -155,8 +166,8 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) {
|
|
|
155
166
|
if (PyErr_Occurred()) {
|
|
156
167
|
return NULL;
|
|
157
168
|
}
|
|
158
|
-
return Py_BuildValue("(
|
|
159
|
-
n_spills);
|
|
169
|
+
return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
|
|
170
|
+
n_spills, n_max_threads);
|
|
160
171
|
}
|
|
161
172
|
|
|
162
173
|
typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
|
|
@@ -266,6 +277,9 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
|
|
|
266
277
|
cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config));
|
|
267
278
|
Py_END_ALLOW_THREADS;
|
|
268
279
|
return PyLong_FromLong(maxActiveClusters);
|
|
280
|
+
|
|
281
|
+
cleanup:
|
|
282
|
+
return NULL;
|
|
269
283
|
}
|
|
270
284
|
|
|
271
285
|
static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
|
|
@@ -304,116 +318,162 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
|
|
|
304
318
|
}
|
|
305
319
|
|
|
306
320
|
Py_END_ALLOW_THREADS;
|
|
307
|
-
|
|
308
|
-
return Py_None;
|
|
321
|
+
Py_RETURN_NONE;
|
|
309
322
|
}
|
|
310
323
|
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
uint64_t dims[1] = {dim};
|
|
324
|
-
uint64_t globalStrides[1] = {dim * elementSize};
|
|
325
|
-
uint32_t boxDim[1] = {tensorDim};
|
|
326
|
-
uint32_t elementStrides[1] = {1};
|
|
327
|
-
CUtensorMapDataType type;
|
|
328
|
-
switch (elementSize) {
|
|
329
|
-
case 1:
|
|
330
|
-
type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
|
331
|
-
break;
|
|
332
|
-
case 2:
|
|
333
|
-
type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
|
|
334
|
-
break;
|
|
335
|
-
case 4:
|
|
336
|
-
type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
|
|
337
|
-
break;
|
|
338
|
-
default:
|
|
339
|
-
PyErr_SetString(PyExc_ValueError, "elementSize must be 1, 2, or 4");
|
|
324
|
+
static PyObject *PyCUtensorMap_alloc(PyTypeObject *type, Py_ssize_t n_items) {
|
|
325
|
+
PyCUtensorMapObject *self = NULL;
|
|
326
|
+
void *mem = NULL;
|
|
327
|
+
size_t size = type->tp_basicsize;
|
|
328
|
+
|
|
329
|
+
#ifdef _WIN32
|
|
330
|
+
mem = _aligned_malloc(size, 128);
|
|
331
|
+
if (mem == NULL) {
|
|
332
|
+
#else
|
|
333
|
+
if (posix_memalign(&mem, 128, size) != 0) {
|
|
334
|
+
#endif
|
|
335
|
+
PyErr_NoMemory();
|
|
340
336
|
return NULL;
|
|
341
337
|
}
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
338
|
+
|
|
339
|
+
self = (PyCUtensorMapObject *)mem;
|
|
340
|
+
PyObject_INIT(self, type);
|
|
341
|
+
return (PyObject *)self;
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
static void PyCUtensorMap_dealloc(PyObject *self) {
|
|
345
|
+
Py_TYPE(self)->tp_free(self);
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
static void PyCUtensorMap_free(void *ptr) {
|
|
349
|
+
#ifdef _WIN32
|
|
350
|
+
_aligned_free(ptr);
|
|
351
|
+
#else
|
|
352
|
+
free(ptr);
|
|
353
|
+
#endif
|
|
354
354
|
}
|
|
355
355
|
|
|
356
|
-
//
|
|
357
|
-
|
|
358
|
-
|
|
356
|
+
// clang-format off
|
|
357
|
+
static PyTypeObject PyCUtensorMapType = {
|
|
358
|
+
PyVarObject_HEAD_INIT(NULL, 0)
|
|
359
|
+
.tp_name = "triton.backends.nvidia.PyCUtensorMap",
|
|
360
|
+
.tp_basicsize = sizeof(PyCUtensorMapObject),
|
|
361
|
+
.tp_itemsize = 0,
|
|
362
|
+
.tp_flags = Py_TPFLAGS_DEFAULT,
|
|
363
|
+
.tp_doc = "<PyCUtensorMap object>",
|
|
364
|
+
.tp_new = PyType_GenericNew,
|
|
365
|
+
.tp_alloc = PyCUtensorMap_alloc,
|
|
366
|
+
.tp_dealloc = (destructor)PyCUtensorMap_dealloc,
|
|
367
|
+
.tp_free = PyCUtensorMap_free,
|
|
368
|
+
};
|
|
369
|
+
// clang-format on
|
|
370
|
+
|
|
371
|
+
static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
|
|
359
372
|
unsigned long long global_address;
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
int
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
373
|
+
int swizzle;
|
|
374
|
+
int elemSize;
|
|
375
|
+
int elemType;
|
|
376
|
+
PyObject *blockSize;
|
|
377
|
+
PyObject *shape;
|
|
378
|
+
PyObject *strides;
|
|
379
|
+
int padding;
|
|
380
|
+
|
|
381
|
+
if (!PyArg_ParseTuple(args, "KiiiOOOi", &global_address, &swizzle, &elemSize,
|
|
382
|
+
&elemType, &blockSize, &shape, &strides, &padding)) {
|
|
367
383
|
return NULL;
|
|
368
384
|
}
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
+
|
|
386
|
+
PyCUtensorMapObject *desc = (PyCUtensorMapObject *)PyObject_CallObject(
|
|
387
|
+
(PyObject *)&PyCUtensorMapType, NULL);
|
|
388
|
+
if (!desc) {
|
|
389
|
+
return NULL;
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
PyObject *blockSizeFast = NULL;
|
|
393
|
+
PyObject *shapeFast = NULL;
|
|
394
|
+
PyObject *stridesFast = NULL;
|
|
395
|
+
|
|
396
|
+
uint32_t blockSizeInt[5];
|
|
397
|
+
uint64_t shapeInt[5];
|
|
398
|
+
uint64_t stridesLL[5];
|
|
399
|
+
|
|
400
|
+
blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
|
|
401
|
+
if (!blockSizeFast)
|
|
402
|
+
goto cleanup;
|
|
403
|
+
int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
|
|
404
|
+
|
|
405
|
+
for (int i = 0; i < rank; ++i) {
|
|
406
|
+
PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
|
|
407
|
+
if (!PyLong_Check(item)) {
|
|
408
|
+
PyErr_SetString(PyExc_TypeError, "block size must be an int");
|
|
409
|
+
goto cleanup;
|
|
410
|
+
}
|
|
411
|
+
blockSizeInt[rank - i - 1] = PyLong_AsLongLong(item);
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
shapeFast = PySequence_Fast(shape, "shape must be a sequence");
|
|
415
|
+
if (!shapeFast)
|
|
416
|
+
goto cleanup;
|
|
417
|
+
|
|
418
|
+
if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
|
|
419
|
+
PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
|
|
420
|
+
goto cleanup;
|
|
385
421
|
}
|
|
386
|
-
int
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
} else if (contigDimSizeInByte >= 64) {
|
|
394
|
-
swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
|
|
395
|
-
} else if (contigDimSizeInByte >= 32) {
|
|
396
|
-
swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
|
|
397
|
-
} else {
|
|
398
|
-
assert(false && "block size too small.");
|
|
422
|
+
for (int i = 0; i < rank; ++i) {
|
|
423
|
+
PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
|
|
424
|
+
if (!PyLong_Check(item)) {
|
|
425
|
+
PyErr_SetString(PyExc_TypeError, "shape must be an int");
|
|
426
|
+
goto cleanup;
|
|
427
|
+
}
|
|
428
|
+
shapeInt[rank - i - 1] = PyLong_AsLong(item);
|
|
399
429
|
}
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
430
|
+
|
|
431
|
+
stridesFast = PySequence_Fast(strides, "strides must be a sequence");
|
|
432
|
+
if (!stridesFast)
|
|
433
|
+
goto cleanup;
|
|
434
|
+
|
|
435
|
+
if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
|
|
436
|
+
PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
|
|
437
|
+
goto cleanup;
|
|
406
438
|
}
|
|
439
|
+
for (int i = 0; i + 1 < rank; ++i) {
|
|
440
|
+
PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
|
|
441
|
+
if (!PyLong_Check(item)) {
|
|
442
|
+
PyErr_SetString(PyExc_TypeError, "shape must be an int");
|
|
443
|
+
goto cleanup;
|
|
444
|
+
}
|
|
445
|
+
stridesLL[rank - i - 2] = elemSize * PyLong_AsLongLong(item);
|
|
446
|
+
}
|
|
447
|
+
stridesLL[rank - 1] =
|
|
448
|
+
shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]);
|
|
449
|
+
Py_DECREF(blockSizeFast);
|
|
450
|
+
blockSizeFast = NULL;
|
|
451
|
+
Py_DECREF(shapeFast);
|
|
452
|
+
shapeFast = NULL;
|
|
453
|
+
Py_DECREF(stridesFast);
|
|
454
|
+
stridesFast = NULL;
|
|
455
|
+
|
|
456
|
+
CUtensorMapFloatOOBfill fill =
|
|
457
|
+
(padding == 1) ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
|
|
458
|
+
: CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
|
|
459
|
+
|
|
460
|
+
uint32_t elementStrides[5] = {1, 1, 1, 1, 1};
|
|
407
461
|
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
|
|
408
462
|
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
|
|
409
463
|
getCuTensorMapEncodeTiledHandle);
|
|
410
464
|
CUDA_CHECK_AND_RETURN_NULL(cuTensorMapEncodeTiled(
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
465
|
+
&desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
|
|
466
|
+
stridesLL, blockSizeInt, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
|
|
467
|
+
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill));
|
|
468
|
+
|
|
469
|
+
return (PyObject *)desc;
|
|
470
|
+
|
|
471
|
+
cleanup:
|
|
472
|
+
Py_XDECREF(blockSizeFast);
|
|
473
|
+
Py_XDECREF(shapeFast);
|
|
474
|
+
Py_XDECREF(stridesFast);
|
|
475
|
+
Py_XDECREF(desc);
|
|
476
|
+
return NULL;
|
|
417
477
|
}
|
|
418
478
|
|
|
419
479
|
static PyMethodDef ModuleMethods[] = {
|
|
@@ -429,8 +489,7 @@ static PyMethodDef ModuleMethods[] = {
|
|
|
429
489
|
"being dropped. This inherits all the limitations of this call; in "
|
|
430
490
|
"particular it's an error to change this value after launching any kernel "
|
|
431
491
|
"that calls printf()."},
|
|
432
|
-
{"
|
|
433
|
-
{"fill_2d_tma_descriptor", fill2DTMADescriptor, METH_VARARGS, "doc"},
|
|
492
|
+
{"fill_tma_descriptor", fillTMADescriptor, METH_VARARGS, "doc"},
|
|
434
493
|
|
|
435
494
|
{NULL, NULL, 0, NULL} // sentinel
|
|
436
495
|
};
|
|
@@ -441,12 +500,18 @@ static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
|
|
|
441
500
|
ModuleMethods};
|
|
442
501
|
|
|
443
502
|
PyMODINIT_FUNC PyInit_cuda_utils(void) {
|
|
503
|
+
if (PyType_Ready(&PyCUtensorMapType) < 0) {
|
|
504
|
+
return NULL;
|
|
505
|
+
}
|
|
506
|
+
|
|
444
507
|
PyObject *m = PyModule_Create(&ModuleDef);
|
|
445
508
|
if (m == NULL) {
|
|
446
509
|
return NULL;
|
|
447
510
|
}
|
|
448
511
|
|
|
449
512
|
PyModule_AddFunctions(m, ModuleMethods);
|
|
513
|
+
Py_INCREF(&PyCUtensorMapType);
|
|
514
|
+
PyModule_AddObject(m, "PyCUtensorMap", (PyObject *)&PyCUtensorMapType);
|
|
450
515
|
|
|
451
516
|
return m;
|
|
452
517
|
}
|