triton-windows 3.4.0.post20__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ import math
2
3
  from typing import TypeVar, List, TYPE_CHECKING, Tuple
3
4
  from functools import wraps
4
5
 
@@ -37,38 +38,17 @@ from triton.language.core import (
37
38
  float64,
38
39
  _unwrap_if_constexpr,
39
40
  _unwrap_shape,
41
+ static_range,
40
42
  tensor,
41
43
  tuple,
42
44
  tuple_type,
43
45
  )
44
46
 
45
- _IMPORT_FROM_TRITON: List[str] = [
46
- "expand_dims",
47
- "join",
48
- "load",
49
- "maximum",
50
- "minimum",
51
- "permute",
52
- "program_id",
53
- "reduce",
54
- "reshape",
55
- "split",
56
- "static_assert",
57
- "static_print",
58
- "store",
59
- "to_tensor",
60
- "where",
61
- "inline_asm_elementwise",
62
- ]
63
-
47
+ # We define __all__ only to appease the python linter, these are not used in
48
+ # this file but we want to import them anyway so they are importable from here.
64
49
  __all__ = [
65
50
  "constexpr",
66
- "base_value",
67
- "base_type",
68
- "dtype",
69
- "block_type",
70
51
  "pointer_type",
71
- "tuple_type",
72
52
  "void",
73
53
  "int1",
74
54
  "int8",
@@ -83,24 +63,14 @@ __all__ = [
83
63
  "float8e5b16",
84
64
  "float8e4nv",
85
65
  "float8e4b8",
86
- "float8e4b8",
87
66
  "float8e4b15",
88
67
  "float16",
89
68
  "bfloat16",
90
69
  "float32",
91
70
  "float64",
92
- "_unwrap_if_constexpr",
93
- "tensor",
71
+ "static_range",
94
72
  "tuple",
95
73
  "tuple_type",
96
- "thread_barrier",
97
- "arange",
98
- "full",
99
- "convert_layout",
100
- "allocate_shared_memory",
101
- "shared_memory_descriptor",
102
- "warp_specialize",
103
- *_IMPORT_FROM_TRITON,
104
74
  ]
105
75
 
106
76
  T = TypeVar("T")
@@ -109,6 +79,57 @@ T = TypeVar("T")
109
79
  GLUON_BUILTIN = "__triton_builtin__"
110
80
 
111
81
 
82
+ def builtin(fn: T) -> T:
83
+ """Mark a function as a builtin."""
84
+ assert callable(fn)
85
+
86
+ @wraps(fn)
87
+ def wrapper(*args, **kwargs):
88
+ if "_semantic" not in kwargs or kwargs["_semantic"] is None:
89
+ raise ValueError("Did you forget to add @triton.gluon.jit ? "
90
+ "(`_semantic` argument must be provided outside of JIT functions.)")
91
+ return fn(*args, **kwargs)
92
+
93
+ setattr(wrapper, GLUON_BUILTIN, True)
94
+
95
+ return wrapper
96
+
97
+
98
+ # Explicitly import forwarded Triton language symbols so mypy sees them.
99
+ associative_scan = builtin(tl_core.associative_scan)
100
+ atomic_add = builtin(tl_core.atomic_add)
101
+ atomic_and = builtin(tl_core.atomic_and)
102
+ atomic_cas = builtin(tl_core.atomic_cas)
103
+ atomic_max = builtin(tl_core.atomic_max)
104
+ atomic_min = builtin(tl_core.atomic_min)
105
+ atomic_or = builtin(tl_core.atomic_or)
106
+ atomic_xchg = builtin(tl_core.atomic_xchg)
107
+ atomic_xor = builtin(tl_core.atomic_xor)
108
+ broadcast = builtin(tl_core.broadcast)
109
+ device_assert = builtin(tl_core.device_assert)
110
+ expand_dims = builtin(tl_core.expand_dims)
111
+ inline_asm_elementwise = builtin(tl_core.inline_asm_elementwise)
112
+ join = builtin(tl_core.join)
113
+ load = builtin(tl_core.load)
114
+ map_elementwise = builtin(tl_core.map_elementwise)
115
+ max_constancy = builtin(tl_core.max_constancy)
116
+ max_contiguous = builtin(tl_core.max_contiguous)
117
+ maximum = builtin(tl_core.maximum)
118
+ minimum = builtin(tl_core.minimum)
119
+ multiple_of = builtin(tl_core.multiple_of)
120
+ num_programs = builtin(tl_core.num_programs)
121
+ permute = builtin(tl_core.permute)
122
+ program_id = builtin(tl_core.program_id)
123
+ reduce = builtin(tl_core.reduce)
124
+ reshape = builtin(tl_core.reshape)
125
+ split = builtin(tl_core.split)
126
+ static_assert = builtin(tl_core.static_assert)
127
+ static_print = builtin(tl_core.static_print)
128
+ store = builtin(tl_core.store)
129
+ to_tensor = builtin(tl_core.to_tensor)
130
+ where = builtin(tl_core.where)
131
+
132
+
112
133
  class distributed_type(block_type):
113
134
 
114
135
  def __init__(self, element_ty: dtype, shape: List[int], layout):
@@ -131,21 +152,10 @@ class distributed_type(block_type):
131
152
  def with_element_ty(self, scalar_ty: dtype) -> block_type:
132
153
  return distributed_type(scalar_ty, self.shape, self.layout)
133
154
 
134
-
135
- def builtin(fn: T) -> T:
136
- """Mark a function as a builtin."""
137
- assert callable(fn)
138
-
139
- @wraps(fn)
140
- def wrapper(*args, **kwargs):
141
- if "_semantic" not in kwargs or kwargs["_semantic"] is None:
142
- raise ValueError("Did you forget to add @triton.gluon.jit ? "
143
- "(`_semantic` argument must be provided outside of JIT functions.)")
144
- return fn(*args, **kwargs)
145
-
146
- setattr(wrapper, GLUON_BUILTIN, True)
147
-
148
- return wrapper
155
+ def __eq__(self, other) -> bool:
156
+ if not isinstance(other, distributed_type):
157
+ return False
158
+ return super().__eq__(other) and self.layout == other.layout
149
159
 
150
160
 
151
161
  class shared_memory_descriptor_type(base_type):
@@ -188,6 +198,9 @@ class shared_memory_descriptor_type(base_type):
188
198
 
189
199
 
190
200
  class shared_memory_descriptor(base_value):
201
+ """
202
+ Represents a handle to a shared memory allocation in Gluon IR.
203
+ """
191
204
 
192
205
  def __init__(self, handle, element_ty, shape, layout, alloc_shape):
193
206
  self.handle = handle
@@ -208,6 +221,10 @@ class shared_memory_descriptor(base_value):
208
221
  def rank(self):
209
222
  return len(self.shape)
210
223
 
224
+ @property
225
+ def numel(self) -> int:
226
+ return math.prod(self.shape)
227
+
211
228
  @property
212
229
  def layout(self):
213
230
  return self.type.layout
@@ -216,16 +233,42 @@ class shared_memory_descriptor(base_value):
216
233
  return str(self.type)
217
234
 
218
235
  @builtin
219
- def load(self, layout, _semantic: GluonSemantic) -> tensor:
236
+ def load(self, layout, _semantic: GluonSemantic = None) -> tensor:
237
+ """
238
+ Load a tensor from shared memory.
239
+
240
+ Args:
241
+ layout (DistributedLayout): The destination layout of the tensor.
242
+
243
+ Returns:
244
+ tensor: A Gluon tensor containing the loaded data.
245
+ """
220
246
  layout = _unwrap_if_constexpr(layout)
221
247
  return _semantic.shared_load(self, layout)
222
248
 
223
249
  @builtin
224
- def store(self, value, _semantic: GluonSemantic) -> None:
250
+ def store(self, value, _semantic: GluonSemantic = None) -> None:
251
+ """
252
+ Store a tensor into shared memory.
253
+
254
+ Args:
255
+ value (tensor): The tensor whose contents to store.
256
+ """
225
257
  return _semantic.shared_store(self, value)
226
258
 
227
259
  @builtin
228
260
  def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
261
+ """
262
+ Create a subview of shared memory by slicing along a given dimension.
263
+
264
+ Args:
265
+ start (int): The starting index of the slice.
266
+ length (int): The length of the slice.
267
+ dim (int): The dimension to slice (default: 0).
268
+
269
+ Returns:
270
+ shared_memory_descriptor: Descriptor for the sliced subview.
271
+ """
229
272
  start = _unwrap_if_constexpr(start)
230
273
  length = _unwrap_if_constexpr(length)
231
274
  dim = _unwrap_if_constexpr(dim)
@@ -233,23 +276,60 @@ class shared_memory_descriptor(base_value):
233
276
 
234
277
  @builtin
235
278
  def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
279
+ """
280
+ Create a subview of shared memory by indexing along the first dimension.
281
+
282
+ Args:
283
+ index (int): The index at which to take the subview.
284
+
285
+ Returns:
286
+ shared_memory_descriptor: Descriptor for the indexed subview.
287
+ """
236
288
  index = _unwrap_if_constexpr(index)
237
289
  return _semantic.memdesc_index(self, index)
238
290
 
239
291
  @builtin
240
- def permute(self, order, _semantic: GluonSemantic) -> shared_memory_descriptor:
292
+ def permute(self, order, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
293
+ """
294
+ Permute the dimensions of the shared memory descriptor.
295
+
296
+ Args:
297
+ order (List[int]): The new ordering of dimensions.
298
+
299
+ Returns:
300
+ shared_memory_descriptor: Descriptor with permuted dimensions.
301
+ """
241
302
  order = [_unwrap_if_constexpr(o) for o in order]
242
303
  return _semantic.memdesc_trans(self, order)
243
304
 
244
305
  @builtin
245
- def reshape(self, shape, layout, _semantic: GluonSemantic) -> shared_memory_descriptor:
306
+ def reshape(self, shape, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
307
+ """
308
+ Reshape the shared memory descriptor to a new shape and layout.
309
+
310
+ Args:
311
+ shape (List[int]): The target shape.
312
+
313
+ Returns:
314
+ shared_memory_descriptor: Descriptor with the new shape and layout.
315
+ """
246
316
  shape = [_unwrap_if_constexpr(s) for s in shape]
247
- layout = _unwrap_if_constexpr(layout)
248
317
 
249
- return _semantic.memdesc_reshape(self, shape, layout)
318
+ return _semantic.memdesc_reshape(self, shape)
250
319
 
251
320
  @builtin
252
321
  def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
322
+ """
323
+ Reinterpret the shared memory descriptor as a different dtype, shape, or layout.
324
+
325
+ Args:
326
+ dtype (dtype): The new data type.
327
+ shape (List[int]): The new shape.
328
+ layout (SharedLayout): The new layout.
329
+
330
+ Returns:
331
+ shared_memory_descriptor: Descriptor with updated type and layout.
332
+ """
253
333
  dtype = _unwrap_if_constexpr(dtype)
254
334
  shape = [_unwrap_if_constexpr(s) for s in shape]
255
335
  layout = _unwrap_if_constexpr(layout)
@@ -258,16 +338,25 @@ class shared_memory_descriptor(base_value):
258
338
 
259
339
  @builtin
260
340
  def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
341
+ """
342
+ Dummy use to keep the shared memory descriptor alive.
343
+ """
261
344
  return _semantic.shared_dealloc(self)
262
345
 
263
346
 
264
- for name in _IMPORT_FROM_TRITON:
265
- fn = getattr(tl_core, name)
266
- globals()[name] = builtin(fn)
267
-
268
-
269
347
  @builtin
270
- def arange(start, end, layout, _semantic=None):
348
+ def arange(start, end, layout=None, _semantic=None):
349
+ """
350
+ Generate a sequence tensor with values in [start, end) using a specified layout.
351
+
352
+ Args:
353
+ start (int): Inclusive start of the sequence.
354
+ end (int): Exclusive end of the sequence.
355
+ layout (DistributedLayout): The layout of the output tensor. Defaults to AutoLayout.
356
+
357
+ Returns:
358
+ tensor: A 1D tensor containing sequential values.
359
+ """
271
360
  start = _unwrap_if_constexpr(start)
272
361
  end = _unwrap_if_constexpr(end)
273
362
  layout = _unwrap_if_constexpr(layout)
@@ -275,13 +364,36 @@ def arange(start, end, layout, _semantic=None):
275
364
 
276
365
 
277
366
  @builtin
278
- def convert_layout(value, layout, _semantic=None):
367
+ def convert_layout(value, layout, assert_trivial=False, _semantic=None):
368
+ """
369
+ Convert a tensor to a different distributed layout.
370
+
371
+ Args:
372
+ value (tensor): The input tensor.
373
+ layout (DistributedLayout): The target layout.
374
+ assert_trivial (bool): If True, asserts that the conversion is trivial (no data movement).
375
+
376
+ Returns:
377
+ tensor: The tensor with the new layout.
378
+ """
279
379
  layout = _unwrap_if_constexpr(layout)
280
- return _semantic.convert_layout(value, layout)
380
+ return _semantic.convert_layout(value, layout, assert_trivial)
281
381
 
282
382
 
283
383
  @builtin
284
- def full(shape, value, dtype, layout, _semantic=None):
384
+ def full(shape, value, dtype, layout=None, _semantic=None):
385
+ """
386
+ Create a tensor filled with a scalar value, with specified shape, dtype, and layout.
387
+
388
+ Args:
389
+ shape (Sequence[int]): The shape of the tensor.
390
+ value (int or float): The fill value.
391
+ dtype (dtype): The data type for the tensor.
392
+ layout (Optional[DistributedLayout]): The layout of the output tensor, defaults to AutoLayout().
393
+
394
+ Returns:
395
+ tensor: A tensor where every element equals value.
396
+ """
285
397
  shape = _unwrap_shape(shape)
286
398
  value = _unwrap_if_constexpr(value)
287
399
  dtype = _unwrap_if_constexpr(dtype)
@@ -290,7 +402,40 @@ def full(shape, value, dtype, layout, _semantic=None):
290
402
 
291
403
 
292
404
  @builtin
293
- def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None):
405
+ def histogram(input, num_bins, mask=None, layout=None, _semantic=None, _generator=None):
406
+ """
407
+ Compute a histogram of a 1D integer tensor.
408
+
409
+ Args:
410
+ input (tensor): 1D tensor of integer values.
411
+ num_bins (int): Number of bins. Bins have width 1 and start at 0.
412
+ mask (Optional[tensor]): Boolean mask to exclude elements when False.
413
+ layout (DistributedLayout): Destination layout of the output histogram.
414
+
415
+ Returns:
416
+ tensor: 1D int32 tensor of length `num_bins` with the requested layout.
417
+ """
418
+ num_bins = _unwrap_if_constexpr(num_bins)
419
+ layout = _unwrap_if_constexpr(layout)
420
+ if mask is not None:
421
+ mask = _semantic.to_tensor(mask)
422
+ return _semantic.histogram(input, num_bins, mask, layout)
423
+
424
+
425
+ @builtin
426
+ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None) -> shared_memory_descriptor:
427
+ """
428
+ Allocate shared memory for a tensor with the given element type, shape, and layout.
429
+
430
+ Args:
431
+ element_ty (dtype): The element data type.
432
+ shape (Sequence[int]): The dimensions of the shared memory.
433
+ layout (SharedLayout): The shared memory layout.
434
+ value (tensor, optional): Initial value to copy into shared memory.
435
+
436
+ Returns:
437
+ shared_memory_descriptor: Descriptor for the allocated memory.
438
+ """
294
439
  element_ty = _unwrap_if_constexpr(element_ty)
295
440
  shape = _unwrap_if_constexpr(shape)
296
441
  shape = [_unwrap_if_constexpr(s) for s in shape]
@@ -299,14 +444,47 @@ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None
299
444
 
300
445
 
301
446
  @builtin
302
- def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, #
447
+ def set_auto_layout(value, layout, _semantic=None):
448
+ """
449
+ Set a a tensor with AutoLayout to a concrete layout
450
+
451
+ Args:
452
+ value (tensor): The input tensor.
453
+ layout (DistribtedLayout): The target layout.
454
+
455
+ Returns:
456
+ tensor: The tensor with the new layout.
457
+ """
458
+ layout = _unwrap_if_constexpr(layout)
459
+ return _semantic.set_auto_layout(value, layout)
460
+
461
+
462
+ @builtin
463
+ def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
303
464
  _semantic=None, _generator=None):
465
+ """
466
+ Create a warp-specialized execution region, partitioning work across warps.
467
+
468
+ Args:
469
+ default_args (List[Any]): Arguments for the default region.
470
+ default_partition (callable): Function to build the default execution region.
471
+ worker_args (List[Any]): Arguments for each warp partition.
472
+ worker_partitions (List[callable]): Functions for each warp partition.
473
+ worker_num_warps (List[int]): Number of warps per partition.
474
+ worker_num_regs (List[int]): Number of registers per partition.
475
+
476
+ Returns:
477
+ Tuple[Any, ...]: Results from the default region.
478
+ """
304
479
  worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
305
480
  worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
306
- return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
481
+ return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
307
482
  worker_num_regs, _generator)
308
483
 
309
484
 
310
485
  @builtin
311
486
  def thread_barrier(_semantic=None):
487
+ """
488
+ Insert a barrier to synchronize threads within a CTA.
489
+ """
312
490
  return _semantic.debug_barrier()