tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import math
2
16
  from functools import partial
3
17
  from typing import (Callable, List, Literal, NamedTuple, Optional, TypedDict,
@@ -486,6 +500,11 @@ class Qwen2_5_VisionTransformer(nnx.Module):
486
500
  dtype=dtype,
487
501
  rngs=rngs)
488
502
 
503
+ additional_config = getattr(vllm_config, "additional_config",
504
+ None) or {}
505
+ self.enable_dynamic_image_sizes = additional_config.get(
506
+ "enable_dynamic_image_sizes", False)
507
+
489
508
  def rotary_pos_emb_thw(self, t, h, w):
490
509
  hpos_ids, wpos_ids = jnp.indices((h, w))
491
510
  hpos_ids = hpos_ids.reshape(
@@ -579,21 +598,7 @@ class Qwen2_5_VisionTransformer(nnx.Module):
579
598
  seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
580
599
  return max_seqlen, seqlens
581
600
 
582
- def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
583
- int]]) -> jax.Array:
584
- # x: pixel_values: jax.Array
585
- # """Shape:
586
- # `(num_patches, num_channels * patch_size * patch_size)`
587
- # """
588
-
589
- # grid_thw: image_grid_thw: jax.Array
590
- # """Shape: `(num_images, 3)`
591
- # This should be in `(grid_t, grid_h, grid_w)` format.
592
- # """
593
- hidden_states = self.patch_embed(x)
594
-
595
- # num of patches
596
- seq_len = x.shape[0]
601
+ def compute_aux_arrays(self, grid_thw: tuple[tuple[int, int, int]]):
597
602
  # num of images/videoes
598
603
  num_grids = len(grid_thw)
599
604
 
@@ -638,6 +643,42 @@ class Qwen2_5_VisionTransformer(nnx.Module):
638
643
  cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
639
644
  mode='constant',
640
645
  constant_values=0)
646
+ return window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens
647
+
648
+ def pad_inputs(self, x, window_index, rotary_pos_emb, cu_seqlens,
649
+ cu_window_seqlens):
650
+ # padding
651
+ num_patches = int(rotary_pos_emb.shape[0])
652
+ bucket_num_patches = 1 << (num_patches - 1).bit_length()
653
+ num_tokens = window_index.shape[0]
654
+ bucket_num_tokens = bucket_num_patches // self.spatial_merge_unit
655
+ vit_merger_window_size = (self.window_size //
656
+ self.spatial_merge_size // self.patch_size)
657
+ max_windows = (bucket_num_tokens // vit_merger_window_size) + 2
658
+
659
+ rotary_pos_emb = jnp.pad(rotary_pos_emb,
660
+ ((0, bucket_num_patches - num_patches),
661
+ (0, 0)))
662
+ window_index = jnp.concatenate([
663
+ window_index,
664
+ jnp.arange(num_tokens, bucket_num_tokens, dtype=jnp.int32)
665
+ ])
666
+ cu_window_seqlens = jnp.append(cu_window_seqlens, bucket_num_patches)
667
+ pad_w = max(0, max_windows + 1 - cu_window_seqlens.shape[0])
668
+ cu_window_seqlens = jnp.pad(cu_window_seqlens, (0, pad_w), mode='edge')
669
+ cu_seqlens = jnp.append(cu_seqlens, bucket_num_patches)
670
+
671
+ x_padded = jnp.pad(x, ((0, bucket_num_patches - x.shape[0]), (0, 0)))
672
+
673
+ return x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens
674
+
675
+ def compute_hidden_states(self, x: jax.Array, window_index: jax.Array,
676
+ rotary_pos_emb: jax.Array, cu_seqlens: jax.Array,
677
+ cu_window_seqlens: jax.Array) -> jax.Array:
678
+ hidden_states = self.patch_embed(x)
679
+
680
+ # num of patches
681
+ seq_len = x.shape[0]
641
682
 
642
683
  hidden_states = hidden_states.reshape(
643
684
  seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
@@ -664,6 +705,48 @@ class Qwen2_5_VisionTransformer(nnx.Module):
664
705
  hidden_states = hidden_states[reverse_indices, :]
665
706
  return hidden_states
666
707
 
708
+ @jax.jit
709
+ def encode_padded_jit(self, x_padded, window_index, rotary_pos_emb,
710
+ cu_seqlens, cu_window_seqlens):
711
+ return self.compute_hidden_states(x_padded, window_index,
712
+ rotary_pos_emb, cu_seqlens,
713
+ cu_window_seqlens)
714
+
715
+ @partial(
716
+ jax.jit,
717
+ static_argnames=("grid_thw", ),
718
+ )
719
+ def encode_jit(self, x, grid_thw):
720
+ window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
721
+ grid_thw)
722
+ return self.compute_hidden_states(x, window_index, rotary_pos_emb,
723
+ cu_seqlens, cu_window_seqlens)
724
+
725
+ def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
726
+ int]]) -> jax.Array:
727
+ # x: pixel_values: jax.Array
728
+ # """Shape:
729
+ # `(num_patches, num_channels * patch_size * patch_size)`
730
+ # """
731
+
732
+ # grid_thw: image_grid_thw: jax.Array
733
+ # """Shape: `(num_images, 3)`
734
+ # This should be in `(grid_t, grid_h, grid_w)` format.
735
+ # """
736
+ if self.enable_dynamic_image_sizes:
737
+ window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
738
+ grid_thw)
739
+ x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens = self.pad_inputs(
740
+ x, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens)
741
+
742
+ hidden_states = self.encode_padded_jit(x_padded, window_index,
743
+ rotary_pos_emb, cu_seqlens,
744
+ cu_window_seqlens)
745
+ return hidden_states[:num_tokens]
746
+
747
+ else:
748
+ return self.encode_jit(x, grid_thw)
749
+
667
750
 
