tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -60,6 +60,7 @@ D workflow:
60
60
 
61
61
  import copy
62
62
  import functools
63
+ import os
63
64
  import threading
64
65
  import time
65
66
  from concurrent.futures import Future, ThreadPoolExecutor
@@ -85,7 +86,6 @@ if TYPE_CHECKING:
85
86
  from vllm.v1.core.kv_cache_manager import KVCacheBlocks
86
87
  from vllm.v1.request import Request
87
88
 
88
- from tpu_inference import envs
89
89
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
90
90
  get_kv_ports,
91
91
  get_kv_transfer_port, get_node_id,
@@ -441,7 +441,8 @@ class TPUConnectorWorker:
441
441
 
442
442
  self.runner: TPUModelRunner = None
443
443
  self.mesh: Mesh = None
444
- self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
444
+ self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
445
+ "").lower() == "ray"
445
446
  # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
446
447
  # The worker rank is assigned with vLLM's sorting logic, which does not work
447
448
  # for TPU host topology.
@@ -457,6 +458,7 @@ class TPUConnectorWorker:
457
458
  self.side_channel_port = get_side_channel_port()
458
459
 
459
460
  self.kv_transfer_server = None
461
+ self._maybe_start_p2p_server()
460
462
  self.zmq_cxt = zmq.Context()
461
463
  if self.is_producer:
462
464
  ready_event = threading.Event()
@@ -498,7 +500,6 @@ class TPUConnectorWorker:
498
500
  self.shape = list(kv_layer.shape)
499
501
  self.dtype = kv_layer.dtype
500
502
  self.sharding = kv_layer.sharding
501
- self._maybe_start_p2p_server()
502
503
 
503
504
  def _maybe_start_p2p_server(self):
504
505
  if self.kv_transfer_server is not None:
@@ -2,7 +2,6 @@ import os
2
2
 
3
3
  from vllm.utils.network_utils import get_ip
4
4
 
5
- from tpu_inference import envs
6
5
  from tpu_inference.logger import init_logger
7
6
 
8
7
  logger = init_logger(__name__)
@@ -18,7 +17,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
18
17
 
19
18
 
20
19
  def get_kv_ips() -> str:
21
- if envs.TPU_MULTIHOST_BACKEND == "ray":
20
+ if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
22
21
  num_nodes = len(_NODES_KV_IP_PORT)
23
22
  ips = []
24
23
  for node_id in range(num_nodes):
@@ -29,7 +28,7 @@ def get_kv_ips() -> str:
29
28
 
30
29
 
31
30
  def get_kv_ports() -> str:
32
- if envs.TPU_MULTIHOST_BACKEND == "ray":
31
+ if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
33
32
  num_nodes = len(_NODES_KV_IP_PORT)
34
33
  ports = []
35
34
  for node_id in range(num_nodes):
tpu_inference/envs.py CHANGED
@@ -15,64 +15,18 @@ if TYPE_CHECKING:
15
15
  PREFILL_SLICES: str = ""
16
16
  DECODE_SLICES: str = ""
17
17
  SKIP_JAX_PRECOMPILE: bool = False
18
- VLLM_XLA_CHECK_RECOMPILATION: bool = False
19
18
  MODEL_IMPL_TYPE: str = "flax_nnx"
20
19
  NEW_MODEL_DESIGN: bool = False
21
20
  PHASED_PROFILING_DIR: str = ""
22
21
  PYTHON_TRACER_LEVEL: int = 1
23
22
  USE_MOE_EP_KERNEL: bool = False
24
- NUM_SLICES: int = 1
25
23
  RAY_USAGE_STATS_ENABLED: str = "0"
26
24
  VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
27
25
 
28
-
29
- def env_with_choices(
30
- env_name: str,
31
- default: str | None,
32
- choices: list[str] | Callable[[], list[str]],
33
- case_sensitive: bool = True,
34
- ) -> Callable[[], str | None]:
35
- """
36
- Create a lambda that validates environment variable against allowed choices
37
-
38
- Args:
39
- env_name: Name of the environment variable
40
- default: Default value if not set (can be None)
41
- choices: List of valid string options or callable that returns list
42
- case_sensitive: Whether validation should be case sensitive
43
-
44
- Returns:
45
- Lambda function for environment_variables dict
46
- """
47
-
48
- def _get_validated_env() -> str | None:
49
- value = os.getenv(env_name)
50
- if value is None:
51
- return default
52
-
53
- # Resolve choices if it's a callable (for lazy loading)
54
- actual_choices = choices() if callable(choices) else choices
55
-
56
- if not case_sensitive:
57
- check_value = value.lower()
58
- check_choices = [choice.lower() for choice in actual_choices]
59
- else:
60
- check_value = value
61
- check_choices = actual_choices
62
-
63
- if check_value not in check_choices:
64
- raise ValueError(f"Invalid value '{value}' for {env_name}. "
65
- f"Valid options: {actual_choices}.")
66
-
67
- return value
68
-
69
- return _get_validated_env
70
-
71
-
72
26
  environment_variables: dict[str, Callable[[], Any]] = {
73
27
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
74
28
  "JAX_PLATFORMS":
75
- lambda: os.getenv("JAX_PLATFORMS", "").lower(),
29
+ lambda: os.getenv("JAX_PLATFORMS", ""),
76
30
  # TPU accelerator type (e.g., "v5litepod-16", "v4-8")
77
31
  "TPU_ACCELERATOR_TYPE":
78
32
  lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
@@ -84,7 +38,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
84
38
  lambda: os.getenv("TPU_WORKER_ID", None),
85
39
  # Backend for multi-host communication on TPU
86
40
  "TPU_MULTIHOST_BACKEND":
87
- env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
41
+ lambda: os.getenv("TPU_MULTIHOST_BACKEND", "").lower(),
88
42
  # Slice configuration for disaggregated prefill workers
89
43
  "PREFILL_SLICES":
90
44
  lambda: os.getenv("PREFILL_SLICES", ""),
@@ -93,35 +47,28 @@ environment_variables: dict[str, Callable[[], Any]] = {
93
47
  lambda: os.getenv("DECODE_SLICES", ""),
94
48
  # Skip JAX precompilation step during initialization
95
49
  "SKIP_JAX_PRECOMPILE":
96
- lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
97
- # Check for XLA recompilation during execution
98
- "VLLM_XLA_CHECK_RECOMPILATION":
99
- lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
50
+ lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
100
51
  # Model implementation type (e.g., "flax_nnx")
101
52
  "MODEL_IMPL_TYPE":
102
- env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
103
- ["vllm", "flax_nnx", "jetpack"]),
53
+ lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
104
54
  # Enable new experimental model design
105
55
  "NEW_MODEL_DESIGN":
106
- lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
56
+ lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))),
107
57
  # Directory to store phased profiling output
