triton-windows 3.2.0.post11__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (154) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +85 -0
  3. triton/_internal_testing.py +123 -0
  4. triton/backends/__init__.py +50 -0
  5. triton/backends/amd/compiler.py +368 -0
  6. triton/backends/amd/driver.c +211 -0
  7. triton/backends/amd/driver.py +512 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  25. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  26. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  27. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  28. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  31. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  32. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  40. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  41. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  42. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  43. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  44. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  45. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  46. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  48. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  49. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  50. triton/backends/amd/include/hip/device_functions.h +38 -0
  51. triton/backends/amd/include/hip/driver_types.h +468 -0
  52. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  53. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  54. triton/backends/amd/include/hip/hip_common.h +100 -0
  55. triton/backends/amd/include/hip/hip_complex.h +38 -0
  56. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  57. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  58. triton/backends/amd/include/hip/hip_ext.h +159 -0
  59. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  60. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  61. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  62. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  63. triton/backends/amd/include/hip/hip_profile.h +27 -0
  64. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  65. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  66. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  67. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  68. triton/backends/amd/include/hip/hip_version.h +17 -0
  69. triton/backends/amd/include/hip/hiprtc.h +421 -0
  70. triton/backends/amd/include/hip/library_types.h +78 -0
  71. triton/backends/amd/include/hip/math_functions.h +42 -0
  72. triton/backends/amd/include/hip/surface_types.h +63 -0
  73. triton/backends/amd/include/hip/texture_types.h +194 -0
  74. triton/backends/amd/include/hsa/Brig.h +1131 -0
  75. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  76. triton/backends/amd/include/hsa/amd_hsa_elf.h +436 -0
  77. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  78. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  79. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  80. triton/backends/amd/include/hsa/hsa.h +5729 -0
  81. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  82. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  83. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  84. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  85. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  87. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  88. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  89. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  90. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  91. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  92. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  93. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  94. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  95. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  96. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  97. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  98. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  99. triton/backends/amd/include/roctracer/roctx.h +229 -0
  100. triton/backends/amd/lib/ockl.bc +0 -0
  101. triton/backends/amd/lib/ocml.bc +0 -0
  102. triton/backends/compiler.py +304 -0
  103. triton/backends/driver.py +48 -0
  104. triton/backends/nvidia/__init__.py +0 -0
  105. triton/backends/nvidia/bin/ptxas.exe +0 -0
  106. triton/backends/nvidia/compiler.py +410 -0
  107. triton/backends/nvidia/driver.c +451 -0
  108. triton/backends/nvidia/driver.py +524 -0
  109. triton/backends/nvidia/include/cuda.h +24359 -0
  110. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  111. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  112. triton/compiler/__init__.py +4 -0
  113. triton/compiler/code_generator.py +1303 -0
  114. triton/compiler/compiler.py +430 -0
  115. triton/compiler/errors.py +51 -0
  116. triton/compiler/make_launcher.py +0 -0
  117. triton/errors.py +5 -0
  118. triton/language/__init__.py +294 -0
  119. triton/language/_utils.py +21 -0
  120. triton/language/core.py +2694 -0
  121. triton/language/extra/__init__.py +26 -0
  122. triton/language/extra/cuda/__init__.py +13 -0
  123. triton/language/extra/cuda/_experimental_tma.py +108 -0
  124. triton/language/extra/cuda/libdevice.py +1629 -0
  125. triton/language/extra/cuda/utils.py +109 -0
  126. triton/language/extra/hip/__init__.py +3 -0
  127. triton/language/extra/hip/libdevice.py +475 -0
  128. triton/language/extra/libdevice.py +786 -0
  129. triton/language/math.py +250 -0
  130. triton/language/random.py +207 -0
  131. triton/language/semantic.py +1796 -0
  132. triton/language/standard.py +452 -0
  133. triton/runtime/__init__.py +23 -0
  134. triton/runtime/autotuner.py +408 -0
  135. triton/runtime/build.py +111 -0
  136. triton/runtime/cache.py +295 -0
  137. triton/runtime/driver.py +60 -0
  138. triton/runtime/errors.py +26 -0
  139. triton/runtime/interpreter.py +1235 -0
  140. triton/runtime/jit.py +951 -0
  141. triton/testing.py +511 -0
  142. triton/tools/__init__.py +0 -0
  143. triton/tools/build_extern.py +365 -0
  144. triton/tools/compile.c +67 -0
  145. triton/tools/compile.h +14 -0
  146. triton/tools/compile.py +155 -0
  147. triton/tools/disasm.py +144 -0
  148. triton/tools/experimental_descriptor.py +32 -0
  149. triton/tools/link.py +322 -0
  150. triton/windows_utils.py +375 -0
  151. triton_windows-3.2.0.post11.dist-info/METADATA +39 -0
  152. triton_windows-3.2.0.post11.dist-info/RECORD +154 -0
  153. triton_windows-3.2.0.post11.dist-info/WHEEL +5 -0
  154. triton_windows-3.2.0.post11.dist-info/top_level.txt +12 -0
