tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (59) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -1
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/test_envs.py +110 -12
  9. tests/test_quantization.py +3 -0
  10. tests/test_utils.py +1 -2
  11. tpu_inference/distributed/tpu_connector.py +1 -1
  12. tpu_inference/envs.py +92 -8
  13. tpu_inference/executors/ray_distributed_executor.py +5 -1
  14. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  15. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  16. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  17. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  18. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  19. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  20. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  21. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
  22. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
  23. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  25. tpu_inference/layers/common/attention_interface.py +7 -1
  26. tpu_inference/layers/common/sharding.py +11 -7
  27. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  28. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  29. tpu_inference/layers/vllm/fused_moe.py +170 -208
  30. tpu_inference/layers/vllm/linear_common.py +43 -21
  31. tpu_inference/layers/vllm/quantization/common.py +11 -6
  32. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  33. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  34. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  35. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  36. tpu_inference/models/common/model_loader.py +78 -22
  37. tpu_inference/models/jax/deepseek_v3.py +185 -64
  38. tpu_inference/models/jax/gpt_oss.py +3 -3
  39. tpu_inference/models/jax/llama_eagle3.py +4 -5
  40. tpu_inference/models/jax/qwen2_5_vl.py +161 -47
  41. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  42. tpu_inference/models/jax/utils/weight_utils.py +203 -155
  43. tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
  44. tpu_inference/platforms/tpu_platform.py +29 -48
  45. tpu_inference/runner/compilation_manager.py +112 -46
  46. tpu_inference/runner/kv_cache.py +40 -20
  47. tpu_inference/runner/kv_cache_manager.py +40 -31
  48. tpu_inference/runner/persistent_batch_manager.py +40 -2
  49. tpu_inference/runner/structured_decoding_manager.py +2 -3
  50. tpu_inference/runner/tpu_runner.py +94 -51
  51. tpu_inference/runner/utils.py +2 -2
  52. tpu_inference/spec_decode/jax/eagle3.py +71 -22
  53. tpu_inference/utils.py +41 -14
  54. tpu_inference/worker/tpu_worker.py +43 -45
  55. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
  56. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
  57. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  58. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  59. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, Optional, Union
1
+ from typing import Any, Optional, Union
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
@@ -25,7 +25,7 @@ from tpu_inference import envs
25
25
  from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
26
26
  from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
27
27
  get_tpu_quant_method)
28
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
28
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func
29
29
  from tpu_inference.layers.vllm.linear_common import (
30
30
  reorder_concatenated_tensor_for_sharding,
31
31
  slice_sharded_tensor_for_concatenation, torch_to_jax_param)
@@ -36,6 +36,10 @@ P = PartitionSpec
36
36
  logger = init_logger(__name__)
37
37
 
38
38
 
39
+ def align_to(a, b):
40
+ return (a + b - 1) // b * b
41
+
42
+
39
43
  @register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
40
44
  class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
41
45
 
@@ -108,6 +112,8 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
108
112
  layer: torch.nn.Module,
109
113
  x: torch.Tensor,
110
114
  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
115
+ assert isinstance(layer, LinearBase)
116
+
111
117
  with jax.named_scope(layer._get_name()):
112
118
  if in_sharding := self.jax_config.get_input_sharding(x):
113
119
  x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
@@ -166,18 +172,18 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
166
172
  ep_axis_name: str = 'model'):
167
173
  super().__init__(moe)
168
174
  self.mesh = mesh
169
- self.use_kernel = envs.USE_MOE_EP_KERNEL
175
+ self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
170
176
  self.ep_axis_name = ep_axis_name
171
177
  # TODO: Use autotune table once we have it.
172
178
  self.block_size = {
173
- "bt": 16,
174
- "bf": 384,
175
- "bd1": 512,
176
- "bd2": 512,
177
- "btc": 16,
178
- "bfc": 384,
179
- "bd1c": 256,
180
- "bd2c": 256,
179
+ "bt": 64,
180
+ "bf": 1024,
181
+ "bd1": 1536,
182
+ "bd2": 1536,
183
+ "btc": 64,
184
+ "bfc": 1024,
185
+ "bd1c": 1536,
186
+ "bd2c": 1536,
181
187
  }
182
188
 
