triton-windows 3.3.0.post19__cp311-cp311-win_amd64.whl → 3.4.0.post20__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (173) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/runtime/tcc/lib/python310.def +1610 -0
  56. triton/runtime/tcc/lib/python311.def +1633 -0
  57. triton/runtime/tcc/lib/python312.def +1703 -0
  58. triton/runtime/tcc/lib/python313.def +1651 -0
  59. triton/runtime/tcc/lib/python313t.def +1656 -0
  60. triton/runtime/tcc/lib/python39.def +1644 -0
  61. triton/runtime/tcc/lib/python3t.def +905 -0
  62. triton/testing.py +16 -12
  63. triton/tools/disasm.py +3 -4
  64. triton/tools/tensor_descriptor.py +36 -0
  65. triton/windows_utils.py +14 -6
  66. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  67. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  68. {triton_windows-3.3.0.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +1 -1
  69. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  70. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  71. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  72. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  73. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  80. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  81. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  82. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  83. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  84. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  85. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  86. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  87. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  88. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  89. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  90. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  91. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  92. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  93. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  94. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  95. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  96. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  97. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  98. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  99. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  100. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  101. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  102. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  103. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  104. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  105. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  106. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  107. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  108. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  109. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  110. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  111. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  112. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  113. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  114. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  115. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  116. triton/backends/amd/include/hip/device_functions.h +0 -38
  117. triton/backends/amd/include/hip/driver_types.h +0 -468
  118. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  119. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  120. triton/backends/amd/include/hip/hip_common.h +0 -100
  121. triton/backends/amd/include/hip/hip_complex.h +0 -38
  122. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  123. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  124. triton/backends/amd/include/hip/hip_ext.h +0 -161
  125. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  126. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  127. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  128. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  129. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  130. triton/backends/amd/include/hip/hip_profile.h +0 -27
  131. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  132. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  133. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  134. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  135. triton/backends/amd/include/hip/hip_version.h +0 -17
  136. triton/backends/amd/include/hip/hiprtc.h +0 -421
  137. triton/backends/amd/include/hip/library_types.h +0 -78
  138. triton/backends/amd/include/hip/math_functions.h +0 -42
  139. triton/backends/amd/include/hip/surface_types.h +0 -63
  140. triton/backends/amd/include/hip/texture_types.h +0 -194
  141. triton/backends/amd/include/hsa/Brig.h +0 -1131
  142. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  143. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  144. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  145. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  146. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  147. triton/backends/amd/include/hsa/hsa.h +0 -5738
  148. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  149. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  150. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  151. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  152. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  153. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  154. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  155. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  156. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  157. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  158. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  159. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  160. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  161. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  162. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  163. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  164. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  165. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  166. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  167. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  168. triton/backends/amd/include/roctracer/roctx.h +0 -229
  169. triton/language/_utils.py +0 -21
  170. triton/language/extra/cuda/_experimental_tma.py +0 -106
  171. triton/tools/experimental_descriptor.py +0 -32
  172. triton_windows-3.3.0.post19.dist-info/RECORD +0 -253
  173. triton_windows-3.3.0.post19.dist-info/top_level.txt +0 -14
@@ -1,13 +1,16 @@
1
1
  from __future__ import annotations # remove after python 3.11
2
2
  import warnings
3
3
 
4
- from typing import List, Optional, Sequence, Tuple, TypeVar
4
+ from typing import List, Optional, Sequence, Tuple, TypeVar, Generic, Type
5
5
  import numbers
6
6
 
7
+ from triton.runtime import driver
8
+
7
9
  from .._C.libtriton import ir
8
10
  from . import core as tl
9
11
 
10
12
  T = TypeVar('T')
13
+ TensorTy = TypeVar('TensorTy')
11
14
 
12
15
 
13
16
  class IncompatibleTypeErrorImpl(Exception):
@@ -19,1932 +22,1865 @@ class IncompatibleTypeErrorImpl(Exception):
19
22
  super(IncompatibleTypeErrorImpl, self).__init__(self.message)
20
23
 
21
24
 
22
- # ===----------------------------------------------------------------------===##
23
- # Programming Model
24
- # ===----------------------------------------------------------------------===##
25
+ class TritonSemantic(Generic[TensorTy]):
26
+ tensor: Type[TensorTy] = tl.tensor
27
+ lang = tl
25
28
 
29
+ builder: ir.builder
26
30
 
27
- def program_id(axis: int, builder: ir.builder) -> tl.tensor:
28
- if axis not in (0, 1, 2):
29
- raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
30
- return tl.tensor(builder.create_get_program_id(axis), tl.int32)
31
+ def __init__(self, builder):
32
+ self.builder = builder
31
33
 
34
+ # ===----------------------------------------------------------------------===##
35
+ # Programming Model
36
+ # ===----------------------------------------------------------------------===##
32
37
 
33
- def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
34
- if axis not in (0, 1, 2):
35
- raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
36
- return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
38
+ def program_id(self, axis: int) -> TensorTy:
39
+ if axis not in (0, 1, 2):
40
+ raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
41
+ return self.tensor(self.builder.create_get_program_id(axis), tl.int32)
37
42
 
43
+ def num_programs(self, axis: int) -> TensorTy:
44
+ if axis not in (0, 1, 2):
45
+ raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
46
+ return self.tensor(self.builder.create_get_num_programs(axis), tl.int32)
38
47
 
39
48
  # ===----------------------------------------------------------------------===//
40
49
  # Implicit Casting Utilities
41
50
  # ===----------------------------------------------------------------------===//
42
51
 
43
-
44
- def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype:
45
- a_rank = a_ty.int_bitwidth
46
- b_rank = b_ty.int_bitwidth
47
- a_sn = a_ty.int_signedness
48
- b_sn = b_ty.int_signedness
49
- # Rules for signedness taken from "Usual arithmetic conversions" on
50
- # https://en.cppreference.com/w/c/language/conversion.
51
- if a_sn == b_sn:
52
- return a_ty if a_rank > b_rank else b_ty
53
- elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
54
- return a_ty if a_rank >= b_rank else b_ty
55
- elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
56
- return b_ty if b_rank >= a_rank else a_ty
57
- raise TypeError(f"unexpected signedness {a_sn} and {b_sn}")
58
-
59
-
60
- def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool,
61
- div_or_mod: bool) -> tl.dtype:
62
- # 0) For scalars we follow semantics similar to PyTorch, namely:
63
- # - If the scalar is of a lower or equal kind (bool < uint < int < fp),
64
- # it doesn't participate in the promotion
65
- if a_is_scalar != b_is_scalar:
66
- scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty)
67
- if scalar_ty.kind().value <= tensor_ty.kind().value:
68
- # Upcast because of 3) and 4) below!
69
- if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)):
70
- return tl.float32
71
- return tensor_ty
72
-
73
- # 1) if one operand is double, the other is implicitly
74
- # converted to double
75
- if a_ty.is_fp64() or b_ty.is_fp64():
76
- return tl.float64
77
- # 2) if one operand is float, the other is implicitly
78
- # converted to float
79
- if a_ty.is_fp32() or b_ty.is_fp32():
80
- return tl.float32
81
- # 3 ) if one operand is half, the other is implicitly converted to half
82
- # unless we're doing / or %, which do not exist natively in PTX for fp16.
83
- # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
84
- if a_ty.is_fp16() or b_ty.is_fp16():
85
- if div_or_mod:
52
+ def integer_promote_impl(self, a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype:
53
+ a_rank = a_ty.int_bitwidth
54
+ b_rank = b_ty.int_bitwidth
55
+ a_sn = a_ty.int_signedness
56
+ b_sn = b_ty.int_signedness
57
+ # Rules for signedness taken from "Usual arithmetic conversions" on
58
+ # https://en.cppreference.com/w/c/language/conversion.
59
+ if a_sn == b_sn:
60
+ return a_ty if a_rank > b_rank else b_ty
61
+ elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
62
+ return a_ty if a_rank >= b_rank else b_ty
63
+ elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
64
+ return b_ty if b_rank >= a_rank else a_ty
65
+ raise TypeError(f"unexpected signedness {a_sn} and {b_sn}")
66
+
67
+ def computation_type_impl(self, a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool,
68
+ div_or_mod: bool) -> tl.dtype:
69
+ # 0) For scalars we follow semantics similar to PyTorch, namely:
70
+ # - If the scalar is of a lower or equal kind (bool < uint < int < fp),
71
+ # it doesn't participate in the promotion
72
+ if a_is_scalar != b_is_scalar:
73
+ scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty)
74
+ if scalar_ty.kind().value <= tensor_ty.kind().value:
75
+ # Upcast because of 3) and 4) below!
76
+ if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)):
77
+ return tl.float32
78
+ return tensor_ty
79
+
80
+ # 1) if one operand is double, the other is implicitly
81
+ # converted to double
82
+ if a_ty.is_fp64() or b_ty.is_fp64():
83
+ return tl.float64
84
+ # 2) if one operand is float, the other is implicitly
85
+ # converted to float
86
+ if a_ty.is_fp32() or b_ty.is_fp32():
86
87
  return tl.float32
87
- else:
88
- return tl.float16
89
- # 4) return bf16 only if both operands are of bf16
90
- if a_ty.is_bf16() and b_ty.is_bf16():
91
- if div_or_mod:
88
+ # 3 ) if one operand is half, the other is implicitly converted to half
89
+ # unless we're doing / or %, which do not exist natively in PTX for fp16.
90
+ # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
91
+ if a_ty.is_fp16() or b_ty.is_fp16():
92
+ if div_or_mod:
93
+ return tl.float32
94
+ else:
95
+ return tl.float16
96
+ # 4) return bf16 only if both operands are of bf16
97
+ if a_ty.is_bf16() and b_ty.is_bf16():
98
+ if div_or_mod:
99
+ return tl.float32
100
+ else:
101
+ return tl.bfloat16
102
+ if a_ty.is_bf16() or b_ty.is_bf16():
92
103
  return tl.float32
93
- else:
94
- return tl.bfloat16
95
- if a_ty.is_bf16() or b_ty.is_bf16():
96
- return tl.float32
97
- # 5) return fp16 if operands are different fp8
98
- if a_ty.is_fp8() and b_ty.is_fp8():
99
- return a_ty if a_ty == b_ty else tl.float16
100
- if not a_ty.is_int() or not b_ty.is_int():
101
- raise TypeError(f"unexpected type {a_ty} and {b_ty}")
102
- # 6 ) both operands are integer and undergo
103
- # integer promotion
104
- if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
105
- raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
106
- " because they have different signedness;"
107
- "this is unlikely to result in a useful answer. Cast them to the same signedness.")
108
- return integer_promote_impl(a_ty, b_ty)
109
-
110
-
111
- def to_tensor(x, builder, check_type: bool = True):
112
- if isinstance(x, bool):
113
- return tl.tensor(builder.get_int1(x), tl.int1)
114
- # Note: compile-time const integers are represented by unsigned values
115
- elif isinstance(x, int):
116
- if -2**31 <= x < 2**31:
117
- dtype = tl.int32
118
- elif 2**31 <= x < 2**32:
119
- dtype = tl.uint32
120
- elif -2**63 <= x < 2**63:
121
- dtype = tl.int64
122
- elif 2**63 <= x < 2**64:
123
- dtype = tl.uint64
124
- else:
125
- raise ValueError(f'Nonrepresentable integer {x}.')
126
- return full((), x, dtype=dtype, builder=builder)
127
- elif isinstance(x, float):
128
- min_float32 = 2**-126
129
- max_float32 = (2 - 2**-23) * 2**127
130
- abs_x = __builtins__['abs'](x)
131
- if abs_x == float("inf") or\
132
- abs_x == 0.0 or \
133
- x != x or \
134
- min_float32 <= abs_x <= max_float32:
135
- dtype = tl.float32
136
- else:
137
- dtype = tl.float64
138
- return full((), x, dtype=dtype, builder=builder)
139
-
140
- elif isinstance(x, tl.constexpr):
141
- return to_tensor(x.value, builder)
142
- elif isinstance(x, tl.tensor):
104
+ # 5) return fp16 if operands are different fp8
105
+ if a_ty.is_fp8() and b_ty.is_fp8():
106
+ return a_ty if a_ty == b_ty else tl.float16
107
+ if not a_ty.is_int() or not b_ty.is_int():
108
+ raise TypeError(f"unexpected type {a_ty} and {b_ty}")
109
+ # 6 ) both operands are integer and undergo
110
+ # integer promotion
111
+ if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
112
+ raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
113
+ " because they have different signedness;"
114
+ "this is unlikely to result in a useful answer. Cast them to the same signedness.")
115
+ return self.integer_promote_impl(a_ty, b_ty)
116
+
117
+ def to_tensor(self, x, check_type: bool = True):
118
+ if isinstance(x, bool):
119
+ return self.tensor(self.builder.get_int1(x), tl.int1)
120
+ # Note: compile-time const integers are represented by unsigned values
121
+ elif isinstance(x, int):
122
+ if -2**31 <= x < 2**31:
123
+ dtype = tl.int32
124
+ elif 2**31 <= x < 2**32:
125
+ dtype = tl.uint32
126
+ elif -2**63 <= x < 2**63:
127
+ dtype = tl.int64
128
+ elif 2**63 <= x < 2**64:
129
+ dtype = tl.uint64
130
+ else:
131
+ raise ValueError(f'Nonrepresentable integer {x}.')
132
+ return self.scalar_constant(x, dtype=dtype)
133
+ elif isinstance(x, float):
134
+ min_float32 = 2**-126
135
+ max_float32 = (2 - 2**-23) * 2**127
136
+ abs_x = __builtins__['abs'](x)
137
+ if abs_x == float("inf") or\
138
+ abs_x == 0.0 or \
139
+ x != x or \
140
+ min_float32 <= abs_x <= max_float32:
141
+ dtype = tl.float32
142
+ else:
143
+ dtype = tl.float64
144
+ return self.scalar_constant(x, dtype=dtype)
145
+
146
+ elif isinstance(x, tl.constexpr):
147
+ return self.to_tensor(x.value)
148
+ elif isinstance(x, self.tensor):
149
+ return x
150
+ if check_type:
151
+ raise TypeError(f"cannot convert {x} of type {type(x)} to tensor")
143
152
  return x
144
- if check_type:
145
- raise TypeError(f"cannot convert {x} of type {type(x)} to tensor")
146
- return x
147
-
148
153
 
149
154
  # ===----------------------------------------------------------------------===//
150
155
  # Binary Operators
151
156
  # ===----------------------------------------------------------------------===//
152
157
 
158
+ def check_ptr_type_impl(self, type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
159
+ if type_a.is_ptr():
160
+ if not allow_ptr_a:
161
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
162
+ # T* + U* with T != U
163
+ if type_b.is_ptr() and (type_a != type_b):
164
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
165
+ # T* + float
166
+ if type_b.is_floating():
167
+ raise IncompatibleTypeErrorImpl(type_a, type_b)
168
+
169
+ def binary_op_type_checking_impl(self, lhs: TensorTy | numbers.Number, rhs: TensorTy | numbers.Number,
170
+ allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True,
171
+ div_or_mod=False) -> Tuple[TensorTy, TensorTy]:
172
+ lhs_is_scalar = isinstance(lhs, numbers.Number)
173
+ rhs_is_scalar = isinstance(rhs, numbers.Number)
174
+ if lhs_is_scalar:
175
+ lhs_scalar = lhs
176
+ lhs = self.to_tensor(lhs)
177
+ if rhs_is_scalar:
178
+ rhs_scalar = rhs
179
+ rhs = self.to_tensor(rhs)
180
+
181
+ # implicit typecasting
182
+ lhs_sca_ty = lhs.type.scalar
183
+ rhs_sca_ty = rhs.type.scalar
184
+ self.check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
185
+ self.check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
186
+ if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr():
187
+ ret_sca_ty = self.computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod)
188
+ if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned()
189
+ or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()):
190
+ raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. "
191
+ "Perform a explicit cast on one of them.")
192
+ if ret_sca_ty.is_int():
193
+ if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <=
194
+ ret_sca_ty.get_int_max_value()):
195
+ raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}")
196
+ if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <=
197
+ ret_sca_ty.get_int_max_value()):
198
+ raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}")
199
+ lhs = self.scalar_constant(lhs_scalar, dtype=ret_sca_ty) if lhs_is_scalar else self.cast(lhs, ret_sca_ty)
200
+ rhs = self.scalar_constant(rhs_scalar, dtype=ret_sca_ty) if rhs_is_scalar else self.cast(rhs, ret_sca_ty)
201
+
202
+ # implicit broadcasting
203
+ lhs, rhs = self.broadcast_impl_value(lhs, rhs)
204
+ return lhs, rhs
205
+
206
+ def binary_op_sanitize_overflow_impl(self, lhs: TensorTy, rhs: TensorTy, binary_op: callable):
207
+ if lhs.type.scalar.int_bitwidth >= 64 or not self.builder.options.sanitize_overflow:
208
+ return
209
+ lhs_sca_ty = lhs.type.scalar
210
+ rhs_sca_ty = rhs.type.scalar
211
+ assert lhs_sca_ty == rhs_sca_ty
212
+ assert lhs_sca_ty.is_int()
213
+ lhs = self.cast(lhs, tl.int64)
214
+ rhs = self.cast(rhs, tl.int64)
215
+ ret = binary_op(lhs, rhs, False)
216
+ max_value = lhs_sca_ty.get_int_max_value()
217
+ max_value = self.scalar_constant(max_value, tl.int64)
218
+ min_value = lhs_sca_ty.get_int_min_value()
219
+ min_value = self.scalar_constant(min_value, tl.int64)
220
+ cond = self.and_(self.less_equal(ret, max_value), self.greater_equal(ret, min_value))
221
+ msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}"
222
+ self.device_assert(cond, msg)
223
+
224
+ def add(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
225
+ sanitize_overflow: bool) -> TensorTy:
226
+ input, other = self.binary_op_type_checking_impl(input, other, True, True)
227
+ input_scalar_ty = input.type.scalar
228
+ other_scalar_ty = other.type.scalar
229
+ if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
230
+ raise TypeError("cannot add pointers together")
231
+
232
+ # offset + ptr
233
+ # ptr + offset
234
+ if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
235
+ input, other = other, input
236
+ input_scalar_ty = input.type.scalar
237
+ other_scalar_ty = other.type.scalar
238
+ if input_scalar_ty.is_ptr():
239
+ other_handle = other.handle
240
+ if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64:
241
+ # addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive
242
+ i64_ty = other.type.with_element_ty(tl.int64).to_ir(self.builder)
243
+ other_handle = self.builder.create_int_cast(other.handle, i64_ty, False)
244
+ return self.tensor(self.builder.create_addptr(input.handle, other_handle), input.type)
245
+ # float + float
246
+ elif input_scalar_ty.is_floating():
247
+ return self.tensor(self.builder.create_fadd(input.handle, other.handle), input.type)
248
+ # int + int
249
+ elif input_scalar_ty.is_int():
250
+ if sanitize_overflow:
251
+ self.binary_op_sanitize_overflow_impl(input, other, self.add)
252
+ return self.tensor(self.builder.create_add(input.handle, other.handle), input.type)
253
+ raise TypeError(f"unexpected type {input_scalar_ty}")
153
254
 
