tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,868 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import torch
8
+ from flax import nnx
9
+ from flax.typing import PRNGKey
10
+ from jax.sharding import Mesh, NamedSharding
11
+ from jax.sharding import PartitionSpec as P
12
+ from torchax.ops.mappings import j2t_dtype
13
+ from vllm.config import VllmConfig
14
+
15
+ from tpu_inference import utils
16
+ from tpu_inference.layers.jax.attention.attention import AttentionMetadata
17
+ from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
18
+ from tpu_inference.layers.jax.constants import KVCacheType
19
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
20
+ from tpu_inference.layers.jax.moe.deepseek_v3_moe import (DeepSeekV3Router,
21
+ SparseMoE)
22
+ from tpu_inference.layers.jax.moe.moe import MoE
23
+ from tpu_inference.layers.jax.transformer_block import (
24
+ SharedExpertsTransformerBlock, TransformerBlock)
25
+ from tpu_inference.logger import init_logger
26
+ from tpu_inference.models.jax.utils.quantization.quantization_utils import \
27
+ get_quant_dtype_from_qwix_config
28
+ from tpu_inference.models.jax.utils.weight_utils import (
29
+ get_param, model_weights_generator, print_param_info, reshape_params)
30
+
31
+ logger = init_logger(__name__)
32
+
33
+ # A map from JAX dtype to the corresponding PyTorch integer dtype for raw memory viewing.
34
+ DTYPE_VIEW_MAP = {
35
+ jnp.dtype(jnp.float8_e4m3fn): torch.uint8,
36
+ jnp.dtype(jnp.bfloat16): torch.uint16,
37
+ jnp.dtype(jnp.float32): torch.uint32,
38
+ }
39
+
40
+
41
+ @dataclass
42
+ class DeepSeekV3(nnx.Module):
43
+
44
+ def __init__(self,
45
+ vllm_config: VllmConfig,
46
+ rng: jax.Array,
47
+ mesh: Mesh,
48
+ force_random_weights: bool = False):
49
+ assert mesh is not None
50
+
51
+ self.vllm_config = vllm_config
52
+ self.rng = nnx.Rngs(rng)
53
+
54
+ # NOTE: the default is 61
55
+ num_layers: int = vllm_config.model_config.hf_config.num_hidden_layers
56
+ num_local_experts: int = 256
57
+
58
+ vocab_size: int = 129280
59
+ hidden_size: int = 7168
60
+ # NOTE: this dtype may be implicitly overriden if using to Qwix to load in the quantized weights
61
+ dtype: jnp.dtype = jnp.bfloat16
62
+ num_attention_heads: int = 128
63
+ num_key_value_heads: int = 128
64
+ ffw_intermediate_size: int = 18432
65
+ moe_intermediate_size: int = 2048
66
+ num_experts_per_token: int = 8
67
+ n_group: int = 8
68
+ interleave_moe_layer_step: int = 1 # Deepseek V3 has moe_layer_freq=1 in hf config.
69
+ hidden_act: str = "silu"
70
+ rms_norm_eps: float = 1e-06
71
+ first_k_dense_replace: int = 3 # replace the first few MOE layers to dense layer.
72
+
73
+ num_shared_experts = 1
74
+ rope_theta = 10000
75
+ rope_scaling = {
76
+ "beta_fast": 32,
77
+ "beta_slow": 1,
78
+ "factor": 40,
79
+ "mscale": 1.0,
80
+ "mscale_all_dim": 1.0,
81
+ "original_max_position_embeddings": 4096,
82
+ "type": "yarn"
83
+ }
84
+ q_lora_rank = 1536
85
+ kv_lora_rank = 512
86
+ qk_nope_head_dim = 128
87
+ qk_rope_head_dim = 64
88
+ v_head_dim = 128
89
+
90
+ self.random_init = force_random_weights or self.vllm_config.additional_config.get(
91
+ "random_weights", False)
92
+ self.sparse_matmul = self.vllm_config.additional_config.get(
93
+ "sparse_matmul", False)
94
+
95
+ if isinstance(self.sparse_matmul, str):
96
+ self.sparse_matmul = self.sparse_matmul.lower() == "true"
97
+ else:
98
+ self.sparse_matmul = bool(self.sparse_matmul)
99
+
100
+ if self.sparse_matmul:
101
+ logger.info("sparse matmul is enabled")
102
+ else:
103
+ logger.info("sparse matmul is disabled, using dense matmul")
104
+ self.mesh = mesh
105
+
106
+ self.weight_loader = DeepSeekV3WeightLoader(
107
+ vllm_config=vllm_config,
108
+ num_layers=num_layers,
109
+ hidden_size=hidden_size,
110
+ q_lora_rank=q_lora_rank,
111
+ kv_lora_rank=kv_lora_rank,
112
+ attn_heads=num_attention_heads,
113
+ qk_nope_head_dim=qk_nope_head_dim,
114
+ qk_rope_head_dim=qk_rope_head_dim,
115
+ v_head_dim=v_head_dim,
116
+ num_local_experts=num_local_experts,
117
+ model_dtype=dtype)
118
+
119
+ self.embedder = Embedder(vocab_size=vocab_size,
120
+ hidden_size=hidden_size,
121
+ dtype=dtype,
122
+ rngs=self.rng,
123
+ vd_sharding=(('data', 'expert', 'model'),
124
+ None),
125
+ random_init=self.random_init)
126
+
127
+ self.layers = []
128
+
129
+ def _create_mla() -> MLA:
130
+ return MLA(
131
+ rope_theta=rope_theta,
132
+ rope_scaling=rope_scaling,
133
+ q_lora_rank=q_lora_rank,
134
+ kv_lora_rank=kv_lora_rank,
135
+ qk_nope_head_dim=qk_nope_head_dim,
136
+ qk_rope_head_dim=qk_rope_head_dim,
137
+ rms_norm_eps=rms_norm_eps,
138
+ v_head_dim=v_head_dim,
139
+ mesh=self.mesh,
140
+ random_init=self.random_init,
141
+ hidden_size=hidden_size,
142
+ num_attention_heads=num_attention_heads,
143
+ num_key_value_heads=num_key_value_heads,
144
+ head_dim=v_head_dim, # MLA uses v_head_dim as head_dim
145
+ dtype=dtype,
146
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
147
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype,
148
+ rngs=self.rng,
149
+ activation_attention_td=(None, None),
150
+ activation_q_td=(None, None),
151
+ query_tnh=P(None, 'model', None),
152
+ keyvalue_skh=P(None, 'model', None),
153
+ activation_attention_out_td=(None, None),
154
+ attn_o_tnh=P(None, 'model', None),
155
+ q_da_sharding=(None, 'model'),
156
+ anh_sharding=(None, 'model', None),
157
+ kv_da_sharding=(None, 'model'),
158
+ nhd_sharding=('model', None, None))
159
+
160
+ for i in range(first_k_dense_replace):
161
+ block = TransformerBlock(
162
+ pre_attention_norm=RMSNorm(
163
+ dims=hidden_size,
164
+ random_init=self.random_init,
165
+ epsilon=rms_norm_eps,
166
+ with_scale=True,
167
+ dtype=dtype,
168
+ rngs=self.rng,
169
+ ),
170
+ pre_mlp_norm=RMSNorm(
171
+ dims=hidden_size,
172
+ random_init=self.random_init,
173
+ epsilon=rms_norm_eps,
174
+ with_scale=True,
175
+ dtype=dtype,
176
+ rngs=self.rng,
177
+ ),
178
+ attn=_create_mla(),
179
+ custom_module=DenseFFW(dtype=dtype,
180
+ hidden_act=hidden_act,
181
+ hidden_size=hidden_size,
182
+ intermediate_size=ffw_intermediate_size,
183
+ rngs=self.rng,
184
+ df_sharding=(None, ('model', 'expert')),
185
+ fd_sharding=(('model', 'expert'), None),
186
+ random_init=self.random_init))
187
+
188
+ self.layers.append(block)
189
+
190
+ for i in range(first_k_dense_replace, num_layers):
191
+ is_moe_layer = ((i + 1) % interleave_moe_layer_step == 0)
192
+ router = DeepSeekV3Router(
193
+ random_init=self.random_init,
194
+ hidden_size=hidden_size,
195
+ num_experts=num_local_experts,
196
+ num_experts_per_tok=num_experts_per_token,
197
+ n_groups=n_group,
198
+ topk_groups=4,
199
+ norm_topk_prob=True,
200
+ rngs=self.rng,
201
+ routed_scaling_factor=2.5,
202
+ dtype=dtype,
203
+ activation_ffw_td=('data', None),
204
+ ed_sharding=('model', None),
205
+ e_sharding=('model', ))
206
+ if self.sparse_matmul:
207
+ # TODO: orginize the SparseMoE and DenseMoE better given they share most interfaces
208
+ custom_module = SparseMoE(
209
+ dtype=dtype,
210
+ num_local_experts=num_local_experts,
211
+ apply_expert_weight_before_computation=False,
212
+ hidden_size=hidden_size,
213
+ intermediate_size_moe=moe_intermediate_size,
214
+ num_experts_per_tok=num_experts_per_token,
215
+ mesh=self.mesh,
216
+ hidden_act=hidden_act,
217
+ rngs=self.rng,
218
+ random_init=self.random_init,
219
+ activation_ffw_td=('data', None),
220
+ activation_ffw_ted=('data', None, None),
221
+ edf_sharding=('model', None, None),
222
+ efd_sharding=('model', None, None),
223
+ quantized_dtype=self.weight_loader.quant_dtype
224
+ if self.weight_loader.is_model_quantized else None,
225
+ router=router) if is_moe_layer else DenseFFW(
226
+ dtype=dtype,
227
+ hidden_act=hidden_act,
228
+ hidden_size=hidden_size,
229
+ intermediate_size=ffw_intermediate_size,
230
+ rngs=self.rng,
231
+ random_init=self.random_init,
232
+ df_sharding=(None, ('model', 'expert')),
233
+ fd_sharding=(('model', 'expert'), None))
234
+ else:
235
+ custom_module = MoE(
236
+ dtype=dtype,
237
+ num_local_experts=num_local_experts,
238
+ apply_expert_weight_before_computation=False,
239
+ hidden_size=hidden_size,
240
+ intermediate_size_moe=moe_intermediate_size,
241
+ hidden_act=hidden_act,
242
+ rngs=self.rng,
243
+ random_init=self.random_init,
244
+ activation_ffw_td=('data', None),
245
+ activation_ffw_ted=('data', None, None),
246
+ edf_sharding=('model', None, None),
247
+ efd_sharding=('model', None, None),
248
+ router=router) if is_moe_layer else DenseFFW(
249
+ dtype=dtype,
250
+ hidden_act=hidden_act,
251
+ hidden_size=hidden_size,
252
+ intermediate_size=ffw_intermediate_size,
253
+ rngs=self.rng,
254
+ random_init=self.random_init,
255
+ df_sharding=(None, ('model', 'expert')),
256
+ fd_sharding=(('model', 'expert'), None))
257
+
258
+ shared_experts = DenseFFW(dtype=dtype,
259
+ hidden_act=hidden_act,
260
+ hidden_size=hidden_size,
261
+ intermediate_size=num_shared_experts *
262
+ moe_intermediate_size,
263
+ rngs=self.rng,
264
+ random_init=self.random_init,
265
+ df_sharding=(None, ('model', 'expert')),
266
+ fd_sharding=(('model', 'expert'), None))
267
+
268
+ pre_attention_norm = RMSNorm(
269
+ dims=hidden_size,
270
+ rngs=self.rng,
271
+ random_init=self.random_init,
272
+ epsilon=rms_norm_eps,
273
+ with_scale=True,
274
+ dtype=dtype,
275
+ )
276
+
277
+ pre_mlp_norm = RMSNorm(
278
+ dims=hidden_size,
279
+ rngs=self.rng,
280
+ random_init=self.random_init,
281
+ epsilon=rms_norm_eps,
282
+ with_scale=True,
283
+ dtype=dtype,
284
+ )
285
+
286
+ block = SharedExpertsTransformerBlock(
287
+ custom_module=custom_module,
288
+ attn=_create_mla(),
289
+ pre_attention_norm=pre_attention_norm,
290
+ pre_mlp_norm=pre_mlp_norm,
291
+ shared_experts=shared_experts)
292
+ self.layers.append(block)
293
+
294
+ self.final_norm = RMSNorm(
295
+ dims=hidden_size,
296
+ rngs=self.rng,
297
+ random_init=self.random_init,
298
+ epsilon=rms_norm_eps,
299
+ with_scale=True,
300
+ dtype=dtype,
301
+ )
302
+
303
+ self.lm_head = LMhead(vocab_size=vocab_size,
304
+ hidden_size=hidden_size,
305
+ dtype=dtype,
306
+ rngs=self.rng,
307
+ vd_sharding=(('data', 'expert', 'model'), None),
308
+ dv_sharding=(None, ('data', 'expert', 'model')),
309
+ random_init=self.random_init)
310
+
311
+ # For compatibility with flax.
312
+ def apply(self, variables, *args, **kwargs):
313
+ return self.__call__(*args, **kwargs)
314
+
315
+ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
316
+ # NOTE: Since we are using nnx.eval_shape to init the model,
317
+ # we have to pass dynamic arrays here for __call__'s usage.
318
+ self.rng = nnx.Rngs(rng)
319
+ self.weight_loader.load_weights(self)
320
+ self.initialize_cache()
321
+
322
+ def initialize_cache(self):
323
+ # Initialize RoPE caches after weights are loaded and before JIT compilation.
324
+ for layer in self.layers:
325
+ if hasattr(layer, 'attn') and hasattr(layer.attn, 'rope'):
326
+ if hasattr(layer.attn.rope, 'initialize_cache'):
327
+ layer.attn.rope.initialize_cache()
328
+
329
+ def __call__(
330
+ self,
331
+ kv_caches: List[jax.Array],
332
+ input_ids: jax.Array,
333
+ attention_metadata: AttentionMetadata,
334
+ *args,
335
+ ) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]:
336
+ is_prefill = False
337
+ x = self.embedder.encode(input_ids)
338
+ for (i, block) in enumerate(self.layers):
339
+ kv_cache = kv_caches[i]
340
+ new_kv_cache, x = block(x, is_prefill, kv_cache,
341
+ attention_metadata)
342
+ kv_caches[i] = new_kv_cache
343
+
344
+ final_activation = self.final_norm(x)
345
+
346
+ return kv_caches, final_activation, []
347
+
348
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
349
+ return self.lm_head.decode(hidden_states)
350
+
351
+
352
+ @dataclass
353
+ class DeepSeekV3WeightLoader:
354
+
355
+ def __init__(self, vllm_config: VllmConfig, num_layers, hidden_size,
356
+ q_lora_rank, kv_lora_rank, attn_heads, qk_nope_head_dim,
357
+ qk_rope_head_dim, v_head_dim, num_local_experts, model_dtype):
358
+
359
+ self.num_layers = num_layers
360
+ self.names_and_weights_generator = model_weights_generator(
361
+ model_name_or_path=vllm_config.model_config.model,
362
+ framework="pt",
363
+ download_dir=vllm_config.load_config.download_dir)
364
+ self.is_verbose = vllm_config.additional_config.get(
365
+ "is_verbose", None) is not None
366
+ self.num_routed_experts = num_local_experts
367
+ self.model_dtype = model_dtype
368
+
369
+ self._transpose_map = {
370
+ # dense mlp
371
+ r"mlp\.down_proj": (1, 0),
372
+ r"mlp\.gate_proj": (1, 0),
373
+ r"mlp\.up_proj": (1, 0),
374
+ # mla
375
+ r"q_a_proj": (1, 0),
376
+ r"q_b_proj": (2, 0, 1),
377
+ r"kv_a_proj_with_mqa": (1, 0),
378
+ r"kv_b_proj": (2, 0, 1),
379
+ r"o_proj": (1, 2, 0),
380
+ # moe
381
+ r"mlp\.gate\.weight": (1, 0),
382
+ r"mlp\.experts\.\d+\.gate_proj": (0, 2, 1),
383
+ r"mlp\.experts\.\d+\.down_proj": (0, 2, 1),
384
+ r"mlp\.experts\.\d+\.up_proj": (0, 2, 1),
385
+ r"mlp\.shared_experts\.down_proj": (1, 0),
386
+ r"mlp\.shared_experts\.gate_proj": (1, 0),
387
+ r"mlp\.shared_experts\.up_proj": (1, 0),
388
+ # lm_head
389
+ r"lm_head\.weight": (1, 0)
390
+ }
391
+ self._weight_shape_map = {
392
+ "q_b_proj":
393
+ (attn_heads, qk_nope_head_dim + qk_rope_head_dim, q_lora_rank),
394
+ "kv_b_proj":
395
+ (attn_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank),
396
+ "o_proj": (hidden_size, attn_heads, v_head_dim)
397
+ }
398
+
399
+ # Set the mappings from loaded parameter keys to standardized names.
400
+ self._loaded_to_standardized_keys = {
401
+ # encode & decode
402
+ "model.embed_tokens.weight":
403
+ "embedder.input_embedding_table_VD",
404
+ "lm_head.weight":
405
+ "lm_head.input_embedding_table_DV",
406
+ # final norm
407
+ "model.norm.weight":
408
+ "final_norm.scale",
409
+ # norm in transformer blocks
410
+ "model.layers.*.input_layernorm.weight":
411
+ "layers.*.pre_attention_norm.scale",
412
+ "model.layers.*.post_attention_layernorm.weight":
413
+ "layers.*.pre_mlp_norm.scale",
414
+ # attention (MLA)
415
+ "model.layers.*.self_attn.q_a_layernorm.weight":
416
+ "layers.*.attn.q_rms_norm.scale",
417
+ "model.layers.*.self_attn.kv_a_layernorm.weight":
418
+ "layers.*.attn.kv_rms_norm.scale",
419
+ "model.layers.*.self_attn.q_a_proj.weight":
420
+ "layers.*.attn.kernel_q_down_proj_DA",
421
+ "model.layers.*.self_attn.q_b_proj.weight":
422
+ "layers.*.attn.kernel_q_up_proj_ANH",
423
+ "model.layers.*.self_attn.kv_a_proj_with_mqa.weight":
424
+ "layers.*.attn.kernel_kv_down_proj_DA",
425
+ "model.layers.*.self_attn.kv_b_proj.weight":
426
+ "layers.*.attn.kernel_kv_up_proj_ANH",
427
+ "model.layers.*.self_attn.o_proj.weight":
428
+ "layers.*.attn.kernel_o_proj_NHD",
429
+ # Dense ffw
430
+ "model.layers.*.mlp.gate_proj.weight":
431
+ "layers.*.custom_module.kernel_gating_DF",
432
+ "model.layers.*.mlp.up_proj.weight":
433
+ "layers.*.custom_module.kernel_up_proj_DF",
434
+ "model.layers.*.mlp.down_proj.weight":
435
+ "layers.*.custom_module.kernel_down_proj_FD",
436
+ # MOE(routed experts)
437
+ "model.layers.*.mlp.gate.weight":
438
+ "layers.*.custom_module.router.kernel_DE",
439
+ "model.layers.*.mlp.gate.e_score_correction_bias":
440
+ "layers.*.custom_module.router.bias_E",
441
+ "model.layers.*.mlp.experts.*.gate_proj.weight":
442
+ "layers.*.custom_module.kernel_gating_EDF",
443
+ "model.layers.*.mlp.experts.*.down_proj.weight":
444
+ "layers.*.custom_module.kernel_down_proj_EFD",
445
+ "model.layers.*.mlp.experts.*.up_proj.weight":
446
+ "layers.*.custom_module.kernel_up_proj_EDF",
447
+ # MOE(shared experts)
448
+ "model.layers.*.mlp.shared_experts.down_proj.weight":
449
+ "layers.*.shared_experts.kernel_down_proj_FD",
450
+ "model.layers.*.mlp.shared_experts.gate_proj.weight":
451
+ "layers.*.shared_experts.kernel_gating_DF",
452
+ "model.layers.*.mlp.shared_experts.up_proj.weight":
453
+ "layers.*.shared_experts.kernel_up_proj_DF",
454
+ }
455
+
456
+ # TODO (jacobplatin): we shouldn't hard-code this, but the logic to obtain the true quantized dtype
457
+ # is non-trivial and the default checkpoints all use this dtype
458
+ self.quant_dtype = jnp.float8_e4m3fn
459
+
460
+ self.is_model_quantized = not vllm_config.additional_config.get(
461
+ "skip_quantization", False)
462
+ if self.is_model_quantized:
463
+ # TODO (jacobplatin): expand support eventually
464
+ quantization_type = vllm_config.model_config.hf_config.quantization_config[
465
+ "quant_method"]
466
+ assert quantization_type == "fp8", "DeepSeek only supports the fp8 quantization method for now"
467
+ self.scale_dtype, self.quant_dtype = get_quant_dtype_from_qwix_config(
468
+ vllm_config)
469
+
470
+ logger.info(
471
+ f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
472
+ )
473
+
474
+ quantization_block_sizes = vllm_config.model_config.hf_config.quantization_config[
475
+ "weight_block_size"]
476
+ assert len(
477
+ quantization_block_sizes
478
+ ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
479
+ self.quantization_block_size_n = quantization_block_sizes[0]
480
+ self.quantization_block_size_k = quantization_block_sizes[1]
481
+ # TODO (jacobplatin): remove this check in the future
482
+ assert self.quantization_block_size_n == self.quantization_block_size_k, "Quantization block size n and k must be the same!"
483
+ # NOTE: this is only needed for pre-quantized models
484
+ self._scale_shape_map = {
485
+ "q_b_proj": (1, qk_nope_head_dim + qk_rope_head_dim,
486
+ q_lora_rank // self.quantization_block_size_n),
487
+ "kv_b_proj": (attn_heads, (qk_nope_head_dim + v_head_dim) //
488
+ self.quantization_block_size_n,
489
+ kv_lora_rank // self.quantization_block_size_n),
490
+ "o_proj":
491
+ (hidden_size // self.quantization_block_size_n, attn_heads,
492
+ v_head_dim // self.quantization_block_size_n),
493
+ }
494
+ # NOTE: this is only needed for pre-quantized models when doing random weight loading
495
+ # TODO (jacobplatin): remove or clean this up
496
+ self.scale_shap_map_for_random_weight_loading = {
497
+ "kernel_kv_down_proj_DA": (56, 576),
498
+ "kernel_kv_up_proj_ANH": (4, 128, 2),
499
+ "kernel_q_up_proj_ANH": (12, 1, 192),
500
+ "kernel_o_proj_NHD": (128, 1, 56),
501
+ "kernel_down_proj_EFD": (256, 16, 56),
502
+ "kernel_up_proj_EDF": (256, 56, 16),
503
+ "kernel_gating_EDF": (256, 56, 16),
504
+ }
505
+
506
+ def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
507
+ # Find the corresponding model key using the HF key
508
+ if "layer" in loaded_key:
509
+ # extract layer number and replace it with *
510
+ layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
511
+ layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
512
+ # extract expert number if exists and replace it with *
513
+ if "experts" in loaded_key and "shared_experts" not in loaded_key:
514
+ layer_key = re.sub(r"experts\.\d+", "experts.*", layer_key)
515
+ # get standardized key and replace * with layer number.
516
+ mapped_key = self._loaded_to_standardized_keys.get(
517
+ layer_key, loaded_key)
518
+ mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
519
+ mapped_key)
520
+ else:
521
+ mapped_key = self._loaded_to_standardized_keys.get(
522
+ loaded_key, loaded_key)
523
+ return mapped_key
524
+
525
+ def _transpose_params(self, param_key: str, param_tensor: jax.Array):
526
+ for key, value in self._transpose_map.items():
527
+ if re.search(key, param_key):
528
+ return jnp.transpose(param_tensor, value)
529
+ return param_tensor # Base case / no-op
530
+
531
+ def _process_moe_weights(self, loaded_name, loaded_weight, weights_dict):
532
+ layer_num = re.search(r"layers\.(\d+)", loaded_name).group(1)
533
+ expert_num_str = re.search(r"experts\.(\d+)", loaded_name).group(1)
534
+ expert_idx = int(expert_num_str)
535
+
536
+ if layer_num not in weights_dict:
537
+ weights_dict[layer_num] = ([None] * self.num_routed_experts, 0)
538
+
539
+ expert_list, count = weights_dict[layer_num]
540
+
541
+ expert_list[expert_idx] = loaded_weight
542
+ count += 1
543
+ weights_dict[layer_num] = (expert_list, count)
544
+
545
+ if count == self.num_routed_experts:
546
+ stacked_weights = torch.stack(expert_list, axis=0)
547
+ del weights_dict[layer_num]
548
+ return stacked_weights
549
+ return None
550
+
551
+ def _load_individual_weight(self,
552
+ name,
553
+ weight,
554
+ model_params,
555
+ model_mesh,
556
+ scale=None) -> Tuple[int, int]:
557
+ """
558
+ Loads a single weight into the model.
559
+
560
+ NOTE: if using the base quantized model, it is assumed that the Qwix abstract
561
+ pass has been run and that the model weights are thus QArrays, which we
562
+ will then load the weights/scales into.
563
+
564
+ Args:
565
+ name: The name of the weight.
566
+ weight: The weight to load.
567
+ model_params: The model parameters.
568
+ model_mesh: The model mesh.
569
+ scale: The scale of the weight (if using the pre-quantized model).
570
+
571
+ Returns:
572
+ Tuple[int, int]: The size (in bytes) for the given layer overall and per shard.
573
+ NOTE: if using the pre-quantized model (with Qwix), we'll include the scale size as well.
574
+ """
575
+ mapped_name = self.map_loaded_to_standardized_name(name)
576
+ base_model_weight = get_param(model_params, mapped_name)
577
+ model_weight = base_model_weight.array.qvalue if hasattr(
578
+ base_model_weight, "array") else base_model_weight
579
+ sharding = base_model_weight.array.qvalue.sharding if hasattr(
580
+ base_model_weight, "array") else base_model_weight.sharding
581
+
582
+ # Convert weights from torch into numpy
583
+ cast_type = model_weight.value.dtype
584
+
585
+ torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
586
+
587
+ if torch_view_type:
588
+ # Avoid unnecessary upcasting and mem copy by viewing the tensor's
589
+ # raw data as integers before converting to a JAX array.
590
+ weight_np = jnp.array(
591
+ weight.view(torch_view_type).numpy()).view(cast_type)
592
+ else:
593
+ raise ValueError(
594
+ f"Unsupported dtype for tensor conversion: {cast_type}")
595
+
596
+ if scale is not None:
597
+ scale = scale.to(torch.float32).numpy().astype(self.scale_dtype)
598
+
599
+ # Reshape and transpose weights if necessary.
600
+ weight_np = reshape_params(name, weight_np, self._weight_shape_map)
601
+ if scale is not None:
602
+ scale = reshape_params(name, scale, self._scale_shape_map)
603
+ weight_np = self._transpose_params(name, weight_np)
604
+ if scale is not None:
605
+ scale = self._transpose_params(name, scale)
606
+ weight_shape = weight_np.shape
607
+ scale_shape = scale.shape
608
+ assert len(weight_shape) == len(scale_shape)
609
+ for idx, (weight_dim,
610
+ scale_dim) in enumerate(zip(weight_shape, scale_shape)):
611
+ if weight_dim // self.quantization_block_size_n != scale_dim and weight_dim // scale_dim != 1:
612
+ old_scale_shape = scale.shape
613
+ scale = scale.repeat(self.quantization_block_size_n,
614
+ axis=idx)[:, :weight_dim]
615
+ logger.warning(
616
+ f"Got a weight with shape {weight_shape} and scale with shape {old_scale_shape} "
617
+ f"where the scale_dim {scale_dim} does not match the weight_dim {weight_dim} "
618
+ f"multiplied by the quantization block size {self.quantization_block_size_n}. "
619
+ f"Repeating the scale to new shape {scale.shape} along axis {idx} with repeat size {self.quantization_block_size_n}."
620
+ )
621
+ break
622
+
623
+ if model_weight.value.shape != weight_np.shape:
624
+ raise ValueError(
625
+ f"Loaded shape for {name}: {weight_np.shape} "
626
+ f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
627
+ )
628
+
629
+ def get_slice(index):
630
+ return weight_np[index]
631
+
632
+ def get_slice_scale(index):
633
+ # ruff: noqa: F821
634
+ return scale[index]
635
+
636
+ sharded_array = jax.make_array_from_callback(
637
+ weight_np.shape, NamedSharding(model_mesh, P(*sharding)),
638
+ get_slice)
639
+
640
+ if scale is not None:
641
+ maybe_sharded_scale = scale
642
+ # Since, by default, we'll use the same sharding as the weights, we might
643
+ # encounter an error where the smaller/different sharding dim isn't divisible
644
+ # by the parallel size
645
+ # NOTE: Qwix expert dangyi@ mentioned that we don't need to worry about accuracy
646
+ # impacts when sharing scales
647
+ try:
648
+ maybe_sharded_scale = jax.make_array_from_callback(
649
+ scale.shape, NamedSharding(model_mesh, P(*sharding)),
650
+ get_slice_scale)
651
+ except ValueError:
652
+ logger.warning(
653
+ f"Could not create sharded scale for {name} with shape {scale.shape} and sharding {sharding}, skipping sharding..."
654
+ )
655
+ # NOTE: Despite the fact that scale has the name `scale_inv` in it, we don't need to
656
+ # inverse it
657
+ assert base_model_weight.array.scale.value.dtype == maybe_sharded_scale.dtype, "Expected dtype for model weight scale with name {mapped_name} and dtype ({base_model_weight.array.scale.value.dtype}) to match that of the incoming weight scale ({maybe_sharded_scale.dtype})"
658
+ assert base_model_weight.array.qvalue.value.dtype == sharded_array.dtype, "Expected dtype for model weight with name {mapped_name} and dtype ({base_model_weight.array.qvalue.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
659
+ base_model_weight.array.scale.value = maybe_sharded_scale
660
+ base_model_weight.array.qvalue.value = sharded_array
661
+ else:
662
+ assert model_weight.value.dtype == sharded_array.dtype, f"Expected dtype for model weight with name {mapped_name} and dtype ({model_weight.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
663
+ model_weight.value = sharded_array
664
+
665
+ model_weight_size_bytes = model_weight.nbytes / 1e9
666
+ model_weight_local_size_bytes = model_weight.addressable_shards[
667
+ 0].data.nbytes / 1e9
668
+
669
+ if scale is not None:
670
+ model_weight_size_bytes += base_model_weight.array.scale.nbytes / 1e9
671
+ model_weight_local_size_bytes += base_model_weight.array.scale.addressable_shards[
672
+ 0].data.nbytes / 1e9
673
+
674
+ if self.is_verbose:
675
+ logger.info(f"Memory usage after loading in {name}: "
676
+ f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
677
+ print_param_info(model_weight, name)
678
+ if scale is not None:
679
+ print_param_info(base_model_weight.array.scale,
680
+ "scale for " + name)
681
+
682
+ del weight, scale
683
+ return model_weight_size_bytes, model_weight_local_size_bytes
684
+
685
+ def load_weights(self, model_for_loading: nnx.Module):
686
+ model_params = nnx.state(model_for_loading)
687
+ logger.warning(
688
+ f"loaded_to_standardized_keys: {self._loaded_to_standardized_keys}"
689
+ )
690
+ cumulative_global_memory = 0
691
+ cumulative_local_memory = 0
692
+ mlp_experts_gate_proj_weights = {}
693
+ mlp_experts_gate_proj_scales = {}
694
+ mlp_experts_up_proj_weights = {}
695
+ mlp_experts_up_proj_scales = {}
696
+ mlp_experts_down_proj_weights = {}
697
+ mlp_experts_down_proj_scales = {}
698
+ quantized_weights = {}
699
+ quantized_scales = {}
700
+ with jax.default_device(jax.devices("cpu")[0]):
701
+ for loaded_name, loaded_weight in self.names_and_weights_generator:
702
+ # Skip if the model has fewer layers than original.
703
+ if re.search(r"layers\.(\d+)", loaded_name):
704
+ layer_num = re.search(r"layers\.(\d+)",
705
+ loaded_name).group(1)
706
+ if int(layer_num) >= self.num_layers:
707
+ del loaded_weight
708
+ continue
709
+ if 'layers.61' in loaded_name:
710
+ # skip loading MTP module.
711
+ del loaded_weight
712
+ continue
713
+ if re.search(r"experts\.(\d+)", loaded_name):
714
+ expert_num = re.search(r"experts\.(\d+)",
715
+ loaded_name).group(1)
716
+ if int(expert_num) >= self.num_routed_experts:
717
+ del loaded_weight
718
+ continue
719
+ # NOTE: loaded_weight will be a Torch tensor, so we need to convert it to the
720
+ # equivalent jnp dtype
721
+ # TODO (jacobplatin): refactor this so that we instead change / update `model_weights_generator`
722
+ # instead of checking "weight_scale_inv" and assuming quantization method is fp8
723
+ scale = None
724
+ if loaded_weight.dtype == j2t_dtype(self.quant_dtype.dtype):
725
+ if self.is_model_quantized:
726
+ scale_name = loaded_name.replace(
727
+ ".weight", ".weight_scale_inv")
728
+ if scale_name in quantized_scales:
729
+ scale = quantized_scales[scale_name]
730
+ del quantized_scales[scale_name]
731
+ else:
732
+ quantized_weights[loaded_name] = loaded_weight
733
+ continue
734
+ else:
735
+ quantized_weights[loaded_name] = loaded_weight
736
+ continue
737
+
738
+ if loaded_name.endswith(".weight_scale_inv"):
739
+ if self.is_model_quantized:
740
+ weight_name = loaded_name.replace(
741
+ ".weight_scale_inv", ".weight")
742
+ if weight_name in quantized_weights:
743
+ scale = loaded_weight
744
+ loaded_weight = quantized_weights[weight_name]
745
+ loaded_name = weight_name
746
+ del quantized_weights[weight_name]
747
+ else:
748
+ quantized_scales[loaded_name] = loaded_weight
749
+ continue
750
+ # In the case that we don't want to use the quantized weights,
751
+ # we'll dequantize the weights using the loaded scale on-the-fly
752
+ else:
753
+ # assuming weights are loaded before scales.
754
+ weight_name = loaded_name.replace(
755
+ ".weight_scale_inv", ".weight")
756
+ loaded_weight = weights_dequant_cpu(
757
+ quantized_weights[weight_name], loaded_weight,
758
+ self.model_dtype)
759
+ loaded_name = weight_name
760
+ del quantized_weights[weight_name]
761
+ # concat mlp.experts weights
762
+ stacked_scales = None
763
+ stacked_weights = None
764
+ if "mlp.experts" in loaded_name:
765
+ if "down_proj" in loaded_name:
766
+ stacked_weights = self._process_moe_weights(
767
+ loaded_name, loaded_weight,
768
+ mlp_experts_down_proj_weights)
769
+ if scale is not None:
770
+ stacked_scales = self._process_moe_weights(
771
+ loaded_name, scale,
772
+ mlp_experts_down_proj_scales)
773
+ if "gate_proj" in loaded_name:
774
+ stacked_weights = self._process_moe_weights(
775
+ loaded_name, loaded_weight,
776
+ mlp_experts_gate_proj_weights)
777
+ if scale is not None:
778
+ stacked_scales = self._process_moe_weights(
779
+ loaded_name, scale,
780
+ mlp_experts_gate_proj_scales)
781
+ if "up_proj" in loaded_name:
782
+ stacked_weights = self._process_moe_weights(
783
+ loaded_name, loaded_weight,
784
+ mlp_experts_up_proj_weights)
785
+ if scale is not None:
786
+ stacked_scales = self._process_moe_weights(
787
+ loaded_name, scale, mlp_experts_up_proj_scales)
788
+ if stacked_weights is not None:
789
+ weight_bytes, weight_shards = self._load_individual_weight(
790
+ loaded_name,
791
+ stacked_weights,
792
+ model_params,
793
+ model_for_loading.mesh,
794
+ scale=stacked_scales)
795
+ if self.is_verbose:
796
+ cumulative_global_memory += weight_bytes
797
+ cumulative_local_memory += weight_shards
798
+ logger.info(
799
+ f"Cumulative global memory: {cumulative_global_memory} GB"
800
+ )
801
+ logger.info(
802
+ f"Cumulative local memory: {cumulative_local_memory} GB"
803
+ )
804
+ else:
805
+ weight_bytes, weight_shards = self._load_individual_weight(
806
+ loaded_name,
807
+ loaded_weight,
808
+ model_params,
809
+ model_for_loading.mesh,
810
+ scale=scale)
811
+ if self.is_verbose:
812
+ cumulative_global_memory += weight_bytes
813
+ cumulative_local_memory += weight_shards
814
+ logger.info(
815
+ f"Cumulative global memory: {cumulative_global_memory} GB"
816
+ )
817
+ logger.info(
818
+ f"Cumulative local memory: {cumulative_local_memory} GB"
819
+ )
820
+
821
+ del mlp_experts_gate_proj_weights
822
+ del mlp_experts_up_proj_weights
823
+ del mlp_experts_down_proj_weights
824
+ del quantized_weights
825
+ del quantized_scales
826
+ # TODO: validate that all of the model_params were accounted for as well.
827
+ nnx.update(model_for_loading, model_params)
828
+
829
+
830
+ def weights_dequant_cpu(x: torch.Tensor,
831
+ s: torch.Tensor,
832
+ output_dtype: jnp.dtype,
833
+ block_size: int = 128) -> torch.Tensor:
834
+ assert x.dim() == 2 and s.dim() == 2, "Both x and s must be 2D tensors"
835
+ M, N = x.shape
836
+
837
+ x = x.to(torch.float32)
838
+ s = s.to(torch.float32)
839
+ y = torch.empty_like(x)
840
+
841
+ M_main = (M // block_size) * block_size
842
+ N_main = (N // block_size) * block_size
843
+
844
+ if M_main > 0 and N_main > 0:
845
+ x_main = x[:M_main, :N_main]
846
+ s_main = s[:(M // block_size), :(N // block_size)]
847
+
848
+ x_reshaped = x_main.view(M // block_size, block_size, N // block_size,
849
+ block_size).permute(0, 2, 1, 3)
850
+ s_reshaped = s_main.view(M // block_size, N // block_size, 1, 1)
851
+ y_main = (x_reshaped * s_reshaped).permute(0, 2, 1,
852
+ 3).reshape(M_main, N_main)
853
+
854
+ y[:M_main, :N_main] = y_main
855
+
856
+ if N_main < N:
857
+ for i in range(0, M_main, block_size):
858
+ block = x[i:i + block_size, N_main:N]
859
+ scale = s[i // block_size, N // block_size]
860
+ y[i:i + block_size, N_main:N] = block * scale
861
+
862
+ if M_main < M:
863
+ for j in range(0, N, block_size):
864
+ block = x[M_main:M, j:j + block_size]
865
+ scale = s[M // block_size, j // block_size]
866
+ y[M_main:M, j:j + block_size] = block * scale
867
+
868
+ return y.to(j2t_dtype(jnp.dtype(output_dtype)))