tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (54) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tests/test_envs.py +32 -11
  5. tests/test_utils.py +1 -2
  6. tpu_inference/__init__.py +22 -3
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +3 -4
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +61 -8
  11. tpu_inference/executors/ray_distributed_executor.py +31 -11
  12. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
  15. tpu_inference/layers/common/attention_interface.py +7 -1
  16. tpu_inference/layers/common/sharding.py +5 -5
  17. tpu_inference/layers/vllm/fused_moe.py +74 -25
  18. tpu_inference/layers/vllm/quantization/common.py +6 -1
  19. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
  20. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  21. tpu_inference/layers/vllm/sharding.py +2 -2
  22. tpu_inference/lora/torch_punica_tpu.py +1 -2
  23. tpu_inference/models/common/model_loader.py +45 -11
  24. tpu_inference/models/jax/llama3.py +2 -1
  25. tpu_inference/models/jax/llama_eagle3.py +8 -5
  26. tpu_inference/models/jax/llama_guard_4.py +361 -0
  27. tpu_inference/models/jax/qwen2.py +2 -1
  28. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  29. tpu_inference/models/jax/qwen3.py +2 -1
  30. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  31. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  32. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
  33. tpu_inference/platforms/tpu_platform.py +28 -22
  34. tpu_inference/runner/compilation_manager.py +144 -59
  35. tpu_inference/runner/kv_cache_manager.py +17 -18
  36. tpu_inference/runner/persistent_batch_manager.py +40 -2
  37. tpu_inference/runner/structured_decoding_manager.py +2 -3
  38. tpu_inference/runner/tpu_runner.py +271 -147
  39. tpu_inference/runner/utils.py +2 -2
  40. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  41. tpu_inference/tpu_info.py +4 -3
  42. tpu_inference/utils.py +36 -13
  43. tpu_inference/worker/tpu_worker.py +162 -25
  44. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
  45. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
  46. tpu_inference/mock/__init__.py +0 -0
  47. tpu_inference/mock/vllm_config_utils.py +0 -28
  48. tpu_inference/mock/vllm_envs.py +0 -1219
  49. tpu_inference/mock/vllm_logger.py +0 -212
  50. tpu_inference/mock/vllm_logging_utils.py +0 -15
  51. tpu_inference/models/jax/phi3.py +0 -376
  52. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  53. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  54. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import os
2
2
 
3
3
  from vllm.utils.network_utils import get_ip
4
4
 
5
+ from tpu_inference import envs
5
6
  from tpu_inference.logger import init_logger
6
7
 
7
8
  logger = init_logger(__name__)
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
17
18
 
18
19
 
19
20
  def get_kv_ips() -> str:
20
- if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
21
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
21
22
  num_nodes = len(_NODES_KV_IP_PORT)
22
23
  ips = []
23
24
  for node_id in range(num_nodes):
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
28
29
 
29
30
 
30
31
  def get_kv_ports() -> str:
31
- if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
32
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
32
33
  num_nodes = len(_NODES_KV_IP_PORT)
33
34
  ports = []
34
35
  for node_id in range(num_nodes):
tpu_inference/envs.py CHANGED
@@ -15,18 +15,64 @@ if TYPE_CHECKING:
15
15
  PREFILL_SLICES: str = ""
16
16
  DECODE_SLICES: str = ""
17
17
  SKIP_JAX_PRECOMPILE: bool = False
18
+ VLLM_XLA_CHECK_RECOMPILATION: bool = False
18
19
  MODEL_IMPL_TYPE: str = "flax_nnx"
19
20
  NEW_MODEL_DESIGN: bool = False
20
21
  PHASED_PROFILING_DIR: str = ""
21
22
  PYTHON_TRACER_LEVEL: int = 1
22
23
  USE_MOE_EP_KERNEL: bool = False
24
+ NUM_SLICES: int = 1
23
25
  RAY_USAGE_STATS_ENABLED: str = "0"
24
26
  VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
25
27
 
