tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,436 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from itertools import islice
16
+ from typing import List, Optional, Tuple
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax import nnx
21
+ from jax.sharding import Mesh
22
+ from transformers import LlamaConfig, modeling_flax_utils
23
+ from vllm.config import VllmConfig
24
+
25
+ from tpu_inference import utils
26
+ from tpu_inference.distributed.jax_parallel_state import get_pp_group
27
+ from tpu_inference.layers.common.attention_interface import attention
28
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.sharding import ShardingAxisName
30
+ from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
31
+ from tpu_inference.layers.jax.rope_interface import apply_rope
32
+ from tpu_inference.logger import init_logger
33
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
34
+ JaxIntermediateTensors
35
+ from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
36
+ load_hf_weights)
37
+
38
+ logger = init_logger(__name__)
39
+
40
+ init_fn = nnx.initializers.uniform()
41
+
42
+
43
+ class LlamaMLP(nnx.Module):
44
+
45
+ def __init__(self, config: LlamaConfig, dtype: jnp.dtype, rng: nnx.Rngs):
46
+ hidden_size = config.hidden_size
47
+ intermediate_size = config.intermediate_size
48
+ act = config.hidden_act
49
+
50
+ self.gate_proj = nnx.Linear(
51
+ hidden_size,
52
+ intermediate_size,
53
+ use_bias=False,
54
+ param_dtype=dtype,
55
+ kernel_init=nnx.with_partitioning(
56
+ init_fn, (None, ShardingAxisName.MLP_TENSOR)),
57
+ rngs=rng,
58
+ )
59
+ self.up_proj = nnx.Linear(
60
+ hidden_size,
61
+ intermediate_size,
62
+ use_bias=False,
63
+ param_dtype=dtype,
64
+ kernel_init=nnx.with_partitioning(
65
+ init_fn, (None, ShardingAxisName.MLP_TENSOR)),
66
+ rngs=rng,
67
+ )
68
+ self.down_proj = nnx.Linear(
69
+ intermediate_size,
70
+ hidden_size,
71
+ use_bias=False,
72
+ param_dtype=dtype,
73
+ kernel_init=nnx.with_partitioning(
74
+ init_fn, (ShardingAxisName.MLP_TENSOR, None)),
75
+ rngs=rng,
76
+ )
77
+ self.act_fn = modeling_flax_utils.ACT2FN[act]
78
+
79
+ def __call__(self, x: jax.Array) -> jax.Array:
80
+ gate = self.act_fn(self.gate_proj(x))
81
+ up = self.up_proj(x)
82
+ fuse = gate * up
83
+ result = self.down_proj(fuse)
84
+ return result
85
+
86
+
87
+ class LlamaAttention(nnx.Module):
88
+
89
+ def __init__(self, config: LlamaConfig, dtype: jnp.dtype, rng: nnx.Rngs,
90
+ mesh: Mesh, kv_cache_dtype: str):
91
+ self.hidden_size = config.hidden_size
92
+ self.num_heads = config.num_attention_heads
93
+ self.num_kv_heads = config.num_key_value_heads
94
+ self.rope_theta = config.rope_theta
95
+ self.rope_scaling = getattr(config, "rope_scaling", None)
96
+
97
+ self.head_dim_original = getattr(config, "head_dim",
98
+ self.hidden_size // self.num_heads)
99
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
100
+
101
+ sharding_size = mesh.shape["model"] * mesh.shape.get("attn_dp", 1)
102
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
103
+ sharding_size)
104
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
105
+ sharding_size)
106
+
107
+ self.mesh = mesh
108
+
109
+ self.q_proj = nnx.Einsum(
110
+ "TD,DNH->TNH",
111
+ (self.hidden_size, self.num_heads, self.head_dim),
112
+ param_dtype=dtype,
113
+ kernel_init=nnx.with_partitioning(
114
+ init_fn, (None, ShardingAxisName.ATTN_HEAD, None)),
115
+ rngs=rng,
116
+ )
117
+ self.k_proj = nnx.Einsum(
118
+ "TD,DKH->TKH",
119
+ (self.hidden_size, self.num_kv_heads, self.head_dim),
120
+ param_dtype=dtype,
121
+ kernel_init=nnx.with_partitioning(
122
+ init_fn, (None, ShardingAxisName.ATTN_HEAD, None)),
123
+ rngs=rng,
124
+ )
125
+ self.v_proj = nnx.Einsum(
126
+ "TD,DKH->TKH",
127
+ (self.hidden_size, self.num_kv_heads, self.head_dim),
128
+ param_dtype=dtype,
129
+ kernel_init=nnx.with_partitioning(
130
+ init_fn, (None, ShardingAxisName.ATTN_HEAD, None)),
131
+ rngs=rng,
132
+ )
133
+ self.o_proj = nnx.Einsum(
134
+ "TNH,NHD->TD",
135
+ (self.num_heads, self.head_dim, self.hidden_size),
136
+ param_dtype=dtype,
137
+ kernel_init=nnx.with_partitioning(
138
+ init_fn, (ShardingAxisName.ATTN_HEAD, None, None)),
139
+ rngs=rng,
140
+ )
141
+
142
+ self._q_scale = 1.0
143
+ self._k_scale = 1.0
144
+ self._v_scale = 1.0
145
+ self.kv_cache_quantized_dtype = None
146
+ if kv_cache_dtype != "auto":
147
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
148
+ kv_cache_dtype)
149
+
150
+ def __call__(
151
+ self,
152
+ kv_cache: Optional[jax.Array],
153
+ x: jax.Array,
154
+ attention_metadata: AttentionMetadata,
155
+ ) -> Tuple[jax.Array, jax.Array]:
156
+ md = attention_metadata
157
+ # q: (T, N, H)
158
+ q = self.q_proj(x)
159
+ q = apply_rope(q, md.input_positions, self.head_dim_original,
160
+ self.rope_theta, self.rope_scaling)
161
+ # k: (T, K, H)
162
+ k = self.k_proj(x)
163
+ k = apply_rope(k, md.input_positions, self.head_dim_original,
164
+ self.rope_theta, self.rope_scaling)
165
+ # v: (T, K, H)
166
+ v = self.v_proj(x)
167
+ # o: (T, N, H)
168
+ q_scale = k_scale = v_scale = None
169
+ if self.kv_cache_quantized_dtype:
170
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
171
+ # q_scale = self._q_scale
172
+ k_scale = self._k_scale
173
+ v_scale = self._v_scale
174
+ k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
175
+ k_scale, v_scale)
176
+ new_kv_cache, outputs = attention(
177
+ kv_cache,
178
+ q,
179
+ k,
180
+ v,
181
+ attention_metadata,
182
+ self.mesh,
183
+ self.head_dim_original,
184
+ q_scale=q_scale,
185
+ k_scale=k_scale,
186
+ v_scale=v_scale,
187
+ )
188
+ # (T, D)
189
+ o = self.o_proj(outputs)
190
+ return new_kv_cache, o
191
+
192
+
193
+ class LlamaDecoderLayer(nnx.Module):
194
+
195
+ def __init__(self, config: LlamaConfig, dtype: jnp.dtype, rng: nnx.Rngs,
196
+ mesh: Mesh, kv_cache_dtype: str):
197
+ rms_norm_eps = config.rms_norm_eps
198
+ hidden_size = config.hidden_size
199
+
200
+ self.input_layernorm = nnx.RMSNorm(
201
+ hidden_size,
202
+ epsilon=rms_norm_eps,
203
+ param_dtype=dtype,
204
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
205
+ rngs=rng,
206
+ )
207
+ self.self_attn = LlamaAttention(config=config,
208
+ dtype=dtype,
209
+ rng=rng,
210
+ mesh=mesh,
211
+ kv_cache_dtype=kv_cache_dtype)
212
+ self.post_attention_layernorm = nnx.RMSNorm(
213
+ hidden_size,
214
+ epsilon=rms_norm_eps,
215
+ param_dtype=dtype,
216
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
217
+ rngs=rng,
218
+ )
219
+ self.mlp = LlamaMLP(
220
+ config=config,
221
+ dtype=dtype,
222
+ rng=rng,
223
+ )
224
+
225
+ def __call__(
226
+ self,
227
+ kv_cache: jax.Array,
228
+ x: jax.Array,
229
+ attention_metadata: AttentionMetadata,
230
+ ) -> Tuple[jax.Array, jax.Array]:
231
+ hidden_states = self.input_layernorm(x)
232
+ kv_cache, attn_output = self.self_attn(
233
+ kv_cache,
234
+ hidden_states,
235
+ attention_metadata,
236
+ )
237
+ attn_output += x
238
+
239
+ residual = attn_output
240
+ attn_output = self.post_attention_layernorm(attn_output)
241
+ outputs = self.mlp(attn_output)
242
+ outputs = residual + outputs
243
+ return kv_cache, outputs
244
+
245
+
246
+ class LlamaModel(nnx.Module):
247
+
248
+ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
249
+ mesh: Mesh) -> None:
250
+ model_config = vllm_config.model_config
251
+ hf_config = model_config.hf_config
252
+ vocab_size = model_config.get_vocab_size()
253
+ dtype = model_config.dtype
254
+ rms_norm_eps = hf_config.rms_norm_eps
255
+ hidden_size = hf_config.hidden_size
256
+
257
+ self.is_first_rank = get_pp_group().is_first_rank
258
+ self.is_last_rank = get_pp_group().is_last_rank
259
+
260
+ if self.is_first_rank or (hf_config.tie_word_embeddings
261
+ and self.is_last_rank):
262
+ self.embed = nnx.Embed(
263
+ num_embeddings=vocab_size,
264
+ features=hidden_size,
265
+ param_dtype=dtype,
266
+ embedding_init=nnx.with_partitioning(
267
+ init_fn, (ShardingAxisName.VOCAB, None)),
268
+ rngs=rng,
269
+ )
270
+ else:
271
+ self.embed = PPMissingLayer()
272
+
273
+ self.start_layer, self.end_layer, self.layers = make_layers(
274
+ hf_config.num_hidden_layers,
275
+ lambda: LlamaDecoderLayer(
276
+ config=hf_config,
277
+ dtype=dtype,
278
+ rng=rng,
279
+ mesh=mesh,
280
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
281
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype))
282
+ if self.is_last_rank:
283
+ self.norm = nnx.RMSNorm(
284
+ hidden_size,
285
+ epsilon=rms_norm_eps,
286
+ param_dtype=dtype,
287
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
288
+ rngs=rng,
289
+ )
290
+ else:
291
+ self.norm = PPMissingLayer()
292
+
293
+ if self.is_last_rank:
294
+ if model_config.hf_config.tie_word_embeddings:
295
+ self.lm_head = self.embed.embedding
296
+ else:
297
+ self.lm_head = nnx.Param(
298
+ init_fn(rng.params(), (hidden_size, vocab_size), dtype),
299
+ sharding=(None, ShardingAxisName.VOCAB),
300
+ )
301
+ else:
302
+ self.lm_head = PPMissingLayer()
303
+
304
+ self.aux_hidden_state_layers = []
305
+ if vllm_config.speculative_config and vllm_config.speculative_config.method == "eagle3":
306
+ self.aux_hidden_state_layers = self.get_eagle3_aux_hidden_state_layers(
307
+ )
308
+
309
+ def get_eagle3_aux_hidden_state_layers(self):
310
+ num_layers = len(self.layers)
311
+ return (2, num_layers // 2, num_layers - 3)
312
+
313
+ def __call__(
314
+ self,
315
+ kv_caches: List[jax.Array],
316
+ input_ids: jax.Array,
317
+ attention_metadata: AttentionMetadata,
318
+ intermediate_tensors: JaxIntermediateTensors | None,
319
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
320
+ List[jax.Array], JaxIntermediateTensors]:
321
+ if self.is_first_rank:
322
+ x = self.embed(input_ids)
323
+ else:
324
+ assert intermediate_tensors is not None
325
+ x = intermediate_tensors["hidden_states"]
326
+
327
+ aux_hidden_states = []
328
+ for i, layer in enumerate(
329
+ islice(self.layers, self.start_layer, self.end_layer)):
330
+ if i in self.aux_hidden_state_layers:
331
+ aux_hidden_states.append(x)
332
+ kv_cache = kv_caches[i]
333
+ kv_cache, x = layer(
334
+ kv_cache,
335
+ x,
336
+ attention_metadata,
337
+ )
338
+ kv_caches[i] = kv_cache
339
+ if not self.is_last_rank:
340
+ # Note: add aux_hidden_states to make the output spec consistent.
341
+ return kv_caches, JaxIntermediateTensors({"hidden_states":
342
+ x}), aux_hidden_states
343
+ x = self.norm(x)
344
+ return kv_caches, x, aux_hidden_states
345
+
346
+
347
+ class LlamaForCausalLM(nnx.Module):
348
+
349
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
350
+ mesh: Mesh) -> None:
351
+ self.vllm_config = vllm_config
352
+ self.rng = nnx.Rngs(rng_key)
353
+ self.mesh = mesh
354
+
355
+ self.model = LlamaModel(
356
+ vllm_config=vllm_config,
357
+ rng=self.rng,
358
+ mesh=mesh,
359
+ )
360
+
361
+ self.pp_missing_layers = []
362
+ for path, module in nnx.iter_graph(self.model):
363
+ if isinstance(module, PPMissingLayer):
364
+ # the path should be sth like ('layers', '0')
365
+ self.pp_missing_layers.append('.'.join([str(s) for s in path]))
366
+
367
+ def __call__(
368
+ self,
369
+ kv_caches: List[jax.Array],
370
+ input_ids: jax.Array,
371
+ attention_metadata: AttentionMetadata,
372
+ _input_embeds,
373
+ _input_positions,
374
+ _layer_name_to_kv_cache,
375
+ _lora_metadata,
376
+ intermediate_tensors: JaxIntermediateTensors,
377
+ _is_first_rank: bool,
378
+ _is_last_rank: bool,
379
+ *args,
380
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
381
+ List[jax.Array], JaxIntermediateTensors]:
382
+ return self.model(
383
+ kv_caches,
384
+ input_ids,
385
+ attention_metadata,
386
+ intermediate_tensors,
387
+ )
388
+
389
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
390
+ if self.vllm_config.model_config.hf_config.tie_word_embeddings:
391
+ logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
392
+ else:
393
+ logits = jnp.dot(hidden_states, self.model.lm_head.value)
394
+ return logits
395
+
396
+ def load_weights(self, rng_key: jax.Array):
397
+ # NOTE: Since we are using nnx.eval_shape to init the model,
398
+ # we have to pass dynamic arrays here for __call__'s usage.
399
+ self.rng = nnx.Rngs(rng_key)
400
+
401
+ # Key: path to a HF layer weight
402
+ # Value: path to a nnx layer weight
403
+ mappings = {
404
+ "model.embed_tokens": "model.embed.embedding",
405
+ "model.layers.*.input_layernorm":
406
+ "model.layers.*.input_layernorm.scale",
407
+ "model.layers.*.mlp.down_proj":
408
+ "model.layers.*.mlp.down_proj.kernel",
409
+ "model.layers.*.mlp.gate_proj":
410
+ "model.layers.*.mlp.gate_proj.kernel",
411
+ "model.layers.*.mlp.up_proj": "model.layers.*.mlp.up_proj.kernel",
412
+ "model.layers.*.post_attention_layernorm":
413
+ "model.layers.*.post_attention_layernorm.scale",
414
+ "model.layers.*.self_attn.k_proj":
415
+ "model.layers.*.self_attn.k_proj.kernel",
416
+ "model.layers.*.self_attn.o_proj":
417
+ "model.layers.*.self_attn.o_proj.kernel",
418
+ "model.layers.*.self_attn.q_proj":
419
+ "model.layers.*.self_attn.q_proj.kernel",
420
+ "model.layers.*.self_attn.v_proj":
421
+ "model.layers.*.self_attn.v_proj.kernel",
422
+ "model.norm": "model.norm.scale",
423
+ }
424
+ # Add lm_head mapping only if it's not tied to embeddings
425
+ if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
426
+ mappings.update({
427
+ "lm_head": "model.lm_head",
428
+ })
429
+
430
+ metadata_map = get_default_maps(self.vllm_config.model_config,
431
+ self.mesh, mappings)
432
+ load_hf_weights(vllm_config=self.vllm_config,
433
+ model=self,
434
+ metadata_map=metadata_map,
435
+ mesh=self.mesh,
436
+ pp_missing_layers=self.pp_missing_layers)