tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (58) hide show
  1. tests/lora/test_layers.py +0 -6
  2. tests/lora/utils.py +0 -8
  3. tests/test_envs.py +182 -0
  4. tests/test_utils.py +23 -14
  5. tpu_inference/__init__.py +22 -3
  6. tpu_inference/core/core_tpu.py +17 -9
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +2 -3
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +1 -1
  11. tpu_inference/executors/ray_distributed_executor.py +27 -11
  12. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
  14. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
  15. tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
  16. tpu_inference/layers/common/quant_methods.py +8 -0
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  19. tpu_inference/layers/jax/sample/sampling.py +2 -2
  20. tpu_inference/layers/vllm/attention.py +1 -1
  21. tpu_inference/layers/vllm/quantization/__init__.py +7 -3
  22. tpu_inference/layers/vllm/quantization/awq.py +4 -3
  23. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
  24. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  25. tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
  26. tpu_inference/layers/vllm/sharding.py +2 -2
  27. tpu_inference/lora/torch_punica_tpu.py +1 -2
  28. tpu_inference/models/common/model_loader.py +12 -11
  29. tpu_inference/models/jax/llama3.py +4 -3
  30. tpu_inference/models/jax/llama_eagle3.py +9 -5
  31. tpu_inference/models/jax/llama_guard_4.py +361 -0
  32. tpu_inference/models/jax/qwen2.py +3 -2
  33. tpu_inference/models/jax/qwen2_5_vl.py +4 -3
  34. tpu_inference/models/jax/qwen3.py +3 -2
  35. tpu_inference/models/jax/utils/weight_utils.py +21 -8
  36. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
  37. tpu_inference/platforms/tpu_platform.py +17 -7
  38. tpu_inference/runner/compilation_manager.py +37 -17
  39. tpu_inference/runner/kv_cache.py +1 -1
  40. tpu_inference/runner/kv_cache_manager.py +8 -2
  41. tpu_inference/runner/tpu_runner.py +199 -87
  42. tpu_inference/spec_decode/jax/eagle3.py +2 -1
  43. tpu_inference/tpu_info.py +4 -3
  44. tpu_inference/utils.py +7 -6
  45. tpu_inference/worker/tpu_worker.py +159 -23
  46. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
  47. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
  48. tpu_inference/mock/__init__.py +0 -0
  49. tpu_inference/mock/vllm_config_utils.py +0 -28
  50. tpu_inference/mock/vllm_envs.py +0 -1219
  51. tpu_inference/mock/vllm_logger.py +0 -212
  52. tpu_inference/mock/vllm_logging_utils.py +0 -15
  53. tpu_inference/models/jax/phi3.py +0 -376
  54. /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
  55. /tpu_inference/layers/{jax → common}/sharding.py +0 -0
  56. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
  57. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
  58. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,266 @@