28
+
29
+ def env_with_choices(
30
+ env_name: str,
31
+ default: str | None,
32
+ choices: list[str] | Callable[[], list[str]],
33
+ case_sensitive: bool = True,
34
+ ) -> Callable[[], str | None]:
35
+ """
36
+ Create a lambda that validates environment variable against allowed choices
37
+
38
+ Args:
39
+ env_name: Name of the environment variable
40
+ default: Default value if not set (can be None)
41
+ choices: List of valid string options or callable that returns list
42
+ case_sensitive: Whether validation should be case sensitive
43
+
44
+ Returns:
45
+ Lambda function for environment_variables dict
46
+ """
47
+
48
+ def _get_validated_env() -> str | None:
49
+ value = os.getenv(env_name)
50
+ if value is None:
51
+ return default
52
+
53
+ # Resolve choices if it's a callable (for lazy loading)
54
+ actual_choices = choices() if callable(choices) else choices
55
+
56
+ if not case_sensitive:
57
+ check_value = value.lower()
58
+ check_choices = [choice.lower() for choice in actual_choices]
59
+ else:
60
+ check_value = value
61
+ check_choices = actual_choices
62
+
63
+ if check_value not in check_choices:
64
+ raise ValueError(f"Invalid value '{value}' for {env_name}. "
65
+ f"Valid options: {actual_choices}.")
66
+
67
+ return value
68
+
69
+ return _get_validated_env
70
+
71
+
26
72
  environment_variables: dict[str, Callable[[], Any]] = {
27
73
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
28
74
  "JAX_PLATFORMS":
29
- lambda: os.getenv("JAX_PLATFORMS", ""),
75
+ lambda: os.getenv("JAX_PLATFORMS", "").lower(),
30
76
  # TPU accelerator type (e.g., "v5litepod-16", "v4-8")
31
77
  "TPU_ACCELERATOR_TYPE":
32
78
  lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
@@ -38,7 +84,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
38
84
  lambda: os.getenv("TPU_WORKER_ID", None),
39
85
  # Backend for multi-host communication on TPU
40
86
  "TPU_MULTIHOST_BACKEND":
41
- lambda: os.getenv("TPU_MULTIHOST_BACKEND", "").lower(),
87
+ env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
42
88
  # Slice configuration for disaggregated prefill workers
43
89
  "PREFILL_SLICES":
44
90
  lambda: os.getenv("PREFILL_SLICES", ""),
@@ -47,28 +93,35 @@ environment_variables: dict[str, Callable[[], Any]] = {
47
93
  lambda: os.getenv("DECODE_SLICES", ""),
48
94
  # Skip JAX precompilation step during initialization
49
95
  "SKIP_JAX_PRECOMPILE":
50
- lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
96
+ lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
97
+ # Check for XLA recompilation during execution
98
+ "VLLM_XLA_CHECK_RECOMPILATION":
99
+ lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
51
100
  # Model implementation type (e.g., "flax_nnx")
52
101
  "MODEL_IMPL_TYPE":
53
- lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
102
+ env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
103
+ ["vllm", "flax_nnx", "jetpack"]),
54
104
  # Enable new experimental model design
55
105
  "NEW_MODEL_DESIGN":
56
- lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))),
106
+ lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
57
107
  # Directory to store phased profiling output
58
108
  "PHASED_PROFILING_DIR":
59
109
  lambda: os.getenv("PHASED_PROFILING_DIR", ""),
60
110
  # Python tracer level for profiling
61
111
  "PYTHON_TRACER_LEVEL":
62
- lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")),
112
+ lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
63
113
  # Use custom expert-parallel kernel for MoE (Mixture of Experts)
64
114
  "USE_MOE_EP_KERNEL":
65
- lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
115
+ lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
116
+ # Number of TPU slices for multi-slice mesh
117
+ "NUM_SLICES":
118
+ lambda: int(os.getenv("NUM_SLICES") or "1"),
66
119
  # Enable/disable Ray usage statistics collection
