torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/quantization.py
ADDED
|
@@ -0,0 +1,686 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
import torch
|
|
9
|
+
import torch.cuda as cuda
|
|
10
|
+
|
|
11
|
+
# pyre-ignore[21]: Could not find a module corresponding to import `triton`
|
|
12
|
+
import triton
|
|
13
|
+
|
|
14
|
+
# pyre-ignore[21]: Could not find a module corresponding to import `triton.language`
|
|
15
|
+
import triton.language as tl
|
|
16
|
+
|
|
17
|
+
# pyre-ignore[21]: Could not find a module corresponding to import `triton.runtime`
|
|
18
|
+
import triton.runtime as tr
|
|
19
|
+
from torch.distributed import ReduceOp
|
|
20
|
+
|
|
21
|
+
SCALE_DTYPE: torch.dtype = torch.float32
|
|
22
|
+
SCALE_DTYPE_BYTES: int = 4
|
|
23
|
+
SCALE_TL_DTYPE = tl.float32
|
|
24
|
+
SCALE_TL_DTYPE_BYTES = tl.constexpr(4)
|
|
25
|
+
|
|
26
|
+
BLOCK_SIZE_T: int = 2048
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# pyre-ignore[11]: Annotation `tl.constexpr` is not defined
|
|
30
|
+
def _get_fp8_max() -> tl.constexpr:
|
|
31
|
+
if cuda.get_device_capability() >= (9, 0):
|
|
32
|
+
return tl.constexpr(448.0)
|
|
33
|
+
else:
|
|
34
|
+
return tl.constexpr(127)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _get_fp8_type() -> tl.constexpr:
|
|
38
|
+
if cuda.get_device_capability() >= (9, 0):
|
|
39
|
+
return tl.constexpr(tl.float8e4nv)
|
|
40
|
+
else:
|
|
41
|
+
return tl.constexpr(tl.int8)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@triton.jit
|
|
45
|
+
# pyre-ignore[11]: Annotation `tl.tensor` is not defined
|
|
46
|
+
def _kernel_calculate_scale(row_max, TL_FP8_MAX: tl.constexpr) -> tl.tensor:
|
|
47
|
+
row_scale = TL_FP8_MAX / row_max
|
|
48
|
+
is_inf = row_scale == float("inf")
|
|
49
|
+
row_scale = tl.where(is_inf, 1.0, row_scale)
|
|
50
|
+
return row_scale
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@triton.jit
|
|
54
|
+
def _fused_kernel_quantize_into_fp8(
|
|
55
|
+
i_ptrs,
|
|
56
|
+
i_shapes,
|
|
57
|
+
i_strides,
|
|
58
|
+
i_offsets,
|
|
59
|
+
i_dtype,
|
|
60
|
+
o_ptr,
|
|
61
|
+
o_size_bytes_per_rank,
|
|
62
|
+
all_reduce_size,
|
|
63
|
+
BLOCK_SIZE: tl.constexpr,
|
|
64
|
+
TL_FP8_TYPE: tl.constexpr,
|
|
65
|
+
TL_FP8_MAX: tl.constexpr,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Kernel to quantize a set of input tensors into fp8. The input tensors are
|
|
69
|
+
expected to be two-dimensional and the output tensor is expected to be
|
|
70
|
+
one-dimensional. The output tensor is expected to be large enough to hold
|
|
71
|
+
the quantized input and scales for all input tensors. The quantized input
|
|
72
|
+
and scales are interleaved in the output tensor. The quantized input
|
|
73
|
+
is stored as fp8 and the scales are stored as fp32.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
i_ptrs: Pointers to the input tensors to be quantized
|
|
77
|
+
i_shapes: Shapes of the input tensors to be quantized
|
|
78
|
+
i_strides: Strides of the input tensors to be quantized
|
|
79
|
+
i_offsets: Offsets of the output tensors for each input tensor
|
|
80
|
+
i_dtype: Dummy tensor that carries the dtype of the input tensors
|
|
81
|
+
o_ptr: Pointer to the output tensor for the quantized input and scales
|
|
82
|
+
o_size_bytes_per_rank: Size in bytes in the output tensor per rank
|
|
83
|
+
all_reduce_size: Size of the all-reduce group
|
|
84
|
+
BLOCK_SIZE: Block size for the quantization
|
|
85
|
+
NUM_SM: Number of SMs to use for the quantization
|
|
86
|
+
"""
|
|
87
|
+
# Index of the row in the input tensor
|
|
88
|
+
i_row_idx = tl.program_id(0)
|
|
89
|
+
# Index of the input tensor
|
|
90
|
+
i_idx = tl.program_id(1)
|
|
91
|
+
|
|
92
|
+
# Number of rows and colums in the input tensor
|
|
93
|
+
i_rows_num = tl.load(i_shapes + i_idx * 2)
|
|
94
|
+
if i_row_idx >= i_rows_num:
|
|
95
|
+
return
|
|
96
|
+
i_cols_num = tl.load(i_shapes + i_idx * 2 + 1)
|
|
97
|
+
|
|
98
|
+
# Stride to advance by a single row and column in the input tensor
|
|
99
|
+
# assume contiguous tensors
|
|
100
|
+
i_row_stride = tl.load(i_strides + i_idx * 2)
|
|
101
|
+
i_col_stride = tl.load(i_strides + i_idx * 2 + 1)
|
|
102
|
+
|
|
103
|
+
# Pointer to the input tensor
|
|
104
|
+
i_ptr = tl.load(i_ptrs + i_idx).to(i_dtype.dtype)
|
|
105
|
+
|
|
106
|
+
# Number of the rows in the input tensor that are processed by a single
|
|
107
|
+
# rank
|
|
108
|
+
i_row_slice_size = tl.cdiv(i_rows_num, all_reduce_size)
|
|
109
|
+
# Index of the row slice in the input tensor
|
|
110
|
+
i_row_slice_idx = i_row_idx // i_row_slice_size
|
|
111
|
+
|
|
112
|
+
# Size in bytes of a single input tensor row quantized and written to the
|
|
113
|
+
# output tensor
|
|
114
|
+
o_row_size_bytes = (
|
|
115
|
+
tl.cdiv(i_cols_num, SCALE_TL_DTYPE_BYTES) + 1
|
|
116
|
+
) * SCALE_TL_DTYPE_BYTES
|
|
117
|
+
|
|
118
|
+
# Pointer to the output tensor where
|
|
119
|
+
o_offset = (
|
|
120
|
+
o_size_bytes_per_rank * i_row_slice_idx
|
|
121
|
+
+ tl.load(i_offsets + i_idx)
|
|
122
|
+
+ (i_row_idx % i_row_slice_size) * o_row_size_bytes
|
|
123
|
+
)
|
|
124
|
+
# Pointer to the output tensor where the scale and quantized row will
|
|
125
|
+
# be written
|
|
126
|
+
o_curr_ptr = o_ptr + o_offset
|
|
127
|
+
o_scale_ptr = o_curr_ptr.to(tl.pointer_type(SCALE_TL_DTYPE))
|
|
128
|
+
o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE)) # type: ignore
|
|
129
|
+
|
|
130
|
+
# Compute maximum for the current row block by block
|
|
131
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
132
|
+
col_maxes = tl.full((BLOCK_SIZE,), 0, dtype=tl.float32)
|
|
133
|
+
for i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
|
|
134
|
+
i_row_block = tl.load(
|
|
135
|
+
i_ptr + i_row_idx * i_row_stride + col_offsets * i_col_stride,
|
|
136
|
+
mask=col_offsets < i_cols_num,
|
|
137
|
+
other=0.0,
|
|
138
|
+
)
|
|
139
|
+
col_maxes = tl.maximum(tl.abs(i_row_block), col_maxes)
|
|
140
|
+
col_offsets += BLOCK_SIZE
|
|
141
|
+
|
|
142
|
+
# Compute and store scale for the current row
|
|
143
|
+
i_row_max = tl.max(col_maxes)
|
|
144
|
+
i_row_scale = _kernel_calculate_scale(i_row_max, TL_FP8_MAX)
|
|
145
|
+
tl.store(o_scale_ptr, i_row_scale)
|
|
146
|
+
|
|
147
|
+
# Scale and quantize current row block by block
|
|
148
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
149
|
+
for i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
|
|
150
|
+
i_row_block = tl.load(
|
|
151
|
+
i_ptr + i_row_idx * i_row_stride + col_offsets * i_col_stride,
|
|
152
|
+
mask=col_offsets < i_cols_num,
|
|
153
|
+
other=0.0,
|
|
154
|
+
)
|
|
155
|
+
scaled_row_block = i_row_block * i_row_scale
|
|
156
|
+
quantized_row_block = scaled_row_block.to(TL_FP8_TYPE)
|
|
157
|
+
tl.store(
|
|
158
|
+
o_quant_ptr + col_offsets,
|
|
159
|
+
quantized_row_block,
|
|
160
|
+
mask=col_offsets < i_cols_num,
|
|
161
|
+
)
|
|
162
|
+
col_offsets += BLOCK_SIZE
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@triton.jit
|
|
166
|
+
def _fused_kernel_dequantize_from_fp8(
|
|
167
|
+
i_ptrs,
|
|
168
|
+
i_shapes,
|
|
169
|
+
i_strides,
|
|
170
|
+
i_offsets,
|
|
171
|
+
i_dtype,
|
|
172
|
+
o_ptr,
|
|
173
|
+
o_size_bytes_per_rank,
|
|
174
|
+
all_reduce_size,
|
|
175
|
+
BLOCK_SIZE: tl.constexpr,
|
|
176
|
+
TL_FP8_TYPE: tl.constexpr,
|
|
177
|
+
) -> None:
|
|
178
|
+
"""
|
|
179
|
+
Kernel to dequantize a set of input tensors from fp8. The input tensors
|
|
180
|
+
are expected to be of the same shape as those passed to the quantization.
|
|
181
|
+
The result of the dequantization is stored in the input tensors.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
i_ptrs: Pointers to the input tensors to be dequantized into
|
|
185
|
+
i_shapes: Shapes of the input tensors to be dequantized into
|
|
186
|
+
i_strides: Strides of the input tensors to be dequantized into
|
|
187
|
+
i_offsets: Offsets of the output tensors for each input tensor
|
|
188
|
+
i_dtype: Dummy tensor that carries the dtype of the input tensors
|
|
189
|
+
o_ptr: Pointer to the tensor that contains output of the quantization
|
|
190
|
+
or local reduction
|
|
191
|
+
o_size_bytes_per_rank: Size in bytes in the output tensor per rank
|
|
192
|
+
all_reduce_size: Size of the all-reduce group
|
|
193
|
+
BLOCK_SIZE: Block size for the dequantization
|
|
194
|
+
"""
|
|
195
|
+
# Index of the row in the input tensor
|
|
196
|
+
i_row_idx = tl.program_id(0)
|
|
197
|
+
# Index of the input tensor
|
|
198
|
+
i_idx = tl.program_id(1)
|
|
199
|
+
|
|
200
|
+
# Number of rows and colums in the input tensor
|
|
201
|
+
i_rows_num = tl.load(i_shapes + i_idx * 2)
|
|
202
|
+
if i_row_idx >= i_rows_num:
|
|
203
|
+
return
|
|
204
|
+
i_cols_num = tl.load(i_shapes + i_idx * 2 + 1)
|
|
205
|
+
|
|
206
|
+
# Stride to advance by a single row and column in the input tensor
|
|
207
|
+
# assume contiguous tensors
|
|
208
|
+
i_row_stride = tl.load(i_strides + i_idx * 2)
|
|
209
|
+
i_col_stride = tl.load(i_strides + i_idx * 2 + 1)
|
|
210
|
+
|
|
211
|
+
# Pointer to the input tensor
|
|
212
|
+
i_ptr = tl.load(i_ptrs + i_idx).to(i_dtype.dtype)
|
|
213
|
+
|
|
214
|
+
# Number of the rows in the input tensor that are processed by a single
|
|
215
|
+
# rank
|
|
216
|
+
i_row_slice_size = tl.cdiv(i_rows_num, all_reduce_size)
|
|
217
|
+
# Index of the row slice in the input tensor
|
|
218
|
+
i_row_slice_idx = i_row_idx // i_row_slice_size
|
|
219
|
+
|
|
220
|
+
# Size in bytes of a single input tensor row quantized and written to the
|
|
221
|
+
# output tensor
|
|
222
|
+
o_row_size_bytes = (
|
|
223
|
+
tl.cdiv(i_cols_num, SCALE_TL_DTYPE_BYTES) + 1
|
|
224
|
+
) * SCALE_TL_DTYPE_BYTES
|
|
225
|
+
|
|
226
|
+
# Pointer to the output tensor where
|
|
227
|
+
o_offset = (
|
|
228
|
+
o_size_bytes_per_rank * i_row_slice_idx
|
|
229
|
+
+ tl.load(i_offsets + i_idx)
|
|
230
|
+
+ (i_row_idx % i_row_slice_size) * o_row_size_bytes
|
|
231
|
+
)
|
|
232
|
+
# Pointer to the output tensor where the scale and quantized row will be
|
|
233
|
+
# written
|
|
234
|
+
o_curr_ptr = o_ptr + o_offset
|
|
235
|
+
o_scale_ptr = o_curr_ptr.to(tl.pointer_type(SCALE_TL_DTYPE))
|
|
236
|
+
o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE)) # type: ignore
|
|
237
|
+
|
|
238
|
+
# Load row scale
|
|
239
|
+
i_row_scale = tl.load(o_scale_ptr)
|
|
240
|
+
|
|
241
|
+
# Dequantize and store current row block by block
|
|
242
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
243
|
+
for i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
|
|
244
|
+
i_quant_row_block = tl.load(
|
|
245
|
+
o_quant_ptr + col_offsets,
|
|
246
|
+
mask=col_offsets < i_cols_num,
|
|
247
|
+
other=0.0,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
i_dequant_row_block = (
|
|
251
|
+
i_quant_row_block.to(i_dtype.dtype.element_ty) / i_row_scale
|
|
252
|
+
)
|
|
253
|
+
tl.store(
|
|
254
|
+
i_ptr + i_row_idx * i_row_stride + col_offsets * i_col_stride,
|
|
255
|
+
i_dequant_row_block,
|
|
256
|
+
mask=col_offsets < i_cols_num,
|
|
257
|
+
)
|
|
258
|
+
col_offsets += BLOCK_SIZE
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@triton.jit
|
|
262
|
+
def _fused_kernel_reduce_fp8(
|
|
263
|
+
i_shapes,
|
|
264
|
+
i_offsets,
|
|
265
|
+
o_ptr,
|
|
266
|
+
o_size_bytes_per_rank,
|
|
267
|
+
all_reduce_size,
|
|
268
|
+
all_reduce_rank,
|
|
269
|
+
division_factor,
|
|
270
|
+
BLOCK_SIZE: tl.constexpr,
|
|
271
|
+
TL_FP8_TYPE: tl.constexpr,
|
|
272
|
+
TL_FP8_MAX: tl.constexpr,
|
|
273
|
+
) -> None:
|
|
274
|
+
"""
|
|
275
|
+
Reduces rows of the output tensor for the given rank. The output tensor
|
|
276
|
+
is expected to be holding quantized rows and scales for all ranks. The
|
|
277
|
+
quantized rows are dequantized, averaged and quantized again. The result
|
|
278
|
+
is stored in the output tensor for the given rank. After the reduction
|
|
279
|
+
the row correspoding to the current rank can be shared with other ranks.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
i_shapes: Shapes of the input tensors to be reduced, used to compute
|
|
283
|
+
the number and length of rows
|
|
284
|
+
i_offsets: Offsets of the output tensors for each input tensor
|
|
285
|
+
o_ptr: Pointer to the tensor that contains output of the quantization
|
|
286
|
+
of all ranks for a row the corresponding to the current rank
|
|
287
|
+
o_size_bytes_per_rank: Size in bytes in the output tensor per rank
|
|
288
|
+
all_reduce_size: Size of the all-reduce group
|
|
289
|
+
all_reduce_rank: Rank in the all-reduce group
|
|
290
|
+
division_factor: Division factor for the reduction result
|
|
291
|
+
BLOCK_SIZE: Block size for the reduction
|
|
292
|
+
NUM_SM: Number of SMs to use for the reduction
|
|
293
|
+
"""
|
|
294
|
+
# Index of the row in the input tensor
|
|
295
|
+
i_row_block_idx = tl.program_id(0)
|
|
296
|
+
# Index of the input tensor
|
|
297
|
+
i_idx = tl.program_id(1)
|
|
298
|
+
|
|
299
|
+
# Number of rows and colums in the input tensor
|
|
300
|
+
i_rows_num = tl.load(i_shapes + i_idx * 2)
|
|
301
|
+
if i_row_block_idx >= tl.cdiv(i_rows_num, all_reduce_size):
|
|
302
|
+
return
|
|
303
|
+
i_cols_num = tl.load(i_shapes + i_idx * 2 + 1)
|
|
304
|
+
|
|
305
|
+
# Size in bytes of a single input tensor row quantized and written to the
|
|
306
|
+
# output tensor
|
|
307
|
+
o_row_size_bytes = (
|
|
308
|
+
tl.cdiv(i_cols_num, SCALE_TL_DTYPE_BYTES) + 1
|
|
309
|
+
) * SCALE_TL_DTYPE_BYTES
|
|
310
|
+
|
|
311
|
+
# Pointer to the output tensor where
|
|
312
|
+
o_offset = tl.load(i_offsets + i_idx) + i_row_block_idx * o_row_size_bytes
|
|
313
|
+
|
|
314
|
+
o_row_block_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
315
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
316
|
+
# Compute scaling factor the reduced row
|
|
317
|
+
o_row_max = 0.0
|
|
318
|
+
for o_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
|
|
319
|
+
o_row_block_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
320
|
+
col_offsets_mask = col_offsets < i_cols_num
|
|
321
|
+
# Load blocks of quantized rows, dequantize and accumulate
|
|
322
|
+
o_row_block_acc = _fused_kernel_accumulate_block(
|
|
323
|
+
o_row_block_acc,
|
|
324
|
+
o_ptr + o_offset,
|
|
325
|
+
all_reduce_size,
|
|
326
|
+
all_reduce_rank,
|
|
327
|
+
o_size_bytes_per_rank,
|
|
328
|
+
col_offsets,
|
|
329
|
+
col_offsets_mask,
|
|
330
|
+
TL_FP8_TYPE,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Compute maximum across accumulated blocks
|
|
334
|
+
o_row_block_max = tl.max(tl.abs(o_row_block_acc))
|
|
335
|
+
o_row_max = tl.maximum(o_row_block_max, o_row_max)
|
|
336
|
+
|
|
337
|
+
col_offsets += BLOCK_SIZE
|
|
338
|
+
|
|
339
|
+
# Compute scaling factor for the reduced row
|
|
340
|
+
o_row_scale = _kernel_calculate_scale(o_row_max / division_factor, TL_FP8_MAX)
|
|
341
|
+
|
|
342
|
+
o_rank_row_ptr = o_ptr + all_reduce_rank * o_size_bytes_per_rank + o_offset
|
|
343
|
+
o_rank_scale_ptr = o_rank_row_ptr.to(tl.pointer_type(SCALE_TL_DTYPE))
|
|
344
|
+
o_rank_quant_ptr = (o_rank_row_ptr + SCALE_TL_DTYPE_BYTES).to(
|
|
345
|
+
tl.pointer_type(TL_FP8_TYPE) # type: ignore
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
349
|
+
# Reduce the row in blocks and write them out
|
|
350
|
+
for o_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
|
|
351
|
+
o_row_block_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
352
|
+
col_offsets_mask = col_offsets < i_cols_num
|
|
353
|
+
# Load blocks of quantized rows, dequantize and accumulate
|
|
354
|
+
o_row_block_acc = _fused_kernel_accumulate_block(
|
|
355
|
+
o_row_block_acc,
|
|
356
|
+
o_ptr + o_offset,
|
|
357
|
+
all_reduce_size,
|
|
358
|
+
all_reduce_rank,
|
|
359
|
+
o_size_bytes_per_rank,
|
|
360
|
+
col_offsets,
|
|
361
|
+
col_offsets_mask,
|
|
362
|
+
TL_FP8_TYPE,
|
|
363
|
+
)
|
|
364
|
+
o_row_block_acc = o_row_block_acc * o_row_scale / division_factor
|
|
365
|
+
o_quant_row_block_acc = o_row_block_acc.to(TL_FP8_TYPE)
|
|
366
|
+
tl.store(
|
|
367
|
+
o_rank_quant_ptr + col_offsets,
|
|
368
|
+
o_quant_row_block_acc,
|
|
369
|
+
mask=col_offsets_mask,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
col_offsets += BLOCK_SIZE
|
|
373
|
+
|
|
374
|
+
# Write reduced row scale
|
|
375
|
+
tl.store(o_rank_scale_ptr, o_row_scale)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@triton.jit
|
|
379
|
+
def _fused_kernel_accumulate_block(
|
|
380
|
+
o_row_block_acc,
|
|
381
|
+
o_ptr,
|
|
382
|
+
o_row_num,
|
|
383
|
+
o_row_start,
|
|
384
|
+
o_row_stride,
|
|
385
|
+
col_offsets,
|
|
386
|
+
col_mask,
|
|
387
|
+
TL_FP8_TYPE: tl.constexpr,
|
|
388
|
+
) -> tl.tensor:
|
|
389
|
+
"""
|
|
390
|
+
Sums up blocks of quantized rows. The blocks are loaded from the output
|
|
391
|
+
tensor, dequantized and accumulated into the row block accumulator.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
o_row_block_acc: Row block accumulator
|
|
395
|
+
o_ptr: Pointer to the output tensor
|
|
396
|
+
o_row_num: Number of rows in the output tensor
|
|
397
|
+
o_row_start: Start row index in the output tensor, used to ensure that
|
|
398
|
+
accumulation happens in the correct order
|
|
399
|
+
o_row_stride: Stride to advance by a single row in the output tensor
|
|
400
|
+
col_offsets: Column offsets for the block of quantized rows
|
|
401
|
+
col_mask: Column mask for the block of quantized rows, used to prevent
|
|
402
|
+
going out of bounds
|
|
403
|
+
"""
|
|
404
|
+
# Load blocks of quantized rows, dequantize and accumulate
|
|
405
|
+
for o_row_idx in range(o_row_num):
|
|
406
|
+
# Start with the row that corresponds to the current rank
|
|
407
|
+
o_row_idx_wrapped = (o_row_idx + o_row_start) % o_row_num
|
|
408
|
+
|
|
409
|
+
o_row_ptr = o_ptr + o_row_idx_wrapped * o_row_stride
|
|
410
|
+
|
|
411
|
+
# Load row scale and block of quantized row
|
|
412
|
+
o_scale_ptr = o_row_ptr.to(tl.pointer_type(tl.float32))
|
|
413
|
+
o_quant_ptr = (o_row_ptr + SCALE_TL_DTYPE_BYTES).to(
|
|
414
|
+
tl.pointer_type(TL_FP8_TYPE) # type: ignore
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
o_row_scale = tl.load(o_scale_ptr)
|
|
418
|
+
# Ensure that we do not divide by zero when reducing "padding" rows
|
|
419
|
+
o_row_scale = tl.where(o_row_scale == 0.0, 1.0, o_row_scale)
|
|
420
|
+
o_row_quant_block = tl.load(
|
|
421
|
+
o_quant_ptr + col_offsets,
|
|
422
|
+
mask=col_mask,
|
|
423
|
+
other=0.0,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
o_row_block_acc += o_row_quant_block.to(tl.float32) / o_row_scale
|
|
427
|
+
|
|
428
|
+
return o_row_block_acc
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def _prepare_quantize_fp8(
|
|
432
|
+
inputs: list[torch.Tensor], all_reduce_group_size: int
|
|
433
|
+
) -> tuple[
|
|
434
|
+
torch.Tensor,
|
|
435
|
+
torch.Tensor,
|
|
436
|
+
torch.Tensor,
|
|
437
|
+
torch.Tensor,
|
|
438
|
+
torch.Tensor,
|
|
439
|
+
int,
|
|
440
|
+
int,
|
|
441
|
+
torch.device,
|
|
442
|
+
]:
|
|
443
|
+
"""
|
|
444
|
+
Prepares the inputs for the quantization, dequantization and reduction kernels.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
inputs: List of input tensors to be quantized, dequantized or reduced
|
|
448
|
+
all_reduce_group_size: Size of the all-reduce group
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
d_i_ptrs: Pointers to the input tensors
|
|
452
|
+
d_i_shapes: Shapes of the input tensors
|
|
453
|
+
d_i_strides: Row strides of the input tensors
|
|
454
|
+
d_i_offsets: Offsets into the output tensor for each rank for each input
|
|
455
|
+
tensor.
|
|
456
|
+
d_i_dtype: The type of the input tensors
|
|
457
|
+
output_size: Size of the output tensor in bytes including necessary padding
|
|
458
|
+
i_max_row_num: Maximum number of rows in the input tensors
|
|
459
|
+
device: Device of the input tensors
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
i_num = len(inputs)
|
|
463
|
+
assert i_num > 0, "At least one input tensor is required"
|
|
464
|
+
device = inputs[0].device
|
|
465
|
+
dtype = inputs[0].dtype
|
|
466
|
+
for i in range(1, i_num):
|
|
467
|
+
assert (
|
|
468
|
+
inputs[i].device == inputs[i - 1].device
|
|
469
|
+
), "All inputs must be on the same device"
|
|
470
|
+
assert (
|
|
471
|
+
inputs[i].dtype == inputs[i - 1].dtype
|
|
472
|
+
), "All inputs must be on the same dtype"
|
|
473
|
+
|
|
474
|
+
assert dtype in [
|
|
475
|
+
torch.float32,
|
|
476
|
+
torch.float16,
|
|
477
|
+
torch.bfloat16,
|
|
478
|
+
], "Only fp32, fp16 and bf16 are supported"
|
|
479
|
+
i_ptrs = []
|
|
480
|
+
i_shapes = []
|
|
481
|
+
i_strides = []
|
|
482
|
+
i_offsets = []
|
|
483
|
+
output_size = 0
|
|
484
|
+
i_max_row_num = 0
|
|
485
|
+
output_size_per_rank = 0
|
|
486
|
+
for i in range(i_num):
|
|
487
|
+
if len(inputs[i].shape) == 1:
|
|
488
|
+
inputs[i] = inputs[i].unsqueeze(1)
|
|
489
|
+
assert len(inputs[i].shape) == 2, "Only 2D tensors are supported"
|
|
490
|
+
i_ptrs.append(inputs[i].data_ptr())
|
|
491
|
+
i_m, i_n = inputs[i].shape
|
|
492
|
+
i_shapes.append([i_m, i_n])
|
|
493
|
+
i_m_stride, i_n_stride = inputs[i].stride()
|
|
494
|
+
i_strides.append([i_m_stride, i_n_stride])
|
|
495
|
+
i_m_padded = triton.cdiv(i_m, all_reduce_group_size) * all_reduce_group_size
|
|
496
|
+
i_max_row_num = max(i_max_row_num, i_m_padded)
|
|
497
|
+
|
|
498
|
+
i_n_padded = (
|
|
499
|
+
i_m_padded * (triton.cdiv(i_n, SCALE_DTYPE_BYTES) + 1) * SCALE_DTYPE_BYTES
|
|
500
|
+
)
|
|
501
|
+
i_offsets.append(output_size_per_rank)
|
|
502
|
+
output_size_per_rank += i_n_padded // all_reduce_group_size
|
|
503
|
+
output_size += i_n_padded
|
|
504
|
+
|
|
505
|
+
d_i_ptrs = torch.empty(i_num, dtype=torch.int64, device=device)
|
|
506
|
+
d_i_ptrs.copy_(torch.tensor(i_ptrs), non_blocking=True)
|
|
507
|
+
|
|
508
|
+
d_i_shapes = torch.empty(i_num, 2, dtype=torch.int32, device=device)
|
|
509
|
+
d_i_shapes.copy_(torch.tensor(i_shapes, dtype=torch.int32), non_blocking=True)
|
|
510
|
+
|
|
511
|
+
d_i_strides = torch.empty(i_num, 2, dtype=torch.int32, device=device)
|
|
512
|
+
d_i_strides.copy_(torch.tensor(i_strides, dtype=torch.int32), non_blocking=True)
|
|
513
|
+
|
|
514
|
+
d_i_offsets = torch.empty(i_num, dtype=torch.int32, device=device)
|
|
515
|
+
d_i_offsets.copy_(torch.tensor(i_offsets, dtype=torch.int32), non_blocking=True)
|
|
516
|
+
|
|
517
|
+
d_i_dtype = torch.empty(1, dtype=dtype, device=device)
|
|
518
|
+
|
|
519
|
+
return (
|
|
520
|
+
d_i_ptrs,
|
|
521
|
+
d_i_shapes,
|
|
522
|
+
d_i_strides,
|
|
523
|
+
d_i_offsets,
|
|
524
|
+
d_i_dtype,
|
|
525
|
+
output_size,
|
|
526
|
+
i_max_row_num,
|
|
527
|
+
device,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def fused_quantize_into_fp8(
|
|
532
|
+
inputs: list[torch.Tensor], all_reduce_group_size: int
|
|
533
|
+
) -> torch.Tensor:
|
|
534
|
+
"""
|
|
535
|
+
Quantizes a set of input tensors into fp8 where each row of each input
|
|
536
|
+
tensor is quantized individually. The result is stored in the output tensor.
|
|
537
|
+
Note that quantized rows and their scales are interleaved in the output
|
|
538
|
+
tensor. Conceptually the output tensor consists one row per rank in the all
|
|
539
|
+
reduce group. Each output row contains subset (input tensor rows are
|
|
540
|
+
divided by the all group size and padded if needed) of quantized rows from
|
|
541
|
+
the input tensors and their scales. The quantized rows are encoded as fp32
|
|
542
|
+
scale followed by fp8 values followed by padding to ensure aligned memory
|
|
543
|
+
access.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
inputs: List of input tensors to be quantized
|
|
547
|
+
all_reduce_group_size: Size of the all-reduce group
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
output: Output tensor that contains quantized rows and scales for all
|
|
551
|
+
ranks.
|
|
552
|
+
"""
|
|
553
|
+
|
|
554
|
+
(
|
|
555
|
+
d_i_ptrs,
|
|
556
|
+
d_i_shapes,
|
|
557
|
+
d_i_strides,
|
|
558
|
+
d_i_offsets,
|
|
559
|
+
d_i_dtype,
|
|
560
|
+
output_size,
|
|
561
|
+
i_max_row_num,
|
|
562
|
+
device,
|
|
563
|
+
) = _prepare_quantize_fp8(inputs, all_reduce_group_size)
|
|
564
|
+
|
|
565
|
+
# Allocate output tensor in scale dtype so that we can store scales by
|
|
566
|
+
# doing pointer arithmetic and do not get misaligned memory access.
|
|
567
|
+
output = torch.zeros(
|
|
568
|
+
output_size // SCALE_DTYPE_BYTES,
|
|
569
|
+
device=device,
|
|
570
|
+
dtype=SCALE_DTYPE,
|
|
571
|
+
).view(torch.uint8)
|
|
572
|
+
|
|
573
|
+
grid = (i_max_row_num, len(inputs))
|
|
574
|
+
_fused_kernel_quantize_into_fp8[grid](
|
|
575
|
+
d_i_ptrs,
|
|
576
|
+
d_i_shapes,
|
|
577
|
+
d_i_strides,
|
|
578
|
+
d_i_offsets,
|
|
579
|
+
d_i_dtype,
|
|
580
|
+
output,
|
|
581
|
+
output_size // all_reduce_group_size,
|
|
582
|
+
all_reduce_group_size,
|
|
583
|
+
BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore
|
|
584
|
+
TL_FP8_TYPE=_get_fp8_type(),
|
|
585
|
+
TL_FP8_MAX=_get_fp8_max(),
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
return output
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def fused_dequantize_from_fp8(
|
|
592
|
+
inputs: list[torch.Tensor], output: torch.Tensor, all_reduce_group_size: int
|
|
593
|
+
) -> None:
|
|
594
|
+
"""
|
|
595
|
+
Dequantizes a set of input tensors from fp8 stored in the output tensor.
|
|
596
|
+
The input tensors are expected to be of the same shape as those passed to
|
|
597
|
+
the quantization. The result of the dequantization is stored in the input
|
|
598
|
+
tensors. Note that quantized rows and their scales are interleaved in the
|
|
599
|
+
output tensor. Conceptually the output tensor consists one row per rank in
|
|
600
|
+
the all reduce group. Each output row contains subset (input tensor rows are
|
|
601
|
+
divided by the all group size and padded if needed) of quantized rows from
|
|
602
|
+
the input tensors and their scales.
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
inputs: List of input tensors to be dequantized into
|
|
606
|
+
output: Output tensor that contains quantized rows and scales for all
|
|
607
|
+
ranks.
|
|
608
|
+
all_reduce_group_size: Size of the all-reduce group
|
|
609
|
+
"""
|
|
610
|
+
(
|
|
611
|
+
d_i_ptrs,
|
|
612
|
+
d_i_shapes,
|
|
613
|
+
d_i_strides,
|
|
614
|
+
d_i_offsets,
|
|
615
|
+
d_i_dtype,
|
|
616
|
+
output_size,
|
|
617
|
+
i_max_row_num,
|
|
618
|
+
device,
|
|
619
|
+
) = _prepare_quantize_fp8(inputs, all_reduce_group_size)
|
|
620
|
+
|
|
621
|
+
assert output.shape[0] == output_size, "Output size does not match"
|
|
622
|
+
|
|
623
|
+
grid = (i_max_row_num, len(inputs))
|
|
624
|
+
_fused_kernel_dequantize_from_fp8[grid](
|
|
625
|
+
d_i_ptrs,
|
|
626
|
+
d_i_shapes,
|
|
627
|
+
d_i_strides,
|
|
628
|
+
d_i_offsets,
|
|
629
|
+
d_i_dtype,
|
|
630
|
+
output,
|
|
631
|
+
output_size // all_reduce_group_size,
|
|
632
|
+
all_reduce_group_size,
|
|
633
|
+
BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore
|
|
634
|
+
TL_FP8_TYPE=_get_fp8_type(),
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def fused_reduce_fp8(
|
|
639
|
+
inputs: list[torch.Tensor],
|
|
640
|
+
output: torch.Tensor,
|
|
641
|
+
all_reduce_group_size: int,
|
|
642
|
+
all_reduce_rank: int,
|
|
643
|
+
reduce_op: ReduceOp = ReduceOp.SUM,
|
|
644
|
+
) -> None:
|
|
645
|
+
"""
|
|
646
|
+
Reduces rows of the output tensor for the given rank. The output tensor
|
|
647
|
+
is expected to be holding quantized rows and scales for all ranks. The
|
|
648
|
+
quantized rows are dequantized, averaged and quantized again. The result
|
|
649
|
+
is stored in the output tensor for the given rank. After the reduction
|
|
650
|
+
the row correspoding to the current rank can be shared with other
|
|
651
|
+
ranks.
|
|
652
|
+
|
|
653
|
+
Args:
|
|
654
|
+
inputs: List of input tensors to be reduced
|
|
655
|
+
output: Output tensor that contains quantized rows and scales for
|
|
656
|
+
all ranks.
|
|
657
|
+
all_reduce_group_size: Size of the all-reduce group
|
|
658
|
+
all_reduce_rank: Rank in the all-reduce group
|
|
659
|
+
"""
|
|
660
|
+
|
|
661
|
+
(
|
|
662
|
+
d_i_ptrs,
|
|
663
|
+
d_i_shapes,
|
|
664
|
+
d_i_strides,
|
|
665
|
+
d_i_offsets,
|
|
666
|
+
d_i_dtype,
|
|
667
|
+
output_size,
|
|
668
|
+
i_max_row_num,
|
|
669
|
+
device,
|
|
670
|
+
) = _prepare_quantize_fp8(inputs, all_reduce_group_size)
|
|
671
|
+
|
|
672
|
+
assert output.shape[0] == output_size, "Output size does not match"
|
|
673
|
+
|
|
674
|
+
grid = (i_max_row_num // all_reduce_group_size, len(inputs))
|
|
675
|
+
_fused_kernel_reduce_fp8[grid](
|
|
676
|
+
d_i_shapes,
|
|
677
|
+
d_i_offsets,
|
|
678
|
+
output,
|
|
679
|
+
output_size // all_reduce_group_size,
|
|
680
|
+
all_reduce_group_size,
|
|
681
|
+
all_reduce_rank,
|
|
682
|
+
1.0 if reduce_op == ReduceOp.SUM else float(all_reduce_group_size),
|
|
683
|
+
BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore
|
|
684
|
+
TL_FP8_TYPE=_get_fp8_type(),
|
|
685
|
+
TL_FP8_MAX=_get_fp8_max(),
|
|
686
|
+
)
|