tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from typing import Any, Optional
3
17
 
@@ -15,9 +29,10 @@ from vllm.utils.func_utils import supports_kw
15
29
  from tpu_inference import envs
16
30
  from tpu_inference.layers.common.sharding import ShardingAxisName
17
31
  from tpu_inference.logger import init_logger
18
- from tpu_inference.models.jax.utils.quantization.quantization_utils import (
32
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import (
19
33
  apply_qwix_on_abstract_model, apply_qwix_quantization,
20
- load_random_weights_into_qwix_abstract_model)
34
+ load_random_weights_into_qwix_abstract_model,
35
+ update_vllm_config_for_qwix_quantization)
21
36
  from tpu_inference.utils import to_jax_dtype, to_torch_dtype
22
37
 
23
38
  logger = init_logger(__name__)
@@ -218,6 +233,10 @@ def get_flax_model(
218
233
  model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
219
234
  vllm_config.model_config.dtype = model_dtype
220
235
 
236
+ # Only perform qwix quantization if it is jax model.
237
+ if vllm_config.model_config:
238
+ update_vllm_config_for_qwix_quantization(vllm_config)
239
+
221
240
  if is_draft_model:
222
241
  model_class = _get_model_architecture(
223
242
  vllm_config.speculative_config.draft_model_config.hf_config)
@@ -269,10 +288,9 @@ def get_flax_model(
269
288
 
270
289
  # Multi-modal support only
271
290
  # This function calculates the image token's embeddings by VIT
272
- def run_get_multimodal_embeddings(graphdef, state, image_grid_thw,
273
- **kwargs):
291
+ def run_embed_multimodal(graphdef, state, image_grid_thw, **kwargs):
274
292
  model = nnx.merge(graphdef, state)
275
- return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
293
+ return model.embed_multimodal(image_grid_thw, **kwargs)
276
294
 
277
295
  embed_sharding = NamedSharding(mesh, PartitionSpec(None))
278
296
  # This function will calculates the embeddings of input texts and then merge with the image embeddings
@@ -280,9 +298,9 @@ def get_flax_model(
280
298
  jax.jit,
281
299
  out_shardings=(embed_sharding),
282
300
  )
283
- def run_get_input_embeddings(graphdef, state, *args, **kwargs):
301
+ def run_embed_input_ids(graphdef, state, *args, **kwargs):
284
302
  model = nnx.merge(graphdef, state)
285
- return model.get_input_embeddings(*args, **kwargs)
303
+ return model.embed_input_ids(*args, **kwargs)
286
304
 
287
305
  # For models that want to work with EAGLE-3 speculative decoding
288
306
  @functools.partial(
@@ -298,10 +316,8 @@ def get_flax_model(
298
316
  None)
299
317
  model_fn = functools.partial(run_model, graphdef)
300
318
  compute_logits_fn = functools.partial(run_compute_logits, graphdef)
301
- get_multimodal_embeddings_fn = functools.partial(
302
- run_get_multimodal_embeddings, graphdef)
303
- get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
304
- graphdef)
319
+ embed_multimodal_fn = functools.partial(run_embed_multimodal, graphdef)
320
+ embed_input_ids_fn = functools.partial(run_embed_input_ids, graphdef)
305
321
  lora_manager, model = None, None
306
322
  combine_hidden_states_fn = functools.partial(combine_hidden_states,
307
323
  graphdef)
@@ -312,8 +328,8 @@ def get_flax_model(
312
328
 
313
329
  multimodal_fns = {
314
330
  "precompile_vision_encoder_fn": precompile_vision_encoder_fn,
315
- "get_multimodal_embeddings_fn": get_multimodal_embeddings_fn,
316
- "get_input_embeddings_fn": get_input_embeddings_fn,
331
+ "embed_multimodal_fn": embed_multimodal_fn,
332
+ "embed_input_ids_fn": embed_input_ids_fn,
317
333
  "get_mrope_input_positions_fn": get_mrope_input_positions_fn,
318
334
  }
319
335
 
@@ -365,6 +381,11 @@ def get_model(
365
381
 
366
382
  match impl:
367
383
  case "flax_nnx":
384
+ if vllm_config.parallel_config.pipeline_parallel_size > 1:
385
+ logger.warning(
386
+ "PP is not fully supported on Jax flax_nnx models yet, fallback to vllm models."
387
+ )
388
+ return get_vllm_model(vllm_config, rng, mesh)
368
389
  try:
369
390
  # Try to load the flax model first
370
391
  return get_flax_model(vllm_config, rng, mesh, is_draft_model)
@@ -466,14 +487,14 @@ def register_model(arch: str, model: Any) -> None:
466
487
  )
467
488
 
468
489
  # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
469
- def unimplemented_get_input_embeddings(
490
+ def unimplemented_embed_input_ids(
470
491
  self,
471
492
  input_ids: "torch.Tensor",
472
493
  positions: "torch.Tensor",
473
494
  inputs_embeds: Optional["torch.Tensor"] = None,
474
495
  ) -> "torch.Tensor":
475
496
  raise NotImplementedError(
476
- "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
497
+ "This is a JAX model and does not implement the PyTorch embed_input_ids method."
477
498
  )
478
499
 
479
500
  # We need a custom __init__ that only calls torch.nn.Module's init,
@@ -489,7 +510,7 @@ def register_model(arch: str, model: Any) -> None:
489
510
  {
490
511
  "__init__": wrapper_init,
491
512
  "forward": unimplemented_forward,
492
- "get_input_embeddings": unimplemented_get_input_embeddings,
513
+ "embed_input_ids": unimplemented_embed_input_ids,
493
514
  # Prevent vLLM from trying to load weights into this dummy class.
494
515
  "load_weights": lambda self, *args, **kwargs: None,
495
516
  })
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import os
2
16
  import re
3
17
  from dataclasses import dataclass
@@ -14,6 +28,7 @@ from torchax.ops.mappings import j2t_dtype
14
28
  from vllm.config import VllmConfig
15
29
 
16
30
  from tpu_inference import utils
31
+ from tpu_inference.layers.common.quantization import u8_unpack_e2m1
17
32
  from tpu_inference.layers.common.sharding import ShardingAxisName
18
33
  from tpu_inference.layers.jax.attention.attention import AttentionMetadata
19
34
  from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
@@ -25,10 +40,8 @@ from tpu_inference.layers.jax.moe.moe import MoE
25
40
  from tpu_inference.layers.jax.transformer_block import (
26
41
  SharedExpertsTransformerBlock, TransformerBlock)
27
42
  from tpu_inference.logger import init_logger
28
- from tpu_inference.models.jax.utils.quantization.quantization_utils import \
29
- get_quant_dtype_from_qwix_config
30
43
  from tpu_inference.models.jax.utils.weight_utils import (
31
- get_param, model_weights_generator, print_param_info, reshape_params)
44
+ get_param, model_weights_generator, print_param_info)
32
45
 
33
46
  logger = init_logger(__name__)
34
47
 
@@ -73,6 +86,8 @@ class DeepSeekV3(nnx.Module):
73
86
  first_k_dense_replace: int = 3 # replace the first few MOE layers to dense layer.
74
87
  self.use_mla_kernel: bool = self.vllm_config.model_config.use_mla
75
88
 
89
+ logger.info(f"Is using MLA kernel in DeepSeek: {self.use_mla_kernel}")
90
+
76
91
  num_shared_experts = 1
77
92
  rope_theta = 10000
78
93
  rope_scaling = {
@@ -169,9 +184,10 @@ class DeepSeekV3(nnx.Module):
169
184
  activation_attention_out_td=(None, None),
170
185
  attn_o_tnh=attn_o_tnh_spec,
171
186
  q_da_sharding=(None, ShardingAxisName.VOCAB),
187
+ ap_sharding=(None, ShardingAxisName.MLP_TENSOR),
172
188
  anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
173
189
  kv_da_sharding=(None, ShardingAxisName.VOCAB),
174
- nhd_sharding=(ShardingAxisName.MLP_TENSOR, None, None))
190
+ rd_sharding=(ShardingAxisName.MLP_TENSOR, None))
175
191
 
176
192
  for i in range(first_k_dense_replace):
177
193
  block = TransformerBlock(
@@ -422,12 +438,12 @@ class DeepSeekV3WeightLoader:
422
438
  r"mlp\.up_proj": (1, 0),
423
439
  # mla
424
440
  r"q_a_proj": (1, 0),
425
- r"q_b_proj": (2, 0, 1),
441
+ r"q_b_proj": (1, 0),
426
442
  r"kv_a_proj_with_mqa": (1, 0),
427
- r"kv_b_proj": (2, 0, 1),
443
+ r"kv_b_proj": (1, 0),
428
444
  r"k_b_proj": (2, 0, 1), # used for MLA kernel
429
445
  r"v_b_proj": (2, 0, 1), # used for MLA kernel
430
- r"o_proj": (1, 2, 0),
446
+ r"o_proj": (1, 0),
431
447
  # moe
432
448
  r"mlp\.gate\.weight": (1, 0),
433
449
  r"mlp\.experts\.\d+\.gate_proj": (0, 2, 1),
@@ -439,15 +455,6 @@ class DeepSeekV3WeightLoader:
439
455
  # lm_head
440
456
  r"lm_head\.weight": (1, 0)
441
457
  }
442
- self._weight_shape_map = {
443
- "q_b_proj":
444
- (attn_heads, qk_nope_head_dim + qk_rope_head_dim, q_lora_rank),
445
- "kv_b_proj":
446
- (attn_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank),
447
- "k_b_proj": (attn_heads, qk_nope_head_dim, kv_lora_rank),
448
- "v_b_proj": (attn_heads, v_head_dim, kv_lora_rank),
449
- "o_proj": (hidden_size, attn_heads, v_head_dim)
450
- }
451
458
 
452
459
  # Set the mappings from loaded parameter keys to standardized names.
453
460
  self._loaded_to_standardized_keys = {
@@ -472,13 +479,13 @@ class DeepSeekV3WeightLoader:
472
479
  "model.layers.*.self_attn.q_a_proj.weight":
473
480
  "layers.*.attn.kernel_q_down_proj_DA",
474
481
  "model.layers.*.self_attn.q_b_proj.weight":
475
- "layers.*.attn.kernel_q_up_proj_ANH",
482
+ "layers.*.attn.kernel_q_up_proj_AP",
476
483
  "model.layers.*.self_attn.kv_a_proj_with_mqa.weight":
477
484
  "layers.*.attn.kernel_kv_down_proj_DA",
478
485
  "model.layers.*.self_attn.kv_b_proj.weight":
479
- "layers.*.attn.kernel_kv_up_proj_ANH",
486
+ "layers.*.attn.kernel_kv_up_proj_AL",
480
487
  "model.layers.*.self_attn.o_proj.weight":
481
- "layers.*.attn.kernel_o_proj_NHD",
488
+ "layers.*.attn.kernel_o_proj_RD",
482
489
  # Dense ffw
483
490
  "model.layers.*.mlp.gate_proj.weight":
484
491
  "layers.*.custom_module.kernel_gating_DF",
@@ -512,66 +519,43 @@ class DeepSeekV3WeightLoader:
512
519
  "model.layers.*.self_attn.v_b_proj.weight":
513
520
  "layers.*.attn.kernel_v_up_proj_ANH",
514
521
  })