154
- def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
155
- if type_a.is_ptr():
156
- if not allow_ptr_a:
157
- raise IncompatibleTypeErrorImpl(type_a, type_b)
158
- # T* + U* with T != U
159
- if type_b.is_ptr() and (type_a != type_b):
160
- raise IncompatibleTypeErrorImpl(type_a, type_b)
161
- # T* + float
162
- if type_b.is_floating():
163
- raise IncompatibleTypeErrorImpl(type_a, type_b)
164
-
165
-
166
- def binary_op_type_checking_impl(lhs: tl.tensor | numbers.Number, rhs: tl.tensor | numbers.Number, builder: ir.builder,
167
- allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True,
168
- div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]:
169
- lhs_is_scalar = isinstance(lhs, numbers.Number)
170
- rhs_is_scalar = isinstance(rhs, numbers.Number)
171
- if lhs_is_scalar:
172
- lhs_scalar = lhs
173
- lhs = to_tensor(lhs, builder)
174
- if rhs_is_scalar:
175
- rhs_scalar = rhs
176
- rhs = to_tensor(rhs, builder)
177
-
178
- # implicit typecasting
179
- lhs_sca_ty = lhs.type.scalar
180
- rhs_sca_ty = rhs.type.scalar
181
- check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
182
- check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
183
- if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr():
184
- ret_sca_ty = computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod)
185
- if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned()
186
- or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()):
187
- raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. "
188
- "Perform a explicit cast on one of them.")
189
- if ret_sca_ty.is_int():
190
- if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <= ret_sca_ty.get_int_max_value()):
191
- raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}")
192
- if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <= ret_sca_ty.get_int_max_value()):
193
- raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}")
194
- lhs = full(
195
- (), lhs_scalar, dtype=ret_sca_ty, builder=builder) if lhs_is_scalar else cast(lhs, ret_sca_ty, builder)
196
- rhs = full(
197
- (), rhs_scalar, dtype=ret_sca_ty, builder=builder) if rhs_is_scalar else cast(rhs, ret_sca_ty, builder)
198
-
199
- # implicit broadcasting
200
- lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
201
- return lhs, rhs
202
-
203
-
204
- def binary_op_sanitize_overflow_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, binary_op: callable):
205
- if lhs.type.scalar.int_bitwidth >= 64 or not builder.options.sanitize_overflow:
206
- return
207
- lhs_sca_ty = lhs.type.scalar
208
- rhs_sca_ty = rhs.type.scalar
209
- assert lhs_sca_ty == rhs_sca_ty
210
- assert lhs_sca_ty.is_int()
211
- lhs = cast(lhs, tl.int64, builder)
212
- rhs = cast(rhs, tl.int64, builder)
213
- ret = binary_op(lhs, rhs, False, builder)
214
- max_value = lhs_sca_ty.get_int_max_value()
215
- max_value = tl.tensor(builder.get_int64(max_value), tl.int64)
216
- min_value = lhs_sca_ty.get_int_min_value()
217
- min_value = tl.tensor(builder.get_int64(min_value), tl.int64)
218
- cond = and_(less_equal(ret, max_value, builder), greater_equal(ret, min_value, builder), builder)
219
- msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}"
220
- device_assert(cond, msg, builder)
221
-
222
-
223
- def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool,
224
- builder: ir.builder) -> tl.tensor:
225
- input, other = binary_op_type_checking_impl(input, other, builder, True, True)
226
- input_scalar_ty = input.type.scalar
227
- other_scalar_ty = other.type.scalar
228
- if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
229
- raise TypeError("cannot add pointers together")
230
-
231
- # offset + ptr
232
- # ptr + offset
233
- if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
234
- input, other = other, input
255
+ def sub(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
256
+ sanitize_overflow: bool) -> TensorTy:
257
+ input, other = self.binary_op_type_checking_impl(input, other, True, False)
258
+ scalar_ty = input.type.scalar
259
+ # ptr - offset
260
+ if scalar_ty.is_ptr():
261
+ return self.add(input, self.minus(other), sanitize_overflow=False)
262
+ # float - float
263
+ if scalar_ty.is_floating():
264
+ return self.tensor(self.builder.create_fsub(input.handle, other.handle), input.type)
265
+ # int - int
266
+ elif scalar_ty.is_int():
267
+ if sanitize_overflow:
268
+ self.binary_op_sanitize_overflow_impl(input, other, self.sub)
269
+ return self.tensor(self.builder.create_sub(input.handle, other.handle), input.type)
270
+ raise TypeError(f"unexpected type {scalar_ty}")
271
+
272
+ def mul(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
273
+ sanitize_overflow: bool) -> TensorTy:
274
+ input, other = self.binary_op_type_checking_impl(input, other)
275
+ scalar_ty = input.type.scalar
276
+ # float * float
277
+ if scalar_ty.is_floating():
278
+ return self.tensor(self.builder.create_fmul(input.handle, other.handle), input.type)
279
+ # int * int
280
+ elif scalar_ty.is_int():
281
+ if sanitize_overflow:
282
+ self.binary_op_sanitize_overflow_impl(input, other, self.mul)
283
+ return self.tensor(self.builder.create_mul(input.handle, other.handle), input.type)
284
+ raise TypeError(f"unexpected type {scalar_ty}")
285
+
286
+ def truediv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
287
+ input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
235
288
  input_scalar_ty = input.type.scalar
236
289
  other_scalar_ty = other.type.scalar
237
- if input_scalar_ty.is_ptr():
238
- other_handle = other.handle
239
- if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64:
240
- # addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive
241
- if other.type.is_block():
242
- i64_ty = tl.block_type(tl.int64, other.type.get_block_shapes()).to_ir(builder)
290
+ # float / int
291
+ if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
292
+ other = self.cast(other, input_scalar_ty)
293
+ # int / float
294
+ elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
295
+ input = self.cast(input, other_scalar_ty)
296
+ # int / int (cast to tl.float32)
297
+ elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
298
+ input = self.cast(input, tl.float32)
299
+ other = self.cast(other, tl.float32)
300
+ # float / float (cast to the highest exponent type)
301
+ elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
302
+ if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
303
+ other = self.cast(other, input_scalar_ty)
243
304
  else:
244
- i64_ty = tl.int64.to_ir(builder)
245
- other_handle = builder.create_int_cast(other.handle, i64_ty, False)
246
- return tl.tensor(builder.create_addptr(input.handle, other_handle), input.type)
247
- # float + float
248
- elif input_scalar_ty.is_floating():
249
- return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
250
- # int + int
251
- elif input_scalar_ty.is_int():
252
- if sanitize_overflow:
253
- binary_op_sanitize_overflow_impl(input, other, builder, add)
254
- return tl.tensor(builder.create_add(input.handle, other.handle), input.type)
255
- raise TypeError(f"unexpected type {input_scalar_ty}")
256
-
257
-
258
- def sub(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool,
259
- builder: ir.builder) -> tl.tensor:
260
- input, other = binary_op_type_checking_impl(input, other, builder, True, False)
261
- scalar_ty = input.type.scalar
262
- # ptr - offset
263
- if scalar_ty.is_ptr():
264
- return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type)
265
- # float - float
266
- if scalar_ty.is_floating():
267
- return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
268
- # int - int
269
- elif scalar_ty.is_int():
270
- if sanitize_overflow:
271
- binary_op_sanitize_overflow_impl(input, other, builder, sub)
272
- return tl.tensor(builder.create_sub(input.handle, other.handle), input.type)
273
- raise TypeError(f"unexpected type {scalar_ty}")
274
-
275
-
276
- def mul(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool,
277
- builder: ir.builder) -> tl.tensor:
278
- input, other = binary_op_type_checking_impl(input, other, builder)
279
- scalar_ty = input.type.scalar
280
- # float * float
281
- if scalar_ty.is_floating():
282
- return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type)
283
- # int * int
284
- elif scalar_ty.is_int():
285
- if sanitize_overflow:
286
- binary_op_sanitize_overflow_impl(input, other, builder, mul)
287
- return tl.tensor(builder.create_mul(input.handle, other.handle), input.type)
288
- raise TypeError(f"unexpected type {scalar_ty}")
289
-
290
-
291
- def truediv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor:
292
- input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
293
- input_scalar_ty = input.type.scalar
294
- other_scalar_ty = other.type.scalar
295
- # float / int
296
- if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
297
- other = cast(other, input_scalar_ty, builder)
298
- # int / float
299
- elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
300
- input = cast(input, other_scalar_ty, builder)
301
- # int / int (cast to tl.float32)
302
- elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
303
- input = cast(input, tl.float32, builder)
304
- other = cast(other, tl.float32, builder)
305
- # float / float (cast to the highest exponent type)
306
- elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
307
- if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
308
- other = cast(other, input_scalar_ty, builder)
305
+ input = self.cast(input, other_scalar_ty)
306
+ # unreachable
309
307
  else:
310
- input = cast(input, other_scalar_ty, builder)
311
- # unreachable
312
- else:
308
+ raise TypeError(f"unexpected type {input_scalar_ty}")
309
+ return self.tensor(self.builder.create_fdiv(input.handle, other.handle), input.type)
310
+
311
+ def floordiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
312
+ input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
313
+ input_scalar_ty = input.type.scalar
314
+ other_scalar_ty = other.type.scalar
315
+ if input_scalar_ty.is_int() and other_scalar_ty.is_int():
316
+ ret_ty = self.integer_promote_impl(input_scalar_ty, other_scalar_ty)
317
+ input = self.cast(input, ret_ty)
318
+ other = self.cast(other, ret_ty)
319
+ if ret_ty.is_int_signed():
320
+ return self.tensor(self.builder.create_sdiv(input.handle, other.handle), input.type)
321
+ else:
322
+ return self.tensor(self.builder.create_udiv(input.handle, other.handle), input.type)
313
323
  raise TypeError(f"unexpected type {input_scalar_ty}")
314
- return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
315
-
316
-
317
- def floordiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor:
318
- input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
319
- input_scalar_ty = input.type.scalar
320
- other_scalar_ty = other.type.scalar
321
- if input_scalar_ty.is_int() and other_scalar_ty.is_int():
322
- ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty)
323
- input = cast(input, ret_ty, builder)
324
- other = cast(other, ret_ty, builder)
325
- if ret_ty.is_int_signed():
326
- return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type)
327
- else:
328
- return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type)
329
- raise TypeError(f"unexpected type {input_scalar_ty}")
330
-
331
-
332
- def fdiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, ieee_rounding: bool,
333
- builder: ir.builder) -> tl.tensor:
334
- input_scalar_ty = input.type.scalar
335
- other_scalar_ty = other.type.scalar
336
- if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
337
- raise TypeError("both operands of fdiv must have floating scalar type")
338
- input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
339
- ret = builder.create_fdiv(input.handle, other.handle)
340
- return tl.tensor(ret, input.type)
341
-
342
-
343
- def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor:
344
- input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
345
- scalar_ty = input.type.scalar
346
- other_scalar_ty = other.type.scalar
347
- # float % float
348
- if scalar_ty.is_floating():
349
- return tl.tensor(builder.create_frem(input.handle, other.handle), input.type)
350
- # % int
351
- elif scalar_ty.is_int():
352
- if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
353
- raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
354
- "because they have different signedness;"
355
- "this is unlikely to result in a useful answer. Cast them to the same signedness.")
356
- if scalar_ty.is_int_signed():
357
- return tl.tensor(builder.create_srem(input.handle, other.handle), input.type)
358
- else:
359
- return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
360
- raise TypeError(f"unexpected type {scalar_ty}")
361
324
 
325
+ def fdiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, ieee_rounding: bool) -> TensorTy:
326
+ input_scalar_ty = input.type.scalar
327
+ other_scalar_ty = other.type.scalar
328
+ if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
329
+ raise TypeError("both operands of fdiv must have floating scalar type")
330
+ input, other = self.binary_op_type_checking_impl(input, other, False, False, False, True)
331
+ ret = self.builder.create_fdiv(input.handle, other.handle)
332
+ return self.tensor(ret, input.type)
333
+
334
+ def mod(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
335
+ input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
336
+ scalar_ty = input.type.scalar
337
+ other_scalar_ty = other.type.scalar
338
+ # float % float
339
+ if scalar_ty.is_floating():
340
+ return self.tensor(self.builder.create_frem(input.handle, other.handle), input.type)
341
+ # % int
342
+ elif scalar_ty.is_int():
343
+ if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
344
+ raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
345
+ "because they have different signedness;"
346
+ "this is unlikely to result in a useful answer. Cast them to the same signedness.")
347
+ if scalar_ty.is_int_signed():
348
+ return self.tensor(self.builder.create_srem(input.handle, other.handle), input.type)
349
+ else:
350
+ return self.tensor(self.builder.create_urem(input.handle, other.handle), input.type)
351
+ raise TypeError(f"unexpected type {scalar_ty}")
362
352
 
363
353
  ##############
364
354
  # other arithmetic ops
365
355
  ##############
366
356
 
367
-
368
- def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder):
369
- x, y = binary_op_type_checking_impl(x, y, builder)
370
- dtype = x.dtype
371
- if dtype.is_floating():
372
- if propagate_nan == tl.PropagateNan.ALL:
373
- return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type)
374
- elif propagate_nan == tl.PropagateNan.NONE:
375
- return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type)
357
+ def minimum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan):
358
+ x, y = self.binary_op_type_checking_impl(x, y)
359
+ dtype = x.dtype
360
+ if dtype.is_floating():
361
+ if propagate_nan == tl.PropagateNan.ALL:
362
+ return self.tensor(self.builder.create_minimumf(x.handle, y.handle), x.type)
363
+ elif propagate_nan == tl.PropagateNan.NONE:
364
+ return self.tensor(self.builder.create_minnumf(x.handle, y.handle), x.type)
365
+ else:
366
+ raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
367
+ elif dtype.is_int_signed():
368
+ return self.tensor(self.builder.create_minsi(x.handle, y.handle), x.type)
369
+ elif dtype.is_int_unsigned():
370
+ return self.tensor(self.builder.create_minui(x.handle, y.handle), x.type)
376
371
  else:
377
- raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
378
- elif dtype.is_int_signed():
379
- return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type)
380
- elif dtype.is_int_unsigned():
381
- return tl.tensor(builder.create_minui(x.handle, y.handle), x.type)
382
- else:
383
- raise TypeError(f"Unexpected dtype {dtype}")
384
-
385
-
386
- def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder):
387
- x, y = binary_op_type_checking_impl(x, y, builder)
388
- dtype = x.dtype
389
- if dtype.is_floating():
390
- if propagate_nan == tl.PropagateNan.ALL:
391
- return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type)
392
- elif propagate_nan == tl.PropagateNan.NONE:
393
- return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type)
372
+ raise TypeError(f"Unexpected dtype {dtype}")
373
+
374
+ def maximum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan):
375
+ x, y = self.binary_op_type_checking_impl(x, y)
376
+ dtype = x.dtype
377
+ if dtype.is_floating():
378
+ if propagate_nan == tl.PropagateNan.ALL:
379
+ return self.tensor(self.builder.create_maximumf(x.handle, y.handle), x.type)
380
+ elif propagate_nan == tl.PropagateNan.NONE:
381
+ return self.tensor(self.builder.create_maxnumf(x.handle, y.handle), x.type)
382
+ else:
383
+ raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
384
+ elif dtype.is_int_signed():
385
+ return self.tensor(self.builder.create_maxsi(x.handle, y.handle), x.type)
386
+ elif dtype.is_int_unsigned():
387
+ return self.tensor(self.builder.create_maxui(x.handle, y.handle), x.type)
394
388
  else:
395
- raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
396
- elif dtype.is_int_signed():
397
- return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type)
398
- elif dtype.is_int_unsigned():
399
- return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type)
400
- else:
401
- raise TypeError(f"Unexpected dtype {dtype}")
389
+ raise TypeError(f"Unexpected dtype {dtype}")
402
390
 
391
+ def clamp(self, x: TensorTy, min: TensorTy, max: TensorTy, propagate_nan: tl.PropagateNan):
392
+ min, max = self.binary_op_type_checking_impl(min, max)
393
+ x, min = self.binary_op_type_checking_impl(x, min)
394
+ x, max = self.binary_op_type_checking_impl(x, max)
403
395
 
404
- def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder):
405
- min, max = binary_op_type_checking_impl(min, max, builder)
406
- x, min = binary_op_type_checking_impl(x, min, builder)
407
- x, max = binary_op_type_checking_impl(x, max, builder)
408
-
409
- dtype = x.dtype
410
- if dtype.is_floating():
411
- return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type)
412
- else:
413
- raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported")
414
-
396
+ dtype = x.dtype
397
+ if dtype.is_floating():
398
+ return self.tensor(self.builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type)
399
+ else:
400
+ raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported")
415
401
 
416
402
  ##############
417
403
  # bitwise ops
418
404
  ##############
419
405
 
