tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (54) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tests/test_envs.py +32 -11
  5. tests/test_utils.py +1 -2
  6. tpu_inference/__init__.py +22 -3
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +3 -4
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +61 -8
  11. tpu_inference/executors/ray_distributed_executor.py +31 -11
  12. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
  15. tpu_inference/layers/common/attention_interface.py +7 -1
  16. tpu_inference/layers/common/sharding.py +5 -5
  17. tpu_inference/layers/vllm/fused_moe.py +74 -25
  18. tpu_inference/layers/vllm/quantization/common.py +6 -1
  19. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
  20. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  21. tpu_inference/layers/vllm/sharding.py +2 -2
  22. tpu_inference/lora/torch_punica_tpu.py +1 -2
  23. tpu_inference/models/common/model_loader.py +45 -11
  24. tpu_inference/models/jax/llama3.py +2 -1
  25. tpu_inference/models/jax/llama_eagle3.py +8 -5
  26. tpu_inference/models/jax/llama_guard_4.py +361 -0
  27. tpu_inference/models/jax/qwen2.py +2 -1
  28. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  29. tpu_inference/models/jax/qwen3.py +2 -1
  30. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  31. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  32. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
  33. tpu_inference/platforms/tpu_platform.py +28 -22
  34. tpu_inference/runner/compilation_manager.py +144 -59
  35. tpu_inference/runner/kv_cache_manager.py +17 -18
  36. tpu_inference/runner/persistent_batch_manager.py +40 -2
  37. tpu_inference/runner/structured_decoding_manager.py +2 -3
  38. tpu_inference/runner/tpu_runner.py +271 -147
  39. tpu_inference/runner/utils.py +2 -2
  40. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  41. tpu_inference/tpu_info.py +4 -3
  42. tpu_inference/utils.py +36 -13
  43. tpu_inference/worker/tpu_worker.py +162 -25
  44. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
  45. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
  46. tpu_inference/mock/__init__.py +0 -0
  47. tpu_inference/mock/vllm_config_utils.py +0 -28
  48. tpu_inference/mock/vllm_envs.py +0 -1219
  49. tpu_inference/mock/vllm_logger.py +0 -212
  50. tpu_inference/mock/vllm_logging_utils.py +0 -15
  51. tpu_inference/models/jax/phi3.py +0 -376
  52. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  53. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  54. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -108,6 +108,8 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
108
108
  layer: torch.nn.Module,
109
109
  x: torch.Tensor,
110
110
  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
111
+ assert isinstance(layer, LinearBase)
112
+
111
113
  with jax.named_scope(layer._get_name()):
112
114
  if in_sharding := self.jax_config.get_input_sharding(x):
113
115
  x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
@@ -170,14 +172,14 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
170
172
  self.ep_axis_name = ep_axis_name
171
173
  # TODO: Use autotune table once we have it.
172
174
  self.block_size = {
173
- "bt": 16,
174
- "bf": 384,
175
- "bd1": 512,
176
- "bd2": 512,
177
- "btc": 16,
178
- "bfc": 384,
179
- "bd1c": 256,
180
- "bd2c": 256,
175
+ "bt": 64,
176
+ "bf": 1024,
177
+ "bd1": 1536,
178
+ "bd2": 1536,
179
+ "btc": 64,
180
+ "bfc": 1024,
181
+ "bd1c": 1536,
182
+ "bd2c": 1536,
181
183
  }
182
184
 