515
-
516
- # TODO (jacobplatin): we shouldn't hard-code this, but the logic to obtain the true quantized dtype
517
- # is non-trivial and the default checkpoints all use this dtype
518
- self.quant_dtype = jnp.float8_e4m3fn
522
+ # TODO (jacobplatin): we should not be hard-coding these
523
+ self.scale_dtype, self.quant_dtype = jnp.bfloat16, jnp.float8_e4m3fn
519
524
 
520
525
  self.is_model_quantized = not vllm_config.additional_config.get(
521
526
  "skip_quantization", False)
522
- if self.is_model_quantized:
523
- # TODO (jacobplatin): expand support eventually
524
- quantization_type = vllm_config.model_config.hf_config.quantization_config[
525
- "quant_method"]
526
- assert quantization_type == "fp8", "DeepSeek only supports the fp8 quantization method for now"
527
- self.scale_dtype, self.quant_dtype = get_quant_dtype_from_qwix_config(
528
- vllm_config)
529
-
530
- logger.info(
531
- f"Quantizing DeepSeek with quantization dtype: {self.quant_dtype} and scale dtype: {self.scale_dtype}"
532
- )
533
527
 
534
- quantization_block_sizes = vllm_config.model_config.hf_config.quantization_config[
535
- "weight_block_size"]
536
- assert len(
537
- quantization_block_sizes
538
- ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
539
- self.quantization_block_size_n = quantization_block_sizes[0]
540
- self.quantization_block_size_k = quantization_block_sizes[1]
541
- # TODO (jacobplatin): remove this check in the future
542
- assert self.quantization_block_size_n == self.quantization_block_size_k, "Quantization block size n and k must be the same!"
543
- # NOTE: this is only needed for pre-quantized models
544
- self._scale_shape_map = {
545
- "q_b_proj": (1, qk_nope_head_dim + qk_rope_head_dim,
546
- q_lora_rank // self.quantization_block_size_n),
547
- "kv_b_proj": (attn_heads, (qk_nope_head_dim + v_head_dim) //
548
- self.quantization_block_size_n,
549
- kv_lora_rank // self.quantization_block_size_n),
550
- # used for MLA kernel
551
- "k_b_proj":
552
- (attn_heads,
553
- qk_nope_head_dim // self.quantization_block_size_n,
554
- kv_lora_rank // self.quantization_block_size_n),
555
- # used for MLA kernel
556
- "v_b_proj":
557
- (attn_heads, v_head_dim // self.quantization_block_size_n,
558
- kv_lora_rank // self.quantization_block_size_n),
559
- "o_proj":
560
- (hidden_size // self.quantization_block_size_n, attn_heads,
561
- v_head_dim // self.quantization_block_size_n),
562
- }
528
+ if self.is_model_quantized:
563
529
  # NOTE: this is only needed for pre-quantized models when doing random weight loading
530
+ # because the scales that Qwix configures by default don't necessarily match the
531
+ # scales in practice
564
532
  # TODO (jacobplatin): remove or clean this up
565
- self.scale_shap_map_for_random_weight_loading = {
566
- "kernel_kv_down_proj_DA": (56, 576),
567
- "kernel_kv_up_proj_ANH": (4, 128, 2),
568
- "kernel_q_up_proj_ANH": (12, 1, 192),
569
- "kernel_o_proj_NHD": (128, 1, 56),
570
- "kernel_down_proj_EFD": (256, 16, 56),
571
- "kernel_up_proj_EDF": (256, 56, 16),
572
- "kernel_gating_EDF": (256, 56, 16),
533
+ self.scale_shape_map_for_random_weight_loading = {
534
+ # MoE experts (3D)
535
+ "custom_module.kernel_down_proj_EFD": (256, 8, 7168),
536
+ "custom_module.kernel_gating_EDF": (256, 28, 2048),
537
+ "custom_module.kernel_up_proj_EDF": (256, 28, 2048),
538
+ # Shared experts (2D)
539
+ "shared_experts.kernel_down_proj_FD": (8, 7168),
540
+ "shared_experts.kernel_gating_DF": (28, 2048),
541
+ "shared_experts.kernel_up_proj_DF": (28, 2048),
542
+ # Dense FFW (2D)
543
+ "custom_module.kernel_gating_DF": (28, 18432),
544
+ "custom_module.kernel_up_proj_DF": (28, 18432),
545
+ "custom_module.kernel_down_proj_FD": (72, 7168),
546
+ # Attention (3D for MLA, 2D for the rest)
547
+ "attn.kernel_q_down_proj_DA": (28, 1536),
548
+ "attn.kernel_q_up_proj_AP": (6, 24576),
549
+ "attn.kernel_kv_down_proj_DA": (28, 576),
550
+ "attn.kernel_kv_up_proj_AL": (2, 32768),
551
+ "attn.kernel_o_proj_RD": (64, 7168),
552
+ "attn.kernel_k_up_proj_ANH": (2, 128, 128), # MLA
553
+ "attn.kernel_v_up_proj_ANH": (2, 128, 128), # MLA
573
554
  }
574
555
 
556
+ # TODO (jacobplatin): remove this check eventually!
557
+ assert self.quant_dtype == jnp.float8_e4m3fn, f"Expected quant_dtype to be float8_e4m3fn for DeepSeek but got {self.quant_dtype}"
558
+
575
559
  def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
576
560
  # Find the corresponding model key using the HF key
577
561
  if "layer" in loaded_key:
@@ -649,45 +633,56 @@ class DeepSeekV3WeightLoader:
649
633
  base_model_weight, "array") else base_model_weight.sharding
650
634
 
651
635
  # Convert weights from torch into numpy
652
- cast_type = model_weight.value.dtype
653
-
654
- torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
655
-
656
- if torch_view_type:
657
- # Avoid unnecessary upcasting and mem copy by viewing the tensor's
658
- # raw data as integers before converting to a JAX array.
659
- weight_np = jnp.array(
660
- weight.view(torch_view_type).numpy()).view(cast_type)
636
+ if weight.dtype == torch.uint8 and scale is not None:
637
+ # Assume packed FP4 format when uint8 weights with scale provided
638
+ weight_jax_u8 = jnp.array(weight.cpu().numpy())
639
+ weight_np = u8_unpack_e2m1(weight_jax_u8)
640
+ scale = scale.to(torch.float32).numpy().astype(self.scale_dtype)
661
641
  else:
662
- raise ValueError(
663
- f"Unsupported dtype for tensor conversion: {cast_type}")
642
+ cast_type = model_weight.value.dtype
643
+ # Special-case: FP4 values stored as FP8 for compatibility.
644
+ # If the model expects float4_e2m1fn but the checkpoint provides FP8,
645
+ # convert by numeric value (float32) then cast to float4.
646
+ if cast_type == jnp.float4_e2m1fn and weight.dtype == torch.float8_e4m3fn:
647
+ weight_np = jnp.array(weight.float().numpy()).astype(cast_type)
648
+ else:
649
+ torch_view_type = DTYPE_VIEW_MAP.get(jnp.dtype(cast_type))
664
650
 
665
- if scale is not None:
666
- scale = scale.to(torch.float32).numpy().astype(self.scale_dtype)
651
+ if torch_view_type:
652
+ # Avoid unnecessary upcasting and mem copy by viewing the tensor's
653
+ # raw data as integers before converting to a JAX array.
654
+ weight_np = jnp.array(
655
+ weight.view(torch_view_type).numpy()).view(cast_type)
656
+ else:
657
+ raise ValueError(
658
+ f"Unsupported dtype for tensor conversion: {cast_type}"
659
+ )
667
660
 
668
- # Reshape and transpose weights if necessary.
669
- weight_np = reshape_params(name, weight_np, self._weight_shape_map)
670
- if scale is not None:
671
- scale = reshape_params(name, scale, self._scale_shape_map)
661
+ if scale is not None:
662
+ scale = scale.to(torch.float32).numpy().astype(
663
+ self.scale_dtype)
672
664
  weight_np = self._transpose_params(name, weight_np)
673
665
  if scale is not None:
674
666
  scale = self._transpose_params(name, scale)
667
+ # Ensure scale is broadcastable to weight_np by repeating per-axis.
675
668
  weight_shape = weight_np.shape
676
669
  scale_shape = scale.shape
677
- assert len(weight_shape) == len(scale_shape)
678
- for idx, (weight_dim,
679
- scale_dim) in enumerate(zip(weight_shape, scale_shape)):
680
- if weight_dim // self.quantization_block_size_n != scale_dim and weight_dim // scale_dim != 1:
681
- old_scale_shape = scale.shape
682
- scale = scale.repeat(self.quantization_block_size_n,
683
- axis=idx)[:, :weight_dim]
670
+ if len(weight_shape) == len(scale_shape):
671
+ new_scale = scale
672
+ for wdim, sdim in zip(weight_shape, scale_shape):
673
+ if (wdim % sdim != 0):
674
+ raise ValueError(
675
+ f"Weight dim {wdim} is not divisible by scale dim {sdim} for weight {name} with shape {weight_shape} and scale {scale_shape}!"
676
+ )
677
+ if scale_shape != new_scale.shape:
684
678
  logger.warning(
685
- f"Got a weight with shape {weight_shape} and scale with shape {old_scale_shape} "
686
- f"where the scale_dim {scale_dim} does not match the weight_dim {weight_dim} "
687
- f"multiplied by the quantization block size {self.quantization_block_size_n}. "
688
- f"Repeating the scale to new shape {scale.shape} along axis {idx} with repeat size {self.quantization_block_size_n}."
679
+ f"Adjusted scale shape {scale_shape} to {new_scale.shape} to match weight {weight_shape}"
689
680
  )
690
- break
681
+ scale = new_scale
682
+ else:
683
+ raise ValueError(
684
+ f"Scale rank {scale_shape} does not match weight rank {weight_shape}"
685
+ )
691
686
 
692
687
  if model_weight.value.shape != weight_np.shape:
693
688
  raise ValueError(
@@ -721,10 +716,8 @@ class DeepSeekV3WeightLoader:
721
716
  logger.warning(
722
717
  f"Could not create sharded scale for {name} with shape {scale.shape} and sharding {sharding}, skipping sharding..."
723
718
  )
724
- # NOTE: Despite the fact that scale has the name `scale_inv` in it, we don't need to
725
- # inverse it
726
- assert base_model_weight.array.scale.value.dtype == maybe_sharded_scale.dtype, "Expected dtype for model weight scale with name {mapped_name} and dtype ({base_model_weight.array.scale.value.dtype}) to match that of the incoming weight scale ({maybe_sharded_scale.dtype})"
727
- assert base_model_weight.array.qvalue.value.dtype == sharded_array.dtype, "Expected dtype for model weight with name {mapped_name} and dtype ({base_model_weight.array.qvalue.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
719
+ assert base_model_weight.array.scale.value.dtype == maybe_sharded_scale.dtype, f"Expected dtype for model weight scale with name {mapped_name} and dtype ({base_model_weight.array.scale.value.dtype}) to match that of the incoming weight scale ({maybe_sharded_scale.dtype})"
720
+ assert base_model_weight.array.qvalue.value.dtype == sharded_array.dtype, f"Expected dtype for model weight with name {mapped_name} and dtype ({base_model_weight.array.qvalue.value.dtype}) to match that of the incoming weight ({sharded_array.dtype})"
728
721
  base_model_weight.array.scale.value = maybe_sharded_scale
729
722
  base_model_weight.array.qvalue.value = sharded_array
730
723
  else:
@@ -790,7 +783,11 @@ class DeepSeekV3WeightLoader:
790
783
  # TODO (jacobplatin): refactor this so that we instead change / update `model_weights_generator`
791
784
  # instead of checking "weight_scale_inv" and assuming quantization method is fp8
792
785
  scale = None
793
- if loaded_weight.dtype == j2t_dtype(self.quant_dtype.dtype):
786
+ # Mixed quantization: accept both fp8 and packed fp4 (uint8) tensors
787
+ allowed_quant_dtypes = {
788
+ j2t_dtype(self.quant_dtype.dtype), torch.uint8
789
+ }
790
+ if loaded_weight.dtype in allowed_quant_dtypes:
794
791
  if self.is_model_quantized:
795
792
  scale_name = loaded_name.replace(
796
793
  ".weight", ".weight_scale_inv")
@@ -880,11 +877,9 @@ class DeepSeekV3WeightLoader:
880
877
  self.qk_nope_head_dim + self.v_head_dim,
881
878
  self.kv_lora_rank)
882
879
  k_weight = weight_reshaped[:, :self.
883
- qk_nope_head_dim, :].reshape(
884
- -1, self.kv_lora_rank)
885
- v_weight = weight_reshaped[:, self.
886
- qk_nope_head_dim:, :].reshape(
887
- -1, self.kv_lora_rank)
880
+ qk_nope_head_dim, :]
881
+ v_weight = weight_reshaped[:,
882
+ self.qk_nope_head_dim:, :]
888
883
 
889
884
  loaded_weights_list = [k_weight, v_weight]
890
885
  loaded_names = [
@@ -894,25 +889,19 @@ class DeepSeekV3WeightLoader:
894
889
 
895
890
  scales_list = [None, None]
896
891
  if scale is not None:
897
- bn = self.quantization_block_size_n
898
- bk = self.quantization_block_size_k
892
+ assert loaded_weight.shape[0] == scale.shape[0]
893
+ block_size_k = loaded_weight.shape[
894
+ 1] // scale.shape[1]
895
+ assert block_size_k > 0, f"Expected non-zero block size but got {block_size_k}!"
899
896
  scale_reshaped = scale.view(
900
897
  self.attn_heads,
901
- (self.qk_nope_head_dim + self.v_head_dim) //
902
- bn, self.kv_lora_rank // bk)
898
+ (self.qk_nope_head_dim + self.v_head_dim),
899
+ self.kv_lora_rank // block_size_k)
903
900
 
904
901
  k_scale = scale_reshaped[:, :self.
905
- qk_nope_head_dim //
906
- bn, :].reshape(
907
- -1,
908
- self.kv_lora_rank //
909
- bk)
902
+ qk_nope_head_dim, :]
910
903
  v_scale = scale_reshaped[:,
911
- self.qk_nope_head_dim //
912
- bn:, :].reshape(
913
- -1,
914
- self.kv_lora_rank //
915
- bk)
904
+ self.qk_nope_head_dim:, :]
916
905
  scales_list = [k_scale, v_scale]
917
906
 
918
907
  else:
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import re
2
16
  from dataclasses import dataclass
3
17
  from typing import List, Optional, Tuple
@@ -11,6 +25,9 @@ from jax.sharding import Mesh, NamedSharding
11
25
  from jax.sharding import PartitionSpec as P
12
26
  from vllm.config import VllmConfig
13
27
 
28
+ from tpu_inference.layers.common.quant_methods import MXFP4
29
+ from tpu_inference.layers.common.quantization import (
30
+ dequantize_tensor_from_mxfp4_packed, e8m0_to_fp32, u8_unpack_e2m1)
14
31
  from tpu_inference.layers.jax.attention.gpt_oss_attention import (
15
32
  AttentionMetadata, GptOssAttention)
16
33
  from tpu_inference.layers.jax.constants import KVCacheType
@@ -18,8 +35,6 @@ from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm
18
35
  from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter
19
36
  from tpu_inference.layers.jax.transformer_block import TransformerBlock
20
37
  from tpu_inference.logger import init_logger
21
- from tpu_inference.models.jax.utils.quantization.mxfp4_utils import (
22
- MXFP4_QUANT_METHOD, dequant_mxfp4_to_bf16, unpack_mxfp4_to_fp32)
23
38
  from tpu_inference.models.jax.utils.weight_utils import (
24
39
  get_param, model_weights_generator, print_param_info)
25
40
 
@@ -205,7 +220,7 @@ class GptOss(nnx.Module):
205
220
 
206
221
  # MXFP4 checkpoints swap last two dims for MoE to place packed dim at most minor
207
222
  swap_mlp_transform = transforms[
208
- "swap_last2"] if quant_method == MXFP4_QUANT_METHOD else None
223
+ "swap_last2"] if quant_method == MXFP4 else None
209
224
 
210
225
  mappings = {
211
226
  # Embeddings, Norms, and LM Head
@@ -285,7 +300,7 @@ class GptOss(nnx.Module):
285
300
  # Build a pool of weights with MXFP4 experts combined if neededs
286
301
  pool: dict[str, torch.Tensor | tuple] = (self._build_mxfp4_pool(
287
302
  names_and_weights_generator,
288
- mappings) if quant_method == MXFP4_QUANT_METHOD else {
303
+ mappings) if quant_method == MXFP4 else {
289
304
  loaded_name: loaded_weight
290
305
  for loaded_name, loaded_weight in names_and_weights_generator
291
306
  })
@@ -316,8 +331,9 @@ class GptOss(nnx.Module):
316
331
  blocks_u8, scales_u8 = loaded_weight
317
332
  # Quantized param (QArray): set qvalue/scale directly and skip regular path
318
333
  if hasattr(model_weight, "array"): # QArray check
319
- codes_fp32_t, scales_fp32_t = unpack_mxfp4_to_fp32(
320
- blocks_u8, scales_u8)
334
+ codes_fp32_t = u8_unpack_e2m1(blocks_u8).astype(
335
+ jnp.float32)
336
+ scales_fp32_t = e8m0_to_fp32(scales_u8)
321
337
  self._load_mxfp4(
322
338
  model_weight=model_weight,
323
339
  codes_fp32_t=codes_fp32_t,
@@ -328,7 +344,7 @@ class GptOss(nnx.Module):
328
344
  print_param_info(model_weight, loaded_name)
329
345
  continue
330
346
  # Not a QArray: dequantize MXFP4 to BF16 full weights
331
- prepared_weight = dequant_mxfp4_to_bf16(
347
+ prepared_weight = dequantize_tensor_from_mxfp4_packed(
332
348
  blocks_u8, scales_u8)
333
349
 
334
350
  # Single regular-tensor load call (BF16 or dequantized MXFP4)
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import dataclass
2
16
  from typing import TYPE_CHECKING, Any, Dict, Union
3
17