183
189
  def select_gemm_impl(
@@ -194,6 +200,8 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
194
200
  w13_weight = t2j(layer.w13_weight, use_dlpack=False)
195
201
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
196
202
 
203
+ num_experts, hidden_size, intermediate_size = w2_weight.shape
204
+
197
205
  if self.moe.has_bias:
198
206
  w13_bias = t2j(layer.w13_bias, use_dlpack=False)
199
207
  w2_bias = t2j(layer.w2_bias, use_dlpack=False)
@@ -212,7 +220,7 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
212
220
  w3_bias = w13_bias[:, 1::2]
213
221
  w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
214
222
 
215
- if self.use_kernel and layer.use_ep:
223
+ if self.use_kernel:
216
224
  # Kernel expects:
217
225
  # w13: (num_experts, 2, hidden_size, intermediate_size)
218
226
  # w2: (num_experts, intermediate_size, hidden_size)
@@ -223,65 +231,82 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
223
231
  intermediate_size = w13_weight.shape[1] // 2
224
232
  hidden_size = w13_weight.shape[2]
225
233
 
226
- # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
227
- w13_reshaped = w13_weight.reshape(num_experts, 2,
228
- intermediate_size, hidden_size)
229
- w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
234
+ padded_intermediate_size = align_to(intermediate_size, 256)
235
+ padded_hidden_size = align_to(hidden_size, 256)
236
+
237
+ w13_weight = w13_weight.reshape(num_experts, 2, intermediate_size,
238
+ hidden_size)
239
+ w13_weight = jnp.transpose(w13_weight, (0, 1, 3, 2))
230
240
 
231
241
  # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
232
- w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
242
+ w2_weight = jnp.transpose(w2_weight, (0, 2, 1))
243
+
244
+ w13_weight = jnp.pad(
245
+ w13_weight,
246
+ ((0, 0), (0, 0), (0, padded_hidden_size - hidden_size),
247
+ (0, padded_intermediate_size - intermediate_size)),
248
+ constant_values=0)
249
+
250
+ w2_weight = jnp.pad(
251
+ w2_weight,
252
+ ((0, 0), (0, padded_intermediate_size - intermediate_size),
253
+ (0, padded_hidden_size - hidden_size)),
254
+ constant_values=0)
233
255
 
234
256
  # Apply EP sharding
257
+ ep_sharding = NamedSharding(self.mesh, P("model"))
258
+
235
259
  w13_weight = jax.device_put(
236
- w13_weight_transposed,
260
+ w13_weight,
237
261
  Format(Layout((0, 1, 2, 3)),
238
262
  NamedSharding(self.mesh, P("model", None, None, None))))
239
263
  w2_weight = jax.device_put(
240
- w2_weight_transposed,
264
+ w2_weight,
241
265
  Format(Layout((0, 1, 2)),
242
266
  NamedSharding(self.mesh, P("model", None, None))))
243
267
 
244
268
  if self.moe.has_bias:
245
- w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
269
+ w13_bias = w13_bias.astype(jnp.float32).reshape(
270
+ num_experts, 2, 1, intermediate_size)
271
+ w2_bias = w2_bias.astype(jnp.float32).reshape(
272
+ num_experts, 1, hidden_size)
273
+
274
+ w13_bias = jnp.pad(
275
+ w13_bias,
276
+ ((0, 0), (0, 0), (0, 0),
277
+ (0, padded_intermediate_size - intermediate_size)),
278
+ constant_values=0)
279
+
280
+ w2_bias = jnp.pad(w2_bias,
281
+ ((0, 0), (0, 0),
282
+ (0, padded_hidden_size - hidden_size)),
283
+ constant_values=0)
246
284
 
247
285
  # Apply EP sharding
248
286
  w13_bias = jax.device_put(
249
- w13_bias,
250
- Format(Layout((0, 1, 2)),
251
- NamedSharding(self.mesh, P("model", None, None))))
287
+ w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
252
288
  w2_bias = jax.device_put(
253
- w2_bias,
254
- Format(Layout((0, 1)),
255
- NamedSharding(self.mesh, P("model", None))))
256
-
289
+ w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
257
290
  else:
258
- # Original logic for non-kernel path
291
+
259
292
  if layer.use_ep:
293
+ ep_sharding = NamedSharding(self.mesh, P("model"))
260
294
  w13_weight = jax.device_put(
261
- w13_weight,
262
- Format(Layout((0, 1, 2)),
263
- NamedSharding(self.mesh, P("model", None, None))))
295
+ w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
264
296
  w2_weight = jax.device_put(
265
- w2_weight,
266
- Format(Layout((0, 1, 2)),
267
- NamedSharding(self.mesh, P("model", None, None))))
297
+ w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
268
298
 
269
299
  if self.moe.has_bias:
270
300
  w13_bias = jax.device_put(
271
- w13_bias,
272
- Format(Layout((0, 1)),
273
- NamedSharding(self.mesh, P("model", None))))
301
+ w13_bias, Format(Layout((0, 1)), ep_sharding))
274
302
  w2_bias = jax.device_put(
275
- w2_bias,
276
- Format(Layout((0, 1)),
277
- NamedSharding(self.mesh, P("model", None))))
303
+ w2_bias, Format(Layout((0, 1)), ep_sharding))
278
304
 