183
185
  def select_gemm_impl(
@@ -191,131 +193,119 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
191
193
 
192
194
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
193
195
  assert isinstance(layer, FusedMoE)
194
- available_devices = self.mesh.devices.flatten()
195
- with jax.default_device(available_devices[0]):
196
- w13_weight = t2j(layer.w13_weight, use_dlpack=False)
197
- w2_weight = t2j(layer.w2_weight, use_dlpack=False)
196
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
197
+ w2_weight = t2j(layer.w2_weight, use_dlpack=False)
198
198
 
199
- if self.moe.has_bias:
200
- w13_bias = t2j(layer.w13_bias, use_dlpack=False)
201
- w2_bias = t2j(layer.w2_bias, use_dlpack=False)
202
-
203
- if layer.activation == "swigluoai":
204
- # When using swigluoai, vLLM splits gmm output in a interleaved way.
205
- # However, interleaved split is not performant on TPU. Therefore,
206
- # we preprocess the weight so that splitting gmm output by middle
207
- # can still get the same result.
208
- w1_weight = w13_weight[:, ::2, :]
209
- w3_weight = w13_weight[:, 1::2, :]
210
- w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
199
+ if self.moe.has_bias:
200
+ w13_bias = t2j(layer.w13_bias, use_dlpack=False)
201
+ w2_bias = t2j(layer.w2_bias, use_dlpack=False)
202
+
203
+ if layer.activation == "swigluoai":
204
+ # When using swigluoai, vLLM splits gmm output in a interleaved way.
205
+ # However, interleaved split is not performant on TPU. Therefore,
206
+ # we preprocess the weight so that splitting gmm output by middle
207
+ # can still get the same result.
208
+ w1_weight = w13_weight[:, ::2, :]
209
+ w3_weight = w13_weight[:, 1::2, :]
210
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
211
211
 
212
- if self.moe.has_bias:
213
- w1_bias = w13_bias[:, ::2]
214
- w3_bias = w13_bias[:, 1::2]
215
- w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
216
-
217
- if self.use_kernel and layer.use_ep:
218
- # Kernel expects:
219
- # w13: (num_experts, 2, hidden_size, intermediate_size)
220
- # w2: (num_experts, intermediate_size, hidden_size)
221
- # Current format:
222
- # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
223
- # w2_weight: (num_experts, hidden_size, intermediate_size)
224
- num_experts = w13_weight.shape[0]
225
- intermediate_size = w13_weight.shape[1] // 2
226
- hidden_size = w13_weight.shape[2]
212
+ if self.moe.has_bias:
213
+ w1_bias = w13_bias[:, ::2]
214
+ w3_bias = w13_bias[:, 1::2]
215
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
227
216
 
228
- # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
229
- w13_reshaped = w13_weight.reshape(num_experts, 2,
230
- intermediate_size,
231
- hidden_size)
232
- w13_weight_transposed = jnp.transpose(w13_reshaped,
233
- (0, 1, 3, 2))
217
+ if self.use_kernel and layer.use_ep:
218
+ # Kernel expects:
219
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
220
+ # w2: (num_experts, intermediate_size, hidden_size)
221
+ # Current format:
222
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
223
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
224
+ num_experts = w13_weight.shape[0]
225
+ intermediate_size = w13_weight.shape[1] // 2
226
+ hidden_size = w13_weight.shape[2]
227
+
228
+ # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
229
+ w13_reshaped = w13_weight.reshape(num_experts, 2,
230
+ intermediate_size, hidden_size)
231
+ w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
232
+
233
+ # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
234
+ w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
235
+
236
+ # Apply EP sharding
237
+ w13_weight = jax.device_put(
238
+ w13_weight_transposed,
239
+ Format(Layout((0, 1, 2, 3)),
240
+ NamedSharding(self.mesh, P("model", None, None, None))))
241
+ w2_weight = jax.device_put(
242
+ w2_weight_transposed,
243
+ Format(Layout((0, 1, 2)),
244
+ NamedSharding(self.mesh, P("model", None, None))))
234
245
 
235
- # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
236
- w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
246
+ if self.moe.has_bias:
247
+ w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
237
248
 
238
249
  # Apply EP sharding
250
+ w13_bias = jax.device_put(
251
+ w13_bias,
252
+ Format(Layout((0, 1, 2)),
253
+ NamedSharding(self.mesh, P("model", None, None))))
254
+ w2_bias = jax.device_put(
255
+ w2_bias,
256
+ Format(Layout((0, 1)),
257
+ NamedSharding(self.mesh, P("model", None))))
258
+
259
+ else:
260
+ # Original logic for non-kernel path
261
+ if layer.use_ep:
239
262
  w13_weight = jax.device_put(
240
- w13_weight_transposed,
241
- Format(
242
- Layout((0, 1, 2, 3)),
243
- NamedSharding(self.mesh, P("model", None, None,
244
- None))))
263
+ w13_weight,
264
+ Format(Layout((0, 1, 2)),
265
+ NamedSharding(self.mesh, P("model", None, None))))
245
266
  w2_weight = jax.device_put(
246
- w2_weight_transposed,
267
+ w2_weight,
247
268
  Format(Layout((0, 1, 2)),
248
269
  NamedSharding(self.mesh, P("model", None, None))))
