triton-windows 3.4.0.post20__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -55,9 +55,10 @@ from .core import (
55
55
  cat,
56
56
  cast,
57
57
  clamp,
58
+ condition,
58
59
  const,
59
60
  constexpr,
60
- constexpr_function,
61
+ constexpr_type,
61
62
  debug_barrier,
62
63
  device_assert,
63
64
  device_print,
@@ -85,6 +86,7 @@ from .core import (
85
86
  join,
86
87
  load,
87
88
  make_block_ptr,
89
+ map_elementwise,
88
90
  max_constancy,
89
91
  max_contiguous,
90
92
  maximum,
@@ -130,6 +132,7 @@ from .random import (
130
132
  randn4x,
131
133
  uint_to_uniform_float,
132
134
  )
135
+ from . import target_info
133
136
 
134
137
  __all__ = [
135
138
  "PropagateNan",
@@ -165,9 +168,10 @@ __all__ = [
165
168
  "cdiv",
166
169
  "ceil",
167
170
  "clamp",
171
+ "condition",
168
172
  "const",
169
173
  "constexpr",
170
- "constexpr_function",
174
+ "constexpr_type",
171
175
  "cos",
172
176
  "cumprod",
173
177
  "cumsum",
@@ -210,6 +214,7 @@ __all__ = [
210
214
  "log",
211
215
  "log2",
212
216
  "make_block_ptr",
217
+ "map_elementwise",
213
218
  "math",
214
219
  "max",
215
220
  "max_constancy",
@@ -252,6 +257,7 @@ __all__ = [
252
257
  "store",
253
258
  "sum",
254
259
  "swizzle2d",
260
+ "target_info",
255
261
  "tensor",
256
262
  "topk",
257
263
  "trans",
@@ -271,12 +277,12 @@ __all__ = [
271
277
  ]
272
278
 
273
279
 
274
- def str_to_ty(name):
280
+ def str_to_ty(name, c):
275
281
  from builtins import tuple
276
282
 
277
283
  if isinstance(name, tuple):
278
284
  fields = type(name).__dict__.get("_fields", None)
279
- return tuple_type([str_to_ty(x) for x in name], fields)
285
+ return tuple_type([str_to_ty(x, c) for x in name], fields)
280
286
 
281
287
  if name[0] == "*":
282
288
  name = name[1:]
@@ -284,17 +290,17 @@ def str_to_ty(name):
284
290
  if name[0] == "k":
285
291
  name = name[1:]
286
292
  const = True
287
- ty = str_to_ty(name)
293
+ ty = str_to_ty(name, c)
288
294
  return pointer_type(element_ty=ty, const=const)
289
295
 
290
296
  if name.startswith("tensordesc"):
291
297
  inner = name.split("<")[1].rstrip(">")
292
- dtype, rest = inner.split("[", maxsplit=2)
293
- block_shape, rest = rest.split("]", maxsplit=2)
298
+ dtype, rest = inner.split("[", maxsplit=1)
299
+ block_shape, rest = rest.split("]", maxsplit=1)
294
300
  block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")]
295
301
  layout = rest.lstrip(",")
296
302
  is_gluon = len(layout)
297
- dtype = str_to_ty(dtype)
303
+ dtype = str_to_ty(dtype, None)
298
304
  ndim = len(block_shape)
299
305
  shape_type = tuple_type([int32] * ndim)
300
306
  # FIXME: Last dim stride should be constexpr(1)
@@ -308,8 +314,8 @@ def str_to_ty(name):
308
314
  return gluon_tensor_descriptor_type(block, shape_type, stride_type, layout)
309
315
  return tensor_descriptor_type(block, shape_type, stride_type)
310
316
 
311
- if name == "constexpr":
312
- return constexpr
317
+ if name.startswith("constexpr"):
318
+ return constexpr_type(c)
313
319
 
314
320
  tys = {
315
321
  "fp8e4nv": float8e4nv,
triton/language/core.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import math
3
4
  from warnings import warn
4
5
  from contextlib import contextmanager
5
6
  from enum import Enum
@@ -9,7 +10,7 @@ from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
9
10
  from dataclasses import dataclass
10
11
  import builtins
11
12
  from .. import knobs
12
- from ..runtime.jit import jit, JITFunction
13
+ from ..runtime.jit import JITCallable
13
14
  import inspect
14
15
 
15
16
  from .._C.libtriton import ir
@@ -86,7 +87,7 @@ def _tensor_member_fn(fn: T) -> T:
86
87
  if is_builtin(fn):
87
88
  setattr(wrapper, TRITON_BUILTIN, True)
88
89
 
89
- setattr(tensor, fn.__name__, fn if isinstance(fn, JITFunction) else wrapper)
90
+ setattr(tensor, fn.__name__, fn if isinstance(fn, JITCallable) else wrapper)
90
91
  return fn
91
92
 
92
93
 
@@ -152,10 +153,10 @@ class base_value:
152
153
 
153
154
  class base_type:
154
155
 
155
- def __eq__(self, other):
156
+ def __eq__(self, other) -> bool:
156
157
  raise NotImplementedError("Types must implement __eq__")
157
158
 
158
- def __ne__(self, other):
159
+ def __ne__(self, other) -> bool:
159
160
  return not (self == other)
160
161
 
161
162
  def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
@@ -178,10 +179,13 @@ class constexpr_type(base_type):
178
179
  self.value = value
179
180
 
180
181
  def __eq__(self, other):
181
- return self.value == other.value
182
+ return isinstance(other, constexpr_type) and self.value == other.value
182
183
 
183
184
  def __repr__(self) -> str:
184
- return f"constexpr[{self.value}]"
185
+ return f"constexpr_type[{self.value}]"
186
+
187
+ def __hash__(self):
188
+ return hash(self.value)
185
189
 
186
190
  def mangle(self) -> str:
187
191
  return repr(self)
@@ -199,15 +203,17 @@ class constexpr(base_value):
199
203
  """
200
204
 
201
205
  def __init__(self, value):
202
- if isinstance(value, constexpr):
203
- self.value = value.value
204
- else:
205
- self.value = value
206
+ while isinstance(value, constexpr):
207
+ value = value.value
208
+ self.value = value
206
209
  self.type = constexpr_type(value)
207
210
 
208
211
  def __repr__(self) -> str:
209
212
  return f"constexpr[{self.value}]"
210
213
 
214
+ def __hash__(self):
215
+ return hash((self.value, self.type))
216
+
211
217
  def _flatten_ir(self, handles: List[ir.value]) -> None:
212
218
  return
213
219
 
@@ -334,32 +340,6 @@ class constexpr(base_value):
334
340
  return self.value.__getitem__(*args)
335
341
 
336
342
 
337
- def constexpr_function(f):
338
- """
339
- Wraps an arbitrary Python function so that it can be called at
340
- compile-time on constexpr arguments in a Triton function and
341
- returns a constexpr result.
342
- """
343
-
344
- @wraps(f)
345
- def wrapper(*args, _semantic=None, **kwargs):
346
- # de-constexpr arguments and discard the _semantic keyword argument:
347
- args = [_unwrap_if_constexpr(x) for x in args]
348
- kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()}
349
-
350
- # call the raw Python function f:
351
- res = f(*args, **kwargs)
352
-
353
- # convert result back to a Triton constexpr:
354
- return constexpr(res)
355
-
356
- # disguise the function as a Triton builtin to avoid raising an error
357
- # that we're calling a non-JIT function from within a Triton kernel:
358
- wrapper.__triton_builtin__ = True
359
- wrapper.__module__ = constexpr_function.__module__
360
- return wrapper
361
-
362
-
363
343
  CONSTEXPR_0 = constexpr(0)
364
344
 
365
345
 
@@ -572,7 +552,8 @@ class dtype(base_type):
572
552
  def is_const():
573
553
  return False
574
554
 
575
- def __eq__(self, other: dtype):
555
+ def __eq__(self, other) -> bool:
556
+ other = _unwrap_if_constexpr(other)
576
557
  if not isinstance(other, dtype):
577
558
  return False
578
559
  return self.name == other.name
@@ -696,7 +677,8 @@ class pointer_type(dtype):
696
677
  def is_const(self):
697
678
  return self.const
698
679
 
699
- def __eq__(self, other: pointer_type) -> bool:
680
+ def __eq__(self, other) -> bool:
681
+ other = _unwrap_if_constexpr(other)
700
682
  if not isinstance(other, pointer_type):
701
683
  return False
702
684
  return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const
@@ -753,6 +735,10 @@ class block_type(dtype):
753
735
  def scalar(self):
754
736
  return self.element_ty
755
737
 
738
+ @property
739
+ def nbytes(self):
740
+ return self.numel * (self.element_ty.primitive_bitwidth // 8)
741
+
756
742
  def mangle(self) -> str:
757
743
  elt = self.scalar.mangle()
758
744
  shape = '_'.join(map(str, self.shape))
@@ -879,10 +865,7 @@ class tensor(base_value):
879
865
  self.handle = handle
880
866
  # Block shape
881
867
  self.shape = type.shape if type.is_block() else ()
882
- self.numel = 1
883
- for s in self.shape:
884
- self.numel *= s
885
- self.numel = constexpr(self.numel)
868
+ self.numel = constexpr(math.prod(self.shape))
886
869
  self.type = type # Tensor type (can be block_type)
887
870
  # Following the practice in pytorch, dtype is scalar type
888
871
  self.dtype = type.scalar
@@ -1268,19 +1251,20 @@ class tensor(base_value):
1268
1251
  ...
1269
1252
 
1270
1253
 
1271
- class tuple(base_value):
1254
+ def _type_for_tuple_values(values, fields=None):
1255
+ return tuple_type([constexpr_type(x) if isinstance(x, (int, float, dtype)) else x.type for x in values], fields)
1272
1256
 
1273
- def __init__(self, args: Sequence, type: tuple_type = None):
1274
- self.values = [i for i in args]
1275
1257
 
1276
- def get_type(x):
1277
- if isinstance(x, dtype):
1278
- return dtype
1279
- if isinstance(x, (int, float)):
1280
- return constexpr
1281
- return x.type
1258
+ class tuple(base_value):
1282
1259
 
1283
- self.type = type or tuple_type([get_type(x) for x in self.values])
1260
+ def __init__(self, args: Sequence, type: Optional[tuple_type] = None):
1261
+ self.values = [i for i in args]
1262
+ if isinstance(type, tuple_type):
1263
+ self.type = type
1264
+ elif type is not None: # make_template in ASTFunction.deserialize may pass us a list/tuple
1265
+ self.type = tuple_type(type)
1266
+ else:
1267
+ self.type = _type_for_tuple_values(self.values)
1284
1268
 
1285
1269
  def __getitem__(self, idx: constexpr):
1286
1270
  if isinstance(idx, int):
@@ -1295,11 +1279,11 @@ class tuple(base_value):
1295
1279
  return self.values[self.type.fields.index(name)]
1296
1280
 
1297
1281
  # TODO: remove
1298
- def __setitem__(self, idx: constexpr, value):
1299
- if isinstance(idx, int):
1300
- idx = constexpr(idx)
1301
- assert isinstance(idx, constexpr)
1282
+ def _setitem(self, idx, value):
1283
+ idx = _unwrap_if_constexpr(idx)
1284
+ assert isinstance(idx, int)
1302
1285
  self.values[idx] = value
1286
+ self.type = _type_for_tuple_values(self.values, self.type.fields)
1303
1287
 
1304
1288
  def __add__(self, other):
1305
1289
  other = _normalize_tuple(other)
@@ -1560,7 +1544,7 @@ def _aggregate(cls):
1560
1544
  def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs):
1561
1545
  # Call into the user-defined constructor.
1562
1546
  instance = this_cls._get_instance()
1563
- if isinstance(cls.__init__, JITFunction):
1547
+ if isinstance(cls.__init__, JITCallable):
1564
1548
  raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
1565
1549
  extra_kwargs = {}
1566
1550
  if "_semantic" in inspect.signature(cls.__init__).parameters:
@@ -1594,7 +1578,7 @@ def _aggregate(cls):
1594
1578
  [(name, getattr(self, name).type) for name in cls.__annotations__.keys()])
1595
1579
 
1596
1580
  for (name, member) in inspect.getmembers(cls):
1597
- if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITFunction):
1581
+ if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITCallable):
1598
1582
  if name != "__init__":
1599
1583
  setattr(aggregate_value, name, member)
1600
1584
 
@@ -1828,11 +1812,6 @@ def join(a, b, _semantic=None):
1828
1812
  return _semantic.join(a, b)
1829
1813
 
1830
1814
 
1831
- @jit
1832
- def _take_first(a, b):
1833
- return a
1834
-
1835
-
1836
1815
  def _unsplat(x, _semantic=None, _generator=None):
1837
1816
  """
1838
1817
  Convert a single-element tensor to a scalar.
@@ -1843,10 +1822,7 @@ def _unsplat(x, _semantic=None, _generator=None):
1843
1822
  for d in x.shape:
1844
1823
  numel *= d
1845
1824
  assert numel == 1, "can only unsplat single-element tensors"
1846
- if len(x.shape) >= 2:
1847
- x = _semantic.reshape(x, [1])
1848
- x = typing.cast(tensor, reduce(x, 0, _take_first, _semantic=_semantic, _generator=_generator))
1849
- return x
1825
+ return _semantic.unsplat(x)
1850
1826
 
1851
1827
 
1852
1828
  @_tensor_member_fn
@@ -2252,6 +2228,7 @@ def make_tensor_descriptor(
2252
2228
  shape: List[tensor],
2253
2229
  strides: List[tensor],
2254
2230
  block_shape: List[constexpr],
2231
+ padding_option="zero",
2255
2232
  _semantic=None,
2256
2233
  ) -> tensor_descriptor:
2257
2234
  """Make a tensor descriptor object
@@ -2301,7 +2278,9 @@ def make_tensor_descriptor(
2301
2278
  inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)
2302
2279
 
2303
2280
  """
2304
- return _semantic.make_tensor_descriptor(base, shape, strides, block_shape)
2281
+
2282
+ padding_option = _unwrap_if_constexpr(padding_option)
2283
+ return _semantic.make_tensor_descriptor(base, shape, strides, block_shape, padding_option)
2305
2284
 
2306
2285
 
2307
2286
  # -----------------------
@@ -2784,6 +2763,79 @@ def gather(src, index, axis, _semantic=None):
2784
2763
  return _semantic.gather(src, index, axis)
2785
2764
 
2786
2765
 
2766
+ @builtin
2767
+ def map_elementwise(
2768
+ scalar_fn: Callable[..., Tuple[tensor, ...]],
2769
+ *args: tensor,
2770
+ pack=1,
2771
+ _semantic=None,
2772
+ _generator=None,
2773
+ ):
2774
+ '''
2775
+ Map a scalar function over a tensor.
2776
+
2777
+ The input tensors :code:`args` are implicitly broadcasted to the same shape.
2778
+
2779
+ This may be useful in allowing control flow over single elements in a tensor,
2780
+ for example a multi-branch function where one branch is more expensive. With
2781
+ :code:`tl.where` you are forced to calculate both sides of the branch, but
2782
+ with an if we only execute one side.
2783
+
2784
+ .. highlight:: python
2785
+ .. code-block:: python
2786
+
2787
+ @triton.jit
2788
+ def selu_scalar(x, alpha):
2789
+ if x > 0:
2790
+ return a
2791
+ else:
2792
+ return alpha * (tl.exp(x) - 1)
2793
+
2794
+ @triton.jit
2795
+ def selu(x, alpha):
2796
+ return tl.map_elementwise(selu_scalar, x, alpha)
2797
+
2798
+ :param scalar_fn: the function to map over.
2799
+ :param pack: the number of elements to be processed by one function call.
2800
+ :return: one tensor or a tuple of tensors, depending on the mapped function.
2801
+ '''
2802
+ # Build the block for the nested region first to discover the return types
2803
+ assert pack >= 1
2804
+ in_scalar_tys = [t.type.scalar for t in args]
2805
+ builder = _semantic.builder
2806
+ block = builder.new_block()
2807
+ scalar_args = []
2808
+ for i, ty in enumerate(in_scalar_tys):
2809
+ for j in builtins.range(pack):
2810
+ block.add_argument(ty.to_ir(builder))
2811
+ scalar_args.append(tensor(block.arg(i * pack + j), ty))
2812
+
2813
+ with _insertion_guard(builder):
2814
+ builder.set_insertion_point_to_start(block)
2815
+ scalar_results = _generator.call_JitFunction(scalar_fn, scalar_args, kwargs={})
2816
+
2817
+ is_single = isinstance(scalar_results, tensor)
2818
+ if is_single:
2819
+ scalar_results = scalar_results,
2820
+
2821
+ handles = [r.handle for r in scalar_results]
2822
+ builder.create_map_elementwise_ret(handles)
2823
+
2824
+ fn_result_types = [x.type for x in scalar_results]
2825
+ scalar_result_types = fn_result_types
2826
+ if pack > 1:
2827
+ scalar_result_types = fn_result_types[::pack]
2828
+ for offset in builtins.range(1, pack):
2829
+ assert scalar_result_types == fn_result_types[offset::pack], "type mismatch in unpacked results"
2830
+
2831
+ def make_elementwise_region(elementwise_op):
2832
+ region = elementwise_op.get_region(0)
2833
+ region.push_back(block)
2834
+
2835
+ result = _semantic.map_elementwise(args, scalar_result_types, pack, make_elementwise_region)
2836
+ return result[0] if is_single else result
2837
+
2838
+
2787
2839
  # -----------------------
2788
2840
  # Compiler Hint Ops
2789
2841
  # -----------------------
@@ -2941,7 +2993,7 @@ def device_print(prefix, *args, hex=False, _semantic=None):
2941
2993
 
2942
2994
 
2943
2995
  @builtin
2944
- def device_assert(cond, msg="", _semantic=None):
2996
+ def device_assert(cond, msg="", mask=None, _semantic=None):
2945
2997
  '''
2946
2998
  Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
2947
2999
  is set to a value besides :code:`0` in order for this to have any effect.
@@ -2960,7 +3012,10 @@ def device_assert(cond, msg="", _semantic=None):
2960
3012
  :param msg: the message to print if the assertion fails. This is required to be a string literal.
2961
3013
  '''
2962
3014
  msg = _unwrap_if_constexpr(msg)
2963
- return _semantic.device_assert(_semantic.to_tensor(cond), msg)
3015
+ mask = _unwrap_if_constexpr(mask)
3016
+ if mask is not None:
3017
+ mask = _semantic.to_tensor(mask)
3018
+ return _semantic.device_assert(_semantic.to_tensor(cond), msg, mask)
2964
3019
 
2965
3020
 
2966
3021
  @builtin
@@ -3098,7 +3153,7 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
3098
3153
  # -----------------------
3099
3154
 
3100
3155
 
3101
- class static_range:
3156
+ class static_range(base_value):
3102
3157
  """
3103
3158
  Iterator that counts upward forever.
3104
3159
 
@@ -3154,7 +3209,7 @@ class async_task:
3154
3209
  self.builder.unset_async_task_ids()
3155
3210
 
3156
3211
 
3157
- class range:
3212
+ class range(base_value):
3158
3213
  """
3159
3214
  Iterator that counts upward forever.
3160
3215
 
@@ -3189,6 +3244,9 @@ class range:
3189
3244
  The compiler will attempt to partition memory, MMA, and vector
3190
3245
  operations in the loop into separate async partitions. This will
3191
3246
  increase the total number of warps required by the kernel.
3247
+ :param disable_licm: Tells the compiler it shouldn't hoist loop invariant
3248
+ code outside the loop. This is often useful to avoid creating long liveranges
3249
+ within a loop.
3192
3250
 
3193
3251
  Note that warp specialization is only supported on Blackwell GPUs and
3194
3252
  only works on simple matmul loops. Support for arbitrary loops will be
@@ -3196,7 +3254,7 @@ class range:
3196
3254
  """
3197
3255
 
3198
3256
  def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
3199
- disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False):
3257
+ disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False):
3200
3258
  if step is None:
3201
3259
  self.step = constexpr(1)
3202
3260
  else:
@@ -3212,6 +3270,7 @@ class range:
3212
3270
  self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
3213
3271
  self.flatten = flatten
3214
3272
  self.warp_specialize = warp_specialize
3273
+ self.disable_licm = disable_licm
3215
3274
 
3216
3275
  def __iter__(self):
3217
3276
  raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
@@ -3220,13 +3279,36 @@ class range:
3220
3279
  raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
3221
3280
 
3222
3281
 
3282
+ class condition(base_value):
3283
+ """
3284
+ While loop condition wrapper.
3285
+
3286
+ .. highlight:: python
3287
+ .. code-block:: python
3288
+
3289
+ @triton.jit
3290
+ def kernel(...):
3291
+ while tl.condition(c, disable_licm)
3292
+ ...
3293
+ :note: This is a special wrapper used to annotate while loops in the context of
3294
+ :code:`triton.jit` functions. It allows user to pass extra attributes to the compiler.
3295
+ :param disable_licm: Tells the compiler it shouldn't hoist loop invariant
3296
+ code outside the loop. This is often useful to avoid creating long liveranges
3297
+ within a loop.
3298
+ """
3299
+
3300
+ def __init__(self, arg1, disable_licm=False):
3301
+ self.condition = arg1
3302
+ self.disable_licm = disable_licm
3303
+
3304
+
3223
3305
  # -----------------------
3224
3306
  # Extern functions
3225
3307
  # -----------------------
3226
3308
 
3227
3309
 
3228
- def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple,
3229
- is_pure: bool, _semantic):
3310
+ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_type: dtype, is_pure: bool,
3311
+ _semantic):
3230
3312
  '''
3231
3313
  Dispatch a function to a library
3232
3314
  :param func: the function to dispatch
@@ -3234,7 +3316,7 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
3234
3316
  :param lib_path: the path of the library
3235
3317
  :param args: the arguments of the function
3236
3318
  :param arg_type_symbol_dict: the type of the arguments
3237
- :param ret_shape: the shape of the return value
3319
+ :param ret_type: the type of the return value
3238
3320
  :return: the return value of the function
3239
3321
  '''
3240
3322
  if len(arg_type_symbol_dict) == 0:
@@ -3261,9 +3343,6 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
3261
3343
  f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
3262
3344
  else:
3263
3345
  symbol = arg_type_symbol_dict[arg_types][0]
3264
- ret_type = arg_type_symbol_dict[arg_types][1]
3265
- if ret_shape:
3266
- ret_type = block_type(ret_type, ret_shape)
3267
3346
  builder = _semantic.builder
3268
3347
  return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type)
3269
3348
 
@@ -3282,15 +3361,16 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
3282
3361
  '''
3283
3362
  dispatch_args = args.copy()
3284
3363
  all_scalar = True
3285
- ret_shape = None
3286
3364
  arg_types = []
3287
3365
  for i in builtins.range(len(dispatch_args)):
3288
3366
  dispatch_args[i] = _semantic.to_tensor(dispatch_args[i])
3289
3367
  arg_types.append(dispatch_args[i].dtype)
3290
3368
  if dispatch_args[i].type.is_block():
3291
3369
  all_scalar = False
3370
+
3371
+ arg_types = tuple(arg_types)
3372
+ ret_type = arg_type_symbol_dict[arg_types][1]
3292
3373
  if len(arg_types) > 0:
3293
- arg_types = tuple(arg_types)
3294
3374
  arithmetic_check = True
3295
3375
  # If there's a type tuple that is not supported by the library, we will do arithmetic check
3296
3376
  if arg_types in arg_type_symbol_dict:
@@ -3305,9 +3385,9 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
3305
3385
  dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg,
3306
3386
  arithmetic_check=arithmetic_check)
3307
3387
  if not all_scalar:
3308
- ret_shape = broadcast_arg.shape
3388
+ ret_type = broadcast_arg.type.with_element_ty(ret_type)
3309
3389
  func = _semantic.builder.create_extern_elementwise
3310
- return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _semantic)
3390
+ return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_type, is_pure, _semantic)
3311
3391
 
3312
3392
 
3313
3393
  def binary_op_type_legalization(lhs, rhs, semantic):
@@ -10,22 +10,22 @@ from triton.language import core
10
10
 
11
11
 
12
12
  @core.extern
13
- def gdc_wait(_builder=None):
13
+ def gdc_wait(_semantic=None):
14
14
  """
15
15
  GDC wait is a blocking instruction that waits for all instructions in a prior kernel to complete before continuing.
16
16
  This ensures all memory operations happening before the wait is visible to instructions after it,
17
17
  e.g. if the prior kernel writes to address "x" the new values will be visible in this kernel after the wait.
18
18
 
19
- This instruction is also safe to execute when programatic dependent launch is disabled.
19
+ This instruction is also safe to execute when programmatic dependent launch is disabled.
20
20
 
21
21
  See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol for more details.
22
22
  """
23
23
  core.inline_asm_elementwise("griddepcontrol.wait; // dummy $0", "=r", [], dtype=core.int32, is_pure=False, pack=1,
24
- _builder=_builder)
24
+ _semantic=_semantic)
25
25
 
26
26
 
27
27
  @core.extern
28
- def gdc_launch_dependents(_builder=None):
28
+ def gdc_launch_dependents(_semantic=None):
29
29
  """
30
30
  This operation when launched with programmatic dependent launch signals that
31
31
  the next program may launch once all programs in the current kernel
@@ -34,9 +34,9 @@ def gdc_launch_dependents(_builder=None):
34
34
  Repeated calls to this function have no effect past the first call, and the first call should be
35
35
  treated by the programmer as a hint to the runtime system to launch the next kernel.
36
36
 
37
- This instruction is also safe to execute when programatic dependent launch is disabled.
37
+ This instruction is also safe to execute when programmatic dependent launch is disabled.
38
38
 
39
39
  See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol for more details.
40
40
  """
41
41
  core.inline_asm_elementwise("griddepcontrol.launch_dependents; // dummy $0", "=r", [], dtype=core.int32,
42
- is_pure=False, pack=1, _builder=_builder)
42
+ is_pure=False, pack=1, _semantic=_semantic)
@@ -1,3 +1,5 @@
1
1
  from . import libdevice
2
2
 
3
- __all__ = ["libdevice"]
3
+ from .utils import memrealtime
4
+
5
+ __all__ = ["libdevice", "memrealtime"]
@@ -73,6 +73,13 @@ def fast_expf(arg0, _semantic=None):
73
73
  }, is_pure=True, _semantic=_semantic)
74
74
 
75
75
 
76
+ @core.extern
77
+ def fast_tanhf(arg0, _semantic=None):
78
+ return core.extern_elementwise("", "", [arg0], {
79
+ (core.dtype("fp32"), ): ("__triton_hip_fast_tanhf", core.dtype("fp32")),
80
+ }, is_pure=True, _semantic=_semantic)
81
+
82
+
76
83
  @core.extern
77
84
  def fast_dividef(arg0, arg1, _semantic=None):
78
85
  return core.extern_elementwise("", "", [arg0, arg1], {
@@ -0,0 +1,35 @@
1
+ from triton.language import core
2
+
3
+
4
+ @core.extern
5
+ def memrealtime(_semantic=None):
6
+ """
7
+ Returns a 64-bit real time-counter value
8
+ """
9
+ target_arch = _semantic.builder.options.arch
10
+ if 'gfx11' in target_arch or 'gfx12' in target_arch:
11
+ return core.inline_asm_elementwise(
12
+ """
13
+ s_sendmsg_rtn_b64 $0, sendmsg(MSG_RTN_GET_REALTIME)
14
+ s_waitcnt lgkmcnt(0)
15
+ """,
16
+ "=r",
17
+ [],
18
+ dtype=core.int64,
19
+ is_pure=False,
20
+ pack=1,
21
+ _semantic=_semantic,
22
+ )
23
+ else:
24
+ return core.inline_asm_elementwise(
25
+ """
26
+ s_memrealtime $0
27
+ s_waitcnt vmcnt(0)
28
+ """,
29
+ "=r",
30
+ [],
31
+ dtype=core.int64,
32
+ is_pure=False,
33
+ pack=1,
34
+ _semantic=_semantic,
35
+ )