420
-
421
- def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor,
422
- builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
423
- input, other = binary_op_type_checking_impl(input, other, builder)
424
- input_sca_ty = input.type.scalar
425
- other_sca_ty = other.type.scalar
426
- if not input_sca_ty.is_int() or not other_sca_ty.is_int():
427
- raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
428
- ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
429
- if ret_sca_ty != input_sca_ty:
430
- input = cast(input, ret_sca_ty, builder)
431
- if ret_sca_ty != other_sca_ty:
432
- other = cast(other, ret_sca_ty, builder)
433
- return input, other
434
-
435
-
436
- def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
437
- input, other = bitwise_op_type_checking_impl(input, other, builder)
438
- return tl.tensor(builder.create_and(input.handle, other.handle), input.type)
439
-
440
-
441
- def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
442
- input, other = bitwise_op_type_checking_impl(input, other, builder)
443
- return tl.tensor(builder.create_or(input.handle, other.handle), input.type)
444
-
445
-
446
- def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
447
- input, other = bitwise_op_type_checking_impl(input, other, builder)
448
- return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
449
-
450
-
451
- def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
452
- if not input.type.is_int1():
453
- input = bitcast(input, tl.dtype("int1"), builder)
454
- if not other.type.is_int1():
455
- other = bitcast(other, tl.dtype("int1"), builder)
456
- return and_(input, other, builder)
457
-
458
-
459
- def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
460
- if not input.type.is_int1():
461
- input = bitcast(input, tl.dtype("int1"), builder)
462
- if not other.type.is_int1():
463
- other = bitcast(other, tl.dtype("int1"), builder)
464
- return or_(input, other, builder)
465
-
466
-
467
- def not_(input: tl.tensor, builder: ir.builder):
468
- if not input.type.is_int1():
469
- input = bitcast(input, tl.dtype("int1"), builder)
470
- return invert(input, builder)
471
-
472
-
473
- def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
474
- input, other = bitwise_op_type_checking_impl(input, other, builder)
475
- return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
476
-
477
-
478
- def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
479
- input, other = bitwise_op_type_checking_impl(input, other, builder)
480
- return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type)
481
-
482
-
483
- def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
484
- input, other = bitwise_op_type_checking_impl(input, other, builder)
485
- return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
486
-
406
+ def bitwise_op_type_checking_impl(self, input: TensorTy, other: TensorTy) -> Tuple[TensorTy, TensorTy]:
407
+ input, other = self.binary_op_type_checking_impl(input, other)
408
+ input_sca_ty = input.type.scalar
409
+ other_sca_ty = other.type.scalar
410
+ if not input_sca_ty.is_int() or not other_sca_ty.is_int():
411
+ raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
412
+ ret_sca_ty = self.integer_promote_impl(input_sca_ty, other_sca_ty)
413
+ if ret_sca_ty != input_sca_ty:
414
+ input = self.cast(input, ret_sca_ty)
415
+ if ret_sca_ty != other_sca_ty:
416
+ other = self.cast(other, ret_sca_ty)
417
+ return input, other
418
+
419
+ def and_(self, input: TensorTy, other: TensorTy) -> TensorTy:
420
+ input, other = self.bitwise_op_type_checking_impl(input, other)
421
+ return self.tensor(self.builder.create_and(input.handle, other.handle), input.type)
422
+
423
+ def or_(self, input: TensorTy, other: TensorTy) -> TensorTy:
424
+ input, other = self.bitwise_op_type_checking_impl(input, other)
425
+ return self.tensor(self.builder.create_or(input.handle, other.handle), input.type)
426
+
427
+ def xor_(self, input: TensorTy, other: TensorTy) -> TensorTy:
428
+ input, other = self.bitwise_op_type_checking_impl(input, other)
429
+ return self.tensor(self.builder.create_xor(input.handle, other.handle), input.type)
430
+
431
+ def logical_and(self, input: TensorTy, other: TensorTy) -> TensorTy:
432
+ if not input.type.is_int1():
433
+ input = self.bitcast(input, tl.int1)
434
+ if not other.type.is_int1():
435
+ other = self.bitcast(other, tl.int1)
436
+ return self.and_(input, other)
437
+
438
+ def logical_or(self, input: TensorTy, other: TensorTy) -> TensorTy:
439
+ if not input.type.is_int1():
440
+ input = self.bitcast(input, tl.int1)
441
+ if not other.type.is_int1():
442
+ other = self.bitcast(other, tl.int1)
443
+ return self.or_(input, other)
444
+
445
+ def not_(self, input: TensorTy):
446
+ if not input.type.is_int1():
447
+ input = self.bitcast(input, tl.int1)
448
+ return self.invert(input)
449
+
450
+ def lshr(self, input: TensorTy, other: TensorTy) -> TensorTy:
451
+ input, other = self.bitwise_op_type_checking_impl(input, other)
452
+ return self.tensor(self.builder.create_lshr(input.handle, other.handle), input.type)
453
+
454
+ def ashr(self, input: TensorTy, other: TensorTy) -> TensorTy:
455
+ input, other = self.bitwise_op_type_checking_impl(input, other)
456
+ return self.tensor(self.builder.create_ashr(input.handle, other.handle), input.type)
457
+
458
+ def shl(self, input: TensorTy, other: TensorTy) -> TensorTy:
459
+ input, other = self.bitwise_op_type_checking_impl(input, other)
460
+ return self.tensor(self.builder.create_shl(input.handle, other.handle), input.type)
487
461
 
488
462
  # ===----------------------------------------------------------------------===//
489
463
  # Unary Operators
490
464
  # ===----------------------------------------------------------------------===//
491
465
 
466
+ def plus(self, input: TensorTy) -> TensorTy:
467
+ return input
492
468
 
493
- def plus(input: tl.tensor) -> tl.tensor:
494
- return input
495
-
496
-
497
- def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor:
498
- input_sca_ty = input.type.scalar
499
- if input_sca_ty.is_ptr():
500
- raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
501
- _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
502
- return sub(_0, input, True, builder)
503
-
504
-
505
- def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor:
506
- input_sca_ty = input.type.scalar
507
- if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
508
- raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
509
- _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
510
- return xor_(input, _1, builder)
469
+ def minus(self, input: TensorTy) -> TensorTy:
470
+ input_sca_ty = input.type.scalar
471
+ if input_sca_ty.is_ptr():
472
+ raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
473
+ _0 = self.tensor(self.builder.get_null_value(input_sca_ty.to_ir(self.builder)), input_sca_ty)
474
+ return self.sub(_0, input, True)
511
475
 
476
+ def invert(self, input: TensorTy) -> TensorTy:
477
+ input_sca_ty = input.type.scalar
478
+ if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
479
+ raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
480
+ _1 = self.tensor(self.builder.get_all_ones_value(input_sca_ty.to_ir(self.builder)), input_sca_ty)
481
+ return self.xor_(input, _1)
512
482
 
513
483
  # ===----------------------------------------------------------------------===//
514
484
  # Comparison Operators
515
485
  # ===----------------------------------------------------------------------===//
516
- def _bool_like(v: tl.tensor) -> tl.block_type:
517
- if not v.type.is_block():
518
- return tl.int1
519
- shape = v.type.shape
520
- return tl.block_type(tl.int1, shape)
521
-
522
-
523
- def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
524
- input, other = binary_op_type_checking_impl(input, other, builder)
525
- scalar_ty = input.type.scalar
526
- # float > float
527
- if scalar_ty.is_floating():
528
- return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input))
529
- # > int
530
- elif scalar_ty.is_int():
531
- if scalar_ty.is_int_signed():
532
- return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input))
533
- else:
534
- return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input))
535
- raise TypeError(f"unexpected type {scalar_ty}")
536
-
537
-
538
- def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
539
- input, other = binary_op_type_checking_impl(input, other, builder)
540
- scalar_ty = input.type.scalar
541
- # float >= float
542
- if scalar_ty.is_floating():
543
- return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input))
544
- # >= int
545
- elif scalar_ty.is_int():
546
- if scalar_ty.is_int_signed():
547
- return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input))
548
- else:
549
- return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input))
550
- raise TypeError(f"unexpected type {scalar_ty}")
551
-
552
-
553
- def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
554
- input, other = binary_op_type_checking_impl(input, other, builder)
555
- scalar_ty = input.type.scalar
556
- # float < float
557
- if scalar_ty.is_floating():
558
- return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input))
559
- # < int
560
- elif scalar_ty.is_int():
561
- if scalar_ty.is_int_signed():
562
- return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input))
563
- else:
564
- return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input))
565
- raise TypeError(f"unexpected type {scalar_ty}")
566
-
567
-
568
- def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
569
- input, other = binary_op_type_checking_impl(input, other, builder)
570
- scalar_ty = input.type.scalar
571
- # float < float
572
- if scalar_ty.is_floating():
573
- return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input))
574
- # < int
575
- elif scalar_ty.is_int():
576
- if scalar_ty.is_int_signed():
577
- return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input))
578
- else:
579
- return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input))
580
- raise TypeError(f"unexpected type {scalar_ty}")
581
-
582
-
583
- def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
584
- input, other = binary_op_type_checking_impl(input, other, builder)
585
- scalar_ty = input.type.scalar
586
- # float == float
587
- if scalar_ty.is_floating():
588
- return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input))
589
- # == int
590
- elif scalar_ty.is_int():
591
- return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input))
592
- raise TypeError(f"unexpected type {scalar_ty}")
593
-
594
-
595
- def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
596
- input, other = binary_op_type_checking_impl(input, other, builder)
597
- scalar_ty = input.type.scalar
598
- # float == float
599
- if scalar_ty.is_floating():
600
- return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input))
601
- # == int
602
- elif scalar_ty.is_int():
603
- return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
604
- raise TypeError(f"unexpected type {scalar_ty}")
605
486
 
487
+ def _bool_like(self, v: TensorTy) -> tl.block_type:
488
+ return v.type.with_element_ty(tl.int1)
489
+
490
+ def greater_than(self, input: TensorTy, other: TensorTy) -> TensorTy:
491
+ input, other = self.binary_op_type_checking_impl(input, other)
492
+ scalar_ty = input.type.scalar
493
+ # float > float
494
+ if scalar_ty.is_floating():
495
+ return self.tensor(self.builder.create_fcmpOGT(input.handle, other.handle), self._bool_like(input))
496
+ # > int
497
+ elif scalar_ty.is_int():
498
+ if scalar_ty.is_int_signed():
499
+ return self.tensor(self.builder.create_icmpSGT(input.handle, other.handle), self._bool_like(input))
500
+ else:
501
+ return self.tensor(self.builder.create_icmpUGT(input.handle, other.handle), self._bool_like(input))
502
+ raise TypeError(f"unexpected type {scalar_ty}")
503
+
504
+ def greater_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
505
+ input, other = self.binary_op_type_checking_impl(input, other)
506
+ scalar_ty = input.type.scalar
507
+ # float >= float
508
+ if scalar_ty.is_floating():
509
+ return self.tensor(self.builder.create_fcmpOGE(input.handle, other.handle), self._bool_like(input))
510
+ # >= int
511
+ elif scalar_ty.is_int():
512
+ if scalar_ty.is_int_signed():
513
+ return self.tensor(self.builder.create_icmpSGE(input.handle, other.handle), self._bool_like(input))
514
+ else:
515
+ return self.tensor(self.builder.create_icmpUGE(input.handle, other.handle), self._bool_like(input))
516
+ raise TypeError(f"unexpected type {scalar_ty}")
517
+
518
+ def less_than(self, input: TensorTy, other: TensorTy) -> TensorTy:
519
+ input, other = self.binary_op_type_checking_impl(input, other)
520
+ scalar_ty = input.type.scalar
521
+ # float < float
522
+ if scalar_ty.is_floating():
523
+ return self.tensor(self.builder.create_fcmpOLT(input.handle, other.handle), self._bool_like(input))
524
+ # < int
525
+ elif scalar_ty.is_int():
526
+ if scalar_ty.is_int_signed():
527
+ return self.tensor(self.builder.create_icmpSLT(input.handle, other.handle), self._bool_like(input))
528
+ else:
529
+ return self.tensor(self.builder.create_icmpULT(input.handle, other.handle), self._bool_like(input))
530
+ raise TypeError(f"unexpected type {scalar_ty}")
531
+
532
+ def less_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
533
+ input, other = self.binary_op_type_checking_impl(input, other)
534
+ scalar_ty = input.type.scalar
535
+ # float < float
536
+ if scalar_ty.is_floating():
537
+ return self.tensor(self.builder.create_fcmpOLE(input.handle, other.handle), self._bool_like(input))
538
+ # < int
539
+ elif scalar_ty.is_int():
540
+ if scalar_ty.is_int_signed():
541
+ return self.tensor(self.builder.create_icmpSLE(input.handle, other.handle), self._bool_like(input))
542
+ else:
543
+ return self.tensor(self.builder.create_icmpULE(input.handle, other.handle), self._bool_like(input))
544
+ raise TypeError(f"unexpected type {scalar_ty}")
545
+
546
+ def equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
547
+ input, other = self.binary_op_type_checking_impl(input, other)
548
+ scalar_ty = input.type.scalar
549
+ # float == float
550
+ if scalar_ty.is_floating():
551
+ return self.tensor(self.builder.create_fcmpOEQ(input.handle, other.handle), self._bool_like(input))
552
+ # == int
553
+ elif scalar_ty.is_int():
554
+ return self.tensor(self.builder.create_icmpEQ(input.handle, other.handle), self._bool_like(input))
555
+ raise TypeError(f"unexpected type {scalar_ty}")
556
+
557
+ def not_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
558
+ input, other = self.binary_op_type_checking_impl(input, other)
559
+ scalar_ty = input.type.scalar
560
+ # float == float
561
+ if scalar_ty.is_floating():
562
+ return self.tensor(self.builder.create_fcmpUNE(input.handle, other.handle), self._bool_like(input))
563
+ # == int
564
+ elif scalar_ty.is_int():
565
+ return self.tensor(self.builder.create_icmpNE(input.handle, other.handle), self._bool_like(input))
566
+ raise TypeError(f"unexpected type {scalar_ty}")
606
567
 
607
568
  # ===----------------------------------------------------------------------===//
608
569
  # Block Creation
609
570
  # ===----------------------------------------------------------------------===//
610
571
 
611
-
612
- def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
613
- if not isinstance(start, int) or not isinstance(end, int):
614
- raise ValueError("arange's arguments must be of type tl.constexpr")
615
- is_start_int64 = bool(start >> 32)
616
- is_end_int64 = bool(end >> 32)
617
- if is_start_int64 or is_end_int64:
618
- raise ValueError("arange must fit in int32")
619
- if end <= start:
620
- raise ValueError("arange's end argument must be greater than the start argument")
621
- range = end - start
622
- if (range & (range - 1)) != 0:
623
- raise ValueError("arange's range must be a power of 2")
624
- shape = [range]
625
- ret_ty = tl.block_type(tl.int32, shape)
626
- return tl.tensor(builder.create_make_range(start, end), ret_ty)
627
-
628
-
629
- def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
630
- if isinstance(value, tl.tensor):
631
- assert value.numel.value == 1, "only accepts size-1 tensor"
632
- value = cast(value, dtype, builder)
633
- else:
572
+ def arange(self, start: int, end: int, *, ret_ty: tl.block_type = None) -> TensorTy:
573
+ if not isinstance(start, int) or not isinstance(end, int):
574
+ raise ValueError("arange's arguments must be of type tl.constexpr")
575
+ is_start_int64 = bool(start >> 32)
576
+ is_end_int64 = bool(end >> 32)
577
+ if is_start_int64 or is_end_int64:
578
+ raise ValueError("arange must fit in int32")
579
+ if end <= start:
580
+ raise ValueError("arange's end argument must be greater than the start argument")
581
+ range = end - start
582
+ if (range & (range - 1)) != 0:
583
+ raise ValueError("arange's range must be a power of 2")
584
+ shape = [range]
585
+ if ret_ty is None:
586
+ ret_ty = tl.block_type(tl.int32, shape)
587
+ ret_ty_ir = ret_ty.to_ir(self.builder)
588
+ return self.tensor(self.builder.create_make_range(ret_ty_ir, start, end), ret_ty)
589
+
590
+ def scalar_constant(self, value, dtype: tl.dtype) -> TensorTy:
634
591
  # scalar
635
592
  if dtype is None:
636
593
  raise ValueError("dtype must be specified when value is not a tensor")
637
594
  if value == 0:
638
- value = builder.get_null_value(dtype.to_ir(builder))
595
+ value = self.builder.get_null_value(dtype.to_ir(self.builder))
639
596
  else:
640
- get_value_fn = getattr(builder, f"get_{dtype.name}")
597
+ get_value_fn = getattr(self.builder, f"get_{dtype.name}")
641
598
  value = get_value_fn(value)
642
- value = tl.tensor(value, dtype)
599
+ return self.tensor(value, dtype)
643
600
 
644
- return splat(value, shape, builder)
601
+ def make_scalar(self, value, dtype: tl.dtype) -> TensorTy:
602
+ if isinstance(value, tl.tensor):
603
+ assert value.numel.value == 1, "only accepts size-1 tensor"
604
+ return self.cast(value, dtype)
605
+ # scalar
606
+ return self.scalar_constant(value, dtype)
645
607
 
608
+ def full(self, shape: List[int], value, dtype: tl.dtype) -> TensorTy:
609
+ return self.splat(self.make_scalar(value, dtype), shape)
646
610
 
647
611
  # ===----------------------------------------------------------------------===//
648
612
  # Shape Manipulation
649
613
  # ===----------------------------------------------------------------------===//
650
614
 
615
+ def splat(self, value: TensorTy, shape: List[int]) -> TensorTy:
616
+ assert not value.type.is_block(), "Cannot splat a block tensor"
617
+ if len(shape) == 0:
618
+ return value
619
+ ret_ty = tl.block_type(value.dtype, shape)
620
+ return self.tensor(self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle), ret_ty)
621
+
622
+ def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool) -> TensorTy:
623
+ numel = 1
624
+ for s in dst_shape:
625
+ numel *= s
626
+ if input.type.numel != numel:
627
+ raise ValueError("reshape() cannot change total number of elements in tensor")
628
+ ret_ty = tl.block_type(input.type.scalar, dst_shape)
629
+ return self.tensor(self.builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty)
630
+
631
+ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
632
+ dst_shape = [tl._unwrap_if_constexpr(x) for x in input.shape]
633
+ dst_shape.insert(axis, 1)
634
+
635
+ if not input.type.is_block():
636
+ return self.splat(input, shape=dst_shape)
637
+
638
+ ret_ty = tl.block_type(input.type.scalar, dst_shape)
639
+ return self.tensor(self.builder.create_expand_dims(input.handle, axis), ret_ty)
640
+
641
+ def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy:
642
+ assert can_reorder, "current implementation of `cat` always may reorder elements"
643
+ assert len(lhs.shape) == 1
644
+ ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
645
+ return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type)
646
+
647
+ def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
648
+ a, b = self.broadcast_impl_value(a, b)
649
+
650
+ # The IR can't handle joining two scalars, so upcast them to 1D tensors,
651
+ # then downcast the result.
652
+ was_rank_1 = a.shape == []
653
+ if was_rank_1:
654
+ a = self.expand_dims(a, 0)
655
+ b = self.expand_dims(b, 0)
656
+
657
+ if isinstance(a.shape[-1], tl.constexpr):
658
+ two = tl.constexpr(2)
659
+ else:
660
+ two = 2
661
+ new_shape = a.shape + [two]
651
662
 
652
- def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
653
- assert not value.type.is_block(), "Cannot splat a block tensor"
654
- if len(shape) == 0:
655
- return value
656
- ret_ty = tl.block_type(value.dtype, shape)
657
- return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
658
-
659
-
660
- def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor:
661
- numel = 1
662
- for s in dst_shape:
663
- numel *= s
664
- if input.type.numel != numel:
665
- raise ValueError("reshape() cannot change total number of elements in tensor")
666
- ret_ty = tl.block_type(input.type.scalar, dst_shape)
667
- return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty)
668
-
669
-
670
- def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
671
- dst_shape = [tl._constexpr_to_value(x) for x in input.shape]
672
- dst_shape.insert(axis, 1)
673
-
674
- if not input.type.is_block():
675
- return splat(input, shape=dst_shape, builder=builder)
676
-
677
- ret_ty = tl.block_type(input.type.scalar, dst_shape)
678
- return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
679
-
680
-
681
- def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
682
- assert can_reorder, "current implementation of `cat` always may reorder elements"
683
- assert len(lhs.shape) == 1
684
- ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
685
- return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
686
-
687
-
688
- def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor:
689
- a, b = broadcast_impl_value(a, b, builder)
690
-
691
- # The IR can't handle joining two scalars, so upcast them to 1D tensors,
692
- # then downcast the result.
693
- was_rank_1 = a.shape == []
694
- if was_rank_1:
695
- a = expand_dims(a, 0, builder)
696
- b = expand_dims(b, 0, builder)
697
-
698
- if isinstance(a.shape[-1], tl.constexpr):
699
- two = tl.constexpr(2)
700
- else:
701
- two = 2
702
- new_shape = a.shape + [two]
703
-
704
- ret_type = tl.block_type(a.type.scalar, new_shape)
705
- ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type)
706
-
707
- if was_rank_1:
708
- ret = reshape(ret, [2], can_reorder=False, builder=builder)
709
-
710
- return ret
711
-
712
-
713
- def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
714
- assert (len(a.shape) > 0)
715
- assert (tl._constexpr_to_value(a.shape[-1]) == 2)
716
-
717
- new_shape = a.shape[:-1]
718
- ret_type = tl.block_type(a.type.scalar, new_shape)
719
- outLHS, outRHS = builder.create_split(a.handle)
720
- return (
721
- tl.tensor(outLHS, ret_type),
722
- tl.tensor(outRHS, ret_type),
723
- )
724
-
663
+ ret_type = tl.block_type(a.type.scalar, new_shape)
664
+ ret = self.tensor(self.builder.create_join(a.handle, b.handle), ret_type)
725
665
 