249
270
 
250
271
  if self.moe.has_bias:
251
- w13_bias = w13_bias.reshape(num_experts, 2,
252
- intermediate_size)
253
-
254
- # Apply EP sharding
255
272
  w13_bias = jax.device_put(
256
273
  w13_bias,
257
- Format(
258
- Layout((0, 1, 2)),
259
- NamedSharding(self.mesh, P("model", None, None))))
274
+ Format(Layout((0, 1)),
275
+ NamedSharding(self.mesh, P("model", None))))
260
276
  w2_bias = jax.device_put(
261
277
  w2_bias,
262
278
  Format(Layout((0, 1)),
263
279
  NamedSharding(self.mesh, P("model", None))))
264
280
 
265
281
  else:
266
- # Original logic for non-kernel path
267
- if layer.use_ep:
268
- w13_weight = jax.device_put(
269
- w13_weight,
270
- Format(
271
- Layout((0, 1, 2)),
272
- NamedSharding(self.mesh, P("model", None, None))))
273
- w2_weight = jax.device_put(
274
- w2_weight,
275
- Format(
276
- Layout((0, 1, 2)),
277
- NamedSharding(self.mesh, P("model", None, None))))
278
-
279
- if self.moe.has_bias:
280
- w13_bias = jax.device_put(
281
- w13_bias,
282
- Format(Layout((0, 1)),
283
- NamedSharding(self.mesh, P("model", None))))
284
- w2_bias = jax.device_put(
285
- w2_bias,
286
- Format(Layout((0, 1)),
287
- NamedSharding(self.mesh, P("model", None))))
288
-
289
- else:
290
- intermediate_size = w13_weight.shape[1] // 2
291
- assert intermediate_size == w2_weight.shape[-1]
292
- output_sizes = [intermediate_size, intermediate_size]
293
- n_shards = self.mesh.shape["model"]
294
- assert intermediate_size % n_shards == 0
295
- w13_weight = reorder_concatenated_tensor_for_sharding(
296
- w13_weight, output_sizes, n_shards, dim=1)
297
- w13_weight = jax.device_put(
298
- w13_weight,
299
- Format(
300
- Layout((0, 1, 2)),
301
- NamedSharding(self.mesh, P(None, "model", None))))
302
- w2_weight = jax.device_put(
303
- w2_weight,
304
- Format(
305
- Layout((0, 1, 2)),
306
- NamedSharding(self.mesh, P(None, None, "model"))))
307
-
308
- if self.moe.has_bias:
309
- w13_bias = reorder_concatenated_tensor_for_sharding(
310
- w13_bias, output_sizes, n_shards, dim=1)
311
- w13_bias = jax.device_put(
312
- w13_bias,
313
- Format(Layout((0, 1)),
314
- NamedSharding(self.mesh, P(None, "model"))))
315
- w2_bias = jax.device_put(
316
- w2_bias,
317
- Format(Layout((0, 1)),
318
- NamedSharding(self.mesh, P(None, None))))
282
+ intermediate_size = w13_weight.shape[1] // 2
283
+ assert intermediate_size == w2_weight.shape[-1]
284
+ output_sizes = [intermediate_size, intermediate_size]
285
+ n_shards = self.mesh.shape["model"]
286
+ assert intermediate_size % n_shards == 0
287
+ w13_weight = reorder_concatenated_tensor_for_sharding(
288
+ w13_weight, output_sizes, n_shards, dim=1)
289
+ w13_weight = jax.device_put(
290
+ w13_weight,
291
+ Format(Layout((0, 1, 2)),
292
+ NamedSharding(self.mesh, P(None, "model", None))))
293
+ w2_weight = jax.device_put(
294
+ w2_weight,
295
+ Format(Layout((0, 1, 2)),
296
+ NamedSharding(self.mesh, P(None, None, "model"))))
297
+
298
+ if self.moe.has_bias:
299
+ w13_bias = reorder_concatenated_tensor_for_sharding(
300
+ w13_bias, output_sizes, n_shards, dim=1)
301
+ w13_bias = jax.device_put(
302
+ w13_bias,
303
+ Format(Layout((0, 1)),
304
+ NamedSharding(self.mesh, P(None, "model"))))
305
+ w2_bias = jax.device_put(
306
+ w2_bias,
307
+ Format(Layout((0, 1)),
308
+ NamedSharding(self.mesh, P(None, None))))
319
309
 