279
305
  else:
280
- intermediate_size = w13_weight.shape[1] // 2
281
- assert intermediate_size == w2_weight.shape[-1]
282
306
  output_sizes = [intermediate_size, intermediate_size]
283
307
  n_shards = self.mesh.shape["model"]
284
308
  assert intermediate_size % n_shards == 0
309
+
285
310
  w13_weight = reorder_concatenated_tensor_for_sharding(
286
311
  w13_weight, output_sizes, n_shards, dim=1)
287
312
  w13_weight = jax.device_put(
@@ -319,56 +344,54 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
319
344
  layer: torch.nn.Module,
320
345
  x: torch.Tensor,
321
346
  router_logits: torch.Tensor,
322
- top_k: int,
323
- renormalize: bool,
324
- use_grouped_topk: bool = False,
325
- topk_group: Optional[int] = None,
326
- num_expert_group: Optional[int] = None,
327
- global_num_experts: int = -1,
328
- expert_map: Optional[torch.Tensor] = None,
329
- custom_routing_function: Optional[Callable] = None,
330
- scoring_func: str = "softmax",
331
- routed_scaling_factor: float = 1.0,
332
- e_score_correction_bias: Optional[torch.Tensor] = None,
333
- apply_router_weight_on_input: bool = False,
334
- activation: str = "silu",
335
- enable_eplb: bool = False,
336
- expert_load_view: Optional[torch.Tensor] = None,
337
- logical_to_physical_map: Optional[torch.Tensor] = None,
338
- logical_replica_count: Optional[torch.Tensor] = None,
339
347
  ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
340
348
  assert isinstance(layer, FusedMoE)
341
- if scoring_func != "softmax":
349
+ if layer.scoring_func != "softmax":
342
350
  raise NotImplementedError(
343
351
  "Only softmax is supported for scoring_func")
344
352
 
345
- if self.use_kernel and layer.use_ep:
353
+ x = jax_view(x)
354
+ w13_weight = jax_view(layer.w13_weight)
355
+ w2_weight = jax_view(layer.w2_weight)
356
+ w13_bias = w2_bias = None
357
+ if self.moe.has_bias:
358
+ w13_bias = jax_view(layer.w13_bias)
359
+ w2_bias = jax_view(layer.w2_bias)
360
+ gating_output = jax_view(router_logits)
361
+
362
+ if self.use_kernel:
363
+ actual_hidden_size = x.shape[-1]
364
+ padded_hidden_size = align_to(actual_hidden_size, 256)
365
+ x = jnp.pad(x,
366
+ ((0, 0), (0, padded_hidden_size - actual_hidden_size)),
367
+ constant_values=0)
346
368
  output = fused_ep_moe(
347
369
  mesh=self.mesh,
348
- tokens=jax_view(x),
349
- w1=jax_view(layer.w13_weight),
350
- w2=jax_view(layer.w2_weight),
351
- gating_output=jax_view(router_logits),
352
- top_k=top_k,
370
+ tokens=x,
371
+ w1=w13_weight,
372
+ w2=w2_weight,
373
+ b1=w13_bias,
374
+ b2=w2_bias,
375
+ gating_output=gating_output,
376
+ top_k=layer.top_k,
353
377
  ep_axis_name=self.ep_axis_name,
378
+ renormalize_topk_logits=layer.renormalize,
379
+ act_fn=layer.activation,
354
380
  **self.block_size,
355
- )
381
+ )[:, :actual_hidden_size]
356
382
  else:
