tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (40) hide show
  1. tests/lora/test_layers.py +0 -6
  2. tests/lora/utils.py +0 -8
  3. tpu_inference/__init__.py +22 -3
  4. tpu_inference/core/disagg_utils.py +6 -8
  5. tpu_inference/distributed/tpu_connector.py +2 -3
  6. tpu_inference/distributed/utils.py +3 -2
  7. tpu_inference/envs.py +1 -1
  8. tpu_inference/executors/ray_distributed_executor.py +4 -1
  9. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  10. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +77 -54
  11. tpu_inference/layers/vllm/sharding.py +2 -2
  12. tpu_inference/lora/torch_punica_tpu.py +1 -2
  13. tpu_inference/models/common/model_loader.py +9 -9
  14. tpu_inference/models/jax/llama3.py +2 -1
  15. tpu_inference/models/jax/llama_eagle3.py +9 -5
  16. tpu_inference/models/jax/llama_guard_4.py +361 -0
  17. tpu_inference/models/jax/qwen2.py +2 -1
  18. tpu_inference/models/jax/qwen2_5_vl.py +2 -1
  19. tpu_inference/models/jax/qwen3.py +2 -1
  20. tpu_inference/models/jax/utils/weight_utils.py +21 -8
  21. tpu_inference/models/vllm/vllm_model_wrapper.py +4 -4
  22. tpu_inference/platforms/tpu_platform.py +5 -2
  23. tpu_inference/runner/compilation_manager.py +33 -15
  24. tpu_inference/runner/kv_cache_manager.py +8 -2
  25. tpu_inference/runner/tpu_runner.py +187 -99
  26. tpu_inference/spec_decode/jax/eagle3.py +2 -1
  27. tpu_inference/tpu_info.py +4 -3
  28. tpu_inference/utils.py +5 -4
  29. tpu_inference/worker/tpu_worker.py +158 -22
  30. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
  31. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +34 -39
  32. tpu_inference/mock/__init__.py +0 -0
  33. tpu_inference/mock/vllm_config_utils.py +0 -28
  34. tpu_inference/mock/vllm_envs.py +0 -1219
  35. tpu_inference/mock/vllm_logger.py +0 -212
  36. tpu_inference/mock/vllm_logging_utils.py +0 -15
  37. tpu_inference/models/jax/phi3.py +0 -376
  38. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
  39. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
  40. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