320
310
  layer.w13_weight = Parameter(torch_view(w13_weight),
321
311
  requires_grad=False)
@@ -360,9 +350,13 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
360
350
  tokens=jax_view(x),
361
351
  w1=jax_view(layer.w13_weight),
362
352
  w2=jax_view(layer.w2_weight),
353
+ b1=jax_view(layer.w13_bias) if self.moe.has_bias else None,
354
+ b2=jax_view(layer.w2_bias) if self.moe.has_bias else None,
363
355
  gating_output=jax_view(router_logits),
364
356
  top_k=top_k,
365
357
  ep_axis_name=self.ep_axis_name,
358
+ renormalize_topk_logits=renormalize,
359
+ act_fn=activation,
366
360
  **self.block_size,
367
361
  )
368
362
  else:
@@ -19,6 +19,7 @@ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
19
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
20
20
  ParallelLMHead, VocabParallelEmbedding)
21
21
 
22
+ from tpu_inference import envs
22
23
  from tpu_inference.logger import init_logger
23
24
 
24
25
  P = PartitionSpec
@@ -211,8 +212,7 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
211
212
  def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
212
213
  if isinstance(tensor, tuple):
213
214
  return tuple(_sharded_device_put(t, sharding) for t in tensor)
214
- import os
215
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
215
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
216
216
  if multihost_backend != "ray":
217
217
  return jax.device_put(tensor, sharding)
218
218
 
@@ -239,7 +239,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
239
239
  lora_index_to_id: list[Optional[int]],
240
240
  max_loras: int,
241
241
  vocab_size: int,
242
- extra_vocab_size: int,
243
242
  ):
244
243
  # Pad the prompt mapping to avoid running into recompiles on the TPU
245
244
  # TODO: Should this happen inside mapping internally? If so how can we
@@ -258,7 +257,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
258
257
  lora_index_to_id,
259
258
  max_loras,
260
259
  vocab_size,
261
- extra_vocab_size,
260
+ 0, # extra_vocab_size
262
261
  "cpu",
263
262
  )
264
263
  with torchax.default_env():
@@ -8,6 +8,9 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
8
  from torchax.ops.mappings import j2t_dtype
9
9
  from transformers import PretrainedConfig
10
10
  from vllm.config import VllmConfig
11
+ from vllm.model_executor.model_loader import get_model_loader
12
+ from vllm.model_executor.model_loader.runai_streamer_loader import \
13
+ RunaiModelStreamerLoader
11
14
  from vllm.utils.func_utils import supports_kw
12
15
 
13
16
  from tpu_inference import envs
@@ -36,19 +39,17 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
36
39
  from tpu_inference.models.jax.llama3 import LlamaForCausalLM
37
40
  from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
38
41
  from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
