tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,350 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Tuple
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from flax import nnx
20
+ from jax.sharding import Mesh
21
+ from transformers import LlamaConfig
22
+ from vllm.config import VllmConfig
23
+
24
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
25
+ from tpu_inference.logger import init_logger
26
+ from tpu_inference.models.jax.llama3 import LlamaDecoderLayer
27
+ from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
28
+ get_default_maps,
29
+ load_hf_weights)
30
+
31
+ logger = init_logger(__name__)
32
+
33
+ init_fn = nnx.initializers.uniform()
34
+
35
+
36
+ class Eagle3LlamaDecoderLayer(LlamaDecoderLayer):
37
+
38
+ def __init__(self, config: LlamaConfig, dtype: jnp.dtype, rng: nnx.Rngs,
39
+ mesh: Mesh, kv_cache_dtype: str):
40
+ super().__init__(config,
41
+ dtype=dtype,
42
+ rng=rng,
43
+ mesh=mesh,
44
+ kv_cache_dtype=kv_cache_dtype)
45
+ self.config = config
46
+ # Override qkv
47
+ hidden_size = 2 * self.self_attn.hidden_size
48
+ self.self_attn.q_proj = nnx.Einsum(
49
+ "TD,DNH->TNH",
50
+ (hidden_size, self.self_attn.num_heads, self.self_attn.head_dim),
51
+ param_dtype=dtype,
52
+ dtype=dtype,
53
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
54
+ rngs=rng,
55
+ )
56
+ self.self_attn.k_proj = nnx.Einsum(
57
+ "TD,DKH->TKH",
58
+ (hidden_size, self.self_attn.num_kv_heads,
59
+ self.self_attn.head_dim),
60
+ param_dtype=dtype,
61
+ dtype=dtype,
62
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
63
+ rngs=rng,
64
+ )
65
+ self.self_attn.v_proj = nnx.Einsum(
66
+ "TD,DKH->TKH",
67
+ (hidden_size, self.self_attn.num_kv_heads,
68
+ self.self_attn.head_dim),
69
+ param_dtype=dtype,
70
+ dtype=dtype,
71
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
72
+ rngs=rng,
73
+ )
74
+ # Override input layernorm and specify dtype to avoid unexpected upcasting.
75
+ self.input_layernorm = nnx.RMSNorm(
76
+ config.hidden_size,
77
+ epsilon=config.rms_norm_eps,
78
+ param_dtype=dtype,
79
+ dtype=dtype,
80
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
81
+ rngs=rng,
82
+ )
83
+ self.hidden_norm = nnx.RMSNorm(
84
+ config.hidden_size,
85
+ epsilon=config.rms_norm_eps,
86
+ param_dtype=dtype,
87
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
88
+ rngs=rng,
89
+ )
90
+
91
+ def _norm_before_residual(
92
+ self, hidden_states: jax.Array) -> tuple[jax.Array, jax.Array]:
93
+ hidden_states = self.hidden_norm(hidden_states)
94
+ residual = hidden_states
95
+ return hidden_states, residual
96
+
97
+ def _norm_after_residual(
98
+ self, hidden_states: jax.Array) -> tuple[jax.Array, jax.Array]:
99
+ residual = hidden_states
100
+ hidden_states = self.hidden_norm(hidden_states)
101
+ return hidden_states, residual
102
+
103
+ def __call__(
104
+ self,
105
+ kv_cache: jax.Array,
106
+ embeds: jax.Array,
107
+ hidden_states: jax.Array,
108
+ attention_metadata: AttentionMetadata,
109
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
110
+ embeds = self.input_layernorm(embeds)
111
+ if getattr(self.config, "norm_before_residual", False):
112
+ hidden_states, residual = self._norm_before_residual(
113
+ hidden_states=hidden_states)
114
+ else:
115
+ hidden_states, residual = self._norm_after_residual(
116
+ hidden_states=hidden_states)
117
+ hidden_states = jnp.concatenate([embeds, hidden_states], axis=-1)
118
+
119
+ kv_cache, attn_output = self.self_attn(
120
+ kv_cache,
121
+ hidden_states,
122
+ attention_metadata,
123
+ )
124
+
125
+ # TODO(ranlihao): Check if this residual connection is correct.
126
+ hidden_states = attn_output + residual
127
+ residual = hidden_states
128
+ hidden_states = self.post_attention_layernorm(hidden_states)
129
+ mlp_output = self.mlp(hidden_states)
130
+
131
+ return kv_cache, mlp_output, residual
132
+
133
+
134
+ class Eagle3LlamaModel(nnx.Module):
135
+
136
+ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, mesh: Mesh):
137
+ super().__init__()
138
+ hf_config = vllm_config.speculative_config.draft_model_config.hf_config
139
+ dtype: jnp.dtype = jnp.bfloat16
140
+
141
+ self.embed_tokens = nnx.Embed(
142
+ num_embeddings=hf_config.vocab_size,
143
+ features=hf_config.hidden_size,
144
+ param_dtype=dtype,
145
+ embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
146
+ rngs=rng,
147
+ )
148
+
149
+ self.layers = [
150
+ Eagle3LlamaDecoderLayer(
151
+ config=hf_config,
152
+ dtype=dtype,
153
+ rng=rng,
154
+ mesh=mesh,
155
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
156
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype)
157
+ ]
158
+
159
+ if hasattr(hf_config, "target_hidden_size"):
160
+ input_size = hf_config.target_hidden_size * 3
161
+ else:
162
+ input_size = hf_config.hidden_size * 3
163
+
164
+ self.fc = nnx.Linear(
165
+ in_features=input_size,
166
+ out_features=hf_config.hidden_size,
167
+ use_bias=False,
168
+ param_dtype=dtype,
169
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
170
+ rngs=rng,
171
+ )
172
+
173
+ self.norm = nnx.RMSNorm(
174
+ hf_config.hidden_size,
175
+ epsilon=hf_config.rms_norm_eps,
176
+ param_dtype=dtype,
177
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
178
+ rngs=rng,
179
+ )
180
+
181
+ def __call__(
182
+ self,
183
+ kv_caches: List[jax.Array],
184
+ input_ids: jax.Array,
185
+ hidden_states: jax.Array,
186
+ attention_metadata: AttentionMetadata,
187
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
188
+ embeds = self.embed_tokens(input_ids)
189
+ assert hidden_states.shape[-1] == embeds.shape[-1]
190
+
191
+ assert len(self.layers) == 1
192
+ # The first N - 1 KV caches are for the target model, and the last one is for the draft model.
193
+ # N is the number of layers in the target model.
194
+ # The draft model has only 1 layer.
195
+ kv_caches[-1], hidden_states, residual = self.layers[0](
196
+ kv_caches[-1],
197
+ embeds,
198
+ hidden_states,
199
+ attention_metadata,
200
+ )
201
+
202
+ # TODO(ranlihao): Check if this residual connection is correct.
203
+ hidden_states = hidden_states + residual
204
+ residual = hidden_states
205
+ hidden_states = self.norm(hidden_states)
206
+ return kv_caches, hidden_states, [residual]
207
+
208
+
209
+ def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
210
+ metadata_map: MetadataMap):
211
+ model_config = vllm_config.speculative_config.draft_model_config
212
+ hf_config = model_config.hf_config
213
+
214
+ num_heads = hf_config.num_attention_heads
215
+ num_kv_heads = hf_config.num_key_value_heads
216
+ hidden_size = hf_config.hidden_size
217
+ head_dim_original = model_config.get_head_size()
218
+
219
+ metadata_map.reshape_map.update({
220
+ "q_proj": (num_heads, head_dim_original, 2 * hidden_size),
221
+ "k_proj": (num_kv_heads, head_dim_original, 2 * hidden_size),
222
+ "v_proj": (num_kv_heads, head_dim_original, 2 * hidden_size),\
223
+ })
224
+
225
+
226
+ class EagleLlama3ForCausalLM(nnx.Module):
227
+
228
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
229
+ mesh: Mesh):
230
+ nnx.Module.__init__(self)
231
+ self.vllm_config = vllm_config
232
+ self.rng = nnx.Rngs(rng_key)
233
+ self.mesh = mesh
234
+ dtype: jnp.dtype = jnp.bfloat16
235
+
236
+ spec_config = vllm_config.speculative_config
237
+ assert spec_config is not None
238
+ model_config = spec_config.draft_model_config
239
+ assert model_config is not None
240
+ hf_config = model_config.hf_config
241
+
242
+ self.model = Eagle3LlamaModel(
243
+ vllm_config=vllm_config,
244
+ rng=self.rng,
245
+ mesh=mesh,
246
+ )
247
+
248
+ self.lm_head = nnx.Linear(
249
+ hf_config.hidden_size,
250
+ hf_config.draft_vocab_size,
251
+ use_bias=False,
252
+ param_dtype=dtype,
253
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
254
+ rngs=self.rng,
255
+ )
256
+
257
+ self.draft_id_to_target_id = nnx.Param(jnp.zeros(
258
+ hf_config.draft_vocab_size, dtype=jnp.int32),
259
+ sharding=(None, ))
260
+
261
+ def __call__(
262
+ self,
263
+ kv_caches: List[jax.Array],
264
+ input_ids: jax.Array,
265
+ hidden_states: jax.Array,
266
+ attention_metadata: AttentionMetadata,
267
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
268
+ return self.model(
269
+ kv_caches,
270
+ input_ids,
271
+ hidden_states,
272
+ attention_metadata,
273
+ )
274
+
275
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
276
+ logits = self.lm_head(hidden_states)
277
+
278
+ target_vocab_size = self.vllm_config.model_config.get_vocab_size()
279
+ draft_vocab_size = self.vllm_config.speculative_config.draft_model_config.hf_config.draft_vocab_size
280
+
281
+ base = jnp.arange(draft_vocab_size, dtype=jnp.int32)
282
+ targets = base + self.draft_id_to_target_id.value
283
+
284
+ logits_new = jnp.full((logits.shape[0], target_vocab_size),
285
+ -jnp.inf,
286
+ dtype=logits.dtype)
287
+
288
+ logits_new = logits_new.at[:, targets].set(logits)
289
+
290
+ return logits_new
291
+
292
+ def combine_hidden_states(self, hidden_states: jax.Array) -> jax.Array:
293
+ return self.model.fc(hidden_states)
294
+
295
+ def load_weights(self, rng_key: jax.Array):
296
+ # Create a new Rngs object for the draft model to avoid sharing RNG state
297
+ self.rng = jax.random.key(self.vllm_config.model_config.seed)
298
+ spec_config = self.vllm_config.speculative_config
299
+ assert spec_config is not None
300
+
301
+ mappings = {
302
+ "midlayer.input_layernorm": "model.layers.0.input_layernorm.scale",
303
+ "midlayer.hidden_norm": "model.layers.0.hidden_norm.scale",
304
+ "midlayer.mlp.down_proj": "model.layers.0.mlp.down_proj.kernel",
305
+ "midlayer.mlp.gate_proj": "model.layers.0.mlp.gate_proj.kernel",
306
+ "midlayer.mlp.up_proj": "model.layers.0.mlp.up_proj.kernel",
307
+ "midlayer.post_attention_layernorm":
308
+ "model.layers.0.post_attention_layernorm.scale",
309
+ "midlayer.self_attn.k_proj":
310
+ "model.layers.0.self_attn.k_proj.kernel",
311
+ "midlayer.self_attn.o_proj":
312
+ "model.layers.0.self_attn.o_proj.kernel",
313
+ "midlayer.self_attn.q_proj":
314
+ "model.layers.0.self_attn.q_proj.kernel",
315
+ "midlayer.self_attn.v_proj":
316
+ "model.layers.0.self_attn.v_proj.kernel",
317
+ "norm": "model.norm.scale",
318
+ "fc": "model.fc.kernel",
319
+ "lm_head": "lm_head.kernel",
320
+ "d2t": "draft_id_to_target_id",
321
+ "embed_tokens":
322
+ "model.embed_tokens.embedding", # Some checkpoints need this
323
+ }
324
+
325
+ # Define keys to keep in original dtype (e.g., float32 for stability)
326
+ keep_original_dtype_keys_regex = [
327
+ r".*d2t.*",
328
+ ]
329
+
330
+ metadata_map = get_default_maps(
331
+ self.vllm_config.speculative_config.draft_model_config, self.mesh,
332
+ mappings)
333
+
334
+ update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
335
+
336
+ load_hf_weights(
337
+ vllm_config=self.vllm_config,
338
+ model=self,
339
+ metadata_map=metadata_map,
340
+ mesh=self.mesh,
341
+ is_draft_model=True,
342
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
343
+
344
+ # If the embedding is not initialized, initialize it with a dummy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
345
+ if isinstance(self.model.embed_tokens.embedding.value,
346
+ jax.ShapeDtypeStruct):
347
+ self.model.embed_tokens.embedding.value = jnp.zeros(
348
+ self.model.embed_tokens.embedding.shape,
349
+ dtype=self.model.embed_tokens.embedding.dtype,
350
+ )