726
- def permute(input: tl.tensor, dims: Tuple[int], builder: ir.builder) -> tl.tensor:
727
- if len(input.shape) != len(dims):
728
- raise ValueError("permute dims must have the same length as input shape")
729
- if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))):
730
- raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}")
666
+ if was_rank_1:
667
+ ret = self.reshape(ret, [2], can_reorder=False)
731
668
 
732
- ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims])
733
- return tl.tensor(builder.create_trans(input.handle, dims), ret_type)
669
+ return ret
734
670
 
671
+ def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
672
+ assert (len(a.shape) > 0)
673
+ assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2)
735
674
 
736
- def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
737
- if not input.type.is_block():
738
- ret_ty = tl.block_type(input.type, shape)
739
- return tl.tensor(builder.create_splat(input.handle, shape), ret_ty)
740
- src_shape = input.type.get_block_shapes()
741
- if len(src_shape) != len(shape):
742
- raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
743
- if shape == src_shape:
744
- return input
745
- for i, item in enumerate(src_shape):
746
- if shape[i] != item and item != 1:
747
- raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
748
- f" must match the existing size ({item}) at non-singleton dimension"
749
- f" {i}: {src_shape}, {shape}")
750
- ret_ty = tl.block_type(input.type.scalar, shape)
751
- return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
752
-
753
-
754
- def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
755
- lhs_ty = lhs.type
756
- rhs_ty = rhs.type
757
-
758
- # make_shape_compatible(block, scalar)
759
- if lhs_ty.is_block() and not rhs_ty.is_block():
760
- rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape)
761
- rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty)
762
- # make_shape_compatible(scalar, block)
763
- elif not lhs_ty.is_block() and rhs_ty.is_block():
764
- lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape)
765
- lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty)
766
- # make_shape_compatible(block, block)
767
- elif lhs_ty.is_block() and rhs_ty.is_block():
768
- lhs_shape = lhs_ty.get_block_shapes()
769
- rhs_shape = rhs_ty.get_block_shapes()
770
-
771
- if len(lhs_shape) < len(rhs_shape):
772
- # Add new axes to lhs
773
- for _ in range(len(lhs_shape), len(rhs_shape)):
774
- lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0),
775
- tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values))
776
- lhs_ty = lhs.type
777
- lhs_shape = lhs_ty.get_block_shapes()
778
- elif len(rhs_shape) < len(lhs_shape):
779
- # Add new axes to rhs
780
- for _ in range(len(rhs_shape), len(lhs_shape)):
781
- rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0),
782
- tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values))
783
- rhs_ty = rhs.type
784
- rhs_shape = rhs_ty.get_block_shapes()
785
- assert len(rhs_shape) == len(lhs_shape)
786
-
787
- ret_shape = []
788
- for i, left in enumerate(lhs_shape):
789
- right = rhs_shape[i]
790
- if left == 1:
791
- ret_shape.append(right)
792
- elif (right == 1) or (right == left):
793
- ret_shape.append(left)
794
- else:
795
- raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
796
- "at index " + str(i) + ": " + str(left) + " and " + str(right))
797
- if lhs_shape != ret_shape:
798
- ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
799
- lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
800
- if rhs_shape != ret_shape:
801
- ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
802
- rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
803
- # (scalar, scalar) => returns original blocks
804
- return lhs, rhs
675
+ new_shape = a.shape[:-1]
676
+ ret_type = tl.block_type(a.type.scalar, new_shape)
677
+ outLHS, outRHS = self.builder.create_split(a.handle)
678
+ return (
679
+ self.tensor(outLHS, ret_type),
680
+ self.tensor(outRHS, ret_type),
681
+ )
805
682
 