67
120
  "RAY_USAGE_STATS_ENABLED":
68
121
  lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
69
122
  # Ray compiled DAG channel type for TPU
70
123
  "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
71
- lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm"),
124
+ env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
72
125
  }
73
126
 
74
127
 
@@ -108,6 +108,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
108
108
  ip_port = self.collective_rpc("get_node_kv_ip_port")
109
109
  for item in ip_port:
110
110
  set_node_kv_ip_port(item)
111
+ self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
112
+ self.vllm_config.ec_transfer_config is None
113
+ or not self.vllm_config.ec_transfer_config.is_ec_producer)
111
114
 
112
115
  def _initialize_ray_cluster(self) -> None:
113
116
  """Initialize the distributed cluster with Ray.
@@ -131,10 +134,21 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
131
134
  f"current platform {current_platform.device_name} does not "
132
135
  "support ray.")
133
136
 
134
- placement_group_specs: List[Dict[str, float]] = [{
135
- device_str:
136
- node['Resources'][device_str]
137
- } for node in ray.nodes()]
137
+ pp_size = self.parallel_config.pipeline_parallel_size
138
+ placement_group_specs: List[Dict[str, float]] = []
139
+
140
+ ray_nodes = ray.nodes()
141
+ logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
142
+
143
+ if pp_size == 1:
144
+ placement_group_specs = [{
145
+ device_str: node['Resources'][device_str]
146
+ } for node in ray_nodes]
147
+ else:
148
+ num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
149
+ placement_group_specs = [{
150
+ device_str: num_devices_per_pp_rank
151
+ } for _ in range(pp_size)]
138
152
 
139
153
  # vLLM engine is also a worker to execute model with an accelerator,
140
154
  # so it requires to have the device in a current node. Check if
@@ -329,6 +343,8 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
329
343
  all_kwargs = []
330
344
  for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
331
345
  local_rank = node_workers[node_id].index(rank)
346
+ ip = sorted_worker_metadata[rank].ip
347
+ prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
332
348
  kwargs = dict(
333
349
  vllm_config=self.vllm_config,
334
350
  local_rank=local_rank,
@@ -336,22 +352,26 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
336
352
  distributed_init_method=distributed_init_method,
337
353
  is_driver_worker=(not self.parallel_config)
338
354
  or (rank % self.parallel_config.tensor_parallel_size == 0),
355
+ ip=ip,
356
+ prev_worker_ip=prev_ip,
339
357
  )
340
358
  all_kwargs.append(kwargs)
341
359
  self.collective_rpc("init_worker", args=(all_kwargs, ))
342
360
  self.collective_rpc("init_device")
361
+ if self.parallel_config.pipeline_parallel_size > 1:
362
+ self.collective_rpc("initialize_pp_transfer_connect")
343
363
  self.collective_rpc("load_model")
344
364
 
345
365
  if self.use_ray_spmd_worker:
346
366
  for pp_rank in range(self.parallel_config.pipeline_parallel_size):
347
367
  self.pp_tp_workers.append([])
348
- for tp_rank in range(
349
- int(self.parallel_config.tensor_parallel_size //
350
- num_tpu_per_worker)):
351
- # PP=2, TP=4
352
- # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
353
- rank = (pp_rank * self.parallel_config.tensor_parallel_size
354
- ) + tp_rank
368
+ num_tp_workers = int(
369
+ self.parallel_config.tensor_parallel_size //
370
+ num_tpu_per_worker)
371
+ for tp_rank in range(num_tp_workers):
372
+ # PP=2, TP=4, num_tpu_per_worker=2
373
+ # pp_tp_workers = [[0, 1], [2, 3]]
374
+ rank = (pp_rank * num_tp_workers) + tp_rank
355
375
  assert len(self.pp_tp_workers[pp_rank]) == tp_rank
356
376
  assert pp_rank < len(self.pp_tp_workers)
357
377
  self.pp_tp_workers[pp_rank].append(self.workers[rank])