108
58
  "PHASED_PROFILING_DIR":
109
59
  lambda: os.getenv("PHASED_PROFILING_DIR", ""),
110
60
  # Python tracer level for profiling
111
61
  "PYTHON_TRACER_LEVEL":
112
- lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
62
+ lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")),
113
63
  # Use custom expert-parallel kernel for MoE (Mixture of Experts)
114
64
  "USE_MOE_EP_KERNEL":
115
- lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
116
- # Number of TPU slices for multi-slice mesh
117
- "NUM_SLICES":
118
- lambda: int(os.getenv("NUM_SLICES") or "1"),
65
+ lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
119
66
  # Enable/disable Ray usage statistics collection
120
67
  "RAY_USAGE_STATS_ENABLED":
121
68
  lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
122
69
  # Ray compiled DAG channel type for TPU
123
70
  "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
124
- env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
71
+ lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm"),
125
72
  }
126
73
 
127
74
 
@@ -108,9 +108,6 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
108
108
  ip_port = self.collective_rpc("get_node_kv_ip_port")
109
109
  for item in ip_port:
110
110
  set_node_kv_ip_port(item)
111
- self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
112
- self.vllm_config.ec_transfer_config is None
113
- or not self.vllm_config.ec_transfer_config.is_ec_producer)
114
111
 
115
112
  def _initialize_ray_cluster(self) -> None:
116
113
  """Initialize the distributed cluster with Ray.
@@ -134,21 +131,10 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
134
131
  f"current platform {current_platform.device_name} does not "
135
132
  "support ray.")
136
133
 
137
- pp_size = self.parallel_config.pipeline_parallel_size
138
- placement_group_specs: List[Dict[str, float]] = []
139
-
140
- ray_nodes = ray.nodes()
141
- logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
142
-
143
- if pp_size == 1:
144
- placement_group_specs = [{
145
- device_str: node['Resources'][device_str]
146
- } for node in ray_nodes]
147
- else:
148
- num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
149
- placement_group_specs = [{
150
- device_str: num_devices_per_pp_rank
151
- } for _ in range(pp_size)]
134
+ placement_group_specs: List[Dict[str, float]] = [{
135
+ device_str:
136
+ node['Resources'][device_str]
137
+ } for node in ray.nodes()]
152
138
 
153
139
  # vLLM engine is also a worker to execute model with an accelerator,
154
140
  # so it requires to have the device in a current node. Check if
@@ -343,8 +329,6 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
343
329
  all_kwargs = []
344
330
  for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
345
331
  local_rank = node_workers[node_id].index(rank)
346
- ip = sorted_worker_metadata[rank].ip
347
- prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
348
332
  kwargs = dict(
349
333
  vllm_config=self.vllm_config,
350
334
  local_rank=local_rank,
@@ -352,26 +336,22 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
352
336
  distributed_init_method=distributed_init_method,
353
337
  is_driver_worker=(not self.parallel_config)
354
338
  or (rank % self.parallel_config.tensor_parallel_size == 0),
355
- ip=ip,
356
- prev_worker_ip=prev_ip,
357
339
  )
358
340
  all_kwargs.append(kwargs)
359
341
  self.collective_rpc("init_worker", args=(all_kwargs, ))
360
342
  self.collective_rpc("init_device")
361
- if self.parallel_config.pipeline_parallel_size > 1:
362
- self.collective_rpc("initialize_pp_transfer_connect")
363
343
  self.collective_rpc("load_model")
364
344
 
365
345
  if self.use_ray_spmd_worker:
366
346
  for pp_rank in range(self.parallel_config.pipeline_parallel_size):
367
347
  self.pp_tp_workers.append([])
368
- num_tp_workers = int(
369
- self.parallel_config.tensor_parallel_size //
370
- num_tpu_per_worker)
371
- for tp_rank in range(num_tp_workers):
372
- # PP=2, TP=4, num_tpu_per_worker=2
373
- # pp_tp_workers = [[0, 1], [2, 3]]
374
- rank = (pp_rank * num_tp_workers) + tp_rank
348
+ for tp_rank in range(
349
+ int(self.parallel_config.tensor_parallel_size //
350
+ num_tpu_per_worker)):
351
+ # PP=2, TP=4
352
+ # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
353
+ rank = (pp_rank * self.parallel_config.tensor_parallel_size
354
+ ) + tp_rank
375
355
  assert len(self.pp_tp_workers[pp_rank]) == tp_rank
376
356
  assert pp_rank < len(self.pp_tp_workers)
377
357
  self.pp_tp_workers[pp_rank].append(self.workers[rank])