tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -13,10 +13,12 @@ from typing import Any, Optional
|
|
|
13
13
|
import jax
|
|
14
14
|
import jax.numpy as jnp
|
|
15
15
|
import torch
|
|
16
|
+
import torchax
|
|
16
17
|
from flax import nnx
|
|
17
18
|
from jax.sharding import Mesh, NamedSharding
|
|
18
19
|
from jax.sharding import PartitionSpec as P
|
|
19
20
|
from safetensors import safe_open
|
|
21
|
+
from vllm.config import VllmConfig
|
|
20
22
|
|
|
21
23
|
from tpu_inference import envs, utils
|
|
22
24
|
from tpu_inference.logger import init_logger
|
|
@@ -65,7 +67,13 @@ def transpose_params(param_key: str, param_tensor: jax.Array, transpose_map):
|
|
|
65
67
|
def reshape_params(param_key: str, param_tensor: jax.Array, shape_map):
|
|
66
68
|
for key, new_shape in shape_map.items():
|
|
67
69
|
if key in param_key:
|
|
68
|
-
|
|
70
|
+
try:
|
|
71
|
+
#TODO:(gpolovets) Add validation on whether reshape preserves data layout.
|
|
72
|
+
return jnp.reshape(param_tensor, new_shape)
|
|
73
|
+
except TypeError:
|
|
74
|
+
raise TypeError(
|
|
75
|
+
f"Cannot reshape for key={key}, new_shape={new_shape}, param_shape={param_tensor.shape}"
|
|
76
|
+
)
|
|
69
77
|
return param_tensor # Base case / no-op
|
|
70
78
|
|
|
71
79
|
|
|
@@ -265,15 +273,15 @@ def get_default_maps(model_config, mesh: Mesh,
|
|
|
265
273
|
bias_pad_map=bias_pad_keys)
|
|
266
274
|
|
|
267
275
|
|
|
268
|
-
def
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
276
|
+
def _load_and_shard_weight(vllm_config,
|
|
277
|
+
params: nnx.State,
|
|
278
|
+
shardings: Any,
|
|
279
|
+
metadata_map: MetadataMap,
|
|
280
|
+
mesh: Mesh,
|
|
281
|
+
hf_key: str,
|
|
282
|
+
hf_weight: jax.Array,
|
|
283
|
+
keep_original_dtype_keys_regex: list[str]
|
|
284
|
+
| None = None):
|
|
277
285
|
name_map = metadata_map.name_map
|
|
278
286
|
reshape_keys = metadata_map.reshape_map
|
|
279
287
|
bias_reshape_keys = metadata_map.bias_reshape_map
|
|
@@ -290,6 +298,118 @@ def _load_hf_weights_on_thread(vllm_config,
|
|
|
290
298
|
head_dim = utils.get_padded_head_dim(head_dim_original)
|
|
291
299
|
head_dim_pad = head_dim - head_dim_original
|
|
292
300
|
|
|
301
|
+
# Check if the key should retain its original dtype
|
|
302
|
+
keep_original_dtype = False
|
|
303
|
+
if keep_original_dtype_keys_regex:
|
|
304
|
+
for pattern in keep_original_dtype_keys_regex:
|
|
305
|
+
if re.match(pattern, hf_key):
|
|
306
|
+
keep_original_dtype = True
|
|
307
|
+
break
|
|
308
|
+
|
|
309
|
+
# Converting to config's dtype
|
|
310
|
+
if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
|
|
311
|
+
logger.warning(
|
|
312
|
+
f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
|
|
313
|
+
)
|
|
314
|
+
hf_weight = hf_weight.astype(model_config.dtype)
|
|
315
|
+
|
|
316
|
+
if hf_key.endswith(".weight"):
|
|
317
|
+
hf_key = hf_key.removesuffix(".weight")
|
|
318
|
+
|
|
319
|
+
# Find the corresponding model key using the HF key
|
|
320
|
+
if "layers" in hf_key:
|
|
321
|
+
layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
|
|
322
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
|
|
323
|
+
model_key = name_map[layer_key]
|
|
324
|
+
model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
|
|
325
|
+
elif "blocks" in hf_key:
|
|
326
|
+
layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
|
|
327
|
+
layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
|
|
328
|
+
model_key = name_map[layer_key]
|
|
329
|
+
model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
|
|
330
|
+
else:
|
|
331
|
+
if hf_key not in name_map and hf_key == "lm_head":
|
|
332
|
+
logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
333
|
+
return
|
|
334
|
+
if hf_key not in name_map and "t2d" in hf_key:
|
|
335
|
+
logger.warning(
|
|
336
|
+
f"Skip loading {hf_key} as it's not used in eagle-3 for now")
|
|
337
|
+
return
|
|
338
|
+
model_key = name_map.get(hf_key, hf_key)
|
|
339
|
+
|
|
340
|
+
model_weight, model_sharding = get_param_and_sharding(
|
|
341
|
+
params, shardings, model_key)
|
|
342
|
+
|
|
343
|
+
logger.debug(
|
|
344
|
+
"before transform | "
|
|
345
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
if hf_key.endswith(".bias"):
|
|
349
|
+
for key in bias_reshape_keys:
|
|
350
|
+
if key in hf_key:
|
|
351
|
+
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
352
|
+
if head_dim_pad > 0:
|
|
353
|
+
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad)))
|
|
354
|
+
break
|
|
355
|
+
else:
|
|
356
|
+
for key in reshape_keys:
|
|
357
|
+
if key in hf_key:
|
|
358
|
+
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
359
|
+
if head_dim_pad > 0:
|
|
360
|
+
if "o_proj" in key:
|
|
361
|
+
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
362
|
+
(0, head_dim_pad)))
|
|
363
|
+
else:
|
|
364
|
+
hf_weight = jnp.pad(hf_weight,
|
|
365
|
+
((0, 0), (0, head_dim_pad),
|
|
366
|
+
(0, 0)))
|
|
367
|
+
break
|
|
368
|
+
for key in transpose_keys:
|
|
369
|
+
if key in hf_key:
|
|
370
|
+
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
371
|
+
break
|
|
372
|
+
|
|
373
|
+
# Pad num-kv-heads
|
|
374
|
+
if hf_key.endswith(".bias"):
|
|
375
|
+
for key, value in bias_pad_keys.items():
|
|
376
|
+
dim = value[0]
|
|
377
|
+
dim_size = value[1]
|
|
378
|
+
if key in hf_key and dim_size != 0:
|
|
379
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
380
|
+
break
|
|
381
|
+
else:
|
|
382
|
+
for key, value in pad_keys.items():
|
|
383
|
+
dim = value[0]
|
|
384
|
+
dim_size = value[1]
|
|
385
|
+
if key in hf_key and dim_size != 0:
|
|
386
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
387
|
+
break
|
|
388
|
+
|
|
389
|
+
logger.debug(
|
|
390
|
+
"after transform | "
|
|
391
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
if head_dim_pad == 0:
|
|
395
|
+
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
396
|
+
|
|
397
|
+
# Update the model weight
|
|
398
|
+
spec = model_weight.sharding.spec if isinstance(
|
|
399
|
+
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
400
|
+
model_weight.value = shard(hf_weight, spec)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _load_hf_weights_on_thread(
|
|
404
|
+
vllm_config: VllmConfig,
|
|
405
|
+
params: nnx.State,
|
|
406
|
+
metadata_map: "MetadataMap",
|
|
407
|
+
mesh: Mesh,
|
|
408
|
+
weights_file: str,
|
|
409
|
+
filter_regex: Optional[str] = None,
|
|
410
|
+
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
411
|
+
):
|
|
412
|
+
"""Loads weights from a single weights file."""
|
|
293
413
|
try:
|
|
294
414
|
shardings = nnx.get_named_sharding(params, mesh)
|
|
295
415
|
except TypeError:
|
|
@@ -297,160 +417,88 @@ def _load_hf_weights_on_thread(vllm_config,
|
|
|
297
417
|
|
|
298
418
|
for hf_key, hf_weight in model_weights_single_file_generator(
|
|
299
419
|
weights_file, framework="flax", filter_regex=filter_regex):
|
|
420
|
+
_load_and_shard_weight(
|
|
421
|
+
vllm_config,
|
|
422
|
+
params,
|
|
423
|
+
shardings,
|
|
424
|
+
metadata_map,
|
|
425
|
+
mesh,
|
|
426
|
+
hf_key,
|
|
427
|
+
hf_weight,
|
|
428
|
+
keep_original_dtype_keys_regex,
|
|
429
|
+
)
|
|
300
430
|
|
|
301
|
-
# Check if the key should be excluded
|
|
302
|
-
if exclude_regex:
|
|
303
|
-
should_exclude = False
|
|
304
|
-
for pattern in exclude_regex:
|
|
305
|
-
if re.search(pattern, hf_key):
|
|
306
|
-
logger.info(
|
|
307
|
-
f"Excluding {hf_key} based on pattern {pattern}")
|
|
308
|
-
should_exclude = True
|
|
309
|
-
break
|
|
310
|
-
if should_exclude:
|
|
311
|
-
continue
|
|
312
|
-
|
|
313
|
-
# Check if the key should retain its original dtype
|
|
314
|
-
keep_original_dtype = False
|
|
315
|
-
if keep_original_dtype_keys_regex:
|
|
316
|
-
for pattern in keep_original_dtype_keys_regex:
|
|
317
|
-
if re.match(pattern, hf_key):
|
|
318
|
-
keep_original_dtype = True
|
|
319
|
-
break
|
|
320
431
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
if
|
|
344
|
-
logger.warning(
|
|
345
|
-
f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
346
|
-
continue
|
|
347
|
-
if hf_key not in name_map and "t2d" in hf_key:
|
|
348
|
-
logger.warning(
|
|
349
|
-
f"Skip loading {hf_key} as it's not used in eagle-3 for now"
|
|
350
|
-
)
|
|
432
|
+
def load_hf_weights(
|
|
433
|
+
vllm_config: VllmConfig,
|
|
434
|
+
model: nnx.Module,
|
|
435
|
+
metadata_map: "MetadataMap",
|
|
436
|
+
mesh: Mesh,
|
|
437
|
+
filter_regex: Optional[str] = None,
|
|
438
|
+
is_draft_model: bool = False,
|
|
439
|
+
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
440
|
+
):
|
|
441
|
+
"""Load weights into a JAX model from either an iterator or files."""
|
|
442
|
+
params = nnx.state(model)
|
|
443
|
+
try:
|
|
444
|
+
shardings = nnx.get_named_sharding(params, mesh)
|
|
445
|
+
except TypeError:
|
|
446
|
+
shardings = params
|
|
447
|
+
weights_iterator = None
|
|
448
|
+
if hasattr(vllm_config.model_config, "model_weights_iterator"):
|
|
449
|
+
weights_iterator = vllm_config.model_config.model_weights_iterator
|
|
450
|
+
env = torchax.default_env()
|
|
451
|
+
# The weights_iterator is used in RunAI model streamer integration.
|
|
452
|
+
if weights_iterator is not None:
|
|
453
|
+
for hf_key, hf_weight in weights_iterator:
|
|
454
|
+
if filter_regex and not re.match(filter_regex, hf_key):
|
|
351
455
|
continue
|
|
352
|
-
model_key = name_map.get(hf_key, hf_key)
|
|
353
|
-
model_weight, model_sharding = get_param_and_sharding(
|
|
354
|
-
params, shardings, model_key)
|
|
355
|
-
|
|
356
|
-
logger.debug(
|
|
357
|
-
"before transform | "
|
|
358
|
-
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
359
|
-
)
|
|
360
456
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
365
|
-
if head_dim_pad > 0:
|
|
366
|
-
hf_weight = jnp.pad(hf_weight,
|
|
367
|
-
((0, 0), (0, head_dim_pad)))
|
|
368
|
-
break
|
|
369
|
-
else:
|
|
370
|
-
for key in reshape_keys:
|
|
371
|
-
if key in hf_key:
|
|
372
|
-
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
373
|
-
if head_dim_pad > 0:
|
|
374
|
-
if "o_proj" in key:
|
|
375
|
-
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
376
|
-
(0, head_dim_pad)))
|
|
377
|
-
else:
|
|
378
|
-
hf_weight = jnp.pad(hf_weight,
|
|
379
|
-
((0, 0), (0, head_dim_pad),
|
|
380
|
-
(0, 0)))
|
|
381
|
-
break
|
|
382
|
-
for key in transpose_keys:
|
|
383
|
-
if key in hf_key:
|
|
384
|
-
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
385
|
-
break
|
|
386
|
-
|
|
387
|
-
# Pad num-kv-heads
|
|
388
|
-
if hf_key.endswith(".bias"):
|
|
389
|
-
for key, value in bias_pad_keys.items():
|
|
390
|
-
dim = value[0]
|
|
391
|
-
dim_size = value[1]
|
|
392
|
-
if key in hf_key and dim_size != 0:
|
|
393
|
-
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
394
|
-
break
|
|
395
|
-
else:
|
|
396
|
-
for key, value in pad_keys.items():
|
|
397
|
-
dim = value[0]
|
|
398
|
-
dim_size = value[1]
|
|
399
|
-
if key in hf_key and dim_size != 0:
|
|
400
|
-
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
401
|
-
break
|
|
402
|
-
|
|
403
|
-
logger.debug(
|
|
404
|
-
"after transform | "
|
|
405
|
-
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
406
|
-
)
|
|
457
|
+
# Since the weights_iterator yields Pytorch tensors (torch.Tensor),
|
|
458
|
+
# we need to convert them to JAX arrays (jax.Array).
|
|
459
|
+
hf_weight_jax = env.t2j_copy(hf_weight)
|
|
407
460
|
|
|
408
|
-
|
|
409
|
-
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
410
|
-
|
|
411
|
-
# Update the model weight
|
|
412
|
-
spec = model_weight.sharding.spec if isinstance(
|
|
413
|
-
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
414
|
-
model_weight.value = shard(hf_weight, spec)
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
def load_hf_weights(vllm_config,
|
|
418
|
-
model: nnx.Module,
|
|
419
|
-
metadata_map: MetadataMap,
|
|
420
|
-
mesh: Mesh,
|
|
421
|
-
filter_regex: str | None = None,
|
|
422
|
-
is_draft_model: bool = False,
|
|
423
|
-
keep_original_dtype_keys_regex: list[str] | None = None,
|
|
424
|
-
exclude_regex: list[str] | None = None):
|
|
425
|
-
"""Load weights from all model weights files to the model, run in multi threads."""
|
|
426
|
-
if is_draft_model:
|
|
427
|
-
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
428
|
-
else:
|
|
429
|
-
model_path = vllm_config.model_config.model
|
|
430
|
-
weights_files = get_model_weights_files(
|
|
431
|
-
model_path, vllm_config.load_config.download_dir)
|
|
432
|
-
params = nnx.state(model)
|
|
433
|
-
max_workers = min(64, len(weights_files))
|
|
434
|
-
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
435
|
-
# Because multi-threading would cause different JAX processes to load
|
|
436
|
-
# different weights at the same time.
|
|
437
|
-
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
438
|
-
max_workers = 1
|
|
439
|
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
440
|
-
futures = [
|
|
441
|
-
executor.submit(
|
|
442
|
-
_load_hf_weights_on_thread,
|
|
461
|
+
_load_and_shard_weight(
|
|
443
462
|
vllm_config,
|
|
444
463
|
params,
|
|
464
|
+
shardings,
|
|
445
465
|
metadata_map,
|
|
446
466
|
mesh,
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
keep_original_dtype_keys_regex
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
467
|
+
hf_key,
|
|
468
|
+
hf_weight_jax,
|
|
469
|
+
keep_original_dtype_keys_regex,
|
|
470
|
+
)
|
|
471
|
+
else:
|
|
472
|
+
# File-based path (multi-threaded)
|
|
473
|
+
if is_draft_model:
|
|
474
|
+
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
475
|
+
else:
|
|
476
|
+
model_path = vllm_config.model_config.model
|
|
477
|
+
weights_files = get_model_weights_files(
|
|
478
|
+
model_path, vllm_config.load_config.download_dir)
|
|
479
|
+
max_workers = min(64, len(weights_files))
|
|
480
|
+
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
481
|
+
# Because multi-threading would cause different JAX processes to load
|
|
482
|
+
# different weights at the same time.
|
|
483
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
484
|
+
max_workers = 1
|
|
485
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
486
|
+
futures = [
|
|
487
|
+
executor.submit(
|
|
488
|
+
_load_hf_weights_on_thread,
|
|
489
|
+
vllm_config,
|
|
490
|
+
params,
|
|
491
|
+
metadata_map,
|
|
492
|
+
mesh,
|
|
493
|
+
weights_file,
|
|
494
|
+
filter_regex=filter_regex,
|
|
495
|
+
keep_original_dtype_keys_regex=
|
|
496
|
+
keep_original_dtype_keys_regex,
|
|
497
|
+
) for weights_file in weights_files
|
|
498
|
+
]
|
|
499
|
+
for future in futures:
|
|
500
|
+
future.result()
|
|
501
|
+
|
|
454
502
|
check_all_loaded(params)
|
|
455
503
|
nnx.update(model, params)
|
|
456
504
|
|
|
@@ -9,6 +9,7 @@ import jax
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.nn
|
|
11
11
|
import torchax
|
|
12
|
+
import vllm.envs as vllm_envs
|
|
12
13
|
from flax.typing import PRNGKey
|
|
13
14
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
14
15
|
from torchax.interop import jax_view, torch_view
|
|
@@ -118,9 +119,16 @@ class VllmModelWrapper:
|
|
|
118
119
|
"torch._sync",
|
|
119
120
|
return_value=None) if use_random_weights else nullcontext()
|
|
120
121
|
|
|
122
|
+
# By default load weights to the CPU device first. If we are running
|
|
123
|
+
# under Pathways, this would cause weights to be loaded on a CPU-only
|
|
124
|
+
# node, so we'll need to remove this context.
|
|
125
|
+
jax_context = jax.default_device(
|
|
126
|
+
jax.devices("cpu")
|
|
127
|
+
[0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
|
|
128
|
+
|
|
121
129
|
# Load the vLLM model and wrap it into a new model whose forward
|
|
122
130
|
# function can calculate the hidden_state and logits.
|
|
123
|
-
with load_context,
|
|
131
|
+
with load_context, jax_context:
|
|
124
132
|
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
|
|
125
133
|
lora_manager = None
|
|
126
134
|
if vllm_config_for_load.lora_config is not None:
|
|
@@ -213,7 +221,7 @@ class VllmModelWrapper:
|
|
|
213
221
|
@functools.partial(
|
|
214
222
|
jax.jit,
|
|
215
223
|
out_shardings=(NamedSharding(self.mesh,
|
|
216
|
-
PartitionSpec(
|
|
224
|
+
PartitionSpec("data", "model"))),
|
|
217
225
|
)
|
|
218
226
|
def compute_logits_func(
|
|
219
227
|
params_and_buffers: Any,
|
|
@@ -255,7 +263,6 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
|
|
|
255
263
|
vllm_config,
|
|
256
264
|
device,
|
|
257
265
|
model.embedding_modules,
|
|
258
|
-
model.embedding_padding_modules,
|
|
259
266
|
)
|
|
260
267
|
return lora_manager, lora_manager.create_lora_manager(model)
|
|
261
268
|
|
|
@@ -269,10 +276,9 @@ def replace_set_lora(model):
|
|
|
269
276
|
index: int,
|
|
270
277
|
lora_a: torch.Tensor,
|
|
271
278
|
lora_b: torch.Tensor,
|
|
272
|
-
embeddings_tensor: Optional[torch.Tensor],
|
|
273
279
|
):
|
|
274
280
|
with torchax.default_env():
|
|
275
|
-
self._original_set_lora(index, lora_a, lora_b
|
|
281
|
+
self._original_set_lora(index, lora_a, lora_b)
|
|
276
282
|
|
|
277
283
|
def _tpu_reset_lora(self, index: int):
|
|
278
284
|
with torchax.default_env():
|
|
@@ -3,36 +3,32 @@
|
|
|
3
3
|
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
|
|
4
4
|
|
|
5
5
|
import jax.numpy as jnp
|
|
6
|
+
import torch
|
|
6
7
|
import vllm.envs as vllm_envs
|
|
7
|
-
from torchax.ops.mappings import j2t_dtype
|
|
8
8
|
from tpu_info import device
|
|
9
9
|
from vllm.inputs import ProcessorInputs, PromptType
|
|
10
10
|
from vllm.platforms.interface import Platform, PlatformEnum
|
|
11
|
-
from vllm.sampling_params import SamplingParams, SamplingType
|
|
12
11
|
|
|
13
12
|
from tpu_inference import envs
|
|
14
13
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
15
14
|
from tpu_inference.logger import init_logger
|
|
16
15
|
|
|
17
16
|
if TYPE_CHECKING:
|
|
18
|
-
from vllm.attention.backends.registry import
|
|
17
|
+
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
19
18
|
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
|
20
19
|
from vllm.pooling_params import PoolingParams
|
|
20
|
+
from vllm.sampling_params import SamplingParams, SamplingType
|
|
21
21
|
else:
|
|
22
22
|
BlockSize = None
|
|
23
23
|
ModelConfig = None
|
|
24
24
|
VllmConfig = None
|
|
25
25
|
PoolingParams = None
|
|
26
|
-
|
|
26
|
+
AttentionBackendEnum = None
|
|
27
|
+
SamplingParams = None
|
|
28
|
+
SamplingType = None
|
|
27
29
|
|
|
28
30
|
logger = init_logger(__name__)
|
|
29
31
|
|
|
30
|
-
_DTYPE: dict[str, jnp.dtype] = {
|
|
31
|
-
"bfloat16": jnp.bfloat16,
|
|
32
|
-
"float": jnp.float32,
|
|
33
|
-
"float32": jnp.float32,
|
|
34
|
-
}
|
|
35
|
-
|
|
36
32
|
|
|
37
33
|
class TpuPlatform(Platform):
|
|
38
34
|
_enum = PlatformEnum.TPU
|
|
@@ -49,25 +45,22 @@ class TpuPlatform(Platform):
|
|
|
49
45
|
|
|
50
46
|
additional_env_vars: list[str] = [
|
|
51
47
|
"PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
|
|
52
|
-
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
|
|
48
|
+
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE",
|
|
49
|
+
"NEW_MODEL_DESIGN"
|
|
53
50
|
]
|
|
54
51
|
|
|
55
52
|
@classmethod
|
|
56
|
-
def get_attn_backend_cls(cls, selected_backend: "
|
|
57
|
-
dtype: jnp.dtype,
|
|
58
|
-
|
|
59
|
-
has_sink: bool, use_sparse: bool,
|
|
60
|
-
attn_type: Any) -> str:
|
|
61
|
-
from vllm.attention.backends.registry import
|
|
62
|
-
if selected_backend !=
|
|
53
|
+
def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
|
|
54
|
+
head_size: int, dtype: jnp.dtype,
|
|
55
|
+
kv_cache_dtype: Optional[str], block_size: int,
|
|
56
|
+
use_mla: bool, has_sink: bool, use_sparse: bool,
|
|
57
|
+
use_mm_prefix: bool, attn_type: Any) -> str:
|
|
58
|
+
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
59
|
+
if selected_backend != AttentionBackendEnum.PALLAS:
|
|
63
60
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
64
61
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
|
|
68
|
-
else:
|
|
69
|
-
logger.info("Using Pallas backend.")
|
|
70
|
-
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
|
62
|
+
logger.info("Using Pallas V1 backend.")
|
|
63
|
+
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
|
|
71
64
|
|
|
72
65
|
@classmethod
|
|
73
66
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
@@ -82,6 +75,14 @@ class TpuPlatform(Platform):
|
|
|
82
75
|
logger.warning(f"Error getting device name: {e}")
|
|
83
76
|
return 'TPU'
|
|
84
77
|
|
|
78
|
+
@classmethod
|
|
79
|
+
def fp8_dtype(cls) -> torch.dtype:
|
|
80
|
+
if cls.get_device_name().lower() == "tpu v6e":
|
|
81
|
+
logger.info(
|
|
82
|
+
"Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
|
|
83
|
+
return torch.float8_e5m2
|
|
84
|
+
return torch.float8_e4m3fn
|
|
85
|
+
|
|
85
86
|
@classmethod
|
|
86
87
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
87
88
|
raise NotImplementedError
|
|
@@ -132,6 +133,7 @@ class TpuPlatform(Platform):
|
|
|
132
133
|
# For v0, the default block size is 16.
|
|
133
134
|
if cache_config and cache_config.block_size is None:
|
|
134
135
|
cache_config.block_size = cast(BlockSize, 16)
|
|
136
|
+
|
|
135
137
|
compilation_config = vllm_config.compilation_config
|
|
136
138
|
|
|
137
139
|
# TPU only supports DYNAMO_TRACE_ONCE compilation level
|
|
@@ -142,27 +144,6 @@ class TpuPlatform(Platform):
|
|
|
142
144
|
if compilation_config.backend == "":
|
|
143
145
|
compilation_config.backend = "openxla"
|
|
144
146
|
|
|
145
|
-
# If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
|
|
146
|
-
impl = envs.MODEL_IMPL_TYPE
|
|
147
|
-
|
|
148
|
-
# NOTE(xiang): convert dtype to jnp.dtype
|
|
149
|
-
# NOTE(wenlong): skip this logic for mm model preprocessing
|
|
150
|
-
# For mm model preprocessors, it may need the output dtype to be torch.
|
|
151
|
-
# In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
|
|
152
|
-
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
|
|
153
|
-
if not isinstance(vllm_config.model_config.dtype, str):
|
|
154
|
-
logger.warning(
|
|
155
|
-
"The model dtype is not properly set for JAX backend. "
|
|
156
|
-
"Overwriting it to jnp.bfloat16")
|
|
157
|
-
vllm_config.model_config.dtype = jnp.bfloat16
|
|
158
|
-
else:
|
|
159
|
-
vllm_config.model_config.dtype = _DTYPE.get(
|
|
160
|
-
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
161
|
-
|
|
162
|
-
if impl == "vllm":
|
|
163
|
-
vllm_config.model_config.dtype = j2t_dtype(
|
|
164
|
-
vllm_config.model_config.dtype.dtype)
|
|
165
|
-
|
|
166
147
|
# TODO(cuiq): remove this dependency.
|
|
167
148
|
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
|
168
149
|
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
|
@@ -170,8 +151,7 @@ class TpuPlatform(Platform):
|
|
|
170
151
|
min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config)
|
|
171
152
|
if min_page_size > cache_config.block_size:
|
|
172
153
|
logger.warning(
|
|
173
|
-
"Increase the page size from %s to %s to
|
|
174
|
-
"no SMEM OOM",
|
|
154
|
+
"Increase the page size from %s to %s to avoid SMEM OOM",
|
|
175
155
|
cache_config.block_size,
|
|
176
156
|
min_page_size,
|
|
177
157
|
)
|
|
@@ -246,10 +226,11 @@ class TpuPlatform(Platform):
|
|
|
246
226
|
def validate_request(
|
|
247
227
|
cls,
|
|
248
228
|
prompt: PromptType,
|
|
249
|
-
params: Union[SamplingParams, PoolingParams],
|
|
229
|
+
params: Union["SamplingParams", PoolingParams],
|
|
250
230
|
processed_inputs: ProcessorInputs,
|
|
251
231
|
) -> None:
|
|
252
232
|
"""Raises if this request is unsupported on this platform"""
|
|
233
|
+
from vllm.sampling_params import SamplingParams, SamplingType
|
|
253
234
|
|
|
254
235
|
if isinstance(params, SamplingParams):
|
|
255
236
|
if params.sampling_type == SamplingType.RANDOM_SEED:
|