triton-windows 3.5.0.post21__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (217) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +82 -0
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +255 -0
  5. triton/_utils.py +126 -0
  6. triton/backends/__init__.py +47 -0
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +461 -0
  9. triton/backends/amd/driver.c +283 -0
  10. triton/backends/amd/driver.py +724 -0
  11. triton/backends/amd/lib/asanrtl.bc +0 -0
  12. triton/backends/amd/lib/ockl.bc +0 -0
  13. triton/backends/amd/lib/ocml.bc +0 -0
  14. triton/backends/compiler.py +90 -0
  15. triton/backends/driver.py +66 -0
  16. triton/backends/nvidia/__init__.py +0 -0
  17. triton/backends/nvidia/bin/ptxas.exe +0 -0
  18. triton/backends/nvidia/compiler.py +533 -0
  19. triton/backends/nvidia/driver.c +517 -0
  20. triton/backends/nvidia/driver.py +799 -0
  21. triton/backends/nvidia/include/cuda.h +26280 -0
  22. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  23. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  24. triton/compiler/__init__.py +7 -0
  25. triton/compiler/code_generator.py +1614 -0
  26. triton/compiler/compiler.py +509 -0
  27. triton/compiler/errors.py +51 -0
  28. triton/compiler/make_launcher.py +0 -0
  29. triton/errors.py +5 -0
  30. triton/experimental/__init__.py +0 -0
  31. triton/experimental/gluon/__init__.py +5 -0
  32. triton/experimental/gluon/_compiler.py +0 -0
  33. triton/experimental/gluon/_runtime.py +102 -0
  34. triton/experimental/gluon/language/__init__.py +119 -0
  35. triton/experimental/gluon/language/_core.py +490 -0
  36. triton/experimental/gluon/language/_layouts.py +583 -0
  37. triton/experimental/gluon/language/_math.py +20 -0
  38. triton/experimental/gluon/language/_semantic.py +380 -0
  39. triton/experimental/gluon/language/_standard.py +80 -0
  40. triton/experimental/gluon/language/amd/__init__.py +4 -0
  41. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  42. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  43. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  44. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  45. triton/experimental/gluon/language/extra/__init__.py +3 -0
  46. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  47. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  48. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  49. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  50. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  51. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  52. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  53. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  54. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  55. triton/experimental/gluon/nvidia/__init__.py +4 -0
  56. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  57. triton/experimental/gluon/nvidia/hopper.py +45 -0
  58. triton/knobs.py +546 -0
  59. triton/language/__init__.py +342 -0
  60. triton/language/core.py +3405 -0
  61. triton/language/extra/__init__.py +26 -0
  62. triton/language/extra/cuda/__init__.py +16 -0
  63. triton/language/extra/cuda/gdc.py +42 -0
  64. triton/language/extra/cuda/libdevice.py +1629 -0
  65. triton/language/extra/cuda/utils.py +109 -0
  66. triton/language/extra/hip/__init__.py +5 -0
  67. triton/language/extra/hip/libdevice.py +491 -0
  68. triton/language/extra/hip/utils.py +35 -0
  69. triton/language/extra/libdevice.py +790 -0
  70. triton/language/math.py +249 -0
  71. triton/language/random.py +218 -0
  72. triton/language/semantic.py +1939 -0
  73. triton/language/standard.py +534 -0
  74. triton/language/target_info.py +54 -0
  75. triton/runtime/__init__.py +23 -0
  76. triton/runtime/_allocation.py +44 -0
  77. triton/runtime/_async_compile.py +55 -0
  78. triton/runtime/autotuner.py +476 -0
  79. triton/runtime/build.py +168 -0
  80. triton/runtime/cache.py +317 -0
  81. triton/runtime/driver.py +38 -0
  82. triton/runtime/errors.py +36 -0
  83. triton/runtime/interpreter.py +1414 -0
  84. triton/runtime/jit.py +1107 -0
  85. triton/runtime/tcc/include/_mingw.h +168 -0
  86. triton/runtime/tcc/include/assert.h +62 -0
  87. triton/runtime/tcc/include/conio.h +409 -0
  88. triton/runtime/tcc/include/ctype.h +281 -0
  89. triton/runtime/tcc/include/dir.h +31 -0
  90. triton/runtime/tcc/include/direct.h +68 -0
  91. triton/runtime/tcc/include/dirent.h +135 -0
  92. triton/runtime/tcc/include/dos.h +55 -0
  93. triton/runtime/tcc/include/errno.h +75 -0
  94. triton/runtime/tcc/include/excpt.h +123 -0
  95. triton/runtime/tcc/include/fcntl.h +52 -0
  96. triton/runtime/tcc/include/fenv.h +108 -0
  97. triton/runtime/tcc/include/float.h +75 -0
  98. triton/runtime/tcc/include/inttypes.h +297 -0
  99. triton/runtime/tcc/include/io.h +418 -0
  100. triton/runtime/tcc/include/iso646.h +36 -0
  101. triton/runtime/tcc/include/limits.h +116 -0
  102. triton/runtime/tcc/include/locale.h +91 -0
  103. triton/runtime/tcc/include/malloc.h +181 -0
  104. triton/runtime/tcc/include/math.h +497 -0
  105. triton/runtime/tcc/include/mem.h +13 -0
  106. triton/runtime/tcc/include/memory.h +40 -0
  107. triton/runtime/tcc/include/process.h +176 -0
  108. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  109. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  110. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  111. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  112. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  113. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  114. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  115. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  116. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  117. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  118. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  119. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  120. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  121. triton/runtime/tcc/include/setjmp.h +160 -0
  122. triton/runtime/tcc/include/share.h +28 -0
  123. triton/runtime/tcc/include/signal.h +63 -0
  124. triton/runtime/tcc/include/stdalign.h +16 -0
  125. triton/runtime/tcc/include/stdarg.h +14 -0
  126. triton/runtime/tcc/include/stdatomic.h +171 -0
  127. triton/runtime/tcc/include/stdbool.h +11 -0
  128. triton/runtime/tcc/include/stddef.h +42 -0
  129. triton/runtime/tcc/include/stdint.h +212 -0
  130. triton/runtime/tcc/include/stdio.h +429 -0
  131. triton/runtime/tcc/include/stdlib.h +591 -0
  132. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  133. triton/runtime/tcc/include/string.h +164 -0
  134. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  135. triton/runtime/tcc/include/sys/file.h +14 -0
  136. triton/runtime/tcc/include/sys/locking.h +30 -0
  137. triton/runtime/tcc/include/sys/stat.h +290 -0
  138. triton/runtime/tcc/include/sys/time.h +69 -0
  139. triton/runtime/tcc/include/sys/timeb.h +133 -0
  140. triton/runtime/tcc/include/sys/types.h +123 -0
  141. triton/runtime/tcc/include/sys/unistd.h +14 -0
  142. triton/runtime/tcc/include/sys/utime.h +146 -0
  143. triton/runtime/tcc/include/tcc/tcc_libm.h +618 -0
  144. triton/runtime/tcc/include/tccdefs.h +342 -0
  145. triton/runtime/tcc/include/tcclib.h +80 -0
  146. triton/runtime/tcc/include/tchar.h +1102 -0
  147. triton/runtime/tcc/include/tgmath.h +89 -0
  148. triton/runtime/tcc/include/time.h +287 -0
  149. triton/runtime/tcc/include/uchar.h +33 -0
  150. triton/runtime/tcc/include/unistd.h +1 -0
  151. triton/runtime/tcc/include/vadefs.h +11 -0
  152. triton/runtime/tcc/include/values.h +4 -0
  153. triton/runtime/tcc/include/varargs.h +12 -0
  154. triton/runtime/tcc/include/wchar.h +873 -0
  155. triton/runtime/tcc/include/wctype.h +172 -0
  156. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  157. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  158. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  159. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  160. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  161. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  162. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  163. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  164. triton/runtime/tcc/include/winapi/qos.h +72 -0
  165. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  166. triton/runtime/tcc/include/winapi/winbase.h +2958 -0
  167. triton/runtime/tcc/include/winapi/wincon.h +309 -0
  168. triton/runtime/tcc/include/winapi/windef.h +293 -0
  169. triton/runtime/tcc/include/winapi/windows.h +127 -0
  170. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  171. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  172. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  173. triton/runtime/tcc/include/winapi/winnt.h +5837 -0
  174. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  175. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  176. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  177. triton/runtime/tcc/include/winapi/winver.h +160 -0
  178. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  179. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  180. triton/runtime/tcc/lib/cuda.def +697 -0
  181. triton/runtime/tcc/lib/gdi32.def +337 -0
  182. triton/runtime/tcc/lib/kernel32.def +770 -0
  183. triton/runtime/tcc/lib/libtcc1.a +0 -0
  184. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  185. triton/runtime/tcc/lib/python3.def +810 -0
  186. triton/runtime/tcc/lib/python310.def +1610 -0
  187. triton/runtime/tcc/lib/python311.def +1633 -0
  188. triton/runtime/tcc/lib/python312.def +1703 -0
  189. triton/runtime/tcc/lib/python313.def +1651 -0
  190. triton/runtime/tcc/lib/python313t.def +1656 -0
  191. triton/runtime/tcc/lib/python314.def +1800 -0
  192. triton/runtime/tcc/lib/python314t.def +1809 -0
  193. triton/runtime/tcc/lib/python39.def +1644 -0
  194. triton/runtime/tcc/lib/python3t.def +905 -0
  195. triton/runtime/tcc/lib/user32.def +658 -0
  196. triton/runtime/tcc/libtcc.dll +0 -0
  197. triton/runtime/tcc/tcc.exe +0 -0
  198. triton/testing.py +543 -0
  199. triton/tools/__init__.py +0 -0
  200. triton/tools/build_extern.py +365 -0
  201. triton/tools/compile.py +210 -0
  202. triton/tools/disasm.py +143 -0
  203. triton/tools/extra/cuda/compile.c +70 -0
  204. triton/tools/extra/cuda/compile.h +14 -0
  205. triton/tools/extra/hip/compile.cpp +66 -0
  206. triton/tools/extra/hip/compile.h +13 -0
  207. triton/tools/link.py +322 -0
  208. triton/tools/mxfp.py +301 -0
  209. triton/tools/ragged_tma.py +92 -0
  210. triton/tools/tensor_descriptor.py +34 -0
  211. triton/windows_utils.py +405 -0
  212. triton_windows-3.5.0.post21.dist-info/METADATA +46 -0
  213. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  214. triton_windows-3.5.0.post21.dist-info/WHEEL +5 -0
  215. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  216. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  217. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