1
+ from typing import Callable, Optional, Union
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ from jax.experimental.layout import Format, Layout
7
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
+ from torch.nn.parameter import Parameter
9
+ from torchax.interop import jax_view, torch_view
10
+ from torchax.ops.mappings import t2j
11
+ from vllm.logger import init_logger
12
+ from vllm.model_executor.layers.fused_moe.config import (
13
+ FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
14
+ from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
15
+ FusedMoEMethodBase)
16
+ from vllm.model_executor.layers.linear import LinearBase
17
+ from vllm.model_executor.layers.quantization import \
18
+ register_quantization_config
19
+ from vllm.model_executor.layers.quantization.base_config import \
20
+ QuantizeMethodBase
21
+ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
22
+ Mxfp4Config,
23
+ Mxfp4MoEMethod)
24
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
25
+ is_layer_skipped
26
+
27
+ from tpu_inference.layers.common.quant_methods import (MXFP4,
28
+ get_tpu_quant_method)
29
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
30
+ from tpu_inference.layers.vllm.linear_common import \
31
+ reorder_concatenated_tensor_for_sharding
32
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
33
+ from tpu_inference.layers.vllm.quantization.unquantized import \
34
+ VllmUnquantizedLinearMethod
35
+
36
+ MXFP4_BLOCK_SIZE = 32
37
+
38
+ P = PartitionSpec
39
+ logger = init_logger(__name__)
40
+
41
+
42
+ # TODO(kyuyeunk): Move these functions into a common utility file.
43
+ def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
44
+ assert u8_packed_e2m1.dtype == jnp.uint8
45
+ e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
46
+ # bitcast creates one more dimension that splits 8 bits into two e2m1.
47
+ # we flatten them with the last dim.
48
+ return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
49
+
50
+
51
+ def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
52
+ e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
53
+ exponents = u8.astype(jnp.int32) + e8_finfo.minexp
54
+ ones = jnp.ones_like(u8, dtype=jnp.float32)
55
+ return jnp.ldexp(ones, exponents)
56
+
57
+
58
+ def dequantize_block_weight(weight: jax.Array,
59
+ scale: jax.Array,
60
+ block_size: int,
61
+ out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
62
+ orig_shape = weight.shape
63
+ weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
64
+ weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
65
+ scale, -1)
66
+ return weight_dequantized.reshape(orig_shape).astype(out_dtype)
67
+
68
+
69
+ @register_quantization_config(get_tpu_quant_method(MXFP4))
70
+ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
71
+
72
+ @classmethod
73
+ def get_name(cls):
74
+ return MXFP4
75
+
76
+ def get_quant_method(self, layer: torch.nn.Module,
77
+ prefix: str) -> Optional["QuantizeMethodBase"]:
78
+ from vllm.attention.layer import Attention # Avoid circular import
79
+
80
+ if isinstance(layer, LinearBase):
81
+ linear_config = self.get_linear_config(layer)
82
+ if self.ignored_layers and is_layer_skipped(
83
+ prefix=prefix,
84
+ ignored_layers=self.ignored_layers,
85
+ fused_mapping=self.packed_modules_mapping,
86
+ ):
87
+ return VllmUnquantizedLinearMethod(linear_config)
88
+ # TODO: Add support for MXFP4 Linear Method.
89
+ # MXFP4 LinearMethod is available in AMD-Quark, refer to that
90
+ # implementation if you are interested in enabling MXFP4 here.
91
+ logger.warning_once(
92
+ "MXFP4 linear layer is not implemented - falling back to "
93
+ "UnquantizedLinearMethod.")
94
+ return VllmUnquantizedLinearMethod(linear_config)
95
+ elif isinstance(layer, FusedMoE):
96
+ return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
97
+ elif isinstance(layer, Attention):
98
+ # TODO: Add support for MXFP4 Attention.
99
+ logger.warning_once("MXFP4 attention layer is not implemented. "
100
+ "Skipping quantization for this layer.")
101
+ return None
102
+
103
+
104
+ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
105
+
106
+ def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
107
+ FusedMoEMethodBase.__init__(self, moe)
108
+
109
+ # We piggyback on triton implementation as it applies minimal hardware
110
+ # specific post processing to the weights.
111
+ self.mxfp4_backend = Mxfp4Backend.TRITON
112
+ self.mesh = mesh
113
+
114
+ def get_fused_moe_quant_config(
115
+ self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
116
+ # Because we have dequantized weights, we only need biased moe config.
117
+ # TODO(kyuyeunk): Add native support for MXFP4.
118
+ return biased_moe_quant_config(
119
+ layer.w13_bias,
120
+ layer.w2_bias,
121
+ )
122
+
123
+ def process_weights_after_loading(self, layer: torch.nn.Module):
124
+ assert isinstance(layer, FusedMoE)
125
+
126
+ w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
127
+ w13_weight_scale = e8m0_to_fp32(
128
+ t2j(layer.w13_weight_scale, use_dlpack=False))
129
+ w13_bias = t2j(layer.w13_bias, use_dlpack=False)
130
+
131
+ w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
132
+ w2_weight_scale = e8m0_to_fp32(
133
+ t2j(layer.w2_weight_scale, use_dlpack=False))
134
+ w2_bias = t2j(layer.w2_bias, use_dlpack=False)
135
+
136
+ # We dequantize fp4 weights into bf16.
137
+ # TODO(kyuyeunk): Add native support for MXFP4.
138
+ w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
139
+ MXFP4_BLOCK_SIZE, jnp.bfloat16)
140
+ w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
141
+ MXFP4_BLOCK_SIZE, jnp.bfloat16)
142
+
143
+ # Because we have dequantized weights, scales are not used anymore.
144
+ delattr(layer, "w13_weight_scale")
145
+ delattr(layer, "w2_weight_scale")
146
+
147
+ if layer.activation == "swigluoai":
148
+ # When using swigluoai, vLLM splits gmm output in a interleaved way.
149
+ # However, interleaved split is not performant on TPU. Therefore,
150
+ # we preprocess the weight so that splitting gmm output by middle
151
+ # can still get the same result.
152
+ w1_weight = w13_weight[:, ::2, :]
153
+ w3_weight = w13_weight[:, 1::2, :]
154
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
155
+
156
+ w1_bias = w13_bias[:, ::2]
157
+ w3_bias = w13_bias[:, 1::2]
158
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
159
+
160
+ # TODO(kyuyeunk): Add weight processing logic for the new kernel.
161
+ if layer.use_ep:
162
+ w13_weight = jax.device_put(
163
+ w13_weight,
164
+ Format(Layout((0, 1, 2)),
165
+ NamedSharding(self.mesh, P("model", None, None))))
166
+ w2_weight = jax.device_put(
167
+ w2_weight,
168
+ Format(Layout((0, 1, 2)),
169
+ NamedSharding(self.mesh, P("model", None, None))))
170
+
171
+ w13_bias = jax.device_put(
172
+ w13_bias,
173
+ Format(Layout((0, 1)),
174
+ NamedSharding(self.mesh, P("model", None))))
175
+ w2_bias = jax.device_put(
176
+ w2_bias,
177
+ Format(Layout((0, 1)),
178
+ NamedSharding(self.mesh, P("model", None))))
179
+
180
+ else:
181
+ intermediate_size = w13_weight.shape[1] // 2
182
+ assert intermediate_size == w2_weight.shape[-1]
183
+ output_sizes = [intermediate_size, intermediate_size]
184
+ n_shards = self.mesh.shape["model"]
185
+ assert intermediate_size % n_shards == 0
186
+ w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
187
+ output_sizes,
188
+ n_shards,
189
+ dim=1)
190
+ w13_weight = jax.device_put(
191
+ w13_weight,
192
+ Format(Layout((0, 1, 2)),
193
+ NamedSharding(self.mesh, P(None, "model", None))))
194
+ w2_weight = jax.device_put(
195
+ w2_weight,
196
+ Format(Layout((0, 1, 2)),
197
+ NamedSharding(self.mesh, P(None, None, "model"))))
198
+
199
+ w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
200
+ output_sizes,
201
+ n_shards,
202
+ dim=1)
203
+ w13_bias = jax.device_put(
204
+ w13_bias,
205
+ Format(Layout((0, 1)),
206
+ NamedSharding(self.mesh, P(None, "model"))))
207
+ w2_bias = jax.device_put(
208
+ w2_bias,
209
+ Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
210
+ None))))
211
+
212
+ layer.w13_weight = Parameter(torch_view(w13_weight),
213
+ requires_grad=False)
214
+ layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
215
+
216
+ layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
217
+ layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
218
+
219
+ pass
220
+
221
+ def apply(
222
+ self,
223
+ layer: torch.nn.Module,
224
+ x: torch.Tensor,
225
+ router_logits: torch.Tensor,
226
+ top_k: int,
227
+ renormalize: bool,
228
+ use_grouped_topk: bool = False,
229
+ topk_group: Optional[int] = None,
230
+ num_expert_group: Optional[int] = None,
231
+ global_num_experts: int = -1,
232
+ expert_map: Optional[torch.Tensor] = None,
233
+ custom_routing_function: Optional[Callable] = None,
234
+ scoring_func: str = "softmax",
235
+ routed_scaling_factor: float = 1.0,
236
+ e_score_correction_bias: Optional[torch.Tensor] = None,
237
+ apply_router_weight_on_input: bool = False,
238
+ activation: str = "silu",
239
+ enable_eplb: bool = False,
240
+ expert_load_view: Optional[torch.Tensor] = None,
241
+ logical_to_physical_map: Optional[torch.Tensor] = None,
242
+ logical_replica_count: Optional[torch.Tensor] = None,
243
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
244
+ assert isinstance(layer, FusedMoE)
245
+ if scoring_func != "softmax":
246
+ raise NotImplementedError(
247
+ "Only softmax is supported for scoring_func")
248
+
249
+ # Use the original implementation
250
+ output = fused_moe_func_padded(
251
+ jax_view(x),
252
+ jax_view(layer.w13_weight),
253
+ jax_view(layer.w2_weight),
254
+ jax_view(layer.w13_bias) if self.moe.has_bias else None,
255
+ jax_view(layer.w2_bias) if self.moe.has_bias else None,
256
+ jax_view(router_logits),
257
+ topk=top_k,
258
+ global_num_experts=global_num_experts,
259
+ renormalize=renormalize,
260
+ reduce_results=layer.reduce_results,
261
+ mesh=self.mesh,
262
+ use_ep=layer.use_ep,
263
+ activation=activation,
264
+ )
265
+
266
+ return torch_view(output)
@@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization.base_config import (
23
23
 
24
24
  from tpu_inference import envs
25
25
  from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
26
+ from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
27
+ get_tpu_quant_method)
26
28
  from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
27
29
  from tpu_inference.layers.vllm.linear_common import (
28
30
  reorder_concatenated_tensor_for_sharding,
@@ -34,12 +36,12 @@ P = PartitionSpec
34
36
  logger = init_logger(__name__)
35
37
 
36
38
 
37
- @register_quantization_config("jax-unquantized")
39
+ @register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
38
40
  class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
39
41
 
40
42
  @classmethod
41
43
  def get_name(cls) -> str:
42
- return "jax-unquantized"
44
+ return UNQUANTIZED
43
45
 
44
46
  @classmethod
45
47
  def get_supported_act_dtypes(cls) -> list[torch.dtype]:
@@ -189,7 +191,6 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
189
191
 
190
192
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
191
193
  assert isinstance(layer, FusedMoE)
192
-
193
194
  w13_weight = t2j(layer.w13_weight, use_dlpack=False)
194
195
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
195
196
 
@@ -19,6 +19,7 @@ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
19
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
20
20
  ParallelLMHead, VocabParallelEmbedding)
21
21
 
22
+ from tpu_inference import envs
22
23
  from tpu_inference.logger import init_logger
23
24
 
24
25
  P = PartitionSpec
@@ -211,8 +212,7 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
211
212
  def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
212
213
  if isinstance(tensor, tuple):
213
214
  return tuple(_sharded_device_put(t, sharding) for t in tensor)
214
- import os
215
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
215
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
216
216
  if multihost_backend != "ray":
217
217
  return jax.device_put(tensor, sharding)
218
218
 
@@ -239,7 +239,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
239
239
  lora_index_to_id: list[Optional[int]],
240
240
  max_loras: int,
241
241
  vocab_size: int,
242
- extra_vocab_size: int,
243
242
  ):