357
- # Use the original implementation
358
- output = fused_moe_func_padded(
359
- jax_view(x),
360
- jax_view(layer.w13_weight),
361
- jax_view(layer.w2_weight),
362
- jax_view(layer.w13_bias) if self.moe.has_bias else None,
363
- jax_view(layer.w2_bias) if self.moe.has_bias else None,
364
- jax_view(router_logits),
365
- topk=top_k,
366
- global_num_experts=global_num_experts,
367
- renormalize=renormalize,
368
- reduce_results=layer.reduce_results,
383
+ output = fused_moe_func(
384
+ hidden_states=x,
385
+ w1=w13_weight,
386
+ w2=w2_weight,
387
+ w1_bias=w13_bias,
388
+ w2_bias=w2_bias,
389
+ gating_output=gating_output,
390
+ topk=layer.top_k,
391
+ renormalize=layer.renormalize,
369
392
  mesh=self.mesh,
370
393
  use_ep=layer.use_ep,
371
- activation=activation,
394
+ activation=layer.activation,
372
395
  )
373
396
 
374
397
  return torch_view(output)
@@ -5,9 +5,11 @@ import jax
5
5
  import torch
6
6
  from flax import nnx
7
7
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
- from torchax.ops.mappings import j2t_dtype
9
8
  from transformers import PretrainedConfig
10
9
  from vllm.config import VllmConfig
10
+ from vllm.model_executor.model_loader import get_model_loader
11
+ from vllm.model_executor.model_loader.runai_streamer_loader import \
12
+ RunaiModelStreamerLoader
11
13
  from vllm.utils.func_utils import supports_kw
12
14
 
13
15
  from tpu_inference import envs
@@ -16,11 +18,17 @@ from tpu_inference.logger import init_logger
16
18
  from tpu_inference.models.jax.utils.quantization.quantization_utils import (
17
19
  apply_qwix_on_abstract_model, apply_qwix_quantization,
18
20
  load_random_weights_into_qwix_abstract_model)
21
+ from tpu_inference.utils import to_jax_dtype, to_torch_dtype
19
22
 
20
23
  logger = init_logger(__name__)
21
24
 
22
25
  _MODEL_REGISTRY = {}
23
26
 
27
+ # List of architectures that are preferred to use "vllm" implementation over
28
+ # "flax_nnx" implementation due to various factors such as performance.
29
+ _VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
30
+ {"GptOssForCausalLM"})
31
+
24
32
 
25
33
  class UnsupportedArchitectureError(ValueError):
26
34
  """Raised when a model architecture is not supported in the registry."""