683
+ def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
684
+ if len(input.shape) != len(dims):
685
+ raise ValueError("permute dims must have the same length as input shape")
686
+ if sorted(tl._unwrap_if_constexpr(d) for d in dims) != list(range(len(dims))):
687
+ raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}")
688
+
689
+ ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims])
690
+ return self.tensor(self.builder.create_trans(input.handle, dims), ret_type)
691
+
692
+ def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
693
+ if not input.type.is_block():
694
+ return self.splat(input, shape)
695
+ src_shape = input.type.get_block_shapes()
696
+ if len(src_shape) != len(shape):
697
+ raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
698
+ if shape == src_shape:
699
+ return input
700
+ for i, item in enumerate(src_shape):
701
+ if shape[i] != item and item != 1:
702
+ raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
703
+ f" must match the existing size ({item}) at non-singleton dimension"
704
+ f" {i}: {src_shape}, {shape}")
705
+ ret_ty = tl.block_type(input.type.scalar, shape)
706
+ return self.tensor(self.builder.create_broadcast(input.handle, shape), ret_ty)
707
+
708
+ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
709
+ lhs_ty = lhs.type
710
+ rhs_ty = rhs.type
711
+
712
+ # make_shape_compatible(block, scalar)
713
+ if lhs_ty.is_block() and not rhs_ty.is_block():
714
+ rhs_ty = lhs_ty.with_element_ty(rhs_ty.scalar)
715
+ rhs = self.tensor(self.builder.create_splat(rhs_ty.to_ir(self.builder), rhs.handle), rhs_ty)
716
+ # make_shape_compatible(scalar, block)
717
+ elif not lhs_ty.is_block() and rhs_ty.is_block():
718
+ lhs_ty = rhs_ty.with_element_ty(lhs_ty.scalar)
719
+ lhs = self.tensor(self.builder.create_splat(lhs_ty.to_ir(self.builder), lhs.handle), lhs_ty)
720
+ # make_shape_compatible(block, block)
721
+ elif lhs_ty.is_block() and rhs_ty.is_block():
722
+ lhs_shape = lhs_ty.get_block_shapes()
723
+ rhs_shape = rhs_ty.get_block_shapes()
724
+
725
+ if len(lhs_shape) < len(rhs_shape):
726
+ # Add new axes to lhs
727
+ for _ in range(len(lhs_shape), len(rhs_shape)):
728
+ lhs = self.tensor(self.builder.create_expand_dims(lhs.handle, 0),
729
+ tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values))
730
+ lhs_ty = lhs.type
731
+ lhs_shape = lhs_ty.get_block_shapes()
732
+ elif len(rhs_shape) < len(lhs_shape):
733
+ # Add new axes to rhs
734
+ for _ in range(len(rhs_shape), len(lhs_shape)):
735
+ rhs = self.tensor(self.builder.create_expand_dims(rhs.handle, 0),
736
+ tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values))
737
+ rhs_ty = rhs.type
738
+ rhs_shape = rhs_ty.get_block_shapes()
739
+ assert len(rhs_shape) == len(lhs_shape)
740
+
741
+ ret_shape = []
742
+ for i, left in enumerate(lhs_shape):
743
+ right = rhs_shape[i]
744
+ if left == 1:
745
+ ret_shape.append(right)
746
+ elif (right == 1) or (right == left):
747
+ ret_shape.append(left)
748
+ else:
749
+ raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
750
+ "at index " + str(i) + ": " + str(left) + " and " + str(right))
751
+ if lhs_shape != ret_shape:
752
+ ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
753
+ lhs = self.tensor(self.builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
754
+ if rhs_shape != ret_shape:
755
+ ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
756
+ rhs = self.tensor(self.builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
757
+ # (scalar, scalar) => returns original blocks
758
+ return lhs, rhs
806
759
 
807
760
  #######
808
761
  # cast
809
762
  #######
810
763
 
811
-
812
- def _str_to_rounding_mode(rounding_mode: Optional[str]):
813
- if rounding_mode is None:
814
- return None
815
- if rounding_mode == 'rtne':
816
- return ir.ROUNDING_MODE.RTNE
817
- if rounding_mode == 'rtz':
818
- return ir.ROUNDING_MODE.RTZ
819
- raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.")
820
-
821
-
822
- def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor:
823
- src_ty = input.type
824
- if src_ty.is_block():
825
- dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
826
- if src_ty == dst_ty:
827
- return input
828
- src_sca_ty = src_ty.scalar
829
- dst_sca_ty = dst_ty.scalar
830
- if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr():
831
- return cast(input, dst_ty, builder)
832
- # Bitcast
833
- src_bits = src_sca_ty.primitive_bitwidth
834
- dst_bits = dst_sca_ty.primitive_bitwidth
835
- if src_bits != dst_bits:
836
- raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
837
- "data-type of size " + str(dst_bits))
838
- return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
839
-
840
-
841
- def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder,
842
- fp_downcast_rounding: Optional[str] = None) -> tl.tensor:
843
- src_ty = input.type
844
- if src_ty.is_block():
845
- dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
846
- if src_ty == dst_ty:
847
- return input
848
-
849
- src_sca_ty = src_ty.scalar
850
- dst_sca_ty = dst_ty.scalar
851
-
852
- # For fp downcasting default rounding mode should be RTNE, for all other conversions it should
853
- # not be set
854
- fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding)
855
- use_custom_rounding = False
856
- if dst_sca_ty.is_floating() and src_sca_ty.is_floating(
857
- ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth:
858
- if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE
859
- elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True
860
- else:
861
- if fp_downcast_rounding is not None:
862
- raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. "
863
- "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty))
864
-
865
- if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()):
866
- assert builder.codegen_fns.get(
867
- "convert_custom_types") is not None, "target doesn't provide conversion for this type."
868
- return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder)
869
- # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
870
- # and non-default rounding modes for downcasting
871
- if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
872
- (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \
873
- use_custom_rounding:
874
- return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty)
875
-
876
- # bf16 <=> (not fp32)
877
- if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
878
- (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
879
- return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
880
-
881
- # Standard floating types' casting: truncation
882
- # fp64 => fp32, fp16, bf16
883
- # fp32 => fp16, bf16
884
- truncate_fp = src_sca_ty.is_floating() and \
885
- dst_sca_ty.is_floating() and \
886
- src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
887
- if truncate_fp:
888
- return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty)
889
-
890
- # Standard floating types' casting: extension
891
- # fp32 => fp64
892
- # fp16 => fp32, fp64
893
- # bf16 => fp32, fp64
894
- ext_fp = src_sca_ty.is_floating() and \
895
- dst_sca_ty.is_floating() and \
896
- src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
897
- if ext_fp:
898
- return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty)
899
-
900
- # Casting between integer types
901
- if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
902
- (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
903
- sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
904
- if dst_sca_ty.is_bool():
905
- ty = input.dtype.to_ir(builder)
906
- _0 = tl.tensor(builder.get_null_value(ty), input.dtype)
907
- return not_equal(input, _0, builder)
764
+ def _str_to_rounding_mode(self, rounding_mode: Optional[str]):
765
+ if rounding_mode is None:
766
+ return None
767
+ if rounding_mode == 'rtne':
768
+ return ir.ROUNDING_MODE.RTNE
769
+ if rounding_mode == 'rtz':
770
+ return ir.ROUNDING_MODE.RTZ
771
+ raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.")
772
+
773
+ def bitcast(self, input: TensorTy, dst_ty: tl.dtype) -> TensorTy:
774
+ src_ty = input.type
775
+ if src_ty.is_block():
776
+ dst_ty = src_ty.with_element_ty(dst_ty.scalar)
777
+ if src_ty == dst_ty:
778
+ return input
779
+ src_sca_ty = src_ty.scalar
780
+ dst_sca_ty = dst_ty.scalar
781
+ if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr():
782
+ return self.cast(input, dst_ty)
783
+ # Bitcast
784
+ src_bits = src_sca_ty.primitive_bitwidth
785
+ dst_bits = dst_sca_ty.primitive_bitwidth
786
+ if src_bits != dst_bits:
787
+ raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
788
+ "data-type of size " + str(dst_bits))
789
+ return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
790
+
791
+ def cast(self, input: TensorTy, dst_ty: tl.dtype, fp_downcast_rounding: Optional[str] = None) -> TensorTy:
792
+ src_ty = input.type
793
+ src_sca_ty = src_ty.scalar
794
+ dst_sca_ty = dst_ty.scalar
795
+ if src_sca_ty == dst_sca_ty:
796
+ return input
797
+ if src_ty.is_block():
798
+ dst_ty = src_ty.with_element_ty(dst_sca_ty)
799
+
800
+ # For fp downcasting default rounding mode should be RTNE, for all other conversions it should
801
+ # not be set
802
+ fp_downcast_rounding = self._str_to_rounding_mode(fp_downcast_rounding)
803
+ use_custom_rounding = False
804
+ if dst_sca_ty.is_floating() and src_sca_ty.is_floating(
805
+ ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth:
806
+ if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE
807
+ elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True
908
808
  else:
909
- return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty)
910
-
911
- # Casting standard floating types to integer types
912
- if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
913
- if dst_sca_ty.is_bool():
914
- ty = input.dtype.to_ir(builder)
915
- _0 = tl.tensor(builder.get_null_value(ty), input.dtype)
916
- return not_equal(input, _0, builder)
917
- elif dst_sca_ty.is_int_signed():
918
- return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty)
919
- else:
920
- return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty)
921
-
922
- # Casting integer types to standard floating types
923
- if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
924
- if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
925
- return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
926
- else:
927
- return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty)
809
+ if fp_downcast_rounding is not None:
810
+ raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. "
811
+ "Source scalar type is " + str(src_sca_ty) + " and destination type is " +
812
+ str(dst_sca_ty))
813
+
814
+ if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()):
815
+ assert self.builder.codegen_fns.get(
816
+ "convert_custom_types") is not None, "target doesn't provide conversion for this type."
817
+ return self.builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _semantic=self)
818
+ # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
819
+ # and non-default rounding modes for downcasting
820
+ if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
821
+ (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \
822
+ use_custom_rounding:
823
+ return self.tensor(
824
+ self.builder.create_fp_to_fp(input.handle, dst_ty.to_ir(self.builder), fp_downcast_rounding), dst_ty)
825
+
826
+ # bf16 <=> (not fp32)
827
+ if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
828
+ (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
829
+ return self.cast(self.cast(input, tl.float32), dst_sca_ty)
830
+
831
+ # Standard floating types' casting: truncation
832
+ # fp64 => fp32, fp16, bf16
833
+ # fp32 => fp16, bf16
834
+ truncate_fp = src_sca_ty.is_floating() and \
835
+ dst_sca_ty.is_floating() and \
836
+ src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
837
+ if truncate_fp:
838
+ return self.tensor(self.builder.create_fp_trunc(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
839
+
840
+ # Standard floating types' casting: extension
841
+ # fp32 => fp64
842
+ # fp16 => fp32, fp64
843
+ # bf16 => fp32, fp64
844
+ ext_fp = src_sca_ty.is_floating() and \
845
+ dst_sca_ty.is_floating() and \
846
+ src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
847
+ if ext_fp:
848
+ return self.tensor(self.builder.create_fp_ext(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
849
+
850
+ # Casting between integer types
851
+ if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
852
+ (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
853
+ sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
854
+ if dst_sca_ty.is_bool():
855
+ ty = input.dtype.to_ir(self.builder)
856
+ _0 = self.tensor(self.builder.get_null_value(ty), input.dtype)
857
+ return self.not_equal(input, _0)
858
+ else:
859
+ return self.tensor(self.builder.create_int_cast(input.handle, dst_ty.to_ir(self.builder), sign_extend),
860
+ dst_ty)
861
+
862
+ # Casting standard floating types to integer types
863
+ if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
864
+ if dst_sca_ty.is_bool():
865
+ ty = input.dtype.to_ir(self.builder)
866
+ _0 = self.tensor(self.builder.get_null_value(ty), input.dtype)
867
+ return self.not_equal(input, _0)
868
+ elif dst_sca_ty.is_int_signed():
869
+ return self.tensor(self.builder.create_fp_to_si(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
870
+ else:
871
+ return self.tensor(self.builder.create_fp_to_ui(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
928
872
 
929
- # Casting pointer types to integer types
930
- if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
931
- bitwidth = dst_sca_ty.int_bitwidth
932
- if bitwidth == 64:
933
- return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty)
934
- if bitwidth == 1:
935
- return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder)
873
+ # Casting integer types to standard floating types
874
+ if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
875
+ if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
876
+ return self.tensor(self.builder.create_ui_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
877
+ else:
878
+ return self.tensor(self.builder.create_si_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
936
879
 
937
- # Casting integer types to pointer types
938
- if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
939
- return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
880
+ # Casting pointer types to integer types
881
+ if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
882
+ bitwidth = dst_sca_ty.int_bitwidth
883
+ if bitwidth == 64:
884
+ return self.tensor(self.builder.create_ptr_to_int(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
885
+ if bitwidth == 1:
886
+ return self.not_equal(self.cast(input, tl.int64), self.tensor(self.builder.get_int64(0), tl.int64))
940
887
 
941
- # Casting pointer types to pointer types
942
- if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
943
- return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
888
+ # Casting integer types to pointer types
889
+ if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
890
+ return self.tensor(self.builder.create_int_to_ptr(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
944
891
 
945
- assert False, f'cannot cast {input} to {dst_ty}'
892
+ # Casting pointer types to pointer types
893
+ if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
894
+ return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
946
895
 
896
+ assert False, f'cannot cast {input} to {dst_ty}'
947
897
 
948
898
  # ===----------------------------------------------------------------------===//
949
899
  # Memory Operators
950
900
  # ===----------------------------------------------------------------------===//
951
901
 
902
+ def _str_to_load_cache_modifier(self, cache_modifier):
903
+ cache = ir.CACHE_MODIFIER.NONE # default
904
+ if cache_modifier:
905
+ if cache_modifier == ".ca":
906
+ cache = ir.CACHE_MODIFIER.CA
907
+ elif cache_modifier == ".cg":
908
+ cache = ir.CACHE_MODIFIER.CG
909
+ elif cache_modifier == ".cv":
910
+ cache = ir.CACHE_MODIFIER.CV
911
+ else:
912
+ raise ValueError(f"Cache modifier {cache_modifier} not supported")
913
+ return cache
914
+
915
+ def _str_to_store_cache_modifier(self, cache_modifier):
916
+ cache = ir.CACHE_MODIFIER.NONE # default
917
+ if cache_modifier:
918
+ if cache_modifier == ".wb":
919
+ cache = ir.CACHE_MODIFIER.WB
920
+ elif cache_modifier == ".cg":
921
+ cache = ir.CACHE_MODIFIER.CG
922
+ elif cache_modifier == ".cs":
923
+ cache = ir.CACHE_MODIFIER.CS
924
+ elif cache_modifier == ".wt":
925
+ cache = ir.CACHE_MODIFIER.WT
926
+ else:
927
+ raise ValueError(f"Cache modifier {cache_modifier} not supported")
928
+ return cache
929
+
930
+ def _str_to_eviction_policy(self, eviction_policy):
931
+ eviction = ir.EVICTION_POLICY.NORMAL # default
932
+ if eviction_policy:
933
+ if eviction_policy == "evict_last":
934
+ eviction = ir.EVICTION_POLICY.EVICT_LAST
935
+ elif eviction_policy == "evict_first":
936
+ eviction = ir.EVICTION_POLICY.EVICT_FIRST
937
+ else:
938
+ raise ValueError(f"Eviction policy {eviction_policy} not supported")
939
+ return eviction
940
+
941
+ def _str_to_padding_option(self, padding_option):
942
+ padding = None # default
943
+ if padding_option:
944
+ if padding_option == "zero":
945
+ padding = ir.PADDING_OPTION.PAD_ZERO
946
+ elif padding_option == "nan":
947
+ padding = ir.PADDING_OPTION.PAD_NAN
948
+ else:
949
+ raise ValueError(f"Padding option {padding_option} not supported")
950
+ return padding
951
+
952
+ def _str_to_sem(self, sem_option):
953
+ sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
954
+ if sem_option:
955
+ if sem_option == "acquire":
956
+ sem = ir.MEM_SEMANTIC.ACQUIRE
957
+ elif sem_option == "release":
958
+ sem = ir.MEM_SEMANTIC.RELEASE
959
+ elif sem_option == "acq_rel":
960
+ sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
961
+ elif sem_option == "relaxed":
962
+ sem = ir.MEM_SEMANTIC.RELAXED
963
+ else:
964
+ raise ValueError(f"Memory semantic {sem_option} not supported")
965
+ return sem
966
+
967
+ def _str_to_scope(self, scope_option):
968
+ scope = ir.MEM_SYNC_SCOPE.GPU
969
+ if scope_option:
970
+ if scope_option == "gpu":
971
+ scope = ir.MEM_SYNC_SCOPE.GPU
972
+ elif scope_option == "cta":
973
+ scope = ir.MEM_SYNC_SCOPE.CTA
974
+ elif scope_option == "sys":
975
+ scope = ir.MEM_SYNC_SCOPE.SYSTEM
976
+ else:
977
+ raise ValueError(f"Memory semantic {scope_option} not supported")
978
+ return scope
979
+
980
+ def _canonicalize_boundary_check(self, boundary_check, block_shape):
981
+ if boundary_check:
982
+ if not hasattr(boundary_check, "__iter__"):
983
+ boundary_check = [boundary_check]
984
+ boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check]
985
+ for dim in boundary_check:
986
+ assert isinstance(dim, int) and 0 <= dim < len(block_shape)
987
+ assert len(boundary_check) > 0
988
+ assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`"
989
+ return sorted(boundary_check)
990
+ return ()
991
+
992
+ def _load_block_pointer(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile):
993
+ # Load by a block pointer: `pointer_type<block_type<>>`
994
+ # Block pointer can not have `mask` and `other` arguments
995
+ if mask is not None or other is not None:
996
+ raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
952
997
 
953
- def _str_to_load_cache_modifier(cache_modifier):
954
- cache = ir.CACHE_MODIFIER.NONE # default
955
- if cache_modifier:
956
- if cache_modifier == ".ca":
957
- cache = ir.CACHE_MODIFIER.CA
958
- elif cache_modifier == ".cg":
959
- cache = ir.CACHE_MODIFIER.CG
960
- elif cache_modifier == ".cv":
961
- cache = ir.CACHE_MODIFIER.CV
962
- else:
963
- raise ValueError(f"Cache modifier {cache_modifier} not supported")
964
- return cache
965
-
966
-
967
- def _str_to_store_cache_modifier(cache_modifier):
968
- cache = ir.CACHE_MODIFIER.NONE # default
969
- if cache_modifier:
970
- if cache_modifier == ".wb":
971
- cache = ir.CACHE_MODIFIER.WB
972
- elif cache_modifier == ".cg":
973
- cache = ir.CACHE_MODIFIER.CG
974
- elif cache_modifier == ".cs":
975
- cache = ir.CACHE_MODIFIER.CS
976
- elif cache_modifier == ".wt":
977
- cache = ir.CACHE_MODIFIER.WT
978
- else:
979
- raise ValueError(f"Cache modifier {cache_modifier} not supported")
980
- return cache
998
+ elt_ty = ptr.type.element_ty.element_ty
999
+ assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
1000
+ if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
1001
+ raise ValueError("Padding option `nan` is not supported for integer block pointers")
981
1002
 
1003
+ # `dst_ty` is de-referenced type of the pointer type
1004
+ dst_ty = ptr.type.element_ty
982
1005
 
983
- def _str_to_eviction_policy(eviction_policy):
984
- eviction = ir.EVICTION_POLICY.NORMAL # default
985
- if eviction_policy:
986
- if eviction_policy == "evict_last":
987
- eviction = ir.EVICTION_POLICY.EVICT_LAST
988
- elif eviction_policy == "evict_first":
989
- eviction = ir.EVICTION_POLICY.EVICT_FIRST
990
- else:
991
- raise ValueError(f"Eviction policy {eviction_policy} not supported")
992
- return eviction
1006
+ # Check `boundary_check` argument
1007
+ boundary_check = self._canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
1008
+
1009
+ # Build IR
1010
+ return self.tensor(
1011
+ self.builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile),
1012
+ dst_ty)
993
1013
 
1014
+ def _load_legacy(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile):
1015
+ # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1016
+ if not ptr.type.scalar.is_ptr():
1017
+ raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
1018
+
1019
+ # Check `mask`, `other`, `boundary_check`, and `padding` arguments
1020
+ if mask is None and other is not None:
1021
+ raise ValueError("`other` cannot be provided without `mask`")
1022
+ if padding or boundary_check:
1023
+ raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of"
1024
+ "pointers or loading a scalar. Because the compiler does not know the boundary; please "
1025
+ "use block pointers (defined by `make_block_ptr`) instead")
1026
+
1027
+ # For a pointer of scalar, check the type of `mask` and `other`
1028
+ if not ptr.type.is_block():
1029
+ if mask and mask.type.is_block():
1030
+ raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
1031
+ if other and other.type.is_block():
1032
+ raise ValueError("Other argument cannot be block type if pointer argument is not a block")
1033
+
1034
+ # Make `mask` and `other` into the same shape as `ptr`
1035
+ if ptr.type.is_block():
1036
+ if mask is not None:
1037
+ mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes())
1038
+ if other is not None:
1039
+ other = self.broadcast_impl_shape(other, ptr.type.get_block_shapes())
1040
+
1041
+ # Get `pointer_type<elt_ty>` and `elt_ty`
1042
+ ptr_ty = ptr.type.scalar
1043
+ elt_ty = ptr_ty.element_ty
1044
+
1045
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1046
+ is_bool = elt_ty == tl.int1
1047
+ if is_bool:
1048
+ elt_ty = tl.int8
1049
+ ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
1050
+ ptr = self.cast(ptr, ptr_ty)
1051
+
1052
+ # Cast `other` into `elt_ty` type
1053
+ if other is not None:
1054
+ other = self.cast(other, elt_ty)
994
1055
 
995
- def _str_to_padding_option(padding_option):
996
- padding = None # default
997
- if padding_option:
998
- if padding_option == "zero":
999
- padding = ir.PADDING_OPTION.PAD_ZERO
1000
- elif padding_option == "nan":
1001
- padding = ir.PADDING_OPTION.PAD_NAN
1056
+ # Create loaded result type `dst_ty`
1057
+ if ptr.type.is_block():
1058
+ dst_ty = ptr.type.with_element_ty(elt_ty)
1002
1059
  else:
1003
- raise ValueError(f"Padding option {padding_option} not supported")
1004
- return padding
1005
-
1006
-
1007
- def _str_to_sem(sem_option):
1008
- sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
1009
- if sem_option:
1010
- if sem_option == "acquire":
1011
- sem = ir.MEM_SEMANTIC.ACQUIRE
1012
- elif sem_option == "release":
1013
- sem = ir.MEM_SEMANTIC.RELEASE
1014
- elif sem_option == "acq_rel":
1015
- sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
1016
- elif sem_option == "relaxed":
1017
- sem = ir.MEM_SEMANTIC.RELAXED
1060
+ # Load by de-referencing the pointer of scalar
1061
+ dst_ty = elt_ty
1062
+
1063
+ # Build IR
1064
+ if mask is None:
1065
+ ret = self.tensor(self.builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
1018
1066
  else:
1019
- raise ValueError(f"Memory semantic {sem_option} not supported")
1020
- return sem
1021
-
1022
-
1023
- def _str_to_scope(scope_option):
1024
- scope = ir.MEM_SYNC_SCOPE.GPU
1025
- if scope_option:
1026
- if scope_option == "gpu":
1027
- scope = ir.MEM_SYNC_SCOPE.GPU
1028
- elif scope_option == "cta":
1029
- scope = ir.MEM_SYNC_SCOPE.CTA
1030
- elif scope_option == "sys":
1031
- scope = ir.MEM_SYNC_SCOPE.SYSTEM
1067
+ ret = self.tensor(
1068
+ self.builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
1069
+ eviction, is_volatile), dst_ty)
1070
+ if is_bool:
1071
+ ret = self.cast(ret, tl.int1)
1072
+ return ret
1073
+
1074
+ def load(self, ptr: TensorTy, mask: Optional[TensorTy], other: Optional[TensorTy], boundary_check: Tuple,
1075
+ padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool) -> TensorTy:
1076
+ # Cache, eviction and padding options
1077
+ cache = self._str_to_load_cache_modifier(cache_modifier)
1078
+ eviction = self._str_to_eviction_policy(eviction_policy)
1079
+ padding = self._str_to_padding_option(padding_option)
1080
+
1081
+ if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
1082
+ # Load by a block pointer: `pointer_type<block_type<>>`
1083
+ return self._load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile)
1032
1084
  else:
1033
- raise ValueError(f"Memory semantic {scope_option} not supported")
1034
- return scope
1035
-
1036
-
1037
- def _canonicalize_boundary_check(boundary_check, block_shape):
1038
- if boundary_check:
1039
- if not hasattr(boundary_check, "__iter__"):
1040
- boundary_check = [boundary_check]
1041
- boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check]
1042
- for dim in boundary_check:
1043
- assert isinstance(dim, int) and 0 <= dim < len(block_shape)
1044
- assert len(boundary_check) > 0
1045
- assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`"
1046
- return sorted(boundary_check)
1047
- return ()
1048
-
1049
-
1050
- def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
1051
- # Load by a block pointer: `pointer_type<block_type<>>`
1052
- # Block pointer can not have `mask` and `other` arguments
1053
- if mask is not None or other is not None:
1054
- raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
1055
-
1056
- elt_ty = ptr.type.element_ty.element_ty
1057
- assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
1058
- if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
1059
- raise ValueError("Padding option `nan` is not supported for integer block pointers")
1060
-
1061
- # `dst_ty` is de-referenced type of the pointer type
1062
- dst_ty = ptr.type.element_ty
1063
-
1064
- # Check `boundary_check` argument
1065
- boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
1066
-
1067
- # Build IR
1068
- return tl.tensor(
1069
- builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty)
1070
-
1071
-
1072
- def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder):
1073
- # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1074
- if not ptr.type.scalar.is_ptr():
1075
- raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
1076
-
1077
- # Check `mask`, `other`, `boundary_check`, and `padding` arguments
1078
- if mask is None and other is not None:
1079
- raise ValueError("`other` cannot be provided without `mask`")
1080
- if padding or boundary_check:
1081
- raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of"
1082
- "pointers or loading a scalar. Because the compiler does not know the boundary; please "
1083
- "use block pointers (defined by `make_block_ptr`) instead")
1084
-
1085
- # For a pointer of scalar, check the type of `mask` and `other`
1086
- if not ptr.type.is_block():
1087
- if mask and mask.type.is_block():
1088
- raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
1089
- if other and other.type.is_block():
1090
- raise ValueError("Other argument cannot be block type if pointer argument is not a block")
1091
-
1092
- # Make `mask` and `other` into the same shape as `ptr`
1093
- if ptr.type.is_block():
1094
- if mask is not None:
1095
- mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
1096
- if other is not None:
1097
- other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
1098
-
1099
- # Get `pointer_type<elt_ty>` and `elt_ty`
1100
- ptr_ty = ptr.type.scalar
1101
- elt_ty = ptr_ty.element_ty
1102
-
1103
- # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1104
- is_bool = elt_ty == tl.int1
1105
- if is_bool:
1106
- elt_ty = tl.int8
1107
- ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
1108
- ptr = cast(ptr, ptr_ty, builder)
1109
-
1110
- # Cast `other` into `elt_ty` type
1111
- if other is not None:
1112
- other = cast(other, elt_ty, builder)
1113
-
1114
- # Create loaded result type `dst_ty`
1115
- if ptr.type.is_block():
1116
- shape = ptr.type.get_block_shapes()
1117
- dst_ty = tl.block_type(elt_ty, shape)
1118
- else:
1119
- # Load by de-referencing the pointer of scalar
1120
- dst_ty = elt_ty
1121
-
1122
- # Build IR
1123
- if mask is None:
1124
- ret = tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
1125
- else:
1126
- ret = tl.tensor(
1127
- builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction,
1128
- is_volatile), dst_ty)
1129
- if is_bool:
1130
- ret = cast(ret, tl.int1, builder)
1131
- return ret
1132
-
1133
-
1134
- def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple,
1135
- padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool,
1136
- builder: ir.builder) -> tl.tensor:
1137
- # Cache, eviction and padding options
1138
- cache = _str_to_load_cache_modifier(cache_modifier)
1139
- eviction = _str_to_eviction_policy(eviction_policy)
1140
- padding = _str_to_padding_option(padding_option)
1141
-
1142
- if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
1143
- # Load by a block pointer: `pointer_type<block_type<>>`
1144
- return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
1145
- else:
1146
- # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1147
- return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)
1148
-
1149
-
1150
- def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type, builder: ir.builder):
1151
- handle = builder.create_reinterpret_tensor_descriptor(desc_ptr.handle, block_ty.to_ir(builder))
1152
- return tl._experimental_tensor_descriptor_base(handle, block_ty)
1153
-
1154
-
1155
- def validate_descriptor_block(shape, dtype):
1156
- if len(shape) != 2:
1157
- return
1158
- # Due to limitations of the shared memory encoding, the TMA bounding box has
1159
- # to be at least as big as the swizzle tile.
1160
- assert shape[0] >= 8, f"tensor descriptor block shape must have at least 8 rows, but got {shape[0]}"
1161
- min_cols = 32 // dtype.primitive_bitwidth * 8
1162
- assert shape[
1163
- 1] >= min_cols, f"{dtype} tensor descriptor block shape must have at least {min_cols} columns, but got {shape[1]}"
1164
-
1165
-
1166
- def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str,
1167
- builder: ir.builder) -> tl.tensor:
1168
- assert isinstance(desc, tl._experimental_tensor_descriptor_base)
1169
- validate_descriptor_block(desc.block_shape, desc.dtype)
1170
- ndim = len(desc.block_shape)
1171
- assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
1172
-
1173
- offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1174
- x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier),
1175
- _str_to_eviction_policy(eviction_policy))
1176
- return tl.tensor(x, desc.block_type)
1177
-
1178
-
1179
- def descriptor_store(desc: tl._experimental_tensor_descriptor_base, value: tl.tensor, offsets,
1180
- builder: ir.builder) -> tl.tensor:
1181
- assert isinstance(desc, tl._experimental_tensor_descriptor_base)
1182
- validate_descriptor_block(desc.block_shape, desc.dtype)
1183
- ndim = len(desc.block_shape)
1184
- assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
1185
- assert value.shape == desc.block_shape
1186
-
1187
- offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1188
- return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
1189
-
1190
-
1191
- def descriptor_gather(desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str,
1192
- builder: ir.builder) -> tl.tensor:
1193
- assert isinstance(desc, tl._experimental_tensor_descriptor_base)
1194
- assert cache_modifier == "", "cache modifier is not supported yet"
1195
- assert eviction_policy == "", "eviction policy is not supported yet"
1196
-
1197
- # Validate descriptor.
1198
- assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
1199
- assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
1200
-
1201
- # Validate offsets.
1202
- assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}"
1203
-
1204
- # Validate minimum block size.
1205
- assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}"
1206
- dtype = desc.dtype
1207
- min_cols = 32 // dtype.primitive_bitwidth * 8
1208
- assert desc.block_shape[
1209
- 1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
1210
-
1211
- type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]])
1212
- y_offset = _convert_to_ir_values(builder, (y_offset, ), require_i64=False)[0]
1213
- x = builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(builder))
1214
- return tl.tensor(x, type)
1215
-
1216
-
1217
- def descriptor_scatter(desc, value: tl.tensor, x_offsets, y_offset, builder: ir.builder) -> tl.tensor:
1218
- assert isinstance(desc, tl._experimental_tensor_descriptor_base)
1219
-
1220
- # Validate descriptor.
1221
- assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
1222
- assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
1223
-
1224
- # Validate offsets.
1225
- assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}"
1226
-
1227
- # Validate minimum block size.
1228
- assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}"
1229
- dtype = desc.dtype
1230
- min_cols = 32 // dtype.primitive_bitwidth * 8
1231
- assert desc.block_shape[
1232
- 1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
1233
-
1234
- y_offset = _convert_to_ir_values(builder, (y_offset, ), require_i64=False)[0]
1235
- builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset)
1236
- return tl.tensor(None, tl.void)
1237
-
1238
-
1239
- def tensormap_create(
1240
- desc_ptr: tl.tensor,
1241
- global_address: tl.tensor,
1242
- box_dim: List[tl.tensor],
1243
- global_dim: List[tl.tensor],
1244
- global_stride: List[tl.tensor],
1245
- element_stride: List[tl.tensor],
1246
- elem_type: int,
1247
- interleave_layout: int,
1248
- swizzle_mode: int,
1249
- fill_mode: int,
1250
- builder: ir.builder,
1251
- ) -> tl.tensor:
1252
- assert not global_stride or global_stride[0].dtype == tl.int64
1253
- return tl.tensor(
1254
- builder.create_tensormap_create(
1255
- desc_ptr.handle,
1256
- global_address.handle,
1257
- [x.handle for x in box_dim],
1258
- [x.handle for x in global_dim],
1259
- [x.handle for x in global_stride],
1260
- [x.handle for x in element_stride],
1261
- elem_type,
1262
- interleave_layout,
1263
- swizzle_mode,
1264
- fill_mode,
1265
- ),
1266
- tl.void,
1267
- )
1268
-
1269
-
1270
- def tensormap_fenceproxy_acquire(desc_ptr: tl.tensor, builder: ir.builder) -> tl.tensor:
1271
- return tl.tensor(builder.create_tensormap_fenceproxy_acquire(desc_ptr.handle), tl.void)
1272
-
1273
-
1274
- def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder):
1275
- # Store by a block pointer: `pointer_type<block_type<>>`
1276
- # Block pointers can not have the `mask` argument
1277
- if mask is not None:
1278
- raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
1279
-
1280
- # Check same shape and element type
1281
- block_shape = ptr.type.element_ty.get_block_shapes()
1282
- if not val.type.is_block():
1283
- val = broadcast_impl_shape(val, block_shape, builder)
1284
- assert val.type.is_block(), "Value argument must be block type or a scalar"
1285
- assert block_shape == val.type.get_block_shapes(
1286
- ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
1287
- assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
1288
-
1289
- elt_ty = ptr.type.element_ty.element_ty
1290
- assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
1291
-
1292
- # Check `boundary_check` argument
1293
- boundary_check = _canonicalize_boundary_check(boundary_check, block_shape)
1294
-
1295
- # Cast to target data type
1296
- val = cast(val, elt_ty, builder)
1297
-
1298
- # Build IR
1299
- return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction),
1300
- tl.void)
1301
-
1302
-
1303
- def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
1304
- # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1305
- if not ptr.type.scalar.is_ptr():
1306
- raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`")
1307
-
1308
- # Check `boundary_check` argument
1309
- if boundary_check:
1310
- raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a "
1311
- "scalar. Because the compiler does not know the boundary; please use block pointers "
1312
- "(defined by `make_block_ptr`) instead")
1313
-
1314
- # For a pointer of scalar, check the type of `val` and `mask`
1315
- if not ptr.type.is_block():
1316
- if val.type.is_block():
1317
- raise ValueError("Value argument cannot be block type if pointer argument is not a block")
1318
- if mask and mask.type.is_block():
1319
- raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
1320
-
1321
- # Make `mask` and `val` into the same shape as `ptr`
1322
- if ptr.type.is_block():
1323
- val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
1085
+ # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1086
+ return self._load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile)
1087
+
1088
+ def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifier: str,
1089
+ eviction_policy: str) -> TensorTy:
1090
+ assert isinstance(desc, tl.tensor_descriptor_base)
1091
+ ndim = len(desc.block_shape)
1092
+ assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
1093
+
1094
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1095
+ x = self.builder.create_descriptor_load(desc.handle, offsets, self._str_to_load_cache_modifier(cache_modifier),
1096
+ self._str_to_eviction_policy(eviction_policy))
1097
+ return self.tensor(x, desc.block_type)
1098
+
1099
+ def validate_store_like(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> None:
1100
+ assert isinstance(desc, tl.tensor_descriptor_base)
1101
+ ndim = len(desc.block_shape)
1102
+ assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
1103
+ assert value.shape == desc.block_shape
1104
+
1105
+ def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1106
+ self.validate_store_like(desc, value, offsets)
1107
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1108
+ return self.tensor(self.builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
1109
+
1110
+ def descriptor_atomic_add(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1111
+ self.validate_store_like(desc, value, offsets)
1112
+ assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16}, "Unsupported dtype"
1113
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1114
+ kind = ir.DESCRIPTOR_REDUCE_KIND.ADD
1115
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1116
+
1117
+ def _has_native_tma(self, ):
1118
+ target = driver.active.get_current_target()
1119
+ return (target.backend == "cuda" and target.arch >= 90)
1120
+
1121
+ def _descriptor_atomic_min_max_supported(self, dtype):
1122
+ assert dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype"
1123
+ if dtype in {tl.float16, tl.bfloat16}:
1124
+ assert self._has_native_tma(), "16-bit float types require native tma support"
1125
+
1126
+ def descriptor_atomic_min(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1127
+ self.validate_store_like(desc, value, offsets)
1128
+ self._descriptor_atomic_min_max_supported(desc.dtype)
1129
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1130
+ kind = ir.DESCRIPTOR_REDUCE_KIND.MIN
1131
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1132
+
1133
+ def descriptor_atomic_max(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1134
+ self.validate_store_like(desc, value, offsets)
1135
+ self._descriptor_atomic_min_max_supported(desc.dtype)
1136
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1137
+ kind = ir.DESCRIPTOR_REDUCE_KIND.MAX
1138
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1139
+
1140
+ def descriptor_atomic_and(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1141
+ self.validate_store_like(desc, value, offsets)
1142
+ assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
1143
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1144
+ kind = ir.DESCRIPTOR_REDUCE_KIND.AND
1145
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1146
+
1147
+ def descriptor_atomic_or(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1148
+ self.validate_store_like(desc, value, offsets)
1149
+ assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
1150
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1151
+ kind = ir.DESCRIPTOR_REDUCE_KIND.OR
1152
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1153
+
1154
+ def descriptor_atomic_xor(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1155
+ self.validate_store_like(desc, value, offsets)
1156
+ assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
1157
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1158
+ kind = ir.DESCRIPTOR_REDUCE_KIND.XOR
1159
+ return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
1160
+
1161
+ def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str) -> TensorTy:
1162
+ assert isinstance(desc, tl.tensor_descriptor_base)
1163
+ assert cache_modifier == "", "cache modifier is not supported yet"
1164
+ assert eviction_policy == "", "eviction policy is not supported yet"
1165
+
1166
+ # Validate descriptor.
1167
+ assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
1168
+ assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
1169
+
1170
+ # Validate offsets.
1171
+ assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}"
1172
+
1173
+ # Validate minimum block size.
1174
+ assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}"
1175
+ dtype = desc.dtype
1176
+ min_cols = 32 // dtype.primitive_bitwidth * 8
1177
+ assert desc.block_shape[
1178
+ 1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
1179
+
1180
+ type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]])
1181
+ y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0]
1182
+ x = self.builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(self.builder))
1183
+ return self.tensor(x, type)
1184
+
1185
+ def descriptor_scatter(self, desc, value: TensorTy, x_offsets, y_offset) -> TensorTy:
1186
+ assert isinstance(desc, tl.tensor_descriptor_base)
1187
+
1188
+ # Validate descriptor.
1189
+ assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
1190
+ assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
1191
+
1192
+ # Validate offsets.
1193
+ assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}"
1194
+
1195
+ # Validate minimum block size.
1196
+ assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}"
1197
+ dtype = desc.dtype
1198
+ min_cols = 32 // dtype.primitive_bitwidth * 8
1199
+ assert desc.block_shape[
1200
+ 1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
1201
+
1202
+ y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0]
1203
+ self.builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset)
1204
+ return self.tensor(None, tl.void)
1205
+
1206
+ def _store_block_pointer(self, ptr, val, mask, boundary_check, cache, eviction):
1207
+ # Store by a block pointer: `pointer_type<block_type<>>`
1208
+ # Block pointers can not have the `mask` argument
1324
1209
  if mask is not None:
1325
- mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
1210
+ raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
1326
1211
 
1327
- ptr_ty = ptr.type.scalar
1328
- elt_ty = ptr_ty.element_ty
1212
+ # Check same shape and element type
1213
+ block_shape = ptr.type.element_ty.get_block_shapes()
1214
+ if not val.type.is_block():
1215
+ val = self.broadcast_impl_shape(val, block_shape)
1216
+ assert val.type.is_block(), "Value argument must be block type or a scalar"
1217
+ assert block_shape == val.type.get_block_shapes(
1218
+ ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
1219
+ assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
1329
1220
 
1330
- # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1331
- if elt_ty == tl.int1:
1332
- elt_ty = tl.int8
1333
- ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
1334
- ptr = cast(ptr, ptr_ty, builder)
1221
+ elt_ty = ptr.type.element_ty.element_ty
1222
+ assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
1335
1223
 
1336
- # Cast to target data type
1337
- val = cast(val, elt_ty, builder)
1224
+ # Check `boundary_check` argument
1225
+ boundary_check = self._canonicalize_boundary_check(boundary_check, block_shape)
1338
1226
 
1339
- # Build IR
1340
- if mask is None:
1341
- return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
1342
- if not mask.type.scalar.is_bool():
1343
- raise ValueError("Mask must have boolean scalar type")
1344
- return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void)
1227
+ # Cast to target data type
1228
+ val = self.cast(val, elt_ty)
1345
1229
 
1230
+ # Build IR
1231
+ return self.tensor(
1232
+ self.builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), tl.void)
1346
1233
 
1347
- def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str,
1348
- eviction_policy: str, builder: ir.builder) -> tl.tensor:
1349
- # Cache and eviction options
1350
- cache = _str_to_store_cache_modifier(cache_modifier)
1351
- eviction = _str_to_eviction_policy(eviction_policy)
1352
-
1353
- if ptr.type.is_const() or ptr.type.scalar.is_const():
1354
- raise ValueError("Cannot store to a constant pointer")
1355
-
1356
- if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
1357
- # Store by a block pointer: `pointer_type<block_type<>>`
1358
- return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder)
1359
- else:
1234
+ def _store_legacy(self, ptr, val, mask, boundary_check, cache, eviction):
1360
1235
  # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1361
- return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder)
1362
-
1236
+ if not ptr.type.scalar.is_ptr():
1237
+ raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`")
1238
+
1239
+ # Check `boundary_check` argument
1240
+ if boundary_check:
1241
+ raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a "
1242
+ "scalar. Because the compiler does not know the boundary; please use block pointers "
1243
+ "(defined by `make_block_ptr`) instead")
1244
+
1245
+ # For a pointer of scalar, check the type of `val` and `mask`
1246
+ if not ptr.type.is_block():
1247
+ if val.type.is_block():
1248
+ raise ValueError("Value argument cannot be block type if pointer argument is not a block")
1249
+ if mask and mask.type.is_block():
1250
+ raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
1251
+
1252
+ # Make `mask` and `val` into the same shape as `ptr`
1253
+ if ptr.type.is_block():
1254
+ val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes())
1255
+ if mask is not None:
1256
+ mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes())
1257
+
1258
+ ptr_ty = ptr.type.scalar
1259
+ elt_ty = ptr_ty.element_ty
1260
+
1261
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1262
+ if elt_ty == tl.int1:
1263
+ elt_ty = tl.int8
1264
+ ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
1265
+ ptr = self.cast(ptr, ptr_ty)
1266
+
1267
+ # Cast to target data type
1268
+ val = self.cast(val, elt_ty)
1269
+
1270
+ # Build IR
1271
+ if mask is None:
1272
+ return self.tensor(self.builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
1273
+ if not mask.type.scalar.is_bool():
1274
+ raise ValueError("Mask must have boolean scalar type")
1275
+ return self.tensor(self.builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction),
1276
+ tl.void)
1277
+
1278
+ def store(self, ptr: TensorTy, val: TensorTy, mask: Optional[TensorTy], boundary_check, cache_modifier: str,
1279
+ eviction_policy: str) -> TensorTy:
1280
+ # Cache and eviction options
1281
+ cache = self._str_to_store_cache_modifier(cache_modifier)
1282
+ eviction = self._str_to_eviction_policy(eviction_policy)
1283
+
1284
+ if ptr.type.is_const() or ptr.type.scalar.is_const():
1285
+ raise ValueError("Cannot store to a constant pointer")
1286
+
1287
+ if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
1288
+ # Store by a block pointer: `pointer_type<block_type<>>`
1289
+ return self._store_block_pointer(ptr, val, mask, boundary_check, cache, eviction)
1290
+ else:
1291
+ # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
1292
+ return self._store_legacy(ptr, val, mask, boundary_check, cache, eviction)
1363
1293
 
