tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +2 -1
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +43 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
- tpu_inference/platforms/tpu_platform.py +15 -2
- tpu_inference/runner/compilation_manager.py +58 -33
- tpu_inference/runner/kv_cache_manager.py +9 -3
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +203 -102
- tpu_inference/spec_decode/jax/eagle3.py +19 -2
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +160 -23
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
|
@@ -13,12 +13,14 @@ from typing import Any, Optional
|
|
|
13
13
|
import jax
|
|
14
14
|
import jax.numpy as jnp
|
|
15
15
|
import torch
|
|
16
|
+
import torchax
|
|
16
17
|
from flax import nnx
|
|
17
18
|
from jax.sharding import Mesh, NamedSharding
|
|
18
19
|
from jax.sharding import PartitionSpec as P
|
|
19
20
|
from safetensors import safe_open
|
|
21
|
+
from vllm.config import VllmConfig
|
|
20
22
|
|
|
21
|
-
from tpu_inference import utils
|
|
23
|
+
from tpu_inference import envs, utils
|
|
22
24
|
from tpu_inference.logger import init_logger
|
|
23
25
|
from tpu_inference.models.jax.utils import file_utils
|
|
24
26
|
|
|
@@ -197,12 +199,11 @@ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
|
|
|
197
199
|
return jax.device_put(x, shardings)
|
|
198
200
|
|
|
199
201
|
|
|
200
|
-
def get_default_maps(
|
|
202
|
+
def get_default_maps(model_config, mesh: Mesh,
|
|
201
203
|
name_map: dict[str, str]) -> MetadataMap:
|
|
202
204
|
"""Load weights from one model weights file to the model, run on single thread."""
|
|
203
205
|
sharding_size = mesh.shape["model"]
|
|
204
206
|
|
|
205
|
-
model_config = vllm_config.model_config
|
|
206
207
|
hf_config = model_config.hf_config
|
|
207
208
|
|
|
208
209
|
num_heads = hf_config.num_attention_heads
|
|
@@ -266,14 +267,15 @@ def get_default_maps(vllm_config, mesh: Mesh,
|
|
|
266
267
|
bias_pad_map=bias_pad_keys)
|
|
267
268
|
|
|
268
269
|
|
|
269
|
-
def
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
270
|
+
def _load_and_shard_weight(vllm_config,
|
|
271
|
+
params: nnx.State,
|
|
272
|
+
shardings: Any,
|
|
273
|
+
metadata_map: MetadataMap,
|
|
274
|
+
mesh: Mesh,
|
|
275
|
+
hf_key: str,
|
|
276
|
+
hf_weight: jax.Array,
|
|
277
|
+
keep_original_dtype_keys_regex: list[str]
|
|
278
|
+
| None = None):
|
|
277
279
|
name_map = metadata_map.name_map
|
|
278
280
|
reshape_keys = metadata_map.reshape_map
|
|
279
281
|
bias_reshape_keys = metadata_map.bias_reshape_map
|
|
@@ -290,6 +292,118 @@ def _load_hf_weights_on_thread(vllm_config,
|
|
|
290
292
|
head_dim = utils.get_padded_head_dim(head_dim_original)
|
|
291
293
|
head_dim_pad = head_dim - head_dim_original
|
|
292
294
|
|
|
295
|
+
# Check if the key should retain its original dtype
|
|
296
|
+
keep_original_dtype = False
|
|
297
|
+
if keep_original_dtype_keys_regex:
|
|
298
|
+
for pattern in keep_original_dtype_keys_regex:
|
|
299
|
+
if re.match(pattern, hf_key):
|
|
300
|
+
keep_original_dtype = True
|
|
301
|
+
break
|
|
302
|
+
|
|
303
|
+
# Converting to config's dtype
|
|
304
|
+
if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
|
|
305
|
+
logger.warning(
|
|
306
|
+
f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
|
|
307
|
+
)
|
|
308
|
+
hf_weight = hf_weight.astype(model_config.dtype)
|
|
309
|
+
|
|
310
|
+
if hf_key.endswith(".weight"):
|
|
311
|
+
hf_key = hf_key.removesuffix(".weight")
|
|
312
|
+
|
|
313
|
+
# Find the corresponding model key using the HF key
|
|
314
|
+
if "layers" in hf_key:
|
|
315
|
+
layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
|
|
316
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
|
|
317
|
+
model_key = name_map[layer_key]
|
|
318
|
+
model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
|
|
319
|
+
elif "blocks" in hf_key:
|
|
320
|
+
layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
|
|
321
|
+
layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
|
|
322
|
+
model_key = name_map[layer_key]
|
|
323
|
+
model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
|
|
324
|
+
else:
|
|
325
|
+
if hf_key not in name_map and hf_key == "lm_head":
|
|
326
|
+
logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
327
|
+
return
|
|
328
|
+
if hf_key not in name_map and "t2d" in hf_key:
|
|
329
|
+
logger.warning(
|
|
330
|
+
f"Skip loading {hf_key} as it's not used in eagle-3 for now")
|
|
331
|
+
return
|
|
332
|
+
model_key = name_map.get(hf_key, hf_key)
|
|
333
|
+
|
|
334
|
+
model_weight, model_sharding = get_param_and_sharding(
|
|
335
|
+
params, shardings, model_key)
|
|
336
|
+
|
|
337
|
+
logger.debug(
|
|
338
|
+
"before transform | "
|
|
339
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
if hf_key.endswith(".bias"):
|
|
343
|
+
for key in bias_reshape_keys:
|
|
344
|
+
if key in hf_key:
|
|
345
|
+
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
346
|
+
if head_dim_pad > 0:
|
|
347
|
+
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad)))
|
|
348
|
+
break
|
|
349
|
+
else:
|
|
350
|
+
for key in reshape_keys:
|
|
351
|
+
if key in hf_key:
|
|
352
|
+
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
353
|
+
if head_dim_pad > 0:
|
|
354
|
+
if "o_proj" in key:
|
|
355
|
+
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
356
|
+
(0, head_dim_pad)))
|
|
357
|
+
else:
|
|
358
|
+
hf_weight = jnp.pad(hf_weight,
|
|
359
|
+
((0, 0), (0, head_dim_pad),
|
|
360
|
+
(0, 0)))
|
|
361
|
+
break
|
|
362
|
+
for key in transpose_keys:
|
|
363
|
+
if key in hf_key:
|
|
364
|
+
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
365
|
+
break
|
|
366
|
+
|
|
367
|
+
# Pad num-kv-heads
|
|
368
|
+
if hf_key.endswith(".bias"):
|
|
369
|
+
for key, value in bias_pad_keys.items():
|
|
370
|
+
dim = value[0]
|
|
371
|
+
dim_size = value[1]
|
|
372
|
+
if key in hf_key and dim_size != 0:
|
|
373
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
374
|
+
break
|
|
375
|
+
else:
|
|
376
|
+
for key, value in pad_keys.items():
|
|
377
|
+
dim = value[0]
|
|
378
|
+
dim_size = value[1]
|
|
379
|
+
if key in hf_key and dim_size != 0:
|
|
380
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
381
|
+
break
|
|
382
|
+
|
|
383
|
+
logger.debug(
|
|
384
|
+
"after transform | "
|
|
385
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
if head_dim_pad == 0:
|
|
389
|
+
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
390
|
+
|
|
391
|
+
# Update the model weight
|
|
392
|
+
spec = model_weight.sharding.spec if isinstance(
|
|
393
|
+
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
394
|
+
model_weight.value = shard(hf_weight, spec)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def _load_hf_weights_on_thread(
|
|
398
|
+
vllm_config: VllmConfig,
|
|
399
|
+
params: nnx.State,
|
|
400
|
+
metadata_map: "MetadataMap",
|
|
401
|
+
mesh: Mesh,
|
|
402
|
+
weights_file: str,
|
|
403
|
+
filter_regex: Optional[str] = None,
|
|
404
|
+
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
405
|
+
):
|
|
406
|
+
"""Loads weights from a single weights file."""
|
|
293
407
|
try:
|
|
294
408
|
shardings = nnx.get_named_sharding(params, mesh)
|
|
295
409
|
except TypeError:
|
|
@@ -297,147 +411,88 @@ def _load_hf_weights_on_thread(vllm_config,
|
|
|
297
411
|
|
|
298
412
|
for hf_key, hf_weight in model_weights_single_file_generator(
|
|
299
413
|
weights_file, framework="flax", filter_regex=filter_regex):
|
|
414
|
+
_load_and_shard_weight(
|
|
415
|
+
vllm_config,
|
|
416
|
+
params,
|
|
417
|
+
shardings,
|
|
418
|
+
metadata_map,
|
|
419
|
+
mesh,
|
|
420
|
+
hf_key,
|
|
421
|
+
hf_weight,
|
|
422
|
+
keep_original_dtype_keys_regex,
|
|
423
|
+
)
|
|
300
424
|
|
|
301
|
-
# Check if the key should retain its original dtype
|
|
302
|
-
keep_original_dtype = False
|
|
303
|
-
if keep_original_dtype_keys_regex:
|
|
304
|
-
for pattern in keep_original_dtype_keys_regex:
|
|
305
|
-
if re.match(pattern, hf_key):
|
|
306
|
-
keep_original_dtype = True
|
|
307
|
-
break
|
|
308
425
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
if
|
|
332
|
-
logger.warning(
|
|
333
|
-
f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
334
|
-
continue
|
|
335
|
-
if hf_key not in name_map and "t2d" in hf_key:
|
|
336
|
-
logger.warning(
|
|
337
|
-
f"Skip loading {hf_key} as it's not used in eagle-3 for now"
|
|
338
|
-
)
|
|
426
|
+
def load_hf_weights(
|
|
427
|
+
vllm_config: VllmConfig,
|
|
428
|
+
model: nnx.Module,
|
|
429
|
+
metadata_map: "MetadataMap",
|
|
430
|
+
mesh: Mesh,
|
|
431
|
+
filter_regex: Optional[str] = None,
|
|
432
|
+
is_draft_model: bool = False,
|
|
433
|
+
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
434
|
+
):
|
|
435
|
+
"""Load weights into a JAX model from either an iterator or files."""
|
|
436
|
+
params = nnx.state(model)
|
|
437
|
+
try:
|
|
438
|
+
shardings = nnx.get_named_sharding(params, mesh)
|
|
439
|
+
except TypeError:
|
|
440
|
+
shardings = params
|
|
441
|
+
weights_iterator = None
|
|
442
|
+
if hasattr(vllm_config.model_config, "model_weights_iterator"):
|
|
443
|
+
weights_iterator = vllm_config.model_config.model_weights_iterator
|
|
444
|
+
env = torchax.default_env()
|
|
445
|
+
# The weights_iterator is used in RunAI model streamer integration.
|
|
446
|
+
if weights_iterator is not None:
|
|
447
|
+
for hf_key, hf_weight in weights_iterator:
|
|
448
|
+
if filter_regex and not re.match(filter_regex, hf_key):
|
|
339
449
|
continue
|
|
340
|
-
model_key = name_map.get(hf_key, hf_key)
|
|
341
|
-
model_weight, model_sharding = get_param_and_sharding(
|
|
342
|
-
params, shardings, model_key)
|
|
343
450
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
)
|
|
348
|
-
|
|
349
|
-
if hf_key.endswith(".bias"):
|
|
350
|
-
for key in bias_reshape_keys:
|
|
351
|
-
if key in hf_key:
|
|
352
|
-
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
353
|
-
if head_dim_pad > 0:
|
|
354
|
-
hf_weight = jnp.pad(hf_weight,
|
|
355
|
-
((0, 0), (0, head_dim_pad)))
|
|
356
|
-
break
|
|
357
|
-
else:
|
|
358
|
-
for key in reshape_keys:
|
|
359
|
-
if key in hf_key:
|
|
360
|
-
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
361
|
-
if head_dim_pad > 0:
|
|
362
|
-
if "o_proj" in key:
|
|
363
|
-
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
364
|
-
(0, head_dim_pad)))
|
|
365
|
-
else:
|
|
366
|
-
hf_weight = jnp.pad(hf_weight,
|
|
367
|
-
((0, 0), (0, head_dim_pad),
|
|
368
|
-
(0, 0)))
|
|
369
|
-
break
|
|
370
|
-
for key in transpose_keys:
|
|
371
|
-
if key in hf_key:
|
|
372
|
-
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
373
|
-
break
|
|
374
|
-
|
|
375
|
-
# Pad num-kv-heads
|
|
376
|
-
if hf_key.endswith(".bias"):
|
|
377
|
-
for key, value in bias_pad_keys.items():
|
|
378
|
-
dim = value[0]
|
|
379
|
-
dim_size = value[1]
|
|
380
|
-
if key in hf_key and dim_size != 0:
|
|
381
|
-
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
382
|
-
break
|
|
383
|
-
else:
|
|
384
|
-
for key, value in pad_keys.items():
|
|
385
|
-
dim = value[0]
|
|
386
|
-
dim_size = value[1]
|
|
387
|
-
if key in hf_key and dim_size != 0:
|
|
388
|
-
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
389
|
-
break
|
|
390
|
-
|
|
391
|
-
logger.debug(
|
|
392
|
-
"after transform | "
|
|
393
|
-
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
394
|
-
)
|
|
451
|
+
# Since the weights_iterator yields Pytorch tensors (torch.Tensor),
|
|
452
|
+
# we need to convert them to JAX arrays (jax.Array).
|
|
453
|
+
hf_weight_jax = env.t2j_copy(hf_weight)
|
|
395
454
|
|
|
396
|
-
|
|
397
|
-
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
398
|
-
|
|
399
|
-
# Update the model weight
|
|
400
|
-
spec = model_weight.sharding.spec if isinstance(
|
|
401
|
-
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
402
|
-
model_weight.value = shard(hf_weight, spec)
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
def load_hf_weights(vllm_config,
|
|
406
|
-
model: nnx.Module,
|
|
407
|
-
metadata_map: MetadataMap,
|
|
408
|
-
mesh: Mesh,
|
|
409
|
-
filter_regex: str | None = None,
|
|
410
|
-
is_draft_model: bool = False,
|
|
411
|
-
keep_original_dtype_keys_regex: list[str] | None = None):
|
|
412
|
-
"""Load weights from all model weights files to the model, run in multi threads."""
|
|
413
|
-
if is_draft_model:
|
|
414
|
-
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
415
|
-
else:
|
|
416
|
-
model_path = vllm_config.model_config.model
|
|
417
|
-
weights_files = get_model_weights_files(
|
|
418
|
-
model_path, vllm_config.load_config.download_dir)
|
|
419
|
-
params = nnx.state(model)
|
|
420
|
-
max_workers = min(64, len(weights_files))
|
|
421
|
-
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
422
|
-
# Because multi-threading would cause different JAX processes to load
|
|
423
|
-
# different weights at the same time.
|
|
424
|
-
if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
|
|
425
|
-
max_workers = 1
|
|
426
|
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
427
|
-
futures = [
|
|
428
|
-
executor.submit(
|
|
429
|
-
_load_hf_weights_on_thread,
|
|
455
|
+
_load_and_shard_weight(
|
|
430
456
|
vllm_config,
|
|
431
457
|
params,
|
|
458
|
+
shardings,
|
|
432
459
|
metadata_map,
|
|
433
460
|
mesh,
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
keep_original_dtype_keys_regex
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
461
|
+
hf_key,
|
|
462
|
+
hf_weight_jax,
|
|
463
|
+
keep_original_dtype_keys_regex,
|
|
464
|
+
)
|
|
465
|
+
else:
|
|
466
|
+
# File-based path (multi-threaded)
|
|
467
|
+
if is_draft_model:
|
|
468
|
+
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
469
|
+
else:
|
|
470
|
+
model_path = vllm_config.model_config.model
|
|
471
|
+
weights_files = get_model_weights_files(
|
|
472
|
+
model_path, vllm_config.load_config.download_dir)
|
|
473
|
+
max_workers = min(64, len(weights_files))
|
|
474
|
+
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
475
|
+
# Because multi-threading would cause different JAX processes to load
|
|
476
|
+
# different weights at the same time.
|
|
477
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
478
|
+
max_workers = 1
|
|
479
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
480
|
+
futures = [
|
|
481
|
+
executor.submit(
|
|
482
|
+
_load_hf_weights_on_thread,
|
|
483
|
+
vllm_config,
|
|
484
|
+
params,
|
|
485
|
+
metadata_map,
|
|
486
|
+
mesh,
|
|
487
|
+
weights_file,
|
|
488
|
+
filter_regex=filter_regex,
|
|
489
|
+
keep_original_dtype_keys_regex=
|
|
490
|
+
keep_original_dtype_keys_regex,
|
|
491
|
+
) for weights_file in weights_files
|
|
492
|
+
]
|
|
493
|
+
for future in futures:
|
|
494
|
+
future.result()
|
|
495
|
+
|
|
441
496
|
check_all_loaded(params)
|
|
442
497
|
nnx.update(model, params)
|
|
443
498
|
|
|
@@ -9,6 +9,7 @@ import jax
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.nn
|
|
11
11
|
import torchax
|
|
12
|
+
import vllm.envs as vllm_envs
|
|
12
13
|
from flax.typing import PRNGKey
|
|
13
14
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
14
15
|
from torchax.interop import jax_view, torch_view
|
|
@@ -118,9 +119,16 @@ class VllmModelWrapper:
|
|
|
118
119
|
"torch._sync",
|
|
119
120
|
return_value=None) if use_random_weights else nullcontext()
|
|
120
121
|
|
|
122
|
+
# By default load weights to the CPU device first. If we are running
|
|
123
|
+
# under Pathways, this would cause weights to be loaded on a CPU-only
|
|
124
|
+
# node, so we'll need to remove this context.
|
|
125
|
+
jax_context = jax.default_device(
|
|
126
|
+
jax.devices("cpu")
|
|
127
|
+
[0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
|
|
128
|
+
|
|
121
129
|
# Load the vLLM model and wrap it into a new model whose forward
|
|
122
130
|
# function can calculate the hidden_state and logits.
|
|
123
|
-
with load_context:
|
|
131
|
+
with load_context, jax_context:
|
|
124
132
|
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
|
|
125
133
|
lora_manager = None
|
|
126
134
|
if vllm_config_for_load.lora_config is not None:
|
|
@@ -161,6 +169,7 @@ class VllmModelWrapper:
|
|
|
161
169
|
input_ids: jax.Array,
|
|
162
170
|
attn_metadata: AttentionMetadata,
|
|
163
171
|
input_embeds: jax.Array,
|
|
172
|
+
input_positions: jax.Array,
|
|
164
173
|
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
|
|
165
174
|
lora_metadata,
|
|
166
175
|
intermediate_tensors: JaxIntermediateTensors = None,
|
|
@@ -187,8 +196,8 @@ class VllmModelWrapper:
|
|
|
187
196
|
torch_view(params_and_buffers),
|
|
188
197
|
kwargs={
|
|
189
198
|
"input_ids": torch_view(input_ids),
|
|
190
|
-
"positions": torch_view(
|
|
191
|
-
"intermediate_tensors":
|
|
199
|
+
"positions": torch_view(input_positions),
|
|
200
|
+
"intermediate_tensors": None,
|
|
192
201
|
"inputs_embeds": None,
|
|
193
202
|
},
|
|
194
203
|
tie_weights=False,
|
|
@@ -268,10 +277,9 @@ def replace_set_lora(model):
|
|
|
268
277
|
index: int,
|
|
269
278
|
lora_a: torch.Tensor,
|
|
270
279
|
lora_b: torch.Tensor,
|
|
271
|
-
embeddings_tensor: Optional[torch.Tensor],
|
|
272
280
|
):
|
|
273
281
|
with torchax.default_env():
|
|
274
|
-
self._original_set_lora(index, lora_a, lora_b
|
|
282
|
+
self._original_set_lora(index, lora_a, lora_b)
|
|
275
283
|
|
|
276
284
|
def _tpu_reset_lora(self, index: int):
|
|
277
285
|
with torchax.default_env():
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
|
|
5
4
|
|
|
6
5
|
import jax.numpy as jnp
|
|
6
|
+
import torch
|
|
7
7
|
import vllm.envs as vllm_envs
|
|
8
8
|
from torchax.ops.mappings import j2t_dtype
|
|
9
9
|
from tpu_info import device
|
|
@@ -83,6 +83,14 @@ class TpuPlatform(Platform):
|
|
|
83
83
|
logger.warning(f"Error getting device name: {e}")
|
|
84
84
|
return 'TPU'
|
|
85
85
|
|
|
86
|
+
@classmethod
|
|
87
|
+
def fp8_dtype(cls) -> torch.dtype:
|
|
88
|
+
if cls.get_device_name().lower() == "tpu v6e":
|
|
89
|
+
logger.info(
|
|
90
|
+
"Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
|
|
91
|
+
return torch.float8_e5m2
|
|
92
|
+
return torch.float8_e4m3fn
|
|
93
|
+
|
|
86
94
|
@classmethod
|
|
87
95
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
88
96
|
raise NotImplementedError
|
|
@@ -133,6 +141,7 @@ class TpuPlatform(Platform):
|
|
|
133
141
|
# For v0, the default block size is 16.
|
|
134
142
|
if cache_config and cache_config.block_size is None:
|
|
135
143
|
cache_config.block_size = cast(BlockSize, 16)
|
|
144
|
+
|
|
136
145
|
compilation_config = vllm_config.compilation_config
|
|
137
146
|
|
|
138
147
|
# TPU only supports DYNAMO_TRACE_ONCE compilation level
|
|
@@ -183,7 +192,7 @@ class TpuPlatform(Platform):
|
|
|
183
192
|
parallel_config.worker_cls = \
|
|
184
193
|
"tpu_inference.worker.tpu_worker.TPUWorker"
|
|
185
194
|
|
|
186
|
-
multihost_backend =
|
|
195
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
187
196
|
if not multihost_backend: # Single host
|
|
188
197
|
if parallel_config.pipeline_parallel_size == 1:
|
|
189
198
|
logger.info("Force using UniProcExecutor for JAX on \
|
|
@@ -267,3 +276,7 @@ class TpuPlatform(Platform):
|
|
|
267
276
|
Returns if the current platform needs to sync weight loader.
|
|
268
277
|
"""
|
|
269
278
|
return True
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def support_hybrid_kv_cache(cls) -> bool:
|
|
282
|
+
return True
|