tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -13,14 +13,12 @@ from typing import Any, Optional
|
|
|
13
13
|
import jax
|
|
14
14
|
import jax.numpy as jnp
|
|
15
15
|
import torch
|
|
16
|
-
import torchax
|
|
17
16
|
from flax import nnx
|
|
18
17
|
from jax.sharding import Mesh, NamedSharding
|
|
19
18
|
from jax.sharding import PartitionSpec as P
|
|
20
19
|
from safetensors import safe_open
|
|
21
|
-
from vllm.config import VllmConfig
|
|
22
20
|
|
|
23
|
-
from tpu_inference import
|
|
21
|
+
from tpu_inference import utils
|
|
24
22
|
from tpu_inference.logger import init_logger
|
|
25
23
|
from tpu_inference.models.jax.utils import file_utils
|
|
26
24
|
|
|
@@ -199,11 +197,12 @@ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
|
|
|
199
197
|
return jax.device_put(x, shardings)
|
|
200
198
|
|
|
201
199
|
|
|
202
|
-
def get_default_maps(
|
|
200
|
+
def get_default_maps(vllm_config, mesh: Mesh,
|
|
203
201
|
name_map: dict[str, str]) -> MetadataMap:
|
|
204
202
|
"""Load weights from one model weights file to the model, run on single thread."""
|
|
205
203
|
sharding_size = mesh.shape["model"]
|
|
206
204
|
|
|
205
|
+
model_config = vllm_config.model_config
|
|
207
206
|
hf_config = model_config.hf_config
|
|
208
207
|
|
|
209
208
|
num_heads = hf_config.num_attention_heads
|
|
@@ -267,15 +266,14 @@ def get_default_maps(model_config, mesh: Mesh,
|
|
|
267
266
|
bias_pad_map=bias_pad_keys)
|
|
268
267
|
|
|
269
268
|
|
|
270
|
-
def
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
| None = None):
|
|
269
|
+
def _load_hf_weights_on_thread(vllm_config,
|
|
270
|
+
params: nnx.State,
|
|
271
|
+
metadata_map: MetadataMap,
|
|
272
|
+
mesh: Mesh,
|
|
273
|
+
weights_file: str,
|
|
274
|
+
filter_regex: str | None = None,
|
|
275
|
+
keep_original_dtype_keys_regex: list[str]
|
|
276
|
+
| None = None):
|
|
279
277
|
name_map = metadata_map.name_map
|
|
280
278
|
reshape_keys = metadata_map.reshape_map
|
|
281
279
|
bias_reshape_keys = metadata_map.bias_reshape_map
|
|
@@ -292,118 +290,6 @@ def _load_and_shard_weight(vllm_config,
|
|
|
292
290
|
head_dim = utils.get_padded_head_dim(head_dim_original)
|
|
293
291
|
head_dim_pad = head_dim - head_dim_original
|
|
294
292
|
|
|
295
|
-
# Check if the key should retain its original dtype
|
|
296
|
-
keep_original_dtype = False
|
|
297
|
-
if keep_original_dtype_keys_regex:
|
|
298
|
-
for pattern in keep_original_dtype_keys_regex:
|
|
299
|
-
if re.match(pattern, hf_key):
|
|
300
|
-
keep_original_dtype = True
|
|
301
|
-
break
|
|
302
|
-
|
|
303
|
-
# Converting to config's dtype
|
|
304
|
-
if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
|
|
305
|
-
logger.warning(
|
|
306
|
-
f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
|
|
307
|
-
)
|
|
308
|
-
hf_weight = hf_weight.astype(model_config.dtype)
|
|
309
|
-
|
|
310
|
-
if hf_key.endswith(".weight"):
|
|
311
|
-
hf_key = hf_key.removesuffix(".weight")
|
|
312
|
-
|
|
313
|
-
# Find the corresponding model key using the HF key
|
|
314
|
-
if "layers" in hf_key:
|
|
315
|
-
layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
|
|
316
|
-
layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
|
|
317
|
-
model_key = name_map[layer_key]
|
|
318
|
-
model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
|
|
319
|
-
elif "blocks" in hf_key:
|
|
320
|
-
layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
|
|
321
|
-
layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
|
|
322
|
-
model_key = name_map[layer_key]
|
|
323
|
-
model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
|
|
324
|
-
else:
|
|
325
|
-
if hf_key not in name_map and hf_key == "lm_head":
|
|
326
|
-
logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
327
|
-
return
|
|
328
|
-
if hf_key not in name_map and "t2d" in hf_key:
|
|
329
|
-
logger.warning(
|
|
330
|
-
f"Skip loading {hf_key} as it's not used in eagle-3 for now")
|
|
331
|
-
return
|
|
332
|
-
model_key = name_map.get(hf_key, hf_key)
|
|
333
|
-
|
|
334
|
-
model_weight, model_sharding = get_param_and_sharding(
|
|
335
|
-
params, shardings, model_key)
|
|
336
|
-
|
|
337
|
-
logger.debug(
|
|
338
|
-
"before transform | "
|
|
339
|
-
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
if hf_key.endswith(".bias"):
|
|
343
|
-
for key in bias_reshape_keys:
|
|
344
|
-
if key in hf_key:
|
|
345
|
-
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
346
|
-
if head_dim_pad > 0:
|
|
347
|
-
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad)))
|
|
348
|
-
break
|
|
349
|
-
else:
|
|
350
|
-
for key in reshape_keys:
|
|
351
|
-
if key in hf_key:
|
|
352
|
-
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
353
|
-
if head_dim_pad > 0:
|
|
354
|
-
if "o_proj" in key:
|
|
355
|
-
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
356
|
-
(0, head_dim_pad)))
|
|
357
|
-
else:
|
|
358
|
-
hf_weight = jnp.pad(hf_weight,
|
|
359
|
-
((0, 0), (0, head_dim_pad),
|
|
360
|
-
(0, 0)))
|
|
361
|
-
break
|
|
362
|
-
for key in transpose_keys:
|
|
363
|
-
if key in hf_key:
|
|
364
|
-
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
365
|
-
break
|
|
366
|
-
|
|
367
|
-
# Pad num-kv-heads
|
|
368
|
-
if hf_key.endswith(".bias"):
|
|
369
|
-
for key, value in bias_pad_keys.items():
|
|
370
|
-
dim = value[0]
|
|
371
|
-
dim_size = value[1]
|
|
372
|
-
if key in hf_key and dim_size != 0:
|
|
373
|
-
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
374
|
-
break
|
|
375
|
-
else:
|
|
376
|
-
for key, value in pad_keys.items():
|
|
377
|
-
dim = value[0]
|
|
378
|
-
dim_size = value[1]
|
|
379
|
-
if key in hf_key and dim_size != 0:
|
|
380
|
-
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
381
|
-
break
|
|
382
|
-
|
|
383
|
-
logger.debug(
|
|
384
|
-
"after transform | "
|
|
385
|
-
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
386
|
-
)
|
|
387
|
-
|
|
388
|
-
if head_dim_pad == 0:
|
|
389
|
-
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
390
|
-
|
|
391
|
-
# Update the model weight
|
|
392
|
-
spec = model_weight.sharding.spec if isinstance(
|
|
393
|
-
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
394
|
-
model_weight.value = shard(hf_weight, spec)
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
def _load_hf_weights_on_thread(
|
|
398
|
-
vllm_config: VllmConfig,
|
|
399
|
-
params: nnx.State,
|
|
400
|
-
metadata_map: "MetadataMap",
|
|
401
|
-
mesh: Mesh,
|
|
402
|
-
weights_file: str,
|
|
403
|
-
filter_regex: Optional[str] = None,
|
|
404
|
-
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
405
|
-
):
|
|
406
|
-
"""Loads weights from a single weights file."""
|
|
407
293
|
try:
|
|
408
294
|
shardings = nnx.get_named_sharding(params, mesh)
|
|
409
295
|
except TypeError:
|
|
@@ -411,88 +297,147 @@ def _load_hf_weights_on_thread(
|
|
|
411
297
|
|
|
412
298
|
for hf_key, hf_weight in model_weights_single_file_generator(
|
|
413
299
|
weights_file, framework="flax", filter_regex=filter_regex):
|
|
414
|
-
_load_and_shard_weight(
|
|
415
|
-
vllm_config,
|
|
416
|
-
params,
|
|
417
|
-
shardings,
|
|
418
|
-
metadata_map,
|
|
419
|
-
mesh,
|
|
420
|
-
hf_key,
|
|
421
|
-
hf_weight,
|
|
422
|
-
keep_original_dtype_keys_regex,
|
|
423
|
-
)
|
|
424
300
|
|
|
301
|
+
# Check if the key should retain its original dtype
|
|
302
|
+
keep_original_dtype = False
|
|
303
|
+
if keep_original_dtype_keys_regex:
|
|
304
|
+
for pattern in keep_original_dtype_keys_regex:
|
|
305
|
+
if re.match(pattern, hf_key):
|
|
306
|
+
keep_original_dtype = True
|
|
307
|
+
break
|
|
425
308
|
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
)
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
if
|
|
309
|
+
# Converting to config's dtype
|
|
310
|
+
if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
|
|
311
|
+
logger.warning(
|
|
312
|
+
f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
|
|
313
|
+
)
|
|
314
|
+
hf_weight = hf_weight.astype(model_config.dtype)
|
|
315
|
+
|
|
316
|
+
if hf_key.endswith(".weight"):
|
|
317
|
+
hf_key = hf_key.removesuffix(".weight")
|
|
318
|
+
|
|
319
|
+
# Find the corresponding model key using the HF key
|
|
320
|
+
if "layers" in hf_key:
|
|
321
|
+
layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
|
|
322
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
|
|
323
|
+
model_key = name_map[layer_key]
|
|
324
|
+
model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
|
|
325
|
+
elif "blocks" in hf_key:
|
|
326
|
+
layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
|
|
327
|
+
layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
|
|
328
|
+
model_key = name_map[layer_key]
|
|
329
|
+
model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
|
|
330
|
+
else:
|
|
331
|
+
if hf_key not in name_map and hf_key == "lm_head":
|
|
332
|
+
logger.warning(
|
|
333
|
+
f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
334
|
+
continue
|
|
335
|
+
if hf_key not in name_map and "t2d" in hf_key:
|
|
336
|
+
logger.warning(
|
|
337
|
+
f"Skip loading {hf_key} as it's not used in eagle-3 for now"
|
|
338
|
+
)
|
|
449
339
|
continue
|
|
340
|
+
model_key = name_map.get(hf_key, hf_key)
|
|
341
|
+
model_weight, model_sharding = get_param_and_sharding(
|
|
342
|
+
params, shardings, model_key)
|
|
450
343
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
344
|
+
logger.debug(
|
|
345
|
+
"before transform | "
|
|
346
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
if hf_key.endswith(".bias"):
|
|
350
|
+
for key in bias_reshape_keys:
|
|
351
|
+
if key in hf_key:
|
|
352
|
+
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
353
|
+
if head_dim_pad > 0:
|
|
354
|
+
hf_weight = jnp.pad(hf_weight,
|
|
355
|
+
((0, 0), (0, head_dim_pad)))
|
|
356
|
+
break
|
|
357
|
+
else:
|
|
358
|
+
for key in reshape_keys:
|
|
359
|
+
if key in hf_key:
|
|
360
|
+
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
361
|
+
if head_dim_pad > 0:
|
|
362
|
+
if "o_proj" in key:
|
|
363
|
+
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
364
|
+
(0, head_dim_pad)))
|
|
365
|
+
else:
|
|
366
|
+
hf_weight = jnp.pad(hf_weight,
|
|
367
|
+
((0, 0), (0, head_dim_pad),
|
|
368
|
+
(0, 0)))
|
|
369
|
+
break
|
|
370
|
+
for key in transpose_keys:
|
|
371
|
+
if key in hf_key:
|
|
372
|
+
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
373
|
+
break
|
|
374
|
+
|
|
375
|
+
# Pad num-kv-heads
|
|
376
|
+
if hf_key.endswith(".bias"):
|
|
377
|
+
for key, value in bias_pad_keys.items():
|
|
378
|
+
dim = value[0]
|
|
379
|
+
dim_size = value[1]
|
|
380
|
+
if key in hf_key and dim_size != 0:
|
|
381
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
382
|
+
break
|
|
383
|
+
else:
|
|
384
|
+
for key, value in pad_keys.items():
|
|
385
|
+
dim = value[0]
|
|
386
|
+
dim_size = value[1]
|
|
387
|
+
if key in hf_key and dim_size != 0:
|
|
388
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
389
|
+
break
|
|
390
|
+
|
|
391
|
+
logger.debug(
|
|
392
|
+
"after transform | "
|
|
393
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
394
|
+
)
|
|
454
395
|
|
|
455
|
-
|
|
396
|
+
if head_dim_pad == 0:
|
|
397
|
+
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
398
|
+
|
|
399
|
+
# Update the model weight
|
|
400
|
+
spec = model_weight.sharding.spec if isinstance(
|
|
401
|
+
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
402
|
+
model_weight.value = shard(hf_weight, spec)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def load_hf_weights(vllm_config,
|
|
406
|
+
model: nnx.Module,
|
|
407
|
+
metadata_map: MetadataMap,
|
|
408
|
+
mesh: Mesh,
|
|
409
|
+
filter_regex: str | None = None,
|
|
410
|
+
is_draft_model: bool = False,
|
|
411
|
+
keep_original_dtype_keys_regex: list[str] | None = None):
|
|
412
|
+
"""Load weights from all model weights files to the model, run in multi threads."""
|
|
413
|
+
if is_draft_model:
|
|
414
|
+
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
415
|
+
else:
|
|
416
|
+
model_path = vllm_config.model_config.model
|
|
417
|
+
weights_files = get_model_weights_files(
|
|
418
|
+
model_path, vllm_config.load_config.download_dir)
|
|
419
|
+
params = nnx.state(model)
|
|
420
|
+
max_workers = min(64, len(weights_files))
|
|
421
|
+
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
422
|
+
# Because multi-threading would cause different JAX processes to load
|
|
423
|
+
# different weights at the same time.
|
|
424
|
+
if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
|
|
425
|
+
max_workers = 1
|
|
426
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
427
|
+
futures = [
|
|
428
|
+
executor.submit(
|
|
429
|
+
_load_hf_weights_on_thread,
|
|
456
430
|
vllm_config,
|
|
457
431
|
params,
|
|
458
|
-
shardings,
|
|
459
432
|
metadata_map,
|
|
460
433
|
mesh,
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
keep_original_dtype_keys_regex
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
469
|
-
else:
|
|
470
|
-
model_path = vllm_config.model_config.model
|
|
471
|
-
weights_files = get_model_weights_files(
|
|
472
|
-
model_path, vllm_config.load_config.download_dir)
|
|
473
|
-
max_workers = min(64, len(weights_files))
|
|
474
|
-
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
475
|
-
# Because multi-threading would cause different JAX processes to load
|
|
476
|
-
# different weights at the same time.
|
|
477
|
-
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
478
|
-
max_workers = 1
|
|
479
|
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
480
|
-
futures = [
|
|
481
|
-
executor.submit(
|
|
482
|
-
_load_hf_weights_on_thread,
|
|
483
|
-
vllm_config,
|
|
484
|
-
params,
|
|
485
|
-
metadata_map,
|
|
486
|
-
mesh,
|
|
487
|
-
weights_file,
|
|
488
|
-
filter_regex=filter_regex,
|
|
489
|
-
keep_original_dtype_keys_regex=
|
|
490
|
-
keep_original_dtype_keys_regex,
|
|
491
|
-
) for weights_file in weights_files
|
|
492
|
-
]
|
|
493
|
-
for future in futures:
|
|
494
|
-
future.result()
|
|
495
|
-
|
|
434
|
+
weights_file,
|
|
435
|
+
filter_regex=filter_regex,
|
|
436
|
+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
|
|
437
|
+
for weights_file in weights_files
|
|
438
|
+
]
|
|
439
|
+
for future in futures:
|
|
440
|
+
future.result()
|
|
496
441
|
check_all_loaded(params)
|
|
497
442
|
nnx.update(model, params)
|
|
498
443
|
|
|
@@ -9,7 +9,6 @@ import jax
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.nn
|
|
11
11
|
import torchax
|
|
12
|
-
import vllm.envs as vllm_envs
|
|
13
12
|
from flax.typing import PRNGKey
|
|
14
13
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
15
14
|
from torchax.interop import jax_view, torch_view
|
|
@@ -119,16 +118,10 @@ class VllmModelWrapper:
|
|
|
119
118
|
"torch._sync",
|
|
120
119
|
return_value=None) if use_random_weights else nullcontext()
|
|
121
120
|
|
|
122
|
-
# By default load weights to the CPU device first. If we are running
|
|
123
|
-
# under Pathways, this would cause weights to be loaded on a CPU-only
|
|
124
|
-
# node, so we'll need to remove this context.
|
|
125
|
-
jax_context = jax.default_device(
|
|
126
|
-
jax.devices("cpu")
|
|
127
|
-
[0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
|
|
128
|
-
|
|
129
121
|
# Load the vLLM model and wrap it into a new model whose forward
|
|
130
122
|
# function can calculate the hidden_state and logits.
|
|
131
|
-
|
|
123
|
+
available_devices = self.mesh.devices.flatten()
|
|
124
|
+
with load_context, jax.default_device(available_devices[0]):
|
|
132
125
|
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
|
|
133
126
|
lora_manager = None
|
|
134
127
|
if vllm_config_for_load.lora_config is not None:
|
|
@@ -169,7 +162,6 @@ class VllmModelWrapper:
|
|
|
169
162
|
input_ids: jax.Array,
|
|
170
163
|
attn_metadata: AttentionMetadata,
|
|
171
164
|
input_embeds: jax.Array,
|
|
172
|
-
input_positions: jax.Array,
|
|
173
165
|
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
|
|
174
166
|
lora_metadata,
|
|
175
167
|
intermediate_tensors: JaxIntermediateTensors = None,
|
|
@@ -196,8 +188,8 @@ class VllmModelWrapper:
|
|
|
196
188
|
torch_view(params_and_buffers),
|
|
197
189
|
kwargs={
|
|
198
190
|
"input_ids": torch_view(input_ids),
|
|
199
|
-
"positions": torch_view(input_positions),
|
|
200
|
-
"intermediate_tensors":
|
|
191
|
+
"positions": torch_view(attn_metadata.input_positions),
|
|
192
|
+
"intermediate_tensors": intermediate_tensors,
|
|
201
193
|
"inputs_embeds": None,
|
|
202
194
|
},
|
|
203
195
|
tie_weights=False,
|
|
@@ -221,7 +213,7 @@ class VllmModelWrapper:
|
|
|
221
213
|
@functools.partial(
|
|
222
214
|
jax.jit,
|
|
223
215
|
out_shardings=(NamedSharding(self.mesh,
|
|
224
|
-
PartitionSpec(
|
|
216
|
+
PartitionSpec(None, "model"))),
|
|
225
217
|
)
|
|
226
218
|
def compute_logits_func(
|
|
227
219
|
params_and_buffers: Any,
|
|
@@ -263,6 +255,7 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
|
|
|
263
255
|
vllm_config,
|
|
264
256
|
device,
|
|
265
257
|
model.embedding_modules,
|
|
258
|
+
model.embedding_padding_modules,
|
|
266
259
|
)
|
|
267
260
|
return lora_manager, lora_manager.create_lora_manager(model)
|
|
268
261
|
|
|
@@ -276,9 +269,10 @@ def replace_set_lora(model):
|
|
|
276
269
|
index: int,
|
|
277
270
|
lora_a: torch.Tensor,
|
|
278
271
|
lora_b: torch.Tensor,
|
|
272
|
+
embeddings_tensor: Optional[torch.Tensor],
|
|
279
273
|
):
|
|
280
274
|
with torchax.default_env():
|
|
281
|
-
self._original_set_lora(index, lora_a, lora_b)
|
|
275
|
+
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
|
|
282
276
|
|
|
283
277
|
def _tpu_reset_lora(self, index: int):
|
|
284
278
|
with torchax.default_env():
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
|
|
4
5
|
|
|
5
6
|
import jax.numpy as jnp
|
|
6
|
-
import torch
|
|
7
7
|
import vllm.envs as vllm_envs
|
|
8
|
+
from torchax.ops.mappings import j2t_dtype
|
|
8
9
|
from tpu_info import device
|
|
9
10
|
from vllm.inputs import ProcessorInputs, PromptType
|
|
10
11
|
from vllm.platforms.interface import Platform, PlatformEnum
|
|
@@ -13,10 +14,9 @@ from vllm.sampling_params import SamplingParams, SamplingType
|
|
|
13
14
|
from tpu_inference import envs
|
|
14
15
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
15
16
|
from tpu_inference.logger import init_logger
|
|
16
|
-
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
|
-
from vllm.attention.backends.registry import
|
|
19
|
+
from vllm.attention.backends.registry import _Backend
|
|
20
20
|
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
|
21
21
|
from vllm.pooling_params import PoolingParams
|
|
22
22
|
else:
|
|
@@ -24,10 +24,16 @@ else:
|
|
|
24
24
|
ModelConfig = None
|
|
25
25
|
VllmConfig = None
|
|
26
26
|
PoolingParams = None
|
|
27
|
-
|
|
27
|
+
_Backend = None
|
|
28
28
|
|
|
29
29
|
logger = init_logger(__name__)
|
|
30
30
|
|
|
31
|
+
_DTYPE: dict[str, jnp.dtype] = {
|
|
32
|
+
"bfloat16": jnp.bfloat16,
|
|
33
|
+
"float": jnp.float32,
|
|
34
|
+
"float32": jnp.float32,
|
|
35
|
+
}
|
|
36
|
+
|
|
31
37
|
|
|
32
38
|
class TpuPlatform(Platform):
|
|
33
39
|
_enum = PlatformEnum.TPU
|
|
@@ -48,13 +54,13 @@ class TpuPlatform(Platform):
|
|
|
48
54
|
]
|
|
49
55
|
|
|
50
56
|
@classmethod
|
|
51
|
-
def get_attn_backend_cls(cls, selected_backend: "
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
from vllm.attention.backends.registry import
|
|
57
|
-
if selected_backend !=
|
|
57
|
+
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
|
58
|
+
dtype: jnp.dtype, kv_cache_dtype: Optional[str],
|
|
59
|
+
block_size: int, use_v1: bool, use_mla: bool,
|
|
60
|
+
has_sink: bool, use_sparse: bool,
|
|
61
|
+
attn_type: Any) -> str:
|
|
62
|
+
from vllm.attention.backends.registry import _Backend
|
|
63
|
+
if selected_backend != _Backend.PALLAS:
|
|
58
64
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
59
65
|
|
|
60
66
|
if use_v1:
|
|
@@ -77,14 +83,6 @@ class TpuPlatform(Platform):
|
|
|
77
83
|
logger.warning(f"Error getting device name: {e}")
|
|
78
84
|
return 'TPU'
|
|
79
85
|
|
|
80
|
-
@classmethod
|
|
81
|
-
def fp8_dtype(cls) -> torch.dtype:
|
|
82
|
-
if cls.get_device_name().lower() == "tpu v6e":
|
|
83
|
-
logger.info(
|
|
84
|
-
"Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
|
|
85
|
-
return torch.float8_e5m2
|
|
86
|
-
return torch.float8_e4m3fn
|
|
87
|
-
|
|
88
86
|
@classmethod
|
|
89
87
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
90
88
|
raise NotImplementedError
|
|
@@ -135,7 +133,6 @@ class TpuPlatform(Platform):
|
|
|
135
133
|
# For v0, the default block size is 16.
|
|
136
134
|
if cache_config and cache_config.block_size is None:
|
|
137
135
|
cache_config.block_size = cast(BlockSize, 16)
|
|
138
|
-
|
|
139
136
|
compilation_config = vllm_config.compilation_config
|
|
140
137
|
|
|
141
138
|
# TPU only supports DYNAMO_TRACE_ONCE compilation level
|
|
@@ -152,19 +149,20 @@ class TpuPlatform(Platform):
|
|
|
152
149
|
# NOTE(xiang): convert dtype to jnp.dtype
|
|
153
150
|
# NOTE(wenlong): skip this logic for mm model preprocessing
|
|
154
151
|
# For mm model preprocessors, it may need the output dtype to be torch.
|
|
155
|
-
# In order to avoid a PR to vLLM, we postpone the dtype checking during
|
|
156
|
-
# tpu_worker initialization
|
|
152
|
+
# In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
|
|
157
153
|
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
dtype =
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
154
|
+
if not isinstance(vllm_config.model_config.dtype, str):
|
|
155
|
+
logger.warning(
|
|
156
|
+
"The model dtype is not properly set for JAX backend. "
|
|
157
|
+
"Overwriting it to jnp.bfloat16")
|
|
158
|
+
vllm_config.model_config.dtype = jnp.bfloat16
|
|
159
|
+
else:
|
|
160
|
+
vllm_config.model_config.dtype = _DTYPE.get(
|
|
161
|
+
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
162
|
+
|
|
163
|
+
if impl == "vllm":
|
|
164
|
+
vllm_config.model_config.dtype = j2t_dtype(
|
|
165
|
+
vllm_config.model_config.dtype.dtype)
|
|
168
166
|
|
|
169
167
|
# TODO(cuiq): remove this dependency.
|
|
170
168
|
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
|
@@ -185,7 +183,7 @@ class TpuPlatform(Platform):
|
|
|
185
183
|
parallel_config.worker_cls = \
|
|
186
184
|
"tpu_inference.worker.tpu_worker.TPUWorker"
|
|
187
185
|
|
|
188
|
-
multihost_backend =
|
|
186
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
189
187
|
if not multihost_backend: # Single host
|
|
190
188
|
if parallel_config.pipeline_parallel_size == 1:
|
|
191
189
|
logger.info("Force using UniProcExecutor for JAX on \
|
|
@@ -269,7 +267,3 @@ class TpuPlatform(Platform):
|
|
|
269
267
|
Returns if the current platform needs to sync weight loader.
|
|
270
268
|
"""
|
|
271
269
|
return True
|
|
272
|
-
|
|
273
|
-
@classmethod
|
|
274
|
-
def support_hybrid_kv_cache(cls) -> bool:
|
|
275
|
-
return True
|