triton-windows 3.2.0.post11__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
@@ -0,0 +1,1796 @@
1
+ from __future__ import annotations # remove after python 3.11
2
+ import warnings
3
+
4
+ from typing import List, Optional, Sequence, Tuple, TypeVar
5
+ import numbers
6
+
7
+ from .._C.libtriton import ir
8
+ from . import core as tl
9
+ from . import math
10
+
11
+ T = TypeVar('T')
12
+
13
+
14
+ class IncompatibleTypeErrorImpl(Exception):
15
+
16
+ def __init__(self, type_a, type_b):
17
+ self.type_a = type_a
18
+ self.type_b = type_b
19
+ self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
20
+ super(IncompatibleTypeErrorImpl, self).__init__(self.message)
21
+
22
+
23
+ # ===----------------------------------------------------------------------===##
24
+ # Programming Model
25
+ # ===----------------------------------------------------------------------===##
26
+
27
+
28
+ def program_id(axis: int, builder: ir.builder) -> tl.tensor:
29
+ if axis not in (0, 1, 2):
30
+ raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
31
+ return tl.tensor(builder.create_get_program_id(axis), tl.int32)
32
+
33
+
34
+ def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
35
+ if axis not in (0, 1, 2):
36
+ raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
37
+ return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
38
+
39
+
40
+ # ===----------------------------------------------------------------------===//
41
+ # Implicit Casting Utilities
42
+ # ===----------------------------------------------------------------------===//
43
+
44
+
45
+ def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype:
46
+ a_rank = a_ty.int_bitwidth
47
+ b_rank = b_ty.int_bitwidth
48
+ a_sn = a_ty.int_signedness
49
+ b_sn = b_ty.int_signedness
50
+ # Rules for signedness taken from "Usual arithmetic conversions" on
51
+ # https://en.cppreference.com/w/c/language/conversion.
52
+ if a_sn == b_sn:
53
+ return a_ty if a_rank > b_rank else b_ty
54
+ elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
55
+ return a_ty if a_rank >= b_rank else b_ty
56
+ elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
57
+ return b_ty if b_rank >= a_rank else a_ty
58
+ raise TypeError(f"unexpected signedness {a_sn} and {b_sn}")
59
+
60
+
61
+ def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool,
62
+ div_or_mod: bool) -> tl.dtype:
63
+ # 0) For scalars we follow semantics similar to PyTorch, namely:
64
+ # - If the scalar is of a lower or equal kind (bool < uint < int < fp),
65
+ # it doesn't participate in the pomotion
66
+ if a_is_scalar != b_is_scalar:
67
+ scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty)
68
+ if scalar_ty.kind().value <= tensor_ty.kind().value:
69
+ # Upcast because of 3) and 4) below!
70
+ if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)):
71
+ return tl.float32
72
+ return tensor_ty
73
+
74
+ # 1) if one operand is double, the other is implicitly
75
+ # converted to double
76
+ if a_ty.is_fp64() or b_ty.is_fp64():
77
+ return tl.float64
78
+ # 2) if one operand is float, the other is implicitly
79
+ # converted to float
80
+ if a_ty.is_fp32() or b_ty.is_fp32():
81
+ return tl.float32
82
+ # 3 ) if one operand is half, the other is implicitly converted to half
83
+ # unless we're doing / or %, which do not exist natively in PTX for fp16.
84
+ # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
85
+ if a_ty.is_fp16() or b_ty.is_fp16():
86
+ if div_or_mod:
87
+ return tl.float32
88
+ else:
89
+ return tl.float16
90
+ # 4) return bf16 only if both operands are of bf16
91
+ if a_ty.is_bf16() or b_ty.is_bf16():
92
+ if div_or_mod:
93
+ return tl.float32
94
+ if a_ty.is_bf16() and b_ty.is_bf16():
95
+ return tl.bfloat16
96
+ return tl.float32
97
+ # 5) return fp16 if operands are different fp8
98
+ if a_ty.is_fp8() and b_ty.is_fp8():
99
+ return a_ty if a_ty == b_ty else tl.float16
100
+ if not a_ty.is_int() or not b_ty.is_int():
101
+ raise TypeError(f"unexpected type {a_ty} and {b_ty}")
102
+ # 6 ) both operands are integer and undergo
103
+ # integer promotion
104
+ if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
105
+ raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
106
+ " because they have different signedness;"
107
+ "this is unlikely to result in a useful answer. Cast them to the same signedness.")
108
+ return integer_promote_impl(a_ty, b_ty)
109
+
110
+
111
+ def to_tensor(x, builder, check_type: bool = True):
112
+ if isinstance(x, bool):
113
+ return tl.tensor(builder.get_int1(x), tl.int1)
114
+ # Note: compile-time const integers are represented by unsigned values
115
+ elif isinstance(x, int):
116
+ if -2**31 <= x < 2**31:
117
+ dtype = tl.int32
118
+ elif 2**31 <= x < 2**32:
119
+ dtype = tl.uint32
120
+ elif -2**63 <= x < 2**63:
121
+ dtype = tl.int64
122
+ elif 2**63 <= x < 2**64:
123
+ dtype = tl.uint64
124
+ else:
125
+ raise ValueError(f'Nonrepresentable integer {x}.')
126
+ return full((), x, dtype=dtype, builder=builder)
127
+ elif isinstance(x, float):
128
+ min_float32 = 2**-126
129
+ max_float32 = (2 - 2**-23) * 2**127
130
+ abs_x = __builtins__['abs'](x)
131
+ if abs_x == float("inf") or\
132
+ abs_x == 0.0 or \
133
+ x != x or \
134
+ min_float32 <= abs_x <= max_float32:
135
+ dtype = tl.float32
136
+ else:
137
+ dtype = tl.float64
138
+ return full((), x, dtype=dtype, builder=builder)
139
+
140
+ elif isinstance(x, tl.constexpr):
141
+ return to_tensor(x.value, builder)
142
+ elif isinstance(x, tl.tensor):
143
+ return x
144
+ if check_type:
145
+ raise TypeError(f"cannot convert {x} of type {type(x)} to tensor")
146
+ return x
147
+
148
+
149
+ # ===----------------------------------------------------------------------===//
150
+ # Binary Operators
151
+ # ===----------------------------------------------------------------------===//
152
+
153
+
154
+ def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
155
+ if type_a.is_ptr():
156
+ if not allow_ptr_a:
157
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
158
+ # T* + U* with T != U
159
+ if type_b.is_ptr() and (type_a != type_b):
160
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
161
+ # T* + float
162
+ if type_b.is_floating():
163
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
164
+
165
+
166
+ def binary_op_type_checking_impl(lhs: tl.tensor | numbers.Number, rhs: tl.tensor | numbers.Number, builder: ir.builder,
167
+ allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True,
168
+ div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]:
169
+ lhs_is_scalar = isinstance(lhs, numbers.Number)
170
+ rhs_is_scalar = isinstance(rhs, numbers.Number)
171
+ if lhs_is_scalar:
172
+ lhs_scalar = lhs
173
+ lhs = to_tensor(lhs, builder)
174
+ if rhs_is_scalar:
175
+ rhs_scalar = rhs
176
+ rhs = to_tensor(rhs, builder)
177
+
178
+ # implicit typecasting
179
+ lhs_sca_ty = lhs.type.scalar
180
+ rhs_sca_ty = rhs.type.scalar
181
+ check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
182
+ check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
183
+ if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr():
184
+ ret_sca_ty = computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod)
185
+ if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned()
186
+ or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()):
187
+ raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. "
188
+ "Perform a explicit cast on one of them.")
189
+ lhs = full(
190
+ (), lhs_scalar, dtype=ret_sca_ty, builder=builder) if lhs_is_scalar else cast(lhs, ret_sca_ty, builder)
191
+ rhs = full(
192
+ (), rhs_scalar, dtype=ret_sca_ty, builder=builder) if rhs_is_scalar else cast(rhs, ret_sca_ty, builder)
193
+
194
+ # implicit broadcasting
195
+ lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
196
+ return lhs, rhs
197
+
198
+
199
+ def binary_op_sanitize_overflow_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, binary_op: callable):
200
+ if lhs.type.scalar.int_bitwidth >= 64 or not builder.options.sanitize_overflow:
201
+ return
202
+ lhs_sca_ty = lhs.type.scalar
203
+ rhs_sca_ty = rhs.type.scalar
204
+ assert lhs_sca_ty == rhs_sca_ty
205
+ assert lhs_sca_ty.is_int()
206
+ lhs = cast(lhs, tl.int64, builder)
207
+ rhs = cast(rhs, tl.int64, builder)
208
+ ret = binary_op(lhs, rhs, False, builder)
209
+ max_value = lhs_sca_ty.get_int_max_value()
210
+ max_value = tl.tensor(builder.get_int64(max_value), tl.int64)
211
+ min_value = lhs_sca_ty.get_int_min_value()
212
+ min_value = tl.tensor(builder.get_int64(min_value), tl.int64)
213
+ cond = and_(less_equal(ret, max_value, builder), greater_equal(ret, min_value, builder), builder)
214
+ msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}"
215
+ device_assert(cond, msg, builder)
216
+
217
+
218
+ def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool,
219
+ builder: ir.builder) -> tl.tensor:
220
+ input, other = binary_op_type_checking_impl(input, other, builder, True, True)
221
+ input_scalar_ty = input.type.scalar
222
+ other_scalar_ty = other.type.scalar
223
+ if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
224
+ raise TypeError("cannot add pointers together")
225
+
226
+ # offset + ptr
227
+ # ptr + offset
228
+ if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
229
+ input, other = other, input
230
+ input_scalar_ty = input.type.scalar
231
+ other_scalar_ty = other.type.scalar
232
+ if input_scalar_ty.is_ptr():
233
+ return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
234
+ # float + float
235
+ elif input_scalar_ty.is_floating():
236
+ return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
237
+ # int + int
238
+ elif input_scalar_ty.is_int():
239
+ if sanitize_overflow:
240
+ binary_op_sanitize_overflow_impl(input, other, builder, add)
241
+ return tl.tensor(builder.create_add(input.handle, other.handle), input.type)
242
+ raise TypeError(f"unexpected type {input_scalar_ty}")
243
+
244
+
245
+ def sub(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool,
246
+ builder: ir.builder) -> tl.tensor:
247
+ input, other = binary_op_type_checking_impl(input, other, builder, True, False)
248
+ scalar_ty = input.type.scalar
249
+ # ptr - offset
250
+ if scalar_ty.is_ptr():
251
+ return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type)
252
+ # float - float
253
+ if scalar_ty.is_floating():
254
+ return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
255
+ # int - int
256
+ elif scalar_ty.is_int():
257
+ if sanitize_overflow:
258
+ binary_op_sanitize_overflow_impl(input, other, builder, sub)
259
+ return tl.tensor(builder.create_sub(input.handle, other.handle), input.type)
260
+ raise TypeError(f"unexpected type {scalar_ty}")
261
+
262
+
263
+ def mul(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool,
264
+ builder: ir.builder) -> tl.tensor:
265
+ input, other = binary_op_type_checking_impl(input, other, builder)
266
+ scalar_ty = input.type.scalar
267
+ # float * float
268
+ if scalar_ty.is_floating():
269
+ return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type)
270
+ # int * int
271
+ elif scalar_ty.is_int():
272
+ if sanitize_overflow:
273
+ binary_op_sanitize_overflow_impl(input, other, builder, mul)
274
+ return tl.tensor(builder.create_mul(input.handle, other.handle), input.type)
275
+ raise TypeError(f"unexpected type {scalar_ty}")
276
+
277
+
278
+ def truediv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor:
279
+ input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
280
+ input_scalar_ty = input.type.scalar
281
+ other_scalar_ty = other.type.scalar
282
+ # float / int
283
+ if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
284
+ other = cast(other, input_scalar_ty, builder)
285
+ # int / float
286
+ elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
287
+ input = cast(input, other_scalar_ty, builder)
288
+ # int / int (cast to tl.float32)
289
+ elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
290
+ input = cast(input, tl.float32, builder)
291
+ other = cast(other, tl.float32, builder)
292
+ # float / float (cast to the highest exponent type)
293
+ elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
294
+ if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
295
+ other = cast(other, input_scalar_ty, builder)
296
+ else:
297
+ input = cast(input, other_scalar_ty, builder)
298
+ # unreachable
299
+ else:
300
+ raise TypeError(f"unexpected type {input_scalar_ty}")
301
+ return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
302
+
303
+
304
+ def floordiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor:
305
+ input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
306
+ input_scalar_ty = input.type.scalar
307
+ other_scalar_ty = other.type.scalar
308
+ if input_scalar_ty.is_int() and other_scalar_ty.is_int():
309
+ ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty)
310
+ input = cast(input, ret_ty, builder)
311
+ other = cast(other, ret_ty, builder)
312
+ if ret_ty.is_int_signed():
313
+ return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type)
314
+ else:
315
+ return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type)
316
+ raise TypeError(f"unexpected type {input_scalar_ty}")
317
+
318
+
319
+ def fdiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, ieee_rounding: bool,
320
+ builder: ir.builder) -> tl.tensor:
321
+ input_scalar_ty = input.type.scalar
322
+ other_scalar_ty = other.type.scalar
323
+ if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
324
+ raise TypeError("both operands of fdiv must have floating scalar type")
325
+ input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
326
+ ret = builder.create_fdiv(input.handle, other.handle)
327
+ return tl.tensor(ret, input.type)
328
+
329
+
330
+ def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor:
331
+ input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
332
+ scalar_ty = input.type.scalar
333
+ other_scalar_ty = other.type.scalar
334
+ # float % float
335
+ if scalar_ty.is_floating():
336
+ # input - input.div(other, rounding_mode="floor") * other
337
+ floor = math.floor(fdiv(input, other, False, builder), _builder=builder)
338
+ ret = sub(input, mul(floor, other, True, builder), True, builder)
339
+ return ret
340
+ # % int
341
+ elif scalar_ty.is_int():
342
+ if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
343
+ raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
344
+ "because they have different signedness;"
345
+ "this is unlikely to result in a useful answer. Cast them to the same signedness.")
346
+ if scalar_ty.is_int_signed():
347
+ return tl.tensor(builder.create_srem(input.handle, other.handle), input.type)
348
+ else:
349
+ return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
350
+ raise TypeError(f"unexpected type {scalar_ty}")
351
+
352
+
353
+ ##############
354
+ # other arithmetic ops
355
+ ##############
356
+
357
+
358
+ def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder):
359
+ x, y = binary_op_type_checking_impl(x, y, builder)
360
+ dtype = x.dtype
361
+ if dtype.is_floating():
362
+ if propagate_nan == tl.PropagateNan.ALL:
363
+ return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type)
364
+ elif propagate_nan == tl.PropagateNan.NONE:
365
+ return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type)
366
+ else:
367
+ raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
368
+ elif dtype.is_int_signed():
369
+ return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type)
370
+ elif dtype.is_int_unsigned():
371
+ return tl.tensor(builder.create_minui(x.handle, y.handle), x.type)
372
+ else:
373
+ raise TypeError(f"Unexpected dtype {dtype}")
374
+
375
+
376
+ def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder):
377
+ x, y = binary_op_type_checking_impl(x, y, builder)
378
+ dtype = x.dtype
379
+ if dtype.is_floating():
380
+ if propagate_nan == tl.PropagateNan.ALL:
381
+ return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type)
382
+ elif propagate_nan == tl.PropagateNan.NONE:
383
+ return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type)
384
+ else:
385
+ raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
386
+ elif dtype.is_int_signed():
387
+ return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type)
388
+ elif dtype.is_int_unsigned():
389
+ return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type)
390
+ else:
391
+ raise TypeError(f"Unexpected dtype {dtype}")
392
+
393
+
394
+ def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder):
395
+ min, max = binary_op_type_checking_impl(min, max, builder)
396
+ x, min = binary_op_type_checking_impl(x, min, builder)
397
+ x, max = binary_op_type_checking_impl(x, max, builder)
398
+
399
+ dtype = x.dtype
400
+ if dtype.is_floating():
401
+ return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type)
402
+ else:
403
+ raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported")
404
+
405
+
406
+ ##############
407
+ # bitwise ops
408
+ ##############
409
+
410
+
411
+ def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor,
412
+ builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
413
+ input, other = binary_op_type_checking_impl(input, other, builder)
414
+ input_sca_ty = input.type.scalar
415
+ other_sca_ty = other.type.scalar
416
+ if not input_sca_ty.is_int() or not other_sca_ty.is_int():
417
+ raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
418
+ ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
419
+ if ret_sca_ty != input_sca_ty:
420
+ input = cast(input, ret_sca_ty, builder)
421
+ if ret_sca_ty != other_sca_ty:
422
+ other = cast(other, ret_sca_ty, builder)
423
+ return input, other
424
+
425
+
426
+ def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
427
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
428
+ return tl.tensor(builder.create_and(input.handle, other.handle), input.type)
429
+
430
+
431
+ def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
432
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
433
+ return tl.tensor(builder.create_or(input.handle, other.handle), input.type)
434
+
435
+
436
+ def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
437
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
438
+ return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
439
+
440
+
441
+ def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
442
+ if not input.type.is_int1():
443
+ input = bitcast(input, tl.dtype("int1"), builder)
444
+ if not other.type.is_int1():
445
+ other = bitcast(other, tl.dtype("int1"), builder)
446
+ return and_(input, other, builder)
447
+
448
+
449
+ def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
450
+ if not input.type.is_int1():
451
+ input = bitcast(input, tl.dtype("int1"), builder)
452
+ if not other.type.is_int1():
453
+ other = bitcast(other, tl.dtype("int1"), builder)
454
+ return or_(input, other, builder)
455
+
456
+
457
+ def not_(input: tl.tensor, builder: ir.builder):
458
+ if not input.type.is_int1():
459
+ input = bitcast(input, tl.dtype("int1"), builder)
460
+ return invert(input, builder)
461
+
462
+
463
+ def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
464
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
465
+ return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
466
+
467
+
468
+ def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
469
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
470
+ return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type)
471
+
472
+
473
+ def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
474
+ input, other = bitwise_op_type_checking_impl(input, other, builder)
475
+ return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
476
+
477
+
478
+ # ===----------------------------------------------------------------------===//
479
+ # Unary Operators
480
+ # ===----------------------------------------------------------------------===//
481
+
482
+
483
+ def plus(input: tl.tensor) -> tl.tensor:
484
+ return input
485
+
486
+
487
+ def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor:
488
+ input_sca_ty = input.type.scalar
489
+ if input_sca_ty.is_ptr():
490
+ raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
491
+ _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
492
+ return sub(_0, input, True, builder)
493
+
494
+
495
+ def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor:
496
+ input_sca_ty = input.type.scalar
497
+ if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
498
+ raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
499
+ _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
500
+ return xor_(input, _1, builder)
501
+
502
+
503
+ # ===----------------------------------------------------------------------===//
504
+ # Comparison Operators
505
+ # ===----------------------------------------------------------------------===//
506
+ def _bool_like(v: tl.tensor) -> tl.block_type:
507
+ if not v.type.is_block():
508
+ return tl.int1
509
+ shape = v.type.shape
510
+ return tl.block_type(tl.int1, shape)
511
+
512
+
513
+ def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
514
+ input, other = binary_op_type_checking_impl(input, other, builder)
515
+ scalar_ty = input.type.scalar
516
+ # float > float
517
+ if scalar_ty.is_floating():
518
+ return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input))
519
+ # > int
520
+ elif scalar_ty.is_int():
521
+ if scalar_ty.is_int_signed():
522
+ return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input))
523
+ else:
524
+ return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input))
525
+ raise TypeError(f"unexpected type {scalar_ty}")
526
+
527
+
528
+ def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
529
+ input, other = binary_op_type_checking_impl(input, other, builder)
530
+ scalar_ty = input.type.scalar
531
+ # float >= float
532
+ if scalar_ty.is_floating():
533
+ return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input))
534
+ # >= int
535
+ elif scalar_ty.is_int():
536
+ if scalar_ty.is_int_signed():
537
+ return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input))
538
+ else:
539
+ return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input))
540
+ raise TypeError(f"unexpected type {scalar_ty}")
541
+
542
+
543
+ def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
544
+ input, other = binary_op_type_checking_impl(input, other, builder)
545
+ scalar_ty = input.type.scalar
546
+ # float < float
547
+ if scalar_ty.is_floating():
548
+ return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input))
549
+ # < int
550
+ elif scalar_ty.is_int():
551
+ if scalar_ty.is_int_signed():
552
+ return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input))
553
+ else:
554
+ return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input))
555
+ raise TypeError(f"unexpected type {scalar_ty}")
556
+
557
+
558
+ def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
559
+ input, other = binary_op_type_checking_impl(input, other, builder)
560
+ scalar_ty = input.type.scalar
561
+ # float < float
562
+ if scalar_ty.is_floating():
563
+ return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input))
564
+ # < int
565
+ elif scalar_ty.is_int():
566
+ if scalar_ty.is_int_signed():
567
+ return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input))
568
+ else:
569
+ return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input))
570
+ raise TypeError(f"unexpected type {scalar_ty}")
571
+
572
+
573
+ def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
574
+ input, other = binary_op_type_checking_impl(input, other, builder)
575
+ scalar_ty = input.type.scalar
576
+ # float == float
577
+ if scalar_ty.is_floating():
578
+ return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input))
579
+ # == int
580
+ elif scalar_ty.is_int():
581
+ return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input))
582
+ raise TypeError(f"unexpected type {scalar_ty}")
583
+
584
+
585
+ def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
586
+ input, other = binary_op_type_checking_impl(input, other, builder)
587
+ scalar_ty = input.type.scalar
588
+ # float == float
589
+ if scalar_ty.is_floating():
590
+ return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input))
591
+ # == int
592
+ elif scalar_ty.is_int():
593
+ return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
594
+ raise TypeError(f"unexpected type {scalar_ty}")
595
+
596
+
597
+ # ===----------------------------------------------------------------------===//
598
+ # Block Creation
599
+ # ===----------------------------------------------------------------------===//
600
+
601
+
602
+ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
603
+ if not isinstance(start, int) or not isinstance(end, int):
604
+ raise ValueError("arange's arguments must be of type tl.constexpr")
605
+ is_start_int64 = bool(start >> 32)
606
+ is_end_int64 = bool(end >> 32)
607
+ if is_start_int64 or is_end_int64:
608
+ raise ValueError("arange must fit in int32")
609
+ if end <= start:
610
+ raise ValueError("arange's end argument must be greater than the start argument")
611
+ range = end - start
612
+ if (range & (range - 1)) != 0:
613
+ raise ValueError("arange's range must be a power of 2")
614
+ shape = [range]
615
+ ret_ty = tl.block_type(tl.int32, shape)
616
+ return tl.tensor(builder.create_make_range(start, end), ret_ty)
617
+
618
+
619
+ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
620
+ if isinstance(value, tl.tensor):
621
+ assert value.numel.value == 1, "only accepts size-1 tensor"
622
+ value = cast(value, dtype, builder)
623
+ else:
624
+ # scalar
625
+ if dtype is None:
626
+ raise ValueError("dtype must be specified when value is not a tensor")
627
+ if value == 0:
628
+ value = builder.get_null_value(dtype.to_ir(builder))
629
+ else:
630
+ get_value_fn = getattr(builder, f"get_{dtype.name}")
631
+ value = get_value_fn(value)
632
+ value = tl.tensor(value, dtype)
633
+
634
+ return splat(value, shape, builder)
635
+
636
+
637
+ # ===----------------------------------------------------------------------===//
638
+ # Shape Manipulation
639
+ # ===----------------------------------------------------------------------===//
640
+
641
+
642
+ def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
643
+ assert not value.type.is_block(), "Cannot splat a block tensor"
644
+ if len(shape) == 0:
645
+ return value
646
+ ret_ty = tl.block_type(value.dtype, shape)
647
+ return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
648
+
649
+
650
+ def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor:
651
+ numel = 1
652
+ for s in dst_shape:
653
+ numel *= s
654
+ if input.type.numel != numel:
655
+ raise ValueError("reshape() cannot change total number of elements in tensor")
656
+ ret_ty = tl.block_type(input.type.scalar, dst_shape)
657
+ return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty)
658
+
659
+
660
+ def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
661
+ dst_shape = [tl._constexpr_to_value(x) for x in input.shape]
662
+ dst_shape.insert(axis, 1)
663
+
664
+ if not input.type.is_block():
665
+ return splat(input, shape=dst_shape, builder=builder)
666
+
667
+ ret_ty = tl.block_type(input.type.scalar, dst_shape)
668
+ return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
669
+
670
+
671
+ def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
672
+ assert can_reorder, "current implementation of `cat` always may reorder elements"
673
+ assert len(lhs.shape) == 1
674
+ ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
675
+ return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
676
+
677
+
678
+ def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor:
679
+ a, b = broadcast_impl_value(a, b, builder)
680
+
681
+ # The IR can't handle joining two scalars, so upcast them to 1D tensors,
682
+ # then downcast the result.
683
+ was_rank_1 = a.shape == []
684
+ if was_rank_1:
685
+ a = expand_dims(a, 0, builder)
686
+ b = expand_dims(b, 0, builder)
687
+
688
+ if isinstance(a.shape[-1], tl.constexpr):
689
+ two = tl.constexpr(2)
690
+ else:
691
+ two = 2
692
+ new_shape = a.shape + [two]
693
+
694
+ ret_type = tl.block_type(a.type.scalar, new_shape)
695
+ ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type)
696
+
697
+ if was_rank_1:
698
+ ret = reshape(ret, [2], can_reorder=False, builder=builder)
699
+
700
+ return ret
701
+
702
+
703
+ def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
704
+ assert (len(a.shape) > 0)
705
+ assert (tl._constexpr_to_value(a.shape[-1]) == 2)
706
+
707
+ new_shape = a.shape[:-1]
708
+ ret_type = tl.block_type(a.type.scalar, new_shape)
709
+ outLHS, outRHS = builder.create_split(a.handle)
710
+ return (
711
+ tl.tensor(outLHS, ret_type),
712
+ tl.tensor(outRHS, ret_type),
713
+ )
714
+
715
+
716
+ def permute(input: tl.tensor, dims: Tuple[int], builder: ir.builder) -> tl.tensor:
717
+ if len(input.shape) != len(dims):
718
+ raise ValueError("permute dims must have the same length as input shape")
719
+ if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))):
720
+ raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}")
721
+
722
+ ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims])
723
+ return tl.tensor(builder.create_trans(input.handle, dims), ret_type)
724
+
725
+
726
+ def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
727
+ if not input.type.is_block():
728
+ ret_ty = tl.block_type(input.type, shape)
729
+ return tl.tensor(builder.create_splat(input.handle, shape), ret_ty)
730
+ src_shape = input.type.get_block_shapes()
731
+ if len(src_shape) != len(shape):
732
+ raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
733
+ if shape == src_shape:
734
+ return input
735
+ for i, item in enumerate(src_shape):
736
+ if shape[i] != item and item != 1:
737
+ raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
738
+ f" must match the existing size ({item}) at non-singleton dimension"
739
+ f" {i}: {src_shape}, {shape}")
740
+ ret_ty = tl.block_type(input.type.scalar, shape)
741
+ return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
742
+
743
+
744
+ def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
745
+ lhs_ty = lhs.type
746
+ rhs_ty = rhs.type
747
+
748
+ # make_shape_compatible(block, scalar)
749
+ if lhs_ty.is_block() and not rhs_ty.is_block():
750
+ rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape)
751
+ rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty)
752
+ # make_shape_compatible(scalar, block)
753
+ elif not lhs_ty.is_block() and rhs_ty.is_block():
754
+ lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape)
755
+ lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty)
756
+ # make_shape_compatible(block, block)
757
+ elif lhs_ty.is_block() and rhs_ty.is_block():
758
+ lhs_shape = lhs_ty.get_block_shapes()
759
+ rhs_shape = rhs_ty.get_block_shapes()
760
+
761
+ if len(lhs_shape) < len(rhs_shape):
762
+ # Add new axes to lhs
763
+ for _ in range(len(lhs_shape), len(rhs_shape)):
764
+ lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0),
765
+ tl.block_type(lhs_ty.scalar, [1] + lhs_shape))
766
+ lhs_ty = lhs.type
767
+ lhs_shape = lhs_ty.get_block_shapes()
768
+ elif len(rhs_shape) < len(lhs_shape):
769
+ # Add new axes to rhs
770
+ for _ in range(len(rhs_shape), len(lhs_shape)):
771
+ rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0),
772
+ tl.block_type(rhs_ty.scalar, [1] + rhs_shape))
773
+ rhs_ty = rhs.type
774
+ rhs_shape = rhs_ty.get_block_shapes()
775
+ assert len(rhs_shape) == len(lhs_shape)
776
+
777
+ ret_shape = []
778
+ for i, left in enumerate(lhs_shape):
779
+ right = rhs_shape[i]
780
+ if left == 1:
781
+ ret_shape.append(right)
782
+ elif (right == 1) or (right == left):
783
+ ret_shape.append(left)
784
+ else:
785
+ raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
786
+ "at index " + str(i) + ": " + str(left) + " and " + str(right))
787
+ if lhs_shape != ret_shape:
788
+ ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
789
+ lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
790
+ if rhs_shape != ret_shape:
791
+ ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
792
+ rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
793
+ # (scalar, scalar) => returns original blocks
794
+ return lhs, rhs
795
+
796
+
797
+ #######
798
+ # cast
799
+ #######
800
+
801
+
802
+ def _str_to_rounding_mode(rounding_mode: Optional[str]):
803
+ if rounding_mode is None:
804
+ return None
805
+ if rounding_mode == 'rtne':
806
+ return ir.ROUNDING_MODE.RTNE
807
+ if rounding_mode == 'rtz':
808
+ return ir.ROUNDING_MODE.RTZ
809
+ raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.")
810
+
811
+
812
+ def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
813
+ src_ty = input.type
814
+ if src_ty.is_block():
815
+ dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
816
+ if src_ty == dst_ty:
817
+ return input
818
+ src_sca_ty = src_ty.scalar
819
+ dst_sca_ty = dst_ty.scalar
820
+ if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr():
821
+ return cast(input, dst_ty, builder)
822
+ # Bitcast
823
+ src_bits = src_sca_ty.primitive_bitwidth
824
+ dst_bits = dst_sca_ty.primitive_bitwidth
825
+ if src_bits != dst_bits:
826
+ raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
827
+ "data-type of size " + str(dst_bits))
828
+ return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
829
+
830
+
831
+ def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder,
832
+ fp_downcast_rounding: Optional[str] = None) -> tl.tensor:
833
+ src_ty = input.type
834
+ if isinstance(dst_ty, tl.constexpr):
835
+ dst_ty = dst_ty.value
836
+ if isinstance(fp_downcast_rounding, tl.constexpr):
837
+ fp_downcast_rounding = fp_downcast_rounding.value
838
+ if src_ty.is_block():
839
+ dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
840
+ if src_ty == dst_ty:
841
+ return input
842
+
843
+ src_sca_ty = src_ty.scalar
844
+ dst_sca_ty = dst_ty.scalar
845
+
846
+ # For fp downcasting default rounding mode should be RTNE, for all other conversions it should
847
+ # not be set
848
+ fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding)
849
+ use_custom_rounding = False
850
+ if dst_sca_ty.is_floating() and src_sca_ty.is_floating(
851
+ ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth:
852
+ if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE
853
+ elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True
854
+ else:
855
+ if fp_downcast_rounding is not None:
856
+ raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. "
857
+ "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty))
858
+
859
+ if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()):
860
+ assert builder.codegen_fns.get(
861
+ "convert_custom_types") is not None, "target doesn't provide conversion for this type."
862
+ return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder)
863
+ # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
864
+ # and non-default rounding modes for downcasting
865
+ if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
866
+ (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \
867
+ use_custom_rounding:
868
+ return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty)
869
+
870
+ # bf16 <=> (not fp32)
871
+ if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
872
+ (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
873
+ return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
874
+
875
+ # Standard floating types' casting: truncation
876
+ # fp64 => fp32, fp16, bf16
877
+ # fp32 => fp16, bf16
878
+ truncate_fp = src_sca_ty.is_floating() and \
879
+ dst_sca_ty.is_floating() and \
880
+ src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
881
+ if truncate_fp:
882
+ return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty)
883
+
884
+ # Standard floating types' casting: extension
885
+ # fp32 => fp64
886
+ # fp16 => fp32, fp64
887
+ # bf16 => fp32, fp64
888
+ ext_fp = src_sca_ty.is_floating() and \
889
+ dst_sca_ty.is_floating() and \
890
+ src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
891
+ if ext_fp:
892
+ return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty)
893
+
894
+ # Casting between integer types
895
+ if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
896
+ (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
897
+ sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
898
+ if dst_sca_ty.is_bool():
899
+ ty = input.dtype.to_ir(builder)
900
+ _0 = tl.tensor(builder.get_null_value(ty), input.dtype)
901
+ return not_equal(input, _0, builder)
902
+ else:
903
+ return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty)
904
+
905
+ # Casting standard floating types to integer types
906
+ if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
907
+ if dst_sca_ty.is_bool():
908
+ ty = input.dtype.to_ir(builder)
909
+ _0 = tl.tensor(builder.get_null_value(ty), input.dtype)
910
+ return not_equal(input, _0, builder)
911
+ elif dst_sca_ty.is_int_signed():
912
+ return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty)
913
+ else:
914
+ return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty)
915
+
916
+ # Casting integer types to standard floating types
917
+ if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
918
+ if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
919
+ return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
920
+ else:
921
+ return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
922
+
923
+ # Casting pointer types to integer types
924
+ if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
925
+ bitwidth = dst_sca_ty.int_bitwidth
926
+ if bitwidth == 64:
927
+ return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty)
928
+ if bitwidth == 1:
929
+ return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder)
930
+
931
+ # Casting integer types to pointer types
932
+ if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
933
+ return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
934
+
935
+ # Casting pointer types to pointer types
936
+ if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
937
+ return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
938
+
939
+ assert False, f'cannot cast {input} to {dst_ty}'
940
+
941
+
942
+ # ===----------------------------------------------------------------------===//
943
+ # Memory Operators
944
+ # ===----------------------------------------------------------------------===//
945
+
946
+
947
+ def _str_to_load_cache_modifier(cache_modifier):
948
+ cache = ir.CACHE_MODIFIER.NONE # default
949
+ if cache_modifier:
950
+ if cache_modifier == ".ca":
951
+ cache = ir.CACHE_MODIFIER.CA
952
+ elif cache_modifier == ".cg":
953
+ cache = ir.CACHE_MODIFIER.CG
954
+ elif cache_modifier == ".cv":
955
+ cache = ir.CACHE_MODIFIER.CV
956
+ else:
957
+ raise ValueError(f"Cache modifier {cache_modifier} not supported")
958
+ return cache
959
+
960
+
961
+ def _str_to_store_cache_modifier(cache_modifier):
962
+ cache = ir.CACHE_MODIFIER.NONE # default
963
+ if cache_modifier:
964
+ if cache_modifier == ".wb":
965
+ cache = ir.CACHE_MODIFIER.WB
966
+ elif cache_modifier == ".cg":
967
+ cache = ir.CACHE_MODIFIER.CG
968
+ elif cache_modifier == ".cs":
969
+ cache = ir.CACHE_MODIFIER.CS
970
+ elif cache_modifier == ".wt":
971
+ cache = ir.CACHE_MODIFIER.WT
972
+ else:
973
+ raise ValueError(f"Cache modifier {cache_modifier} not supported")
974
+ return cache
975
+
976
+
977
+ def _str_to_eviction_policy(eviction_policy):
978
+ eviction = ir.EVICTION_POLICY.NORMAL # default
979
+ if eviction_policy:
980
+ if eviction_policy == "evict_last":
981
+ eviction = ir.EVICTION_POLICY.EVICT_LAST
982
+ elif eviction_policy == "evict_first":
983
+ eviction = ir.EVICTION_POLICY.EVICT_FIRST
984
+ else:
985
+ raise ValueError(f"Eviction policy {eviction_policy} not supported")
986
+ return eviction
987
+
988
+
989
+ def _str_to_padding_option(padding_option):
990
+ padding = None # default
991
+ if padding_option:
992
+ if padding_option == "zero":
993
+ padding = ir.PADDING_OPTION.PAD_ZERO
994
+ elif padding_option == "nan":
995
+ padding = ir.PADDING_OPTION.PAD_NAN
996
+ else:
997
+ raise ValueError(f"Padding option {padding_option} not supported")
998
+ return padding
999
+
1000
+
1001
+ def _str_to_sem(sem_option):
1002
+ sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
1003
+ if sem_option:
1004
+ if sem_option == "acquire":
1005
+ sem = ir.MEM_SEMANTIC.ACQUIRE
1006
+ elif sem_option == "release":
1007
+ sem = ir.MEM_SEMANTIC.RELEASE
1008
+ elif sem_option == "acq_rel":
1009
+ sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
1010
+ elif sem_option == "relaxed":
1011
+ sem = ir.MEM_SEMANTIC.RELAXED
1012
+ else:
1013
+ raise ValueError(f"Memory semantic {sem_option} not supported")
1014
+ return sem
1015
+
1016
+
1017
+ def _str_to_scope(scope_option):
1018
+ scope = ir.MEM_SYNC_SCOPE.GPU
1019
+ if scope_option:
1020
+ if scope_option == "gpu":
1021
+ scope = ir.MEM_SYNC_SCOPE.GPU
1022
+ elif scope_option == "cta":
1023
+ scope = ir.MEM_SYNC_SCOPE.CTA
1024
+ elif scope_option == "sys":
1025
+ scope = ir.MEM_SYNC_SCOPE.SYSTEM
1026
+ else:
1027
+ raise ValueError(f"Memory semantic {scope_option} not supported")
1028
+ return scope
1029
+
1030
+
1031
+ def _canonicalize_boundary_check(boundary_check, block_shape):
1032
+ if boundary_check:
1033
+ if not hasattr(boundary_check, "__iter__"):
1034
+ boundary_check = [boundary_check]
1035
+ boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check]
1036
+ for dim in boundary_check:
1037
+ assert isinstance(dim, int) and 0 <= dim < len(block_shape)
1038
+ assert len(boundary_check) > 0
1039
+ assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`"
1040
+ return sorted(boundary_check)
1041
+ return ()
1042
+
1043
+
1044
+ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
1045
+ # Load by a block pointer: `pointer_type<block_type<>>`
1046
+ # Block pointer can not have `mask` and `other` arguments
1047
+ if mask is not None or other is not None:
1048
+ raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
1049
+
1050
+ elt_ty = ptr.type.element_ty.element_ty
1051
+ assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`"
1052
+ if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
1053
+ raise ValueError("Padding option `nan` is not supported for integer block pointers")
1054
+
1055
+ # `dst_ty` is de-referenced type of the pointer type
1056
+ dst_ty = ptr.type.element_ty
1057
+
1058
+ # Check `boundary_check` argument
1059
+ boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
1060
+
1061
+ # Build IR
1062
+ return tl.tensor(
1063
+ builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty)
1064
+
1065
+
1066
+ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
1067
+ # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1068
+ if not ptr.type.scalar.is_ptr():
1069
+ raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
1070
+
1071
+ # Check `mask`, `other`, `boundary_check`, and `padding` arguments
1072
+ if mask is None and other is not None:
1073
+ raise ValueError("`other` cannot be provided without `mask`")
1074
+ if padding or boundary_check:
1075
+ raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of"
1076
+ "pointers or loading a scalar. Because the compiler does not know the boundary; please "
1077
+ "use block pointers (defined by `make_block_ptr`) instead")
1078
+
1079
+ # For a pointer of scalar, check the type of `mask` and `other`
1080
+ if not ptr.type.is_block():
1081
+ if mask and mask.type.is_block():
1082
+ raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
1083
+ if other and other.type.is_block():
1084
+ raise ValueError("Other argument cannot be block type if pointer argument is not a block")
1085
+
1086
+ # Make `mask` and `other` into the same shape as `ptr`
1087
+ if ptr.type.is_block():
1088
+ if mask is not None:
1089
+ mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
1090
+ if other is not None:
1091
+ other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
1092
+
1093
+ # Get `pointer_type<elt_ty>` and `elt_ty`
1094
+ ptr_ty = ptr.type.scalar
1095
+ elt_ty = ptr_ty.element_ty
1096
+
1097
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1098
+ is_bool = elt_ty == tl.int1
1099
+ if is_bool:
1100
+ elt_ty = tl.int8
1101
+ ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
1102
+ ptr = cast(ptr, ptr_ty, builder)
1103
+
1104
+ # Cast `other` into `elt_ty` type
1105
+ if other is not None:
1106
+ other = cast(other, elt_ty, builder)
1107
+
1108
+ # Create loaded result type `dst_ty`
1109
+ if ptr.type.is_block():
1110
+ shape = ptr.type.get_block_shapes()
1111
+ dst_ty = tl.block_type(elt_ty, shape)
1112
+ else:
1113
+ # Load by de-referencing the pointer of scalar
1114
+ dst_ty = elt_ty
1115
+
1116
+ # Build IR
1117
+ if mask is None:
1118
+ ret = tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
1119
+ else:
1120
+ ret = tl.tensor(
1121
+ builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction,
1122
+ is_volatile), dst_ty)
1123
+ if is_bool:
1124
+ ret = cast(ret, tl.int1, builder)
1125
+ return ret
1126
+
1127
+
1128
+ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple,
1129
+ padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool,
1130
+ builder: ir.builder) -> tl.tensor:
1131
+ # Cache, eviction and padding options
1132
+ cache = _str_to_load_cache_modifier(cache_modifier)
1133
+ eviction = _str_to_eviction_policy(eviction_policy)
1134
+ padding = _str_to_padding_option(padding_option)
1135
+
1136
+ if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
1137
+ # Load by a block pointer: `pointer_type<block_type<>>`
1138
+ return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
1139
+ else:
1140
+ # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1141
+ return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
1142
+
1143
+
1144
+ def descriptor_load(desc_ptr: tl.tensor, offsets, cache_modifier: str, eviction_policy: str, type,
1145
+ builder: ir.builder) -> tl.tensor:
1146
+ offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1147
+ x = builder.create_descriptor_load(desc_ptr.handle, offsets, type.to_ir(builder),
1148
+ _str_to_load_cache_modifier(cache_modifier),
1149
+ _str_to_eviction_policy(eviction_policy))
1150
+ return tl.tensor(x, type)
1151
+
1152
+
1153
+ def descriptor_store(desc_ptr: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
1154
+ offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1155
+ return tl.tensor(builder.create_descriptor_store(desc_ptr.handle, value.handle, offsets), tl.void)
1156
+
1157
+
1158
+ def tensormap_create(
1159
+ desc_ptr: tl.tensor,
1160
+ global_address: tl.tensor,
1161
+ box_dim: List[tl.tensor],
1162
+ global_dim: List[tl.tensor],
1163
+ global_stride: List[tl.tensor],
1164
+ element_stride: List[tl.tensor],
1165
+ elem_type: int,
1166
+ interleave_layout: int,
1167
+ swizzle_mode: int,
1168
+ fill_mode: int,
1169
+ builder: ir.builder,
1170
+ ) -> tl.tensor:
1171
+ assert not global_stride or global_stride[0].dtype == tl.int64
1172
+ return tl.tensor(
1173
+ builder.create_tensormap_create(
1174
+ desc_ptr.handle,
1175
+ global_address.handle,
1176
+ [x.handle for x in box_dim],
1177
+ [x.handle for x in global_dim],
1178
+ [x.handle for x in global_stride],
1179
+ [x.handle for x in element_stride],
1180
+ elem_type,
1181
+ interleave_layout,
1182
+ swizzle_mode,
1183
+ fill_mode,
1184
+ ),
1185
+ tl.void,
1186
+ )
1187
+
1188
+
1189
+ def tensormap_fenceproxy_acquire(desc_ptr: tl.tensor, builder: ir.builder) -> tl.tensor:
1190
+ return tl.tensor(builder.create_tensormap_fenceproxy_acquire(desc_ptr.handle), tl.void)
1191
+
1192
+
1193
+ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder):
1194
+ # Store by a block pointer: `pointer_type<block_type<>>`
1195
+ # Block pointers can not have the `mask` argument
1196
+ if mask is not None:
1197
+ raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
1198
+
1199
+ # Check same shape and element type
1200
+ block_shape = ptr.type.element_ty.get_block_shapes()
1201
+ if not val.type.is_block():
1202
+ val = broadcast_impl_shape(val, block_shape, builder)
1203
+ assert val.type.is_block(), "Value argument must be block type or a scalar"
1204
+ assert block_shape == val.type.get_block_shapes(
1205
+ ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
1206
+ assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
1207
+
1208
+ elt_ty = ptr.type.element_ty.element_ty
1209
+ assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`"
1210
+
1211
+ # Check `boundary_check` argument
1212
+ boundary_check = _canonicalize_boundary_check(boundary_check, block_shape)
1213
+
1214
+ # Cast to target data type
1215
+ val = cast(val, elt_ty, builder)
1216
+
1217
+ # Build IR
1218
+ return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction),
1219
+ tl.void)
1220
+
1221
+
1222
+ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
1223
+ # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1224
+ if not ptr.type.scalar.is_ptr():
1225
+ raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`")
1226
+
1227
+ # Check `boundary_check` argument
1228
+ if boundary_check:
1229
+ raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a "
1230
+ "scalar. Because the compiler does not know the boundary; please use block pointers "
1231
+ "(defined by `make_block_ptr`) instead")
1232
+
1233
+ # For a pointer of scalar, check the type of `val` and `mask`
1234
+ if not ptr.type.is_block():
1235
+ if val.type.is_block():
1236
+ raise ValueError("Value argument cannot be block type if pointer argument is not a block")
1237
+ if mask and mask.type.is_block():
1238
+ raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
1239
+
1240
+ # Make `mask` and `val` into the same shape as `ptr`
1241
+ if ptr.type.is_block():
1242
+ val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
1243
+ if mask is not None:
1244
+ mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
1245
+
1246
+ ptr_ty = ptr.type.scalar
1247
+ elt_ty = ptr_ty.element_ty
1248
+
1249
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1250
+ if elt_ty == tl.int1:
1251
+ elt_ty = tl.int8
1252
+ ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
1253
+ ptr = cast(ptr, ptr_ty, builder)
1254
+
1255
+ # Cast to target data type
1256
+ val = cast(val, elt_ty, builder)
1257
+
1258
+ # Build IR
1259
+ if not mask:
1260
+ return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
1261
+ if not mask.type.scalar.is_bool():
1262
+ raise ValueError("Mask must have boolean scalar type")
1263
+ return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
1264
+
1265
+
1266
+ def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str,
1267
+ eviction_policy: str, builder: ir.builder) -> tl.tensor:
1268
+ # Cache and eviction options
1269
+ cache = _str_to_store_cache_modifier(cache_modifier)
1270
+ eviction = _str_to_eviction_policy(eviction_policy)
1271
+
1272
+ if ptr.type.is_const() or ptr.type.scalar.is_const():
1273
+ raise ValueError("Cannot store to a constant pointer")
1274
+
1275
+ if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
1276
+ # Store by a block pointer: `pointer_type<block_type<>>`
1277
+ return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder)
1278
+ else:
1279
+ # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1280
+ return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder)
1281
+
1282
+
1283
+ #########
1284
+ # atomic
1285
+ #########
1286
+
1287
+
1288
+ def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1289
+ sem = _str_to_sem(sem)
1290
+ scope = _str_to_scope(scope)
1291
+ element_ty = ptr.type.scalar.element_ty
1292
+ if element_ty.primitive_bitwidth not in [16, 32, 64]:
1293
+ raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
1294
+ return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
1295
+
1296
+
1297
+ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str,
1298
+ builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
1299
+ if not ptr.type.scalar.is_ptr():
1300
+ raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
1301
+ if ptr.type.is_const() or ptr.type.element_ty.is_const():
1302
+ raise ValueError("Cannot store to a constant pointer")
1303
+ element_ty = ptr.type.scalar.element_ty
1304
+ if element_ty is tl.float16 and op != 'add':
1305
+ raise ValueError("atomic_" + op + " does not support fp16")
1306
+ if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
1307
+ raise ValueError("atomic_" + op + " does not support " + str(element_ty))
1308
+ if ptr.type.is_block():
1309
+ if mask is not None:
1310
+ mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
1311
+ if val is not None:
1312
+ val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
1313
+ val = cast(val, ptr.type.scalar.element_ty, builder)
1314
+ if not mask:
1315
+ mask_ir = builder.get_int1(True)
1316
+ mask_ty = tl.int1
1317
+ if ptr.type.is_block():
1318
+ mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes())
1319
+ mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes())
1320
+ mask = tl.tensor(mask_ir, mask_ty)
1321
+ return ptr, val, mask
1322
+
1323
+
1324
+ def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1325
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
1326
+ sem = _str_to_sem(sem)
1327
+ scope = _str_to_scope(scope)
1328
+ sca_ty = val.type.scalar
1329
+ # direct call to atomic_max for integers
1330
+ if sca_ty.is_int():
1331
+ if sca_ty.is_int_signed():
1332
+ return tl.tensor(
1333
+ builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1334
+ else:
1335
+ return tl.tensor(
1336
+ builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1337
+ # for float
1338
+ # return atomic_smax(i_ptr, i_val) if val >= 0
1339
+ # return atomic_umin(i_ptr, i_val) if val < 0
1340
+ if sca_ty not in {tl.float32, tl.float64}:
1341
+ raise TypeError(f"atomic_max not supported for dtype {sca_ty}")
1342
+
1343
+ zero = full([], 0.0, sca_ty, builder)
1344
+
1345
+ i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
1346
+ i_val = bitcast(val, i_type, builder)
1347
+ i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder)
1348
+ ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
1349
+ ui_val = bitcast(val, ui_type, builder)
1350
+ ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder)
1351
+ pos = greater_equal(val, zero, builder)
1352
+ neg = less_than(val, zero, builder)
1353
+ pos_ret = tl.tensor(
1354
+ builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
1355
+ and_(mask, pos, builder).handle, sem, scope), i_val.type)
1356
+ neg_ret = tl.tensor(
1357
+ builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle,
1358
+ and_(mask, neg, builder).handle, sem, scope), ui_val.type)
1359
+ ret = where(pos, pos_ret, neg_ret, builder)
1360
+ return bitcast(ret, sca_ty, builder)
1361
+
1362
+
1363
+ def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1364
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
1365
+ sem = _str_to_sem(sem)
1366
+ scope = _str_to_scope(scope)
1367
+ sca_ty = val.type.scalar
1368
+ # direct call to atomic_min for integers
1369
+ if sca_ty.is_int():
1370
+ if sca_ty.is_int_signed():
1371
+ return tl.tensor(
1372
+ builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1373
+ else:
1374
+ return tl.tensor(
1375
+ builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1376
+ # for float
1377
+ # return atomic_smin(i_ptr, i_val) if val >= 0
1378
+ # return atomic_umax(i_ptr, i_val) if val < 0
1379
+ if sca_ty not in {tl.float32, tl.float64}:
1380
+ raise TypeError(f"atomic_min not supported for dtype {sca_ty}")
1381
+
1382
+ zero = full([], 0.0, sca_ty, builder)
1383
+
1384
+ i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
1385
+ i_val = bitcast(val, i_type, builder)
1386
+ i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder)
1387
+ ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
1388
+ ui_val = bitcast(val, ui_type, builder)
1389
+ ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder)
1390
+ pos = greater_equal(val, zero, builder)
1391
+ neg = less_than(val, zero, builder)
1392
+ pos_ret = tl.tensor(
1393
+ builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
1394
+ and_(mask, pos, builder).handle, sem, scope), i_val.type)
1395
+ neg_ret = tl.tensor(
1396
+ builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle,
1397
+ and_(mask, neg, builder).handle, sem, scope), ui_ptr.type)
1398
+ ret = where(pos, pos_ret, neg_ret, builder)
1399
+ return bitcast(ret, sca_ty, builder)
1400
+
1401
+
1402
+ def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1403
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
1404
+ sem = _str_to_sem(sem)
1405
+ scope = _str_to_scope(scope)
1406
+ sca_ty = val.type.scalar
1407
+ op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
1408
+ return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1409
+
1410
+
1411
+ def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1412
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
1413
+ sem = _str_to_sem(sem)
1414
+ scope = _str_to_scope(scope)
1415
+ return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope),
1416
+ val.type)
1417
+
1418
+
1419
+ def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1420
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
1421
+ sem = _str_to_sem(sem)
1422
+ scope = _str_to_scope(scope)
1423
+ return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope),
1424
+ val.type)
1425
+
1426
+
1427
+ def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1428
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
1429
+ sem = _str_to_sem(sem)
1430
+ scope = _str_to_scope(scope)
1431
+ return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope),
1432
+ val.type)
1433
+
1434
+
1435
+ def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str,
1436
+ builder: ir.builder) -> tl.tensor:
1437
+ ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
1438
+ sem = _str_to_sem(sem)
1439
+ scope = _str_to_scope(scope)
1440
+ return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
1441
+ val.type)
1442
+
1443
+
1444
+ # ===----------------------------------------------------------------------===//
1445
+ # Linear Algebra
1446
+ # ===----------------------------------------------------------------------===//
1447
+
1448
+
1449
+ def _str_to_dot_input_precision(input_precision, builder):
1450
+ assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \
1451
+ f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}"
1452
+ input_precision = input_precision.upper()
1453
+ if input_precision == "TF32X3":
1454
+ input_precision = "TF32x3"
1455
+ return getattr(ir.INPUT_PRECISION, input_precision)
1456
+
1457
+
1458
+ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int,
1459
+ out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
1460
+ assert lhs.type.is_block() and rhs.type.is_block()
1461
+
1462
+ if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
1463
+ # All combinations of supported fp8 x fp8 are permitted
1464
+ pass
1465
+ else:
1466
+ assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
1467
+ tl.float32), f"Unsupported lhs dtype {lhs.dtype}"
1468
+ assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
1469
+ tl.float32), f"Unsupported rhs dtype {rhs.dtype}"
1470
+ assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
1471
+
1472
+ if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
1473
+ lhs = cast(lhs, tl.float16, builder)
1474
+ rhs = cast(rhs, tl.float16, builder)
1475
+
1476
+ if input_precision is None:
1477
+ input_precision = builder.options.default_dot_input_precision
1478
+
1479
+ input_precision = _str_to_dot_input_precision(input_precision, builder)
1480
+
1481
+ lhs_rank = len(lhs.shape)
1482
+ rhs_rank = len(rhs.shape)
1483
+ assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1484
+ assert lhs.shape[-1].value == rhs.shape[
1485
+ -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})"
1486
+ assert builder.codegen_fns.get("min_dot_size") is not None, "target doesn't provide lower shape bounds for dot."
1487
+ min_dot_size = builder.codegen_fns["min_dot_size"](lhs.type, rhs.type)
1488
+ assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \
1489
+ and rhs.shape[-1].value >= min_dot_size[1], \
1490
+ f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}"
1491
+ if lhs.type.scalar.is_int():
1492
+ assert lhs.type.scalar == tl.int8, "only int8 supported!"
1493
+ _0 = builder.get_int32(0)
1494
+ ret_scalar_ty = tl.int32
1495
+ elif out_dtype.is_bf16():
1496
+ raise ValueError(
1497
+ "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
1498
+ elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
1499
+ _0 = builder.get_fp32(0)
1500
+ ret_scalar_ty = tl.float32
1501
+ else:
1502
+ _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0)
1503
+ ret_scalar_ty = out_dtype
1504
+
1505
+ M = lhs.type.shape[-2]
1506
+ N = rhs.type.shape[-1]
1507
+ K = lhs.type.shape[-1]
1508
+ B = lhs.type.shape[0] if lhs_rank == 3 else None
1509
+ ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N])
1510
+ if acc is None:
1511
+ acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N])
1512
+ else:
1513
+ acc_handle = acc.handle
1514
+ assert acc.type == ret_ty
1515
+
1516
+ # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
1517
+ if max_num_imprecise_acc is None:
1518
+ if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
1519
+ max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default
1520
+ else:
1521
+ max_num_imprecise_acc = 0
1522
+ else:
1523
+ if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K:
1524
+ raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})")
1525
+
1526
+ return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc),
1527
+ ret_ty)
1528
+
1529
+
1530
+ def _str_to_fp_type(float_format: Optional[str]):
1531
+ if float_format == 'e4m3':
1532
+ return ir.F8F6F4TY.E4M3
1533
+ if float_format == 'e5m2':
1534
+ return ir.F8F6F4TY.E5M2
1535
+ if float_format == 'e2m3':
1536
+ return ir.F8F6F4TY.E2M3
1537
+ if float_format == 'e3m2':
1538
+ return ir.F8F6F4TY.E3M2
1539
+ if float_format == 'e2m1':
1540
+ return ir.F8F6F4TY.E2M1
1541
+ raise ValueError(f"Invalid float format: {float_format}.")
1542
+
1543
+
1544
+ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
1545
+ rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
1546
+ assert lhs.type.is_block() and rhs.type.is_block()
1547
+ #TODO: validate types.
1548
+ lhs_rank = len(lhs.shape)
1549
+ rhs_rank = len(rhs.shape)
1550
+ assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1551
+ lhs_format_enum = _str_to_fp_type(lhs_format)
1552
+ rhs_format_enum = _str_to_fp_type(rhs_format)
1553
+ assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}"
1554
+ assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}"
1555
+ rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None
1556
+ assert rhs_scale_is_none, "NYI: rhs_scale not supported"
1557
+
1558
+ M = lhs.type.shape[-2]
1559
+ K, N = rhs.type.shape[-2:]
1560
+ PACKED = 2 if lhs_format == "e2m1" else 1
1561
+ assert K == PACKED * lhs.type.shape[
1562
+ -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1563
+ assert K >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
1564
+ B = lhs.type.shape[0] if lhs_rank == 3 else None
1565
+
1566
+ ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N])
1567
+ _0 = builder.get_fp32(0)
1568
+ if acc is None:
1569
+ acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N])
1570
+ else:
1571
+ acc_handle = acc.handle
1572
+ assert acc.type == ret_ty
1573
+ rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
1574
+ return tl.tensor(
1575
+ builder.create_dot_scaled(lhs.handle, lhs_scale.handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
1576
+ rhs_format_enum, acc_handle), ret_ty)
1577
+
1578
+
1579
+ # ===----------------------------------------------------------------------===//
1580
+ # Indexing
1581
+ # ===----------------------------------------------------------------------===//
1582
+
1583
+
1584
+ def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
1585
+ if condition.dtype != tl.int1:
1586
+ warnings.warn(
1587
+ f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}"
1588
+ )
1589
+ condition = cast(condition, tl.int1, builder)
1590
+ x, y = binary_op_type_checking_impl(x, y, builder, True, True)
1591
+ # x, y are broadcasted
1592
+ if condition.type.is_block():
1593
+ condition, x = broadcast_impl_value(condition, x, builder)
1594
+ x, y = broadcast_impl_value(x, y, builder)
1595
+ else:
1596
+ condition, _ = broadcast_impl_value(condition, x, builder)
1597
+ ret_ty = x.type
1598
+ return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
1599
+
1600
+
1601
+ # ===----------------------------------------------------------------------===//
1602
+ # Reduction
1603
+ # ===----------------------------------------------------------------------===
1604
+
1605
+
1606
+ def wrap_tensor(x, scalar_ty, ret_shape):
1607
+ if ret_shape:
1608
+ res_ty = tl.block_type(scalar_ty, ret_shape)
1609
+ else:
1610
+ # 0d-tensor -> scalar
1611
+ res_ty = scalar_ty
1612
+ return tl.tensor(x, res_ty)
1613
+
1614
+
1615
+ def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]:
1616
+ if axis is None:
1617
+ inputs = tuple(reshape(t, [t.numel.value], can_reorder=True, builder=builder) for t in inputs)
1618
+ axis = 0
1619
+ # get result shape
1620
+ shape = inputs[0].type.shape
1621
+ rank = len(shape)
1622
+ assert axis < rank, f"reduction axis must be < inputs rank ({rank})"
1623
+ ret_shape = [s for i, s in enumerate(shape) if i != axis]
1624
+ assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
1625
+
1626
+ reduce_op = builder.create_reduce([t.handle for t in inputs], axis)
1627
+ region_builder_fn(reduce_op)
1628
+ reduce_op.verify()
1629
+
1630
+ return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs)))
1631
+
1632
+
1633
+ # ===----------------------------------------------------------------------===
1634
+ # Associative Scan
1635
+ # ===----------------------------------------------------------------------===
1636
+
1637
+
1638
+ def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, reverse: bool,
1639
+ builder: ir.builder) -> Tuple[tl.tensor, ...]:
1640
+ shape = inputs[0].type.shape
1641
+ rank = len(shape)
1642
+
1643
+ assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
1644
+
1645
+ if axis < 0:
1646
+ axis += rank
1647
+
1648
+ for t in inputs:
1649
+ assert t.type.shape == shape, "all scan inputs must have the same shape"
1650
+
1651
+ scan_op = builder.create_scan([t.handle for t in inputs], axis, reverse)
1652
+ region_builder_fn(scan_op)
1653
+ scan_op.verify()
1654
+
1655
+ return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
1656
+
1657
+
1658
+ # ===----------------------------------------------------------------------===
1659
+ # Histogram
1660
+ # ===----------------------------------------------------------------------===
1661
+
1662
+
1663
+ def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor:
1664
+ assert len(input.shape) == 1, "histogram only supports 1D input"
1665
+ assert input.dtype.is_int(), "histogram only supports integer input"
1666
+ return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, )))
1667
+
1668
+
1669
+ ##
1670
+
1671
+
1672
+ def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
1673
+ if max(1, len(x.shape)) != len(values):
1674
+ raise ValueError("Shape of input to multiple_of does not match the length of values")
1675
+ x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
1676
+ return x
1677
+
1678
+
1679
+ def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
1680
+ if len(x.shape) != len(values):
1681
+ raise ValueError("Shape of input to max_contiguous does not match the length of values")
1682
+ x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
1683
+ return x
1684
+
1685
+
1686
+ def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor:
1687
+ if len(x.shape) != len(values):
1688
+ raise ValueError("Shape of input to max_constancy does not match the length of values")
1689
+ x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context()))
1690
+ return x
1691
+
1692
+
1693
+ def debug_barrier(builder: ir.builder) -> tl.tensor:
1694
+ return tl.tensor(builder.create_barrier(), tl.void)
1695
+
1696
+
1697
+ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor:
1698
+ # It makes sense visually for prefix to end in ": "; make it so. Also,
1699
+ # non-empty prefixes should start with " ".
1700
+ if not prefix.endswith(" ") and args:
1701
+ prefix += " "
1702
+ if not prefix.endswith(": ") and args:
1703
+ prefix = prefix[:-1] + ": "
1704
+ if len(prefix) > 2 and not prefix.startswith(" "):
1705
+ prefix = " " + prefix
1706
+
1707
+ new_args = [arg.handle for arg in args]
1708
+ is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args]
1709
+ return tl.tensor(builder.create_print(prefix, hex, new_args, is_signed), tl.void)
1710
+
1711
+
1712
+ def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
1713
+ if not builder.options.debug:
1714
+ return
1715
+ return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
1716
+
1717
+
1718
+ def assume(cond, builder: ir.builder) -> tl.tensor:
1719
+ return tl.tensor(builder.create_assume(cond.handle), tl.void)
1720
+
1721
+
1722
+ def _convert_elem_to_ir_value(builder, elem, require_i64):
1723
+ if isinstance(elem, int):
1724
+ elem = tl.constexpr(elem)
1725
+ if isinstance(elem, tl.constexpr):
1726
+ if require_i64:
1727
+ assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \
1728
+ f"got a value {elem.value} which is out of the range"
1729
+ return builder.get_int64(elem.value)
1730
+ else:
1731
+ assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \
1732
+ f"got a value {elem.value} which is out of the range"
1733
+ return builder.get_int32(elem.value)
1734
+ elif isinstance(elem, tl.tensor):
1735
+ assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets"
1736
+ assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets"
1737
+ if elem.dtype != tl.int64 and require_i64:
1738
+ return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed())
1739
+ elif elem.dtype != tl.int32 and not require_i64:
1740
+ assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \
1741
+ "add a `.to(tl.int32)` or use regular indexing for 64 bit support"
1742
+ return elem.handle
1743
+ assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}"
1744
+
1745
+
1746
+ def _convert_to_ir_values(builder, list_like, require_i64=True):
1747
+ if hasattr(list_like, "__iter__"):
1748
+ return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like]
1749
+ return [_convert_elem_to_ir_value(builder, list_like, require_i64)]
1750
+
1751
+
1752
+ def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor:
1753
+ # Convert dynamic arguments to IR values
1754
+ # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t`
1755
+ shape = _convert_to_ir_values(builder, shape)
1756
+ strides = _convert_to_ir_values(builder, strides)
1757
+ offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1758
+
1759
+ # Check `base` type
1760
+ if not base.type.is_ptr() or base.type.element_ty.is_block():
1761
+ raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)")
1762
+
1763
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1764
+ if base.type.element_ty == tl.int1:
1765
+ base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder)
1766
+
1767
+ # Check whether `block_shape` is static
1768
+ if not hasattr(block_shape, "__iter__"):
1769
+ block_shape = [block_shape]
1770
+ block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape]
1771
+ assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \
1772
+ "Expected a list of constant integers (`int32_t` range) in `block_shape`"
1773
+
1774
+ # Check `order`
1775
+ if not hasattr(order, "__iter__"):
1776
+ order = [order]
1777
+ order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order]
1778
+ assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order"
1779
+
1780
+ # Must have same length
1781
+ assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \
1782
+ "Expected shape/strides/offsets/block_shape to have the same length"
1783
+
1784
+ # Build value, the type is:
1785
+ # `pointer_type<blocked<shape, element_type>>` in Python
1786
+ # `tt.ptr<tensor<shape, element_type>>` in MLIR
1787
+ handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order)
1788
+ return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape)))
1789
+
1790
+
1791
+ def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
1792
+ # Convert dynamic offsets to IR values
1793
+ offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1794
+
1795
+ # Advanced block pointer type is the same as before
1796
+ return tl.tensor(builder.create_advance(base.handle, offsets), base.type)