torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,686 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import torch
9
+ import torch.cuda as cuda
10
+
11
+ # pyre-ignore[21]: Could not find a module corresponding to import `triton`
12
+ import triton
13
+
14
+ # pyre-ignore[21]: Could not find a module corresponding to import `triton.language`
15
+ import triton.language as tl
16
+
17
+ # pyre-ignore[21]: Could not find a module corresponding to import `triton.runtime`
18
+ import triton.runtime as tr
19
+ from torch.distributed import ReduceOp
20
+
21
+ SCALE_DTYPE: torch.dtype = torch.float32
22
+ SCALE_DTYPE_BYTES: int = 4
23
+ SCALE_TL_DTYPE = tl.float32
24
+ SCALE_TL_DTYPE_BYTES = tl.constexpr(4)
25
+
26
+ BLOCK_SIZE_T: int = 2048
27
+
28
+
29
+ # pyre-ignore[11]: Annotation `tl.constexpr` is not defined
30
+ def _get_fp8_max() -> tl.constexpr:
31
+ if cuda.get_device_capability() >= (9, 0):
32
+ return tl.constexpr(448.0)
33
+ else:
34
+ return tl.constexpr(127)
35
+
36
+
37
+ def _get_fp8_type() -> tl.constexpr:
38
+ if cuda.get_device_capability() >= (9, 0):
39
+ return tl.constexpr(tl.float8e4nv)
40
+ else:
41
+ return tl.constexpr(tl.int8)
42
+
43
+
44
+ @triton.jit
45
+ # pyre-ignore[11]: Annotation `tl.tensor` is not defined
46
+ def _kernel_calculate_scale(row_max, TL_FP8_MAX: tl.constexpr) -> tl.tensor:
47
+ row_scale = TL_FP8_MAX / row_max
48
+ is_inf = row_scale == float("inf")
49
+ row_scale = tl.where(is_inf, 1.0, row_scale)
50
+ return row_scale
51
+
52
+
53
+ @triton.jit
54
+ def _fused_kernel_quantize_into_fp8(
55
+ i_ptrs,
56
+ i_shapes,
57
+ i_strides,
58
+ i_offsets,
59
+ i_dtype,
60
+ o_ptr,
61
+ o_size_bytes_per_rank,
62
+ all_reduce_size,
63
+ BLOCK_SIZE: tl.constexpr,
64
+ TL_FP8_TYPE: tl.constexpr,
65
+ TL_FP8_MAX: tl.constexpr,
66
+ ):
67
+ """
68
+ Kernel to quantize a set of input tensors into fp8. The input tensors are
69
+ expected to be two-dimensional and the output tensor is expected to be
70
+ one-dimensional. The output tensor is expected to be large enough to hold
71
+ the quantized input and scales for all input tensors. The quantized input
72
+ and scales are interleaved in the output tensor. The quantized input
73
+ is stored as fp8 and the scales are stored as fp32.
74
+
75
+ Args:
76
+ i_ptrs: Pointers to the input tensors to be quantized
77
+ i_shapes: Shapes of the input tensors to be quantized
78
+ i_strides: Strides of the input tensors to be quantized
79
+ i_offsets: Offsets of the output tensors for each input tensor
80
+ i_dtype: Dummy tensor that carries the dtype of the input tensors
81
+ o_ptr: Pointer to the output tensor for the quantized input and scales
82
+ o_size_bytes_per_rank: Size in bytes in the output tensor per rank
83
+ all_reduce_size: Size of the all-reduce group
84
+ BLOCK_SIZE: Block size for the quantization
85
+ NUM_SM: Number of SMs to use for the quantization
86
+ """
87
+ # Index of the row in the input tensor
88
+ i_row_idx = tl.program_id(0)
89
+ # Index of the input tensor
90
+ i_idx = tl.program_id(1)
91
+
92
+ # Number of rows and colums in the input tensor
93
+ i_rows_num = tl.load(i_shapes + i_idx * 2)
94
+ if i_row_idx >= i_rows_num:
95
+ return
96
+ i_cols_num = tl.load(i_shapes + i_idx * 2 + 1)
97
+
98
+ # Stride to advance by a single row and column in the input tensor
99
+ # assume contiguous tensors
100
+ i_row_stride = tl.load(i_strides + i_idx * 2)
101
+ i_col_stride = tl.load(i_strides + i_idx * 2 + 1)
102
+
103
+ # Pointer to the input tensor
104
+ i_ptr = tl.load(i_ptrs + i_idx).to(i_dtype.dtype)
105
+
106
+ # Number of the rows in the input tensor that are processed by a single
107
+ # rank
108
+ i_row_slice_size = tl.cdiv(i_rows_num, all_reduce_size)
109
+ # Index of the row slice in the input tensor
110
+ i_row_slice_idx = i_row_idx // i_row_slice_size
111
+
112
+ # Size in bytes of a single input tensor row quantized and written to the
113
+ # output tensor
114
+ o_row_size_bytes = (
115
+ tl.cdiv(i_cols_num, SCALE_TL_DTYPE_BYTES) + 1
116
+ ) * SCALE_TL_DTYPE_BYTES
117
+
118
+ # Pointer to the output tensor where
119
+ o_offset = (
120
+ o_size_bytes_per_rank * i_row_slice_idx
121
+ + tl.load(i_offsets + i_idx)
122
+ + (i_row_idx % i_row_slice_size) * o_row_size_bytes
123
+ )
124
+ # Pointer to the output tensor where the scale and quantized row will
125
+ # be written
126
+ o_curr_ptr = o_ptr + o_offset
127
+ o_scale_ptr = o_curr_ptr.to(tl.pointer_type(SCALE_TL_DTYPE))
128
+ o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE)) # type: ignore
129
+
130
+ # Compute maximum for the current row block by block
131
+ col_offsets = tl.arange(0, BLOCK_SIZE)
132
+ col_maxes = tl.full((BLOCK_SIZE,), 0, dtype=tl.float32)
133
+ for i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
134
+ i_row_block = tl.load(
135
+ i_ptr + i_row_idx * i_row_stride + col_offsets * i_col_stride,
136
+ mask=col_offsets < i_cols_num,
137
+ other=0.0,
138
+ )
139
+ col_maxes = tl.maximum(tl.abs(i_row_block), col_maxes)
140
+ col_offsets += BLOCK_SIZE
141
+
142
+ # Compute and store scale for the current row
143
+ i_row_max = tl.max(col_maxes)
144
+ i_row_scale = _kernel_calculate_scale(i_row_max, TL_FP8_MAX)
145
+ tl.store(o_scale_ptr, i_row_scale)
146
+
147
+ # Scale and quantize current row block by block
148
+ col_offsets = tl.arange(0, BLOCK_SIZE)
149
+ for i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
150
+ i_row_block = tl.load(
151
+ i_ptr + i_row_idx * i_row_stride + col_offsets * i_col_stride,
152
+ mask=col_offsets < i_cols_num,
153
+ other=0.0,
154
+ )
155
+ scaled_row_block = i_row_block * i_row_scale
156
+ quantized_row_block = scaled_row_block.to(TL_FP8_TYPE)
157
+ tl.store(
158
+ o_quant_ptr + col_offsets,
159
+ quantized_row_block,
160
+ mask=col_offsets < i_cols_num,
161
+ )
162
+ col_offsets += BLOCK_SIZE
163
+
164
+
165
+ @triton.jit
166
+ def _fused_kernel_dequantize_from_fp8(
167
+ i_ptrs,
168
+ i_shapes,
169
+ i_strides,
170
+ i_offsets,
171
+ i_dtype,
172
+ o_ptr,
173
+ o_size_bytes_per_rank,
174
+ all_reduce_size,
175
+ BLOCK_SIZE: tl.constexpr,
176
+ TL_FP8_TYPE: tl.constexpr,
177
+ ) -> None:
178
+ """
179
+ Kernel to dequantize a set of input tensors from fp8. The input tensors
180
+ are expected to be of the same shape as those passed to the quantization.
181
+ The result of the dequantization is stored in the input tensors.
182
+
183
+ Args:
184
+ i_ptrs: Pointers to the input tensors to be dequantized into
185
+ i_shapes: Shapes of the input tensors to be dequantized into
186
+ i_strides: Strides of the input tensors to be dequantized into
187
+ i_offsets: Offsets of the output tensors for each input tensor
188
+ i_dtype: Dummy tensor that carries the dtype of the input tensors
189
+ o_ptr: Pointer to the tensor that contains output of the quantization
190
+ or local reduction
191
+ o_size_bytes_per_rank: Size in bytes in the output tensor per rank
192
+ all_reduce_size: Size of the all-reduce group
193
+ BLOCK_SIZE: Block size for the dequantization
194
+ """
195
+ # Index of the row in the input tensor
196
+ i_row_idx = tl.program_id(0)
197
+ # Index of the input tensor
198
+ i_idx = tl.program_id(1)
199
+
200
+ # Number of rows and colums in the input tensor
201
+ i_rows_num = tl.load(i_shapes + i_idx * 2)
202
+ if i_row_idx >= i_rows_num:
203
+ return
204
+ i_cols_num = tl.load(i_shapes + i_idx * 2 + 1)
205
+
206
+ # Stride to advance by a single row and column in the input tensor
207
+ # assume contiguous tensors
208
+ i_row_stride = tl.load(i_strides + i_idx * 2)
209
+ i_col_stride = tl.load(i_strides + i_idx * 2 + 1)
210
+
211
+ # Pointer to the input tensor
212
+ i_ptr = tl.load(i_ptrs + i_idx).to(i_dtype.dtype)
213
+
214
+ # Number of the rows in the input tensor that are processed by a single
215
+ # rank
216
+ i_row_slice_size = tl.cdiv(i_rows_num, all_reduce_size)
217
+ # Index of the row slice in the input tensor
218
+ i_row_slice_idx = i_row_idx // i_row_slice_size
219
+
220
+ # Size in bytes of a single input tensor row quantized and written to the
221
+ # output tensor
222
+ o_row_size_bytes = (
223
+ tl.cdiv(i_cols_num, SCALE_TL_DTYPE_BYTES) + 1
224
+ ) * SCALE_TL_DTYPE_BYTES
225
+
226
+ # Pointer to the output tensor where
227
+ o_offset = (
228
+ o_size_bytes_per_rank * i_row_slice_idx
229
+ + tl.load(i_offsets + i_idx)
230
+ + (i_row_idx % i_row_slice_size) * o_row_size_bytes
231
+ )
232
+ # Pointer to the output tensor where the scale and quantized row will be
233
+ # written
234
+ o_curr_ptr = o_ptr + o_offset
235
+ o_scale_ptr = o_curr_ptr.to(tl.pointer_type(SCALE_TL_DTYPE))
236
+ o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE)) # type: ignore
237
+
238
+ # Load row scale
239
+ i_row_scale = tl.load(o_scale_ptr)
240
+
241
+ # Dequantize and store current row block by block
242
+ col_offsets = tl.arange(0, BLOCK_SIZE)
243
+ for i_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
244
+ i_quant_row_block = tl.load(
245
+ o_quant_ptr + col_offsets,
246
+ mask=col_offsets < i_cols_num,
247
+ other=0.0,
248
+ )
249
+
250
+ i_dequant_row_block = (
251
+ i_quant_row_block.to(i_dtype.dtype.element_ty) / i_row_scale
252
+ )
253
+ tl.store(
254
+ i_ptr + i_row_idx * i_row_stride + col_offsets * i_col_stride,
255
+ i_dequant_row_block,
256
+ mask=col_offsets < i_cols_num,
257
+ )
258
+ col_offsets += BLOCK_SIZE
259
+
260
+
261
+ @triton.jit
262
+ def _fused_kernel_reduce_fp8(
263
+ i_shapes,
264
+ i_offsets,
265
+ o_ptr,
266
+ o_size_bytes_per_rank,
267
+ all_reduce_size,
268
+ all_reduce_rank,
269
+ division_factor,
270
+ BLOCK_SIZE: tl.constexpr,
271
+ TL_FP8_TYPE: tl.constexpr,
272
+ TL_FP8_MAX: tl.constexpr,
273
+ ) -> None:
274
+ """
275
+ Reduces rows of the output tensor for the given rank. The output tensor
276
+ is expected to be holding quantized rows and scales for all ranks. The
277
+ quantized rows are dequantized, averaged and quantized again. The result
278
+ is stored in the output tensor for the given rank. After the reduction
279
+ the row correspoding to the current rank can be shared with other ranks.
280
+
281
+ Args:
282
+ i_shapes: Shapes of the input tensors to be reduced, used to compute
283
+ the number and length of rows
284
+ i_offsets: Offsets of the output tensors for each input tensor
285
+ o_ptr: Pointer to the tensor that contains output of the quantization
286
+ of all ranks for a row the corresponding to the current rank
287
+ o_size_bytes_per_rank: Size in bytes in the output tensor per rank
288
+ all_reduce_size: Size of the all-reduce group
289
+ all_reduce_rank: Rank in the all-reduce group
290
+ division_factor: Division factor for the reduction result
291
+ BLOCK_SIZE: Block size for the reduction
292
+ NUM_SM: Number of SMs to use for the reduction
293
+ """
294
+ # Index of the row in the input tensor
295
+ i_row_block_idx = tl.program_id(0)
296
+ # Index of the input tensor
297
+ i_idx = tl.program_id(1)
298
+
299
+ # Number of rows and colums in the input tensor
300
+ i_rows_num = tl.load(i_shapes + i_idx * 2)
301
+ if i_row_block_idx >= tl.cdiv(i_rows_num, all_reduce_size):
302
+ return
303
+ i_cols_num = tl.load(i_shapes + i_idx * 2 + 1)
304
+
305
+ # Size in bytes of a single input tensor row quantized and written to the
306
+ # output tensor
307
+ o_row_size_bytes = (
308
+ tl.cdiv(i_cols_num, SCALE_TL_DTYPE_BYTES) + 1
309
+ ) * SCALE_TL_DTYPE_BYTES
310
+
311
+ # Pointer to the output tensor where
312
+ o_offset = tl.load(i_offsets + i_idx) + i_row_block_idx * o_row_size_bytes
313
+
314
+ o_row_block_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
315
+ col_offsets = tl.arange(0, BLOCK_SIZE)
316
+ # Compute scaling factor the reduced row
317
+ o_row_max = 0.0
318
+ for o_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
319
+ o_row_block_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
320
+ col_offsets_mask = col_offsets < i_cols_num
321
+ # Load blocks of quantized rows, dequantize and accumulate
322
+ o_row_block_acc = _fused_kernel_accumulate_block(
323
+ o_row_block_acc,
324
+ o_ptr + o_offset,
325
+ all_reduce_size,
326
+ all_reduce_rank,
327
+ o_size_bytes_per_rank,
328
+ col_offsets,
329
+ col_offsets_mask,
330
+ TL_FP8_TYPE,
331
+ )
332
+
333
+ # Compute maximum across accumulated blocks
334
+ o_row_block_max = tl.max(tl.abs(o_row_block_acc))
335
+ o_row_max = tl.maximum(o_row_block_max, o_row_max)
336
+
337
+ col_offsets += BLOCK_SIZE
338
+
339
+ # Compute scaling factor for the reduced row
340
+ o_row_scale = _kernel_calculate_scale(o_row_max / division_factor, TL_FP8_MAX)
341
+
342
+ o_rank_row_ptr = o_ptr + all_reduce_rank * o_size_bytes_per_rank + o_offset
343
+ o_rank_scale_ptr = o_rank_row_ptr.to(tl.pointer_type(SCALE_TL_DTYPE))
344
+ o_rank_quant_ptr = (o_rank_row_ptr + SCALE_TL_DTYPE_BYTES).to(
345
+ tl.pointer_type(TL_FP8_TYPE) # type: ignore
346
+ )
347
+
348
+ col_offsets = tl.arange(0, BLOCK_SIZE)
349
+ # Reduce the row in blocks and write them out
350
+ for o_b in range(0, tl.cdiv(i_cols_num, BLOCK_SIZE)):
351
+ o_row_block_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
352
+ col_offsets_mask = col_offsets < i_cols_num
353
+ # Load blocks of quantized rows, dequantize and accumulate
354
+ o_row_block_acc = _fused_kernel_accumulate_block(
355
+ o_row_block_acc,
356
+ o_ptr + o_offset,
357
+ all_reduce_size,
358
+ all_reduce_rank,
359
+ o_size_bytes_per_rank,
360
+ col_offsets,
361
+ col_offsets_mask,
362
+ TL_FP8_TYPE,
363
+ )
364
+ o_row_block_acc = o_row_block_acc * o_row_scale / division_factor
365
+ o_quant_row_block_acc = o_row_block_acc.to(TL_FP8_TYPE)
366
+ tl.store(
367
+ o_rank_quant_ptr + col_offsets,
368
+ o_quant_row_block_acc,
369
+ mask=col_offsets_mask,
370
+ )
371
+
372
+ col_offsets += BLOCK_SIZE
373
+
374
+ # Write reduced row scale
375
+ tl.store(o_rank_scale_ptr, o_row_scale)
376
+
377
+
378
+ @triton.jit
379
+ def _fused_kernel_accumulate_block(
380
+ o_row_block_acc,
381
+ o_ptr,
382
+ o_row_num,
383
+ o_row_start,
384
+ o_row_stride,
385
+ col_offsets,
386
+ col_mask,
387
+ TL_FP8_TYPE: tl.constexpr,
388
+ ) -> tl.tensor:
389
+ """
390
+ Sums up blocks of quantized rows. The blocks are loaded from the output
391
+ tensor, dequantized and accumulated into the row block accumulator.
392
+
393
+ Args:
394
+ o_row_block_acc: Row block accumulator
395
+ o_ptr: Pointer to the output tensor
396
+ o_row_num: Number of rows in the output tensor
397
+ o_row_start: Start row index in the output tensor, used to ensure that
398
+ accumulation happens in the correct order
399
+ o_row_stride: Stride to advance by a single row in the output tensor
400
+ col_offsets: Column offsets for the block of quantized rows
401
+ col_mask: Column mask for the block of quantized rows, used to prevent
402
+ going out of bounds
403
+ """
404
+ # Load blocks of quantized rows, dequantize and accumulate
405
+ for o_row_idx in range(o_row_num):
406
+ # Start with the row that corresponds to the current rank
407
+ o_row_idx_wrapped = (o_row_idx + o_row_start) % o_row_num
408
+
409
+ o_row_ptr = o_ptr + o_row_idx_wrapped * o_row_stride
410
+
411
+ # Load row scale and block of quantized row
412
+ o_scale_ptr = o_row_ptr.to(tl.pointer_type(tl.float32))
413
+ o_quant_ptr = (o_row_ptr + SCALE_TL_DTYPE_BYTES).to(
414
+ tl.pointer_type(TL_FP8_TYPE) # type: ignore
415
+ )
416
+
417
+ o_row_scale = tl.load(o_scale_ptr)
418
+ # Ensure that we do not divide by zero when reducing "padding" rows
419
+ o_row_scale = tl.where(o_row_scale == 0.0, 1.0, o_row_scale)
420
+ o_row_quant_block = tl.load(
421
+ o_quant_ptr + col_offsets,
422
+ mask=col_mask,
423
+ other=0.0,
424
+ )
425
+
426
+ o_row_block_acc += o_row_quant_block.to(tl.float32) / o_row_scale
427
+
428
+ return o_row_block_acc
429
+
430
+
431
+ def _prepare_quantize_fp8(
432
+ inputs: list[torch.Tensor], all_reduce_group_size: int
433
+ ) -> tuple[
434
+ torch.Tensor,
435
+ torch.Tensor,
436
+ torch.Tensor,
437
+ torch.Tensor,
438
+ torch.Tensor,
439
+ int,
440
+ int,
441
+ torch.device,
442
+ ]:
443
+ """
444
+ Prepares the inputs for the quantization, dequantization and reduction kernels.
445
+
446
+ Args:
447
+ inputs: List of input tensors to be quantized, dequantized or reduced
448
+ all_reduce_group_size: Size of the all-reduce group
449
+
450
+ Returns:
451
+ d_i_ptrs: Pointers to the input tensors
452
+ d_i_shapes: Shapes of the input tensors
453
+ d_i_strides: Row strides of the input tensors
454
+ d_i_offsets: Offsets into the output tensor for each rank for each input
455
+ tensor.
456
+ d_i_dtype: The type of the input tensors
457
+ output_size: Size of the output tensor in bytes including necessary padding
458
+ i_max_row_num: Maximum number of rows in the input tensors
459
+ device: Device of the input tensors
460
+ """
461
+
462
+ i_num = len(inputs)
463
+ assert i_num > 0, "At least one input tensor is required"
464
+ device = inputs[0].device
465
+ dtype = inputs[0].dtype
466
+ for i in range(1, i_num):
467
+ assert (
468
+ inputs[i].device == inputs[i - 1].device
469
+ ), "All inputs must be on the same device"
470
+ assert (
471
+ inputs[i].dtype == inputs[i - 1].dtype
472
+ ), "All inputs must be on the same dtype"
473
+
474
+ assert dtype in [
475
+ torch.float32,
476
+ torch.float16,
477
+ torch.bfloat16,
478
+ ], "Only fp32, fp16 and bf16 are supported"
479
+ i_ptrs = []
480
+ i_shapes = []
481
+ i_strides = []
482
+ i_offsets = []
483
+ output_size = 0
484
+ i_max_row_num = 0
485
+ output_size_per_rank = 0
486
+ for i in range(i_num):
487
+ if len(inputs[i].shape) == 1:
488
+ inputs[i] = inputs[i].unsqueeze(1)
489
+ assert len(inputs[i].shape) == 2, "Only 2D tensors are supported"
490
+ i_ptrs.append(inputs[i].data_ptr())
491
+ i_m, i_n = inputs[i].shape
492
+ i_shapes.append([i_m, i_n])
493
+ i_m_stride, i_n_stride = inputs[i].stride()
494
+ i_strides.append([i_m_stride, i_n_stride])
495
+ i_m_padded = triton.cdiv(i_m, all_reduce_group_size) * all_reduce_group_size
496
+ i_max_row_num = max(i_max_row_num, i_m_padded)
497
+
498
+ i_n_padded = (
499
+ i_m_padded * (triton.cdiv(i_n, SCALE_DTYPE_BYTES) + 1) * SCALE_DTYPE_BYTES
500
+ )
501
+ i_offsets.append(output_size_per_rank)
502
+ output_size_per_rank += i_n_padded // all_reduce_group_size
503
+ output_size += i_n_padded
504
+
505
+ d_i_ptrs = torch.empty(i_num, dtype=torch.int64, device=device)
506
+ d_i_ptrs.copy_(torch.tensor(i_ptrs), non_blocking=True)
507
+
508
+ d_i_shapes = torch.empty(i_num, 2, dtype=torch.int32, device=device)
509
+ d_i_shapes.copy_(torch.tensor(i_shapes, dtype=torch.int32), non_blocking=True)
510
+
511
+ d_i_strides = torch.empty(i_num, 2, dtype=torch.int32, device=device)
512
+ d_i_strides.copy_(torch.tensor(i_strides, dtype=torch.int32), non_blocking=True)
513
+
514
+ d_i_offsets = torch.empty(i_num, dtype=torch.int32, device=device)
515
+ d_i_offsets.copy_(torch.tensor(i_offsets, dtype=torch.int32), non_blocking=True)
516
+
517
+ d_i_dtype = torch.empty(1, dtype=dtype, device=device)
518
+
519
+ return (
520
+ d_i_ptrs,
521
+ d_i_shapes,
522
+ d_i_strides,
523
+ d_i_offsets,
524
+ d_i_dtype,
525
+ output_size,
526
+ i_max_row_num,
527
+ device,
528
+ )
529
+
530
+
531
+ def fused_quantize_into_fp8(
532
+ inputs: list[torch.Tensor], all_reduce_group_size: int
533
+ ) -> torch.Tensor:
534
+ """
535
+ Quantizes a set of input tensors into fp8 where each row of each input
536
+ tensor is quantized individually. The result is stored in the output tensor.
537
+ Note that quantized rows and their scales are interleaved in the output
538
+ tensor. Conceptually the output tensor consists one row per rank in the all
539
+ reduce group. Each output row contains subset (input tensor rows are
540
+ divided by the all group size and padded if needed) of quantized rows from
541
+ the input tensors and their scales. The quantized rows are encoded as fp32
542
+ scale followed by fp8 values followed by padding to ensure aligned memory
543
+ access.
544
+
545
+ Args:
546
+ inputs: List of input tensors to be quantized
547
+ all_reduce_group_size: Size of the all-reduce group
548
+
549
+ Returns:
550
+ output: Output tensor that contains quantized rows and scales for all
551
+ ranks.
552
+ """
553
+
554
+ (
555
+ d_i_ptrs,
556
+ d_i_shapes,
557
+ d_i_strides,
558
+ d_i_offsets,
559
+ d_i_dtype,
560
+ output_size,
561
+ i_max_row_num,
562
+ device,
563
+ ) = _prepare_quantize_fp8(inputs, all_reduce_group_size)
564
+
565
+ # Allocate output tensor in scale dtype so that we can store scales by
566
+ # doing pointer arithmetic and do not get misaligned memory access.
567
+ output = torch.zeros(
568
+ output_size // SCALE_DTYPE_BYTES,
569
+ device=device,
570
+ dtype=SCALE_DTYPE,
571
+ ).view(torch.uint8)
572
+
573
+ grid = (i_max_row_num, len(inputs))
574
+ _fused_kernel_quantize_into_fp8[grid](
575
+ d_i_ptrs,
576
+ d_i_shapes,
577
+ d_i_strides,
578
+ d_i_offsets,
579
+ d_i_dtype,
580
+ output,
581
+ output_size // all_reduce_group_size,
582
+ all_reduce_group_size,
583
+ BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore
584
+ TL_FP8_TYPE=_get_fp8_type(),
585
+ TL_FP8_MAX=_get_fp8_max(),
586
+ )
587
+
588
+ return output
589
+
590
+
591
+ def fused_dequantize_from_fp8(
592
+ inputs: list[torch.Tensor], output: torch.Tensor, all_reduce_group_size: int
593
+ ) -> None:
594
+ """
595
+ Dequantizes a set of input tensors from fp8 stored in the output tensor.
596
+ The input tensors are expected to be of the same shape as those passed to
597
+ the quantization. The result of the dequantization is stored in the input
598
+ tensors. Note that quantized rows and their scales are interleaved in the
599
+ output tensor. Conceptually the output tensor consists one row per rank in
600
+ the all reduce group. Each output row contains subset (input tensor rows are
601
+ divided by the all group size and padded if needed) of quantized rows from
602
+ the input tensors and their scales.
603
+
604
+ Args:
605
+ inputs: List of input tensors to be dequantized into
606
+ output: Output tensor that contains quantized rows and scales for all
607
+ ranks.
608
+ all_reduce_group_size: Size of the all-reduce group
609
+ """
610
+ (
611
+ d_i_ptrs,
612
+ d_i_shapes,
613
+ d_i_strides,
614
+ d_i_offsets,
615
+ d_i_dtype,
616
+ output_size,
617
+ i_max_row_num,
618
+ device,
619
+ ) = _prepare_quantize_fp8(inputs, all_reduce_group_size)
620
+
621
+ assert output.shape[0] == output_size, "Output size does not match"
622
+
623
+ grid = (i_max_row_num, len(inputs))
624
+ _fused_kernel_dequantize_from_fp8[grid](
625
+ d_i_ptrs,
626
+ d_i_shapes,
627
+ d_i_strides,
628
+ d_i_offsets,
629
+ d_i_dtype,
630
+ output,
631
+ output_size // all_reduce_group_size,
632
+ all_reduce_group_size,
633
+ BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore
634
+ TL_FP8_TYPE=_get_fp8_type(),
635
+ )
636
+
637
+
638
+ def fused_reduce_fp8(
639
+ inputs: list[torch.Tensor],
640
+ output: torch.Tensor,
641
+ all_reduce_group_size: int,
642
+ all_reduce_rank: int,
643
+ reduce_op: ReduceOp = ReduceOp.SUM,
644
+ ) -> None:
645
+ """
646
+ Reduces rows of the output tensor for the given rank. The output tensor
647
+ is expected to be holding quantized rows and scales for all ranks. The
648
+ quantized rows are dequantized, averaged and quantized again. The result
649
+ is stored in the output tensor for the given rank. After the reduction
650
+ the row correspoding to the current rank can be shared with other
651
+ ranks.
652
+
653
+ Args:
654
+ inputs: List of input tensors to be reduced
655
+ output: Output tensor that contains quantized rows and scales for
656
+ all ranks.
657
+ all_reduce_group_size: Size of the all-reduce group
658
+ all_reduce_rank: Rank in the all-reduce group
659
+ """
660
+
661
+ (
662
+ d_i_ptrs,
663
+ d_i_shapes,
664
+ d_i_strides,
665
+ d_i_offsets,
666
+ d_i_dtype,
667
+ output_size,
668
+ i_max_row_num,
669
+ device,
670
+ ) = _prepare_quantize_fp8(inputs, all_reduce_group_size)
671
+
672
+ assert output.shape[0] == output_size, "Output size does not match"
673
+
674
+ grid = (i_max_row_num // all_reduce_group_size, len(inputs))
675
+ _fused_kernel_reduce_fp8[grid](
676
+ d_i_shapes,
677
+ d_i_offsets,
678
+ output,
679
+ output_size // all_reduce_group_size,
680
+ all_reduce_group_size,
681
+ all_reduce_rank,
682
+ 1.0 if reduce_op == ReduceOp.SUM else float(all_reduce_group_size),
683
+ BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore
684
+ TL_FP8_TYPE=_get_fp8_type(),
685
+ TL_FP8_MAX=_get_fp8_max(),
686
+ )