triton-windows 3.3.1.post19__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +11 -2
- triton/_filecheck.py +97 -0
- triton/_internal_testing.py +95 -18
- triton/_utils.py +112 -21
- triton/backends/__init__.py +20 -23
- triton/backends/amd/__init__.py +0 -0
- triton/backends/amd/compiler.py +161 -119
- triton/backends/amd/driver.c +118 -46
- triton/backends/amd/driver.py +274 -96
- triton/backends/compiler.py +7 -21
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +163 -106
- triton/backends/nvidia/driver.c +166 -101
- triton/backends/nvidia/driver.py +384 -202
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +439 -231
- triton/compiler/compiler.py +152 -84
- triton/experimental/__init__.py +0 -0
- triton/experimental/gluon/__init__.py +5 -0
- triton/experimental/gluon/_compiler.py +0 -0
- triton/experimental/gluon/_runtime.py +102 -0
- triton/experimental/gluon/language/__init__.py +119 -0
- triton/experimental/gluon/language/_core.py +490 -0
- triton/experimental/gluon/language/_layouts.py +583 -0
- triton/experimental/gluon/language/_math.py +20 -0
- triton/experimental/gluon/language/_semantic.py +380 -0
- triton/experimental/gluon/language/_standard.py +80 -0
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/__init__.py +4 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
- triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
- triton/experimental/gluon/nvidia/__init__.py +4 -0
- triton/experimental/gluon/nvidia/blackwell.py +3 -0
- triton/experimental/gluon/nvidia/hopper.py +45 -0
- triton/knobs.py +546 -0
- triton/language/__init__.py +50 -19
- triton/language/core.py +909 -572
- triton/language/extra/cuda/__init__.py +10 -7
- triton/language/extra/cuda/gdc.py +42 -0
- triton/language/extra/cuda/libdevice.py +394 -394
- triton/language/extra/cuda/utils.py +21 -21
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +120 -104
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/math.py +65 -66
- triton/language/random.py +12 -2
- triton/language/semantic.py +1757 -1768
- triton/language/standard.py +127 -62
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +117 -60
- triton/runtime/build.py +83 -17
- triton/runtime/cache.py +61 -47
- triton/runtime/driver.py +25 -47
- triton/runtime/interpreter.py +95 -50
- triton/runtime/jit.py +445 -248
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +16 -12
- triton/tools/compile.py +62 -14
- triton/tools/disasm.py +3 -4
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +34 -0
- triton/windows_utils.py +52 -81
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
- triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
- triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
- triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
- triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
- triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
- triton/backends/amd/include/hip/channel_descriptor.h +0 -39
- triton/backends/amd/include/hip/device_functions.h +0 -38
- triton/backends/amd/include/hip/driver_types.h +0 -468
- triton/backends/amd/include/hip/hip_bf16.h +0 -36
- triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
- triton/backends/amd/include/hip/hip_common.h +0 -100
- triton/backends/amd/include/hip/hip_complex.h +0 -38
- triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
- triton/backends/amd/include/hip/hip_deprecated.h +0 -95
- triton/backends/amd/include/hip/hip_ext.h +0 -161
- triton/backends/amd/include/hip/hip_fp16.h +0 -36
- triton/backends/amd/include/hip/hip_fp8.h +0 -33
- triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
- triton/backends/amd/include/hip/hip_hcc.h +0 -24
- triton/backends/amd/include/hip/hip_math_constants.h +0 -36
- triton/backends/amd/include/hip/hip_profile.h +0 -27
- triton/backends/amd/include/hip/hip_runtime.h +0 -75
- triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
- triton/backends/amd/include/hip/hip_texture_types.h +0 -29
- triton/backends/amd/include/hip/hip_vector_types.h +0 -41
- triton/backends/amd/include/hip/hip_version.h +0 -17
- triton/backends/amd/include/hip/hiprtc.h +0 -421
- triton/backends/amd/include/hip/library_types.h +0 -78
- triton/backends/amd/include/hip/math_functions.h +0 -42
- triton/backends/amd/include/hip/surface_types.h +0 -63
- triton/backends/amd/include/hip/texture_types.h +0 -194
- triton/backends/amd/include/hsa/Brig.h +0 -1131
- triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
- triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
- triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
- triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
- triton/backends/amd/include/hsa/hsa.h +0 -5738
- triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
- triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
- triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
- triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
- triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
- triton/backends/amd/include/roctracer/roctracer.h +0 -779
- triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
- triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
- triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
- triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
- triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
- triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
- triton/backends/amd/include/roctracer/roctx.h +0 -229
- triton/language/_utils.py +0 -21
- triton/language/extra/cuda/_experimental_tma.py +0 -106
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/tools/experimental_descriptor.py +0 -32
- triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
- triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
- {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
triton/language/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from . import extra
|
|
|
6
6
|
from .standard import (
|
|
7
7
|
argmax,
|
|
8
8
|
argmin,
|
|
9
|
+
bitonic_merge,
|
|
9
10
|
cdiv,
|
|
10
11
|
cumprod,
|
|
11
12
|
cumsum,
|
|
@@ -14,11 +15,13 @@ from .standard import (
|
|
|
14
15
|
max,
|
|
15
16
|
min,
|
|
16
17
|
ravel,
|
|
18
|
+
reduce_or,
|
|
17
19
|
sigmoid,
|
|
18
20
|
softmax,
|
|
19
21
|
sort,
|
|
20
22
|
sum,
|
|
21
23
|
swizzle2d,
|
|
24
|
+
topk,
|
|
22
25
|
xor_sum,
|
|
23
26
|
zeros,
|
|
24
27
|
zeros_like,
|
|
@@ -26,16 +29,17 @@ from .standard import (
|
|
|
26
29
|
from .core import (
|
|
27
30
|
PropagateNan,
|
|
28
31
|
TRITON_MAX_TENSOR_NUMEL,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
32
|
+
load_tensor_descriptor,
|
|
33
|
+
store_tensor_descriptor,
|
|
34
|
+
make_tensor_descriptor,
|
|
35
|
+
tensor_descriptor,
|
|
36
|
+
tensor_descriptor_type,
|
|
34
37
|
add,
|
|
35
38
|
advance,
|
|
36
39
|
arange,
|
|
37
40
|
associative_scan,
|
|
38
41
|
assume,
|
|
42
|
+
async_task,
|
|
39
43
|
atomic_add,
|
|
40
44
|
atomic_and,
|
|
41
45
|
atomic_cas,
|
|
@@ -51,8 +55,10 @@ from .core import (
|
|
|
51
55
|
cat,
|
|
52
56
|
cast,
|
|
53
57
|
clamp,
|
|
58
|
+
condition,
|
|
54
59
|
const,
|
|
55
60
|
constexpr,
|
|
61
|
+
constexpr_type,
|
|
56
62
|
debug_barrier,
|
|
57
63
|
device_assert,
|
|
58
64
|
device_print,
|
|
@@ -80,6 +86,7 @@ from .core import (
|
|
|
80
86
|
join,
|
|
81
87
|
load,
|
|
82
88
|
make_block_ptr,
|
|
89
|
+
map_elementwise,
|
|
83
90
|
max_constancy,
|
|
84
91
|
max_contiguous,
|
|
85
92
|
maximum,
|
|
@@ -89,7 +96,6 @@ from .core import (
|
|
|
89
96
|
permute,
|
|
90
97
|
pi32_t,
|
|
91
98
|
pointer_type,
|
|
92
|
-
nv_tma_desc_type,
|
|
93
99
|
program_id,
|
|
94
100
|
range,
|
|
95
101
|
reduce,
|
|
@@ -126,15 +132,15 @@ from .random import (
|
|
|
126
132
|
randn4x,
|
|
127
133
|
uint_to_uniform_float,
|
|
128
134
|
)
|
|
135
|
+
from . import target_info
|
|
129
136
|
|
|
130
137
|
__all__ = [
|
|
131
138
|
"PropagateNan",
|
|
132
139
|
"TRITON_MAX_TENSOR_NUMEL",
|
|
133
|
-
"
|
|
134
|
-
"
|
|
135
|
-
"
|
|
136
|
-
"
|
|
137
|
-
"_experimental_tensor_descriptor",
|
|
140
|
+
"load_tensor_descriptor",
|
|
141
|
+
"store_tensor_descriptor",
|
|
142
|
+
"make_tensor_descriptor",
|
|
143
|
+
"tensor_descriptor",
|
|
138
144
|
"abs",
|
|
139
145
|
"add",
|
|
140
146
|
"advance",
|
|
@@ -143,6 +149,7 @@ __all__ = [
|
|
|
143
149
|
"argmin",
|
|
144
150
|
"associative_scan",
|
|
145
151
|
"assume",
|
|
152
|
+
"async_task",
|
|
146
153
|
"atomic_add",
|
|
147
154
|
"atomic_and",
|
|
148
155
|
"atomic_cas",
|
|
@@ -152,6 +159,7 @@ __all__ = [
|
|
|
152
159
|
"atomic_xchg",
|
|
153
160
|
"atomic_xor",
|
|
154
161
|
"bfloat16",
|
|
162
|
+
"bitonic_merge",
|
|
155
163
|
"block_type",
|
|
156
164
|
"broadcast",
|
|
157
165
|
"broadcast_to",
|
|
@@ -160,8 +168,10 @@ __all__ = [
|
|
|
160
168
|
"cdiv",
|
|
161
169
|
"ceil",
|
|
162
170
|
"clamp",
|
|
171
|
+
"condition",
|
|
163
172
|
"const",
|
|
164
173
|
"constexpr",
|
|
174
|
+
"constexpr_type",
|
|
165
175
|
"cos",
|
|
166
176
|
"cumprod",
|
|
167
177
|
"cumsum",
|
|
@@ -204,6 +214,7 @@ __all__ = [
|
|
|
204
214
|
"log",
|
|
205
215
|
"log2",
|
|
206
216
|
"make_block_ptr",
|
|
217
|
+
"map_elementwise",
|
|
207
218
|
"math",
|
|
208
219
|
"max",
|
|
209
220
|
"max_constancy",
|
|
@@ -219,7 +230,6 @@ __all__ = [
|
|
|
219
230
|
"philox_impl",
|
|
220
231
|
"pi32_t",
|
|
221
232
|
"pointer_type",
|
|
222
|
-
"nv_tma_desc_type",
|
|
223
233
|
"program_id",
|
|
224
234
|
"rand",
|
|
225
235
|
"rand4x",
|
|
@@ -230,6 +240,7 @@ __all__ = [
|
|
|
230
240
|
"range",
|
|
231
241
|
"ravel",
|
|
232
242
|
"reduce",
|
|
243
|
+
"reduce_or",
|
|
233
244
|
"reshape",
|
|
234
245
|
"rsqrt",
|
|
235
246
|
"slice",
|
|
@@ -246,7 +257,9 @@ __all__ = [
|
|
|
246
257
|
"store",
|
|
247
258
|
"sum",
|
|
248
259
|
"swizzle2d",
|
|
260
|
+
"target_info",
|
|
249
261
|
"tensor",
|
|
262
|
+
"topk",
|
|
250
263
|
"trans",
|
|
251
264
|
"tuple",
|
|
252
265
|
"uint16",
|
|
@@ -264,12 +277,12 @@ __all__ = [
|
|
|
264
277
|
]
|
|
265
278
|
|
|
266
279
|
|
|
267
|
-
def str_to_ty(name):
|
|
280
|
+
def str_to_ty(name, c):
|
|
268
281
|
from builtins import tuple
|
|
269
282
|
|
|
270
283
|
if isinstance(name, tuple):
|
|
271
284
|
fields = type(name).__dict__.get("_fields", None)
|
|
272
|
-
return tuple_type([str_to_ty(x) for x in name], fields)
|
|
285
|
+
return tuple_type([str_to_ty(x, c) for x in name], fields)
|
|
273
286
|
|
|
274
287
|
if name[0] == "*":
|
|
275
288
|
name = name[1:]
|
|
@@ -277,14 +290,32 @@ def str_to_ty(name):
|
|
|
277
290
|
if name[0] == "k":
|
|
278
291
|
name = name[1:]
|
|
279
292
|
const = True
|
|
280
|
-
ty = str_to_ty(name)
|
|
293
|
+
ty = str_to_ty(name, c)
|
|
281
294
|
return pointer_type(element_ty=ty, const=const)
|
|
282
295
|
|
|
283
|
-
if name
|
|
284
|
-
|
|
296
|
+
if name.startswith("tensordesc"):
|
|
297
|
+
inner = name.split("<")[1].rstrip(">")
|
|
298
|
+
dtype, rest = inner.split("[", maxsplit=1)
|
|
299
|
+
block_shape, rest = rest.split("]", maxsplit=1)
|
|
300
|
+
block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")]
|
|
301
|
+
layout = rest.lstrip(",")
|
|
302
|
+
is_gluon = len(layout)
|
|
303
|
+
dtype = str_to_ty(dtype, None)
|
|
304
|
+
ndim = len(block_shape)
|
|
305
|
+
shape_type = tuple_type([int32] * ndim)
|
|
306
|
+
# FIXME: Last dim stride should be constexpr(1)
|
|
307
|
+
stride_type = tuple_type(([int64] * ndim))
|
|
308
|
+
block = block_type(dtype, block_shape)
|
|
309
|
+
if is_gluon:
|
|
310
|
+
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
|
|
311
|
+
from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor_type as gluon_tensor_descriptor_type
|
|
312
|
+
layout = eval(layout, dict(NVMMASharedLayout=NVMMASharedLayout))
|
|
313
|
+
assert isinstance(layout, NVMMASharedLayout)
|
|
314
|
+
return gluon_tensor_descriptor_type(block, shape_type, stride_type, layout)
|
|
315
|
+
return tensor_descriptor_type(block, shape_type, stride_type)
|
|
285
316
|
|
|
286
|
-
if name
|
|
287
|
-
return
|
|
317
|
+
if name.startswith("constexpr"):
|
|
318
|
+
return constexpr_type(c)
|
|
288
319
|
|
|
289
320
|
tys = {
|
|
290
321
|
"fp8e4nv": float8e4nv,
|