fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,65 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from typing import Callable, Optional
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from .common import to_device
15
+
16
+
17
+ # Merged indices with shape (T, B, L) -> (flattened indices with shape
18
+ # (T * B * L), offsets with shape (T * B + 1))
19
+ def get_table_batched_offsets_from_dense(
20
+ merged_indices: torch.Tensor,
21
+ L: Optional[int] = None,
22
+ total_B: Optional[int] = None,
23
+ use_cpu: bool = False,
24
+ ) -> tuple[torch.Tensor, torch.Tensor]:
25
+ if L is None and total_B is None:
26
+ (T, B, L) = merged_indices.size()
27
+ total_B = T * B
28
+ # pyre-fixme[6]: For 1st argument expected `Union[Sequence[SupportsIndex],
29
+ # SupportsIndex]` but got `Optional[int]`.
30
+ lengths = np.ones(total_B) * L
31
+ return (
32
+ to_device(merged_indices.contiguous().view(-1), use_cpu),
33
+ to_device(
34
+ torch.tensor(([0] + np.cumsum(lengths).tolist())).long(),
35
+ use_cpu,
36
+ ),
37
+ )
38
+
39
+
40
+ def get_offsets_from_dense(indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
41
+ (B, L) = indices.size()
42
+ return (
43
+ indices.contiguous().view(-1),
44
+ torch.tensor(
45
+ np.cumsum(np.asarray([0] + [L for _ in range(B)])[:-1]).astype(np.int64)
46
+ ),
47
+ )
48
+
49
+
50
+ def b_indices(
51
+ b: Callable[..., torch.Tensor],
52
+ x: torch.Tensor,
53
+ per_sample_weights: Optional[torch.Tensor] = None,
54
+ use_cpu: bool = False,
55
+ do_pooling: bool = True,
56
+ ) -> torch.Tensor:
57
+ (indices, offsets) = get_offsets_from_dense(x)
58
+ if do_pooling:
59
+ return b(
60
+ to_device(indices, use_cpu),
61
+ to_device(offsets, use_cpu),
62
+ per_sample_weights=per_sample_weights,
63
+ )
64
+ else:
65
+ return b(to_device(indices, use_cpu))
@@ -0,0 +1,251 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ # pyre-ignore-all-errors[61]
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+
14
+ from .common import to_device
15
+ from fbgemm_gpu.split_embedding_configs import (
16
+ FP8QuantizationConfig,
17
+ SparseType,
18
+ ) # usort:skip
19
+
20
+
21
+ def quantize_embs(
22
+ weight: torch.Tensor,
23
+ weight_ty: SparseType,
24
+ fp8_config: Optional[FP8QuantizationConfig] = None,
25
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
26
+ weight = weight.detach()
27
+ if weight_ty == SparseType.FP32:
28
+ q_weight = weight.float()
29
+ res_weight = q_weight.view(torch.uint8)
30
+ return (res_weight, None)
31
+
32
+ elif weight_ty == SparseType.FP16:
33
+ q_weight = weight.half()
34
+ res_weight = q_weight.view(torch.uint8)
35
+ return (res_weight, None)
36
+
37
+ elif weight_ty == SparseType.FP8:
38
+ assert fp8_config is not None
39
+ # Quantize FP32 to HPF8
40
+ res_weight = torch.ops.fbgemm.FloatToHFP8Quantized(
41
+ weight.float(),
42
+ fp8_config.get("exponent_bits"),
43
+ fp8_config.get("exponent_bias"),
44
+ fp8_config.get("max_position"),
45
+ )
46
+ return (res_weight, None)
47
+
48
+ elif weight_ty == SparseType.INT8:
49
+ # Note that FloatToFused8BitRowwiseQuantized might have additional padding
50
+ # for alignment if embedding dimension is not a multiple of 4:
51
+ # https://fburl.com/code/z009xsy6
52
+ q_weight = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(weight)
53
+ res_weight = q_weight[:, :-8].view(torch.uint8)
54
+ res_scale_shift = torch.tensor(
55
+ q_weight[:, -8:].view(torch.float32).to(torch.float16).view(torch.uint8)
56
+ ) # [-4, -2]: scale; [-2:]: bias
57
+ return (res_weight, res_scale_shift)
58
+
59
+ elif weight_ty == SparseType.INT4 or weight_ty == SparseType.INT2:
60
+ # Note that FP32 -> INT4/INT2 conersion op below might have additional padding
61
+ # for alignment: https://fburl.com/code/xx9kkduf
62
+ q_weight = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
63
+ weight,
64
+ bit_rate=weight_ty.bit_rate(),
65
+ )
66
+ res_weight = q_weight[:, :-4].view(torch.uint8)
67
+ res_scale_shift = torch.tensor(
68
+ q_weight[:, -4:].view(torch.uint8)
69
+ ) # [-4, -2]: scale; [-2:]: bias
70
+ return (res_weight, res_scale_shift)
71
+
72
+ else:
73
+ raise RuntimeError("Unsupported SparseType: {}".format(weight_ty))
74
+
75
+
76
+ def dequantize_embs(
77
+ weights: torch.Tensor,
78
+ scale_shift: torch.Tensor,
79
+ weight_ty: SparseType,
80
+ use_cpu: bool,
81
+ fp8_config: Optional[FP8QuantizationConfig] = None,
82
+ # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
83
+ ) -> torch.Tensor:
84
+ print(f"weight_ty: {weight_ty}")
85
+ assert (
86
+ weights.dtype == torch.uint8
87
+ ), "The input tensor for dequantize_embs function needs to be byte tensor"
88
+ th_weights = weights
89
+
90
+ if scale_shift is not None:
91
+ th_scale_shift: torch.Tensor = scale_shift.view(torch.float16).to(torch.float32)
92
+
93
+ if weight_ty == SparseType.INT4:
94
+ (E, D_2) = th_weights.shape
95
+ D = D_2 * 2
96
+
97
+ def comp(i: int) -> torch.Tensor:
98
+ subs = th_weights.view(torch.uint8) >> (i * 4)
99
+ sub_mask = subs & 0xF
100
+ result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
101
+ -1, 1
102
+ ).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
103
+ return result.to(torch.float32)
104
+
105
+ comps = [comp(i) for i in range(2)]
106
+ comps = torch.stack(comps)
107
+ comps = comps.permute(1, 2, 0)
108
+ comps = comps.reshape(E, D)
109
+ return to_device(torch.tensor(comps), use_cpu)
110
+
111
+ elif weight_ty == SparseType.INT2:
112
+ (E, D_4) = th_weights.shape
113
+ D = D_4 * 4
114
+
115
+ # pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
116
+ # pyre-fixme[53]: Captured variable `weights` is not annotated.
117
+ def comp(i: int) -> torch.Tensor:
118
+ subs = th_weights.view(torch.uint8) >> (i * 2)
119
+ sub_mask = subs & 0x3
120
+ result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
121
+ -1, 1
122
+ ).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
123
+ return result.to(torch.float32)
124
+
125
+ comps = [comp(i) for i in range(4)]
126
+ comps = torch.stack(comps)
127
+ comps = comps.permute(1, 2, 0)
128
+ comps = comps.reshape(E, D)
129
+ return to_device(torch.tensor(comps), use_cpu)
130
+
131
+ elif weight_ty == SparseType.INT8:
132
+ (E, D) = th_weights.shape
133
+ comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
134
+ torch.float32
135
+ ) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
136
+ return to_device(torch.tensor(comps), use_cpu)
137
+
138
+ elif weight_ty == SparseType.FP8:
139
+ assert fp8_config is not None
140
+ assert scale_shift is None
141
+ # Dequantize HPF8 to FP32
142
+ comps = torch.ops.fbgemm.HFP8QuantizedToFloat(
143
+ weights,
144
+ fp8_config.get("exponent_bits"),
145
+ fp8_config.get("exponent_bias"),
146
+ )
147
+ return to_device(comps, use_cpu)
148
+
149
+ elif weight_ty == SparseType.FP16:
150
+ assert scale_shift is None
151
+ comps = th_weights.view(torch.half)
152
+ return to_device(torch.tensor(comps), use_cpu)
153
+
154
+ elif weight_ty == SparseType.FP32:
155
+ assert scale_shift is None
156
+ comps = th_weights.view(torch.float32)
157
+ # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
158
+ return to_device(torch.tensor(comps), use_cpu)
159
+
160
+
161
+ def fake_quantize_embs(
162
+ weights: torch.Tensor,
163
+ scale_shift: Optional[torch.Tensor],
164
+ dequant_weights: torch.Tensor,
165
+ weight_ty: SparseType,
166
+ use_cpu: bool,
167
+ fp8_config: Optional[FP8QuantizationConfig] = None,
168
+ ) -> None:
169
+ assert (
170
+ weights.dtype == torch.uint8
171
+ ), "The input tensor for dequantize_embs function needs to be byte tensor"
172
+ th_weights = weights
173
+
174
+ if scale_shift is not None:
175
+ th_scale_shift: torch.Tensor = (
176
+ scale_shift.contiguous().view(torch.float16).to(torch.float32)
177
+ )
178
+
179
+ if weight_ty == SparseType.INT4:
180
+ (E, D_2) = th_weights.shape
181
+ D = D_2 * 2
182
+
183
+ def comp(i: int) -> torch.Tensor:
184
+ subs = th_weights.view(torch.uint8) >> (i * 4)
185
+ sub_mask = subs & 0xF
186
+ result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
187
+ -1, 1
188
+ ).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
189
+ return result.to(torch.float32)
190
+
191
+ comps = [comp(i) for i in range(2)]
192
+ comps = torch.stack(comps)
193
+ comps = comps.permute(1, 2, 0)
194
+ comps = comps.reshape(E, D)
195
+ dequant_weights.copy_(to_device(comps, use_cpu))
196
+
197
+ elif weight_ty == SparseType.INT2:
198
+ (E, D_4) = th_weights.shape
199
+ D = D_4 * 4
200
+
201
+ # pyre-fixme[53]: Captured variable `scale_shift` is not annotated.
202
+ # pyre-fixme[53]: Captured variable `weights` is not annotated.
203
+ def comp(i: int) -> torch.Tensor:
204
+ subs = th_weights.view(torch.uint8) >> (i * 2)
205
+ sub_mask = subs & 0x3
206
+ result = sub_mask.to(torch.float32) * th_scale_shift[:, 0].reshape(
207
+ -1, 1
208
+ ).to(torch.float32) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
209
+ return result.to(torch.float32)
210
+
211
+ comps = [comp(i) for i in range(4)]
212
+ comps = torch.stack(comps)
213
+ comps = comps.permute(1, 2, 0)
214
+ comps = comps.reshape(E, D)
215
+ dequant_weights.copy_(to_device(comps, use_cpu))
216
+
217
+ elif weight_ty == SparseType.INT8:
218
+ (E, D) = th_weights.shape
219
+ comps = th_weights.to(torch.float32) * th_scale_shift[:, 0].reshape(-1, 1).to(
220
+ torch.float32
221
+ ) + th_scale_shift[:, 1].reshape(-1, 1).to(torch.float32)
222
+ dequant_weights.copy_(to_device(comps, use_cpu))
223
+
224
+ elif weight_ty == SparseType.FP8:
225
+ assert fp8_config is not None
226
+ assert scale_shift is None
227
+ # Quantize FP32 to HPF8
228
+ comps = torch.ops.fbgemm.FloatToHFP8Quantized(
229
+ dequant_weights.detach().float(),
230
+ fp8_config.get("exponent_bits"),
231
+ fp8_config.get("exponent_bias"),
232
+ fp8_config.get("max_position"),
233
+ )
234
+ weights.copy_(comps)
235
+
236
+ # Dequantize HPF8 to FP32
237
+ comps = torch.ops.fbgemm.HFP8QuantizedToFloat(
238
+ comps,
239
+ fp8_config.get("exponent_bits"),
240
+ fp8_config.get("exponent_bias"),
241
+ )
242
+ dequant_weights.copy_(to_device(comps, use_cpu))
243
+
244
+ elif weight_ty == SparseType.FP16:
245
+ assert scale_shift is None
246
+ comps = dequant_weights.detach().half().view(torch.uint8)
247
+ weights.copy_(comps)
248
+ elif weight_ty == SparseType.FP32:
249
+ assert scale_shift is None
250
+ comps = dequant_weights.detach().float().view(torch.uint8)
251
+ weights.copy_(comps)