668
751
  class Qwen2_5_VLForConditionalGeneration(nnx.Module):
669
752
 
@@ -888,10 +971,6 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
888
971
  # "video"] = self._parse_and_validate_video_input(**kwargs)
889
972
  return mm_input_by_modality
890
973
 
891
- @partial(
892
- jax.jit,
893
- static_argnames=("image_grid_thw", ),
894
- )
895
974
  def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
896
975
  return self.visual(image_pixel_values, (image_grid_thw, ))
897
976
 
@@ -931,9 +1010,9 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
931
1010
  split_indices = np.cumsum(sizes)[:-1]
932
1011
  return tuple(jnp.split(image_embeds, split_indices))
933
1012
 
934
- def get_multimodal_embeddings(self, image_grid_thw: tuple[tuple[int, int,
935
- int], ...],
936
- **kwargs: object) -> MultiModalEmbeddings:
1013
+ def embed_multimodal(self, image_grid_thw: tuple[tuple[int, int, int],
1014
+ ...],
1015
+ **kwargs: object) -> MultiModalEmbeddings:
937
1016
 
938
1017
  mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
939
1018
  image_grid_thw, **kwargs)
@@ -957,7 +1036,7 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
957
1036
 
958
1037
  return multimodal_embeddings
959
1038
 
