triton-windows 3.2.0.post12__cp39-cp39-win_amd64.whl → 3.3.0a0.post12__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
triton/tools/compile.py
CHANGED
|
@@ -8,7 +8,6 @@ from typing import List
|
|
|
8
8
|
|
|
9
9
|
import triton
|
|
10
10
|
import triton.backends
|
|
11
|
-
from triton.compiler.code_generator import kernel_suffix
|
|
12
11
|
from triton.backends.nvidia.driver import ty_to_cpp
|
|
13
12
|
|
|
14
13
|
desc = """
|
|
@@ -91,28 +90,29 @@ if __name__ == "__main__":
|
|
|
91
90
|
pass
|
|
92
91
|
return None
|
|
93
92
|
|
|
94
|
-
hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s}
|
|
93
|
+
hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s}
|
|
95
94
|
hints = {k: v for k, v in hints.items() if v is not None}
|
|
96
95
|
constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)}
|
|
97
96
|
constants = {k: v for k, v in constants.items() if v is not None}
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
97
|
+
for key, value in hints.items():
|
|
98
|
+
if value == 1:
|
|
99
|
+
constants[kernel.arg_names[key[0]]] = value
|
|
100
|
+
signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)}
|
|
101
|
+
for key in constants:
|
|
102
|
+
signature[key] = 'constexpr'
|
|
103
103
|
const_sig = 'x'.join([str(v) for v in constants.values()])
|
|
104
104
|
doc_string = [f"{k}={v}" for k, v in constants.items()]
|
|
105
105
|
doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"]
|
|
106
|
-
|
|
107
106
|
# compile ast into cubin
|
|
108
107
|
for h in hints.values():
|
|
109
108
|
assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
|
|
110
|
-
attrs =
|
|
111
|
-
|
|
112
|
-
constants.update({kernel.arg_names[p]: v})
|
|
113
|
-
src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs)
|
|
109
|
+
attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16}
|
|
110
|
+
src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs)
|
|
114
111
|
opts = {"num_warps": args.num_warps, "num_stages": args.num_stages}
|
|
115
112
|
ccinfo = triton.compile(src, options=opts)
|
|
113
|
+
if ccinfo.metadata.global_scratch_size > 0:
|
|
114
|
+
raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented")
|
|
115
|
+
|
|
116
116
|
arg_names = []
|
|
117
117
|
arg_types = []
|
|
118
118
|
arg_names_not_1 = []
|
|
@@ -123,23 +123,30 @@ if __name__ == "__main__":
|
|
|
123
123
|
arg_types.append(signature[arg_name])
|
|
124
124
|
arg_names_not_1.append(arg_name)
|
|
125
125
|
arg_types_not_1.append(signature[arg_name])
|
|
126
|
-
elif i
|
|
126
|
+
elif hints.get((i, ), None) == 1:
|
|
127
127
|
arg_names.append(arg_name)
|
|
128
|
-
arg_types.append(
|
|
128
|
+
arg_types.append("i32")
|
|
129
129
|
|
|
130
130
|
# dump C stub code
|
|
131
|
-
suffix =
|
|
131
|
+
suffix = ''
|
|
132
|
+
for i, ty in enumerate(signature.values()):
|
|
133
|
+
suffix += str(i)
|
|
134
|
+
if hints.get((i, ), None) == 1:
|
|
135
|
+
suffix += 'c'
|
|
136
|
+
if hints.get((i, ), None) == 16:
|
|
137
|
+
suffix += 'd'
|
|
132
138
|
func_name = '_'.join([out_name, sig_hash, suffix])
|
|
133
|
-
|
|
139
|
+
asm = ccinfo.asm["cubin"] # store binary data once
|
|
140
|
+
hex_ = str(binascii.hexlify(asm))[2:-1]
|
|
134
141
|
params = {
|
|
135
142
|
"kernel_name": func_name,
|
|
136
143
|
"triton_kernel_name": args.kernel_name,
|
|
137
|
-
"bin_size": len(
|
|
144
|
+
"bin_size": len(asm),
|
|
138
145
|
"bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
|
|
139
146
|
"signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
|
|
140
147
|
"full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
|
|
141
|
-
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1]),
|
|
142
|
-
"num_args": len(arg_names_not_1),
|
|
148
|
+
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"]),
|
|
149
|
+
"num_args": len(arg_names_not_1) + 1,
|
|
143
150
|
"kernel_docstring": doc_string,
|
|
144
151
|
"shared": ccinfo.metadata.shared,
|
|
145
152
|
"num_warps": args.num_warps,
|
|
@@ -150,6 +157,6 @@ if __name__ == "__main__":
|
|
|
150
157
|
"_placeholder": "",
|
|
151
158
|
}
|
|
152
159
|
for ext in ['h', 'c']:
|
|
153
|
-
template_path = Path(__file__).parent / f"compile.{ext}"
|
|
160
|
+
template_path = Path(__file__).parent / "extra" / "cuda" / f"compile.{ext}"
|
|
154
161
|
with out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}").open("w") as fp:
|
|
155
162
|
fp.write(Path(template_path).read_text().format(**params))
|
|
@@ -60,6 +60,7 @@ CUresult {kernel_name}(CUstream stream, {signature}) {{
|
|
|
60
60
|
unsigned int gX = {gridX};
|
|
61
61
|
unsigned int gY = {gridY};
|
|
62
62
|
unsigned int gZ = {gridZ};
|
|
63
|
+
CUdeviceptr global_scratch = 0;
|
|
63
64
|
void *args[{num_args}] = {{ {arg_pointers} }};
|
|
64
65
|
// TODO: shared memory
|
|
65
66
|
if(gX * gY * gZ > 0)
|
triton/tools/mxfp.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helper classes for working with low precision floating point types that
|
|
3
|
+
align with the opencompute (OCP) microscaling (MX) specification.
|
|
4
|
+
* MXFP4Tensor: 4-bit E2M1 floating point data
|
|
5
|
+
* MXScaleTensor: 8-bit E8M0 floating point data
|
|
6
|
+
Reference: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MXFP4Tensor:
|
|
13
|
+
|
|
14
|
+
def __init__(self, data=None, size=None, device=None):
|
|
15
|
+
"""
|
|
16
|
+
Tensor class for working with four bit E2M1 floating point data as defined by the
|
|
17
|
+
opencompute microscaling specification.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
Parameters:
|
|
21
|
+
- data: A torch tensor of float32 numbers to convert to fp4e2m1 microscaling format.
|
|
22
|
+
- size: The size of the tensor to create.
|
|
23
|
+
- device: The device on which to create the tensor.
|
|
24
|
+
"""
|
|
25
|
+
self.device = device
|
|
26
|
+
if data is not None:
|
|
27
|
+
assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor"
|
|
28
|
+
self.device = data.device
|
|
29
|
+
self.data = self._from_float(data)
|
|
30
|
+
elif size is not None:
|
|
31
|
+
self.size = size if isinstance(size, tuple) else (size, )
|
|
32
|
+
else:
|
|
33
|
+
raise ValueError("Either parameter data or size must be provided")
|
|
34
|
+
|
|
35
|
+
def random(self):
|
|
36
|
+
S = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device)
|
|
37
|
+
E = torch.randint(0, 4, size=self.size, dtype=torch.uint8, device=self.device)
|
|
38
|
+
M = torch.randint(0, 2, size=self.size, dtype=torch.uint8, device=self.device)
|
|
39
|
+
|
|
40
|
+
self.data = ((S << 3) | (E << 1) | M).type(torch.uint8)
|
|
41
|
+
return self
|
|
42
|
+
|
|
43
|
+
def to(self, dtype):
|
|
44
|
+
"""
|
|
45
|
+
Convert fp4e2m1 data to float32.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
- A torch tensor of type dtype representing the fp4e2m1 data.
|
|
49
|
+
"""
|
|
50
|
+
assert dtype == torch.float32, "Currently only float32 is supported for fp4e2m1 to float conversion"
|
|
51
|
+
|
|
52
|
+
data = self.data
|
|
53
|
+
S = ((data >> 3) & 0x1).type(dtype)
|
|
54
|
+
E = ((data >> 1) & 0x3).type(dtype)
|
|
55
|
+
M = (data & 0x1).type(dtype)
|
|
56
|
+
|
|
57
|
+
# The MXF4 E2M1 spec defines 0bS000 as zero
|
|
58
|
+
value = torch.zeros_like(S)
|
|
59
|
+
is_zero = (E == 0) & (M == 0)
|
|
60
|
+
non_zero_mask = ~is_zero
|
|
61
|
+
if non_zero_mask.any():
|
|
62
|
+
S_nz = S[non_zero_mask]
|
|
63
|
+
E_nz = E[non_zero_mask]
|
|
64
|
+
M_nz = M[non_zero_mask]
|
|
65
|
+
|
|
66
|
+
sign = torch.pow(-1, S_nz)
|
|
67
|
+
# Normal and subnormal handling for the exponent and mantissa
|
|
68
|
+
exponent = torch.where(E_nz == 0, E_nz, E_nz - 1)
|
|
69
|
+
mantissa = torch.where(E_nz == 0, M_nz * 0.5, 1.0 + M_nz * 0.5)
|
|
70
|
+
value_nz = sign * torch.pow(2, exponent) * mantissa
|
|
71
|
+
|
|
72
|
+
value[non_zero_mask] = value_nz
|
|
73
|
+
|
|
74
|
+
# For zeros, the values must remain zero with the correct sign
|
|
75
|
+
value[is_zero & (S == 1)] *= -1
|
|
76
|
+
return value.type(torch.float32)
|
|
77
|
+
|
|
78
|
+
def _from_float(self, values):
|
|
79
|
+
"""
|
|
80
|
+
Convert float32 numbers to mxf4 e2m1 format.
|
|
81
|
+
* No encodings are reserved for Inf or NaN in mxf4.
|
|
82
|
+
* Conversion from float supports roundTiesToEven rounding mode.
|
|
83
|
+
* If a value exceeds the mxf4 representable range after rounding,
|
|
84
|
+
clamps to the maximum mxf4 magnitude, preserving the sign.
|
|
85
|
+
* If a value has magnitude less than the minimum subnormal magnitude
|
|
86
|
+
in mxf4 after rounding, converts to zero.
|
|
87
|
+
|
|
88
|
+
Parameters:
|
|
89
|
+
- values: A torch tensor of float32 numbers to convert to fp4 format.
|
|
90
|
+
"""
|
|
91
|
+
S = torch.signbit(values).type(torch.uint8)
|
|
92
|
+
abs_values = torch.abs(values)
|
|
93
|
+
|
|
94
|
+
is_zero = (abs_values == 0)
|
|
95
|
+
is_invalid = torch.isnan(values) | torch.isinf(values)
|
|
96
|
+
|
|
97
|
+
# Enumerate all possible E2M1 exponent and mantissa values. We will
|
|
98
|
+
# use these to compare the distance between float32 and all possible
|
|
99
|
+
# E2M1 floats to find the nearest E2M1 representable value
|
|
100
|
+
E_bits = torch.tensor([0, 1, 2, 3], dtype=torch.uint8, device=self.device)
|
|
101
|
+
M_bits = torch.tensor([0, 1], dtype=torch.uint8, device=self.device)
|
|
102
|
+
|
|
103
|
+
candidate_values = []
|
|
104
|
+
candidate_E = []
|
|
105
|
+
candidate_M = []
|
|
106
|
+
|
|
107
|
+
for E in E_bits:
|
|
108
|
+
if E == 0:
|
|
109
|
+
# Subnormals
|
|
110
|
+
exponent = 0
|
|
111
|
+
for M in M_bits:
|
|
112
|
+
significand = M * 0.5
|
|
113
|
+
value = significand * (2**exponent)
|
|
114
|
+
candidate_values.append(value)
|
|
115
|
+
candidate_E.append(E)
|
|
116
|
+
candidate_M.append(M)
|
|
117
|
+
else:
|
|
118
|
+
# Normals
|
|
119
|
+
exponent = E.item() - 1
|
|
120
|
+
for M in M_bits:
|
|
121
|
+
significand = 1.0 + M * 0.5
|
|
122
|
+
value = significand * (2**exponent)
|
|
123
|
+
candidate_values.append(value)
|
|
124
|
+
candidate_E.append(E)
|
|
125
|
+
candidate_M.append(M)
|
|
126
|
+
|
|
127
|
+
candidates = torch.tensor(candidate_values, dtype=torch.float32, device=self.device)
|
|
128
|
+
candidate_E = torch.tensor(candidate_E, dtype=torch.uint8, device=self.device)
|
|
129
|
+
candidate_M = torch.tensor(candidate_M, dtype=torch.uint8, device=self.device)
|
|
130
|
+
|
|
131
|
+
abs_values_flat = abs_values.view(-1)
|
|
132
|
+
N = abs_values_flat.shape[0]
|
|
133
|
+
abs_values_expanded = abs_values_flat.unsqueeze(1)
|
|
134
|
+
|
|
135
|
+
# Clamp invalid values to the max e2m1 representable value
|
|
136
|
+
max_candidate_value = candidates.max().item()
|
|
137
|
+
abs_values_flat[is_invalid.view(-1)] = max_candidate_value
|
|
138
|
+
|
|
139
|
+
# Compute distance between all abs_values and candidate e2m1 values
|
|
140
|
+
errors = torch.abs(abs_values_expanded - candidates.unsqueeze(0))
|
|
141
|
+
|
|
142
|
+
# To implement roundTiesToEven, we need to break ties by preferring
|
|
143
|
+
# even mantissas (M == 0). We do so by adding an epsilon bias to shift
|
|
144
|
+
# the closest candidate with an even mantissa closer to the float value
|
|
145
|
+
min_errors, _ = torch.min(errors, dim=1, keepdim=True)
|
|
146
|
+
is_tie = (errors == min_errors)
|
|
147
|
+
# More than one candidate has the min error for some float value
|
|
148
|
+
if is_tie.sum() > 1:
|
|
149
|
+
M_bits_expanded = candidate_M.unsqueeze(0).expand(N, -1)
|
|
150
|
+
tie_breaker = (M_bits_expanded == 0).type(torch.int32)
|
|
151
|
+
|
|
152
|
+
errors = errors - (tie_breaker * 1e-6)
|
|
153
|
+
|
|
154
|
+
best_indices = torch.argmin(errors, dim=1)
|
|
155
|
+
|
|
156
|
+
E_selected = candidate_E[best_indices]
|
|
157
|
+
M_selected = candidate_M[best_indices]
|
|
158
|
+
E = E_selected.view(abs_values.shape)
|
|
159
|
+
M = M_selected.view(abs_values.shape)
|
|
160
|
+
|
|
161
|
+
E[is_zero] = 0
|
|
162
|
+
M[is_zero] = 0
|
|
163
|
+
|
|
164
|
+
return ((S << 3) | (E << 1) | M).type(torch.uint8)
|
|
165
|
+
|
|
166
|
+
def to_packed_tensor(self, dim):
|
|
167
|
+
"""
|
|
168
|
+
Packs two e2m1 elements into a single uint8 along the specified dimension.
|
|
169
|
+
|
|
170
|
+
Parameters:
|
|
171
|
+
- dim: The dimension along which to pack the elements.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
- A torch tensor of dtype uint8 with two e2m1 elements packed into one uint8.
|
|
175
|
+
"""
|
|
176
|
+
data = self.data
|
|
177
|
+
assert 0 <= dim < data.ndim, \
|
|
178
|
+
"The dimension to pack along is not within the range of tensor dimensions"
|
|
179
|
+
|
|
180
|
+
size_along_dim = data.size(dim)
|
|
181
|
+
new_size_along_dim = (size_along_dim + 1) // 2
|
|
182
|
+
|
|
183
|
+
# If the size is odd, we pad the data along dim with zeros at the end
|
|
184
|
+
if size_along_dim % 2 != 0:
|
|
185
|
+
pad_sizes = [0] * (2 * data.ndim)
|
|
186
|
+
pad_index = (data.ndim - dim - 1) * 2 + 1
|
|
187
|
+
pad_sizes[pad_index] = 1
|
|
188
|
+
data = torch.nn.functional.pad(data, pad_sizes, mode='constant', value=0)
|
|
189
|
+
|
|
190
|
+
new_shape = list(data.shape)
|
|
191
|
+
new_shape[dim] = new_size_along_dim
|
|
192
|
+
new_shape.insert(dim + 1, 2) # packed dimension of length 2
|
|
193
|
+
data = data.reshape(*new_shape)
|
|
194
|
+
|
|
195
|
+
low = data.select(dim + 1, 0)
|
|
196
|
+
high = data.select(dim + 1, 1)
|
|
197
|
+
packed = (high << 4) | low
|
|
198
|
+
|
|
199
|
+
return packed
|
|
200
|
+
|
|
201
|
+
def unpack_packed_tensor(self, packed_tensor, dim, original_shape):
|
|
202
|
+
"""
|
|
203
|
+
Unpacks a tensor where two fp4 elements are packed into a single uint8.
|
|
204
|
+
|
|
205
|
+
Parameters:
|
|
206
|
+
- packed_tensor: The packed tensor
|
|
207
|
+
- dim: The dimension along which the tensor was packed.
|
|
208
|
+
- original_shape: The shape of the original tensor before packing.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
- A tensor with the original data unpacked into uint8 elements containing one
|
|
212
|
+
fp4e2m1 element in the least significant bits.
|
|
213
|
+
"""
|
|
214
|
+
high = (packed_tensor >> 4) & 0xF
|
|
215
|
+
low = packed_tensor & 0xF
|
|
216
|
+
|
|
217
|
+
stacked = torch.stack((low, high), dim=dim + 1)
|
|
218
|
+
|
|
219
|
+
# Flatten along dim and dim+1 and then merge
|
|
220
|
+
shape = list(stacked.shape)
|
|
221
|
+
new_shape = shape[:dim] + [shape[dim] * 2] + shape[dim + 2:]
|
|
222
|
+
data = stacked.reshape(*new_shape)
|
|
223
|
+
|
|
224
|
+
# Remove any padding
|
|
225
|
+
if original_shape[dim] % 2 != 0:
|
|
226
|
+
indices = [slice(None)] * data.ndim
|
|
227
|
+
indices[dim] = slice(0, original_shape[dim])
|
|
228
|
+
data = data[tuple(indices)]
|
|
229
|
+
|
|
230
|
+
return data.type(torch.uint8)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class MXScaleTensor:
|
|
234
|
+
|
|
235
|
+
def __init__(self, data=None, size=None, device=None):
|
|
236
|
+
"""
|
|
237
|
+
Tensor class for working with microscaling E8M0 block scale factors.
|
|
238
|
+
|
|
239
|
+
Parameters:
|
|
240
|
+
- data: A torch tensor of float32 numbers to convert to fp8e8m0 microscaling format.
|
|
241
|
+
- size: The size of the tensor to create.
|
|
242
|
+
- device: The device on which to create the tensor.
|
|
243
|
+
"""
|
|
244
|
+
self.device = device
|
|
245
|
+
if data is not None:
|
|
246
|
+
assert isinstance(data, torch.Tensor), "Parameter data must be a torch tensor"
|
|
247
|
+
self.device = data.device
|
|
248
|
+
self.data = self._from_float(data)
|
|
249
|
+
elif size is not None:
|
|
250
|
+
self.size = size if isinstance(size, tuple) else (size, )
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError("Either parameter data or size must be provided")
|
|
253
|
+
|
|
254
|
+
def random(self, low=None, high=None):
|
|
255
|
+
"""
|
|
256
|
+
Generate random E8M0 data within a specified range.
|
|
257
|
+
* Excludes the NaN encoding (255).
|
|
258
|
+
"""
|
|
259
|
+
bias = 127
|
|
260
|
+
|
|
261
|
+
min_exponent = 0 if low is None else max(0, int(torch.log2(torch.tensor(low))) + bias)
|
|
262
|
+
max_exponent = 254 if high is None else min(254, max(0, int(torch.log2(torch.tensor(high))) + bias))
|
|
263
|
+
assert min_exponent <= max_exponent, "Low must be less than or equal to high"
|
|
264
|
+
|
|
265
|
+
E = torch.randint(min_exponent, max_exponent + 1, size=self.size, dtype=torch.uint8, device=self.device)
|
|
266
|
+
self.data = E
|
|
267
|
+
return self
|
|
268
|
+
|
|
269
|
+
def to(self, dtype):
|
|
270
|
+
assert dtype == torch.float32, "Currently only float32 is supported for f8e8m0 to float conversion"
|
|
271
|
+
data = self.data.type(dtype)
|
|
272
|
+
is_nan = (data == 255)
|
|
273
|
+
e_biased = data.clone()
|
|
274
|
+
e_biased[is_nan] = 0
|
|
275
|
+
e = e_biased - 127
|
|
276
|
+
value = torch.pow(2.0, e)
|
|
277
|
+
value[is_nan] = torch.nan
|
|
278
|
+
return value.type(dtype)
|
|
279
|
+
|
|
280
|
+
def _from_float(self, values):
|
|
281
|
+
"""
|
|
282
|
+
Convert float32 numbers to E8M0 format.
|
|
283
|
+
* Values <= 0, NaNs, and Infs are converted to the NaN encoding (255).
|
|
284
|
+
* Positive values are converted by computing the floor of log2(value) to get the exponent.
|
|
285
|
+
|
|
286
|
+
Parameters:
|
|
287
|
+
- values: A torch tensor of float32 numbers to convert to E8M0 format.
|
|
288
|
+
"""
|
|
289
|
+
result = torch.empty_like(values, dtype=torch.uint8, device=self.device)
|
|
290
|
+
|
|
291
|
+
is_invalid = torch.isnan(values) | torch.isinf(values) | (values <= 0)
|
|
292
|
+
result[is_invalid] = 255
|
|
293
|
+
|
|
294
|
+
valid_values = values[~is_invalid]
|
|
295
|
+
e = torch.floor(torch.log2(valid_values))
|
|
296
|
+
e_biased = e + 127
|
|
297
|
+
e_biased_int = e_biased.type(torch.int32)
|
|
298
|
+
e_biased_clamped = torch.clamp(e_biased_int, 0, 254)
|
|
299
|
+
result[~is_invalid] = e_biased_clamped.type(torch.uint8)
|
|
300
|
+
|
|
301
|
+
return result
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: triton-windows
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.3.0a0.post12
|
|
4
4
|
Summary: A language and compiler for custom Deep Learning operations
|
|
5
5
|
Home-page: https://github.com/woct0rdho/triton-windows
|
|
6
6
|
Author: Philippe Tillet, Dian Wu
|
|
@@ -15,15 +15,17 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.13
|
|
18
|
+
Requires-Dist: setuptools>=40.8.0
|
|
18
19
|
Provides-Extra: build
|
|
19
20
|
Requires-Dist: cmake>=3.20; extra == "build"
|
|
20
21
|
Requires-Dist: lit; extra == "build"
|
|
21
22
|
Provides-Extra: tests
|
|
22
23
|
Requires-Dist: autopep8; extra == "tests"
|
|
23
|
-
Requires-Dist: flake8; extra == "tests"
|
|
24
24
|
Requires-Dist: isort; extra == "tests"
|
|
25
25
|
Requires-Dist: numpy; extra == "tests"
|
|
26
26
|
Requires-Dist: pytest; extra == "tests"
|
|
27
|
+
Requires-Dist: pytest-forked; extra == "tests"
|
|
28
|
+
Requires-Dist: pytest-xdist; extra == "tests"
|
|
27
29
|
Requires-Dist: scipy>=1.7.1; extra == "tests"
|
|
28
30
|
Requires-Dist: llnl-hatchet; extra == "tests"
|
|
29
31
|
Provides-Extra: tutorials
|
|
@@ -36,4 +38,5 @@ Dynamic: classifier
|
|
|
36
38
|
Dynamic: home-page
|
|
37
39
|
Dynamic: keywords
|
|
38
40
|
Dynamic: provides-extra
|
|
41
|
+
Dynamic: requires-dist
|
|
39
42
|
Dynamic: summary
|