1364
1294
  #########
1365
1295
  # atomic
1366
1296
  #########
1367
1297
 
1368
-
1369
- def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1370
- sem = _str_to_sem(sem)
1371
- scope = _str_to_scope(scope)
1372
- element_ty = ptr.type.scalar.element_ty
1373
- if element_ty.primitive_bitwidth not in [16, 32, 64]:
1374
- raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
1375
- return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
1376
-
1377
-
1378
- def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str,
1379
- builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
1380
- if not ptr.type.scalar.is_ptr():
1381
- raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
1382
- if ptr.type.is_const() or ptr.type.element_ty.is_const():
1383
- raise ValueError("Cannot store to a constant pointer")
1384
- element_ty = ptr.type.scalar.element_ty
1385
- if element_ty is tl.float16 and op != 'add':
1386
- raise ValueError("atomic_" + op + " does not support fp16")
1387
- if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
1388
- raise ValueError("atomic_" + op + " does not support " + str(element_ty))
1389
- if ptr.type.is_block():
1390
- if mask is not None:
1391
- mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
1392
- if val is not None:
1393
- val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
1394
- val = cast(val, ptr.type.scalar.element_ty, builder)
1395
- if mask is None:
1396
- mask_ir = builder.get_int1(True)
1397
- mask_ty = tl.int1
1298
+ def atomic_cas(self, ptr: TensorTy, cmp: TensorTy, val: TensorTy, sem: str, scope: str) -> TensorTy:
1299
+ sem = self._str_to_sem(sem)
1300
+ scope = self._str_to_scope(scope)
1301
+ element_ty = ptr.type.scalar.element_ty
1302
+ if element_ty.primitive_bitwidth not in [16, 32, 64]:
1303
+ raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
1304
+ return self.tensor(self.builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
1305
+
1306
+ def atom_red_typechecking_impl(self, ptr: TensorTy, val: TensorTy, mask: TensorTy,
1307
+ op: str) -> Tuple[TensorTy, TensorTy, TensorTy]:
1308
+ if not ptr.type.scalar.is_ptr():
1309
+ raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
1310
+ if ptr.type.is_const() or ptr.type.element_ty.is_const():
1311
+ raise ValueError("Cannot store to a constant pointer")
1312
+ element_ty = ptr.type.scalar.element_ty
1313
+ if element_ty is tl.float16 and op != 'add':
1314
+ raise ValueError("atomic_" + op + " does not support fp16")
1315
+ if element_ty is tl.bfloat16 and op != 'add':
1316
+ raise ValueError("atomic_" + op + " does not support bf16")
1317
+ if element_ty in [tl.int16, tl.uint16] or element_ty.primitive_bitwidth < 16:
1318
+ raise ValueError("atomic_" + op + " does not support " + str(element_ty))
1398
1319
  if ptr.type.is_block():
1399
- mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes())
1400
- mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes())
1401
- mask = tl.tensor(mask_ir, mask_ty)
1402
- return ptr, val, mask
1403
-
1404
-
1405
- def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1406
- ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
1407
- sem = _str_to_sem(sem)
1408
- scope = _str_to_scope(scope)
1409
- sca_ty = val.type.scalar
1410
- # direct call to atomic_max for integers
1411
- if sca_ty.is_int():
1412
- if sca_ty.is_int_signed():
1413
- return tl.tensor(
1414
- builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1415
- else:
1416
- return tl.tensor(
1417
- builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1418
- # for float
1419
- # return atomic_smax(i_ptr, i_val) if val >= 0
1420
- # return atomic_umin(i_ptr, i_val) if val < 0
1421
- if sca_ty not in {tl.float32, tl.float64}:
1422
- raise TypeError(f"atomic_max not supported for dtype {sca_ty}")
1423
-
1424
- zero = full([], 0.0, sca_ty, builder)
1425
-
1426
- i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
1427
- i_val = bitcast(val, i_type, builder)
1428
- i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder)
1429
- ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
1430
- ui_val = bitcast(val, ui_type, builder)
1431
- ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder)
1432
- pos = greater_equal(val, zero, builder)
1433
- neg = less_than(val, zero, builder)
1434
- pos_ret = tl.tensor(
1435
- builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
1436
- and_(mask, pos, builder).handle, sem, scope), i_val.type)
1437
- neg_ret = tl.tensor(
1438
- builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle,
1439
- and_(mask, neg, builder).handle, sem, scope), ui_val.type)
1440
- ret = where(pos, pos_ret, neg_ret, builder)
1441
- return bitcast(ret, sca_ty, builder)
1442
-
1443
-
1444
- def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1445
- ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
1446
- sem = _str_to_sem(sem)
1447
- scope = _str_to_scope(scope)
1448
- sca_ty = val.type.scalar
1449
- # direct call to atomic_min for integers
1450
- if sca_ty.is_int():
1451
- if sca_ty.is_int_signed():
1452
- return tl.tensor(
1453
- builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1454
- else:
1455
- return tl.tensor(
1456
- builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1457
- # for float
1458
- # return atomic_smin(i_ptr, i_val) if val >= 0
1459
- # return atomic_umax(i_ptr, i_val) if val < 0
1460
- if sca_ty not in {tl.float32, tl.float64}:
1461
- raise TypeError(f"atomic_min not supported for dtype {sca_ty}")
1462
-
1463
- zero = full([], 0.0, sca_ty, builder)
1464
-
1465
- i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
1466
- i_val = bitcast(val, i_type, builder)
1467
- i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder)
1468
- ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
1469
- ui_val = bitcast(val, ui_type, builder)
1470
- ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder)
1471
- pos = greater_equal(val, zero, builder)
1472
- neg = less_than(val, zero, builder)
1473
- pos_ret = tl.tensor(
1474
- builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
1475
- and_(mask, pos, builder).handle, sem, scope), i_val.type)
1476
- neg_ret = tl.tensor(
1477
- builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle,
1478
- and_(mask, neg, builder).handle, sem, scope), ui_ptr.type)
1479
- ret = where(pos, pos_ret, neg_ret, builder)
1480
- return bitcast(ret, sca_ty, builder)
1481
-
1482
-
1483
- def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1484
- ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
1485
- sem = _str_to_sem(sem)
1486
- scope = _str_to_scope(scope)
1487
- sca_ty = val.type.scalar
1488
- op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
1489
- return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1490
-
1491
-
1492
- def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1493
- ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
1494
- sem = _str_to_sem(sem)
1495
- scope = _str_to_scope(scope)
1496
- return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope),
1497
- val.type)
1498
-
1499
-
1500
- def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1501
- ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
1502
- sem = _str_to_sem(sem)
1503
- scope = _str_to_scope(scope)
1504
- return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope),
1505
- val.type)
1506
-
1507
-
1508
- def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor:
1509
- ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
1510
- sem = _str_to_sem(sem)
1511
- scope = _str_to_scope(scope)
1512
- return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope),
1513
- val.type)
1514
-
1515
-
1516
- def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str,
1517
- builder: ir.builder) -> tl.tensor:
1518
- ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
1519
- sem = _str_to_sem(sem)
1520
- scope = _str_to_scope(scope)
1521
- return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
1522
- val.type)
1523
-
1320
+ if mask is not None:
1321
+ mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes())
1322
+ if val is not None:
1323
+ val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes())
1324
+ val = self.cast(val, ptr.type.scalar.element_ty)
1325
+ if mask is None:
1326
+ mask_ir = self.builder.get_int1(True)
1327
+ mask_ty = tl.int1
1328
+ if ptr.type.is_block():
1329
+ mask_ty = ptr.type.with_element_ty(tl.int1)
1330
+ mask_ir = self.builder.create_splat(mask_ty.to_ir(self.builder), mask_ir)
1331
+ mask = self.tensor(mask_ir, mask_ty)
1332
+ return ptr, val, mask
1333
+
1334
+ def _signbit(self, x: TensorTy) -> TensorTy:
1335
+ bitwidth = x.dtype.primitive_bitwidth
1336
+ idtype = tl.get_int_dtype(bitwidth=bitwidth, signed=False)
1337
+ ix = self.bitcast(x, idtype)
1338
+ signbit = self.lshr(ix, bitwidth - 1)
1339
+ return self.cast(signbit, tl.int1)
1340
+
1341
+ def atomic_max(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1342
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'max')
1343
+ sem = self._str_to_sem(sem)
1344
+ scope = self._str_to_scope(scope)
1345
+ sca_ty = val.type.scalar
1346
+ # direct call to atomic_max for integers
1347
+ if sca_ty.is_int():
1348
+ if sca_ty.is_int_signed():
1349
+ return self.tensor(
1350
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope),
1351
+ val.type)
1352
+ else:
1353
+ return self.tensor(
1354
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope),
1355
+ val.type)
1356
+ # for float
1357
+ # return atomic_smax(i_ptr, i_val) if val >= 0
1358
+ # return atomic_umin(i_ptr, i_val) if val < 0
1359
+ if sca_ty not in {tl.float32, tl.float64}:
1360
+ raise TypeError(f"atomic_max not supported for dtype {sca_ty}")
1361
+
1362
+ i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
1363
+ i_val = self.bitcast(val, i_type)
1364
+ i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1))
1365
+ ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
1366
+ ui_val = self.bitcast(val, ui_type)
1367
+ ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1))
1368
+ neg = self._signbit(val)
1369
+ pos = self.not_(neg)
1370
+ pos_ret = self.tensor(
1371
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
1372
+ self.and_(mask, pos).handle, sem, scope), i_val.type)
1373
+ neg_ret = self.tensor(
1374
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle,
1375
+ self.and_(mask, neg).handle, sem, scope), ui_val.type)
1376
+ ret = self.where(pos, pos_ret, neg_ret)
1377
+ return self.bitcast(ret, sca_ty)
1378
+
1379
+ def atomic_min(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1380
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'min')
1381
+ sem = self._str_to_sem(sem)
1382
+ scope = self._str_to_scope(scope)
1383
+ sca_ty = val.type.scalar
1384
+ # direct call to atomic_min for integers
1385
+ if sca_ty.is_int():
1386
+ if sca_ty.is_int_signed():
1387
+ return self.tensor(
1388
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope),
1389
+ val.type)
1390
+ else:
1391
+ return self.tensor(
1392
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope),
1393
+ val.type)
1394
+ # for float
1395
+ # return atomic_smin(i_ptr, i_val) if val >= 0
1396
+ # return atomic_umax(i_ptr, i_val) if val < 0
1397
+ if sca_ty not in {tl.float32, tl.float64}:
1398
+ raise TypeError(f"atomic_min not supported for dtype {sca_ty}")
1399
+
1400
+ i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
1401
+ i_val = self.bitcast(val, i_type)
1402
+ i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1))
1403
+ ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
1404
+ ui_val = self.bitcast(val, ui_type)
1405
+ ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1))
1406
+ neg = self._signbit(val)
1407
+ pos = self.not_(neg)
1408
+ pos_ret = self.tensor(
1409
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
1410
+ self.and_(mask, pos).handle, sem, scope), i_val.type)
1411
+ neg_ret = self.tensor(
1412
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle,
1413
+ self.and_(mask, neg).handle, sem, scope), ui_ptr.type)
1414
+ ret = self.where(pos, pos_ret, neg_ret)
1415
+ return self.bitcast(ret, sca_ty)
1416
+
1417
+ def atomic_add(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1418
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'add')
1419
+ sem = self._str_to_sem(sem)
1420
+ scope = self._str_to_scope(scope)
1421
+ sca_ty = val.type.scalar
1422
+ op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
1423
+ return self.tensor(self.builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope),
1424
+ val.type)
1425
+
1426
+ def atomic_and(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1427
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'and')
1428
+ sem = self._str_to_sem(sem)
1429
+ scope = self._str_to_scope(scope)
1430
+ return self.tensor(
1431
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1432
+
1433
+ def atomic_or(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1434
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'or')
1435
+ sem = self._str_to_sem(sem)
1436
+ scope = self._str_to_scope(scope)
1437
+ return self.tensor(
1438
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1439
+
1440
+ def atomic_xor(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1441
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xor')
1442
+ sem = self._str_to_sem(sem)
1443
+ scope = self._str_to_scope(scope)
1444
+ return self.tensor(
1445
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
1446
+
1447
+ def atomic_xchg(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
1448
+ ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xchg')
1449
+ sem = self._str_to_sem(sem)
1450
+ scope = self._str_to_scope(scope)
1451
+ return self.tensor(
1452
+ self.builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
1453
+ val.type)
1524
1454
 
1525
1455
  # ===----------------------------------------------------------------------===//
1526
1456
  # Linear Algebra
1527
1457
  # ===----------------------------------------------------------------------===//
1528
1458
 
1459
+ def _str_to_dot_input_precision(self, input_precision):
1460
+ assert input_precision.lower() in self.builder.options.allowed_dot_input_precisions, \
1461
+ f"input_precision must be one of {self.builder.options.allowed_dot_input_precisions}. Got {input_precision}"
1462
+ input_precision = input_precision.upper()
1463
+ if input_precision == "TF32X3":
1464
+ input_precision = "TF32x3"
1465
+ return getattr(ir.INPUT_PRECISION, input_precision)
1466
+
1467
+ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str],
1468
+ max_num_imprecise_acc: int, out_dtype: tl.dtype) -> TensorTy:
1469
+ assert lhs.type.is_block() and rhs.type.is_block()
1529
1470
 
1530
- def _str_to_dot_input_precision(input_precision, builder):
1531
- assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \
1532
- f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}"
1533
- input_precision = input_precision.upper()
1534
- if input_precision == "TF32X3":
1535
- input_precision = "TF32x3"
1536
- return getattr(ir.INPUT_PRECISION, input_precision)
1537
-
1538
-
1539
- def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int,
1540
- out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
1541
- assert lhs.type.is_block() and rhs.type.is_block()
1542
-
1543
- if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
1544
- # All combinations of supported fp8 x fp8 are permitted
1545
- pass
1546
- else:
1547
- assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
1548
- tl.float32), f"Unsupported lhs dtype {lhs.dtype}"
1549
- assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
1550
- tl.float32), f"Unsupported rhs dtype {rhs.dtype}"
1551
- assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
1552
-
1553
- if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
1554
- # We upcast because there's no fp8e4b15 type in MLIR
1555
- lhs = cast(lhs, tl.float16, builder)
1556
- rhs = cast(rhs, tl.float16, builder)
1557
-
1558
- if input_precision is None:
1559
- input_precision = builder.options.default_dot_input_precision
1560
-
1561
- input_precision = _str_to_dot_input_precision(input_precision, builder)
1562
-
1563
- lhs_rank = len(lhs.shape)
1564
- rhs_rank = len(rhs.shape)
1565
- assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1566
- assert lhs.shape[-1].value == rhs.shape[
1567
- -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})"
1568
- assert builder.codegen_fns.get("min_dot_size") is not None, "target doesn't provide lower shape bounds for dot."
1569
- min_dot_size = builder.codegen_fns["min_dot_size"](lhs.type, rhs.type)
1570
- assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \
1571
- and rhs.shape[-1].value >= min_dot_size[1], \
1572
- f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}"
1573
- if lhs.type.scalar.is_int():
1574
- assert lhs.type.scalar == tl.int8, "only int8 supported!"
1575
- _0 = builder.get_int32(0)
1576
- ret_scalar_ty = tl.int32
1577
- elif out_dtype.is_bf16():
1578
- raise ValueError(
1579
- "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
1580
- elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
1581
- _0 = builder.get_fp32(0)
1582
- ret_scalar_ty = tl.float32
1583
- else:
1584
- _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0)
1585
- ret_scalar_ty = out_dtype
1586
-
1587
- M = lhs.type.shape[-2]
1588
- N = rhs.type.shape[-1]
1589
- K = lhs.type.shape[-1]
1590
- B = lhs.type.shape[0] if lhs_rank == 3 else None
1591
- ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N])
1592
- if acc is None:
1593
- acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N])
1594
- else:
1595
- acc_handle = acc.handle
1596
- assert acc.type == ret_ty
1597
-
1598
- # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
1599
- if max_num_imprecise_acc is None:
1600
1471
  if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