960
- def get_input_embeddings(
1039
+ def embed_input_ids(
961
1040
  self, input_ids: jax.Array,
962
1041
  multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
963
1042
 
@@ -1072,33 +1151,82 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1072
1151
  self,
1073
1152
  run_compilation_fn: Callable,
1074
1153
  ) -> None:
1075
- image_shapes = []
1076
- if (warmup_config := self.vllm_config.additional_config.get(
1077
- "vision_warmup_config")):
1078
- image_shapes = warmup_config.get("image_shapes")
1079
-
1080
1154
  vc = self.vllm_config.model_config.hf_config.vision_config
1081
- factor = vc.patch_size * vc.spatial_merge_size
1082
- for input_hw in image_shapes:
1083
- if not isinstance(input_hw, list) or len(input_hw) != 2:
1084
- logger.warning(f"Skipping invalid shape {input_hw}.")
1085
- continue
1086
- h_input, w_input = input_hw
1087
- h_processed = round(h_input / factor) * factor
1088
- w_processed = round(w_input / factor) * factor
1089
- t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1090
- grid_thw = (t, h, w)
1091
- num_patches = t * h * w
1092
- patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1093
-
1094
- dummy_pixel_values = jnp.ones(
1095
- (num_patches, patch_input_dim),
1096
- self.vllm_config.model_config.dtype,
1097
- )
1098
- dummy_grid_thw = grid_thw
1155
+ patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1156
+ if self.visual.enable_dynamic_image_sizes:
1157
+ spatial_merge_unit = vc.spatial_merge_size**2
1158
+ max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
1159
+ mm_kwargs = self.vllm_config.model_config.multimodal_config.mm_processor_kwargs or {}
1160
+ limit_pixels = float(mm_kwargs.get("max_pixels", float('inf')))
1161
+
1162
+ max_patches = int(
1163
+ min(max_num_batched_tokens * spatial_merge_unit,
1164
+ limit_pixels / (vc.patch_size**2)))
1165
+
1166
+ num_patches_paddings = [
1167
+ 1 << i for i in range(4, (max_patches - 1).bit_length() + 1)
1168
+ ]
1169
+ rotary_dim = vc.hidden_size // vc.num_heads // 2
1170
+ vit_merger_window_size = (vc.window_size //
1171
+ vc.spatial_merge_size // vc.patch_size)
1172
+
1173
+ for num_patches in num_patches_paddings:
1174
+ dummy_x_padded = jnp.ones(
1175
+ (num_patches, patch_input_dim),
1176
+ dtype=self.vllm_config.model_config.dtype)
1177
+
1178
+ num_tokens = num_patches // spatial_merge_unit
1179
+ dummy_window_index = jnp.arange(num_tokens, dtype=jnp.int32)
1180
+
1181
+ dummy_rotary_pos_emb = jnp.ones(
1182
+ (num_patches, rotary_dim),
1183
+ dtype=self.vllm_config.model_config.dtype)
1184
+
1185
+ dummy_cu_seqlens = jnp.array([0, num_patches, num_patches],
1186
+ dtype=jnp.int32)
1187
+
1188
+ max_windows = (num_tokens // vit_merger_window_size) + 2
1189
+ patches_per_window = (vit_merger_window_size**
1190
+ 2) * spatial_merge_unit
1191
+ dummy_cu_window_seqlens = jnp.arange(
1192
+ max_windows + 1, dtype=jnp.int32) * patches_per_window
1193
+ dummy_cu_window_seqlens = jnp.minimum(dummy_cu_window_seqlens,
1194
+ num_patches)
1195
+
1196
+ run_compilation_fn("vision_encoder_padded",
1197
+ self.visual.encode_padded_jit,
1198
+ dummy_x_padded,
1199
+ dummy_window_index,
1200
+ dummy_rotary_pos_emb,
1201
+ dummy_cu_seqlens,
1202
+ dummy_cu_window_seqlens,
1203
+ num_patches=num_patches)
1204
+ else:
1205
+ image_shapes = []
1206
+ if (warmup_config := self.vllm_config.additional_config.get(
1207
+ "vision_warmup_config")):
1208
+ image_shapes = warmup_config.get("image_shapes")
1209
+
1210
+ factor = vc.patch_size * vc.spatial_merge_size
1211
+ for input_hw in image_shapes:
1212
+ if not isinstance(input_hw, list) or len(input_hw) != 2:
1213
+ logger.warning(f"Skipping invalid shape {input_hw}.")
1214
+ continue
1215
+ h_input, w_input = input_hw
1216
+ h_processed = round(h_input / factor) * factor
1217
+ w_processed = round(w_input / factor) * factor
1218
+ t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1219
+ grid_thw = (t, h, w)
1220
+ num_patches = t * h * w
1221
+
1222
+ dummy_pixel_values = jnp.ones(
1223
+ (num_patches, patch_input_dim),
1224
+ self.vllm_config.model_config.dtype,
1225
+ )
1226
+ dummy_grid_thw = (grid_thw, )
1099
1227
 
1100
- run_compilation_fn("single_image_encoder",
1101
- self.get_single_image_embedding,
1102
- dummy_pixel_values,
1103
- dummy_grid_thw,
1104
- image_shape=input_hw)
1228
+ run_compilation_fn("vision_encoder",
1229
+ self.visual.encode_jit,
1230
+ dummy_pixel_values,
1231
+ dummy_grid_thw,
1232
+ image_shape=input_hw)
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import List, Optional, Tuple
2
16
 
3
17
  import jax
@@ -10,6 +24,7 @@ from vllm.config import VllmConfig
10
24
  from tpu_inference import utils
11
25
  from tpu_inference.layers.common.attention_interface import attention
12
26
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.common.quantization import quantize_kv
13
28
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
29
  from tpu_inference.logger import init_logger
15
30
  from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
@@ -125,8 +140,8 @@ class Qwen3Attention(nnx.Module):
125
140
  # q_scale = self._q_scale
126
141
  k_scale = self._k_scale
127
142
  v_scale = self._v_scale
128
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
129
- k_scale, v_scale)
143
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
144
+ v_scale)
130
145
  new_kv_cache, outputs = attention(
131
146
  kv_cache,
132
147
  q,
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import glob
2
16
  import hashlib
3
17
  import os
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from typing import Union
2
16
 
3
17
  import jax
@@ -29,25 +43,25 @@ def sanity_check_mm_encoder_outputs(
29
43
  ) -> None:
30
44
  """
31
45
  Perform sanity checks for the result of
32
- [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
46
+ [`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
33
47
  """
34
48
  assert isinstance(mm_embeddings, (list, tuple, jax.Array)), (
35
49
  "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
36
50
  f"or a single 3D tensor, but got {type(mm_embeddings)} "
37
51
  "instead. This is most likely due to incorrect implementation "
38
- "of the model's `get_multimodal_embeddings` method.")
52
+ "of the model's `embed_multimodal` method.")
39
53
 
40
54
  assert len(mm_embeddings) == expected_num_items, (
41
55
  "Expected number of multimodal embeddings to match number of "
42
56
  f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
43
57
  "instead. This is most likely due to incorrect implementation "
44
- "of the model's `get_multimodal_embeddings` method.")
58
+ "of the model's `embed_multimodal` method.")
45
59
 
