tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +4 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +77 -54
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +9 -9
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +9 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +2 -1
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +21 -8
- tpu_inference/models/vllm/vllm_model_wrapper.py +4 -4
- tpu_inference/platforms/tpu_platform.py +5 -2
- tpu_inference/runner/compilation_manager.py +33 -15
- tpu_inference/runner/kv_cache_manager.py +8 -2
- tpu_inference/runner/tpu_runner.py +187 -99
- tpu_inference/spec_decode/jax/eagle3.py +2 -1
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +158 -22
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +34 -39
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
tests/lora/test_layers.py
CHANGED
|
@@ -91,7 +91,6 @@ def populate_loras(
|
|
|
91
91
|
index_to_id: list[Optional[int]],
|
|
92
92
|
lora_layer: BaseLayerWithLoRA,
|
|
93
93
|
baselayer_weights: torch.Tensor,
|
|
94
|
-
generate_embeddings_tensor: int = 0,
|
|
95
94
|
repeats: int = 1,
|
|
96
95
|
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
|
|
97
96
|
"""This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
|
|
@@ -103,8 +102,6 @@ def populate_loras(
|
|
|
103
102
|
lora_layer: the LoRAlayer to populate.
|
|
104
103
|
baselayer_weights: the PyTorch tensor containing the layer's
|
|
105
104
|
weights.
|
|
106
|
-
generate_embeddings_tensor: whether to generate an
|
|
107
|
-
embeddings tensor for each LoRA.
|
|
108
105
|
repeats: must only be set for column parallel packed
|
|
109
106
|
layers. Indicates the number of loras to compose
|
|
110
107
|
together to create a single lora layer.
|
|
@@ -131,7 +128,6 @@ def populate_loras(
|
|
|
131
128
|
baselayer_weights.device).init_random_lora(
|
|
132
129
|
module_name=f"fake_{i}",
|
|
133
130
|
weight=baselayer_weights,
|
|
134
|
-
generate_embeddings_tensor=generate_embeddings_tensor,
|
|
135
131
|
)
|
|
136
132
|
sublora.lora_b = sublora.lora_b[(sublora_len *
|
|
137
133
|
i):(sublora_len * (i + 1)), :]
|
|
@@ -147,7 +143,6 @@ def populate_loras(
|
|
|
147
143
|
slot_idx,
|
|
148
144
|
lora_a=lora.lora_a,
|
|
149
145
|
lora_b=lora.lora_b,
|
|
150
|
-
embeddings_tensor=lora.embeddings_tensor,
|
|
151
146
|
)
|
|
152
147
|
|
|
153
148
|
lora_dict[lora_id] = lora
|
|
@@ -546,7 +541,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
|
546
541
|
index_to_id,
|
|
547
542
|
lora_config.max_loras,
|
|
548
543
|
vocab_size=512,
|
|
549
|
-
extra_vocab_size=lora_config.lora_extra_vocab_size,
|
|
550
544
|
)
|
|
551
545
|
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
|
|
552
546
|
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
|
tests/lora/utils.py
CHANGED
|
@@ -24,7 +24,6 @@ class DummyLoRAManager:
|
|
|
24
24
|
module_name: str,
|
|
25
25
|
weight: torch.Tensor,
|
|
26
26
|
rank: int = 8,
|
|
27
|
-
generate_embeddings_tensor: int = 0,
|
|
28
27
|
):
|
|
29
28
|
lora = LoRALayerWeights(
|
|
30
29
|
module_name,
|
|
@@ -37,13 +36,6 @@ class DummyLoRAManager:
|
|
|
37
36
|
dtype=weight.dtype,
|
|
38
37
|
device=self._device),
|
|
39
38
|
)
|
|
40
|
-
if generate_embeddings_tensor:
|
|
41
|
-
lora.embeddings_tensor = torch.rand(
|
|
42
|
-
5,
|
|
43
|
-
generate_embeddings_tensor,
|
|
44
|
-
dtype=weight.dtype,
|
|
45
|
-
device=self._device,
|
|
46
|
-
)
|
|
47
39
|
self.set_module_lora(module_name, lora)
|
|
48
40
|
|
|
49
41
|
return lora
|
tpu_inference/__init__.py
CHANGED
|
@@ -1,21 +1,40 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
1
|
# The environment variables override should be imported before any other
|
|
4
2
|
# modules to ensure that the environment variables are set before any
|
|
5
3
|
# other modules are imported.
|
|
6
4
|
import tpu_inference.env_override # noqa: F401
|
|
5
|
+
from tpu_inference import envs
|
|
7
6
|
from tpu_inference import tpu_info as ti
|
|
8
7
|
from tpu_inference.logger import init_logger
|
|
9
8
|
|
|
10
9
|
logger = init_logger(__name__)
|
|
11
10
|
|
|
12
|
-
if "proxy" in
|
|
11
|
+
if "proxy" in envs.JAX_PLATFORMS:
|
|
13
12
|
logger.info("Running vLLM on TPU via Pathways proxy.")
|
|
14
13
|
# Must run pathwaysutils.initialize() before any JAX operations
|
|
15
14
|
try:
|
|
15
|
+
import traceback
|
|
16
|
+
|
|
16
17
|
import pathwaysutils
|
|
18
|
+
import vllm
|
|
19
|
+
from vllm.platforms import (resolve_current_platform_cls_qualname,
|
|
20
|
+
resolve_obj_by_qualname)
|
|
17
21
|
pathwaysutils.initialize()
|
|
18
22
|
logger.info("Module pathwaysutils is imported.")
|
|
23
|
+
|
|
24
|
+
# Pathways requires eager resolution of vllm.current_platform instead of
|
|
25
|
+
# lazy resolution in the normal code path. Since this part involves
|
|
26
|
+
# global topology discovery across multiple hosts, the platform
|
|
27
|
+
# resolution must happen before other components are loaded.
|
|
28
|
+
logger.info("Eagerly resolving vLLM current_platform for Pathways.")
|
|
29
|
+
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
|
30
|
+
resolved_platform_instance = resolve_obj_by_qualname(
|
|
31
|
+
platform_cls_qualname)()
|
|
32
|
+
vllm.platforms._current_platform = resolved_platform_instance
|
|
33
|
+
vllm.platforms._init_trace = "".join(traceback.format_stack())
|
|
34
|
+
logger.info(
|
|
35
|
+
f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
|
|
36
|
+
)
|
|
37
|
+
|
|
19
38
|
except Exception as e:
|
|
20
39
|
logger.error(
|
|
21
40
|
f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
|
|
@@ -1,17 +1,15 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Tuple
|
|
5
4
|
|
|
6
|
-
|
|
7
|
-
DECODE_SLICES = 'DECODE_SLICES'
|
|
5
|
+
from tpu_inference import envs
|
|
8
6
|
|
|
9
7
|
|
|
10
8
|
def is_disagg_enabled() -> bool:
|
|
11
9
|
# We triggrer our code path as long as prefill slices are set. This
|
|
12
10
|
# allows us to test interleave mode effectively with the code path
|
|
13
11
|
# for comparison purposes.
|
|
14
|
-
return PREFILL_SLICES
|
|
12
|
+
return bool(envs.PREFILL_SLICES)
|
|
15
13
|
|
|
16
14
|
|
|
17
15
|
def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
|
40
38
|
|
|
41
39
|
|
|
42
40
|
def get_prefill_slices() -> Tuple[int, ...]:
|
|
43
|
-
if
|
|
41
|
+
if not envs.PREFILL_SLICES:
|
|
44
42
|
return ()
|
|
45
|
-
return _parse_slices(
|
|
43
|
+
return _parse_slices(envs.PREFILL_SLICES)
|
|
46
44
|
|
|
47
45
|
|
|
48
46
|
def get_decode_slices() -> Tuple[int, ...]:
|
|
49
|
-
if
|
|
47
|
+
if not envs.DECODE_SLICES:
|
|
50
48
|
return ()
|
|
51
|
-
return _parse_slices(
|
|
49
|
+
return _parse_slices(envs.DECODE_SLICES)
|
|
@@ -60,7 +60,6 @@ D workflow:
|
|
|
60
60
|
|
|
61
61
|
import copy
|
|
62
62
|
import functools
|
|
63
|
-
import os
|
|
64
63
|
import threading
|
|
65
64
|
import time
|
|
66
65
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
@@ -86,6 +85,7 @@ if TYPE_CHECKING:
|
|
|
86
85
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
87
86
|
from vllm.v1.request import Request
|
|
88
87
|
|
|
88
|
+
from tpu_inference import envs
|
|
89
89
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
|
|
90
90
|
get_kv_ports,
|
|
91
91
|
get_kv_transfer_port, get_node_id,
|
|
@@ -441,8 +441,7 @@ class TPUConnectorWorker:
|
|
|
441
441
|
|
|
442
442
|
self.runner: TPUModelRunner = None
|
|
443
443
|
self.mesh: Mesh = None
|
|
444
|
-
self.multi_host =
|
|
445
|
-
"").lower() == "ray"
|
|
444
|
+
self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
446
445
|
# NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
|
|
447
446
|
# The worker rank is assigned with vLLM's sorting logic, which does not work
|
|
448
447
|
# for TPU host topology.
|
|
@@ -2,6 +2,7 @@ import os
|
|
|
2
2
|
|
|
3
3
|
from vllm.utils.network_utils import get_ip
|
|
4
4
|
|
|
5
|
+
from tpu_inference import envs
|
|
5
6
|
from tpu_inference.logger import init_logger
|
|
6
7
|
|
|
7
8
|
logger = init_logger(__name__)
|
|
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
|
|
|
17
18
|
|
|
18
19
|
|
|
19
20
|
def get_kv_ips() -> str:
|
|
20
|
-
if
|
|
21
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
21
22
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
22
23
|
ips = []
|
|
23
24
|
for node_id in range(num_nodes):
|
|
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
|
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def get_kv_ports() -> str:
|
|
31
|
-
if
|
|
32
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
32
33
|
num_nodes = len(_NODES_KV_IP_PORT)
|
|
33
34
|
ports = []
|
|
34
35
|
for node_id in range(num_nodes):
|
tpu_inference/envs.py
CHANGED
|
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|
|
26
26
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
27
27
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
28
28
|
"JAX_PLATFORMS":
|
|
29
|
-
lambda: os.getenv("JAX_PLATFORMS", ""),
|
|
29
|
+
lambda: os.getenv("JAX_PLATFORMS", "").lower(),
|
|
30
30
|
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
|
|
31
31
|
"TPU_ACCELERATOR_TYPE":
|
|
32
32
|
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
|
|
@@ -108,6 +108,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
108
108
|
ip_port = self.collective_rpc("get_node_kv_ip_port")
|
|
109
109
|
for item in ip_port:
|
|
110
110
|
set_node_kv_ip_port(item)
|
|
111
|
+
self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
|
|
112
|
+
self.vllm_config.ec_transfer_config is None
|
|
113
|
+
or not self.vllm_config.ec_transfer_config.is_ec_producer)
|
|
111
114
|
|
|
112
115
|
def _initialize_ray_cluster(self) -> None:
|
|
113
116
|
"""Initialize the distributed cluster with Ray.
|
|
@@ -352,7 +355,7 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
352
355
|
self.collective_rpc("init_worker", args=(all_kwargs, ))
|
|
353
356
|
self.collective_rpc("init_device")
|
|
354
357
|
if self.parallel_config.pipeline_parallel_size > 1:
|
|
355
|
-
self.
|
|
358
|
+
self.collective_rpc("initialize_pp_transfer_connect")
|
|
356
359
|
self.collective_rpc("load_model")
|
|
357
360
|
|
|
358
361
|
if self.use_ray_spmd_worker:
|
|
@@ -440,42 +440,54 @@ def _ragged_paged_attention_kernel(
|
|
|
440
440
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
441
441
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
442
442
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
443
|
+
if not wait:
|
|
444
|
+
# Fetch effective kv from kv cache.
|
|
445
|
+
def loop_body(i, offset):
|
|
446
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
447
|
+
_async_copy(
|
|
448
|
+
cache_hbm_ref.at[pl.ds(
|
|
449
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
450
|
+
sz)],
|
|
451
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
452
|
+
sem,
|
|
453
|
+
wait=False,
|
|
454
|
+
)
|
|
455
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
456
|
+
return offset + sz
|
|
457
|
+
|
|
458
|
+
offset = lax.fori_loop(
|
|
459
|
+
0,
|
|
460
|
+
bkv_p_frm_cache,
|
|
461
|
+
loop_body,
|
|
462
|
+
0, # offset
|
|
463
|
+
unroll=False,
|
|
453
464
|
)
|
|
454
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
455
|
-
return offset + sz
|
|
456
|
-
|
|
457
|
-
offset = lax.fori_loop(
|
|
458
|
-
0,
|
|
459
|
-
bkv_p_frm_cache,
|
|
460
|
-
loop_body,
|
|
461
|
-
0, # offset
|
|
462
|
-
unroll=False,
|
|
463
|
-
)
|
|
464
465
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
466
|
+
# Fetch kv directly from new kv.
|
|
467
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
468
|
+
def _fetch_bkv_from_new_kv():
|
|
469
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
470
|
+
debug_print("[RPA debug] new_kv_len_start={}",
|
|
471
|
+
new_kv_len_start)
|
|
472
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
473
|
+
_async_copy(
|
|
474
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
475
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
476
|
+
sem,
|
|
477
|
+
wait,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
481
|
+
else:
|
|
482
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
483
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
471
484
|
_async_copy(
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
sem,
|
|
475
|
-
wait,
|
|
485
|
+
src=dst,
|
|
486
|
+
dst=dst,
|
|
487
|
+
sem=sem,
|
|
488
|
+
wait=True,
|
|
476
489
|
)
|
|
477
|
-
|
|
478
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
490
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
479
491
|
|
|
480
492
|
def _update_kv_cache(seq_idx,
|
|
481
493
|
bkv_sem_idx,
|
|
@@ -511,30 +523,41 @@ def _ragged_paged_attention_kernel(
|
|
|
511
523
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
512
524
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
513
525
|
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
526
|
+
if not wait:
|
|
527
|
+
|
|
528
|
+
def loop_body(i, states):
|
|
529
|
+
update_sz, ignore = states
|
|
530
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
531
|
+
|
|
532
|
+
_async_copy(
|
|
533
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
534
|
+
sz)],
|
|
535
|
+
cache_hbm_ref.at[pl.ds(
|
|
536
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
537
|
+
ignore,
|
|
538
|
+
sz,
|
|
539
|
+
)],
|
|
540
|
+
sem,
|
|
541
|
+
wait=False,
|
|
542
|
+
)
|
|
543
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
544
|
+
return update_sz - sz, 0
|
|
545
|
+
|
|
546
|
+
lax.fori_loop(
|
|
547
|
+
0,
|
|
548
|
+
kv_p_end - kv_p_start,
|
|
549
|
+
loop_body,
|
|
550
|
+
(update_sz, ignore), # total transfer size
|
|
551
|
+
unroll=False,
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
518
555
|
_async_copy(
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
sz,
|
|
524
|
-
)],
|
|
525
|
-
sem,
|
|
526
|
-
wait,
|
|
556
|
+
src=dst,
|
|
557
|
+
dst=dst,
|
|
558
|
+
sem=sem,
|
|
559
|
+
wait=True,
|
|
527
560
|
)
|
|
528
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
529
|
-
return update_sz - sz, 0
|
|
530
|
-
|
|
531
|
-
lax.fori_loop(
|
|
532
|
-
0,
|
|
533
|
-
kv_p_end - kv_p_start,
|
|
534
|
-
loop_body,
|
|
535
|
-
(update_sz, ignore), # total transfer size
|
|
536
|
-
unroll=False,
|
|
537
|
-
)
|
|
538
561
|
|
|
539
562
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
540
563
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -475,42 +475,54 @@ def _ragged_paged_attention_kernel(
|
|
|
475
475
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
476
476
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
477
477
|
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
478
|
+
if not wait:
|
|
479
|
+
# Fetch effective kv from kv cache.
|
|
480
|
+
def loop_body(i, offset):
|
|
481
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
482
|
+
_async_copy(
|
|
483
|
+
cache_hbm_ref.at[pl.ds(
|
|
484
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
485
|
+
sz)],
|
|
486
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
487
|
+
sem,
|
|
488
|
+
wait=False,
|
|
489
|
+
)
|
|
490
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
491
|
+
return offset + sz
|
|
492
|
+
|
|
493
|
+
offset = lax.fori_loop(
|
|
494
|
+
0,
|
|
495
|
+
bkv_p_frm_cache,
|
|
496
|
+
loop_body,
|
|
497
|
+
0, # offset
|
|
498
|
+
unroll=False,
|
|
488
499
|
)
|
|
489
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
490
|
-
return offset + sz
|
|
491
|
-
|
|
492
|
-
offset = lax.fori_loop(
|
|
493
|
-
0,
|
|
494
|
-
bkv_p_frm_cache,
|
|
495
|
-
loop_body,
|
|
496
|
-
0, # offset
|
|
497
|
-
unroll=False,
|
|
498
|
-
)
|
|
499
500
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
501
|
+
# Fetch kv directly from new kv.
|
|
502
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
503
|
+
def _fetch_bkv_from_new_kv():
|
|
504
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
505
|
+
debug_print("[RPA debug] new_kv_len_start={}",
|
|
506
|
+
new_kv_len_start)
|
|
507
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
508
|
+
_async_copy(
|
|
509
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
510
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
511
|
+
sem,
|
|
512
|
+
wait,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
516
|
+
else:
|
|
517
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
518
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
506
519
|
_async_copy(
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
sem,
|
|
510
|
-
wait,
|
|
520
|
+
src=dst,
|
|
521
|
+
dst=dst,
|
|
522
|
+
sem=sem,
|
|
523
|
+
wait=True,
|
|
511
524
|
)
|
|
512
|
-
|
|
513
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
525
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
514
526
|
|
|
515
527
|
def _update_kv_cache(seq_idx,
|
|
516
528
|
bkv_sem_idx,
|
|
@@ -546,30 +558,41 @@ def _ragged_paged_attention_kernel(
|
|
|
546
558
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
547
559
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
548
560
|
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
561
|
+
if not wait:
|
|
562
|
+
|
|
563
|
+
def loop_body(i, states):
|
|
564
|
+
update_sz, ignore = states
|
|
565
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
566
|
+
|
|
567
|
+
_async_copy(
|
|
568
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
569
|
+
sz)],
|
|
570
|
+
cache_hbm_ref.at[pl.ds(
|
|
571
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
572
|
+
ignore,
|
|
573
|
+
sz,
|
|
574
|
+
)],
|
|
575
|
+
sem,
|
|
576
|
+
wait=False,
|
|
577
|
+
)
|
|
578
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
579
|
+
return update_sz - sz, 0
|
|
580
|
+
|
|
581
|
+
lax.fori_loop(
|
|
582
|
+
0,
|
|
583
|
+
kv_p_end - kv_p_start,
|
|
584
|
+
loop_body,
|
|
585
|
+
(update_sz, ignore), # total transfer size
|
|
586
|
+
unroll=False,
|
|
587
|
+
)
|
|
588
|
+
else:
|
|
589
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
553
590
|
_async_copy(
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
sz,
|
|
559
|
-
)],
|
|
560
|
-
sem,
|
|
561
|
-
wait,
|
|
591
|
+
src=dst,
|
|
592
|
+
dst=dst,
|
|
593
|
+
sem=sem,
|
|
594
|
+
wait=True,
|
|
562
595
|
)
|
|
563
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
564
|
-
return update_sz - sz, 0
|
|
565
|
-
|
|
566
|
-
lax.fori_loop(
|
|
567
|
-
0,
|
|
568
|
-
kv_p_end - kv_p_start,
|
|
569
|
-
loop_body,
|
|
570
|
-
(update_sz, ignore), # total transfer size
|
|
571
|
-
unroll=False,
|
|
572
|
-
)
|
|
573
596
|
|
|
574
597
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
575
598
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -19,6 +19,7 @@ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
|
|
19
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
20
20
|
ParallelLMHead, VocabParallelEmbedding)
|
|
21
21
|
|
|
22
|
+
from tpu_inference import envs
|
|
22
23
|
from tpu_inference.logger import init_logger
|
|
23
24
|
|
|
24
25
|
P = PartitionSpec
|
|
@@ -211,8 +212,7 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
|
|
|
211
212
|
def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
|
|
212
213
|
if isinstance(tensor, tuple):
|
|
213
214
|
return tuple(_sharded_device_put(t, sharding) for t in tensor)
|
|
214
|
-
|
|
215
|
-
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
215
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
216
216
|
if multihost_backend != "ray":
|
|
217
217
|
return jax.device_put(tensor, sharding)
|
|
218
218
|
|
|
@@ -239,7 +239,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|
|
239
239
|
lora_index_to_id: list[Optional[int]],
|
|
240
240
|
max_loras: int,
|
|
241
241
|
vocab_size: int,
|
|
242
|
-
extra_vocab_size: int,
|
|
243
242
|
):
|
|
244
243
|
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
|
245
244
|
# TODO: Should this happen inside mapping internally? If so how can we
|
|
@@ -258,7 +257,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|
|
258
257
|
lora_index_to_id,
|
|
259
258
|
max_loras,
|
|
260
259
|
vocab_size,
|
|
261
|
-
extra_vocab_size
|
|
260
|
+
0, # extra_vocab_size
|
|
262
261
|
"cpu",
|
|
263
262
|
)
|
|
264
263
|
with torchax.default_env():
|
|
@@ -36,19 +36,17 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
|
36
36
|
from tpu_inference.models.jax.llama3 import LlamaForCausalLM
|
|
37
37
|
from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
|
|
38
38
|
from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
|
|
39
|
-
from tpu_inference.models.jax.
|
|
40
|
-
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
39
|
+
from tpu_inference.models.jax.llama_guard_4 import LlamaGuard4ForCausalLM
|
|
41
40
|
from tpu_inference.models.jax.qwen2_5_vl import \
|
|
42
41
|
Qwen2_5_VLForConditionalGeneration
|
|
43
42
|
from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
|
|
44
43
|
_MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
|
|
45
44
|
_MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
|
|
46
45
|
_MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
|
|
47
|
-
_MODEL_REGISTRY["
|
|
46
|
+
_MODEL_REGISTRY["Llama4ForConditionalGeneration"] = LlamaGuard4ForCausalLM
|
|
48
47
|
_MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
|
|
49
48
|
_MODEL_REGISTRY[
|
|
50
49
|
"Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
|
|
51
|
-
_MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
|
|
52
50
|
_MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
|
|
53
51
|
_MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
|
|
54
52
|
|
|
@@ -57,8 +55,10 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
|
57
55
|
if arch in _MODEL_REGISTRY:
|
|
58
56
|
return _MODEL_REGISTRY[arch]
|
|
59
57
|
raise UnsupportedArchitectureError(
|
|
60
|
-
f"Model architectures {architectures}
|
|
61
|
-
|
|
58
|
+
f"Model architectures {architectures} not "
|
|
59
|
+
"registered in tpu-inference. Falling back to vLLM-native "
|
|
60
|
+
f"Pytorch definition. JAX-native architectures: {list(_MODEL_REGISTRY.keys())}"
|
|
61
|
+
)
|
|
62
62
|
|
|
63
63
|
|
|
64
64
|
def _get_nnx_model(
|
|
@@ -217,7 +217,7 @@ def get_flax_model(
|
|
|
217
217
|
hidden_states_sharding, # aux hidden states
|
|
218
218
|
),
|
|
219
219
|
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
|
|
220
|
-
static_argnums=
|
|
220
|
+
static_argnums=7, #7 is layer_name_to_kvcache_index
|
|
221
221
|
)
|
|
222
222
|
def run_model(graphdef, state, *args):
|
|
223
223
|
model = nnx.merge(graphdef, state)
|
|
@@ -326,8 +326,8 @@ def get_model(
|
|
|
326
326
|
# Convert the error message to a string to check its contents
|
|
327
327
|
error_msg = str(e)
|
|
328
328
|
|
|
329
|
-
logger.warning(
|
|
330
|
-
|
|
329
|
+
logger.warning(error_msg)
|
|
330
|
+
|
|
331
331
|
# Fall back to the vLLM model and updating the dtype accordingly
|
|
332
332
|
vllm_config.model_config.dtype = j2t_dtype(
|
|
333
333
|
vllm_config.model_config.dtype.dtype)
|
|
@@ -368,7 +368,8 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
368
368
|
"lm_head": "model.lm_head",
|
|
369
369
|
})
|
|
370
370
|
|
|
371
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
371
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
372
|
+
self.mesh, mappings)
|
|
372
373
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
373
374
|
model=self,
|
|
374
375
|
metadata_map=metadata_map,
|