244
243
  # Pad the prompt mapping to avoid running into recompiles on the TPU
245
244
  # TODO: Should this happen inside mapping internally? If so how can we
@@ -258,7 +257,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
258
257
  lora_index_to_id,
259
258
  max_loras,
260
259
  vocab_size,
261
- extra_vocab_size,
260
+ 0, # extra_vocab_size
262
261
  "cpu",
263
262
  )
264
263
  with torchax.default_env():
@@ -11,7 +11,7 @@ from vllm.config import VllmConfig
11
11
  from vllm.utils.func_utils import supports_kw
12
12
 
13
13
  from tpu_inference import envs
14
- from tpu_inference.layers.jax.sharding import ShardingAxisName
14
+ from tpu_inference.layers.common.sharding import ShardingAxisName
15
15
  from tpu_inference.logger import init_logger
16
16
  from tpu_inference.models.jax.utils.quantization.quantization_utils import (
17
17
  apply_qwix_on_abstract_model, apply_qwix_quantization,
@@ -36,19 +36,17 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
36
36
  from tpu_inference.models.jax.llama3 import LlamaForCausalLM
37
37
  from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
38
38
  from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
39
- from tpu_inference.models.jax.phi3 import Phi3ForCausalLM
40
- from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
39
+ from tpu_inference.models.jax.llama_guard_4 import LlamaGuard4ForCausalLM
41
40
  from tpu_inference.models.jax.qwen2_5_vl import \
42
41
  Qwen2_5_VLForConditionalGeneration
43
42
  from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
44
43
  _MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
45
44
  _MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
46
45
  _MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
47
- _MODEL_REGISTRY["Qwen2ForCausalLM"] = Qwen2ForCausalLM
46
+ _MODEL_REGISTRY["Llama4ForConditionalGeneration"] = LlamaGuard4ForCausalLM
48
47
  _MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
49
48
  _MODEL_REGISTRY[
50
49
  "Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
51
- _MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
52
50
  _MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
53
51
  _MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
54
52
 
@@ -57,8 +55,10 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
57
55
  if arch in _MODEL_REGISTRY:
58
56
  return _MODEL_REGISTRY[arch]
59
57
  raise UnsupportedArchitectureError(
60
- f"Model architectures {architectures} are not supported for now. "
61
- f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
58
+ f"Model architectures {architectures} not "
59
+ "registered in tpu-inference. Falling back to vLLM-native "
60
+ f"Pytorch definition. JAX-native architectures: {list(_MODEL_REGISTRY.keys())}"
61
+ )
62
62
 
63
63
 
64
64
  def _get_nnx_model(
@@ -217,7 +217,7 @@ def get_flax_model(
217
217
  hidden_states_sharding, # aux hidden states
218
218
  ),
219
219
  donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
220
- static_argnums=6, #6 is layer_name_to_kvcache_index
220
+ static_argnums=7, #7 is layer_name_to_kvcache_index
221
221
  )
222
222
  def run_model(graphdef, state, *args):
223
223
  model = nnx.merge(graphdef, state)
@@ -242,10 +242,11 @@ def get_flax_model(
242
242
  model = nnx.merge(graphdef, state)
243
243
  return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
244
244
 
245
+ embed_sharding = NamedSharding(mesh, PartitionSpec(None))
245
246
  # This function will calculates the embeddings of input texts and then merge with the image embeddings
246
247
  @functools.partial(
247
248
  jax.jit,
248
- out_shardings=(logits_sharding),
249
+ out_shardings=(embed_sharding),
249
250
  )
250
251
  def run_get_input_embeddings(graphdef, state, *args, **kwargs):
251
252
  model = nnx.merge(graphdef, state)
@@ -325,8 +326,8 @@ def get_model(
325
326
  # Convert the error message to a string to check its contents
326
327
  error_msg = str(e)
327
328
 
328
- logger.warning(f"Flax model failed with: '{error_msg}'. "
329
- "Falling back to vLLM implementation.")
329
+ logger.warning(error_msg)
330
+
330
331
  # Fall back to the vLLM model and updating the dtype accordingly
331
332
  vllm_config.model_config.dtype = j2t_dtype(
332
333
  vllm_config.model_config.dtype.dtype)
@@ -8,10 +8,10 @@ from transformers import LlamaConfig, modeling_flax_utils
8
8
  from vllm.config import VllmConfig
9
9
 
10
10
  from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_interface import attention
11
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
- from tpu_inference.layers.jax.attention_interface import attention
13
+ from tpu_inference.layers.common.sharding import ShardingAxisName
13
14
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
- from tpu_inference.layers.jax.sharding import ShardingAxisName
15
15
  from tpu_inference.logger import init_logger
16
16
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
17
17
  load_hf_weights)
@@ -368,7 +368,8 @@ class LlamaForCausalLM(nnx.Module):
368
368
  "lm_head": "model.lm_head",
369
369
  })
370
370
 
371
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
371
+ metadata_map = get_default_maps(self.vllm_config.model_config,
372
+ self.mesh, mappings)
372
373
  load_hf_weights(vllm_config=self.vllm_config,
373
374
  model=self,
374
375
  metadata_map=metadata_map,
@@ -194,13 +194,12 @@ class Eagle3LlamaModel(nnx.Module):
194
194
 
195
195
  def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
196
196
  metadata_map: MetadataMap):
197
- model_config = vllm_config.model_config
197
+ model_config = vllm_config.speculative_config.draft_model_config
198
198
  hf_config = model_config.hf_config
199
199
 
200
200
  num_heads = hf_config.num_attention_heads
201
201
  num_kv_heads = hf_config.num_key_value_heads
202
- hidden_size = model_config.get_hidden_size()
203
-
202
+ hidden_size = hf_config.hidden_size
204
203
  head_dim_original = model_config.get_head_size()
205
204
 
206
205
  metadata_map.reshape_map.update({
@@ -312,7 +311,11 @@ class EagleLlama3ForCausalLM(nnx.Module):
312
311
  r".*d2t.*",
313
312
  ]
314
313
 
315
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
314
+ # `embed_tokens` is shared between target and draft.
315
+ exclude_regex = [r".*embed_tokens.*"]
316
+ metadata_map = get_default_maps(
317
+ self.vllm_config.speculative_config.draft_model_config, self.mesh,
318
+ mappings)
316
319
 
317
320
  update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
318
321
 
@@ -322,7 +325,8 @@ class EagleLlama3ForCausalLM(nnx.Module):
322
325
  metadata_map=metadata_map,
323
326
  mesh=self.mesh,
324
327
  is_draft_model=True,
325
- keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
328
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
329
+ exclude_regex=exclude_regex if exclude_regex else None)
326
330
 
327
331
  # If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
328
332
  if isinstance(self.model.embed_tokens.embedding.value,