tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,605 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import partial
16
+ from unittest.mock import MagicMock, patch
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ import pytest
22
+ from flax import nnx
23
+ from flax.typing import PRNGKey
24
+ from jax.sharding import Mesh
25
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import \
26
+ Qwen2_5_VLConfig
27
+ from vllm.config import (CacheConfig, DeviceConfig, MultiModalConfig,
28
+ ParallelConfig, SchedulerConfig)
29
+
30
+ # Import the module itself to allow patching
31
+ # Corrected imports for the code under test
32
+ from tpu_inference.models.jax.qwen2_5_vl import (
33
+ AttentionMetadata, Qwen2_5_VisionAttention, Qwen2_5_VisionBlock,
34
+ Qwen2_5_VisionMLP, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionPatchMerger,
35
+ Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
36
+ Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLImagePixelInputs, SegmentIds,
37
+ apply_rotary_pos_emb_vision, generate_window_segment_ids)
38
+
39
+
40
+ # --- Configuration Mocking ---
41
+ class MockModelConfig:
42
+
43
+ def __init__(self, hf_config, dtype):
44
+ self.hf_config = hf_config
45
+ self.dtype = dtype
46
+ self.multimodal_config = MultiModalConfig(
47
+ image_input_type="pixel",
48
+ image_token_id=hf_config.image_token_id,
49
+ image_input_shape=None)
50
+ self.model = "mock_qwen2_5_vl"
51
+ # Add other attributes if needed by the code
52
+ self.tokenizer = "mock_tokenizer"
53
+ self.tokenizer_mode = "auto"
54
+ self.trust_remote_code = True
55
+ self.seed = 0
56
+
57
+ def is_multimodal_model(self):
58
+ return True
59
+
60
+ def get_hidden_size(self):
61
+ return self.hf_config.hidden_size
62
+
63
+ def get_head_size(self):
64
+ return self.hf_config.hidden_size // self.hf_config.num_attention_heads
65
+
66
+
67
+ class MockVllmConfig:
68
+ """A mock VllmConfig sufficient for testing the Qwen2.5 VL model."""
69
+
70
+ def __init__(self, tie_word_embeddings: bool = False):
71
+ vision_config = {
72
+ "hidden_size": 16,
73
+ "intermediate_size": 32,
74
+ "patch_size": 14,
75
+ "image_size": 28,
76
+ "temporal_patch_size": 2,
77
+ "in_channels": 3,
78
+ "window_size": 28,
79
+ "spatial_merge_size": 2,
80
+ "fullatt_block_indexes": [0],
81
+ "out_hidden_size": 24,
82
+ "depth": 2,
83
+ "hidden_act": "gelu",
84
+ "num_heads": 2,
85
+ }
86
+ hf_config = Qwen2_5_VLConfig(
87
+ vision_config=vision_config,
88
+ hidden_size=16,
89
+ num_hidden_layers=2,
90
+ num_attention_heads=2,
91
+ num_key_value_heads=2,
92
+ intermediate_size=32,
93
+ rms_norm_eps=1e-6,
94
+ image_token_id=200000,
95
+ video_token_id=200001,
96
+ tie_word_embeddings=tie_word_embeddings,
97
+ vocab_size=32000,
98
+ rope_theta=1000000.0,
99
+ )
100
+ self.model_config = MockModelConfig(hf_config, jnp.bfloat16)
101
+ self.cache_config = MagicMock(spec=CacheConfig)
102
+ self.parallelism_config = MagicMock(spec=ParallelConfig)
103
+ self.scheduler_config = MagicMock(spec=SchedulerConfig)
104
+ self.device_config = MagicMock(spec=DeviceConfig)
105
+ self.load_config = MagicMock()
106
+ self.extra_configs = {}
107
+ self.additional_config = {}
108
+
109
+
110
+ @pytest.fixture(scope="module")
111
+ def mesh():
112
+ """Creates a mesh with all required axes for testing."""
113
+ if not jax.devices():
114
+ pytest.skip("No JAX devices available for mesh creation.")
115
+ devices = np.array(jax.local_devices())
116
+ return Mesh(devices.reshape((len(devices), 1, 1)),
117
+ axis_names=('data', 'attn_dp', 'model'))
118
+
119
+
120
+ @pytest.fixture
121
+ def rng() -> PRNGKey:
122
+ """Provides a reusable JAX PRNGKey."""
123
+ return jax.random.PRNGKey(42)
124
+
125
+
126
+ @pytest.fixture
127
+ def mock_vllm_config() -> MockVllmConfig:
128
+ return MockVllmConfig()
129
+
130
+
131
+ @pytest.fixture
132
+ def rngs(rng: PRNGKey) -> nnx.Rngs:
133
+ return nnx.Rngs(params=rng)
134
+
135
+
136
+ # --- Test Classes ---
137
+ class TestUtils:
138
+
139
+ def test_apply_rotary_pos_emb_vision(self, rng: PRNGKey):
140
+ B, T, N, H = 1, 10, 2, 8
141
+ x = jax.random.normal(rng, (B, T, N, H))
142
+ rotary_pos_emb = jax.random.normal(rng, (T, H // 2))
143
+ x_rotated = apply_rotary_pos_emb_vision(x, rotary_pos_emb)
144
+ assert x_rotated.shape == (B, T, N, H)
145
+
146
+ def test_generate_window_segment_ids(self):
147
+ cu_seqlens = jnp.array([0, 5, 10])
148
+ seq_len = 10
149
+ padded_seq_len = 16
150
+ segment_ids = generate_window_segment_ids(cu_seqlens, seq_len,
151
+ padded_seq_len)
152
+ assert isinstance(segment_ids, SegmentIds)
153
+ assert segment_ids.q.shape == (1, padded_seq_len)
154
+ assert segment_ids.kv.shape == (1, padded_seq_len)
155
+ expected_q = np.array(
156
+ [[1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0]])
157
+ np.testing.assert_array_equal(segment_ids.q, expected_q)
158
+ np.testing.assert_array_equal(segment_ids.kv, expected_q)
159
+
160
+
161
+ class TestQwen2_5_VisionMLP:
162
+
163
+ def test_forward(self, mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs):
164
+ config = mock_vllm_config.model_config.hf_config.vision_config
165
+ dtype = mock_vllm_config.model_config.dtype
166
+ mlp = Qwen2_5_VisionMLP(config, dtype, rngs)
167
+ x = jnp.ones((5, config.hidden_size), dtype=dtype)
168
+ y = mlp(x)
169
+ assert y.shape == (5, config.hidden_size)
170
+ assert y.dtype == dtype
171
+
172
+
173
+ class TestQwen2_5_VisionAttention:
174
+
175
+ @patch('tpu_inference.models.jax.qwen2_5_vl.sharded_flash_attention')
176
+ def test_forward_fullattn(self, mock_flash_attention: MagicMock,
177
+ mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
178
+ mesh: Mesh, rng: PRNGKey):
179
+ attn_module = Qwen2_5_VisionAttention(
180
+ mock_vllm_config.model_config.hf_config,
181
+ mock_vllm_config.model_config.dtype, rngs, mesh)
182
+ B, T, D = 1, 10, attn_module.hidden_size
183
+ # sharded_flash_attention is a factory, so we mock the returned function
184
+ mock_attn_fn = MagicMock(return_value=jnp.ones((B,
185
+ attn_module.num_heads,
186
+ 128,
187
+ attn_module.head_dim)))
188
+ attn_module.flash_attention = mock_attn_fn
189
+ x = jax.random.normal(rng, (T, B, D))
190
+ rotary_pos_emb = jax.random.normal(rng, (T, attn_module.head_dim // 2))
191
+ cu_seqlens = jnp.array([0, 5])
192
+
193
+ y_full = attn_module(x,
194
+ rotary_pos_emb,
195
+ cu_window_seqlens=cu_seqlens,
196
+ use_fullattn=True)
197
+ assert y_full.shape == (T, B, D)
198
+ mock_attn_fn.assert_called_once()
199
+ assert mock_attn_fn.call_args[0][3].q.shape == (1, 128)
200
+
201
+ @patch('tpu_inference.models.jax.qwen2_5_vl.sharded_flash_attention')
202
+ def test_forward_windowed(self, mock_flash_attention: MagicMock,
203
+ mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
204
+ mesh: Mesh, rng: PRNGKey):
205
+ attn_module = Qwen2_5_VisionAttention(
206
+ mock_vllm_config.model_config.hf_config,
207
+ mock_vllm_config.model_config.dtype, rngs, mesh)
208
+ B, T, D = 1, 10, attn_module.hidden_size
209
+ mock_attn_fn = MagicMock(return_value=jnp.ones((B,
210
+ attn_module.num_heads,
211
+ 128,
212
+ attn_module.head_dim)))
213
+ attn_module.flash_attention = mock_attn_fn
214
+ x = jax.random.normal(rng, (T, B, D))
215
+ rotary_pos_emb = jax.random.normal(rng, (T, attn_module.head_dim // 2))
216
+ cu_window_seqlens = jnp.array([0, 5, 10])
217
+
218
+ y_window = attn_module(x,
219
+ rotary_pos_emb,
220
+ cu_window_seqlens=cu_window_seqlens,
221
+ use_fullattn=False)
222
+ assert y_window.shape == (T, B, D)
223
+ mock_attn_fn.assert_called_once()
224
+ assert mock_attn_fn.call_args[0][3].q.shape == (1, 128)
225
+
226
+ def test_batch_fail(self, mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
227
+ mesh: Mesh, rng: PRNGKey):
228
+ attn_module = Qwen2_5_VisionAttention(
229
+ mock_vllm_config.model_config.hf_config,
230
+ mock_vllm_config.model_config.dtype, rngs, mesh)
231
+ T, B, D = 10, 2, attn_module.hidden_size
232
+ x = jax.random.normal(rng, (T, B, D))
233
+ rotary_pos_emb = jax.random.normal(rng, (T, attn_module.head_dim // 2))
234
+ with pytest.raises(
235
+ AssertionError,
236
+ match="Vision attention currently only supports batch size 1"):
237
+ attn_module(x, rotary_pos_emb, use_fullattn=True)
238
+
239
+
240
+ class TestQwen2_5_VisionBlock:
241
+
242
+ @patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionMLP',
243
+ autospec=True)
244
+ @patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionAttention',
245
+ autospec=True)
246
+ def test_forward(self, MockAttention: MagicMock, MockMLP: MagicMock,
247
+ mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
248
+ mesh: Mesh, rng: PRNGKey):
249
+ config = mock_vllm_config.model_config.hf_config
250
+ dtype = mock_vllm_config.model_config.dtype
251
+ D = config.vision_config.hidden_size
252
+ T, B = 10, 1
253
+
254
+ mock_attn_instance = MockAttention.return_value
255
+ mock_attn_instance.return_value = jnp.zeros((T, B, D), dtype=dtype)
256
+ mock_mlp_instance = MockMLP.return_value
257
+ mock_mlp_instance.return_value = jnp.zeros((T, B, D), dtype=dtype)
258
+
259
+ block = Qwen2_5_VisionBlock(config, dtype, rngs, mesh)
260
+ x = jax.random.normal(rng, (T, B, D))
261
+ rotary_pos_emb = jax.random.normal(
262
+ rng, (T, config.vision_config.hidden_size //
263
+ config.vision_config.num_heads // 2))
264
+
265
+ y = block(x, rotary_pos_emb, use_fullattn=True)
266
+ assert y.shape == (T, B, D)
267
+ mock_attn_instance.assert_called_once()
268
+ mock_mlp_instance.assert_called_once()
269
+
270
+
271
+ class TestQwen2_5_VisionPatchEmbed:
272
+
273
+ def test_forward(self, mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
274
+ rng: PRNGKey):
275
+ vc = mock_vllm_config.model_config.hf_config.vision_config
276
+ dtype = mock_vllm_config.model_config.dtype
277
+ patch_embed = Qwen2_5_VisionPatchEmbed(
278
+ rngs,
279
+ patch_size=vc.patch_size,
280
+ temporal_patch_size=vc.temporal_patch_size,
281
+ in_channels=vc.in_channels,
282
+ hidden_size=vc.hidden_size,
283
+ dtype=dtype)
284
+ num_patches = 4
285
+ patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
286
+ x = jax.random.normal(rng, (num_patches, patch_dim))
287
+ y = patch_embed(x)
288
+ assert y.shape == (num_patches, vc.hidden_size)
289
+
290
+
291
+ class TestQwen2_5_VisionPatchMerger:
292
+
293
+ def test_forward(self, mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
294
+ rng: PRNGKey):
295
+ vc = mock_vllm_config.model_config.hf_config.vision_config
296
+ dtype = mock_vllm_config.model_config.dtype
297
+ merger = Qwen2_5_VisionPatchMerger(
298
+ d_model=vc.out_hidden_size,
299
+ context_dim=vc.hidden_size,
300
+ norm_layer=partial(nnx.RMSNorm, epsilon=1e-6),
301
+ spatial_merge_size=vc.spatial_merge_size,
302
+ dtype=dtype,
303
+ rngs=rngs)
304
+ x = jax.random.normal(rng,
305
+ (5, vc.spatial_merge_size**2, vc.hidden_size))
306
+ y = merger(x)
307
+ assert y.shape == (5, vc.out_hidden_size)
308
+
309
+
310
+ class TestQwen2_5_VisionRotaryEmbedding:
311
+
312
+ def test_forward(self):
313
+ dim = 16
314
+ seqlen = 10
315
+ rotary_emb = Qwen2_5_VisionRotaryEmbedding(dim=dim)
316
+ emb = rotary_emb(seqlen)
317
+ assert emb.shape == (seqlen, dim // 2)
318
+
319
+
320
+ class TestQwen2_5_VisionTransformer:
321
+
322
+ @pytest.fixture
323
+ def vision_transformer(self, mock_vllm_config: MockVllmConfig,
324
+ rngs: nnx.Rngs, mesh: Mesh):
325
+ return Qwen2_5_VisionTransformer(mock_vllm_config, rngs, mesh)
326
+
327
+ def test_rotary_pos_emb_thw(self,
328
+ vision_transformer: Qwen2_5_VisionTransformer):
329
+ t, h, w = 2, 4, 4
330
+ emb = vision_transformer.rotary_pos_emb_thw(t, h, w)
331
+ vc = vision_transformer.config
332
+ sm = vc.spatial_merge_size
333
+ head_dim_half = (vc.hidden_size // vc.num_heads) // 2
334
+ expected_shape = (t * (h // sm) * (w // sm), sm * sm, head_dim_half)
335
+ assert emb.shape == expected_shape
336
+
337
+ def test_get_window_index_thw(
338
+ self, vision_transformer: Qwen2_5_VisionTransformer):
339
+ grid_t, grid_h, grid_w = 1, 8, 8
340
+ index_new, cu_seqlens_tmp = vision_transformer.get_window_index_thw(
341
+ grid_t, grid_h, grid_w)
342
+ vc = vision_transformer.config
343
+ sm = vc.spatial_merge_size
344
+ num_valid_indices = grid_t * (grid_h // sm) * (grid_w // sm)
345
+ assert index_new.shape == (num_valid_indices, )
346
+ assert jnp.all(index_new >= 0)
347
+
348
+ def test_get_rope_by_thw(self,
349
+ vision_transformer: Qwen2_5_VisionTransformer):
350
+ t, h, w = 1, 8, 8
351
+ res = vision_transformer.get_rope_by_thw(t, h, w)
352
+ assert isinstance(res, tuple)
353
+ assert len(res) == 4
354
+ rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw = res
355
+
356
+ vc = vision_transformer.config
357
+ sm = vc.spatial_merge_size
358
+ # The rotary embedding output for each position is head_dim // 2
359
+ head_dim_rope = (vc.hidden_size // vc.num_heads) // 2
360
+ expected_len = window_index_thw.shape[0] * sm * sm
361
+ assert rotary_pos_emb_thw.shape == (expected_len, head_dim_rope)
362
+
363
+ @pytest.mark.parametrize("enable_dynamic_image_sizes", [False, True])
364
+ def test_call(self, mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
365
+ mesh: Mesh, rng: PRNGKey, enable_dynamic_image_sizes: bool):
366
+ mock_vllm_config.additional_config = {
367
+ "enable_dynamic_image_sizes": enable_dynamic_image_sizes
368
+ }
369
+ vision_transformer = Qwen2_5_VisionTransformer(mock_vllm_config, rngs,
370
+ mesh)
371
+ # Mock the flash_attention call to avoid sharding errors in test environment
372
+ for block in vision_transformer.blocks:
373
+ # The mock should return a tensor of the same shape as the query 'q'
374
+ block.attn.flash_attention = MagicMock(
375
+ side_effect=lambda q, k, v, seg: jnp.ones_like(q))
376
+
377
+ vc = vision_transformer.config
378
+ t_pix, h_pix, w_pix = 2, 84, 28
379
+
380
+ # The number of patches is calculated from the pixel dimensions of the image/video
381
+ num_patches = (t_pix // vc.temporal_patch_size) * \
382
+ (h_pix // vc.patch_size) * \
383
+ (w_pix // vc.patch_size)
384
+
385
+ patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
386
+ x = jax.random.normal(rng, (num_patches, patch_dim))
387
+
388
+ # The grid_thw should be in terms of patch grid dimensions, not pixels
389
+ t_grid = t_pix // vc.temporal_patch_size
390
+ h_grid = h_pix // vc.patch_size
391
+ w_grid = w_pix // vc.patch_size
392
+ grid_thw = ((t_grid, h_grid, w_grid), )
393
+
394
+ embeddings = vision_transformer(x, grid_thw)
395
+
396
+ # The number of output tokens is determined by the grid dimensions and spatial merge size.
397
+ expected_len = t_grid * (h_grid // vc.spatial_merge_size) * (
398
+ w_grid // vc.spatial_merge_size)
399
+ assert embeddings.shape == (expected_len, vc.out_hidden_size)
400
+
401
+
402
+ class TestQwen2_5_VLForConditionalGeneration:
403
+
404
+ @pytest.fixture
405
+ def model(self, mock_vllm_config: MockVllmConfig, rng: PRNGKey,
406
+ mesh: Mesh):
407
+ with patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionTransformer', autospec=True) as MockVision, \
408
+ patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2ForCausalLM', autospec=True) as MockLM:
409
+ mock_visual = MockVision.return_value
410
+ mock_visual.dtype = mock_vllm_config.model_config.dtype
411
+ mock_visual.config = mock_vllm_config.model_config.hf_config.vision_config
412
+ mock_visual.spatial_merge_size = mock_vllm_config.model_config.hf_config.vision_config.spatial_merge_size
413
+
414
+ model = Qwen2_5_VLForConditionalGeneration(mock_vllm_config, rng,
415
+ mesh)
416
+ # Directly assign mocked instances
417
+ model.visual = mock_visual
418
+ model.language_model = MockLM.return_value
419
+ yield model
420
+
421
+ def test_validate_and_reshape_mm_tensor(
422
+ self, model: Qwen2_5_VLForConditionalGeneration):
423
+ data_list = [np.ones((2, 4)), np.ones((3, 4))]
424
+ reshaped_list = model._validate_and_reshape_mm_tensor(
425
+ data_list, "test_list")
426
+ assert reshaped_list.shape == (5, 4)
427
+ assert isinstance(reshaped_list, jax.Array)
428
+
429
+ data_2d = np.ones((5, 4))
430
+ reshaped_2d = model._validate_and_reshape_mm_tensor(data_2d, "test_2d")
431
+ assert reshaped_2d.shape == (5, 4)
432
+
433
+ data_3d = np.ones((2, 5, 4))
434
+ reshaped_3d = model._validate_and_reshape_mm_tensor(data_3d, "test_3d")
435
+ assert reshaped_3d.shape == (10, 4)
436
+
437
+ with pytest.raises(ValueError, match="Incorrect type of test_invalid"):
438
+ model._validate_and_reshape_mm_tensor("invalid", "test_invalid")
439
+
440
+ def test_parse_and_validate_image_input(
441
+ self, model: Qwen2_5_VLForConditionalGeneration):
442
+ grid = ((2, 28, 28), )
443
+ vc = model.config.vision_config
444
+ patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
445
+ pixel_values = np.ones((4, patch_dim))
446
+
447
+ parsed = model._parse_and_validate_image_input(
448
+ grid, pixel_values=pixel_values)
449
+ assert parsed is not None
450
+ assert parsed['type'] == "pixel_values"
451
+ assert parsed['pixel_values'].shape == (4, patch_dim)
452
+ assert parsed['image_grid_thw'] == grid
453
+
454
+ parsed_none = model._parse_and_validate_image_input(grid)
455
+ assert parsed_none is None
456
+
457
+ def test_parse_and_validate_multimodal_inputs(
458
+ self, model: Qwen2_5_VLForConditionalGeneration):
459
+ grid = ((2, 28, 28), )
460
+ vc = model.config.vision_config
461
+ patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
462
+ pixel_values = np.ones((4, patch_dim))
463
+
464
+ mm_inputs = model._parse_and_validate_multimodal_inputs(
465
+ grid, pixel_values=pixel_values)
466
+ assert "image" in mm_inputs
467
+ assert mm_inputs["image"]['type'] == "pixel_values"
468
+
469
+ mm_inputs_empty = model._parse_and_validate_multimodal_inputs(grid)
470
+ assert not mm_inputs_empty
471
+
472
+ def test_process_image_input_pixels(
473
+ self, model: Qwen2_5_VLForConditionalGeneration):
474
+ grid_thw = ((2, 28, 28), (2, 28, 28))
475
+ vc = model.config.vision_config
476
+ num_patches = 8 # 4 per image
477
+ patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
478
+ pixel_values = jnp.ones((num_patches, patch_dim))
479
+ image_input = Qwen2_5_VLImagePixelInputs(type="pixel_values",
480
+ pixel_values=pixel_values,
481
+ image_grid_thw=grid_thw)
482
+
483
+ tokens_per_image = (2 * 28 * 28) // (vc.spatial_merge_size**2)
484
+ mock_embeds = jnp.ones((tokens_per_image, vc.out_hidden_size))
485
+ model.visual.return_value = mock_embeds
486
+
487
+ embeddings = model._process_image_input(image_input)
488
+ assert isinstance(embeddings, tuple)
489
+ assert len(embeddings) == 2
490
+ assert embeddings[0].shape == (tokens_per_image, vc.out_hidden_size)
491
+ assert embeddings[1].shape == (tokens_per_image, vc.out_hidden_size)
492
+ assert model.visual.call_count == 2
493
+
494
+ def test_embed_multimodal(self, model: Qwen2_5_VLForConditionalGeneration):
495
+ grid_thw = ((2, 28, 28), )
496
+ vc = model.config.vision_config
497
+ patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
498
+ pixel_values = np.ones((4, patch_dim))
499
+ tokens_per_image = (2 * 28 * 28) // (vc.spatial_merge_size**2)
500
+ mock_vision_output = jnp.ones((tokens_per_image, vc.out_hidden_size))
501
+
502
+ with patch.object(model,
503
+ '_process_image_input',
504
+ return_value=(mock_vision_output, )) as mock_process:
505
+ mm_embeds = model.embed_multimodal(grid_thw,
506
+ pixel_values=pixel_values)
507
+ mock_process.assert_called_once()
508
+ assert isinstance(mm_embeds, tuple)
509
+ assert len(mm_embeds) == 1
510
+ assert mm_embeds[0].shape == (tokens_per_image, vc.out_hidden_size)
511
+
512
+ mm_embeds_none = model.embed_multimodal(grid_thw)
513
+ assert len(mm_embeds_none) == 0
514
+
515
+ @patch('tpu_inference.models.jax.qwen2_5_vl.merge_multimodal_embeddings')
516
+ def test_embed_input_ids(self, mock_merge_embeddings: MagicMock,
517
+ model: Qwen2_5_VLForConditionalGeneration,
518
+ rng: PRNGKey):
519
+ input_ids = jax.random.randint(rng, (1, 10), 0,
520
+ model.config.vocab_size)
521
+ mock_text_embeds = jnp.ones((1, 10, model.config.hidden_size))
522
+ model.language_model.model = MagicMock()
523
+ model.language_model.model.embed = MagicMock(
524
+ return_value=mock_text_embeds)
525
+
526
+ embeds = model.embed_input_ids(input_ids, None)
527
+ np.testing.assert_array_equal(embeds, mock_text_embeds)
528
+ mock_merge_embeddings.assert_not_called()
529
+
530
+ empty_mm = jnp.ones((0, model.config.hidden_size), )
531
+ embeds_empty_mm = model.embed_input_ids(input_ids, empty_mm)
532
+ np.testing.assert_array_equal(embeds_empty_mm, mock_text_embeds)
533
+ mock_merge_embeddings.assert_not_called()
534
+
535
+ mm_embeds = jnp.ones((5, model.config.hidden_size))
536
+ mock_merged = jnp.ones((1, 15, model.config.hidden_size))
537
+ mock_merge_embeddings.return_value = mock_merged
538
+
539
+ embeds_mm = model.embed_input_ids(input_ids, mm_embeds)
540
+ np.testing.assert_array_equal(embeds_mm, mock_merged)
541
+ mock_merge_embeddings.assert_called_once_with(
542
+ input_ids, mock_text_embeds, mm_embeds,
543
+ [model.config.image_token_id, model.config.video_token_id])
544
+
545
+ def test_call(self, model: Qwen2_5_VLForConditionalGeneration,
546
+ rng: PRNGKey):
547
+ kv_caches = [MagicMock()]
548
+ input_ids = jax.random.randint(rng, (1, 10), 0,
549
+ model.config.vocab_size)
550
+ attn_meta = MagicMock(spec=AttentionMetadata)
551
+ mock_lm_output = ([MagicMock()],
552
+ jnp.ones((1, 10, model.config.hidden_size)), [])
553
+ model.language_model.return_value = mock_lm_output
554
+
555
+ new_kvs, x, aux_hidden_states = model(kv_caches, input_ids, attn_meta)
556
+ model.language_model.assert_called_once_with(
557
+ kv_caches=kv_caches,
558
+ input_ids=input_ids,
559
+ attention_metadata=attn_meta,
560
+ inputs_embeds=None)
561
+ assert len(new_kvs) == 1
562
+ assert x.shape == (1, 10, model.config.hidden_size)
563
+ assert len(aux_hidden_states) == 0
564
+
565
+ def test_compute_logits(self, model: Qwen2_5_VLForConditionalGeneration,
566
+ rng: PRNGKey):
567
+ hidden_states = jnp.ones((1, 10, model.config.hidden_size))
568
+ mock_logits = jnp.ones((1, 10, model.config.vocab_size))
569
+ model.language_model.compute_logits.return_value = mock_logits
570
+
571
+ logits = model.compute_logits(hidden_states)
572
+ np.testing.assert_array_equal(logits, mock_logits)
573
+ model.language_model.compute_logits.assert_called_once_with(
574
+ hidden_states)
575
+
576
+ @patch('tpu_inference.models.jax.qwen2_5_vl.load_hf_weights')
577
+ def test_load_weights(self, mock_load_weights: MagicMock,
578
+ model: Qwen2_5_VLForConditionalGeneration,
579
+ mock_vllm_config: MockVllmConfig, rng: PRNGKey,
580
+ mesh: Mesh):
581
+ model.load_weights(rng)
582
+ mock_load_weights.assert_called_once()
583
+ kwargs = mock_load_weights.call_args.kwargs
584
+ assert kwargs['vllm_config'] == mock_vllm_config
585
+ assert kwargs['model'] is model
586
+ assert "model.embed_tokens" in kwargs['metadata_map'].name_map
587
+ assert "lm_head" in kwargs[
588
+ 'metadata_map'].name_map # Should be present when not tied
589
+ assert kwargs['mesh'] is mesh
590
+ assert isinstance(model.rng, nnx.Rngs)
591
+ assert model.language_model.rng is model.rng
592
+
593
+ @patch('tpu_inference.models.jax.qwen2_5_vl.load_hf_weights')
594
+ def test_load_weights_tied(self, mock_load_weights: MagicMock,
595
+ rng: PRNGKey, mesh: Mesh):
596
+ mock_vllm_config_tied = MockVllmConfig(tie_word_embeddings=True)
597
+ with patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionTransformer', autospec=True), \
598
+ patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2ForCausalLM', autospec=True):
599
+ model = Qwen2_5_VLForConditionalGeneration(mock_vllm_config_tied,
600
+ rng, mesh)
601
+
602
+ model.load_weights(rng)
603
+ mock_load_weights.assert_called_once()
604
+ kwargs = mock_load_weights.call_args.kwargs
605
+ assert "lm_head" not in kwargs['metadata_map'].name_map