tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,336 @@
1
+ from typing import List, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from jax.sharding import Mesh
7
+ from transformers import LlamaConfig
8
+ from vllm.config import VllmConfig
9
+
10
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
11
+ from tpu_inference.logger import init_logger
12
+ from tpu_inference.models.jax.llama3 import LlamaDecoderLayer
13
+ from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
14
+ get_default_maps,
15
+ load_hf_weights)
16
+
17
+ logger = init_logger(__name__)
18
+
19
+ init_fn = nnx.initializers.uniform()
20
+
21
+
22
+ class Eagle3LlamaDecoderLayer(LlamaDecoderLayer):
23
+
24
+ def __init__(self, config: LlamaConfig, dtype: jnp.dtype, rng: nnx.Rngs,
25
+ mesh: Mesh, kv_cache_dtype: str):
26
+ super().__init__(config,
27
+ dtype=dtype,
28
+ rng=rng,
29
+ mesh=mesh,
30
+ kv_cache_dtype=kv_cache_dtype)
31
+ self.config = config
32
+ # Override qkv
33
+ hidden_size = 2 * self.self_attn.hidden_size
34
+ self.self_attn.q_proj = nnx.Einsum(
35
+ "TD,DNH->TNH",
36
+ (hidden_size, self.self_attn.num_heads, self.self_attn.head_dim),
37
+ param_dtype=dtype,
38
+ dtype=dtype,
39
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
40
+ rngs=rng,
41
+ )
42
+ self.self_attn.k_proj = nnx.Einsum(
43
+ "TD,DKH->TKH",
44
+ (hidden_size, self.self_attn.num_kv_heads,
45
+ self.self_attn.head_dim),
46
+ param_dtype=dtype,
47
+ dtype=dtype,
48
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
49
+ rngs=rng,
50
+ )
51
+ self.self_attn.v_proj = nnx.Einsum(
52
+ "TD,DKH->TKH",
53
+ (hidden_size, self.self_attn.num_kv_heads,
54
+ self.self_attn.head_dim),
55
+ param_dtype=dtype,
56
+ dtype=dtype,
57
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
58
+ rngs=rng,
59
+ )
60
+ # Override input layernorm and specify dtype to avoid unexpected upcasting.
61
+ self.input_layernorm = nnx.RMSNorm(
62
+ config.hidden_size,
63
+ epsilon=config.rms_norm_eps,
64
+ param_dtype=dtype,
65
+ dtype=dtype,
66
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
67
+ rngs=rng,
68
+ )
69
+ self.hidden_norm = nnx.RMSNorm(
70
+ config.hidden_size,
71
+ epsilon=config.rms_norm_eps,
72
+ param_dtype=dtype,
73
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
74
+ rngs=rng,
75
+ )
76
+
77
+ def _norm_before_residual(
78
+ self, hidden_states: jax.Array) -> tuple[jax.Array, jax.Array]:
79
+ hidden_states = self.hidden_norm(hidden_states)
80
+ residual = hidden_states
81
+ return hidden_states, residual
82
+
83
+ def _norm_after_residual(
84
+ self, hidden_states: jax.Array) -> tuple[jax.Array, jax.Array]:
85
+ residual = hidden_states
86
+ hidden_states = self.hidden_norm(hidden_states)
87
+ return hidden_states, residual
88
+
89
+ def __call__(
90
+ self,
91
+ kv_cache: jax.Array,
92
+ embeds: jax.Array,
93
+ hidden_states: jax.Array,
94
+ attention_metadata: AttentionMetadata,
95
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
96
+ embeds = self.input_layernorm(embeds)
97
+ if getattr(self.config, "norm_before_residual", False):
98
+ hidden_states, residual = self._norm_before_residual(
99
+ hidden_states=hidden_states)
100
+ else:
101
+ hidden_states, residual = self._norm_after_residual(
102
+ hidden_states=hidden_states)
103
+ hidden_states = jnp.concatenate([embeds, hidden_states], axis=-1)
104
+
105
+ kv_cache, attn_output = self.self_attn(
106
+ kv_cache,
107
+ hidden_states,
108
+ attention_metadata,
109
+ )
110
+
111
+ # TODO(ranlihao): Check if this residual connection is correct.
112
+ hidden_states = attn_output + residual
113
+ residual = hidden_states
114
+ hidden_states = self.post_attention_layernorm(hidden_states)
115
+ mlp_output = self.mlp(hidden_states)
116
+
117
+ return kv_cache, mlp_output, residual
118
+
119
+
120
+ class Eagle3LlamaModel(nnx.Module):
121
+
122
+ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs, mesh: Mesh):
123
+ super().__init__()
124
+ hf_config = vllm_config.speculative_config.draft_model_config.hf_config
125
+ dtype: jnp.dtype = jnp.bfloat16
126
+
127
+ self.embed_tokens = nnx.Embed(
128
+ num_embeddings=hf_config.vocab_size,
129
+ features=hf_config.hidden_size,
130
+ param_dtype=dtype,
131
+ embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
132
+ rngs=rng,
133
+ )
134
+
135
+ self.layers = [
136
+ Eagle3LlamaDecoderLayer(
137
+ config=hf_config,
138
+ dtype=dtype,
139
+ rng=rng,
140
+ mesh=mesh,
141
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
142
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype)
143
+ ]
144
+
145
+ if hasattr(hf_config, "target_hidden_size"):
146
+ input_size = hf_config.target_hidden_size * 3
147
+ else:
148
+ input_size = hf_config.hidden_size * 3
149
+
150
+ self.fc = nnx.Linear(
151
+ in_features=input_size,
152
+ out_features=hf_config.hidden_size,
153
+ use_bias=False,
154
+ param_dtype=dtype,
155
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
156
+ rngs=rng,
157
+ )
158
+
159
+ self.norm = nnx.RMSNorm(
160
+ hf_config.hidden_size,
161
+ epsilon=hf_config.rms_norm_eps,
162
+ param_dtype=dtype,
163
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
164
+ rngs=rng,
165
+ )
166
+
167
+ def __call__(
168
+ self,
169
+ kv_caches: List[jax.Array],
170
+ input_ids: jax.Array,
171
+ hidden_states: jax.Array,
172
+ attention_metadata: AttentionMetadata,
173
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
174
+ embeds = self.embed_tokens(input_ids)
175
+ assert hidden_states.shape[-1] == embeds.shape[-1]
176
+
177
+ assert len(self.layers) == 1
178
+ # The first N - 1 KV caches are for the target model, and the last one is for the draft model.
179
+ # N is the number of layers in the target model.
180
+ # The draft model has only 1 layer.
181
+ kv_caches[-1], hidden_states, residual = self.layers[0](
182
+ kv_caches[-1],
183
+ embeds,
184
+ hidden_states,
185
+ attention_metadata,
186
+ )
187
+
188
+ # TODO(ranlihao): Check if this residual connection is correct.
189
+ hidden_states = hidden_states + residual
190
+ residual = hidden_states
191
+ hidden_states = self.norm(hidden_states)
192
+ return kv_caches, hidden_states, [residual]
193
+
194
+
195
+ def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
196
+ metadata_map: MetadataMap):
197
+ model_config = vllm_config.speculative_config.draft_model_config
198
+ hf_config = model_config.hf_config
199
+
200
+ num_heads = hf_config.num_attention_heads
201
+ num_kv_heads = hf_config.num_key_value_heads
202
+ hidden_size = hf_config.hidden_size
203
+ head_dim_original = model_config.get_head_size()
204
+
205
+ metadata_map.reshape_map.update({
206
+ "q_proj": (num_heads, head_dim_original, 2 * hidden_size),
207
+ "k_proj": (num_kv_heads, head_dim_original, 2 * hidden_size),
208
+ "v_proj": (num_kv_heads, head_dim_original, 2 * hidden_size),\
209
+ })
210
+
211
+
212
+ class EagleLlama3ForCausalLM(nnx.Module):
213
+
214
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
215
+ mesh: Mesh):
216
+ nnx.Module.__init__(self)
217
+ self.vllm_config = vllm_config
218
+ self.rng = nnx.Rngs(rng_key)
219
+ self.mesh = mesh
220
+ dtype: jnp.dtype = jnp.bfloat16
221
+
222
+ spec_config = vllm_config.speculative_config
223
+ assert spec_config is not None
224
+ model_config = spec_config.draft_model_config
225
+ assert model_config is not None
226
+ hf_config = model_config.hf_config
227
+
228
+ self.model = Eagle3LlamaModel(
229
+ vllm_config=vllm_config,
230
+ rng=self.rng,
231
+ mesh=mesh,
232
+ )
233
+
234
+ self.lm_head = nnx.Linear(
235
+ hf_config.hidden_size,
236
+ hf_config.draft_vocab_size,
237
+ use_bias=False,
238
+ param_dtype=dtype,
239
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
240
+ rngs=self.rng,
241
+ )
242
+
243
+ self.draft_id_to_target_id = nnx.Param(jnp.zeros(
244
+ hf_config.draft_vocab_size, dtype=jnp.int32),
245
+ sharding=(None, ))
246
+
247
+ def __call__(
248
+ self,
249
+ kv_caches: List[jax.Array],
250
+ input_ids: jax.Array,
251
+ hidden_states: jax.Array,
252
+ attention_metadata: AttentionMetadata,
253
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
254
+ return self.model(
255
+ kv_caches,
256
+ input_ids,
257
+ hidden_states,
258
+ attention_metadata,
259
+ )
260
+
261
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
262
+ logits = self.lm_head(hidden_states)
263
+
264
+ target_vocab_size = self.vllm_config.model_config.get_vocab_size()
265
+ draft_vocab_size = self.vllm_config.speculative_config.draft_model_config.hf_config.draft_vocab_size
266
+
267
+ base = jnp.arange(draft_vocab_size, dtype=jnp.int32)
268
+ targets = base + self.draft_id_to_target_id.value
269
+
270
+ logits_new = jnp.full((logits.shape[0], target_vocab_size),
271
+ -jnp.inf,
272
+ dtype=logits.dtype)
273
+
274
+ logits_new = logits_new.at[:, targets].set(logits)
275
+
276
+ return logits_new
277
+
278
+ def combine_hidden_states(self, hidden_states: jax.Array) -> jax.Array:
279
+ return self.model.fc(hidden_states)
280
+
281
+ def load_weights(self, rng_key: jax.Array):
282
+ # Create a new Rngs object for the draft model to avoid sharing RNG state
283
+ self.rng = jax.random.key(self.vllm_config.model_config.seed)
284
+ spec_config = self.vllm_config.speculative_config
285
+ assert spec_config is not None
286
+
287
+ mappings = {
288
+ "midlayer.input_layernorm": "model.layers.0.input_layernorm.scale",
289
+ "midlayer.hidden_norm": "model.layers.0.hidden_norm.scale",
290
+ "midlayer.mlp.down_proj": "model.layers.0.mlp.down_proj.kernel",
291
+ "midlayer.mlp.gate_proj": "model.layers.0.mlp.gate_proj.kernel",
292
+ "midlayer.mlp.up_proj": "model.layers.0.mlp.up_proj.kernel",
293
+ "midlayer.post_attention_layernorm":
294
+ "model.layers.0.post_attention_layernorm.scale",
295
+ "midlayer.self_attn.k_proj":
296
+ "model.layers.0.self_attn.k_proj.kernel",
297
+ "midlayer.self_attn.o_proj":
298
+ "model.layers.0.self_attn.o_proj.kernel",
299
+ "midlayer.self_attn.q_proj":
300
+ "model.layers.0.self_attn.q_proj.kernel",
301
+ "midlayer.self_attn.v_proj":
302
+ "model.layers.0.self_attn.v_proj.kernel",
303
+ "norm": "model.norm.scale",
304
+ "fc": "model.fc.kernel",
305
+ "lm_head": "lm_head.kernel",
306
+ "d2t": "draft_id_to_target_id",
307
+ "embed_tokens":
308
+ "model.embed_tokens.embedding", # Some checkpoints need this
309
+ }
310
+
311
+ # Define keys to keep in original dtype (e.g., float32 for stability)
312
+ keep_original_dtype_keys_regex = [
313
+ r".*d2t.*",
314
+ ]
315
+
316
+ metadata_map = get_default_maps(
317
+ self.vllm_config.speculative_config.draft_model_config, self.mesh,
318
+ mappings)
319
+
320
+ update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
321
+
322
+ load_hf_weights(
323
+ vllm_config=self.vllm_config,
324
+ model=self,
325
+ metadata_map=metadata_map,
326
+ mesh=self.mesh,
327
+ is_draft_model=True,
328
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
329
+
330
+ # If the embedding is not initialized, initialize it with a dummy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
331
+ if isinstance(self.model.embed_tokens.embedding.value,
332
+ jax.ShapeDtypeStruct):
333
+ self.model.embed_tokens.embedding.value = jnp.zeros(
334
+ self.model.embed_tokens.embedding.shape,
335
+ dtype=self.model.embed_tokens.embedding.dtype,
336
+ )
@@ -0,0 +1,361 @@
1
+ import re
2
+ from typing import Any, List, Optional, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import torch
7
+ from flax import nnx
8
+ from flax.typing import PRNGKey
9
+ from jax.sharding import Mesh
10
+ from jax.sharding import PartitionSpec as P
11
+ from vllm.config import VllmConfig
12
+
13
+ from tpu_inference.layers.jax.attention.attention import AttentionMetadata
14
+ from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
15
+ from tpu_inference.layers.jax.constants import KVCacheType
16
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
17
+ from tpu_inference.layers.jax.misc import shard_put
18
+ from tpu_inference.layers.jax.transformer_block import TransformerBlock
19
+ from tpu_inference.logger import init_logger
20
+ from tpu_inference.models.jax.utils.weight_utils import (
21
+ get_param, model_weights_generator, print_param_info, reshape_params,
22
+ transpose_params)
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ class LlamaGuard4ForCausalLM(nnx.Module):
28
+
29
+ def __init__(self,
30
+ vllm_config: VllmConfig,
31
+ rng: PRNGKey,
32
+ mesh: Mesh,
33
+ force_random_weights: bool = False):
34
+ logger.warning(
35
+ "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
36
+ "Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
37
+ "Multimodal inputs will fail.\n"
38
+ "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨")
39
+ assert mesh is not None
40
+
41
+ self.vllm_config = vllm_config
42
+ self.vllm_config.model_config.dtype = torch.bfloat16
43
+ model_config = vllm_config.model_config
44
+ text_config = model_config.hf_config.text_config
45
+
46
+ self.mesh = mesh
47
+ self.is_verbose = getattr(self.vllm_config.additional_config,
48
+ "is_verbose", False)
49
+
50
+ self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
51
+
52
+ vocab_size = model_config.get_vocab_size()
53
+ self.hidden_size = model_config.get_hidden_size()
54
+
55
+ self.dtype: jnp.dtype = jnp.bfloat16
56
+
57
+ self.num_layers: int = getattr(text_config, "num_layers", 48)
58
+ hidden_act: str = getattr(text_config, "hidden_act", "silu")
59
+
60
+ rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
61
+ self.num_attention_heads = getattr(text_config, "num_attention_heads",
62
+ 40)
63
+ self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
64
+ 8)
65
+ self.head_dim = getattr(text_config, "head_dim", 128)
66
+
67
+ intermediate_size = getattr(text_config, "intermediate_size", 8192)
68
+
69
+ self.rope_theta_text = getattr(text_config, "rope_theta", 500000.0)
70
+ self.rope_scaling = getattr(text_config, "rope_scaling")
71
+
72
+ self.rng = nnx.Rngs(rng)
73
+
74
+ self.embedder = Embedder(
75
+ vocab_size=vocab_size,
76
+ hidden_size=self.hidden_size,
77
+ dtype=self.dtype,
78
+ vd_sharding=(('data', 'model'), None),
79
+ rngs=self.rng,
80
+ random_init=force_random_weights,
81
+ )
82
+
83
+ self.layers = []
84
+
85
+ for i in range(self.num_layers):
86
+ use_attention_rope = True
87
+
88
+ custom_module = DenseFFW(dtype=self.dtype,
89
+ hidden_act=hidden_act,
90
+ hidden_size=self.hidden_size,
91
+ intermediate_size=intermediate_size,
92
+ random_init=force_random_weights,
93
+ rngs=self.rng,
94
+ df_sharding=P(None, 'model'),
95
+ fd_sharding=P('model', None),
96
+ activation_ffw_td=P('data', None))
97
+
98
+ attn = Llama4Attention(
99
+ hidden_size=self.hidden_size,
100
+ dtype=self.dtype,
101
+ num_attention_heads=self.num_attention_heads,
102
+ num_key_value_heads=self.num_key_value_heads,
103
+ head_dim=self.head_dim,
104
+ rope_theta=self.rope_theta_text,
105
+ rope_scaling={
106
+ "scale_factor":
107
+ self.rope_scaling["factor"],
108
+ "low_freq_factor":
109
+ self.rope_scaling["low_freq_factor"],
110
+ "high_freq_factor":
111
+ self.rope_scaling["high_freq_factor"],
112
+ "original_max_position_embeddings":
113
+ self.rope_scaling["original_max_position_embeddings"]
114
+ },
115
+ rngs=self.rng,
116
+ rope_input_ordering="interleaved",
117
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
118
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype,
119
+ temperature_tuning=True,
120
+ temperature_tuning_scale=0.1,
121
+ temperature_tuning_floor_scale=8192,
122
+ use_qk_norm=self.use_qk_norm,
123
+ attention_chunk_size=None if use_attention_rope else 8192,
124
+ mesh=self.mesh,
125
+ random_init=force_random_weights,
126
+ activation_attention_td=('data', 'model'),
127
+ activation_q_td=('data', 'model'),
128
+ query_tnh=P('data', 'model', None),
129
+ keyvalue_skh=P('data', 'model', None),
130
+ activation_attention_out_td=('data', 'model'),
131
+ attn_o_tnh=P('data', 'model', None),
132
+ dnh_sharding=(None, 'model', None),
133
+ dkh_sharding=(None, 'model', None),
134
+ nhd_sharding=('model', None, None),
135
+ )
136
+
137
+ pre_attention_norm = RMSNorm(
138
+ dims=self.hidden_size,
139
+ random_init=force_random_weights,
140
+ epsilon=rms_norm_eps,
141
+ rngs=self.rng,
142
+ activation_ffw_td=('data', None),
143
+ with_scale=True,
144
+ dtype=self.dtype,
145
+ )
146
+
147
+ pre_mlp_norm = RMSNorm(
148
+ dims=self.hidden_size,
149
+ activation_ffw_td=('data', None),
150
+ epsilon=rms_norm_eps,
151
+ rngs=self.rng,
152
+ with_scale=True,
153
+ dtype=self.dtype,
154
+ random_init=force_random_weights,
155
+ )
156
+
157
+ block = TransformerBlock(custom_module=custom_module,
158
+ attn=attn,
159
+ pre_attention_norm=pre_attention_norm,
160
+ pre_mlp_norm=pre_mlp_norm,
161
+ use_attention_rope=use_attention_rope)
162
+ self.layers.append(block)
163
+
164
+ self.final_norm = RMSNorm(
165
+ dims=self.hidden_size,
166
+ activation_ffw_td=P(),
167
+ epsilon=rms_norm_eps,
168
+ rngs=self.rng,
169
+ with_scale=True,
170
+ dtype=self.dtype,
171
+ random_init=force_random_weights,
172
+ )
173
+
174
+ self.lm_head = LMhead(vocab_size=vocab_size,
175
+ hidden_size=self.hidden_size,
176
+ dtype=self.dtype,
177
+ rngs=self.rng,
178
+ vd_sharding=(('data', 'model'), None),
179
+ dv_sharding=(None, ('data', 'model')),
180
+ random_init=force_random_weights)
181
+ if self.is_verbose:
182
+ self._print_model_architecture()
183
+
184
+ def _print_model_architecture(self):
185
+
186
+ logger.info("### Embedding ###")
187
+ nnx.display(self.embedder)
188
+
189
+ logger.info("\n### Layers ###")
190
+ for i, layer in enumerate(self.layers):
191
+ logger.info(f"\n--- Layer {i} ---")
192
+ nnx.display(layer)
193
+
194
+ logger.info("\n### LM Head ###")
195
+ nnx.display(self.lm_head)
196
+
197
+ def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
198
+ self.rng = nnx.Rngs(rng)
199
+
200
+ weight_loader = LlamaGuard4WeightLoader(
201
+ vllm_config=self.vllm_config,
202
+ hidden_size=self.hidden_size,
203
+ attn_heads=self.num_attention_heads,
204
+ num_key_value_heads=self.num_key_value_heads,
205
+ attn_head_dim=self.head_dim)
206
+ weight_loader.load_weights(self)
207
+
208
+ def __call__(
209
+ self,
210
+ kv_caches: List[jax.Array],
211
+ input_ids: jax.Array,
212
+ attention_metadata: AttentionMetadata,
213
+ inputs_embeds: Optional[jax.Array] = None,
214
+ layer_metadata_tuple: Optional[Tuple] = None,
215
+ lora_metadata: Optional[Any] = None,
216
+ *args,
217
+ ) -> Tuple[List[KVCacheType], jax.Array]:
218
+ is_prefill = False
219
+
220
+ if inputs_embeds is not None:
221
+ x_TD = inputs_embeds
222
+ elif input_ids is not None:
223
+ x_TD = self.embedder.encode(input_ids)
224
+ else:
225
+ raise ValueError(
226
+ "Cannot run forward pass: Both input_ids and inputs_embeds are None."
227
+ )
228
+
229
+ for (i, block) in enumerate(self.layers):
230
+ kv_cache = kv_caches[i]
231
+ new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
232
+ attention_metadata)
233
+ jax.block_until_ready(x_TD)
234
+ kv_caches[i] = new_kv_cache
235
+
236
+ final_activation_TD = self.final_norm(x_TD)
237
+
238
+ return kv_caches, final_activation_TD, []
239
+
240
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
241
+ logits_TV = jnp.dot(hidden_states,
242
+ self.lm_head.input_embedding_table_DV.value)
243
+ return logits_TV
244
+
245
+ def get_input_embeddings(
246
+ self,
247
+ input_ids: jax.Array,
248
+ multimodal_embeddings: Optional[List[jax.Array]] = None
249
+ ) -> jax.Array:
250
+ """
251
+ Computes the embeddings for text input (used for input to fusion).
252
+ """
253
+ return self.embedder.encode(input_ids)
254
+
255
+
256
+ class LlamaGuard4WeightLoader:
257
+
258
+ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
259
+ num_key_value_heads, attn_head_dim):
260
+ self.names_and_weights_generator = model_weights_generator(
261
+ model_name_or_path=vllm_config.model_config.model,
262
+ framework="flax",
263
+ filter_regex="language_model",
264
+ download_dir=vllm_config.load_config.download_dir)
265
+ self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
266
+ False)
267
+ self._transpose_map = {
268
+ "q_proj": (2, 0, 1),
269
+ "k_proj": (2, 0, 1),
270
+ "v_proj": (2, 0, 1),
271
+ "o_proj": (1, 2, 0),
272
+ "lm_head": (1, 0),
273
+ "feed_forward.down_proj": (1, 0),
274
+ "feed_forward.gate_proj": (1, 0),
275
+ "feed_forward.up_proj": (1, 0),
276
+ "mlp.down_proj": (1, 0),
277
+ "mlp.gate_proj": (1, 0),
278
+ "mlp.up_proj": (1, 0),
279
+ }
280
+ self._weight_shape_map = {
281
+ "q_proj": (attn_heads, attn_head_dim, hidden_size),
282
+ "k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
283
+ "v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
284
+ "o_proj": (hidden_size, attn_heads, attn_head_dim),
285
+ }
286
+
287
+ self._loaded_to_standardized_keys = {
288
+ "language_model.model.embed_tokens.weight":
289
+ "embedder.input_embedding_table_VD",
290
+ "language_model.lm_head.weight":
291
+ "lm_head.input_embedding_table_DV",
292
+ "language_model.model.norm.weight":
293
+ "final_norm.scale",
294
+ "language_model.model.layers.*.input_layernorm.weight":
295
+ "layers.*.pre_attention_norm.scale",
296
+ "language_model.model.layers.*.post_attention_layernorm.weight":
297
+ "layers.*.pre_mlp_norm.scale",
298
+ "language_model.model.layers.*.self_attn.q_proj.weight":
299
+ "layers.*.attn.kernel_q_proj_DNH",
300
+ "language_model.model.layers.*.self_attn.k_proj.weight":
301
+ "layers.*.attn.kernel_k_proj_DKH",
302
+ "language_model.model.layers.*.self_attn.v_proj.weight":
303
+ "layers.*.attn.kernel_v_proj_DKH",
304
+ "language_model.model.layers.*.self_attn.o_proj.weight":
305
+ "layers.*.attn.kernel_o_proj_NHD",
306
+ "language_model.model.layers.*.feed_forward.gate_proj.weight":
307
+ "layers.*.custom_module.kernel_gating_DF",
308
+ "language_model.model.layers.*.feed_forward.up_proj.weight":
309
+ "layers.*.custom_module.kernel_up_proj_DF",
310
+ "language_model.model.layers.*.feed_forward.down_proj.weight":
311
+ "layers.*.custom_module.kernel_down_proj_FD",
312
+ }
313
+
314
+ def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
315
+ if "layer" in loaded_key:
316
+ layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
317
+ layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
318
+ mapped_key = self._loaded_to_standardized_keys.get(
319
+ layer_key, loaded_key)
320
+ mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
321
+ mapped_key)
322
+ else:
323
+ mapped_key = self._loaded_to_standardized_keys.get(
324
+ loaded_key, loaded_key)
325
+ return mapped_key
326
+
327
+ def load_weights(self, model_for_loading: nnx.Module):
328
+ model_params = nnx.state(model_for_loading)
329
+ with jax.default_device(jax.devices("cpu")[0]):
330
+ for loaded_name, loaded_weight in self.names_and_weights_generator:
331
+ if loaded_name.endswith(".bias"):
332
+ continue
333
+ if "vision_model" in loaded_name or "multi_modal_projector" in loaded_name:
334
+ continue
335
+
336
+ mapped_name = self.map_loaded_to_standardized_name(loaded_name)
337
+ model_weight = get_param(model_params, mapped_name)
338
+
339
+ if not loaded_name.endswith(".bias"):
340
+ # For other layers, continue to use the transpose_params helper.
341
+ loaded_weight = reshape_params(loaded_name, loaded_weight,
342
+ self._weight_shape_map)
343
+ loaded_weight = transpose_params(loaded_name,
344
+ loaded_weight,
345
+ self._transpose_map)
346
+ if model_weight.value.shape != loaded_weight.shape:
347
+ raise ValueError(
348
+ f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
349
+ f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
350
+ )
351
+ logger.debug(
352
+ f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
353
+ )
354
+
355
+ model_weight.value = shard_put(loaded_weight,
356
+ model_weight.sharding,
357
+ mesh=model_for_loading.mesh)
358
+ if self.is_verbose:
359
+ print_param_info(model_weight, loaded_name)
360
+
361
+ nnx.update(model_for_loading, model_params)