tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any,
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
@@ -25,7 +25,7 @@ from tpu_inference import envs
|
|
|
25
25
|
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
26
26
|
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
|
|
27
27
|
get_tpu_quant_method)
|
|
28
|
-
from tpu_inference.layers.vllm.fused_moe import
|
|
28
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
29
29
|
from tpu_inference.layers.vllm.linear_common import (
|
|
30
30
|
reorder_concatenated_tensor_for_sharding,
|
|
31
31
|
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
@@ -36,6 +36,10 @@ P = PartitionSpec
|
|
|
36
36
|
logger = init_logger(__name__)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
def align_to(a, b):
|
|
40
|
+
return (a + b - 1) // b * b
|
|
41
|
+
|
|
42
|
+
|
|
39
43
|
@register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
|
|
40
44
|
class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
|
|
41
45
|
|
|
@@ -108,6 +112,8 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
|
108
112
|
layer: torch.nn.Module,
|
|
109
113
|
x: torch.Tensor,
|
|
110
114
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
115
|
+
assert isinstance(layer, LinearBase)
|
|
116
|
+
|
|
111
117
|
with jax.named_scope(layer._get_name()):
|
|
112
118
|
if in_sharding := self.jax_config.get_input_sharding(x):
|
|
113
119
|
x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
|
|
@@ -166,18 +172,18 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
166
172
|
ep_axis_name: str = 'model'):
|
|
167
173
|
super().__init__(moe)
|
|
168
174
|
self.mesh = mesh
|
|
169
|
-
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
175
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
|
|
170
176
|
self.ep_axis_name = ep_axis_name
|
|
171
177
|
# TODO: Use autotune table once we have it.
|
|
172
178
|
self.block_size = {
|
|
173
|
-
"bt":
|
|
174
|
-
"bf":
|
|
175
|
-
"bd1":
|
|
176
|
-
"bd2":
|
|
177
|
-
"btc":
|
|
178
|
-
"bfc":
|
|
179
|
-
"bd1c":
|
|
180
|
-
"bd2c":
|
|
179
|
+
"bt": 64,
|
|
180
|
+
"bf": 1024,
|
|
181
|
+
"bd1": 1536,
|
|
182
|
+
"bd2": 1536,
|
|
183
|
+
"btc": 64,
|
|
184
|
+
"bfc": 1024,
|
|
185
|
+
"bd1c": 1536,
|
|
186
|
+
"bd2c": 1536,
|
|
181
187
|
}
|
|
182
188
|
|
|
183
189
|
def select_gemm_impl(
|
|
@@ -194,6 +200,8 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
194
200
|
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
195
201
|
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
196
202
|
|
|
203
|
+
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
204
|
+
|
|
197
205
|
if self.moe.has_bias:
|
|
198
206
|
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
199
207
|
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
@@ -212,7 +220,7 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
212
220
|
w3_bias = w13_bias[:, 1::2]
|
|
213
221
|
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
214
222
|
|
|
215
|
-
if self.use_kernel
|
|
223
|
+
if self.use_kernel:
|
|
216
224
|
# Kernel expects:
|
|
217
225
|
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
218
226
|
# w2: (num_experts, intermediate_size, hidden_size)
|
|
@@ -223,65 +231,82 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
223
231
|
intermediate_size = w13_weight.shape[1] // 2
|
|
224
232
|
hidden_size = w13_weight.shape[2]
|
|
225
233
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
234
|
+
padded_intermediate_size = align_to(intermediate_size, 256)
|
|
235
|
+
padded_hidden_size = align_to(hidden_size, 256)
|
|
236
|
+
|
|
237
|
+
w13_weight = w13_weight.reshape(num_experts, 2, intermediate_size,
|
|
238
|
+
hidden_size)
|
|
239
|
+
w13_weight = jnp.transpose(w13_weight, (0, 1, 3, 2))
|
|
230
240
|
|
|
231
241
|
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
|
|
232
|
-
|
|
242
|
+
w2_weight = jnp.transpose(w2_weight, (0, 2, 1))
|
|
243
|
+
|
|
244
|
+
w13_weight = jnp.pad(
|
|
245
|
+
w13_weight,
|
|
246
|
+
((0, 0), (0, 0), (0, padded_hidden_size - hidden_size),
|
|
247
|
+
(0, padded_intermediate_size - intermediate_size)),
|
|
248
|
+
constant_values=0)
|
|
249
|
+
|
|
250
|
+
w2_weight = jnp.pad(
|
|
251
|
+
w2_weight,
|
|
252
|
+
((0, 0), (0, padded_intermediate_size - intermediate_size),
|
|
253
|
+
(0, padded_hidden_size - hidden_size)),
|
|
254
|
+
constant_values=0)
|
|
233
255
|
|
|
234
256
|
# Apply EP sharding
|
|
257
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
258
|
+
|
|
235
259
|
w13_weight = jax.device_put(
|
|
236
|
-
|
|
260
|
+
w13_weight,
|
|
237
261
|
Format(Layout((0, 1, 2, 3)),
|
|
238
262
|
NamedSharding(self.mesh, P("model", None, None, None))))
|
|
239
263
|
w2_weight = jax.device_put(
|
|
240
|
-
|
|
264
|
+
w2_weight,
|
|
241
265
|
Format(Layout((0, 1, 2)),
|
|
242
266
|
NamedSharding(self.mesh, P("model", None, None))))
|
|
243
267
|
|
|
244
268
|
if self.moe.has_bias:
|
|
245
|
-
w13_bias = w13_bias.reshape(
|
|
269
|
+
w13_bias = w13_bias.astype(jnp.float32).reshape(
|
|
270
|
+
num_experts, 2, 1, intermediate_size)
|
|
271
|
+
w2_bias = w2_bias.astype(jnp.float32).reshape(
|
|
272
|
+
num_experts, 1, hidden_size)
|
|
273
|
+
|
|
274
|
+
w13_bias = jnp.pad(
|
|
275
|
+
w13_bias,
|
|
276
|
+
((0, 0), (0, 0), (0, 0),
|
|
277
|
+
(0, padded_intermediate_size - intermediate_size)),
|
|
278
|
+
constant_values=0)
|
|
279
|
+
|
|
280
|
+
w2_bias = jnp.pad(w2_bias,
|
|
281
|
+
((0, 0), (0, 0),
|
|
282
|
+
(0, padded_hidden_size - hidden_size)),
|
|
283
|
+
constant_values=0)
|
|
246
284
|
|
|
247
285
|
# Apply EP sharding
|
|
248
286
|
w13_bias = jax.device_put(
|
|
249
|
-
w13_bias,
|
|
250
|
-
Format(Layout((0, 1, 2)),
|
|
251
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
287
|
+
w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
252
288
|
w2_bias = jax.device_put(
|
|
253
|
-
w2_bias,
|
|
254
|
-
Format(Layout((0, 1)),
|
|
255
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
256
|
-
|
|
289
|
+
w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
|
|
257
290
|
else:
|
|
258
|
-
|
|
291
|
+
|
|
259
292
|
if layer.use_ep:
|
|
293
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
260
294
|
w13_weight = jax.device_put(
|
|
261
|
-
w13_weight,
|
|
262
|
-
Format(Layout((0, 1, 2)),
|
|
263
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
295
|
+
w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
264
296
|
w2_weight = jax.device_put(
|
|
265
|
-
w2_weight,
|
|
266
|
-
Format(Layout((0, 1, 2)),
|
|
267
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
297
|
+
w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
268
298
|
|
|
269
299
|
if self.moe.has_bias:
|
|
270
300
|
w13_bias = jax.device_put(
|
|
271
|
-
w13_bias,
|
|
272
|
-
Format(Layout((0, 1)),
|
|
273
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
301
|
+
w13_bias, Format(Layout((0, 1)), ep_sharding))
|
|
274
302
|
w2_bias = jax.device_put(
|
|
275
|
-
w2_bias,
|
|
276
|
-
Format(Layout((0, 1)),
|
|
277
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
303
|
+
w2_bias, Format(Layout((0, 1)), ep_sharding))
|
|
278
304
|
|
|
279
305
|
else:
|
|
280
|
-
intermediate_size = w13_weight.shape[1] // 2
|
|
281
|
-
assert intermediate_size == w2_weight.shape[-1]
|
|
282
306
|
output_sizes = [intermediate_size, intermediate_size]
|
|
283
307
|
n_shards = self.mesh.shape["model"]
|
|
284
308
|
assert intermediate_size % n_shards == 0
|
|
309
|
+
|
|
285
310
|
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
286
311
|
w13_weight, output_sizes, n_shards, dim=1)
|
|
287
312
|
w13_weight = jax.device_put(
|
|
@@ -319,56 +344,54 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
319
344
|
layer: torch.nn.Module,
|
|
320
345
|
x: torch.Tensor,
|
|
321
346
|
router_logits: torch.Tensor,
|
|
322
|
-
top_k: int,
|
|
323
|
-
renormalize: bool,
|
|
324
|
-
use_grouped_topk: bool = False,
|
|
325
|
-
topk_group: Optional[int] = None,
|
|
326
|
-
num_expert_group: Optional[int] = None,
|
|
327
|
-
global_num_experts: int = -1,
|
|
328
|
-
expert_map: Optional[torch.Tensor] = None,
|
|
329
|
-
custom_routing_function: Optional[Callable] = None,
|
|
330
|
-
scoring_func: str = "softmax",
|
|
331
|
-
routed_scaling_factor: float = 1.0,
|
|
332
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
333
|
-
apply_router_weight_on_input: bool = False,
|
|
334
|
-
activation: str = "silu",
|
|
335
|
-
enable_eplb: bool = False,
|
|
336
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
337
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
338
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
339
347
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
340
348
|
assert isinstance(layer, FusedMoE)
|
|
341
|
-
if scoring_func != "softmax":
|
|
349
|
+
if layer.scoring_func != "softmax":
|
|
342
350
|
raise NotImplementedError(
|
|
343
351
|
"Only softmax is supported for scoring_func")
|
|
344
352
|
|
|
345
|
-
|
|
353
|
+
x = jax_view(x)
|
|
354
|
+
w13_weight = jax_view(layer.w13_weight)
|
|
355
|
+
w2_weight = jax_view(layer.w2_weight)
|
|
356
|
+
w13_bias = w2_bias = None
|
|
357
|
+
if self.moe.has_bias:
|
|
358
|
+
w13_bias = jax_view(layer.w13_bias)
|
|
359
|
+
w2_bias = jax_view(layer.w2_bias)
|
|
360
|
+
gating_output = jax_view(router_logits)
|
|
361
|
+
|
|
362
|
+
if self.use_kernel:
|
|
363
|
+
actual_hidden_size = x.shape[-1]
|
|
364
|
+
padded_hidden_size = align_to(actual_hidden_size, 256)
|
|
365
|
+
x = jnp.pad(x,
|
|
366
|
+
((0, 0), (0, padded_hidden_size - actual_hidden_size)),
|
|
367
|
+
constant_values=0)
|
|
346
368
|
output = fused_ep_moe(
|
|
347
369
|
mesh=self.mesh,
|
|
348
|
-
tokens=
|
|
349
|
-
w1=
|
|
350
|
-
w2=
|
|
351
|
-
|
|
352
|
-
|
|
370
|
+
tokens=x,
|
|
371
|
+
w1=w13_weight,
|
|
372
|
+
w2=w2_weight,
|
|
373
|
+
b1=w13_bias,
|
|
374
|
+
b2=w2_bias,
|
|
375
|
+
gating_output=gating_output,
|
|
376
|
+
top_k=layer.top_k,
|
|
353
377
|
ep_axis_name=self.ep_axis_name,
|
|
378
|
+
renormalize_topk_logits=layer.renormalize,
|
|
379
|
+
act_fn=layer.activation,
|
|
354
380
|
**self.block_size,
|
|
355
|
-
)
|
|
381
|
+
)[:, :actual_hidden_size]
|
|
356
382
|
else:
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
global_num_experts=global_num_experts,
|
|
367
|
-
renormalize=renormalize,
|
|
368
|
-
reduce_results=layer.reduce_results,
|
|
383
|
+
output = fused_moe_func(
|
|
384
|
+
hidden_states=x,
|
|
385
|
+
w1=w13_weight,
|
|
386
|
+
w2=w2_weight,
|
|
387
|
+
w1_bias=w13_bias,
|
|
388
|
+
w2_bias=w2_bias,
|
|
389
|
+
gating_output=gating_output,
|
|
390
|
+
topk=layer.top_k,
|
|
391
|
+
renormalize=layer.renormalize,
|
|
369
392
|
mesh=self.mesh,
|
|
370
393
|
use_ep=layer.use_ep,
|
|
371
|
-
activation=activation,
|
|
394
|
+
activation=layer.activation,
|
|
372
395
|
)
|
|
373
396
|
|
|
374
397
|
return torch_view(output)
|
|
@@ -5,9 +5,11 @@ import jax
|
|
|
5
5
|
import torch
|
|
6
6
|
from flax import nnx
|
|
7
7
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
-
from torchax.ops.mappings import j2t_dtype
|
|
9
8
|
from transformers import PretrainedConfig
|
|
10
9
|
from vllm.config import VllmConfig
|
|
10
|
+
from vllm.model_executor.model_loader import get_model_loader
|
|
11
|
+
from vllm.model_executor.model_loader.runai_streamer_loader import \
|
|
12
|
+
RunaiModelStreamerLoader
|
|
11
13
|
from vllm.utils.func_utils import supports_kw
|
|
12
14
|
|
|
13
15
|
from tpu_inference import envs
|
|
@@ -16,11 +18,17 @@ from tpu_inference.logger import init_logger
|
|
|
16
18
|
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
|
|
17
19
|
apply_qwix_on_abstract_model, apply_qwix_quantization,
|
|
18
20
|
load_random_weights_into_qwix_abstract_model)
|
|
21
|
+
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
|
|
19
22
|
|
|
20
23
|
logger = init_logger(__name__)
|
|
21
24
|
|
|
22
25
|
_MODEL_REGISTRY = {}
|
|
23
26
|
|
|
27
|
+
# List of architectures that are preferred to use "vllm" implementation over
|
|
28
|
+
# "flax_nnx" implementation due to various factors such as performance.
|
|
29
|
+
_VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset(
|
|
30
|
+
{"GptOssForCausalLM"})
|
|
31
|
+
|
|
24
32
|
|
|
25
33
|
class UnsupportedArchitectureError(ValueError):
|
|
26
34
|
"""Raised when a model architecture is not supported in the registry."""
|
|
@@ -177,7 +185,23 @@ def _get_nnx_model(
|
|
|
177
185
|
# the model creation again, otherwise the model forward will have
|
|
178
186
|
# non-trivial overhead in PjitFunction.
|
|
179
187
|
with mesh:
|
|
180
|
-
|
|
188
|
+
loader = get_model_loader(vllm_config.load_config)
|
|
189
|
+
if isinstance(loader, RunaiModelStreamerLoader):
|
|
190
|
+
model_weights = vllm_config.model_config.model
|
|
191
|
+
if hasattr(vllm_config.model_config, "model_weights"):
|
|
192
|
+
model_weights = vllm_config.model_config.model_weights
|
|
193
|
+
weights_iterator = loader._get_weights_iterator(
|
|
194
|
+
model_weights, vllm_config.model_config.revision)
|
|
195
|
+
# We set the weights iterator at runtime, to prevent having to change
|
|
196
|
+
# every model's load_weights signature. This also prevents us from hitting
|
|
197
|
+
# a TypeError at runtime if you use the RunaiModelStreamerLoader with any
|
|
198
|
+
# flax_nnx model whose load_weights function does not accept the
|
|
199
|
+
# weights_iterator keyword argument.
|
|
200
|
+
vllm_config.model_config.model_weights_iterator = weights_iterator
|
|
201
|
+
model.load_weights(rng)
|
|
202
|
+
del vllm_config.model_config.model_weights_iterator
|
|
203
|
+
else:
|
|
204
|
+
model.load_weights(rng)
|
|
181
205
|
jit_model = create_jit_model(
|
|
182
206
|
model,
|
|
183
207
|
use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model)
|
|
@@ -191,6 +215,9 @@ def get_flax_model(
|
|
|
191
215
|
mesh: Mesh,
|
|
192
216
|
is_draft_model: bool = False,
|
|
193
217
|
) -> nnx.Module:
|
|
218
|
+
model_dtype = to_jax_dtype(vllm_config.model_config.dtype)
|
|
219
|
+
vllm_config.model_config.dtype = model_dtype
|
|
220
|
+
|
|
194
221
|
if is_draft_model:
|
|
195
222
|
model_class = _get_model_architecture(
|
|
196
223
|
vllm_config.speculative_config.draft_model_config.hf_config)
|
|
@@ -199,7 +226,9 @@ def get_flax_model(
|
|
|
199
226
|
vllm_config.model_config.hf_config)
|
|
200
227
|
jit_model = _get_nnx_model(model_class, vllm_config, rng, mesh)
|
|
201
228
|
kv_cache_sharding = NamedSharding(
|
|
202
|
-
mesh,
|
|
229
|
+
mesh,
|
|
230
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
|
|
231
|
+
ShardingAxisName.ATTN_HEAD))
|
|
203
232
|
hidden_states_sharding = NamedSharding(mesh,
|
|
204
233
|
PartitionSpec(
|
|
205
234
|
ShardingAxisName.ATTN_DATA,
|
|
@@ -217,14 +246,17 @@ def get_flax_model(
|
|
|
217
246
|
hidden_states_sharding, # aux hidden states
|
|
218
247
|
),
|
|
219
248
|
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
|
|
220
|
-
static_argnums=
|
|
249
|
+
static_argnums=(
|
|
250
|
+
7, 10, 11
|
|
251
|
+
), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
|
|
221
252
|
)
|
|
222
253
|
def run_model(graphdef, state, *args):
|
|
223
254
|
model = nnx.merge(graphdef, state)
|
|
224
255
|
return model(*args)
|
|
225
256
|
|
|
226
257
|
logits_sharding = NamedSharding(
|
|
227
|
-
mesh,
|
|
258
|
+
mesh,
|
|
259
|
+
PartitionSpec(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR))
|
|
228
260
|
|
|
229
261
|
@functools.partial(
|
|
230
262
|
jax.jit,
|
|
@@ -293,6 +325,8 @@ def get_vllm_model(
|
|
|
293
325
|
rng: jax.Array,
|
|
294
326
|
mesh: Mesh,
|
|
295
327
|
):
|
|
328
|
+
model_dtype = to_torch_dtype(vllm_config.model_config.dtype)
|
|
329
|
+
vllm_config.model_config.dtype = model_dtype
|
|
296
330
|
from tpu_inference.models.vllm.vllm_model_wrapper import VllmModelWrapper
|
|
297
331
|
|
|
298
332
|
model = VllmModelWrapper(
|
|
@@ -318,24 +352,34 @@ def get_model(
|
|
|
318
352
|
impl = envs.MODEL_IMPL_TYPE
|
|
319
353
|
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
|
|
320
354
|
|
|
321
|
-
if impl == "
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
355
|
+
if impl == "auto":
|
|
356
|
+
# Resolve "auto" based on architecture
|
|
357
|
+
architectures = getattr(vllm_config.model_config.hf_config,
|
|
358
|
+
"architectures", [])
|
|
359
|
+
assert len(architectures) == 1, (
|
|
360
|
+
f"Expected exactly one architecture, got {len(architectures)}: "
|
|
361
|
+
f"{architectures}")
|
|
362
|
+
arch = architectures[0]
|
|
363
|
+
impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx"
|
|
364
|
+
logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")
|
|
365
|
+
|
|
366
|
+
match impl:
|
|
367
|
+
case "flax_nnx":
|
|
368
|
+
try:
|
|
369
|
+
# Try to load the flax model first
|
|
370
|
+
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
|
|
371
|
+
except UnsupportedArchitectureError as e:
|
|
372
|
+
# Convert the error message to a string to check its contents
|
|
373
|
+
error_msg = str(e)
|
|
374
|
+
|
|
375
|
+
logger.warning(error_msg)
|
|
376
|
+
|
|
377
|
+
# Fall back to the vLLM model and updating the dtype accordingly
|
|
378
|
+
return get_vllm_model(vllm_config, rng, mesh)
|
|
379
|
+
case "vllm":
|
|
334
380
|
return get_vllm_model(vllm_config, rng, mesh)
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
else:
|
|
338
|
-
raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
|
|
381
|
+
case _:
|
|
382
|
+
raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")
|
|
339
383
|
|
|
340
384
|
|
|
341
385
|
def _validate_model_interface(model: Any) -> None:
|
|
@@ -421,6 +465,17 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
421
465
|
"This is a JAX model and does not implement the PyTorch forward method."
|
|
422
466
|
)
|
|
423
467
|
|
|
468
|
+
# Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
|
|
469
|
+
def unimplemented_get_input_embeddings(
|
|
470
|
+
self,
|
|
471
|
+
input_ids: "torch.Tensor",
|
|
472
|
+
positions: "torch.Tensor",
|
|
473
|
+
inputs_embeds: Optional["torch.Tensor"] = None,
|
|
474
|
+
) -> "torch.Tensor":
|
|
475
|
+
raise NotImplementedError(
|
|
476
|
+
"This is a JAX model and does not implement the PyTorch get_input_embeddings method."
|
|
477
|
+
)
|
|
478
|
+
|
|
424
479
|
# We need a custom __init__ that only calls torch.nn.Module's init,
|
|
425
480
|
# to avoid triggering JAX logic when vLLM inspects the class.
|
|
426
481
|
def wrapper_init(self, *args, **kwargs):
|
|
@@ -434,6 +489,7 @@ def register_model(arch: str, model: Any) -> None:
|
|
|
434
489
|
{
|
|
435
490
|
"__init__": wrapper_init,
|
|
436
491
|
"forward": unimplemented_forward,
|
|
492
|
+
"get_input_embeddings": unimplemented_get_input_embeddings,
|
|
437
493
|
# Prevent vLLM from trying to load weights into this dummy class.
|
|
438
494
|
"load_weights": lambda self, *args, **kwargs: None,
|
|
439
495
|
})
|