46
60
  assert all(e.ndim == 2 for e in mm_embeddings), (
47
61
  "Expected multimodal embeddings to be a sequence of 2D tensors, "
48
62
  f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
49
63
  "instead. This is most likely due to incorrect implementation "
50
- "of the model's `get_multimodal_embeddings` method.")
64
+ "of the model's `embed_multimodal` method.")
51
65
 
52
66
 
53
67
  def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,5 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
  # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import copy
3
4
  import functools
4
5
  import os
5
6
  from typing import TYPE_CHECKING, Callable, List
@@ -34,17 +35,43 @@ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
34
35
  DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
35
36
  DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
36
37
 
37
- DEFAULT_DEEPSEEK_FP8_CONFIG = {
38
+ DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG = {
38
39
  "qwix": {
39
40
  "use_abstract_model":
40
41
  True,
41
42
  "scale_dtype":
42
43
  "bfloat16",
43
44
  "rules": [
45
+ # Exclude router from quantization
44
46
  {
45
47
  "module_path": ".*.custom_module.router.*",
46
48
  "weight_qtype": None,
47
49
  },
50
+ # Avoid the combine expert ops
51
+ {
52
+ "module_path": ".*combine_experts.*",
53
+ "weight_qtype": None,
54
+ },
55
+ # Attention layers: keep FP8 for weights and activations
56
+ {
57
+ "module_path": ".*.attn.*",
58
+ "weight_qtype": "float8_e4m3fn",
59
+ "act_qtype": "float8_e4m3fn",
60
+ },
61
+ # MoE experts: use FP4 for expert weights
62
+ {
63
+ "module_path": ".*.custom_module.*",
64
+ "weight_qtype": "float4_e2m1fn",
65
+ "act_qtype": "float8_e4m3fn",
66
+ "tile_size": 256,
67
+ },
68
+ # Shared experts: also FP4
69
+ {
70
+ "module_path": ".*.shared_experts.*",
71
+ "weight_qtype": "float4_e2m1fn",
72
+ "act_qtype": "float8_e4m3fn",
73
+ "tile_size": 256,
74
+ },
48
75
  {
49
76
  "module_path": ".*",
50
77
  "weight_qtype": "float8_e4m3fn",
@@ -154,12 +181,9 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
154
181
  logger.info(f"Memory usage before applying quantization of params: "
155
182
  f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
156
183
 
157
- # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
158
- kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
159
-
160
- # Handle the case where kv_cache_dtype is "auto"
161
- if kv_cache_jnp_dtype is None:
162
- assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
184
+ if kv_cache_dtype != "auto":
185
+ kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
186
+ else:
163
187
  kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
164
188
 
165
189
  kv_caches = create_kv_caches(
@@ -169,9 +193,11 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
169
193
  head_size=kv_cache_head_size,
170
194
  mesh=mesh,
171
195
  layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
172
- cache_dtype=kv_cache_jnp_dtype)
196
+ cache_dtype=kv_cache_jnp_dtype,
197
+ use_mla=model.vllm_config.model_config.use_mla,
198
+ )
173
199
 
174
- dp_size = mesh.shape.get("data", 1) * mesh.shape.get("attn", 1)
200
+ dp_size = model.vllm_config.sharding_config.total_dp_size
175
201
 
176
202
  # NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
177
203
  input_ids = jax.random.randint(rng,
@@ -399,8 +425,7 @@ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
399
425
 
400
426
 
401
427
  def get_default_qwix_quantization_config(
402
- model_type: str, quant_method: str,
403
- skip_quantization: bool) -> dict | None:
428
+ hf_config: dict, skip_quantization: bool) -> dict | None:
404
429
  """
405
430
  Some models are pre-quantized and in those cases, we want to return a default set of
406
431
  Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
@@ -418,9 +443,42 @@ def get_default_qwix_quantization_config(
418
443
  """
419
444
  if skip_quantization:
420
445
  return None
421
- # TODO (jacobplatin): remove this so that we can support various quantization types
446
+ model_type = hf_config.model_type.lower() if hasattr(
447
+ hf_config, "model_type") else None
448
+ quant_method = hf_config.quantization_config["quant_method"] if hasattr(
449
+ hf_config, "quantization_config") else None
450
+ # TODO (jacobplatin): remove this so that we can support various quantization types + make
451
+ # more flexible
452
+ # NOTE (jacobplatin): we'll default to mixed FP8 (attention) + FP4 (MoE experts)
453
+ # for DeepSeek
422
454
  if model_type == "deepseek_v3" and quant_method == "fp8":
423
- return DEFAULT_DEEPSEEK_FP8_CONFIG
455
+ config = copy.deepcopy(DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG)
456
+
457
+ # Dynamically fetch block size from HF config if available
458
+ # Config fmt: 'weight_block_size': [1, 512] -> we want the 2nd dim for tile_size
459
+ # NOTE: if the checkpoint is not 1D subchannel, we will throw an error
460
+ hf_quant_config = hf_config.quantization_config
461
+ assert "weight_block_size" in hf_quant_config, "Expected weight_block_size in quantization_config"
462
+ block_size = hf_quant_config["weight_block_size"]
463
+ if isinstance(block_size, (list, tuple)) and len(block_size) == 2:
464
+ assert block_size[
465
+ 0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}! If you are trying to run quantized DeepSeek, we currently only support 1D-subchannel quantization and those models can be found here: https://huggingface.co/collections/jrplatin/deepseek-r1-1d-subchannel"
466
+ tile_size = block_size[1]
467
+ assert tile_size > 1, f"Expected tile_size > 1 for DeepSeek, but got {tile_size}"
468
+ logger.info(
469
+ f"Detected DeepSeek tile_size from config: {tile_size}")
470
+
471
+ # Update tile_size in the rules, since we might not always use a 1D subchannel size of
472
+ # 256
473
+ for rule in config["qwix"]["rules"]:
474
+ if "tile_size" in rule:
475
+ rule["tile_size"] = tile_size
476
+ else:
477
+ raise ValueError(
478
+ f"Invalid weight_block_size config: {block_size}, expected a list/tuple of length 2"
479
+ )
480
+
481
+ return config
424
482
  elif model_type == "llama4" and quant_method == "compressed-tensors":
425
483
  return DEFAULT_LLAMA4_FP8_CONFIG
426
484
  # MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
@@ -439,14 +497,10 @@ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
439
497
  # Qwix quantization config accordingly
440
498
  # NOTE: if a Qwix config is provided (via the`additional_config`), we'll
441
499
  # use that instead
442
- model_type = vllm_config.model_config.hf_config.model_type.lower(
443
- ) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
444
- quant_method = vllm_config.model_config.hf_config.quantization_config[
445
- "quant_method"] if hasattr(vllm_config.model_config.hf_config,
446
- "quantization_config") else None
500
+ hf_config = vllm_config.model_config.hf_config
447
501
  default_quantization_config = get_default_qwix_quantization_config(
448
- model_type, quant_method,
449
- vllm_config.additional_config.get("skip_quantization", False))
502
+ hf_config, vllm_config.additional_config.get("skip_quantization",
503
+ False))
450
504
 
451
505
  maybe_existing_quantization_config = vllm_config.additional_config.get(
452
506
  "quantization")
@@ -503,7 +557,14 @@ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
503
557
  maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
504
558
  weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
505
559
  else:
506
- weight = jax.random.normal(key, param_shape, dtype)
560
+ # NOTE: _uniform() in random.py does not accept float4_e2m1fn
561
+ # Error: "TypeError: uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot float4_e2m1fn."
562
+ # Workaround: call function with dtype jnp.float8_e4m3fn and cast back to float4_e2m1fn
563
+ if dtype != "float4_e2m1fn":
564
+ weight = jax.random.normal(key, param_shape, dtype)
565
+ else:
566
+ weight = jax.random.normal(key, param_shape,
567
+ jnp.float8_e4m3fn).astype(dtype)
507
568
 
508
569
  def get_slice(index):
509
570
  return weight[index]
@@ -538,18 +599,16 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
538
599
  logger.info("Initializing Qwix-quantized model with random weights...")
539
600
  # TODO (jacobplatin): clean up this logic
540
601
  scale_dtype = model.weight_loader.scale_dtype
541
- scale_shape_map = model.weight_loader.scale_shap_map_for_random_weight_loading if hasattr(
602
+ scale_shape_map = model.weight_loader.scale_shape_map_for_random_weight_loading if hasattr(
542
603
  model.weight_loader,
543
- 'scale_shap_map_for_random_weight_loading') else {}
604
+ 'scale_shape_map_for_random_weight_loading') else {}
544
605
  quantization_block_sizes = quantization_config["weight_block_size"]
545
606
  assert len(
546
607
  quantization_block_sizes
547
608
  ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
548
- quantization_block_size_n, _ = quantization_block_sizes[
549
- 0], quantization_block_sizes[1]
550
609
 
551
610
  # Iterate through all variables and initialize them
552
- prev_param_shape = None
611
+
553
612
  for path, param in nnx.iter_graph(model):
554
613
  if not isinstance(param, nnx.Variable):
555
614
  continue
@@ -559,16 +618,17 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
559
618
  is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
560
619
  param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
561
620
  param_shape = param.value.shape
562
- # TODO (jacobplatin): clean this up
563
621
  if is_qwix_scale:
564
- param_shape = scale_shape_map.get(
565
- path[3],
566
- tuple(dim // quantization_block_size_n
567
- for dim in prev_param_shape))
622
+ key = f"{path[2]}.{path[3]}"
623
+
624
+ if key in scale_shape_map:
625
+ param_shape = scale_shape_map[key]
626
+ else:
627
+ raise ValueError(
628
+ f"Scale shape for {key} not found in scale_shape_map.")
568
629
  param.value = get_random_sharded_array(
569
630
  rng, mesh, param, param_shape, param_dtype,
570
631
  ".".join([str(x) for x in path]))
571
- prev_param_shape = param_shape
572
632
 
573
633
  # Handles the DeepSeek case, where this needs to be called to make the cache weights
574
634
  # concrete