1601
- max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default
1472
+ # All combinations of supported fp8 x fp8 are permitted
1473
+ pass
1602
1474
  else:
1603
- max_num_imprecise_acc = 0
1604
- else:
1605
- if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K:
1606
- raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})")
1607
-
1608
- return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc),
1609
- ret_ty)
1610
-
1611
-
1612
- def _str_to_fp_type(float_format: str):
1613
- ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None)
1614
- if ty_enum is None:
1615
- raise ValueError(f"Invalid float format: {float_format}.")
1616
- return ty_enum
1617
-
1618
-
1619
- def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder):
1620
- """
1621
- If float_format is subbyte, make sure it's packed as uint8 and return it.
1622
- Otherwise, return a tensor (perhaps bitcasting) of the specified float format.
1623
- """
1624
- triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format)
1625
- if triton_ty is None:
1626
- assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}"
1627
- assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}"
1628
- return val
1629
- if val.dtype == triton_ty:
1630
- return val
1631
- else:
1632
- unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format]
1633
- assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
1634
- return bitcast(val, triton_ty, builder)
1635
-
1636
-
1637
- def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
1638
- rhs_format: str, acc: tl.tensor | None, fast_math: bool, out_dtype: tl.dtype,
1639
- builder: ir.builder) -> tl.tensor:
1640
- assert lhs.type.is_block() and rhs.type.is_block()
1641
- #TODO: validate types.
1642
- lhs_rank = len(lhs.shape)
1643
- rhs_rank = len(rhs.shape)
1644
- assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1645
- lhs_format: str = lhs_format.value
1646
- rhs_format: str = rhs_format.value
1647
- lhs_format_enum = _str_to_fp_type(lhs_format)
1648
- rhs_format_enum = _str_to_fp_type(rhs_format)
1649
- allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"}
1650
- assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}"
1651
- assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}"
1652
- rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None)
1653
- lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None)
1654
- lhs = _bitcast_to_fp_type(lhs, lhs_format, builder)
1655
- rhs = _bitcast_to_fp_type(rhs, rhs_format, builder)
1656
-
1657
- M = lhs.type.shape[-2]
1658
- K, N = rhs.type.shape[-2:]
1659
- PACKED_A = 2 if lhs_format == "e2m1" else 1
1660
- PACKED_B = 2 if rhs_format == "e2m1" else 1
1661
- assert K * PACKED_B == PACKED_A * lhs.type.shape[
1662
- -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1663
- #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
1664
- B = lhs.type.shape[0] if lhs_rank == 3 else None
1665
-
1666
- ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N])
1667
- _0 = builder.get_fp32(0)
1668
- if acc is None:
1669
- acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N])
1670
- else:
1671
- acc_handle = acc.handle
1672
- assert acc.type == ret_ty
1673
- rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
1674
- lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
1675
- return tl.tensor(
1676
- builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
1677
- rhs_format_enum, fast_math, acc_handle), ret_ty)
1475
+ assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
1476
+ tl.float32), f"Unsupported lhs dtype {lhs.dtype}"
1477
+ assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
1478
+ tl.float32), f"Unsupported rhs dtype {rhs.dtype}"
1479
+ assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
1480
+
1481
+ if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
1482
+ if "fp8e4b15" in self.builder.options.deprecated_fp8_dot_operand_dtypes:
1483
+ warnings.warn(
1484
+ "the use of fp8e4b15 is deprecated on Hopper and later architectures and can cause significant slow down. It will be removed in a future triton release"
1485
+ )
1486
+ # We upcast because there's no fp8e4b15 type in MLIR
1487
+ lhs = self.cast(lhs, tl.float16)
1488
+ rhs = self.cast(rhs, tl.float16)
1489
+
1490
+ if input_precision is None:
1491
+ input_precision = self.builder.options.default_dot_input_precision
1492
+
1493
+ input_precision = self._str_to_dot_input_precision(input_precision)
1494
+
1495
+ lhs_rank = len(lhs.shape)
1496
+ rhs_rank = len(rhs.shape)
1497
+ assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1498
+ assert lhs.shape[-1].value == rhs.shape[
1499
+ -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})"
1500
+ assert self.builder.codegen_fns.get(
1501
+ "min_dot_size") is not None, "target doesn't provide lower shape bounds for dot."
1502
+ min_dot_size = self.builder.codegen_fns["min_dot_size"](lhs.type, rhs.type)
1503
+ assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \
1504
+ and rhs.shape[-1].value >= min_dot_size[1], \
1505
+ f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}"
1506
+ if lhs.type.scalar.is_int():
1507
+ assert lhs.type.scalar == tl.int8, "only int8 supported!"
1508
+ _0 = self.builder.get_int32(0)
1509
+ ret_scalar_ty = tl.int32
1510
+ elif out_dtype.is_bf16():
1511
+ raise ValueError(
1512
+ "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`"
1513
+ )
1514
+ elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
1515
+ _0 = self.builder.get_fp32(0)
1516
+ ret_scalar_ty = tl.float32
1517
+ else:
1518
+ _0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0)
1519
+ ret_scalar_ty = out_dtype
1520
+
1521
+ M = lhs.type.shape[-2]
1522
+ N = rhs.type.shape[-1]
1523
+ K = lhs.type.shape[-1]
1524
+ B = lhs.type.shape[0] if lhs_rank == 3 else None
1525
+ ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N])
1526
+ if acc is None:
1527
+ acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
1528
+ else:
1529
+ acc_handle = acc.handle
1530
+ assert acc.type == ret_ty
1678
1531
 
1532
+ # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
1533
+ if max_num_imprecise_acc is None:
1534
+ if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
1535
+ max_num_imprecise_acc = self.builder.options.max_num_imprecise_acc_default
1536
+ else:
1537
+ max_num_imprecise_acc = 0
1538
+ else:
1539
+ if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K:
1540
+ raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})")
1541
+
1542
+ return self.tensor(
1543
+ self.builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), ret_ty)
1544
+
1545
+ def _str_to_fp_type(self, float_format: str):
1546
+ ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None)
1547
+ if ty_enum is None:
1548
+ raise ValueError(f"Invalid float format: {float_format}.")
1549
+ return ty_enum
1550
+
1551
+ def _bitcast_to_fp_type(self, val: TensorTy, float_format: str):
1552
+ """
1553
+ If float_format is subbyte, make sure it's packed as uint8 and return it.
1554
+ Otherwise, return a tensor (perhaps bitcasting) of the specified float format.
1555
+ """
1556
+ triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16":
1557
+ tl.float16}.get(float_format)
1558
+ if triton_ty is None:
1559
+ assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}"
1560
+ assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}"
1561
+ return val
1562
+ if val.dtype == triton_ty:
1563
+ return val
1564
+ else:
1565
+ unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format]
1566
+ assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
1567
+ return self.bitcast(val, triton_ty)
1568
+
1569
+ def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy,
1570
+ rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool,
1571
+ lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy:
1572
+ assert lhs.type.is_block() and rhs.type.is_block()
1573
+ #TODO: validate types.
1574
+ lhs_rank = len(lhs.shape)
1575
+ rhs_rank = len(rhs.shape)
1576
+ assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1577
+ lhs_format: str = lhs_format.value
1578
+ rhs_format: str = rhs_format.value
1579
+ lhs_format_enum = self._str_to_fp_type(lhs_format)
1580
+ rhs_format_enum = self._str_to_fp_type(rhs_format)
1581
+ allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"}
1582
+ assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}"
1583
+ assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}"
1584
+ rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None)
1585
+ lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None)
1586
+ lhs = self._bitcast_to_fp_type(lhs, lhs_format)
1587
+ rhs = self._bitcast_to_fp_type(rhs, rhs_format)
1588
+
1589
+ assert lhs_k_pack or lhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K"
1590
+ assert rhs_k_pack or rhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K"
1591
+ M, K_LHS = lhs.type.shape[-2:]
1592
+ K_RHS, N = rhs.type.shape[-2:]
1593
+ PACKED_A = 2 if lhs_format == "e2m1" else 1
1594
+ PACKED_B = 2 if rhs_format == "e2m1" else 1
1595
+ PACKED_A_DIM = PACKED_A * K_LHS if lhs_k_pack else K_LHS
1596
+ PACKED_B_DIM = PACKED_B * K_RHS if rhs_k_pack else K_RHS
1597
+ assert PACKED_B_DIM == PACKED_A_DIM, f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
1598
+ #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
1599
+ B = lhs.type.shape[0] if lhs_rank == 3 else None
1600
+ if not lhs_k_pack:
1601
+ M = M * PACKED_A
1602
+ if not rhs_k_pack:
1603
+ N = N * PACKED_B
1604
+ ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N])
1605
+ _0 = self.builder.get_fp32(0)
1606
+ if acc is None:
1607
+ acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
1608
+ else:
1609
+ acc_handle = acc.handle
1610
+ assert acc.type == ret_ty
1611
+ rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
1612
+ lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
1613
+ return self.tensor(
1614
+ self.builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
1615
+ rhs_format_enum, fast_math, lhs_k_pack, rhs_k_pack, acc_handle), ret_ty)
1679
1616
 
1680
1617
  # ===----------------------------------------------------------------------===//
1681
1618
  # Indexing
1682
1619
  # ===----------------------------------------------------------------------===//
1683
1620
 