triton/tools/mxfp.py ADDED
@@ -0,0 +1,301 @@
1
+ """
2
+ Helper classes for working with low precision floating point types that
3
+ align with the opencompute (OCP) microscaling (MX) specification.
4
+ * MXFP4Tensor: 4-bit E2M1 floating point data
5
+ * MXScaleTensor: 8-bit E8M0 floating point data
6
+ Reference: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ class MXFP4Tensor:
13
+
14
+ def __init__(self, data=None, size=None, device=None):
15
+ """
16
+ Tensor class for working with four bit E2M1 floating point data as defined by the
17
+ opencompute microscaling specification.
18
+
19
+
20
+ Parameters:
21
+ - data: A torch tensor of float32 numbers to convert to fp4e2m1 microscaling format.
22
+ - size: The size of the tensor to create.
23
+ - device: The device on which to create the tensor.
24
+ """
25
+ self.device = device
26
+ if data is not None:
27
+ assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor"
28
+ self.device = data.device
29
+ self.data = self._from_float(data)
30
+ elif size is not None:
31
+ self.size = size if isinstance(size, tuple) else (size, )
32
+ else:
33
+ raise ValueError("Either parameter data or size must be provided")
34
+
35
+ def random(self):
36
+ S = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device)
37
+ E = torch.randint(0, 4, size=self.size, dtype=torch.uint8, device=self.device)
38
+ M = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device)
39
+
40
+ self.data = ((S << 3) | (E << 1) | M).type(torch.uint8)
41
+ return self
42
+
43
+ def to(self, dtype):
44
+ """
45
+ Convert fp4e2m1 data to float32.
46
+
47
+ Returns:
48
+ - A torch tensor of type dtype representing the fp4e2m1 data.
49
+ """
50
+ assert dtype == torch.float32, "Currently only float32 is supported for fp4e2m1 to float conversion"
51
+
52
+ data = self.data
53
+ S = ((data >> 3) & 0x1).type(dtype)
54
+ E = ((data >> 1) & 0x3).type(dtype)
55
+ M = (data & 0x1).type(dtype)
56
+
57
+ # The MXF4 E2M1 spec defines 0bS000 as zero
58
+ value = torch.zeros_like(S)
59
+ is_zero = (E == 0) & (M == 0)
60
+ non_zero_mask = ~is_zero
61
+ if non_zero_mask.any():
62
+ S_nz = S[non_zero_mask]
63
+ E_nz = E[non_zero_mask]
64
+ M_nz = M[non_zero_mask]
65
+
66
+ sign = torch.pow(-1, S_nz)
67
+ # Normal and subnormal handling for the exponent and mantissa
68
+ exponent = torch.where(E_nz == 0, E_nz, E_nz - 1)
69
+ mantissa = torch.where(E_nz == 0, M_nz * 0.5, 1.0 + M_nz * 0.5)
70
+ value_nz = sign * torch.pow(2, exponent) * mantissa
71
+
72
+ value[non_zero_mask] = value_nz
73
+
74
+ # For zeros, the values must remain zero with the correct sign
75
+ value[is_zero & (S == 1)] *= -1
76
+ return value.type(torch.float32)
77
+
78
+ def _from_float(self, values):
79
+ """
80
+ Convert float32 numbers to mxf4 e2m1 format.
81
+ * No encodings are reserved for Inf or NaN in mxf4.
82
+ * Conversion from float supports roundTiesToEven rounding mode.
83
+ * If a value exceeds the mxf4 representable range after rounding,
84
+ clamps to the maximum mxf4 magnitude, preserving the sign.
85
+ * If a value has magnitude less than the minimum subnormal magnitude
86
+ in mxf4 after rounding, converts to zero.
87
+
88
+ Parameters:
89
+ - values: A torch tensor of float32 numbers to convert to fp4 format.
90
+ """
91
+ S = torch.signbit(values).type(torch.uint8)
92
+ abs_values = torch.abs(values)
93
+
94
+ is_zero = (abs_values == 0)
95
+ is_invalid = torch.isnan(values) | torch.isinf(values)
96
+
97
+ # Enumerate all possible E2M1 exponent and mantissa values. We will
98
+ # use these to compare the distance between float32 and all possible
99
+ # E2M1 floats to find the nearest E2M1 representable value
100
+ E_bits = torch.tensor([0, 1, 2, 3], dtype=torch.uint8, device=self.device)
101
+ M_bits = torch.tensor([0, 1], dtype=torch.uint8, device=self.device)
102
+
103
+ candidate_values = []
104
+ candidate_E = []
105
+ candidate_M = []
106
+
107
+ for E in E_bits:
108
+ if E == 0:
109
+ # Subnormals
110
+ exponent = 0
111
+ for M in M_bits:
112
+ significand = M * 0.5
113
+ value = significand * (2**exponent)
114
+ candidate_values.append(value)
115
+ candidate_E.append(E)
116
+ candidate_M.append(M)
117
+ else:
118
+ # Normals
119
+ exponent = E.item() - 1
120
+ for M in M_bits:
121
+ significand = 1.0 + M * 0.5
122
+ value = significand * (2**exponent)
123
+ candidate_values.append(value)
124
+ candidate_E.append(E)
125
+ candidate_M.append(M)
126
+
127
+ candidates = torch.tensor(candidate_values, dtype=torch.float32, device=self.device)
128
+ candidate_E = torch.tensor(candidate_E, dtype=torch.uint8, device=self.device)
129
+ candidate_M = torch.tensor(candidate_M, dtype=torch.uint8, device=self.device)
130
+
131
+ abs_values_flat = abs_values.view(-1)
132
+ N = abs_values_flat.shape[0]
133
+ abs_values_expanded = abs_values_flat.unsqueeze(1)
134
+
135
+ # Clamp invalid values to the max e2m1 representable value
136
+ max_candidate_value = candidates.max().item()
137
+ abs_values_flat[is_invalid.view(-1)] = max_candidate_value
138
+
139
+ # Compute distance between all abs_values and candidate e2m1 values
140
+ errors = torch.abs(abs_values_expanded - candidates.unsqueeze(0))
141
+
142
+ # To implement roundTiesToEven, we need to break ties by preferring
143
+ # even mantissas (M == 0). We do so by adding an epsilon bias to shift
144
+ # the closest candidate with an even mantissa closer to the float value
145
+ min_errors, _ = torch.min(errors, dim=1, keepdim=True)
146
+ is_tie = (errors == min_errors)
147
+ # More than one candidate has the min error for some float value
148
+ if is_tie.sum() > 1:
149
+ M_bits_expanded = candidate_M.unsqueeze(0).expand(N, -1)
150
+ tie_breaker = (M_bits_expanded == 0).type(torch.int32)
151
+
152
+ errors = errors - (tie_breaker * 1e-6)
153
+
154
+ best_indices = torch.argmin(errors, dim=1)
155
+
156
+ E_selected = candidate_E[best_indices]
157
+ M_selected = candidate_M[best_indices]
158
+ E = E_selected.view(abs_values.shape)
159
+ M = M_selected.view(abs_values.shape)
160
+
161
+ E[is_zero] = 0
162
+ M[is_zero] = 0
163
+
164
+ return ((S << 3) | (E << 1) | M).type(torch.uint8)
165
+
166
+ def to_packed_tensor(self, dim):
167
+ """
168
+ Packs two e2m1 elements into a single uint8 along the specified dimension.
169
+
170
+ Parameters:
171
+ - dim: The dimension along which to pack the elements.
172
+
173
+ Returns:
174
+ - A torch tensor of dtype uint8 with two e2m1 elements packed into one uint8.
175
+ """
176
+ data = self.data
177
+ assert 0 <= dim < data.ndim, \
178
+ "The dimension to pack along is not within the range of tensor dimensions"
179
+
180
+ size_along_dim = data.size(dim)
181
+ new_size_along_dim = (size_along_dim + 1) // 2
182
+
183
+ # If the size is odd, we pad the data along dim with zeros at the end
184
+ if size_along_dim % 2 != 0:
185
+ pad_sizes = [0] * (2 * data.ndim)
186
+ pad_index = (data.ndim - dim - 1) * 2 + 1
187
+ pad_sizes[pad_index] = 1
188
+ data = torch.nn.functional.pad(data, pad_sizes, mode='constant', value=0)
189
+
190
+ new_shape = list(data.shape)
191
+ new_shape[dim] = new_size_along_dim
192
+ new_shape.insert(dim + 1, 2) # packed dimension of length 2
193
+ data = data.reshape(*new_shape)
194
+
195
+ low = data.select(dim + 1, 0)
196
+ high = data.select(dim + 1, 1)
197
+ packed = (high << 4) | low
198
+
199
+ return packed
200
+
201
+ def unpack_packed_tensor(self, packed_tensor, dim, original_shape):
202
+ """
203
+ Unpacks a tensor where two fp4 elements are packed into a single uint8.
204
+
205
+ Parameters:
206
+ - packed_tensor: The packed tensor
207
+ - dim: The dimension along which the tensor was packed.
208
+ - original_shape: The shape of the original tensor before packing.
209
+
210
+ Returns:
211
+ - A tensor with the original data unpacked into uint8 elements containing one
212
+ fp4e2m1 element in the least significant bits.
213
+ """
214
+ high = (packed_tensor >> 4) & 0xF
215
+ low = packed_tensor & 0xF
216
+
217
+ stacked = torch.stack((low, high), dim=dim + 1)
218
+
219
+ # Flatten along dim and dim+1 and then merge
220
+ shape = list(stacked.shape)
221
+ new_shape = shape[:dim] + [shape[dim] * 2] + shape[dim + 2:]
222
+ data = stacked.reshape(*new_shape)
223
+
224
+ # Remove any padding
225
+ if original_shape[dim] % 2 != 0:
226
+ indices = [slice(None)] * data.ndim
227
+ indices[dim] = slice(0, original_shape[dim])
228
+ data = data[tuple(indices)]
229
+
230
+ return data.type(torch.uint8)
231
+
232
+
233
+ class MXScaleTensor:
234
+
235
+ def __init__(self, data=None, size=None, device=None):
236
+ """
237
+ Tensor class for working with microscaling E8M0 block scale factors.
238
+
239
+ Parameters:
240
+ - data: A torch tensor of float32 numbers to convert to fp8e8m0 microscaling format.
241
+ - size: The size of the tensor to create.
242
+ - device: The device on which to create the tensor.
243
+ """
244
+ self.device = device
245
+ if data is not None:
246
+ assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor"
247
+ self.device = data.device
248
+ self.data = self._from_float(data)
249
+ elif size is not None:
250
+ self.size = size if isinstance(size, tuple) else (size, )
251
+ else:
252
+ raise ValueError("Either parameter data or size must be provided")
253
+
254
+ def random(self, low=None, high=None):
255
+ """
256
+ Generate random E8M0 data within a specified range.
257
+ * Excludes the NaN encoding (255).
258
+ """
259
+ bias = 127
260
+
261
+ min_exponent = 0 if low is None else max(0, int(torch.log2(torch.tensor(low))) + bias)
262
+ max_exponent = 254 if high is None else min(254, max(0, int(torch.log2(torch.tensor(high))) + bias))
263
+ assert min_exponent <= max_exponent, "Low must be less than or equal to high"
264
+
265
+ E = torch.randint(min_exponent, max_exponent + 1, size=self.size, dtype=torch.uint8, device=self.device)
266
+ self.data = E
267
+ return self
268
+
269
+ def to(self, dtype):
270
+ assert dtype == torch.float32, "Currently only float32 is supported for f8e8m0 to float conversion"
271
+ data = self.data.type(dtype)
272
+ is_nan = (data == 255)
273
+ e_biased = data.clone()
274
+ e_biased[is_nan] = 0
275
+ e = e_biased - 127
276
+ value = torch.pow(2.0, e)
277
+ value[is_nan] = torch.nan
278
+ return value.type(dtype)
279
+
280
+ def _from_float(self, values):
281
+ """
282
+ Convert float32 numbers to E8M0 format.
283
+ * Values <= 0, NaNs, and Infs are converted to the NaN encoding (255).
284
+ * Positive values are converted by computing the floor of log2(value) to get the exponent.
285
+
286
+ Parameters:
287
+ - values: A torch tensor of float32 numbers to convert to E8M0 format.
288
+ """
289
+ result = torch.empty_like(values, dtype=torch.uint8, device=self.device)
290
+
291
+ is_invalid = torch.isnan(values) | torch.isinf(values) | (values <= 0)
292
+ result[is_invalid] = 255
293
+
294
+ valid_values = values[~is_invalid]
295
+ e = torch.floor(torch.log2(valid_values))
296
+ e_biased = e + 127
297
+ e_biased_int = e_biased.type(torch.int32)
298
+ e_biased_clamped = torch.clamp(e_biased_int, 0, 254)
299
+ result[~is_invalid] = e_biased_clamped.type(torch.uint8)
300
+
301
+ return result
@@ -0,0 +1,92 @@
1
+ import triton
2
+ import triton.language as tl
3
+ from triton.tools.tensor_descriptor import TensorDescriptor
4
+
5
+ # fmt: off
6
+
7
+
8
+ def create_ragged_descriptor(T, block_shape, ragged_dim=0):
9
+ """
10
+ Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
11
+ which behaves like a concatenation (along the first axis) of subarrays
12
+ of potentially unequal size.
13
+
14
+ The load_ragged and store_ragged device functions can be used to read
15
+ and write from subarrays T[batch_offset : batch_offset + batch_size]
16
+ with hardware bounds-checking preventing any sort of leakage outside
17
+ the subarray.
18
+ """
19
+
20
+ block_shape = list(block_shape)
21
+ tensor_shape = list(T.shape)
22
+ rank = len(tensor_shape)
23
+
24
+ if ragged_dim < 0:
25
+ ragged_dim += rank
26
+
27
+ assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged"
28
+ assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
29
+
30
+ assert len(block_shape) == rank, "block shape must have same length as tensor shape"
31
+
32
+ max_int = 0x7fff0000
33
+ billion = 0x40000000 # == 2**30
34
+
35
+ assert tensor_shape[ragged_dim] <= billion, "number of rows may not exceed 2**30"
36
+ tensor_shape[ragged_dim] = billion
37
+ ragged_stride = T.stride(ragged_dim)
38
+
39
+ # we prepend an extra two dimensions and rely on the fact that pointers
40
+ # have 64-bit wraparound semantics:
41
+ tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)]
42
+ tma_shape = [max_int, max_int] + tensor_shape
43
+ box_shape = [1, 1] + block_shape
44
+
45
+ return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
46
+
47
+
48
+ @triton.jit
49
+ def to_ragged_indices(batch_offset, batch_size, row):
50
+ """
51
+ Helper function for load_ragged and store_ragged.
52
+ """
53
+
54
+ billion = 0x40000000 # == 2**30
55
+ x = billion - batch_size + row
56
+ y = batch_offset + batch_size
57
+
58
+ return billion, y, x
59
+
60
+
61
+ @triton.jit
62
+ def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
63
+ """
64
+ Read from a subarray T[batch_offset : batch_offset + batch_size] with
65
+ hardware bounds-checking, where reading outside the subarray gives zeros.
66
+
67
+ Coords should be an appropriately-sized list of integers, just like in
68
+ TMA.load().
69
+ """
70
+
71
+ tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")
72
+
73
+ c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
74
+ data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
75
+ data = tl.reshape(data, data.shape[2:])
76
+ return data
77
+
78
+
79
+ @triton.jit
80
+ def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
81
+ """
82
+ Write to a subarray T[batch_offset : batch_offset + batch_size] with
83
+ hardware bounds-checking, where writes outside the subarray are masked
84
+ correctly.
85
+
86
+ Coords should be an appropriately-sized list of integers, just like in
87
+ TMA.store().
88
+ """
89
+
90
+ c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
91
+ data = tl.reshape(data, [1, 1] + data.shape)
92
+ TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
@@ -0,0 +1,34 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Any
3
+ from triton._utils import validate_block_shape
4
+
5
+
6
+ @dataclass
7
+ class TensorDescriptor:
8
+ base: Any
9
+ shape: List[int]
10
+ strides: List[int]
11
+ block_shape: List[int]
12
+ padding: str = "zero"
13
+
14
+ def __post_init__(self):
15
+ rank = len(self.shape)
16
+ assert len(self.strides) == rank, f"rank mismatch: {self}"
17
+ assert len(self.block_shape) == rank, f"rank mismatch: {self}"
18
+ assert rank > 0, "rank must not be zero"
19
+ assert rank <= 5, "rank cannot be more than 5"
20
+ ty = type(self.base)
21
+ if ty.__name__ not in ("FakeTensor", "FunctionalTensor"):
22
+ assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
23
+ validate_block_shape(self.block_shape)
24
+ elem_bytes = self.base.dtype.itemsize
25
+ for stride in self.strides[:-1]:
26
+ assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
27
+ assert self.strides[-1] == 1, "Last dimension must be contiguous"
28
+ assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
29
+ if self.padding == "nan":
30
+ assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
31
+
32
+ @staticmethod
33
+ def from_tensor(tensor: Any, block_shape: List[int], padding="zero"):
34
+ return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, padding)