39
- from tpu_inference.models.jax.phi3 import Phi3ForCausalLM
40
- from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
42
+ from tpu_inference.models.jax.llama_guard_4 import LlamaGuard4ForCausalLM
41
43
  from tpu_inference.models.jax.qwen2_5_vl import \
42
44
  Qwen2_5_VLForConditionalGeneration
43
45
  from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
44
46
  _MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
45
47
  _MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
46
48
  _MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
47
- _MODEL_REGISTRY["Qwen2ForCausalLM"] = Qwen2ForCausalLM
49
+ _MODEL_REGISTRY["Llama4ForConditionalGeneration"] = LlamaGuard4ForCausalLM
48
50
  _MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
49
51
  _MODEL_REGISTRY[
50
52
  "Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
51
- _MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
52
53
  _MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
53
54
  _MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
54
55
 
@@ -57,8 +58,10 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
57
58
  if arch in _MODEL_REGISTRY:
58
59
  return _MODEL_REGISTRY[arch]
59
60
  raise UnsupportedArchitectureError(
60
- f"Model architectures {architectures} are not supported for now. "
61
- f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
61
+ f"Model architectures {architectures} not "
62
+ "registered in tpu-inference. Falling back to vLLM-native "
63
+ f"Pytorch definition. JAX-native architectures: {list(_MODEL_REGISTRY.keys())}"
64
+ )
62
65
 
63
66
 
64
67
  def _get_nnx_model(
@@ -177,7 +180,23 @@ def _get_nnx_model(
177
180
  # the model creation again, otherwise the model forward will have
178
181
  # non-trivial overhead in PjitFunction.
179
182
  with mesh:
180
- model.load_weights(rng)
183
+ loader = get_model_loader(vllm_config.load_config)
184
+ if isinstance(loader, RunaiModelStreamerLoader):
185
+ model_weights = vllm_config.model_config.model
186
+ if hasattr(vllm_config.model_config, "model_weights"):
187
+ model_weights = vllm_config.model_config.model_weights
188
+ weights_iterator = loader._get_weights_iterator(
189
+ model_weights, vllm_config.model_config.revision)
190
+ # We set the weights iterator at runtime, to prevent having to change
191
+ # every model's load_weights signature. This also prevents us from hitting
192
+ # a TypeError at runtime if you use the RunaiModelStreamerLoader with any
193
+ # flax_nnx model whose load_weights function does not accept the
194
+ # weights_iterator keyword argument.
195
+ vllm_config.model_config.model_weights_iterator = weights_iterator
196
+ model.load_weights(rng)
197
+ del vllm_config.model_config.model_weights_iterator
198
+ else:
199
+ model.load_weights(rng)
181
200
  jit_model = create_jit_model(
182
201
  model,
183
202
  use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
@@ -217,7 +236,9 @@ def get_flax_model(
217
236
  hidden_states_sharding, # aux hidden states
218
237
  ),
219
238
  donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
220
- static_argnums=6, #6 is layer_name_to_kvcache_index
239
+ static_argnums=(
240
+ 7, 10, 11
241
+ ), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
221
242
  )
222
243
  def run_model(graphdef, state, *args):
223
244
  model = nnx.merge(graphdef, state)
@@ -242,10 +263,11 @@ def get_flax_model(
242
263
  model = nnx.merge(graphdef, state)
243
264
  return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
244
265
 
266
+ embed_sharding = NamedSharding(mesh, PartitionSpec(None))
245
267
  # This function will calculates the embeddings of input texts and then merge with the image embeddings
246
268
  @functools.partial(
247
269
  jax.jit,
248
- out_shardings=(logits_sharding),
270
+ out_shardings=(embed_sharding),
249
271
  )
250
272
  def run_get_input_embeddings(graphdef, state, *args, **kwargs):
251
273
  model = nnx.merge(graphdef, state)
@@ -325,8 +347,8 @@ def get_model(
325
347
  # Convert the error message to a string to check its contents
326
348
  error_msg = str(e)
327
349
 
328
- logger.warning(f"Flax model failed with: '{error_msg}'. "
329
- "Falling back to vLLM implementation.")
350
+ logger.warning(error_msg)
351
+
330
352
  # Fall back to the vLLM model and updating the dtype accordingly
331
353
  vllm_config.model_config.dtype = j2t_dtype(
332
354
  vllm_config.model_config.dtype.dtype)
@@ -420,6 +442,17 @@ def register_model(arch: str, model: Any) -> None:
420
442
  "This is a JAX model and does not implement the PyTorch forward method."
421
443
  )
422
444
 
445
+ # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
446
+ def unimplemented_get_input_embeddings(
447
+ self,
448
+ input_ids: "torch.Tensor",
449
+ positions: "torch.Tensor",
450
+ inputs_embeds: Optional["torch.Tensor"] = None,
451
+ ) -> "torch.Tensor":
452
+ raise NotImplementedError(
453
+ "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
454
+ )
455
+
423
456
  # We need a custom __init__ that only calls torch.nn.Module's init,
424
457
  # to avoid triggering JAX logic when vLLM inspects the class.
425
458
  def wrapper_init(self, *args, **kwargs):
@@ -433,6 +466,7 @@ def register_model(arch: str, model: Any) -> None:
433
466
  {
434
467
  "__init__": wrapper_init,
435
468
  "forward": unimplemented_forward,
469
+ "get_input_embeddings": unimplemented_get_input_embeddings,
436
470
  # Prevent vLLM from trying to load weights into this dummy class.
437
471
  "load_weights": lambda self, *args, **kwargs: None,
438
472
  })
@@ -368,7 +368,8 @@ class LlamaForCausalLM(nnx.Module):
368
368
  "lm_head": "model.lm_head",
369
369
  })
370
370
 
371
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
371
+ metadata_map = get_default_maps(self.vllm_config.model_config,
372
+ self.mesh, mappings)
372
373
  load_hf_weights(vllm_config=self.vllm_config,
373
374
  model=self,
374
375
  metadata_map=metadata_map,
@@ -194,13 +194,12 @@ class Eagle3LlamaModel(nnx.Module):
194
194
 
195
195
  def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
196
196
  metadata_map: MetadataMap):
197
- model_config = vllm_config.model_config
197
+ model_config = vllm_config.speculative_config.draft_model_config
198
198
  hf_config = model_config.hf_config
199
199
 
200
200
  num_heads = hf_config.num_attention_heads
201
201
  num_kv_heads = hf_config.num_key_value_heads
202
- hidden_size = model_config.get_hidden_size()
203
-
202
+ hidden_size = hf_config.hidden_size
204
203
  head_dim_original = model_config.get_head_size()
205
204
 
206
205
  metadata_map.reshape_map.update({
@@ -305,6 +304,8 @@ class EagleLlama3ForCausalLM(nnx.Module):
305
304
  "fc": "model.fc.kernel",
306
305
  "lm_head": "lm_head.kernel",
307
306
  "d2t": "draft_id_to_target_id",
307
+ "embed_tokens":
308
+ "model.embed_tokens.embedding", # Some checkpoints need this
308
309
  }
309
310
 
310
311
  # Define keys to keep in original dtype (e.g., float32 for stability)
@@ -312,7 +313,9 @@ class EagleLlama3ForCausalLM(nnx.Module):
312
313
  r".*d2t.*",
313
314
  ]
314
315
 
315
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
316
+ metadata_map = get_default_maps(
317
+ self.vllm_config.speculative_config.draft_model_config, self.mesh,
318
+ mappings)
316
319
 
317
320
  update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
318
321
 
@@ -324,7 +327,7 @@ class EagleLlama3ForCausalLM(nnx.Module):
324
327
  is_draft_model=True,
325
328
  keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
326
329
 
327
- # If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
330
+ # If the embedding is not initialized, initialize it with a dummy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
328
331
  if isinstance(self.model.embed_tokens.embedding.value,
329
332
  jax.ShapeDtypeStruct):
330
333
  self.model.embed_tokens.embedding.value = jnp.zeros(