triton-windows 3.3.1.post19__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -0,0 +1,490 @@
1
+ from __future__ import annotations
2
+ import math
3
+ from typing import TypeVar, List, TYPE_CHECKING, Tuple
4
+ from functools import wraps
5
+
6
+ if TYPE_CHECKING:
7
+ from triton._C.libtriton.gluon_ir import GluonOpBuilder
8
+ from ._semantic import GluonSemantic
9
+
10
+ from ._layouts import SharedLayout, DistributedLayout
11
+ from triton._C.libtriton import ir
12
+ import triton.language.core as tl_core
13
+ from triton.language.core import (
14
+ constexpr,
15
+ base_value,
16
+ base_type,
17
+ dtype,
18
+ block_type, # TODO: block type with layout info
19
+ pointer_type,
20
+ void,
21
+ int1,
22
+ int8,
23
+ int16,
24
+ int32,
25
+ int64,
26
+ uint8,
27
+ uint16,
28
+ uint32,
29
+ uint64,
30
+ float8e5,
31
+ float8e5b16,
32
+ float8e4nv,
33
+ float8e4b8,
34
+ float8e4b15,
35
+ float16,
36
+ bfloat16,
37
+ float32,
38
+ float64,
39
+ _unwrap_if_constexpr,
40
+ _unwrap_shape,
41
+ static_range,
42
+ tensor,
43
+ tuple,
44
+ tuple_type,
45
+ )
46
+
47
+ # We define __all__ only to appease the python linter, these are not used in
48
+ # this file but we want to import them anyway so they are importable from here.
49
+ __all__ = [
50
+ "constexpr",
51
+ "pointer_type",
52
+ "void",
53
+ "int1",
54
+ "int8",
55
+ "int16",
56
+ "int32",
57
+ "int64",
58
+ "uint8",
59
+ "uint16",
60
+ "uint32",
61
+ "uint64",
62
+ "float8e5",
63
+ "float8e5b16",
64
+ "float8e4nv",
65
+ "float8e4b8",
66
+ "float8e4b15",
67
+ "float16",
68
+ "bfloat16",
69
+ "float32",
70
+ "float64",
71
+ "static_range",
72
+ "tuple",
73
+ "tuple_type",
74
+ ]
75
+
76
+ T = TypeVar("T")
77
+
78
+ # TODO: split these
79
+ GLUON_BUILTIN = "__triton_builtin__"
80
+
81
+
82
+ def builtin(fn: T) -> T:
83
+ """Mark a function as a builtin."""
84
+ assert callable(fn)
85
+
86
+ @wraps(fn)
87
+ def wrapper(*args, **kwargs):
88
+ if "_semantic" not in kwargs or kwargs["_semantic"] is None:
89
+ raise ValueError("Did you forget to add @triton.gluon.jit ? "
90
+ "(`_semantic` argument must be provided outside of JIT functions.)")
91
+ return fn(*args, **kwargs)
92
+
93
+ setattr(wrapper, GLUON_BUILTIN, True)
94
+
95
+ return wrapper
96
+
97
+
98
+ # Explicitly import forwarded Triton language symbols so mypy sees them.
99
+ associative_scan = builtin(tl_core.associative_scan)
100
+ atomic_add = builtin(tl_core.atomic_add)
101
+ atomic_and = builtin(tl_core.atomic_and)
102
+ atomic_cas = builtin(tl_core.atomic_cas)
103
+ atomic_max = builtin(tl_core.atomic_max)
104
+ atomic_min = builtin(tl_core.atomic_min)
105
+ atomic_or = builtin(tl_core.atomic_or)
106
+ atomic_xchg = builtin(tl_core.atomic_xchg)
107
+ atomic_xor = builtin(tl_core.atomic_xor)
108
+ broadcast = builtin(tl_core.broadcast)
109
+ device_assert = builtin(tl_core.device_assert)
110
+ expand_dims = builtin(tl_core.expand_dims)
111
+ inline_asm_elementwise = builtin(tl_core.inline_asm_elementwise)
112
+ join = builtin(tl_core.join)
113
+ load = builtin(tl_core.load)
114
+ map_elementwise = builtin(tl_core.map_elementwise)
115
+ max_constancy = builtin(tl_core.max_constancy)
116
+ max_contiguous = builtin(tl_core.max_contiguous)
117
+ maximum = builtin(tl_core.maximum)
118
+ minimum = builtin(tl_core.minimum)
119
+ multiple_of = builtin(tl_core.multiple_of)
120
+ num_programs = builtin(tl_core.num_programs)
121
+ permute = builtin(tl_core.permute)
122
+ program_id = builtin(tl_core.program_id)
123
+ reduce = builtin(tl_core.reduce)
124
+ reshape = builtin(tl_core.reshape)
125
+ split = builtin(tl_core.split)
126
+ static_assert = builtin(tl_core.static_assert)
127
+ static_print = builtin(tl_core.static_print)
128
+ store = builtin(tl_core.store)
129
+ to_tensor = builtin(tl_core.to_tensor)
130
+ where = builtin(tl_core.where)
131
+
132
+
133
+ class distributed_type(block_type):
134
+
135
+ def __init__(self, element_ty: dtype, shape: List[int], layout):
136
+ super().__init__(element_ty, shape)
137
+ self.layout = layout
138
+ self.name = f"<{self.shape}, {self.element_ty}, {self.layout}>"
139
+ assert isinstance(layout, DistributedLayout)
140
+
141
+ def to_ir(self, builder: ir.builder) -> ir.type:
142
+ elem_ty = self.element_ty.to_ir(builder)
143
+ layout = self.layout._to_ir(builder)
144
+ return builder.get_distributed_ty(elem_ty, self.shape, layout)
145
+
146
+ def mangle(self) -> str:
147
+ elt = self.scalar.mangle()
148
+ shape = "_".join(map(str, self.shape))
149
+ layout = self.layout.mangle()
150
+ return f"{elt}S{shape}SL{layout}L"
151
+
152
+ def with_element_ty(self, scalar_ty: dtype) -> block_type:
153
+ return distributed_type(scalar_ty, self.shape, self.layout)
154
+
155
+ def __eq__(self, other) -> bool:
156
+ if not isinstance(other, distributed_type):
157
+ return False
158
+ return super().__eq__(other) and self.layout == other.layout
159
+
160
+
161
+ class shared_memory_descriptor_type(base_type):
162
+
163
+ def __init__(self, element_ty, shape, layout, alloc_shape):
164
+ self.element_ty = element_ty
165
+ self.shape = shape
166
+ self.layout = layout
167
+ self.alloc_shape = alloc_shape
168
+ assert isinstance(layout, SharedLayout)
169
+
170
+ def to_ir(self, builder: GluonOpBuilder) -> None:
171
+ return builder.get_shared_mem_desc_ty(
172
+ self.element_ty.to_ir(builder),
173
+ self.shape,
174
+ self.layout._to_ir(builder),
175
+ self.alloc_shape,
176
+ )
177
+
178
+ def _unflatten_ir(self, handles: List[ir.Value], cursor: int) -> Tuple[shared_memory_descriptor, int]:
179
+ value = shared_memory_descriptor(handles[cursor], self.element_ty, self.shape, self.layout, self.alloc_shape)
180
+ return value, cursor + 1
181
+
182
+ def _flatten_ir_types(self, builder: GluonOpBuilder, out: List[ir.type]) -> None:
183
+ out.append(self.to_ir(builder))
184
+
185
+ def __str__(self) -> str:
186
+ return f"shared_memory_descriptor<{self.element_ty}, {self.shape}, {self.layout}, {self.alloc_shape}>"
187
+
188
+ def __eq__(self, other) -> bool:
189
+ return (type(self) is type(other) and self.shape == other.shape and self.layout == other.layout
190
+ and self.alloc_shape == other.alloc_shape)
191
+
192
+ def __neq__(self, other) -> bool:
193
+ return not (self == other)
194
+
195
+ def mangle(self) -> str:
196
+ shape_str = "_".join([str(s) for s in self.shape])
197
+ return f"MD{self.element_ty.mangle()}S{shape_str}SL{self.layout.mangle()}LAS{self.alloc_shape}ASMD"
198
+
199
+
200
+ class shared_memory_descriptor(base_value):
201
+ """
202
+ Represents a handle to a shared memory allocation in Gluon IR.
203
+ """
204
+
205
+ def __init__(self, handle, element_ty, shape, layout, alloc_shape):
206
+ self.handle = handle
207
+ self.type = shared_memory_descriptor_type(element_ty, shape, layout, alloc_shape)
208
+
209
+ def _flatten_ir(self, handles: List[ir.value]) -> None:
210
+ handles.append(self.handle)
211
+
212
+ @property
213
+ def dtype(self):
214
+ return self.type.element_ty
215
+
216
+ @property
217
+ def shape(self):
218
+ return self.type.shape
219
+
220
+ @property
221
+ def rank(self):
222
+ return len(self.shape)
223
+
224
+ @property
225
+ def numel(self) -> int:
226
+ return math.prod(self.shape)
227
+
228
+ @property
229
+ def layout(self):
230
+ return self.type.layout
231
+
232
+ def __str__(self) -> str:
233
+ return str(self.type)
234
+
235
+ @builtin
236
+ def load(self, layout, _semantic: GluonSemantic = None) -> tensor:
237
+ """
238
+ Load a tensor from shared memory.
239
+
240
+ Args:
241
+ layout (DistributedLayout): The destination layout of the tensor.
242
+
243
+ Returns:
244
+ tensor: A Gluon tensor containing the loaded data.
245
+ """
246
+ layout = _unwrap_if_constexpr(layout)
247
+ return _semantic.shared_load(self, layout)
248
+
249
+ @builtin
250
+ def store(self, value, _semantic: GluonSemantic = None) -> None:
251
+ """
252
+ Store a tensor into shared memory.
253
+
254
+ Args:
255
+ value (tensor): The tensor whose contents to store.
256
+ """
257
+ return _semantic.shared_store(self, value)
258
+
259
+ @builtin
260
+ def slice(self, start, length, dim=0, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
261
+ """
262
+ Create a subview of shared memory by slicing along a given dimension.
263
+
264
+ Args:
265
+ start (int): The starting index of the slice.
266
+ length (int): The length of the slice.
267
+ dim (int): The dimension to slice (default: 0).
268
+
269
+ Returns:
270
+ shared_memory_descriptor: Descriptor for the sliced subview.
271
+ """
272
+ start = _unwrap_if_constexpr(start)
273
+ length = _unwrap_if_constexpr(length)
274
+ dim = _unwrap_if_constexpr(dim)
275
+ return _semantic.memdesc_slice(self, start, length, dim)
276
+
277
+ @builtin
278
+ def index(self, index, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
279
+ """
280
+ Create a subview of shared memory by indexing along the first dimension.
281
+
282
+ Args:
283
+ index (int): The index at which to take the subview.
284
+
285
+ Returns:
286
+ shared_memory_descriptor: Descriptor for the indexed subview.
287
+ """
288
+ index = _unwrap_if_constexpr(index)
289
+ return _semantic.memdesc_index(self, index)
290
+
291
+ @builtin
292
+ def permute(self, order, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
293
+ """
294
+ Permute the dimensions of the shared memory descriptor.
295
+
296
+ Args:
297
+ order (List[int]): The new ordering of dimensions.
298
+
299
+ Returns:
300
+ shared_memory_descriptor: Descriptor with permuted dimensions.
301
+ """
302
+ order = [_unwrap_if_constexpr(o) for o in order]
303
+ return _semantic.memdesc_trans(self, order)
304
+
305
+ @builtin
306
+ def reshape(self, shape, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
307
+ """
308
+ Reshape the shared memory descriptor to a new shape and layout.
309
+
310
+ Args:
311
+ shape (List[int]): The target shape.
312
+
313
+ Returns:
314
+ shared_memory_descriptor: Descriptor with the new shape and layout.
315
+ """
316
+ shape = [_unwrap_if_constexpr(s) for s in shape]
317
+
318
+ return _semantic.memdesc_reshape(self, shape)
319
+
320
+ @builtin
321
+ def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> shared_memory_descriptor:
322
+ """
323
+ Reinterpret the shared memory descriptor as a different dtype, shape, or layout.
324
+
325
+ Args:
326
+ dtype (dtype): The new data type.
327
+ shape (List[int]): The new shape.
328
+ layout (SharedLayout): The new layout.
329
+
330
+ Returns:
331
+ shared_memory_descriptor: Descriptor with updated type and layout.
332
+ """
333
+ dtype = _unwrap_if_constexpr(dtype)
334
+ shape = [_unwrap_if_constexpr(s) for s in shape]
335
+ layout = _unwrap_if_constexpr(layout)
336
+
337
+ return _semantic.memdesc_reinterpret(self, dtype, shape, layout)
338
+
339
+ @builtin
340
+ def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
341
+ """
342
+ Dummy use to keep the shared memory descriptor alive.
343
+ """
344
+ return _semantic.shared_dealloc(self)
345
+
346
+
347
+ @builtin
348
+ def arange(start, end, layout=None, _semantic=None):
349
+ """
350
+ Generate a sequence tensor with values in [start, end) using a specified layout.
351
+
352
+ Args:
353
+ start (int): Inclusive start of the sequence.
354
+ end (int): Exclusive end of the sequence.
355
+ layout (DistributedLayout): The layout of the output tensor. Defaults to AutoLayout.
356
+
357
+ Returns:
358
+ tensor: A 1D tensor containing sequential values.
359
+ """
360
+ start = _unwrap_if_constexpr(start)
361
+ end = _unwrap_if_constexpr(end)
362
+ layout = _unwrap_if_constexpr(layout)
363
+ return _semantic.arange(start, end, layout)
364
+
365
+
366
+ @builtin
367
+ def convert_layout(value, layout, assert_trivial=False, _semantic=None):
368
+ """
369
+ Convert a tensor to a different distributed layout.
370
+
371
+ Args:
372
+ value (tensor): The input tensor.
373
+ layout (DistributedLayout): The target layout.
374
+ assert_trivial (bool): If True, asserts that the conversion is trivial (no data movement).
375
+
376
+ Returns:
377
+ tensor: The tensor with the new layout.
378
+ """
379
+ layout = _unwrap_if_constexpr(layout)
380
+ return _semantic.convert_layout(value, layout, assert_trivial)
381
+
382
+
383
+ @builtin
384
+ def full(shape, value, dtype, layout=None, _semantic=None):
385
+ """
386
+ Create a tensor filled with a scalar value, with specified shape, dtype, and layout.
387
+
388
+ Args:
389
+ shape (Sequence[int]): The shape of the tensor.
390
+ value (int or float): The fill value.
391
+ dtype (dtype): The data type for the tensor.
392
+ layout (Optional[DistributedLayout]): The layout of the output tensor, defaults to AutoLayout().
393
+
394
+ Returns:
395
+ tensor: A tensor where every element equals value.
396
+ """
397
+ shape = _unwrap_shape(shape)
398
+ value = _unwrap_if_constexpr(value)
399
+ dtype = _unwrap_if_constexpr(dtype)
400
+ layout = _unwrap_if_constexpr(layout)
401
+ return _semantic.full(shape, value, dtype, layout)
402
+
403
+
404
+ @builtin
405
+ def histogram(input, num_bins, mask=None, layout=None, _semantic=None, _generator=None):
406
+ """
407
+ Compute a histogram of a 1D integer tensor.
408
+
409
+ Args:
410
+ input (tensor): 1D tensor of integer values.
411
+ num_bins (int): Number of bins. Bins have width 1 and start at 0.
412
+ mask (Optional[tensor]): Boolean mask to exclude elements when False.
413
+ layout (DistributedLayout): Destination layout of the output histogram.
414
+
415
+ Returns:
416
+ tensor: 1D int32 tensor of length `num_bins` with the requested layout.
417
+ """
418
+ num_bins = _unwrap_if_constexpr(num_bins)
419
+ layout = _unwrap_if_constexpr(layout)
420
+ if mask is not None:
421
+ mask = _semantic.to_tensor(mask)
422
+ return _semantic.histogram(input, num_bins, mask, layout)
423
+
424
+
425
+ @builtin
426
+ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None) -> shared_memory_descriptor:
427
+ """
428
+ Allocate shared memory for a tensor with the given element type, shape, and layout.
429
+
430
+ Args:
431
+ element_ty (dtype): The element data type.
432
+ shape (Sequence[int]): The dimensions of the shared memory.
433
+ layout (SharedLayout): The shared memory layout.
434
+ value (tensor, optional): Initial value to copy into shared memory.
435
+
436
+ Returns:
437
+ shared_memory_descriptor: Descriptor for the allocated memory.
438
+ """
439
+ element_ty = _unwrap_if_constexpr(element_ty)
440
+ shape = _unwrap_if_constexpr(shape)
441
+ shape = [_unwrap_if_constexpr(s) for s in shape]
442
+ layout = _unwrap_if_constexpr(layout)
443
+ return _semantic.allocate_shared(element_ty, shape, layout, value)
444
+
445
+
446
+ @builtin
447
+ def set_auto_layout(value, layout, _semantic=None):
448
+ """
449
+ Set a a tensor with AutoLayout to a concrete layout
450
+
451
+ Args:
452
+ value (tensor): The input tensor.
453
+ layout (DistribtedLayout): The target layout.
454
+
455
+ Returns:
456
+ tensor: The tensor with the new layout.
457
+ """
458
+ layout = _unwrap_if_constexpr(layout)
459
+ return _semantic.set_auto_layout(value, layout)
460
+
461
+
462
+ @builtin
463
+ def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
464
+ _semantic=None, _generator=None):
465
+ """
466
+ Create a warp-specialized execution region, partitioning work across warps.
467
+
468
+ Args:
469
+ default_args (List[Any]): Arguments for the default region.
470
+ default_partition (callable): Function to build the default execution region.
471
+ worker_args (List[Any]): Arguments for each warp partition.
472
+ worker_partitions (List[callable]): Functions for each warp partition.
473
+ worker_num_warps (List[int]): Number of warps per partition.
474
+ worker_num_regs (List[int]): Number of registers per partition.
475
+
476
+ Returns:
477
+ Tuple[Any, ...]: Results from the default region.
478
+ """
479
+ worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
480
+ worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
481
+ return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
482
+ worker_num_regs, _generator)
483
+
484
+
485
+ @builtin
486
+ def thread_barrier(_semantic=None):
487
+ """
488
+ Insert a barrier to synchronize threads within a CTA.
489
+ """
490
+ return _semantic.debug_barrier()