@@ -177,7 +185,23 @@ def _get_nnx_model(
177
185
  # the model creation again, otherwise the model forward will have
178
186
  # non-trivial overhead in PjitFunction.
179
187
  with mesh:
180
- model.load_weights(rng)
188
+ loader = get_model_loader(vllm_config.load_config)
189
+ if isinstance(loader, RunaiModelStreamerLoader):
190
+ model_weights = vllm_config.model_config.model
191
+ if hasattr(vllm_config.model_config, "model_weights"):
192
+ model_weights = vllm_config.model_config.model_weights
193
+ weights_iterator = loader._get_weights_iterator(
194
+ model_weights, vllm_config.model_config.revision)
195
+ # We set the weights iterator at runtime, to prevent having to change
196
+ # every model's load_weights signature. This also prevents us from hitting
197
+ # a TypeError at runtime if you use the RunaiModelStreamerLoader with any
198
+ # flax_nnx model whose load_weights function does not accept the
199
+ # weights_iterator keyword argument.
200
+ vllm_config.model_config.model_weights_iterator = weights_iterator
201
+ model.load_weights(rng)
202
+ del vllm_config.model_config.model_weights_iterator
203
+ else:
204
+ model.load_weights(rng)
181
205
  jit_model = create_jit_model(
182
206
  model,
183
207
  use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
@@ -191,6 +215,9 @@ def get_flax_model(
191
215
  mesh: Mesh,
192
216
  is_draft_model: bool = False,
193
217
  ) -> nnx.Module:
218
+ model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
219
+ vllm_config.model_config.dtype = model_dtype
220
+
194
221
  if is_draft_model:
195
222
  model_class = _get_model_architecture(
196
223
  vllm_config.speculative_config.draft_model_config.hf_config)
@@ -199,7 +226,9 @@ def get_flax_model(
199
226
  vllm_config.model_config.hf_config)
200
227
  jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
201
228
  kv_cache_sharding = NamedSharding(
202
- mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None, "model"))
229
+ mesh,
230
+ PartitionSpec(ShardingAxisName.ATTN_DATA, None,
231
+ ShardingAxisName.ATTN_HEAD))
203
232
  hidden_states_sharding = NamedSharding(mesh,
204
233
  PartitionSpec(
205
234
  ShardingAxisName.ATTN_DATA,
@@ -217,14 +246,17 @@ def get_flax_model(
217
246
  hidden_states_sharding, # aux hidden states
218
247
  ),
219
248
  donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
220
- static_argnums=7, #7 is layer_name_to_kvcache_index
249
+ static_argnums=(
250
+ 7, 10, 11
251
+ ), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
221
252
  )
222
253
  def run_model(graphdef, state, *args):
223
254
  model = nnx.merge(graphdef, state)
224
255
  return model(*args)
225
256
 
226
257
  logits_sharding = NamedSharding(
227
- mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
258
+ mesh,
259
+ PartitionSpec(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR))
228
260
 
229
261
  @functools.partial(
230
262
  jax.jit,
@@ -293,6 +325,8 @@ def get_vllm_model(
293
325
  rng: jax.Array,
294
326
  mesh: Mesh,
295
327
  ):
328
+ model_dtype = to_torch_dtype(vllm_config.model_config.dtype)
329
+ vllm_config.model_config.dtype = model_dtype
296
330
  from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
297
331
 
298
332
  model = VllmModelWrapper(
@@ -318,24 +352,34 @@ def get_model(
318
352
  impl = envs.MODEL_IMPL_TYPE
319
353
  logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
320
354
 
321
- if impl == "flax_nnx":
322
- try:
323
- # Try to load the flax model first
324
- return get_flax_model(vllm_config, rng, mesh, is_draft_model)
325
- except UnsupportedArchitectureError as e:
326
- # Convert the error message to a string to check its contents
327
- error_msg = str(e)
328
-
329
- logger.warning(error_msg)
330
-
331
- # Fall back to the vLLM model and updating the dtype accordingly
332
- vllm_config.model_config.dtype = j2t_dtype(
333
- vllm_config.model_config.dtype.dtype)
355
+ if impl == "auto":
356
+ # Resolve "auto" based on architecture
357
+ architectures = getattr(vllm_config.model_config.hf_config,
358
+ "architectures", [])
359
+ assert len(architectures) == 1, (
360
+ f"Expected exactly one architecture, got {len(architectures)}: "
361
+ f"{architectures}")
362
+ arch = architectures[0]
363
+ impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
364
+ logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
365
+
366
+ match impl:
367
+ case "flax_nnx":
368
+ try:
369
+ # Try to load the flax model first
370
+ return get_flax_model(vllm_config, rng, mesh, is_draft_model)
371
+ except UnsupportedArchitectureError as e:
372
+ # Convert the error message to a string to check its contents
373
+ error_msg = str(e)
374
+
375
+ logger.warning(error_msg)
376
+
377
+ # Fall back to the vLLM model and updating the dtype accordingly
378
+ return get_vllm_model(vllm_config, rng, mesh)
379
+ case "vllm":
334
380
  return get_vllm_model(vllm_config, rng, mesh)
335
- elif impl == "vllm":
336
- return get_vllm_model(vllm_config, rng, mesh)
337
- else:
338
- raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
381
+ case _:
382
+ raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
339
383
 
340
384
 
341
385
  def _validate_model_interface(model: Any) -> None:
@@ -421,6 +465,17 @@ def register_model(arch: str, model: Any) -> None:
421
465
  "This is a JAX model and does not implement the PyTorch forward method."
422
466
  )
423
467
 
468
+ # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
469
+ def unimplemented_get_input_embeddings(
470
+ self,
471
+ input_ids: "torch.Tensor",
472
+ positions: "torch.Tensor",
473
+ inputs_embeds: Optional["torch.Tensor"] = None,
474
+ ) -> "torch.Tensor":
475
+ raise NotImplementedError(
476
+ "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
477
+ )
478
+
424
479
  # We need a custom __init__ that only calls torch.nn.Module's init,
425
480
  # to avoid triggering JAX logic when vLLM inspects the class.
426
481
  def wrapper_init(self, *args, **kwargs):
@@ -434,6 +489,7 @@ def register_model(arch: str, model: Any) -> None:
434
489
  {
435
490
  "__init__": wrapper_init,
436
491
  "forward": unimplemented_forward,
492
+ "get_input_embeddings": unimplemented_get_input_embeddings,
437
493
  # Prevent vLLM from trying to load weights into this dummy class.
438
494
  "load_weights": lambda self, *args, **kwargs: None,
439
495
  })