tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +2 -1
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +43 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
- tpu_inference/platforms/tpu_platform.py +15 -2
- tpu_inference/runner/compilation_manager.py +58 -33
- tpu_inference/runner/kv_cache_manager.py +9 -3
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +203 -102
- tpu_inference/spec_decode/jax/eagle3.py +19 -2
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +160 -23
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
|
@@ -108,6 +108,8 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
|
108
108
|
layer: torch.nn.Module,
|
|
109
109
|
x: torch.Tensor,
|
|
110
110
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
111
|
+
assert isinstance(layer, LinearBase)
|
|
112
|
+
|
|
111
113
|
with jax.named_scope(layer._get_name()):
|
|
112
114
|
if in_sharding := self.jax_config.get_input_sharding(x):
|
|
113
115
|
x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
|
|
@@ -170,14 +172,14 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
170
172
|
self.ep_axis_name = ep_axis_name
|
|
171
173
|
# TODO: Use autotune table once we have it.
|
|
172
174
|
self.block_size = {
|
|
173
|
-
"bt":
|
|
174
|
-
"bf":
|
|
175
|
-
"bd1":
|
|
176
|
-
"bd2":
|
|
177
|
-
"btc":
|
|
178
|
-
"bfc":
|
|
179
|
-
"bd1c":
|
|
180
|
-
"bd2c":
|
|
175
|
+
"bt": 64,
|
|
176
|
+
"bf": 1024,
|
|
177
|
+
"bd1": 1536,
|
|
178
|
+
"bd2": 1536,
|
|
179
|
+
"btc": 64,
|
|
180
|
+
"bfc": 1024,
|
|
181
|
+
"bd1c": 1536,
|
|
182
|
+
"bd2c": 1536,
|
|
181
183
|
}
|
|
182
184
|
|
|
183
185
|
def select_gemm_impl(
|
|
@@ -191,131 +193,119 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
191
193
|
|
|
192
194
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
193
195
|
assert isinstance(layer, FusedMoE)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
197
|
-
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
196
|
+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
197
|
+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
198
198
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
199
|
+
if self.moe.has_bias:
|
|
200
|
+
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
201
|
+
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
202
|
+
|
|
203
|
+
if layer.activation == "swigluoai":
|
|
204
|
+
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
205
|
+
# However, interleaved split is not performant on TPU. Therefore,
|
|
206
|
+
# we preprocess the weight so that splitting gmm output by middle
|
|
207
|
+
# can still get the same result.
|
|
208
|
+
w1_weight = w13_weight[:, ::2, :]
|
|
209
|
+
w3_weight = w13_weight[:, 1::2, :]
|
|
210
|
+
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
211
211
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
if self.use_kernel and layer.use_ep:
|
|
218
|
-
# Kernel expects:
|
|
219
|
-
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
220
|
-
# w2: (num_experts, intermediate_size, hidden_size)
|
|
221
|
-
# Current format:
|
|
222
|
-
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
223
|
-
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
224
|
-
num_experts = w13_weight.shape[0]
|
|
225
|
-
intermediate_size = w13_weight.shape[1] // 2
|
|
226
|
-
hidden_size = w13_weight.shape[2]
|
|
212
|
+
if self.moe.has_bias:
|
|
213
|
+
w1_bias = w13_bias[:, ::2]
|
|
214
|
+
w3_bias = w13_bias[:, 1::2]
|
|
215
|
+
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
227
216
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
217
|
+
if self.use_kernel and layer.use_ep:
|
|
218
|
+
# Kernel expects:
|
|
219
|
+
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
220
|
+
# w2: (num_experts, intermediate_size, hidden_size)
|
|
221
|
+
# Current format:
|
|
222
|
+
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
223
|
+
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
224
|
+
num_experts = w13_weight.shape[0]
|
|
225
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
226
|
+
hidden_size = w13_weight.shape[2]
|
|
227
|
+
|
|
228
|
+
# Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
|
|
229
|
+
w13_reshaped = w13_weight.reshape(num_experts, 2,
|
|
230
|
+
intermediate_size, hidden_size)
|
|
231
|
+
w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
|
|
232
|
+
|
|
233
|
+
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
|
|
234
|
+
w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
|
|
235
|
+
|
|
236
|
+
# Apply EP sharding
|
|
237
|
+
w13_weight = jax.device_put(
|
|
238
|
+
w13_weight_transposed,
|
|
239
|
+
Format(Layout((0, 1, 2, 3)),
|
|
240
|
+
NamedSharding(self.mesh, P("model", None, None, None))))
|
|
241
|
+
w2_weight = jax.device_put(
|
|
242
|
+
w2_weight_transposed,
|
|
243
|
+
Format(Layout((0, 1, 2)),
|
|
244
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
234
245
|
|
|
235
|
-
|
|
236
|
-
|
|
246
|
+
if self.moe.has_bias:
|
|
247
|
+
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
|
|
237
248
|
|
|
238
249
|
# Apply EP sharding
|
|
250
|
+
w13_bias = jax.device_put(
|
|
251
|
+
w13_bias,
|
|
252
|
+
Format(Layout((0, 1, 2)),
|
|
253
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
254
|
+
w2_bias = jax.device_put(
|
|
255
|
+
w2_bias,
|
|
256
|
+
Format(Layout((0, 1)),
|
|
257
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
258
|
+
|
|
259
|
+
else:
|
|
260
|
+
# Original logic for non-kernel path
|
|
261
|
+
if layer.use_ep:
|
|
239
262
|
w13_weight = jax.device_put(
|
|
240
|
-
|
|
241
|
-
Format(
|
|
242
|
-
|
|
243
|
-
NamedSharding(self.mesh, P("model", None, None,
|
|
244
|
-
None))))
|
|
263
|
+
w13_weight,
|
|
264
|
+
Format(Layout((0, 1, 2)),
|
|
265
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
245
266
|
w2_weight = jax.device_put(
|
|
246
|
-
|
|
267
|
+
w2_weight,
|
|
247
268
|
Format(Layout((0, 1, 2)),
|
|
248
269
|
NamedSharding(self.mesh, P("model", None, None))))
|
|
249
270
|
|
|
250
271
|
if self.moe.has_bias:
|
|
251
|
-
w13_bias = w13_bias.reshape(num_experts, 2,
|
|
252
|
-
intermediate_size)
|
|
253
|
-
|
|
254
|
-
# Apply EP sharding
|
|
255
272
|
w13_bias = jax.device_put(
|
|
256
273
|
w13_bias,
|
|
257
|
-
Format(
|
|
258
|
-
|
|
259
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
274
|
+
Format(Layout((0, 1)),
|
|
275
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
260
276
|
w2_bias = jax.device_put(
|
|
261
277
|
w2_bias,
|
|
262
278
|
Format(Layout((0, 1)),
|
|
263
279
|
NamedSharding(self.mesh, P("model", None))))
|
|
264
280
|
|
|
265
281
|
else:
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
n_shards = self.mesh.shape["model"]
|
|
294
|
-
assert intermediate_size % n_shards == 0
|
|
295
|
-
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
296
|
-
w13_weight, output_sizes, n_shards, dim=1)
|
|
297
|
-
w13_weight = jax.device_put(
|
|
298
|
-
w13_weight,
|
|
299
|
-
Format(
|
|
300
|
-
Layout((0, 1, 2)),
|
|
301
|
-
NamedSharding(self.mesh, P(None, "model", None))))
|
|
302
|
-
w2_weight = jax.device_put(
|
|
303
|
-
w2_weight,
|
|
304
|
-
Format(
|
|
305
|
-
Layout((0, 1, 2)),
|
|
306
|
-
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
307
|
-
|
|
308
|
-
if self.moe.has_bias:
|
|
309
|
-
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
310
|
-
w13_bias, output_sizes, n_shards, dim=1)
|
|
311
|
-
w13_bias = jax.device_put(
|
|
312
|
-
w13_bias,
|
|
313
|
-
Format(Layout((0, 1)),
|
|
314
|
-
NamedSharding(self.mesh, P(None, "model"))))
|
|
315
|
-
w2_bias = jax.device_put(
|
|
316
|
-
w2_bias,
|
|
317
|
-
Format(Layout((0, 1)),
|
|
318
|
-
NamedSharding(self.mesh, P(None, None))))
|
|
282
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
283
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
284
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
285
|
+
n_shards = self.mesh.shape["model"]
|
|
286
|
+
assert intermediate_size % n_shards == 0
|
|
287
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
288
|
+
w13_weight, output_sizes, n_shards, dim=1)
|
|
289
|
+
w13_weight = jax.device_put(
|
|
290
|
+
w13_weight,
|
|
291
|
+
Format(Layout((0, 1, 2)),
|
|
292
|
+
NamedSharding(self.mesh, P(None, "model", None))))
|
|
293
|
+
w2_weight = jax.device_put(
|
|
294
|
+
w2_weight,
|
|
295
|
+
Format(Layout((0, 1, 2)),
|
|
296
|
+
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
297
|
+
|
|
298
|
+
if self.moe.has_bias:
|
|
299
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
300
|
+
w13_bias, output_sizes, n_shards, dim=1)
|
|
301
|
+
w13_bias = jax.device_put(
|
|
302
|
+
w13_bias,
|
|
303
|
+
Format(Layout((0, 1)),
|
|
304
|
+
NamedSharding(self.mesh, P(None, "model"))))
|
|
305
|
+
w2_bias = jax.device_put(
|
|
306
|
+
w2_bias,
|
|
307
|
+
Format(Layout((0, 1)),
|
|
308
|
+
NamedSharding(self.mesh, P(None, None))))
|
|
319
309
|
|
|
320
310
|
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
321
311
|
requires_grad=False)
|
|
@@ -360,9 +350,13 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
360
350
|
tokens=jax_view(x),
|
|
361
351
|
w1=jax_view(layer.w13_weight),
|
|
362
352
|
w2=jax_view(layer.w2_weight),
|
|
353
|
+
b1=jax_view(layer.w13_bias) if self.moe.has_bias else None,
|
|
354
|
+
b2=jax_view(layer.w2_bias) if self.moe.has_bias else None,
|
|
363
355
|
gating_output=jax_view(router_logits),
|
|
364
356
|
top_k=top_k,
|
|
365
357
|
ep_axis_name=self.ep_axis_name,
|
|
358
|
+
renormalize_topk_logits=renormalize,
|
|
359
|
+
act_fn=activation,
|
|
366
360
|
**self.block_size,
|
|
367
361
|
)
|
|
368
362
|
else:
|
|
@@ -19,6 +19,7 @@ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
|
|
19
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
20
20
|
ParallelLMHead, VocabParallelEmbedding)
|
|
21
21
|
|
|
22
|
+
from tpu_inference import envs
|
|
22
23
|
from tpu_inference.logger import init_logger
|
|
23
24
|
|
|
24
25
|
P = PartitionSpec
|
|
@@ -211,8 +212,7 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
|
|
|
211
212
|
def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
|
|
212
213
|
if isinstance(tensor, tuple):
|
|
213
214
|
return tuple(_sharded_device_put(t, sharding) for t in tensor)
|
|
214
|
-
|
|
215
|
-
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
215
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
216
216
|
if multihost_backend != "ray":
|
|
217
217
|
return jax.device_put(tensor, sharding)
|
|
218
218
|
|
|
@@ -239,7 +239,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|
|
239
239
|
lora_index_to_id: list[Optional[int]],
|
|
240
240
|
max_loras: int,
|
|
241
241
|
vocab_size: int,
|
|
242
|
-
extra_vocab_size: int,
|
|
243
242
|
):
|
|
244
243
|
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
|
245
244
|
# TODO: Should this happen inside mapping internally? If so how can we
|
|
@@ -258,7 +257,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|
|
258
257
|
lora_index_to_id,
|
|
259
258
|
max_loras,
|
|
260
259
|
vocab_size,
|
|
261
|
-
extra_vocab_size
|
|
260
|
+
0, # extra_vocab_size
|
|
262
261
|
"cpu",
|
|
263
262
|
)
|
|
264
263
|
with torchax.default_env():
|
|
@@ -8,6 +8,9 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
|
8
8
|
from torchax.ops.mappings import j2t_dtype
|
|
9
9
|
from transformers import PretrainedConfig
|
|
10
10
|
from vllm.config import VllmConfig
|
|
11
|
+
from vllm.model_executor.model_loader import get_model_loader
|
|
12
|
+
from vllm.model_executor.model_loader.runai_streamer_loader import \
|
|
13
|
+
RunaiModelStreamerLoader
|
|
11
14
|
from vllm.utils.func_utils import supports_kw
|
|
12
15
|
|
|
13
16
|
from tpu_inference import envs
|
|
@@ -36,19 +39,17 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
|
36
39
|
from tpu_inference.models.jax.llama3 import LlamaForCausalLM
|
|
37
40
|
from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
|
|
38
41
|
from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
|
|
39
|
-
from tpu_inference.models.jax.
|
|
40
|
-
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
42
|
+
from tpu_inference.models.jax.llama_guard_4 import LlamaGuard4ForCausalLM
|
|
41
43
|
from tpu_inference.models.jax.qwen2_5_vl import \
|
|
42
44
|
Qwen2_5_VLForConditionalGeneration
|
|
43
45
|
from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
|
|
44
46
|
_MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
|
|
45
47
|
_MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
|
|
46
48
|
_MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
|
|
47
|
-
_MODEL_REGISTRY["
|
|
49
|
+
_MODEL_REGISTRY["Llama4ForConditionalGeneration"] = LlamaGuard4ForCausalLM
|
|
48
50
|
_MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
|
|
49
51
|
_MODEL_REGISTRY[
|
|
50
52
|
"Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
|
|
51
|
-
_MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
|
|
52
53
|
_MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
|
|
53
54
|
_MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
|
|
54
55
|
|
|
@@ -57,8 +58,10 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
|
57
58
|
if arch in _MODEL_REGISTRY:
|
|
58
59
|
return _MODEL_REGISTRY[arch]
|
|
59
60
|
raise UnsupportedArchitectureError(
|
|
60
|
-
f"Model architectures {architectures}
|
|
61
|
-
|
|
61
|
+
f"Model architectures {architectures} not "
|
|
62
|
+
"registered in tpu-inference. Falling back to vLLM-native "
|
|
63
|
+
f"Pytorch definition. JAX-native architectures: {list(_MODEL_REGISTRY.keys())}"
|
|
64
|
+
)
|
|
62
65
|
|
|
63
66
|
|
|
64
67
|
def _get_nnx_model(
|
|
@@ -177,7 +180,23 @@ def _get_nnx_model(
|
|
|
177
180
|
# the model creation again, otherwise the model forward will have
|
|
178
181
|
# non-trivial overhead in PjitFunction.
|
|
179
182
|
with mesh:
|
|
180
|
-
|
|
183
|
+
loader = get_model_loader(vllm_config.load_config)
|
|
184
|
+
if isinstance(loader, RunaiModelStreamerLoader):
|
|
185
|
+
model_weights = vllm_config.model_config.model
|
|
186
|
+
if hasattr(vllm_config.model_config, "model_weights"):
|
|
187
|
+
model_weights = vllm_config.model_config.model_weights
|
|
188
|
+
weights_iterator = loader._get_weights_iterator(
|
|
189
|
+
model_weights, vllm_config.model_config.revision)
|
|
190
|
+
# We set the weights iterator at runtime, to prevent having to change
|
|
191
|
+
# every model's load_weights signature. This also prevents us from hitting
|
|
192
|
+
# a TypeError at runtime if you use the RunaiModelStreamerLoader with any
|
|
193
|
+
# flax_nnx model whose load_weights function does not accept the
|
|
194
|
+
# weights_iterator keyword argument.
|
|
195
|
+
vllm_config.model_config.model_weights_iterator = weights_iterator
|
|
196
|
+
model.load_weights(rng)
|
|
197
|
+
del vllm_config.model_config.model_weights_iterator
|
|
198
|
+
else:
|
|
199
|
+
model.load_weights(rng)
|
|
181
200
|
jit_model = create_jit_model(
|
|
182
201
|
model,
|
|
183
202
|
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
|
|
@@ -217,7 +236,7 @@ def get_flax_model(
|
|
|
217
236
|
hidden_states_sharding, # aux hidden states
|
|
218
237
|
),
|
|
219
238
|
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
|
|
220
|
-
static_argnums=
|
|
239
|
+
static_argnums=7, #7 is layer_name_to_kvcache_index
|
|
221
240
|
)
|
|
222
241
|
def run_model(graphdef, state, *args):
|
|
223
242
|
model = nnx.merge(graphdef, state)
|
|
@@ -242,10 +261,11 @@ def get_flax_model(
|
|
|
242
261
|
model = nnx.merge(graphdef, state)
|
|
243
262
|
return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
|
|
244
263
|
|
|
264
|
+
embed_sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
245
265
|
# This function will calculates the embeddings of input texts and then merge with the image embeddings
|
|
246
266
|
@functools.partial(
|
|
247
267
|
jax.jit,
|
|
248
|
-
out_shardings=(
|
|
268
|
+
out_shardings=(embed_sharding),
|
|
249
269
|
)
|
|
250
270
|
def run_get_input_embeddings(graphdef, state, *args, **kwargs):
|
|
251
271
|
model = nnx.merge(graphdef, state)
|
|
@@ -325,8 +345,8 @@ def get_model(
|
|
|
325
345
|
# Convert the error message to a string to check its contents
|
|
326
346
|
error_msg = str(e)
|
|
327
347
|
|
|
328
|
-
logger.warning(
|
|
329
|
-
|
|
348
|
+
logger.warning(error_msg)
|
|
349
|
+
|
|
330
350
|
# Fall back to the vLLM model and updating the dtype accordingly
|
|
331
351
|
vllm_config.model_config.dtype = j2t_dtype(
|
|
332
352
|
vllm_config.model_config.dtype.dtype)
|
|
@@ -420,6 +440,17 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
420
440
|
"This is a JAX model and does not implement the PyTorch forward method."
|
|
421
441
|
)
|
|
422
442
|
|
|
443
|
+
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
|
|
444
|
+
def unimplemented_get_input_embeddings(
|
|
445
|
+
self,
|
|
446
|
+
input_ids: "torch.Tensor",
|
|
447
|
+
positions: "torch.Tensor",
|
|
448
|
+
inputs_embeds: Optional["torch.Tensor"] = None,
|
|
449
|
+
) -> "torch.Tensor":
|
|
450
|
+
raise NotImplementedError(
|
|
451
|
+
"This is a JAX model and does not implement the PyTorch get_input_embeddings method."
|
|
452
|
+
)
|
|
453
|
+
|
|
423
454
|
# We need a custom __init__ that only calls torch.nn.Module's init,
|
|
424
455
|
# to avoid triggering JAX logic when vLLM inspects the class.
|
|
425
456
|
def wrapper_init(self, *args, **kwargs):
|
|
@@ -433,6 +464,7 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
433
464
|
{
|
|
434
465
|
"__init__": wrapper_init,
|
|
435
466
|
"forward": unimplemented_forward,
|
|
467
|
+
"get_input_embeddings": unimplemented_get_input_embeddings,
|
|
436
468
|
# Prevent vLLM from trying to load weights into this dummy class.
|
|
437
469
|
"load_weights": lambda self, *args, **kwargs: None,
|
|
438
470
|
})
|
|
@@ -368,7 +368,8 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
368
368
|
"lm_head": "model.lm_head",
|
|
369
369
|
})
|
|
370
370
|
|
|
371
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
371
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
372
|
+
self.mesh, mappings)
|
|
372
373
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
373
374
|
model=self,
|
|
374
375
|
metadata_map=metadata_map,
|
|
@@ -194,13 +194,12 @@ class Eagle3LlamaModel(nnx.Module):
|
|
|
194
194
|
|
|
195
195
|
def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
|
|
196
196
|
metadata_map: MetadataMap):
|
|
197
|
-
model_config = vllm_config.
|
|
197
|
+
model_config = vllm_config.speculative_config.draft_model_config
|
|
198
198
|
hf_config = model_config.hf_config
|
|
199
199
|
|
|
200
200
|
num_heads = hf_config.num_attention_heads
|
|
201
201
|
num_kv_heads = hf_config.num_key_value_heads
|
|
202
|
-
hidden_size =
|
|
203
|
-
|
|
202
|
+
hidden_size = hf_config.hidden_size
|
|
204
203
|
head_dim_original = model_config.get_head_size()
|
|
205
204
|
|
|
206
205
|
metadata_map.reshape_map.update({
|
|
@@ -305,6 +304,8 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
305
304
|
"fc": "model.fc.kernel",
|
|
306
305
|
"lm_head": "lm_head.kernel",
|
|
307
306
|
"d2t": "draft_id_to_target_id",
|
|
307
|
+
"embed_tokens":
|
|
308
|
+
"model.embed_tokens.embedding", # Some checkpoints need this
|
|
308
309
|
}
|
|
309
310
|
|
|
310
311
|
# Define keys to keep in original dtype (e.g., float32 for stability)
|
|
@@ -312,7 +313,9 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
312
313
|
r".*d2t.*",
|
|
313
314
|
]
|
|
314
315
|
|
|
315
|
-
metadata_map = get_default_maps(
|
|
316
|
+
metadata_map = get_default_maps(
|
|
317
|
+
self.vllm_config.speculative_config.draft_model_config, self.mesh,
|
|
318
|
+
mappings)
|
|
316
319
|
|
|
317
320
|
update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
|
|
318
321
|
|
|
@@ -324,7 +327,7 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
324
327
|
is_draft_model=True,
|
|
325
328
|
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
|
|
326
329
|
|
|
327
|
-
# If the embedding is not initialized, initialize it with a
|
|
330
|
+
# If the embedding is not initialized, initialize it with a dummy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
|
|
328
331
|
if isinstance(self.model.embed_tokens.embedding.value,
|
|
329
332
|
jax.ShapeDtypeStruct):
|
|
330
333
|
self.model.embed_tokens.embedding.value = jnp.zeros(
|