1684
-
1685
- def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
1686
- if condition.dtype != tl.int1:
1687
- warnings.warn(
1688
- f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}"
1689
- )
1690
- condition = cast(condition, tl.int1, builder)
1691
- x, y = binary_op_type_checking_impl(x, y, builder, True, True)
1692
- # x, y are broadcasted
1693
- if condition.type.is_block():
1694
- condition, x = broadcast_impl_value(condition, x, builder)
1695
- x, y = broadcast_impl_value(x, y, builder)
1696
- else:
1697
- condition, _ = broadcast_impl_value(condition, x, builder)
1698
- ret_ty = x.type
1699
- return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
1700
-
1621
+ def where(self, condition: TensorTy, x: TensorTy, y: TensorTy) -> TensorTy:
1622
+ if condition.dtype != tl.int1:
1623
+ warnings.warn(
1624
+ f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}"
1625
+ )
1626
+ condition = self.cast(condition, tl.int1)
1627
+ x, y = self.binary_op_type_checking_impl(x, y, True, True)
1628
+ # x, y are broadcasted
1629
+ if condition.type.is_block():
1630
+ condition, x = self.broadcast_impl_value(condition, x)
1631
+ x, y = self.broadcast_impl_value(x, y)
1632
+ else:
1633
+ condition, _ = self.broadcast_impl_value(condition, x)
1634
+ ret_ty = x.type
1635
+ return self.tensor(self.builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
1701
1636
 
1702
1637
  # ===----------------------------------------------------------------------===//
1703
1638
  # Reduction
1704
1639
  # ===----------------------------------------------------------------------===
1705
1640
 
1706
-
1707
- def wrap_tensor(x, scalar_ty, ret_shape):
1708
- if ret_shape:
1709
- res_ty = tl.block_type(scalar_ty, ret_shape)
1710
- else:
1711
- # 0d-tensor -> scalar
1712
- res_ty = scalar_ty
1713
- return tl.tensor(x, res_ty)
1714
-
1715
-
1716
- def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]:
1717
- if axis is None:
1718
- inputs = tuple(reshape(t, [t.numel.value], can_reorder=True, builder=builder) for t in inputs)
1719
- axis = 0
1720
- # get result shape
1721
- shape = inputs[0].type.shape
1722
- rank = len(shape)
1723
- assert axis < rank, f"reduction axis must be < inputs rank ({rank})"
1724
- ret_shape = [s for i, s in enumerate(shape) if i != axis]
1725
- assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
1726
-
1727
- reduce_op = builder.create_reduce([t.handle for t in inputs], axis)
1728
- region_builder_fn(reduce_op)
1729
- reduce_op.verify()
1730
-
1731
- return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs)))
1732
-
1641
+ def wrap_tensor(self, x, scalar_ty, ret_shape):
1642
+ if ret_shape:
1643
+ res_ty = tl.block_type(scalar_ty, ret_shape)
1644
+ else:
1645
+ # 0d-tensor -> scalar
1646
+ res_ty = scalar_ty
1647
+ return self.tensor(x, res_ty)
1648
+
1649
+ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
1650
+ if axis is None:
1651
+ inputs = tuple(self.reshape(t, [t.numel.value], can_reorder=True) for t in inputs)
1652
+ axis = 0
1653
+ # get result shape
1654
+ shape = inputs[0].type.shape
1655
+ rank = len(shape)
1656
+ assert axis < rank, f"reduction axis must be < inputs rank ({rank})"
1657
+ ret_shape = [s for i, s in enumerate(shape) if i != axis]
1658
+ assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
1659
+
1660
+ reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
1661
+ region_builder_fn(reduce_op)
1662
+ assert reduce_op.verify()
1663
+
1664
+ return tuple(
1665
+ self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs)))
1733
1666
 
1734
1667
  # ===----------------------------------------------------------------------===
1735
1668
  # Associative Scan
1736
1669
  # ===----------------------------------------------------------------------===
1737
1670
 
1671
+ def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
1672
+ reverse: bool) -> Tuple[TensorTy, ...]:
1673
+ shape = inputs[0].type.shape
1674
+ rank = len(shape)
1738
1675
 
1739
- def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, reverse: bool,
1740
- builder: ir.builder) -> Tuple[tl.tensor, ...]:
1741
- shape = inputs[0].type.shape
1742
- rank = len(shape)
1743
-
1744
- assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
1676
+ assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
1745
1677
 
1746
- if axis < 0:
1747
- axis += rank
1678
+ if axis < 0:
1679
+ axis += rank
1748
1680
 
1749
- for t in inputs:
1750
- assert t.type.shape == shape, "all scan inputs must have the same shape"
1681
+ for t in inputs:
1682
+ assert t.type.shape == shape, "all scan inputs must have the same shape"
1751
1683
 
1752
- scan_op = builder.create_scan([t.handle for t in inputs], axis, reverse)
1753
- region_builder_fn(scan_op)
1754
- scan_op.verify()
1755
-
1756
- return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
1684
+ scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
1685
+ region_builder_fn(scan_op)
1686
+ assert scan_op.verify()
1757
1687
 
1688
+ return tuple(self.wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
1758
1689
 
1759
1690
  # ===----------------------------------------------------------------------===
1760
1691
  # Gather
1761
1692
  # ===----------------------------------------------------------------------===
1762
1693
 
1694
+ def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy:
1695
+ assert index.dtype.is_int(), "index must be an integer tensor"
1763
1696
 
1764
- def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
1765
- assert index.dtype.is_int(), "index must be an integer tensor"
1766
-
1767
- rank = len(src.type.shape)
1768
- assert len(index.type.shape) == rank, "source and index tensors must have the same rank"
1697
+ rank = len(src.type.shape)
1698
+ assert len(index.type.shape) == rank, "source and index tensors must have the same rank"
1769
1699
 
1770
- assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})"
1771
- if axis < 0:
1772
- axis += rank
1700
+ assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})"
1701
+ if axis < 0:
1702
+ axis += rank
1773
1703
 
1774
- for d in range(rank):
1775
- if d == axis:
1776
- continue
1777
- assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim"
1704
+ for d in range(rank):
1705
+ if d == axis:
1706
+ continue
1707
+ assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim"
1778
1708
 
1779
- gather = builder.create_gather(src.handle, index.handle, axis)
1780
- return wrap_tensor(gather, src.type.scalar, index.type.shape)
1709
+ gather = self.builder.create_gather(src.handle, index.handle, axis)
1710
+ return self.wrap_tensor(gather, src.type.scalar, index.type.shape)
1781
1711
 
1782
1712
 
1783
1713
  # ===----------------------------------------------------------------------===
1784
1714
  # Histogram
1785
1715
  # ===----------------------------------------------------------------------===
1786
1716
 
1717
+ def histogram(self, input: TensorTy, num_bins: int, mask: Optional[TensorTy]) -> TensorTy:
1718
+ assert len(input.shape) == 1, "histogram only supports 1D input"
1719
+ assert input.dtype.is_int(), "histogram only supports integer input"
1720
+ if mask is not None:
1721
+ mask = self.broadcast_impl_shape(mask, input.shape)
1722
+ if not mask.type.scalar.is_bool():
1723
+ raise ValueError("Mask must have boolean scalar type")
1724
+ mask = mask.handle
1725
+ return self.tensor(self.builder.create_histogram(input.handle, num_bins, mask),
1726
+ tl.block_type(tl.int32, [num_bins]))
1727
+
1728
+ def multiple_of(self, x: TensorTy, values: List[int]) -> TensorTy:
1729
+ if max(1, len(x.shape)) != len(values):
1730
+ raise ValueError("Shape of input to multiple_of does not match the length of values")
1731
+ x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
1732
+ return x
1787
1733
 
1788
- def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor:
1789
- assert len(input.shape) == 1, "histogram only supports 1D input"
1790
- assert input.dtype.is_int(), "histogram only supports integer input"
1791
- return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, [num_bins]))
1792
-
1793
-
1794
- def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
1795
- if max(1, len(x.shape)) != len(values):
1796
- raise ValueError("Shape of input to multiple_of does not match the length of values")
1797
- x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
1798
- return x
1799
-
1800
-
1801
- def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
1802
- if len(x.shape) != len(values):
1803
- raise ValueError("Shape of input to max_contiguous does not match the length of values")
1804
- x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
1805
- return x
1806
-
1807
-
1808
- def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor:
1809
- if len(x.shape) != len(values):
1810
- raise ValueError("Shape of input to max_constancy does not match the length of values")
1811
- x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context()))
1812
- return x
1813
-
1814
-
1815
- def debug_barrier(builder: ir.builder) -> tl.tensor:
1816
- return tl.tensor(builder.create_barrier(), tl.void)
1817
-
1818
-
1819
- def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor:
1820
- # It makes sense visually for prefix to end in ": "; make it so. Also,
1821
- # non-empty prefixes should start with " ".
1822
- if not prefix.endswith(" ") and args:
1823
- prefix += " "
1824
- if not prefix.endswith(": ") and args:
1825
- prefix = prefix[:-1] + ": "
1826
- if len(prefix) > 2 and not prefix.startswith(" "):
1827
- prefix = " " + prefix
1828
-
1829
- new_args = [arg.handle for arg in args]
1830
- is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args]
1831
- return tl.tensor(builder.create_print(prefix, hex, new_args, is_signed), tl.void)
1832
-
1833
-
1834
- def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
1835
- if not builder.options.debug:
1836
- return
1837
- return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
1838
-
1839
-
1840
- def assume(cond, builder: ir.builder) -> tl.tensor:
1841
- return tl.tensor(builder.create_assume(cond.handle), tl.void)
1734
+ def max_contiguous(self, x: TensorTy, values: List[int]) -> TensorTy:
1735
+ if len(x.shape) != len(values):
1736
+ raise ValueError("Shape of input to max_contiguous does not match the length of values")
1737
+ x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
1738
+ return x
1842
1739
 
1740
+ def max_constancy(self, x: TensorTy, values: List[int]) -> TensorTy:
1741
+ if len(x.shape) != len(values):
1742
+ raise ValueError("Shape of input to max_constancy does not match the length of values")
1743
+ x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context()))
1744
+ return x
1843
1745
 
1844
- def _convert_elem_to_ir_value(builder, elem, require_i64):
1845
- if isinstance(elem, int):
1846
- elem = tl.constexpr(elem)
1847
- if isinstance(elem, tl.constexpr):
1848
- if require_i64:
1849
- assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \
1850
- f"got a value {elem.value} which is out of the range"
1851
- return builder.get_int64(elem.value)
1852
- else:
1853
- assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \
1854
- f"got a value {elem.value} which is out of the range"
1855
- return builder.get_int32(elem.value)
1856
- elif isinstance(elem, tl.tensor):
1857
- assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets"
1858
- assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets"
1859
- if elem.dtype != tl.int64 and require_i64:
1860
- return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed())
1861
- elif elem.dtype != tl.int32 and not require_i64:
1862
- assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \
1863
- "add a `.to(tl.int32)` or use regular indexing for 64 bit support"
1864
- return elem.handle
1865
- assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}"
1866
-
1867
-
1868
- def _convert_to_ir_values(builder, list_like, require_i64=True):
1869
- if hasattr(list_like, "__iter__"):
1870
- return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like]
1871
- return [_convert_elem_to_ir_value(builder, list_like, require_i64)]
1872
-
1873
-
1874
- def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor:
1875
- # Convert dynamic arguments to IR values
1876
- # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t`
1877
- shape = _convert_to_ir_values(builder, shape)
1878
- strides = _convert_to_ir_values(builder, strides)
1879
- offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1880
-
1881
- # Check `base` type
1882
- if not base.type.is_ptr() or base.type.element_ty.is_block():
1883
- raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)")
1884
-
1885
- # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1886
- if base.type.element_ty == tl.int1:
1887
- base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder)
1888
-
1889
- # Check whether `block_shape` is static
1890
- if not hasattr(block_shape, "__iter__"):
1891
- block_shape = [block_shape]
1892
- block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape]
1893
- assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \
1894
- "Expected a list of constant integers (`int32_t` range) in `block_shape`"
1895
-
1896
- # Check `order`
1897
- if not hasattr(order, "__iter__"):
1898
- order = [order]
1899
- order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order]
1900
- assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order"
1901
-
1902
- # Must have same length
1903
- assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \
1904
- "Expected shape/strides/offsets/block_shape to have the same length"
1905
-
1906
- # Build value, the type is:
1907
- # `pointer_type<blocked<shape, element_type>>` in Python
1908
- # `tt.ptr<tensor<shape, element_type>>` in MLIR
1909
- handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order)
1910
- return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape)))
1911
-
1912
-
1913
- def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
1914
- # Convert dynamic offsets to IR values
1915
- offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
1916
-
1917
- # Advanced block pointer type is the same as before
1918
- return tl.tensor(builder.create_advance(base.handle, offsets), base.type)
1919
-
1920
-
1921
- def make_tensor_descriptor(
1922
- base: tl.tensor,
1923
- shape: List[tl.tensor],
1924
- strides: List[tl.tensor],
1925
- block_shape: List[tl.constexpr],
1926
- builder: ir.builder,
1927
- ) -> tl._experimental_tensor_descriptor:
1928
- ndim = len(shape)
1929
- if not (2 <= ndim <= 5):
1930
- raise ValueError(f"Expected 2 <= ndim <= 5 but got {ndim} dimensions")
1931
- if len(strides) != ndim:
1932
- raise ValueError(f"Expected {ndim} strides but got {len(strides)}")
1933
- if len(block_shape) != ndim:
1934
- raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}")
1935
-
1936
- strides[-1] = tl._constexpr_to_value(strides[-1])
1937
- if strides[-1] != 1:
1938
- raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}")
1939
-
1940
- shape = [to_tensor(x, builder) for x in shape]
1941
- strides = [to_tensor(x, builder).to(tl.int64, _builder=builder) for x in strides]
1942
-
1943
- # Check whether `block_shape` is static
1944
- block_shape = tl._unwrap_shape(block_shape)
1945
-
1946
- assert isinstance(base.type, tl.pointer_type)
1947
- type = tl.block_type(base.type.element_ty, block_shape)
1948
- handle = builder.create_make_tensor_descriptor(base.handle, [s.handle for s in shape], [s.handle for s in strides],
1949
- block_shape)
1950
- return tl._experimental_tensor_descriptor(handle, shape, strides, type)
1746
+ def debug_barrier(self) -> TensorTy:
1747
+ return self.tensor(self.builder.create_barrier(), tl.void)
1748
+
1749
+ def device_print(self, prefix: str, args: List[TensorTy], hex: bool) -> TensorTy:
1750
+ # It makes sense visually for prefix to end in ": "; make it so. Also,
1751
+ # non-empty prefixes should start with " ".
1752
+ if not prefix.endswith(" ") and args:
1753
+ prefix += " "
1754
+ if not prefix.endswith(": ") and args:
1755
+ prefix = prefix[:-1] + ": "
1756
+ if len(prefix) > 2 and not prefix.startswith(" "):
1757
+ prefix = " " + prefix
1758
+
1759
+ new_args = [arg.handle for arg in args]
1760
+ is_signed = [arg.dtype.is_int_signed() for arg in args]
1761
+ return self.tensor(self.builder.create_print(prefix, hex, new_args, is_signed), tl.void)
1762
+
1763
+ def device_assert(self, cond: TensorTy, msg: str) -> TensorTy:
1764
+ if not self.builder.options.debug:
1765
+ return
1766
+ return self.tensor(self.builder.create_assert(cond.handle, msg), tl.void)
1767
+
1768
+ def assume(self, cond) -> TensorTy:
1769
+ return self.tensor(self.builder.create_assume(cond.handle), tl.void)
1770
+
1771
+ def _convert_elem_to_ir_value(self, elem, require_i64):
1772
+ if isinstance(elem, int):
1773
+ elem = tl.constexpr(elem)
1774
+ if isinstance(elem, tl.constexpr):
1775
+ if isinstance(elem.value, bool):
1776
+ return self.builder.get_int1(elem.value)
1777
+ if require_i64:
1778
+ assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \
1779
+ f"got a value {elem.value} which is out of the range"
1780
+ return self.builder.get_int64(elem.value)
1781
+ else:
1782
+ assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \
1783
+ f"got a value {elem.value} which is out of the range"
1784
+ return self.builder.get_int32(elem.value)
1785
+ elif isinstance(elem, tl.tensor):
1786
+ assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets"
1787
+ assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets"
1788
+ if elem.dtype != tl.int64 and require_i64:
1789
+ return self.builder.create_int_cast(elem.handle, self.builder.get_int64_ty(),
1790
+ elem.dtype.is_int_signed())
1791
+ elif elem.dtype != tl.int32 and not require_i64:
1792
+ assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \
1793
+ "add a `.to(tl.int32)` or use regular indexing for 64 bit support"
1794
+ return elem.handle
1795
+ assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}"
1796
+
1797
+ def _convert_to_ir_values(self, list_like, require_i64=True):
1798
+ if hasattr(list_like, "__iter__"):
1799
+ return [self._convert_elem_to_ir_value(elem, require_i64) for elem in list_like]
1800
+ return [self._convert_elem_to_ir_value(list_like, require_i64)]
1801
+
1802
+ def make_block_ptr(self, base: TensorTy, shape, strides, offsets, block_shape, order) -> TensorTy:
1803
+ # Convert dynamic arguments to IR values
1804
+ # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t`
1805
+ shape = self._convert_to_ir_values(shape)
1806
+ strides = self._convert_to_ir_values(strides)
1807
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1808
+
1809
+ # Check `base` type
1810
+ if not base.type.is_ptr() or base.type.element_ty.is_block():
1811
+ raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)")
1812
+
1813
+ # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
1814
+ if base.type.element_ty == tl.int1:
1815
+ base = self.cast(base, tl.pointer_type(tl.int8, base.type.address_space))
1816
+
1817
+ # Check whether `block_shape` is static
1818
+ if not hasattr(block_shape, "__iter__"):
1819
+ block_shape = [block_shape]
1820
+ block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape]
1821
+ assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \
1822
+ "Expected a list of constant integers (`int32_t` range) in `block_shape`"
1823
+
1824
+ # Check `order`
1825
+ if not hasattr(order, "__iter__"):
1826
+ order = [order]
1827
+ order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order]
1828
+ assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order"
1829
+
1830
+ # Must have same length
1831
+ assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \
1832
+ "Expected shape/strides/offsets/block_shape to have the same length"
1833
+
1834
+ # Build value, the type is:
1835
+ # `pointer_type<blocked<shape, element_type>>` in Python
1836
+ # `tt.ptr<tensor<shape, element_type>>` in MLIR
1837
+ handle = self.builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order)
1838
+ return self.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape)))
1839
+
1840
+ def advance(self, base: TensorTy, offsets) -> TensorTy:
1841
+ # Convert dynamic offsets to IR values
1842
+ offsets = self._convert_to_ir_values(offsets, require_i64=False)
1843
+
1844
+ # Advanced block pointer type is the same as before
1845
+ return self.tensor(self.builder.create_advance(base.handle, offsets), base.type)
1846
+
1847
+ def make_tensor_descriptor(
1848
+ self,
1849
+ base: TensorTy,
1850
+ shape: List[TensorTy],
1851
+ strides: List[TensorTy],
1852
+ block_shape: List[tl.constexpr],
1853
+ ) -> tl.tensor_descriptor:
1854
+ ndim = len(shape)
1855
+ if not (1 <= ndim <= 5):
1856
+ raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions")
1857
+ if len(strides) != ndim:
1858
+ raise ValueError(f"Expected {ndim} strides but got {len(strides)}")
1859
+ if len(block_shape) != ndim:
1860
+ raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}")
1861
+ assert isinstance(base.dtype, tl.pointer_type)
1862
+ elem_size = base.dtype.element_ty.primitive_bitwidth // 8
1863
+ contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1])
1864
+ if contig_dim_size * elem_size < 16:
1865
+ raise ValueError(
1866
+ f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes"
1867
+ )
1868
+
1869
+ strides[-1] = tl._unwrap_if_constexpr(strides[-1])
1870
+ if strides[-1] != 1:
1871
+ raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}")
1872
+
1873
+ shape = [self.make_scalar(x, tl.int32) for x in shape]
1874
+ strides = [self.make_scalar(x, tl.int64) for x in strides]
1875
+
1876
+ # Check whether `block_shape` is static
1877
+ block_shape = tl._unwrap_shape(block_shape)
1878
+
1879
+ assert isinstance(base.type, tl.pointer_type)
1880
+ type = tl.block_type(base.type.element_ty, block_shape)
1881
+ base_handle = base.handle
1882
+ is_signed_int = base.type.element_ty.is_int_signed()
1883
+
1884
+ handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape],
1885
+ [s.handle for s in strides], block_shape, is_signed_int)
1886
+ return tl.tensor_descriptor(handle, shape, strides, type)