tests/lora/test_layers.py CHANGED
@@ -91,7 +91,6 @@ def populate_loras(
91
91
  index_to_id: list[Optional[int]],
92
92
  lora_layer: BaseLayerWithLoRA,
93
93
  baselayer_weights: torch.Tensor,
94
- generate_embeddings_tensor: int = 0,
95
94
  repeats: int = 1,
96
95
  ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
97
96
  """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
@@ -103,8 +102,6 @@ def populate_loras(
103
102
  lora_layer: the LoRAlayer to populate.
104
103
  baselayer_weights: the PyTorch tensor containing the layer's
105
104
  weights.
106
- generate_embeddings_tensor: whether to generate an
107
- embeddings tensor for each LoRA.
108
105
  repeats: must only be set for column parallel packed
109
106
  layers. Indicates the number of loras to compose
110
107
  together to create a single lora layer.
@@ -131,7 +128,6 @@ def populate_loras(
131
128
  baselayer_weights.device).init_random_lora(
132
129
  module_name=f"fake_{i}",
133
130
  weight=baselayer_weights,
134
- generate_embeddings_tensor=generate_embeddings_tensor,
135
131
  )
136
132
  sublora.lora_b = sublora.lora_b[(sublora_len *
137
133
  i):(sublora_len * (i + 1)), :]
@@ -147,7 +143,6 @@ def populate_loras(
147
143
  slot_idx,
148
144
  lora_a=lora.lora_a,
149
145
  lora_b=lora.lora_b,
150
- embeddings_tensor=lora.embeddings_tensor,
151
146
  )
152
147
 
153
148
  lora_dict[lora_id] = lora
@@ -546,7 +541,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
546
541
  index_to_id,
547
542
  lora_config.max_loras,
548
543
  vocab_size=512,
549
- extra_vocab_size=lora_config.lora_extra_vocab_size,
550
544
  )
551
545
  assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
552
546
  ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
tests/lora/utils.py CHANGED
@@ -24,7 +24,6 @@ class DummyLoRAManager:
24
24
  module_name: str,
25
25
  weight: torch.Tensor,
26
26
  rank: int = 8,
27
- generate_embeddings_tensor: int = 0,
28
27
  ):
29
28
  lora = LoRALayerWeights(
30
29
  module_name,
@@ -37,13 +36,6 @@ class DummyLoRAManager:
37
36
  dtype=weight.dtype,
38
37
  device=self._device),
39
38
  )
40
- if generate_embeddings_tensor:
41
- lora.embeddings_tensor = torch.rand(
42
- 5,
43
- generate_embeddings_tensor,
44
- dtype=weight.dtype,
45
- device=self._device,
46
- )
47
39
  self.set_module_lora(module_name, lora)
48
40
 
49
41
  return lora
tpu_inference/__init__.py CHANGED
@@ -1,21 +1,40 @@
1
- import os
2
-
3
1
  # The environment variables override should be imported before any other
4
2
  # modules to ensure that the environment variables are set before any
5
3
  # other modules are imported.
6
4
  import tpu_inference.env_override # noqa: F401
5
+ from tpu_inference import envs
7
6
  from tpu_inference import tpu_info as ti
8
7
  from tpu_inference.logger import init_logger
9
8
 
10
9
  logger = init_logger(__name__)
11
10
 
12
- if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
11
+ if "proxy" in envs.JAX_PLATFORMS:
13
12
  logger.info("Running vLLM on TPU via Pathways proxy.")
14
13
  # Must run pathwaysutils.initialize() before any JAX operations
15
14
  try:
15
+ import traceback
16
+
16
17
  import pathwaysutils
18
+ import vllm
19
+ from vllm.platforms import (resolve_current_platform_cls_qualname,
20
+ resolve_obj_by_qualname)
17
21
  pathwaysutils.initialize()
18
22
  logger.info("Module pathwaysutils is imported.")
23
+
24
+ # Pathways requires eager resolution of vllm.current_platform instead of
25
+ # lazy resolution in the normal code path. Since this part involves
26
+ # global topology discovery across multiple hosts, the platform
27
+ # resolution must happen before other components are loaded.
28
+ logger.info("Eagerly resolving vLLM current_platform for Pathways.")
29
+ platform_cls_qualname = resolve_current_platform_cls_qualname()
30
+ resolved_platform_instance = resolve_obj_by_qualname(
31
+ platform_cls_qualname)()
32
+ vllm.platforms._current_platform = resolved_platform_instance
33
+ vllm.platforms._init_trace = "".join(traceback.format_stack())
34
+ logger.info(
35
+ f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
36
+ )
37
+
19
38
  except Exception as e:
20
39
  logger.error(
21
40
  f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
@@ -1,17 +1,15 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import os
4
3
  from typing import Tuple
5
4
 
6
- PREFILL_SLICES = 'PREFILL_SLICES'
7
- DECODE_SLICES = 'DECODE_SLICES'
5
+ from tpu_inference import envs
8
6
 
9
7
 
10
8
  def is_disagg_enabled() -> bool:
11
9
  # We triggrer our code path as long as prefill slices are set. This
12
10
  # allows us to test interleave mode effectively with the code path
13
11
  # for comparison purposes.
14
- return PREFILL_SLICES in os.environ
12
+ return bool(envs.PREFILL_SLICES)
15
13
 
16
14
 
17
15
  def _parse_slices(slices_str: str) -> Tuple[int, ...]:
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
40
38
 
41
39
 
42
40
  def get_prefill_slices() -> Tuple[int, ...]:
43
- if PREFILL_SLICES not in os.environ:
41
+ if not envs.PREFILL_SLICES:
44
42
  return ()
45
- return _parse_slices(os.environ[PREFILL_SLICES])
43
+ return _parse_slices(envs.PREFILL_SLICES)
46
44
 
47
45
 
48
46
  def get_decode_slices() -> Tuple[int, ...]:
49
- if DECODE_SLICES not in os.environ:
47
+ if not envs.DECODE_SLICES:
50
48
  return ()
51
- return _parse_slices(os.environ[DECODE_SLICES])
49
+ return _parse_slices(envs.DECODE_SLICES)
@@ -60,7 +60,6 @@ D workflow:
60
60
 
61
61
  import copy
62
62
  import functools
63
- import os
64
63
  import threading
65
64
  import time
66
65
  from concurrent.futures import Future, ThreadPoolExecutor
@@ -86,6 +85,7 @@ if TYPE_CHECKING:
86
85
  from vllm.v1.core.kv_cache_manager import KVCacheBlocks
87
86
  from vllm.v1.request import Request
88
87
 
88
+ from tpu_inference import envs
89
89
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
90
90
  get_kv_ports,
91
91
  get_kv_transfer_port, get_node_id,
@@ -441,8 +441,7 @@ class TPUConnectorWorker:
441
441
 
442
442
  self.runner: TPUModelRunner = None
443
443
  self.mesh: Mesh = None
444
- self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
445
- "").lower() == "ray"
444
+ self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
446
445
  # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
447
446
  # The worker rank is assigned with vLLM's sorting logic, which does not work
448
447
  # for TPU host topology.
@@ -2,6 +2,7 @@ import os
2
2
 
3
3
  from vllm.utils.network_utils import get_ip
4
4
 
5
+ from tpu_inference import envs
5
6
  from tpu_inference.logger import init_logger
6
7
 
7
8
  logger = init_logger(__name__)
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
17
18
 
18
19
 
19
20
  def get_kv_ips() -> str:
20
- if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
21
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
21
22
  num_nodes = len(_NODES_KV_IP_PORT)
22
23
  ips = []
23
24
  for node_id in range(num_nodes):
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
28
29
 
29
30
 
30
31
  def get_kv_ports() -> str:
31
- if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
32
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
32
33
  num_nodes = len(_NODES_KV_IP_PORT)
33
34
  ports = []
34
35
  for node_id in range(num_nodes):
tpu_inference/envs.py CHANGED
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
26
26
  environment_variables: dict[str, Callable[[], Any]] = {
27
27
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
28
28
  "JAX_PLATFORMS":
29
- lambda: os.getenv("JAX_PLATFORMS", ""),
29
+ lambda: os.getenv("JAX_PLATFORMS", "").lower(),
30
30
  # TPU accelerator type (e.g., "v5litepod-16", "v4-8")
31
31
  "TPU_ACCELERATOR_TYPE":
32
32
  lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
@@ -108,6 +108,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
108
108
  ip_port = self.collective_rpc("get_node_kv_ip_port")
109
109
  for item in ip_port:
110
110
  set_node_kv_ip_port(item)
111
+ self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
112
+ self.vllm_config.ec_transfer_config is None
113
+ or not self.vllm_config.ec_transfer_config.is_ec_producer)
111
114
 
112
115
  def _initialize_ray_cluster(self) -> None:
113
116
  """Initialize the distributed cluster with Ray.
@@ -352,7 +355,7 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
352
355
  self.collective_rpc("init_worker", args=(all_kwargs, ))
353
356
  self.collective_rpc("init_device")
354
357
  if self.parallel_config.pipeline_parallel_size > 1:
355
- self._run_workers("initialize_pp_transfer_connect")
358
+ self.collective_rpc("initialize_pp_transfer_connect")
356
359
  self.collective_rpc("load_model")
357
360
 
358
361
  if self.use_ray_spmd_worker:
@@ -440,42 +440,54 @@ def _ragged_paged_attention_kernel(
440
440
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
441
441
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
442
442
 
443
- # Fetch effective kv from kv cache.
444
- def loop_body(i, offset):
445
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
446
- _async_copy(
447
- cache_hbm_ref.at[pl.ds(
448
- page_indices_ref[page_indices_offset + i] * page_size,
449
- sz)],
450
- vmem_ref.at[pl.ds(i * page_size, sz)],
451
- sem,
452
- wait,
443
+ if not wait:
444
+ # Fetch effective kv from kv cache.
445
+ def loop_body(i, offset):
446
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
447
+ _async_copy(
448
+ cache_hbm_ref.at[pl.ds(
449
+ page_indices_ref[page_indices_offset + i] * page_size,
450
+ sz)],
451
+ vmem_ref.at[pl.ds(i * page_size, sz)],
452
+ sem,
453
+ wait=False,
454
+ )
455
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
456
+ return offset + sz
457
+
458
+ offset = lax.fori_loop(
459
+ 0,
460
+ bkv_p_frm_cache,
461
+ loop_body,
462
+ 0, # offset
463
+ unroll=False,
453
464
  )
454
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
455
- return offset + sz
456
-
457
- offset = lax.fori_loop(
458
- 0,
459
- bkv_p_frm_cache,
460
- loop_body,
461
- 0, # offset
462
- unroll=False,
463
- )
464
465
 
465
- # Fetch kv directly from new kv.
466
- @pl.when(bkv_sz_frm_new > 0)
467
- def _fetch_bkv_from_new_kv():
468
- new_kv_len_start = q_end - kv_left_frm_new
469
- debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
470
- debug_print("[RPA debug] offset_in_bkv={}", offset)
466
+ # Fetch kv directly from new kv.
467
+ @pl.when(bkv_sz_frm_new > 0)
468
+ def _fetch_bkv_from_new_kv():
469
+ new_kv_len_start = q_end - kv_left_frm_new
470
+ debug_print("[RPA debug] new_kv_len_start={}",
471
+ new_kv_len_start)
472
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
473
+ _async_copy(
474
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476
+ sem,
477
+ wait,
478
+ )
479
+
480
+ return kv_len_start + offset, bkv_sz_frm_new
481
+ else:
482
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
483
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
471
484
  _async_copy(
472
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
473
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
474
- sem,
475
- wait,
485
+ src=dst,
486
+ dst=dst,
487
+ sem=sem,
488
+ wait=True,
476
489
  )
477
-
478
- return kv_len_start + offset, bkv_sz_frm_new
490
+ return kv_len_start + offset, bkv_sz_frm_new
479
491
 
480
492
  def _update_kv_cache(seq_idx,
481
493
  bkv_sem_idx,
@@ -511,30 +523,41 @@ def _ragged_paged_attention_kernel(
511
523
  debug_print("[RPA debug] p_ignore={}", p_ignore)
512
524
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
513
525
 
514
- def loop_body(i, states):
515
- update_sz, ignore = states
516
- sz = jnp.minimum(page_size - ignore, update_sz)
517
-
526
+ if not wait:
527
+
528
+ def loop_body(i, states):
529
+ update_sz, ignore = states
530
+ sz = jnp.minimum(page_size - ignore, update_sz)
531
+
532
+ _async_copy(
533
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
534
+ sz)],
535
+ cache_hbm_ref.at[pl.ds(
536
+ page_indices_ref[page_indices_offset + i] * page_size +
537
+ ignore,
538
+ sz,
539
+ )],
540
+ sem,
541
+ wait=False,
542
+ )
543
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
544
+ return update_sz - sz, 0
545
+
546
+ lax.fori_loop(
547
+ 0,
548
+ kv_p_end - kv_p_start,
549
+ loop_body,
550
+ (update_sz, ignore), # total transfer size
551
+ unroll=False,
552
+ )
553
+ else:
554
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
518
555
  _async_copy(
519
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
520
- cache_hbm_ref.at[pl.ds(
521
- page_indices_ref[page_indices_offset + i] * page_size +
522
- ignore,
523
- sz,
524
- )],
525
- sem,
526
- wait,
556
+ src=dst,
557
+ dst=dst,
558
+ sem=sem,
559
+ wait=True,
527
560
  )
528
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
529
- return update_sz - sz, 0
530
-
531
- lax.fori_loop(
532
- 0,
533
- kv_p_end - kv_p_start,
534
- loop_body,
535
- (update_sz, ignore), # total transfer size
536
- unroll=False,
537
- )
538
561
 
539
562
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
540
563
  sem = sems.at[1, bq_sem_idx]
@@ -475,42 +475,54 @@ def _ragged_paged_attention_kernel(
475
475
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
476
476
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
477
477
 
478
- # Fetch effective kv from kv cache.
479
- def loop_body(i, offset):
480
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
481
- _async_copy(
482
- cache_hbm_ref.at[pl.ds(
483
- page_indices_ref[page_indices_offset + i] * page_size,
484
- sz)],
485
- vmem_ref.at[pl.ds(i * page_size, sz)],
486
- sem,
487
- wait,
478
+ if not wait:
479
+ # Fetch effective kv from kv cache.
480
+ def loop_body(i, offset):
481
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
482
+ _async_copy(
483
+ cache_hbm_ref.at[pl.ds(
484
+ page_indices_ref[page_indices_offset + i] * page_size,
485
+ sz)],
486
+ vmem_ref.at[pl.ds(i * page_size, sz)],
487
+ sem,
488
+ wait=False,
489
+ )
490
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
491
+ return offset + sz
492
+
493
+ offset = lax.fori_loop(
494
+ 0,
495
+ bkv_p_frm_cache,
496
+ loop_body,
497
+ 0, # offset
498
+ unroll=False,
488
499
  )
489
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
490
- return offset + sz
491
-
492
- offset = lax.fori_loop(
493
- 0,
494
- bkv_p_frm_cache,
495
- loop_body,
496
- 0, # offset
497
- unroll=False,
498
- )
499
500
 
500
- # Fetch kv directly from new kv.
501
- @pl.when(bkv_sz_frm_new > 0)
502
- def _fetch_bkv_from_new_kv():
503
- new_kv_len_start = q_end - kv_left_frm_new
504
- debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
505
- debug_print("[RPA debug] offset_in_bkv={}", offset)
501
+ # Fetch kv directly from new kv.
502
+ @pl.when(bkv_sz_frm_new > 0)
503
+ def _fetch_bkv_from_new_kv():
504
+ new_kv_len_start = q_end - kv_left_frm_new
505
+ debug_print("[RPA debug] new_kv_len_start={}",
506
+ new_kv_len_start)
507
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
508
+ _async_copy(
509
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
510
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
511
+ sem,
512
+ wait,
513
+ )
514
+
515
+ return kv_len_start + offset, bkv_sz_frm_new
516
+ else:
517
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
518
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
506
519
  _async_copy(
507
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
508
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
509
- sem,
510
- wait,
520
+ src=dst,
521
+ dst=dst,
522
+ sem=sem,
523
+ wait=True,
511
524
  )
512
-
513
- return kv_len_start + offset, bkv_sz_frm_new
525
+ return kv_len_start + offset, bkv_sz_frm_new
514
526
 
515
527
  def _update_kv_cache(seq_idx,
516
528
  bkv_sem_idx,
@@ -546,30 +558,41 @@ def _ragged_paged_attention_kernel(
546
558
  debug_print("[RPA debug] p_ignore={}", p_ignore)
547
559
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
548
560
 
549
- def loop_body(i, states):
550
- update_sz, ignore = states
551
- sz = jnp.minimum(page_size - ignore, update_sz)
552
-
561
+ if not wait:
562
+
563
+ def loop_body(i, states):
564
+ update_sz, ignore = states
565
+ sz = jnp.minimum(page_size - ignore, update_sz)
566
+
567
+ _async_copy(
568
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
569
+ sz)],
570
+ cache_hbm_ref.at[pl.ds(
571
+ page_indices_ref[page_indices_offset + i] * page_size +
572
+ ignore,
573
+ sz,
574
+ )],
575
+ sem,
576
+ wait=False,
577
+ )
578
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
579
+ return update_sz - sz, 0
580
+
581
+ lax.fori_loop(
582
+ 0,
583
+ kv_p_end - kv_p_start,
584
+ loop_body,
585
+ (update_sz, ignore), # total transfer size
586
+ unroll=False,
587
+ )
588
+ else:
589
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
553
590
  _async_copy(
554
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
555
- cache_hbm_ref.at[pl.ds(
556
- page_indices_ref[page_indices_offset + i] * page_size +
557
- ignore,
558
- sz,
559
- )],
560
- sem,
561
- wait,
591
+ src=dst,
592
+ dst=dst,
593
+ sem=sem,
594
+ wait=True,
562
595
  )
563
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
564
- return update_sz - sz, 0
565
-
566
- lax.fori_loop(
567
- 0,
568
- kv_p_end - kv_p_start,
569
- loop_body,
570
- (update_sz, ignore), # total transfer size
571
- unroll=False,
572
- )
573
596
 
574
597
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
575
598
  sem = sems.at[1, bq_sem_idx]
@@ -19,6 +19,7 @@ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
19
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
20
20
  ParallelLMHead, VocabParallelEmbedding)
21
21
 
22
+ from tpu_inference import envs
22
23
  from tpu_inference.logger import init_logger
23
24
 
24
25
  P = PartitionSpec
@@ -211,8 +212,7 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
211
212
  def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
212
213
  if isinstance(tensor, tuple):
213
214
  return tuple(_sharded_device_put(t, sharding) for t in tensor)
214
- import os
215
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
215
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
216
216
  if multihost_backend != "ray":
217
217
  return jax.device_put(tensor, sharding)
218
218
 
@@ -239,7 +239,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
239
239
  lora_index_to_id: list[Optional[int]],
240
240
  max_loras: int,
241
241
  vocab_size: int,
242
- extra_vocab_size: int,
243
242
  ):
244
243
  # Pad the prompt mapping to avoid running into recompiles on the TPU
245
244
  # TODO: Should this happen inside mapping internally? If so how can we
@@ -258,7 +257,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
258
257
  lora_index_to_id,
259
258
  max_loras,
260
259
  vocab_size,
261
- extra_vocab_size,
260
+ 0, # extra_vocab_size
262
261
  "cpu",
263
262
  )
264
263
  with torchax.default_env():
@@ -36,19 +36,17 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
36
36
  from tpu_inference.models.jax.llama3 import LlamaForCausalLM
37
37
  from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
38
38
  from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
39
- from tpu_inference.models.jax.phi3 import Phi3ForCausalLM
40
- from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
39
+ from tpu_inference.models.jax.llama_guard_4 import LlamaGuard4ForCausalLM
41
40
  from tpu_inference.models.jax.qwen2_5_vl import \
42
41
  Qwen2_5_VLForConditionalGeneration
43
42
  from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
44
43
  _MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
45
44
  _MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
46
45
  _MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
47
- _MODEL_REGISTRY["Qwen2ForCausalLM"] = Qwen2ForCausalLM
46
+ _MODEL_REGISTRY["Llama4ForConditionalGeneration"] = LlamaGuard4ForCausalLM
48
47
  _MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
49
48
  _MODEL_REGISTRY[
50
49
  "Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
51
- _MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
52
50
  _MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
53
51
  _MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
54
52
 
@@ -57,8 +55,10 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
57
55
  if arch in _MODEL_REGISTRY:
58
56
  return _MODEL_REGISTRY[arch]
59
57
  raise UnsupportedArchitectureError(
60
- f"Model architectures {architectures} are not supported for now. "
61
- f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
58
+ f"Model architectures {architectures} not "
59
+ "registered in tpu-inference. Falling back to vLLM-native "
60
+ f"Pytorch definition. JAX-native architectures: {list(_MODEL_REGISTRY.keys())}"
61
+ )
62
62
 
63
63
 
64
64
  def _get_nnx_model(
@@ -217,7 +217,7 @@ def get_flax_model(
217
217
  hidden_states_sharding, # aux hidden states
218
218
  ),
219
219
  donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
220
- static_argnums=6, #6 is layer_name_to_kvcache_index
220
+ static_argnums=7, #7 is layer_name_to_kvcache_index
221
221
  )
222
222
  def run_model(graphdef, state, *args):
223
223
  model = nnx.merge(graphdef, state)
@@ -326,8 +326,8 @@ def get_model(
326
326
  # Convert the error message to a string to check its contents
327
327
  error_msg = str(e)
328
328
 
329
- logger.warning(f"Flax model failed with: '{error_msg}'. "
330
- "Falling back to vLLM implementation.")
329
+ logger.warning(error_msg)
330
+
331
331
  # Fall back to the vLLM model and updating the dtype accordingly
332
332
  vllm_config.model_config.dtype = j2t_dtype(
333
333
  vllm_config.model_config.dtype.dtype)
@@ -368,7 +368,8 @@ class LlamaForCausalLM(nnx.Module):
368
368
  "lm_head": "model.lm_head",
369
369
  })
370
370
 
371
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
371
+ metadata_map = get_default_maps(self.vllm_config.model_config,
372
+ self.mesh, mappings)
372
373
  load_hf_weights(vllm_config=self.vllm_config,
373
374
  model=self,
374
375
  metadata_map=metadata_map,