@@ -0,0 +1,294 @@
1
+ """isort:skip_file"""
2
+ # Import order is significant here.
3
+
4
+ from . import math
5
+ from . import extra
6
+ from .standard import (
7
+ argmax,
8
+ argmin,
9
+ cdiv,
10
+ cumprod,
11
+ cumsum,
12
+ flip,
13
+ interleave,
14
+ max,
15
+ min,
16
+ ravel,
17
+ sigmoid,
18
+ softmax,
19
+ sort,
20
+ sum,
21
+ swizzle2d,
22
+ xor_sum,
23
+ zeros,
24
+ zeros_like,
25
+ )
26
+ from .core import (
27
+ PropagateNan,
28
+ TRITON_MAX_TENSOR_NUMEL,
29
+ _experimental_descriptor_load,
30
+ _experimental_descriptor_store,
31
+ add,
32
+ advance,
33
+ arange,
34
+ associative_scan,
35
+ assume,
36
+ atomic_add,
37
+ atomic_and,
38
+ atomic_cas,
39
+ atomic_max,
40
+ atomic_min,
41
+ atomic_or,
42
+ atomic_xchg,
43
+ atomic_xor,
44
+ bfloat16,
45
+ block_type,
46
+ broadcast,
47
+ broadcast_to,
48
+ cat,
49
+ cast,
50
+ clamp,
51
+ const,
52
+ constexpr,
53
+ debug_barrier,
54
+ device_assert,
55
+ device_print,
56
+ dot,
57
+ dot_scaled,
58
+ dtype,
59
+ expand_dims,
60
+ float16,
61
+ float32,
62
+ float64,
63
+ float8e4b15,
64
+ float8e4nv,
65
+ float8e4b8,
66
+ float8e5,
67
+ float8e5b16,
68
+ full,
69
+ function_type,
70
+ histogram,
71
+ inline_asm_elementwise,
72
+ int1,
73
+ int16,
74
+ int32,
75
+ int64,
76
+ int8,
77
+ join,
78
+ load,
79
+ make_block_ptr,
80
+ max_constancy,
81
+ max_contiguous,
82
+ maximum,
83
+ minimum,
84
+ multiple_of,
85
+ num_programs,
86
+ permute,
87
+ pi32_t,
88
+ pointer_type,
89
+ nv_tma_desc_type,
90
+ program_id,
91
+ range,
92
+ reduce,
93
+ reshape,
94
+ split,
95
+ static_assert,
96
+ static_print,
97
+ static_range,
98
+ store,
99
+ tensor,
100
+ trans,
101
+ uint16,
102
+ uint32,
103
+ uint64,
104
+ uint8,
105
+ view,
106
+ void,
107
+ where,
108
+ )
109
+ from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor,
110
+ ceil)
111
+ from .random import (
112
+ pair_uniform_to_normal,
113
+ philox,
114
+ philox_impl,
115
+ rand,
116
+ rand4x,
117
+ randint,
118
+ randint4x,
119
+ randn,
120
+ randn4x,
121
+ uint_to_uniform_float,
122
+ )
123
+
124
+ __all__ = [
125
+ "PropagateNan",
126
+ "TRITON_MAX_TENSOR_NUMEL",
127
+ "_experimental_descriptor_load",
128
+ "_experimental_descriptor_store",
129
+ "abs",
130
+ "add",
131
+ "advance",
132
+ "arange",
133
+ "argmax",
134
+ "argmin",
135
+ "associative_scan",
136
+ "assume",
137
+ "atomic_add",
138
+ "atomic_and",
139
+ "atomic_cas",
140
+ "atomic_max",
141
+ "atomic_min",
142
+ "atomic_or",
143
+ "atomic_xchg",
144
+ "atomic_xor",
145
+ "bfloat16",
146
+ "block_type",
147
+ "broadcast",
148
+ "broadcast_to",
149
+ "builtin",
150
+ "cat",
151
+ "cast",
152
+ "cdiv",
153
+ "ceil",
154
+ "clamp",
155
+ "const",
156
+ "constexpr",
157
+ "cos",
158
+ "cumprod",
159
+ "cumsum",
160
+ "debug_barrier",
161
+ "device_assert",
162
+ "device_print",
163
+ "div_rn",
164
+ "dot",
165
+ "dot_scaled",
166
+ "dtype",
167
+ "erf",
168
+ "exp",
169
+ "exp2",
170
+ "expand_dims",
171
+ "extra",
172
+ "fdiv",
173
+ "flip",
174
+ "float16",
175
+ "float32",
176
+ "float64",
177
+ "float8e4b15",
178
+ "float8e4nv",
179
+ "float8e4b8",
180
+ "float8e5",
181
+ "float8e5b16",
182
+ "floor",
183
+ "fma",
184
+ "full",
185
+ "function_type",
186
+ "histogram",
187
+ "inline_asm_elementwise",
188
+ "interleave",
189
+ "int1",
190
+ "int16",
191
+ "int32",
192
+ "int64",
193
+ "int8",
194
+ "ir",
195
+ "join",
196
+ "load",
197
+ "log",
198
+ "log2",
199
+ "make_block_ptr",
200
+ "math",
201
+ "max",
202
+ "max_constancy",
203
+ "max_contiguous",
204
+ "maximum",
205
+ "min",
206
+ "minimum",
207
+ "multiple_of",
208
+ "num_programs",
209
+ "pair_uniform_to_normal",
210
+ "permute",
211
+ "philox",
212
+ "philox_impl",
213
+ "pi32_t",
214
+ "pointer_type",
215
+ "nv_tma_desc_type",
216
+ "program_id",
217
+ "rand",
218
+ "rand4x",
219
+ "randint",
220
+ "randint4x",
221
+ "randn",
222
+ "randn4x",
223
+ "range",
224
+ "ravel",
225
+ "reduce",
226
+ "reshape",
227
+ "rsqrt",
228
+ "sigmoid",
229
+ "sin",
230
+ "softmax",
231
+ "sort",
232
+ "split",
233
+ "sqrt",
234
+ "sqrt_rn",
235
+ "static_assert",
236
+ "static_print",
237
+ "static_range",
238
+ "store",
239
+ "sum",
240
+ "swizzle2d",
241
+ "tensor",
242
+ "trans",
243
+ "triton",
244
+ "uint16",
245
+ "uint32",
246
+ "uint64",
247
+ "uint8",
248
+ "uint_to_uniform_float",
249
+ "umulhi",
250
+ "view",
251
+ "void",
252
+ "where",
253
+ "xor_sum",
254
+ "zeros",
255
+ "zeros_like",
256
+ ]
257
+
258
+
259
+ def str_to_ty(name):
260
+ if name[0] == "*":
261
+ name = name[1:]
262
+ const = False
263
+ if name[0] == "k":
264
+ name = name[1:]
265
+ const = True
266
+ ty = str_to_ty(name)
267
+ return pointer_type(element_ty=ty, const=const)
268
+
269
+ if name == "nvTmaDesc":
270
+ return nv_tma_desc_type()
271
+
272
+ tys = {
273
+ "fp8e4nv": float8e4nv,
274
+ "fp8e4b8": float8e4b8,
275
+ "fp8e5": float8e5,
276
+ "fp8e5b16": float8e5b16,
277
+ "fp8e4b15": float8e4b15,
278
+ "fp16": float16,
279
+ "bf16": bfloat16,
280
+ "fp32": float32,
281
+ "fp64": float64,
282
+ "i1": int1,
283
+ "i8": int8,
284
+ "i16": int16,
285
+ "i32": int32,
286
+ "i64": int64,
287
+ "u1": int1,
288
+ "u8": uint8,
289
+ "u16": uint16,
290
+ "u32": uint32,
291
+ "u64": uint64,
292
+ "B": int1,
293
+ }
294
+ return tys[name]
@@ -0,0 +1,21 @@
1
+ from typing import List
2
+
3
+ TRITON_MAX_TENSOR_NUMEL = 1048576
4
+
5
+
6
+ def is_power_of_two(x):
7
+ return (x & (x - 1)) == 0
8
+
9
+
10
+ def validate_block_shape(shape: List[int]):
11
+ numel = 1
12
+ for i, d in enumerate(shape):
13
+ if not isinstance(d, int):
14
+ raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
15
+ if not is_power_of_two(d):
16
+ raise ValueError(f"Shape element {i} must be a power of 2")
17
+ numel *= d
18
+
19
+ if numel > TRITON_MAX_TENSOR_NUMEL:
20
+ raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
21
+ return numel