tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -154,9 +154,12 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
|
|
|
154
154
|
logger.info(f"Memory usage before applying quantization of params: "
|
|
155
155
|
f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
|
|
156
156
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
157
|
+
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
158
|
+
kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
|
|
159
|
+
|
|
160
|
+
# Handle the case where kv_cache_dtype is "auto"
|
|
161
|
+
if kv_cache_jnp_dtype is None:
|
|
162
|
+
assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
|
|
160
163
|
kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
|
|
161
164
|
|
|
162
165
|
kv_caches = create_kv_caches(
|
|
@@ -13,14 +13,12 @@ from typing import Any, Optional
|
|
|
13
13
|
import jax
|
|
14
14
|
import jax.numpy as jnp
|
|
15
15
|
import torch
|
|
16
|
-
import torchax
|
|
17
16
|
from flax import nnx
|
|
18
17
|
from jax.sharding import Mesh, NamedSharding
|
|
19
18
|
from jax.sharding import PartitionSpec as P
|
|
20
19
|
from safetensors import safe_open
|
|
21
|
-
from vllm.config import VllmConfig
|
|
22
20
|
|
|
23
|
-
from tpu_inference import
|
|
21
|
+
from tpu_inference import utils
|
|
24
22
|
from tpu_inference.logger import init_logger
|
|
25
23
|
from tpu_inference.models.jax.utils import file_utils
|
|
26
24
|
|
|
@@ -199,11 +197,12 @@ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
|
|
|
199
197
|
return jax.device_put(x, shardings)
|
|
200
198
|
|
|
201
199
|
|
|
202
|
-
def get_default_maps(
|
|
200
|
+
def get_default_maps(vllm_config, mesh: Mesh,
|
|
203
201
|
name_map: dict[str, str]) -> MetadataMap:
|
|
204
202
|
"""Load weights from one model weights file to the model, run on single thread."""
|
|
205
203
|
sharding_size = mesh.shape["model"]
|
|
206
204
|
|
|
205
|
+
model_config = vllm_config.model_config
|
|
207
206
|
hf_config = model_config.hf_config
|
|
208
207
|
|
|
209
208
|
num_heads = hf_config.num_attention_heads
|
|
@@ -267,15 +266,14 @@ def get_default_maps(model_config, mesh: Mesh,
|
|
|
267
266
|
bias_pad_map=bias_pad_keys)
|
|
268
267
|
|
|
269
268
|
|
|
270
|
-
def
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
| None = None):
|
|
269
|
+
def _load_hf_weights_on_thread(vllm_config,
|
|
270
|
+
params: nnx.State,
|
|
271
|
+
metadata_map: MetadataMap,
|
|
272
|
+
mesh: Mesh,
|
|
273
|
+
weights_file: str,
|
|
274
|
+
filter_regex: str | None = None,
|
|
275
|
+
keep_original_dtype_keys_regex: list[str]
|
|
276
|
+
| None = None):
|
|
279
277
|
name_map = metadata_map.name_map
|
|
280
278
|
reshape_keys = metadata_map.reshape_map
|
|
281
279
|
bias_reshape_keys = metadata_map.bias_reshape_map
|
|
@@ -292,118 +290,6 @@ def _load_and_shard_weight(vllm_config,
|
|
|
292
290
|
head_dim = utils.get_padded_head_dim(head_dim_original)
|
|
293
291
|
head_dim_pad = head_dim - head_dim_original
|
|
294
292
|
|
|
295
|
-
# Check if the key should retain its original dtype
|
|
296
|
-
keep_original_dtype = False
|
|
297
|
-
if keep_original_dtype_keys_regex:
|
|
298
|
-
for pattern in keep_original_dtype_keys_regex:
|
|
299
|
-
if re.match(pattern, hf_key):
|
|
300
|
-
keep_original_dtype = True
|
|
301
|
-
break
|
|
302
|
-
|
|
303
|
-
# Converting to config's dtype
|
|
304
|
-
if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
|
|
305
|
-
logger.warning(
|
|
306
|
-
f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
|
|
307
|
-
)
|
|
308
|
-
hf_weight = hf_weight.astype(model_config.dtype)
|
|
309
|
-
|
|
310
|
-
if hf_key.endswith(".weight"):
|
|
311
|
-
hf_key = hf_key.removesuffix(".weight")
|
|
312
|
-
|
|
313
|
-
# Find the corresponding model key using the HF key
|
|
314
|
-
if "layers" in hf_key:
|
|
315
|
-
layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
|
|
316
|
-
layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
|
|
317
|
-
model_key = name_map[layer_key]
|
|
318
|
-
model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
|
|
319
|
-
elif "blocks" in hf_key:
|
|
320
|
-
layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
|
|
321
|
-
layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
|
|
322
|
-
model_key = name_map[layer_key]
|
|
323
|
-
model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
|
|
324
|
-
else:
|
|
325
|
-
if hf_key not in name_map and hf_key == "lm_head":
|
|
326
|
-
logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
327
|
-
return
|
|
328
|
-
if hf_key not in name_map and "t2d" in hf_key:
|
|
329
|
-
logger.warning(
|
|
330
|
-
f"Skip loading {hf_key} as it's not used in eagle-3 for now")
|
|
331
|
-
return
|
|
332
|
-
model_key = name_map.get(hf_key, hf_key)
|
|
333
|
-
|
|
334
|
-
model_weight, model_sharding = get_param_and_sharding(
|
|
335
|
-
params, shardings, model_key)
|
|
336
|
-
|
|
337
|
-
logger.debug(
|
|
338
|
-
"before transform | "
|
|
339
|
-
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
if hf_key.endswith(".bias"):
|
|
343
|
-
for key in bias_reshape_keys:
|
|
344
|
-
if key in hf_key:
|
|
345
|
-
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
346
|
-
if head_dim_pad > 0:
|
|
347
|
-
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad)))
|
|
348
|
-
break
|
|
349
|
-
else:
|
|
350
|
-
for key in reshape_keys:
|
|
351
|
-
if key in hf_key:
|
|
352
|
-
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
353
|
-
if head_dim_pad > 0:
|
|
354
|
-
if "o_proj" in key:
|
|
355
|
-
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
356
|
-
(0, head_dim_pad)))
|
|
357
|
-
else:
|
|
358
|
-
hf_weight = jnp.pad(hf_weight,
|
|
359
|
-
((0, 0), (0, head_dim_pad),
|
|
360
|
-
(0, 0)))
|
|
361
|
-
break
|
|
362
|
-
for key in transpose_keys:
|
|
363
|
-
if key in hf_key:
|
|
364
|
-
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
365
|
-
break
|
|
366
|
-
|
|
367
|
-
# Pad num-kv-heads
|
|
368
|
-
if hf_key.endswith(".bias"):
|
|
369
|
-
for key, value in bias_pad_keys.items():
|
|
370
|
-
dim = value[0]
|
|
371
|
-
dim_size = value[1]
|
|
372
|
-
if key in hf_key and dim_size != 0:
|
|
373
|
-
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
374
|
-
break
|
|
375
|
-
else:
|
|
376
|
-
for key, value in pad_keys.items():
|
|
377
|
-
dim = value[0]
|
|
378
|
-
dim_size = value[1]
|
|
379
|
-
if key in hf_key and dim_size != 0:
|
|
380
|
-
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
381
|
-
break
|
|
382
|
-
|
|
383
|
-
logger.debug(
|
|
384
|
-
"after transform | "
|
|
385
|
-
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
386
|
-
)
|
|
387
|
-
|
|
388
|
-
if head_dim_pad == 0:
|
|
389
|
-
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
390
|
-
|
|
391
|
-
# Update the model weight
|
|
392
|
-
spec = model_weight.sharding.spec if isinstance(
|
|
393
|
-
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
394
|
-
model_weight.value = shard(hf_weight, spec)
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
def _load_hf_weights_on_thread(
|
|
398
|
-
vllm_config: VllmConfig,
|
|
399
|
-
params: nnx.State,
|
|
400
|
-
metadata_map: "MetadataMap",
|
|
401
|
-
mesh: Mesh,
|
|
402
|
-
weights_file: str,
|
|
403
|
-
filter_regex: Optional[str] = None,
|
|
404
|
-
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
405
|
-
):
|
|
406
|
-
"""Loads weights from a single weights file."""
|
|
407
293
|
try:
|
|
408
294
|
shardings = nnx.get_named_sharding(params, mesh)
|
|
409
295
|
except TypeError:
|
|
@@ -411,88 +297,147 @@ def _load_hf_weights_on_thread(
|
|
|
411
297
|
|
|
412
298
|
for hf_key, hf_weight in model_weights_single_file_generator(
|
|
413
299
|
weights_file, framework="flax", filter_regex=filter_regex):
|
|
414
|
-
_load_and_shard_weight(
|
|
415
|
-
vllm_config,
|
|
416
|
-
params,
|
|
417
|
-
shardings,
|
|
418
|
-
metadata_map,
|
|
419
|
-
mesh,
|
|
420
|
-
hf_key,
|
|
421
|
-
hf_weight,
|
|
422
|
-
keep_original_dtype_keys_regex,
|
|
423
|
-
)
|
|
424
300
|
|
|
301
|
+
# Check if the key should retain its original dtype
|
|
302
|
+
keep_original_dtype = False
|
|
303
|
+
if keep_original_dtype_keys_regex:
|
|
304
|
+
for pattern in keep_original_dtype_keys_regex:
|
|
305
|
+
if re.match(pattern, hf_key):
|
|
306
|
+
keep_original_dtype = True
|
|
307
|
+
break
|
|
425
308
|
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
)
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
if
|
|
309
|
+
# Converting to config's dtype
|
|
310
|
+
if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
|
|
311
|
+
logger.warning(
|
|
312
|
+
f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
|
|
313
|
+
)
|
|
314
|
+
hf_weight = hf_weight.astype(model_config.dtype)
|
|
315
|
+
|
|
316
|
+
if hf_key.endswith(".weight"):
|
|
317
|
+
hf_key = hf_key.removesuffix(".weight")
|
|
318
|
+
|
|
319
|
+
# Find the corresponding model key using the HF key
|
|
320
|
+
if "layers" in hf_key:
|
|
321
|
+
layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
|
|
322
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
|
|
323
|
+
model_key = name_map[layer_key]
|
|
324
|
+
model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
|
|
325
|
+
elif "blocks" in hf_key:
|
|
326
|
+
layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
|
|
327
|
+
layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
|
|
328
|
+
model_key = name_map[layer_key]
|
|
329
|
+
model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
|
|
330
|
+
else:
|
|
331
|
+
if hf_key not in name_map and hf_key == "lm_head":
|
|
332
|
+
logger.warning(
|
|
333
|
+
f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
334
|
+
continue
|
|
335
|
+
if hf_key not in name_map and "t2d" in hf_key:
|
|
336
|
+
logger.warning(
|
|
337
|
+
f"Skip loading {hf_key} as it's not used in eagle-3 for now"
|
|
338
|
+
)
|
|
449
339
|
continue
|
|
340
|
+
model_key = name_map.get(hf_key, hf_key)
|
|
341
|
+
model_weight, model_sharding = get_param_and_sharding(
|
|
342
|
+
params, shardings, model_key)
|
|
450
343
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
344
|
+
logger.debug(
|
|
345
|
+
"before transform | "
|
|
346
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
if hf_key.endswith(".bias"):
|
|
350
|
+
for key in bias_reshape_keys:
|
|
351
|
+
if key in hf_key:
|
|
352
|
+
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
353
|
+
if head_dim_pad > 0:
|
|
354
|
+
hf_weight = jnp.pad(hf_weight,
|
|
355
|
+
((0, 0), (0, head_dim_pad)))
|
|
356
|
+
break
|
|
357
|
+
else:
|
|
358
|
+
for key in reshape_keys:
|
|
359
|
+
if key in hf_key:
|
|
360
|
+
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
361
|
+
if head_dim_pad > 0:
|
|
362
|
+
if "o_proj" in key:
|
|
363
|
+
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
364
|
+
(0, head_dim_pad)))
|
|
365
|
+
else:
|
|
366
|
+
hf_weight = jnp.pad(hf_weight,
|
|
367
|
+
((0, 0), (0, head_dim_pad),
|
|
368
|
+
(0, 0)))
|
|
369
|
+
break
|
|
370
|
+
for key in transpose_keys:
|
|
371
|
+
if key in hf_key:
|
|
372
|
+
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
373
|
+
break
|
|
374
|
+
|
|
375
|
+
# Pad num-kv-heads
|
|
376
|
+
if hf_key.endswith(".bias"):
|
|
377
|
+
for key, value in bias_pad_keys.items():
|
|
378
|
+
dim = value[0]
|
|
379
|
+
dim_size = value[1]
|
|
380
|
+
if key in hf_key and dim_size != 0:
|
|
381
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
382
|
+
break
|
|
383
|
+
else:
|
|
384
|
+
for key, value in pad_keys.items():
|
|
385
|
+
dim = value[0]
|
|
386
|
+
dim_size = value[1]
|
|
387
|
+
if key in hf_key and dim_size != 0:
|
|
388
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
389
|
+
break
|
|
390
|
+
|
|
391
|
+
logger.debug(
|
|
392
|
+
"after transform | "
|
|
393
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
394
|
+
)
|
|
454
395
|
|
|
455
|
-
|
|
396
|
+
if head_dim_pad == 0:
|
|
397
|
+
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
398
|
+
|
|
399
|
+
# Update the model weight
|
|
400
|
+
spec = model_weight.sharding.spec if isinstance(
|
|
401
|
+
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
402
|
+
model_weight.value = shard(hf_weight, spec)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def load_hf_weights(vllm_config,
|
|
406
|
+
model: nnx.Module,
|
|
407
|
+
metadata_map: MetadataMap,
|
|
408
|
+
mesh: Mesh,
|
|
409
|
+
filter_regex: str | None = None,
|
|
410
|
+
is_draft_model: bool = False,
|
|
411
|
+
keep_original_dtype_keys_regex: list[str] | None = None):
|
|
412
|
+
"""Load weights from all model weights files to the model, run in multi threads."""
|
|
413
|
+
if is_draft_model:
|
|
414
|
+
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
415
|
+
else:
|
|
416
|
+
model_path = vllm_config.model_config.model
|
|
417
|
+
weights_files = get_model_weights_files(
|
|
418
|
+
model_path, vllm_config.load_config.download_dir)
|
|
419
|
+
params = nnx.state(model)
|
|
420
|
+
max_workers = min(64, len(weights_files))
|
|
421
|
+
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
422
|
+
# Because multi-threading would cause different JAX processes to load
|
|
423
|
+
# different weights at the same time.
|
|
424
|
+
if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
|
|
425
|
+
max_workers = 1
|
|
426
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
427
|
+
futures = [
|
|
428
|
+
executor.submit(
|
|
429
|
+
_load_hf_weights_on_thread,
|
|
456
430
|
vllm_config,
|
|
457
431
|
params,
|
|
458
|
-
shardings,
|
|
459
432
|
metadata_map,
|
|
460
433
|
mesh,
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
keep_original_dtype_keys_regex
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
469
|
-
else:
|
|
470
|
-
model_path = vllm_config.model_config.model
|
|
471
|
-
weights_files = get_model_weights_files(
|
|
472
|
-
model_path, vllm_config.load_config.download_dir)
|
|
473
|
-
max_workers = min(64, len(weights_files))
|
|
474
|
-
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
475
|
-
# Because multi-threading would cause different JAX processes to load
|
|
476
|
-
# different weights at the same time.
|
|
477
|
-
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
478
|
-
max_workers = 1
|
|
479
|
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
480
|
-
futures = [
|
|
481
|
-
executor.submit(
|
|
482
|
-
_load_hf_weights_on_thread,
|
|
483
|
-
vllm_config,
|
|
484
|
-
params,
|
|
485
|
-
metadata_map,
|
|
486
|
-
mesh,
|
|
487
|
-
weights_file,
|
|
488
|
-
filter_regex=filter_regex,
|
|
489
|
-
keep_original_dtype_keys_regex=
|
|
490
|
-
keep_original_dtype_keys_regex,
|
|
491
|
-
) for weights_file in weights_files
|
|
492
|
-
]
|
|
493
|
-
for future in futures:
|
|
494
|
-
future.result()
|
|
495
|
-
|
|
434
|
+
weights_file,
|
|
435
|
+
filter_regex=filter_regex,
|
|
436
|
+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
|
|
437
|
+
for weights_file in weights_files
|
|
438
|
+
]
|
|
439
|
+
for future in futures:
|
|
440
|
+
future.result()
|
|
496
441
|
check_all_loaded(params)
|
|
497
442
|
nnx.update(model, params)
|
|
498
443
|
|
|
@@ -9,7 +9,6 @@ import jax
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.nn
|
|
11
11
|
import torchax
|
|
12
|
-
import vllm.envs as vllm_envs
|
|
13
12
|
from flax.typing import PRNGKey
|
|
14
13
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
15
14
|
from torchax.interop import jax_view, torch_view
|
|
@@ -26,8 +25,6 @@ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
|
26
25
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
27
26
|
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
|
|
28
27
|
from tpu_inference.logger import init_logger
|
|
29
|
-
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
30
|
-
JaxIntermediateTensors
|
|
31
28
|
from tpu_inference.models.vllm.vllm_model_wrapper_context import (
|
|
32
29
|
get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
|
|
33
30
|
from tpu_inference.runner.lora_utils import replace_lora_metadata
|
|
@@ -92,14 +89,13 @@ class VllmModelWrapper:
|
|
|
92
89
|
slice_config = self.vllm_config.device_config.slice
|
|
93
90
|
modified_slice_config = True
|
|
94
91
|
self.vllm_config.device_config.slice = None
|
|
95
|
-
self.vllm_config.compilation_config.static_forward_context.clear()
|
|
96
|
-
|
|
97
92
|
vllm_config_for_load = copy.deepcopy(self.vllm_config)
|
|
98
93
|
if modified_slice_config:
|
|
99
94
|
self.vllm_config.device_config.slice = slice_config
|
|
100
95
|
assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
|
|
101
96
|
vllm_config_for_load.device_config.device = "cpu"
|
|
102
97
|
# Clearing the cached compilation config, otherwise vllm model init will fail
|
|
98
|
+
vllm_config_for_load.compilation_config.static_forward_context.clear()
|
|
103
99
|
|
|
104
100
|
# When expert parallelism is enabled, vLLM loads weight in sharding
|
|
105
101
|
# aware manner. Since tpu-inference has its own sharding logic, this
|
|
@@ -119,16 +115,9 @@ class VllmModelWrapper:
|
|
|
119
115
|
"torch._sync",
|
|
120
116
|
return_value=None) if use_random_weights else nullcontext()
|
|
121
117
|
|
|
122
|
-
# By default load weights to the CPU device first. If we are running
|
|
123
|
-
# under Pathways, this would cause weights to be loaded on a CPU-only
|
|
124
|
-
# node, so we'll need to remove this context.
|
|
125
|
-
jax_context = jax.default_device(
|
|
126
|
-
jax.devices("cpu")
|
|
127
|
-
[0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
|
|
128
|
-
|
|
129
118
|
# Load the vLLM model and wrap it into a new model whose forward
|
|
130
119
|
# function can calculate the hidden_state and logits.
|
|
131
|
-
with load_context
|
|
120
|
+
with load_context:
|
|
132
121
|
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
|
|
133
122
|
lora_manager = None
|
|
134
123
|
if vllm_config_for_load.lora_config is not None:
|
|
@@ -160,8 +149,7 @@ class VllmModelWrapper:
|
|
|
160
149
|
"xla_tpu_reduce_scatter_collective_matmul_mode":
|
|
161
150
|
"post_spmd_conservative"
|
|
162
151
|
},
|
|
163
|
-
static_argnames=("layer_name_to_kvcache_index",
|
|
164
|
-
"is_last_rank"),
|
|
152
|
+
static_argnames=("layer_name_to_kvcache_index", ),
|
|
165
153
|
)
|
|
166
154
|
def step_fun(
|
|
167
155
|
params_and_buffers, # This has been wrapped into torchax TorchValue
|
|
@@ -169,12 +157,8 @@ class VllmModelWrapper:
|
|
|
169
157
|
input_ids: jax.Array,
|
|
170
158
|
attn_metadata: AttentionMetadata,
|
|
171
159
|
input_embeds: jax.Array,
|
|
172
|
-
input_positions: jax.Array,
|
|
173
160
|
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
|
|
174
161
|
lora_metadata,
|
|
175
|
-
intermediate_tensors: JaxIntermediateTensors = None,
|
|
176
|
-
is_first_rank: bool = True,
|
|
177
|
-
is_last_rank: bool = True,
|
|
178
162
|
*args,
|
|
179
163
|
) -> Tuple[List[jax.Array], jax.Array]:
|
|
180
164
|
layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
|
|
@@ -189,14 +173,12 @@ class VllmModelWrapper:
|
|
|
189
173
|
# torch_view in order to call the Torch function.
|
|
190
174
|
original_lora_metadata = replace_lora_metadata(
|
|
191
175
|
self.model, lora_metadata, self.vllm_config.lora_config)
|
|
192
|
-
|
|
193
|
-
intermediate_tensors = intermediate_tensors.to_torch()
|
|
194
|
-
output_from_torch = torch.func.functional_call(
|
|
176
|
+
hidden_states = torch.func.functional_call(
|
|
195
177
|
self.model,
|
|
196
178
|
torch_view(params_and_buffers),
|
|
197
179
|
kwargs={
|
|
198
180
|
"input_ids": torch_view(input_ids),
|
|
199
|
-
"positions": torch_view(input_positions),
|
|
181
|
+
"positions": torch_view(attn_metadata.input_positions),
|
|
200
182
|
"intermediate_tensors": None,
|
|
201
183
|
"inputs_embeds": None,
|
|
202
184
|
},
|
|
@@ -206,13 +188,11 @@ class VllmModelWrapper:
|
|
|
206
188
|
self.vllm_config.lora_config)
|
|
207
189
|
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
|
|
208
190
|
new_kv_caches = vllm_model_wrapper_context.kv_caches
|
|
209
|
-
# Wrap the
|
|
210
|
-
#
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
output = jax_view(output_from_torch)
|
|
215
|
-
return new_kv_caches, output, []
|
|
191
|
+
# Wrap the hidden_states from torch land into a JaxValue for the jax
|
|
192
|
+
# code to consume.
|
|
193
|
+
hidden_states = jax_view(hidden_states)
|
|
194
|
+
|
|
195
|
+
return new_kv_caches, hidden_states, []
|
|
216
196
|
|
|
217
197
|
return step_fun
|
|
218
198
|
|
|
@@ -221,7 +201,7 @@ class VllmModelWrapper:
|
|
|
221
201
|
@functools.partial(
|
|
222
202
|
jax.jit,
|
|
223
203
|
out_shardings=(NamedSharding(self.mesh,
|
|
224
|
-
PartitionSpec(
|
|
204
|
+
PartitionSpec(None, "model"))),
|
|
225
205
|
)
|
|
226
206
|
def compute_logits_func(
|
|
227
207
|
params_and_buffers: Any,
|
|
@@ -263,6 +243,7 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
|
|
|
263
243
|
vllm_config,
|
|
264
244
|
device,
|
|
265
245
|
model.embedding_modules,
|
|
246
|
+
model.embedding_padding_modules,
|
|
266
247
|
)
|
|
267
248
|
return lora_manager, lora_manager.create_lora_manager(model)
|
|
268
249
|
|
|
@@ -276,9 +257,10 @@ def replace_set_lora(model):
|
|
|
276
257
|
index: int,
|
|
277
258
|
lora_a: torch.Tensor,
|
|
278
259
|
lora_b: torch.Tensor,
|
|
260
|
+
embeddings_tensor: Optional[torch.Tensor],
|
|
279
261
|
):
|
|
280
262
|
with torchax.default_env():
|
|
281
|
-
self._original_set_lora(index, lora_a, lora_b)
|
|
263
|
+
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
|
|
282
264
|
|
|
283
265
|
def _tpu_reset_lora(self, index: int):
|
|
284
266
|
with torchax.default_env():
|