triton-windows 3.4.0.post20__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
import math
|
|
2
3
|
from typing import TypeVar, List, TYPE_CHECKING, Tuple
|
|
3
4
|
from functools import wraps
|
|
4
5
|
|
|
@@ -37,38 +38,17 @@ from triton.language.core import (
|
|
|
37
38
|
float64,
|
|
38
39
|
_unwrap_if_constexpr,
|
|
39
40
|
_unwrap_shape,
|
|
41
|
+
static_range,
|
|
40
42
|
tensor,
|
|
41
43
|
tuple,
|
|
42
44
|
tuple_type,
|
|
43
45
|
)
|
|
44
46
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
"join",
|
|
48
|
-
"load",
|
|
49
|
-
"maximum",
|
|
50
|
-
"minimum",
|
|
51
|
-
"permute",
|
|
52
|
-
"program_id",
|
|
53
|
-
"reduce",
|
|
54
|
-
"reshape",
|
|
55
|
-
"split",
|
|
56
|
-
"static_assert",
|
|
57
|
-
"static_print",
|
|
58
|
-
"store",
|
|
59
|
-
"to_tensor",
|
|
60
|
-
"where",
|
|
61
|
-
"inline_asm_elementwise",
|
|
62
|
-
]
|
|
63
|
-
|
|
47
|
+
# We define __all__ only to appease the python linter, these are not used in
|
|
48
|
+
# this file but we want to import them anyway so they are importable from here.
|
|
64
49
|
__all__ = [
|
|
65
50
|
"constexpr",
|
|
66
|
-
"base_value",
|
|
67
|
-
"base_type",
|
|
68
|
-
"dtype",
|
|
69
|
-
"block_type",
|
|
70
51
|
"pointer_type",
|
|
71
|
-
"tuple_type",
|
|
72
52
|
"void",
|
|
73
53
|
"int1",
|
|
74
54
|
"int8",
|
|
@@ -83,24 +63,14 @@ __all__ = [
|
|
|
83
63
|
"float8e5b16",
|
|
84
64
|
"float8e4nv",
|
|
85
65
|
"float8e4b8",
|
|
86
|
-
"float8e4b8",
|
|
87
66
|
"float8e4b15",
|
|
88
67
|
"float16",
|
|
89
68
|
"bfloat16",
|
|
90
69
|
"float32",
|
|
91
70
|
"float64",
|
|
92
|
-
"
|
|
93
|
-
"tensor",
|
|
71
|
+
"static_range",
|
|
94
72
|
"tuple",
|
|
95
73
|
"tuple_type",
|
|
96
|
-
"thread_barrier",
|
|
97
|
-
"arange",
|
|
98
|
-
"full",
|
|
99
|
-
"convert_layout",
|
|
100
|
-
"allocate_shared_memory",
|
|
101
|
-
"shared_memory_descriptor",
|
|
102
|
-
"warp_specialize",
|
|
103
|
-
*_IMPORT_FROM_TRITON,
|
|
104
74
|
]
|
|
105
75
|
|
|
106
76
|
T = TypeVar("T")
|
|
@@ -109,6 +79,57 @@ T = TypeVar("T")
|
|
|
109
79
|
GLUON_BUILTIN = "__triton_builtin__"
|
|
110
80
|
|
|
111
81
|
|
|
82
|
+
def builtin(fn: T) -> T:
|
|
83
|
+
"""Mark a function as a builtin."""
|
|
84
|
+
assert callable(fn)
|
|
85
|
+
|
|
86
|
+
@wraps(fn)
|
|
87
|
+
def wrapper(*args, **kwargs):
|
|
88
|
+
if "_semantic" not in kwargs or kwargs["_semantic"] is None:
|
|
89
|
+
raise ValueError("Did you forget to add @triton.gluon.jit ? "
|
|
90
|
+
"(`_semantic` argument must be provided outside of JIT functions.)")
|
|
91
|
+
return fn(*args, **kwargs)
|
|
92
|
+
|
|
93
|
+
setattr(wrapper, GLUON_BUILTIN, True)
|
|
94
|
+
|
|
95
|
+
return wrapper
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# Explicitly import forwarded Triton language symbols so mypy sees them.
|
|
99
|
+
associative_scan = builtin(tl_core.associative_scan)
|
|
100
|
+
atomic_add = builtin(tl_core.atomic_add)
|
|
101
|
+
atomic_and = builtin(tl_core.atomic_and)
|
|
102
|
+
atomic_cas = builtin(tl_core.atomic_cas)
|
|
103
|
+
atomic_max = builtin(tl_core.atomic_max)
|
|
104
|
+
atomic_min = builtin(tl_core.atomic_min)
|
|
105
|
+
atomic_or = builtin(tl_core.atomic_or)
|
|
106
|
+
atomic_xchg = builtin(tl_core.atomic_xchg)
|
|
107
|
+
atomic_xor = builtin(tl_core.atomic_xor)
|
|
108
|
+
broadcast = builtin(tl_core.broadcast)
|
|
109
|
+
device_assert = builtin(tl_core.device_assert)
|
|
110
|
+
expand_dims = builtin(tl_core.expand_dims)
|
|
111
|
+
inline_asm_elementwise = builtin(tl_core.inline_asm_elementwise)
|
|
112
|
+
join = builtin(tl_core.join)
|
|
113
|
+
load = builtin(tl_core.load)
|
|
114
|
+
map_elementwise = builtin(tl_core.map_elementwise)
|
|
115
|
+
max_constancy = builtin(tl_core.max_constancy)
|
|
116
|
+
max_contiguous = builtin(tl_core.max_contiguous)
|
|
117
|
+
maximum = builtin(tl_core.maximum)
|
|
118
|
+
minimum = builtin(tl_core.minimum)
|
|
119
|
+
multiple_of = builtin(tl_core.multiple_of)
|
|
120
|
+
num_programs = builtin(tl_core.num_programs)
|
|
121
|
+
permute = builtin(tl_core.permute)
|
|
122
|
+
program_id = builtin(tl_core.program_id)
|
|
123
|
+
reduce = builtin(tl_core.reduce)
|
|
124
|
+
reshape = builtin(tl_core.reshape)
|
|
125
|
+
split = builtin(tl_core.split)
|
|
126
|
+
static_assert = builtin(tl_core.static_assert)
|
|
127
|
+
static_print = builtin(tl_core.static_print)
|
|
128
|
+
store = builtin(tl_core.store)
|
|
129
|
+
to_tensor = builtin(tl_core.to_tensor)
|
|
130
|
+
where = builtin(tl_core.where)
|
|
131
|
+
|
|
132
|
+
|
|
112
133
|
class distributed_type(block_type):
|
|
113
134
|
|
|
114
135
|
def __init__(self, element_ty: dtype, shape: List[int], layout):
|
|
@@ -131,21 +152,10 @@ class distributed_type(block_type):
|
|
|
131
152
|
def with_element_ty(self, scalar_ty: dtype) -> block_type:
|
|
132
153
|
return distributed_type(scalar_ty, self.shape, self.layout)
|
|
133
154
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
@wraps(fn)
|
|
140
|
-
def wrapper(*args, **kwargs):
|
|
141
|
-
if "_semantic" not in kwargs or kwargs["_semantic"] is None:
|
|
142
|
-
raise ValueError("Did you forget to add @triton.gluon.jit ? "
|
|
143
|
-
"(`_semantic` argument must be provided outside of JIT functions.)")
|
|
144
|
-
return fn(*args, **kwargs)
|
|
145
|
-
|
|
146
|
-
setattr(wrapper, GLUON_BUILTIN, True)
|
|
147
|
-
|
|
148
|
-
return wrapper
|
|
155
|
+
def __eq__(self, other) -> bool:
|
|
156
|
+
if not isinstance(other, distributed_type):
|
|
157
|
+
return False
|
|
158
|
+
return super().__eq__(other) and self.layout == other.layout
|
|
149
159
|
|
|
150
160
|
|
|
151
161
|
class shared_memory_descriptor_type(base_type):
|
|
@@ -188,6 +198,9 @@ class shared_memory_descriptor_type(base_type):
|
|
|
188
198
|
|
|
189
199
|
|
|
190
200
|
class shared_memory_descriptor(base_value):
|
|
201
|
+
"""
|
|
202
|
+
Represents a handle to a shared memory allocation in Gluon IR.
|
|
203
|
+
"""
|
|
191
204
|
|
|
192
205
|
def __init__(self, handle, element_ty, shape, layout, alloc_shape):
|
|
193
206
|
self.handle = handle
|
|
@@ -208,6 +221,10 @@ class shared_memory_descriptor(base_value):
|
|
|
208
221
|
def rank(self):
|
|
209
222
|
return len(self.shape)
|
|
210
223
|
|
|
224
|
+
@property
|
|
225
|
+
def numel(self) -> int:
|
|
226
|
+
return math.prod(self.shape)
|
|
227
|
+
|
|
211
228
|
@property
|
|
212
229
|
def layout(self):
|
|
213
230
|
return self.type.layout
|
|
@@ -216,16 +233,42 @@ class shared_memory_descriptor(base_value):
|
|
|
216
233
|
return str(self.type)
|
|
217
234
|
|
|
218
235
|
@builtin
|
|
219
|
-
def load(self, layout, _semantic: GluonSemantic) -> tensor:
|
|
236
|
+
def load(self, layout, _semantic: GluonSemantic = None) -> tensor:
|
|
237
|
+
"""
|
|
238
|
+
Load a tensor from shared memory.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
layout (DistributedLayout): The destination layout of the tensor.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
tensor: A Gluon tensor containing the loaded data.
|
|
245
|
+
"""
|
|
220
246
|
layout = _unwrap_if_constexpr(layout)
|
|
221
247
|
return _semantic.shared_load(self, layout)
|
|
222
248
|
|
|
223
249
|
@builtin
|
|
224
|
-
def store(self, value, _semantic: GluonSemantic) -> None:
|
|
250
|
+
def store(self, value, _semantic: GluonSemantic = None) -> None:
|
|
251
|
+
"""
|
|
252
|
+
Store a tensor into shared memory.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
value (tensor): The tensor whose contents to store.
|
|
256
|
+
"""
|
|
225
257
|
return _semantic.shared_store(self, value)
|
|
226
258
|
|
|
227
259
|
@builtin
|
|
228
260
|
def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
261
|
+
"""
|
|
262
|
+
Create a subview of shared memory by slicing along a given dimension.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
start (int): The starting index of the slice.
|
|
266
|
+
length (int): The length of the slice.
|
|
267
|
+
dim (int): The dimension to slice (default: 0).
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
shared_memory_descriptor: Descriptor for the sliced subview.
|
|
271
|
+
"""
|
|
229
272
|
start = _unwrap_if_constexpr(start)
|
|
230
273
|
length = _unwrap_if_constexpr(length)
|
|
231
274
|
dim = _unwrap_if_constexpr(dim)
|
|
@@ -233,23 +276,60 @@ class shared_memory_descriptor(base_value):
|
|
|
233
276
|
|
|
234
277
|
@builtin
|
|
235
278
|
def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
279
|
+
"""
|
|
280
|
+
Create a subview of shared memory by indexing along the first dimension.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
index (int): The index at which to take the subview.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
shared_memory_descriptor: Descriptor for the indexed subview.
|
|
287
|
+
"""
|
|
236
288
|
index = _unwrap_if_constexpr(index)
|
|
237
289
|
return _semantic.memdesc_index(self, index)
|
|
238
290
|
|
|
239
291
|
@builtin
|
|
240
|
-
def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor:
|
|
292
|
+
def permute(self, order, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
293
|
+
"""
|
|
294
|
+
Permute the dimensions of the shared memory descriptor.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
order (List[int]): The new ordering of dimensions.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
shared_memory_descriptor: Descriptor with permuted dimensions.
|
|
301
|
+
"""
|
|
241
302
|
order = [_unwrap_if_constexpr(o) for o in order]
|
|
242
303
|
return _semantic.memdesc_trans(self, order)
|
|
243
304
|
|
|
244
305
|
@builtin
|
|
245
|
-
def reshape(self, shape,
|
|
306
|
+
def reshape(self, shape, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
307
|
+
"""
|
|
308
|
+
Reshape the shared memory descriptor to a new shape and layout.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
shape (List[int]): The target shape.
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
shared_memory_descriptor: Descriptor with the new shape and layout.
|
|
315
|
+
"""
|
|
246
316
|
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
247
|
-
layout = _unwrap_if_constexpr(layout)
|
|
248
317
|
|
|
249
|
-
return _semantic.memdesc_reshape(self, shape
|
|
318
|
+
return _semantic.memdesc_reshape(self, shape)
|
|
250
319
|
|
|
251
320
|
@builtin
|
|
252
321
|
def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
|
|
322
|
+
"""
|
|
323
|
+
Reinterpret the shared memory descriptor as a different dtype, shape, or layout.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
dtype (dtype): The new data type.
|
|
327
|
+
shape (List[int]): The new shape.
|
|
328
|
+
layout (SharedLayout): The new layout.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
shared_memory_descriptor: Descriptor with updated type and layout.
|
|
332
|
+
"""
|
|
253
333
|
dtype = _unwrap_if_constexpr(dtype)
|
|
254
334
|
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
255
335
|
layout = _unwrap_if_constexpr(layout)
|
|
@@ -258,16 +338,25 @@ class shared_memory_descriptor(base_value):
|
|
|
258
338
|
|
|
259
339
|
@builtin
|
|
260
340
|
def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
|
|
341
|
+
"""
|
|
342
|
+
Dummy use to keep the shared memory descriptor alive.
|
|
343
|
+
"""
|
|
261
344
|
return _semantic.shared_dealloc(self)
|
|
262
345
|
|
|
263
346
|
|
|
264
|
-
for name in _IMPORT_FROM_TRITON:
|
|
265
|
-
fn = getattr(tl_core, name)
|
|
266
|
-
globals()[name] = builtin(fn)
|
|
267
|
-
|
|
268
|
-
|
|
269
347
|
@builtin
|
|
270
|
-
def arange(start, end, layout, _semantic=None):
|
|
348
|
+
def arange(start, end, layout=None, _semantic=None):
|
|
349
|
+
"""
|
|
350
|
+
Generate a sequence tensor with values in [start, end) using a specified layout.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
start (int): Inclusive start of the sequence.
|
|
354
|
+
end (int): Exclusive end of the sequence.
|
|
355
|
+
layout (DistributedLayout): The layout of the output tensor. Defaults to AutoLayout.
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
tensor: A 1D tensor containing sequential values.
|
|
359
|
+
"""
|
|
271
360
|
start = _unwrap_if_constexpr(start)
|
|
272
361
|
end = _unwrap_if_constexpr(end)
|
|
273
362
|
layout = _unwrap_if_constexpr(layout)
|
|
@@ -275,13 +364,36 @@ def arange(start, end, layout, _semantic=None):
|
|
|
275
364
|
|
|
276
365
|
|
|
277
366
|
@builtin
|
|
278
|
-
def convert_layout(value, layout, _semantic=None):
|
|
367
|
+
def convert_layout(value, layout, assert_trivial=False, _semantic=None):
|
|
368
|
+
"""
|
|
369
|
+
Convert a tensor to a different distributed layout.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
value (tensor): The input tensor.
|
|
373
|
+
layout (DistributedLayout): The target layout.
|
|
374
|
+
assert_trivial (bool): If True, asserts that the conversion is trivial (no data movement).
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
tensor: The tensor with the new layout.
|
|
378
|
+
"""
|
|
279
379
|
layout = _unwrap_if_constexpr(layout)
|
|
280
|
-
return _semantic.convert_layout(value, layout)
|
|
380
|
+
return _semantic.convert_layout(value, layout, assert_trivial)
|
|
281
381
|
|
|
282
382
|
|
|
283
383
|
@builtin
|
|
284
|
-
def full(shape, value, dtype, layout, _semantic=None):
|
|
384
|
+
def full(shape, value, dtype, layout=None, _semantic=None):
|
|
385
|
+
"""
|
|
386
|
+
Create a tensor filled with a scalar value, with specified shape, dtype, and layout.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
shape (Sequence[int]): The shape of the tensor.
|
|
390
|
+
value (int or float): The fill value.
|
|
391
|
+
dtype (dtype): The data type for the tensor.
|
|
392
|
+
layout (Optional[DistributedLayout]): The layout of the output tensor, defaults to AutoLayout().
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
tensor: A tensor where every element equals value.
|
|
396
|
+
"""
|
|
285
397
|
shape = _unwrap_shape(shape)
|
|
286
398
|
value = _unwrap_if_constexpr(value)
|
|
287
399
|
dtype = _unwrap_if_constexpr(dtype)
|
|
@@ -290,7 +402,40 @@ def full(shape, value, dtype, layout, _semantic=None):
|
|
|
290
402
|
|
|
291
403
|
|
|
292
404
|
@builtin
|
|
293
|
-
def
|
|
405
|
+
def histogram(input, num_bins, mask=None, layout=None, _semantic=None, _generator=None):
|
|
406
|
+
"""
|
|
407
|
+
Compute a histogram of a 1D integer tensor.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
input (tensor): 1D tensor of integer values.
|
|
411
|
+
num_bins (int): Number of bins. Bins have width 1 and start at 0.
|
|
412
|
+
mask (Optional[tensor]): Boolean mask to exclude elements when False.
|
|
413
|
+
layout (DistributedLayout): Destination layout of the output histogram.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
tensor: 1D int32 tensor of length `num_bins` with the requested layout.
|
|
417
|
+
"""
|
|
418
|
+
num_bins = _unwrap_if_constexpr(num_bins)
|
|
419
|
+
layout = _unwrap_if_constexpr(layout)
|
|
420
|
+
if mask is not None:
|
|
421
|
+
mask = _semantic.to_tensor(mask)
|
|
422
|
+
return _semantic.histogram(input, num_bins, mask, layout)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
@builtin
|
|
426
|
+
def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None) -> shared_memory_descriptor:
|
|
427
|
+
"""
|
|
428
|
+
Allocate shared memory for a tensor with the given element type, shape, and layout.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
element_ty (dtype): The element data type.
|
|
432
|
+
shape (Sequence[int]): The dimensions of the shared memory.
|
|
433
|
+
layout (SharedLayout): The shared memory layout.
|
|
434
|
+
value (tensor, optional): Initial value to copy into shared memory.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
shared_memory_descriptor: Descriptor for the allocated memory.
|
|
438
|
+
"""
|
|
294
439
|
element_ty = _unwrap_if_constexpr(element_ty)
|
|
295
440
|
shape = _unwrap_if_constexpr(shape)
|
|
296
441
|
shape = [_unwrap_if_constexpr(s) for s in shape]
|
|
@@ -299,14 +444,47 @@ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None
|
|
|
299
444
|
|
|
300
445
|
|
|
301
446
|
@builtin
|
|
302
|
-
def
|
|
447
|
+
def set_auto_layout(value, layout, _semantic=None):
|
|
448
|
+
"""
|
|
449
|
+
Set a a tensor with AutoLayout to a concrete layout
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
value (tensor): The input tensor.
|
|
453
|
+
layout (DistribtedLayout): The target layout.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
tensor: The tensor with the new layout.
|
|
457
|
+
"""
|
|
458
|
+
layout = _unwrap_if_constexpr(layout)
|
|
459
|
+
return _semantic.set_auto_layout(value, layout)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
@builtin
|
|
463
|
+
def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
|
|
303
464
|
_semantic=None, _generator=None):
|
|
465
|
+
"""
|
|
466
|
+
Create a warp-specialized execution region, partitioning work across warps.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
default_args (List[Any]): Arguments for the default region.
|
|
470
|
+
default_partition (callable): Function to build the default execution region.
|
|
471
|
+
worker_args (List[Any]): Arguments for each warp partition.
|
|
472
|
+
worker_partitions (List[callable]): Functions for each warp partition.
|
|
473
|
+
worker_num_warps (List[int]): Number of warps per partition.
|
|
474
|
+
worker_num_regs (List[int]): Number of registers per partition.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
Tuple[Any, ...]: Results from the default region.
|
|
478
|
+
"""
|
|
304
479
|
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
|
|
305
480
|
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
|
|
306
|
-
return _semantic.warp_specialize(
|
|
481
|
+
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
|
|
307
482
|
worker_num_regs, _generator)
|
|
308
483
|
|
|
309
484
|
|
|
310
485
|
@builtin
|
|
311
486
|
def thread_barrier(_semantic=None):
|
|
487
|
+
"""
|
|
488
|
+
Insert a barrier to synchronize threads within a CTA.
|
|
489
|
+
"""
|
|
312
490
|
return _semantic.debug_barrier()
|