triton-windows 3.2.0.post11__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
@@ -0,0 +1,452 @@
1
+ from __future__ import annotations
2
+
3
+ from ..runtime.jit import jit
4
+ from . import core
5
+ from . import math
6
+
7
+ # constexpr utilities
8
+
9
+
10
+ def _log2(i: core.constexpr):
11
+ log2 = 0
12
+ n = i.value
13
+ while n > 1:
14
+ n >>= 1
15
+ log2 += 1
16
+ return core.constexpr(log2)
17
+
18
+
19
+ def _is_power_of_two(i: core.constexpr):
20
+ n = i.value
21
+ return core.constexpr((n & (n - 1)) == 0 and n != 0)
22
+
23
+
24
+ # -----------------------
25
+ # Standard library
26
+ # -----------------------
27
+
28
+
29
+ @core._tensor_member_fn
30
+ @jit
31
+ def cdiv(x, div):
32
+ """
33
+ Computes the ceiling division of :code:`x` by :code:`div`
34
+
35
+ :param x: the input number
36
+ :type x: Block
37
+ :param div: the divisor
38
+ :type div: Block
39
+ """
40
+ return (x + div - 1) // div
41
+
42
+
43
+ @core._tensor_member_fn
44
+ @jit
45
+ @math._add_math_1arg_docstr("sigmoid")
46
+ def sigmoid(x):
47
+ return 1 / (1 + math.exp(-x))
48
+
49
+
50
+ @core._tensor_member_fn
51
+ @jit
52
+ @math._add_math_1arg_docstr("softmax")
53
+ def softmax(x, ieee_rounding=False):
54
+ z = x - max(x, 0)
55
+ num = math.exp(z)
56
+ den = sum(num, 0)
57
+ return math.fdiv(num, den, ieee_rounding)
58
+
59
+
60
+ @core._tensor_member_fn
61
+ @jit
62
+ def ravel(x):
63
+ """
64
+ Returns a contiguous flattened view of :code:`x`.
65
+
66
+ :param x: the input tensor
67
+ :type x: Block
68
+ """
69
+ return core.reshape(x, [x.numel], can_reorder=True)
70
+
71
+
72
+ @jit
73
+ def swizzle2d(i, j, size_i, size_j, size_g):
74
+ """
75
+ Transforms the indices of a row-major `size_i * size_j` matrix into
76
+ the indices of a column-major matrix for each group of `size_g` rows.
77
+
78
+ For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will
79
+ transform ::
80
+
81
+ [[0 , 1 , 2 , 3 ],
82
+ [4 , 5 , 6 , 7 ],
83
+ [8 , 9 , 10, 11],
84
+ [12, 13, 14, 15]]
85
+
86
+ into ::
87
+
88
+ [[0, 2, 4 , 6 ],
89
+ [1, 3, 5 , 7 ],
90
+ [8, 10, 12, 14],
91
+ [9, 11, 13, 15]]
92
+ """
93
+ # "unrolled index in array"
94
+ ij = i * size_j + j
95
+ # number of elements in `size_g` groups
96
+ # of `size_j` columns
97
+ size_gj = size_g * size_j
98
+ # index of the group in which (i,j) is
99
+ group_id = ij // size_gj
100
+ # row-index of the first element of this group
101
+ off_i = group_id * size_g
102
+ # last group may have fewer rows
103
+ size_g = core.minimum(size_i - off_i, size_g)
104
+ # linear index with respect to the first element in this group
105
+ ij = ij % size_gj
106
+ # new row and column indices
107
+ new_i = off_i + ij % size_g
108
+ new_j = ij // size_g
109
+ return new_i, new_j
110
+
111
+
112
+ @jit
113
+ def zeros(shape, dtype):
114
+ """
115
+ Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
116
+
117
+ :param shape: Shape of the new array, e.g., (8, 16) or (8, )
118
+ :type shape: tuple of ints
119
+ :param dtype: Data-type of the new array, e.g., :code:`tl.float16`
120
+ :type dtype: DType
121
+ """
122
+ return core.full(shape, 0, dtype)
123
+
124
+
125
+ @jit
126
+ def zeros_like(input):
127
+ """
128
+ Returns a tensor of zeros with the same shape and type as a given tensor.
129
+
130
+ :param input: input tensor
131
+ :type input: Tensor
132
+ """
133
+ return zeros(input.shape, input.dtype)
134
+
135
+
136
+ # max and argmax
137
+
138
+
139
+ @jit
140
+ def _argmax_combine(value1, index1, value2, index2, tie_break_left):
141
+ if tie_break_left:
142
+ tie = value1 == value2 and index1 < index2
143
+ else:
144
+ tie = False
145
+ gt = value1 > value2 or tie
146
+ v_ret = core.where(gt, value1, value2)
147
+ i_ret = core.where(gt, index1, index2)
148
+ return v_ret, i_ret
149
+
150
+
151
+ @jit
152
+ def _argmax_combine_tie_break_left(value1, index1, value2, index2):
153
+ return _argmax_combine(value1, index1, value2, index2, True)
154
+
155
+
156
+ @jit
157
+ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
158
+ return _argmax_combine(value1, index1, value2, index2, False)
159
+
160
+
161
+ @jit
162
+ def _elementwise_max(a, b):
163
+ return core.maximum(a, b)
164
+
165
+
166
+ @core._tensor_member_fn
167
+ @jit
168
+ @core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
169
+ tie_break_arg="return_indices_tie_break_left")
170
+ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
171
+ input = core._promote_bfloat16_to_float32(input)
172
+ if return_indices:
173
+ if return_indices_tie_break_left:
174
+ return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims)
175
+ else:
176
+ return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims)
177
+ else:
178
+ if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
179
+ if core.constexpr(input.dtype.is_floating()):
180
+ input = input.to(core.float32)
181
+ else:
182
+ assert input.dtype.is_int(), "Expecting input to be integer type"
183
+ input = input.to(core.int32)
184
+ return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims)
185
+
186
+
187
+ @core._tensor_member_fn
188
+ @jit
189
+ @core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
190
+ def argmax(input, axis, tie_break_left=True, keep_dims=False):
191
+ (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
192
+ return ret
193
+
194
+
195
+ # min and argmin
196
+
197
+
198
+ @jit
199
+ def _argmin_combine(value1, index1, value2, index2, tie_break_left):
200
+ if tie_break_left:
201
+ tie = value1 == value2 and index1 < index2
202
+ else:
203
+ tie = False
204
+ lt = value1 < value2 or tie
205
+ value_ret = core.where(lt, value1, value2)
206
+ index_ret = core.where(lt, index1, index2)
207
+ return value_ret, index_ret
208
+
209
+
210
+ @jit
211
+ def _argmin_combine_tie_break_left(value1, index1, value2, index2):
212
+ return _argmin_combine(value1, index1, value2, index2, True)
213
+
214
+
215
+ @jit
216
+ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
217
+ return _argmin_combine(value1, index1, value2, index2, False)
218
+
219
+
220
+ @jit
221
+ def _elementwise_min(a, b):
222
+ return core.minimum(a, b)
223
+
224
+
225
+ @core._tensor_member_fn
226
+ @jit
227
+ @core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
228
+ tie_break_arg="return_indices_tie_break_left")
229
+ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
230
+ input = core._promote_bfloat16_to_float32(input)
231
+ if return_indices:
232
+ if return_indices_tie_break_left:
233
+ return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims)
234
+ else:
235
+ return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims)
236
+ else:
237
+ if core.constexpr(input.dtype.primitive_bitwidth) < 32:
238
+ if core.constexpr(input.dtype.is_floating()):
239
+ input = input.to(core.float32)
240
+ else:
241
+ assert input.dtype.is_int(), "Expecting input to be integer type"
242
+ input = input.to(core.int32)
243
+ return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims)
244
+
245
+
246
+ @core._tensor_member_fn
247
+ @jit
248
+ @core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
249
+ def argmin(input, axis, tie_break_left=True, keep_dims=False):
250
+ _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
251
+ return ret
252
+
253
+
254
+ @jit
255
+ def _sum_combine(a, b):
256
+ return a + b
257
+
258
+
259
+ # sum
260
+
261
+
262
+ @core._tensor_member_fn
263
+ @jit
264
+ @core._add_reduction_docstr("sum")
265
+ def sum(input, axis=None, keep_dims=False):
266
+ input = core._promote_bfloat16_to_float32(input)
267
+ return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
268
+
269
+
270
+ @jit
271
+ def _xor_combine(a, b):
272
+ return a ^ b
273
+
274
+
275
+ # xor sum
276
+
277
+
278
+ @core._tensor_member_fn
279
+ @core.builtin
280
+ @core._add_reduction_docstr("xor sum")
281
+ def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None):
282
+ scalar_ty = input.type.scalar
283
+ if not scalar_ty.is_int():
284
+ raise ValueError("xor_sum only supported for integers")
285
+
286
+ input = core._promote_bfloat16_to_float32(input, _builder=_builder)
287
+ return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator)
288
+
289
+
290
+ # cumsum
291
+
292
+
293
+ @core._tensor_member_fn
294
+ @jit
295
+ @core._add_scan_docstr("cumsum")
296
+ def cumsum(input, axis=0, reverse=False):
297
+ # todo rename this to a generic function name
298
+ input = core._promote_bfloat16_to_float32(input)
299
+ return core.associative_scan(input, axis, _sum_combine, reverse)
300
+
301
+
302
+ # cumprod
303
+
304
+
305
+ @jit
306
+ def _prod_combine(a, b):
307
+ return a * b
308
+
309
+
310
+ @core._tensor_member_fn
311
+ @jit
312
+ @core._add_scan_docstr("cumprod")
313
+ def cumprod(input, axis=0, reverse=False):
314
+ # todo rename this to a generic function name
315
+ input = core._promote_bfloat16_to_float32(input)
316
+ return core.associative_scan(input, axis, _prod_combine, reverse)
317
+
318
+
319
+ # sort
320
+
321
+
322
+ @jit
323
+ def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr):
324
+ n_outer: core.constexpr = x.numel >> n_dims
325
+ shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
326
+ y = core.reshape(x, shape)
327
+ # slice left/right with 'stride' 2**(n_dims - i - 1)
328
+ mask = core.arange(0, 2)[None, :, None]
329
+ left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype)
330
+ right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape).to(y.dtype)
331
+ left = core.reshape(left, x.shape)
332
+ right = core.reshape(right, x.shape)
333
+ # actual compare-and-swap
334
+ idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
335
+ ileft = left.to(idtype, bitcast=True)
336
+ iright = right.to(idtype, bitcast=True)
337
+ ix = x.to(idtype, bitcast=True)
338
+ ret = ix ^ core.where((left > right) != flip, ileft ^ iright, zeros_like(ix))
339
+ return ret.to(x.dtype, bitcast=True)
340
+
341
+
342
+ @jit
343
+ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
344
+ '''
345
+ order_type 0 == ascending
346
+ order_type 1 == descending
347
+ order_type 2 == alternating
348
+ '''
349
+ n_outer: core.constexpr = x.numel >> n_dims
350
+ core.static_assert(stage <= n_dims)
351
+ # flip denotes whether to re-arrange sub-sequences of elements in ascending or
352
+ # descending order.
353
+ # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
354
+ # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
355
+ # a stride of 2) at this stage
356
+ if order == 2:
357
+ shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
358
+ flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
359
+ else:
360
+ flip = order
361
+ # perform `stage` rounds of `compare-and-swap`
362
+ for i in core.static_range(stage):
363
+ x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims)
364
+ return x
365
+
366
+
367
+ @core._tensor_member_fn
368
+ @jit
369
+ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
370
+ """
371
+ Sorts a tensor along a specified dimension.
372
+
373
+ :param x: The input tensor to be sorted.
374
+ :type x: Tensor
375
+ :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
376
+ :type dim: int, optional
377
+ :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
378
+ :type descending: bool, optional
379
+ """
380
+ # handle default dimension or check that it is the most minor dim
381
+ _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
382
+ core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
383
+ # iteratively run bitonic merge-sort steps
384
+ n_dims: core.constexpr = _log2(x.shape[_dim])
385
+ for i in core.static_range(1, n_dims + 1):
386
+ x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims)
387
+ return x
388
+
389
+
390
+ # flip
391
+
392
+
393
+ def _get_flip_dim(dim, shape):
394
+ dim = core._unwrap_if_constexpr(dim)
395
+ shape = core._unwrap_if_constexpr(shape)
396
+ if dim is None:
397
+ dim = len(shape) - 1
398
+ assert dim == len(shape) - 1, "Currently only support flipping the last dimension"
399
+ return core.constexpr(dim)
400
+
401
+
402
+ @core._tensor_member_fn
403
+ @jit
404
+ def flip(x, dim=None):
405
+ """
406
+ Flips a tensor `x` along the dimension `dim`.
407
+
408
+ :param x: the first input tensor
409
+ :type x: Block
410
+ :param dim: the dimension to flip along (currently only final dimension supported)
411
+ :type dim: int
412
+ """
413
+ core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
414
+ core.static_assert(_is_power_of_two(x.numel))
415
+ # # reshape the tensor to have all dimensions be 2.
416
+ # # TODO: We shouldn't have to change the dimensions not sorted.
417
+ steps: core.constexpr = _log2(x.numel)
418
+ start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
419
+ y = core.reshape(x, [2] * steps)
420
+ y = core.expand_dims(y, start)
421
+ flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
422
+ for i in core.static_range(start, steps):
423
+ flip2 = flip
424
+ for j in core.static_range(0, steps + 1):
425
+ if j != i and j != i + 1:
426
+ flip2 = core.expand_dims(flip2, j)
427
+ y = sum(y * flip2, i + 1, keep_dims=True)
428
+ x = core.reshape(y, x.shape)
429
+ return x
430
+
431
+
432
+ @jit
433
+ def interleave(a, b):
434
+ """
435
+ Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape.
436
+ Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])`
437
+
438
+ :param a: The first input tensor.
439
+ :type a: Tensor
440
+ :param b: The second input tensor.
441
+ :type b: Tensor
442
+ """
443
+ c = core.join(a, b)
444
+
445
+ if len(c.shape) == 1:
446
+ # We must have interleaved two scalars.
447
+ return c
448
+ else:
449
+ # This `else` is necessary because Triton's AST parser doesn't
450
+ # understand that if we take the `if` above we definitely don't run this
451
+ # `else`.
452
+ return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]])
@@ -0,0 +1,23 @@
1
+ from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics)
2
+ from .cache import RedisRemoteCacheBackend, RemoteCacheBackend
3
+ from .driver import driver
4
+ from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret
5
+ from .errors import OutOfResources, InterpreterError
6
+
7
+ __all__ = [
8
+ "autotune",
9
+ "Autotuner",
10
+ "Config",
11
+ "driver",
12
+ "Heuristics",
13
+ "heuristics",
14
+ "InterpreterError",
15
+ "JITFunction",
16
+ "KernelInterface",
17
+ "MockTensor",
18
+ "OutOfResources",
19
+ "RedisRemoteCacheBackend",
20
+ "reinterpret",
21
+ "RemoteCacheBackend",
22
+ "TensorWrapper",
23
+ ]