tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
tests/lora/test_layers.py CHANGED
@@ -18,7 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
18
18
  ReplicatedLinearWithLoRA,
19
19
  RowParallelLinearWithLoRA)
20
20
  # yapf: enable
21
- from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
21
+ from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
22
22
  from vllm.lora.punica_wrapper import get_punica_wrapper
23
23
  from vllm.model_executor.layers.linear import (ColumnParallelLinear,
24
24
  MergedColumnParallelLinear,
@@ -91,7 +91,6 @@ def populate_loras(
91
91
  index_to_id: list[Optional[int]],
92
92
  lora_layer: BaseLayerWithLoRA,
93
93
  baselayer_weights: torch.Tensor,
94
- generate_embeddings_tensor: int = 0,
95
94
  repeats: int = 1,
96
95
  ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
97
96
  """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
@@ -103,8 +102,6 @@ def populate_loras(
103
102
  lora_layer: the LoRAlayer to populate.
104
103
  baselayer_weights: the PyTorch tensor containing the layer's
105
104
  weights.
106
- generate_embeddings_tensor: whether to generate an
107
- embeddings tensor for each LoRA.
108
105
  repeats: must only be set for column parallel packed
109
106
  layers. Indicates the number of loras to compose
110
107
  together to create a single lora layer.
@@ -131,7 +128,6 @@ def populate_loras(
131
128
  baselayer_weights.device).init_random_lora(
132
129
  module_name=f"fake_{i}",
133
130
  weight=baselayer_weights,
134
- generate_embeddings_tensor=generate_embeddings_tensor,
135
131
  )
136
132
  sublora.lora_b = sublora.lora_b[(sublora_len *
137
133
  i):(sublora_len * (i + 1)), :]
@@ -147,7 +143,6 @@ def populate_loras(
147
143
  slot_idx,
148
144
  lora_a=lora.lora_a,
149
145
  lora_b=lora.lora_b,
150
- embeddings_tensor=lora.embeddings_tensor,
151
146
  )
152
147
 
153
148
  lora_dict[lora_id] = lora
@@ -210,6 +205,9 @@ def create_random_inputs(
210
205
  @pytest.mark.parametrize("repeats", [1, 2, 3])
211
206
  @pytest.mark.parametrize("stage", [True, False])
212
207
  def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
208
+ # TODO(Qiliang Cui): Remove when issue is resolved.
209
+ if 'TPU7x' in jax.devices()[0].device_kind:
210
+ pytest.skip("Skipping test on TPU TPU7x.")
213
211
  set_random_seed(6)
214
212
 
215
213
  max_loras = 9
@@ -546,7 +544,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
546
544
  index_to_id,
547
545
  lora_config.max_loras,
548
546
  vocab_size=512,
549
- extra_vocab_size=lora_config.lora_extra_vocab_size,
550
547
  )
551
548
  assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
552
549
  ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
@@ -0,0 +1,53 @@
1
+ import os
2
+ import time
3
+
4
+ import pytest
5
+ import vllm
6
+ from vllm.lora.request import LoRARequest
7
+
8
+ TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
9
+
10
+
11
+ @pytest.mark.parametrize("tp", TP)
12
+ def test_lora_performance(tp):
13
+ prompt = "What is 1+1? \n"
14
+ llm_without_lora = vllm.LLM(
15
+ model="Qwen/Qwen2.5-3B-Instruct",
16
+ max_model_len=256,
17
+ max_num_batched_tokens=64,
18
+ max_num_seqs=8,
19
+ tensor_parallel_size=tp,
20
+ )
21
+ start_time = time.time()
22
+ llm_without_lora.generate(
23
+ prompt,
24
+ sampling_params=vllm.SamplingParams(max_tokens=16, temperature=0),
25
+ )[0].outputs[0].text
26
+ base_time = time.time() - start_time
27
+
28
+ del llm_without_lora
29
+ # Waiting for TPUs to be released
30
+ time.sleep(10)
31
+
32
+ llm_with_lora = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
33
+ max_model_len=256,
34
+ max_num_batched_tokens=64,
35
+ max_num_seqs=8,
36
+ tensor_parallel_size=tp,
37
+ enable_lora=True,
38
+ max_loras=1,
39
+ max_lora_rank=8)
40
+ lora_request = LoRARequest(
41
+ "lora_adapter_2", 2,
42
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
43
+ start_time = time.time()
44
+ llm_with_lora.generate(prompt,
45
+ sampling_params=vllm.SamplingParams(max_tokens=16,
46
+ temperature=0),
47
+ lora_request=lora_request)[0].outputs[0].text
48
+ lora_time = time.time() - start_time
49
+ print(f"Base time: {base_time}, LoRA time: {lora_time}")
50
+ assert (base_time /
51
+ lora_time) < 8, f"Base time: {base_time}, LoRA time: {lora_time}"
52
+
53
+ del llm_with_lora
tests/lora/utils.py CHANGED
@@ -24,7 +24,6 @@ class DummyLoRAManager:
24
24
  module_name: str,
25
25
  weight: torch.Tensor,
26
26
  rank: int = 8,
27
- generate_embeddings_tensor: int = 0,
28
27
  ):
29
28
  lora = LoRALayerWeights(
30
29
  module_name,
@@ -37,13 +36,6 @@ class DummyLoRAManager:
37
36
  dtype=weight.dtype,
38
37
  device=self._device),
39
38
  )
40
- if generate_embeddings_tensor:
41
- lora.embeddings_tensor = torch.rand(
42
- 5,
43
- generate_embeddings_tensor,
44
- dtype=weight.dtype,
45
- device=self._device,
46
- )
47
39
  self.set_module_lora(module_name, lora)
48
40
 
49
41
  return lora
tests/test_envs.py CHANGED
@@ -56,6 +56,13 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
56
56
 
57
57
 
58
58
  def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
59
+ # Ensure clean environment for boolean vars by setting to default "0"
60
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
61
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
62
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
63
+ monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "0")
64
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
65
+
59
66
  # Test SKIP_JAX_PRECOMPILE (default False)
60
67
  assert envs.SKIP_JAX_PRECOMPILE is False
61
68
  monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
@@ -63,6 +70,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
63
70
  monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
64
71
  assert envs.SKIP_JAX_PRECOMPILE is False
65
72
 
73
+ # Test VLLM_XLA_CHECK_RECOMPILATION (default False)
74
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
75
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
76
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
77
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
78
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
79
+
66
80
  # Test NEW_MODEL_DESIGN (default False)
67
81
  assert envs.NEW_MODEL_DESIGN is False
68
82
  monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
@@ -73,22 +87,110 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
73
87
  monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
74
88
  assert envs.USE_MOE_EP_KERNEL is True
75
89
 
90
+ # Test ENABLE_QUANTIZED_MATMUL_KERNEL (default False)
91
+ assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is False
92
+ monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "1")
93
+ assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is True
94
+
95
+
96
+ def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
97
+ """Test that boolean env vars accept string values like 'True' and 'False'"""
98
+
99
+ # Test NEW_MODEL_DESIGN with string "True"
100
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
101
+ assert envs.NEW_MODEL_DESIGN is True
102
+
103
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
104
+ assert envs.NEW_MODEL_DESIGN is True
105
+
106
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
107
+ assert envs.NEW_MODEL_DESIGN is False
108
+
109
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
110
+ assert envs.NEW_MODEL_DESIGN is False
111
+
112
+ # Test SKIP_JAX_PRECOMPILE with string values
113
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
114
+ assert envs.SKIP_JAX_PRECOMPILE is True
115
+
116
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
117
+ assert envs.SKIP_JAX_PRECOMPILE is False
118
+
119
+ # Test VLLM_XLA_CHECK_RECOMPILATION with string values
120
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
121
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
122
+
123
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
124
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
125
+
126
+ # Test USE_MOE_EP_KERNEL with string values
127
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
128
+ assert envs.USE_MOE_EP_KERNEL is True
129
+
130
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
131
+ assert envs.USE_MOE_EP_KERNEL is False
132
+
133
+
134
+ def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
135
+ """Test that boolean env vars raise errors for invalid values"""
136
+
137
+ # Test invalid value for NEW_MODEL_DESIGN
138
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
139
+ with pytest.raises(
140
+ ValueError,
141
+ match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
142
+ _ = envs.NEW_MODEL_DESIGN
143
+
144
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
145
+ with pytest.raises(ValueError,
146
+ match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
147
+ _ = envs.NEW_MODEL_DESIGN
148
+
149
+ # Test invalid value for SKIP_JAX_PRECOMPILE
150
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
151
+ with pytest.raises(
152
+ ValueError,
153
+ match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
154
+ _ = envs.SKIP_JAX_PRECOMPILE
155
+
156
+
157
+ def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
158
+ """Test that empty string returns default value"""
159
+
160
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "")
161
+ assert envs.NEW_MODEL_DESIGN is False # Should return default
162
+
163
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
164
+ assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
165
+
76
166
 
77
167
  def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
168
+ # Ensure clean environment for integer vars by setting to defaults
169
+ monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
170
+ monkeypatch.setenv("NUM_SLICES", "1")
171
+
78
172
  assert envs.PYTHON_TRACER_LEVEL == 1
79
173
  monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
80
174
  assert envs.PYTHON_TRACER_LEVEL == 3
81
175
  monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
82
176
  assert envs.PYTHON_TRACER_LEVEL == 0
83
177
 
178
+ # Test NUM_SLICES (default 1)
179
+ assert envs.NUM_SLICES == 1
180
+ monkeypatch.setenv("NUM_SLICES", "2")
181
+ assert envs.NUM_SLICES == 2
182
+ monkeypatch.setenv("NUM_SLICES", "4")
183
+ assert envs.NUM_SLICES == 4
84
184
 
85
- def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
86
- monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
87
- assert envs.TPU_MULTIHOST_BACKEND == "grpc"
88
185
 
89
- monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
186
+ def test_model_impl_type_choices(monkeypatch: pytest.MonkeyPatch):
187
+ # Test case sensitive choices
188
+ monkeypatch.setenv("MODEL_IMPL_TYPE", "flax_nnx")
90
189
  assert envs.MODEL_IMPL_TYPE == "flax_nnx"
91
190
 
191
+ monkeypatch.setenv("MODEL_IMPL_TYPE", "vllm")
192
+ assert envs.MODEL_IMPL_TYPE == "vllm"
193
+
92
194
 
93
195
  def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
94
196
  monkeypatch.delenv("JAX_PLATFORMS", raising=False)
@@ -117,8 +219,6 @@ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
117
219
  assert envs.RAY_USAGE_STATS_ENABLED == "1"
118
220
 
119
221
  assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
120
- monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
121
- assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
122
222
 
123
223
 
124
224
  def test_invalid_attribute_raises_error():
@@ -134,6 +234,7 @@ def test_dir_returns_all_env_vars():
134
234
  assert "JAX_PLATFORMS" in env_vars
135
235
  assert "TPU_NAME" in env_vars
136
236
  assert "SKIP_JAX_PRECOMPILE" in env_vars
237
+ assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
137
238
  assert "MODEL_IMPL_TYPE" in env_vars
138
239
 
139
240
 
@@ -141,11 +242,8 @@ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
141
242
  monkeypatch.setenv("TPU_WORKER_ID", "0")
142
243
  assert envs.TPU_WORKER_ID == "0"
143
244
 
144
- monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
145
- assert envs.TPU_MULTIHOST_BACKEND == "grpc"
146
-
147
- monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
148
- assert envs.TPU_MULTIHOST_BACKEND == "xla"
245
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "ray")
246
+ assert envs.TPU_MULTIHOST_BACKEND == "ray"
149
247
 
150
248
 
151
249
  def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
@@ -158,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
158
256
 
159
257
  def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
160
258
  monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
161
- assert envs.MODEL_IMPL_TYPE == "flax_nnx"
259
+ assert envs.MODEL_IMPL_TYPE == "auto"
162
260
 
163
261
 
164
262
  def test_cache_preserves_values_across_env_changes(
@@ -112,6 +112,8 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
112
112
  self.mesh = Mesh(jax.devices(), ('model', ))
113
113
  self.rng = jax.random.PRNGKey(0)
114
114
  self.model = SimpleModel(rngs=nnx.Rngs(0))
115
+ self.model.vllm_config = MagicMock()
116
+ self.model.vllm_config.model_config.use_mla = False
115
117
 
116
118
  self.qwix_config = [
117
119
  {
@@ -131,6 +133,7 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
131
133
  """Test that qwix.quantize_model is called with the correct arguments."""
132
134
  quantized_model_mock = MagicMock(spec=nnx.Module)
133
135
  mock_quantize_model.return_value = quantized_model_mock
136
+ self.model.vllm_config.sharding_config.total_dp_size = 1
134
137
 
135
138
  with patch(
136
139
  "tpu_inference.models.jax.utils.quantization.quantization_utils.init_logger",
tests/test_utils.py CHANGED
@@ -231,6 +231,5 @@ def test_get_jax_dtype_from_str_dtype():
231
231
  assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
232
232
  assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
233
233
  assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
234
- assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
234
+ assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
235
235
  assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
236
- assert get_jax_dtype_from_str_dtype("auto") is None
tpu_inference/__init__.py CHANGED
@@ -1,21 +1,40 @@
1
- import os
2
-
3
1
  # The environment variables override should be imported before any other
4
2
  # modules to ensure that the environment variables are set before any
5
3
  # other modules are imported.
6
4
  import tpu_inference.env_override # noqa: F401
5
+ from tpu_inference import envs
7
6
  from tpu_inference import tpu_info as ti
8
7
  from tpu_inference.logger import init_logger
9
8
 
10
9
  logger = init_logger(__name__)
11
10
 
12
- if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
11
+ if "proxy" in envs.JAX_PLATFORMS:
13
12
  logger.info("Running vLLM on TPU via Pathways proxy.")
14
13
  # Must run pathwaysutils.initialize() before any JAX operations
15
14
  try:
15
+ import traceback
16
+
16
17
  import pathwaysutils
18
+ import vllm
19
+ from vllm.platforms import (resolve_current_platform_cls_qualname,
20
+ resolve_obj_by_qualname)
17
21
  pathwaysutils.initialize()
18
22
  logger.info("Module pathwaysutils is imported.")
23
+
24
+ # Pathways requires eager resolution of vllm.current_platform instead of
25
+ # lazy resolution in the normal code path. Since this part involves
26
+ # global topology discovery across multiple hosts, the platform
27
+ # resolution must happen before other components are loaded.
28
+ logger.info("Eagerly resolving vLLM current_platform for Pathways.")
29
+ platform_cls_qualname = resolve_current_platform_cls_qualname()
30
+ resolved_platform_instance = resolve_obj_by_qualname(
31
+ platform_cls_qualname)()
32
+ vllm.platforms._current_platform = resolved_platform_instance
33
+ vllm.platforms._init_trace = "".join(traceback.format_stack())
34
+ logger.info(
35
+ f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
36
+ )
37
+
19
38
  except Exception as e:
20
39
  logger.error(
21
40
  f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
@@ -1,17 +1,15 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import os
4
3
  from typing import Tuple
5
4
 
6
- PREFILL_SLICES = 'PREFILL_SLICES'
7
- DECODE_SLICES = 'DECODE_SLICES'
5
+ from tpu_inference import envs
8
6
 
9
7
 
10
8
  def is_disagg_enabled() -> bool:
11
9
  # We triggrer our code path as long as prefill slices are set. This
12
10
  # allows us to test interleave mode effectively with the code path
13
11
  # for comparison purposes.
14
- return PREFILL_SLICES in os.environ
12
+ return bool(envs.PREFILL_SLICES)
15
13
 
16
14
 
17
15
  def _parse_slices(slices_str: str) -> Tuple[int, ...]:
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
40
38
 
41
39
 
42
40
  def get_prefill_slices() -> Tuple[int, ...]:
43
- if PREFILL_SLICES not in os.environ:
41
+ if not envs.PREFILL_SLICES:
44
42
  return ()
45
- return _parse_slices(os.environ[PREFILL_SLICES])
43
+ return _parse_slices(envs.PREFILL_SLICES)
46
44
 
47
45
 
48
46
  def get_decode_slices() -> Tuple[int, ...]:
49
- if DECODE_SLICES not in os.environ:
47
+ if not envs.DECODE_SLICES:
50
48
  return ()
51
- return _parse_slices(os.environ[DECODE_SLICES])
49
+ return _parse_slices(envs.DECODE_SLICES)
@@ -60,7 +60,6 @@ D workflow:
60
60
 
61
61
  import copy
62
62
  import functools
63
- import os
64
63
  import threading
65
64
  import time
66
65
  from concurrent.futures import Future, ThreadPoolExecutor
@@ -86,6 +85,7 @@ if TYPE_CHECKING:
86
85
  from vllm.v1.core.kv_cache_manager import KVCacheBlocks
87
86
  from vllm.v1.request import Request
88
87
 
88
+ from tpu_inference import envs
89
89
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
90
90
  get_kv_ports,
91
91
  get_kv_transfer_port, get_node_id,
@@ -441,8 +441,7 @@ class TPUConnectorWorker:
441
441
 
442
442
  self.runner: TPUModelRunner = None
443
443
  self.mesh: Mesh = None
444
- self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
445
- "").lower() == "ray"
444
+ self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
446
445
  # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
447
446
  # The worker rank is assigned with vLLM's sorting logic, which does not work
448
447
  # for TPU host topology.
@@ -458,7 +457,6 @@ class TPUConnectorWorker:
458
457
  self.side_channel_port = get_side_channel_port()
459
458
 
460
459
  self.kv_transfer_server = None
461
- self._maybe_start_p2p_server()
462
460
  self.zmq_cxt = zmq.Context()
463
461
  if self.is_producer:
464
462
  ready_event = threading.Event()
@@ -500,6 +498,7 @@ class TPUConnectorWorker:
500
498
  self.shape = list(kv_layer.shape)
501
499
  self.dtype = kv_layer.dtype
502
500
  self.sharding = kv_layer.sharding
501
+ self._maybe_start_p2p_server()
503
502
 
504
503
  def _maybe_start_p2p_server(self):
505
504
  if self.kv_transfer_server is not None:
@@ -2,6 +2,7 @@ import os
2
2
 
3
3
  from vllm.utils.network_utils import get_ip
4
4
 
5
+ from tpu_inference import envs
5
6
  from tpu_inference.logger import init_logger
6
7
 
7
8
  logger = init_logger(__name__)
@@ -17,7 +18,7 @@ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
17
18
 
18
19
 
19
20
  def get_kv_ips() -> str:
20
- if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
21
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
21
22
  num_nodes = len(_NODES_KV_IP_PORT)
22
23
  ips = []
23
24
  for node_id in range(num_nodes):
@@ -28,7 +29,7 @@ def get_kv_ips() -> str:
28
29
 
29
30
 
30
31
  def get_kv_ports() -> str:
31
- if os.getenv("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
32
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
32
33
  num_nodes = len(_NODES_KV_IP_PORT)
33
34
  ports = []
34
35
  for node_id in range(num_nodes):
tpu_inference/envs.py CHANGED
@@ -15,18 +15,93 @@ if TYPE_CHECKING:
15
15
  PREFILL_SLICES: str = ""
16
16
  DECODE_SLICES: str = ""
17
17
  SKIP_JAX_PRECOMPILE: bool = False
18
- MODEL_IMPL_TYPE: str = "flax_nnx"
18
+ VLLM_XLA_CHECK_RECOMPILATION: bool = False
19
+ MODEL_IMPL_TYPE: str = "auto"
19
20
  NEW_MODEL_DESIGN: bool = False
20
21
  PHASED_PROFILING_DIR: str = ""
21
22
  PYTHON_TRACER_LEVEL: int = 1
22
23
  USE_MOE_EP_KERNEL: bool = False
24
+ NUM_SLICES: int = 1
23
25
  RAY_USAGE_STATS_ENABLED: str = "0"
24
26
  VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
27
+ ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
28
+
29
+
30
+ def env_with_choices(
31
+ env_name: str,
32
+ default: str | None,
33
+ choices: list[str] | Callable[[], list[str]],
34
+ case_sensitive: bool = True,
35
+ ) -> Callable[[], str | None]:
36
+ """
37
+ Create a lambda that validates environment variable against allowed choices
38
+
39
+ Args:
40
+ env_name: Name of the environment variable
41
+ default: Default value if not set (can be None)
42
+ choices: List of valid string options or callable that returns list
43
+ case_sensitive: Whether validation should be case sensitive
44
+
45
+ Returns:
46
+ Lambda function for environment_variables dict
47
+ """
48
+
49
+ def _get_validated_env() -> str | None:
50
+ value = os.getenv(env_name)
51
+ if value is None:
52
+ return default
53
+
54
+ # Resolve choices if it's a callable (for lazy loading)
55
+ actual_choices = choices() if callable(choices) else choices
56
+
57
+ if not case_sensitive:
58
+ check_value = value.lower()
59
+ check_choices = [choice.lower() for choice in actual_choices]
60
+ else:
61
+ check_value = value
62
+ check_choices = actual_choices
63
+
64
+ if check_value not in check_choices:
65
+ raise ValueError(f"Invalid value '{value}' for {env_name}. "
66
+ f"Valid options: {actual_choices}.")
67
+
68
+ return value
69
+
70
+ return _get_validated_env
71
+
72
+
73
+ def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
74
+ """
75
+ Accepts both numeric strings ("0", "1") and boolean strings
76
+ ("true", "false", "True", "False").
77
+
78
+ Args:
79
+ env_name: Name of the environment variable
80
+ default: Default boolean value if not set
81
+ """
82
+
83
+ def _get_bool_env() -> bool:
84
+ value = os.getenv(env_name)
85
+ if value is None or value == "":
86
+ return default
87
+
88
+ value_lower = value.lower()
89
+ if value_lower in ("true", "1"):
90
+ return True
91
+ elif value_lower in ("false", "0"):
92
+ return False
93
+ else:
94
+ raise ValueError(
95
+ f"Invalid boolean value '{value}' for {env_name}. "
96
+ f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
97
+
98
+ return _get_bool_env
99
+
25
100
 
26
101
  environment_variables: dict[str, Callable[[], Any]] = {
27
102
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
28
103
  "JAX_PLATFORMS":
29
- lambda: os.getenv("JAX_PLATFORMS", ""),
104
+ lambda: os.getenv("JAX_PLATFORMS", "").lower(),
30
105
  # TPU accelerator type (e.g., "v5litepod-16", "v4-8")
31
106
  "TPU_ACCELERATOR_TYPE":
32
107
  lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
@@ -38,7 +113,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
38
113
  lambda: os.getenv("TPU_WORKER_ID", None),
39
114
  # Backend for multi-host communication on TPU
40
115
  "TPU_MULTIHOST_BACKEND":
41
- lambda: os.getenv("TPU_MULTIHOST_BACKEND", "").lower(),
116
+ env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
42
117
  # Slice configuration for disaggregated prefill workers
43
118
  "PREFILL_SLICES":
44
119
  lambda: os.getenv("PREFILL_SLICES", ""),
@@ -47,28 +122,37 @@ environment_variables: dict[str, Callable[[], Any]] = {
47
122
  lambda: os.getenv("DECODE_SLICES", ""),
48
123
  # Skip JAX precompilation step during initialization
49
124
  "SKIP_JAX_PRECOMPILE":
50
- lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
125
+ env_bool("SKIP_JAX_PRECOMPILE", default=False),
126
+ # Check for XLA recompilation during execution
127
+ "VLLM_XLA_CHECK_RECOMPILATION":
128
+ env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
51
129
  # Model implementation type (e.g., "flax_nnx")
52
130
  "MODEL_IMPL_TYPE":
53
- lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
131
+ env_with_choices("MODEL_IMPL_TYPE", "auto",
132
+ ["auto", "vllm", "flax_nnx", "jetpack"]),
54
133
  # Enable new experimental model design
55
134
  "NEW_MODEL_DESIGN":
56
- lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))),
135
+ env_bool("NEW_MODEL_DESIGN", default=False),
57
136
  # Directory to store phased profiling output
58
137
  "PHASED_PROFILING_DIR":
59
138
  lambda: os.getenv("PHASED_PROFILING_DIR", ""),
60
139
  # Python tracer level for profiling
61
140
  "PYTHON_TRACER_LEVEL":
62
- lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")),
141
+ lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
63
142
  # Use custom expert-parallel kernel for MoE (Mixture of Experts)
64
143
  "USE_MOE_EP_KERNEL":
65
- lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
144
+ env_bool("USE_MOE_EP_KERNEL", default=False),
145
+ # Number of TPU slices for multi-slice mesh
146
+ "NUM_SLICES":
147
+ lambda: int(os.getenv("NUM_SLICES") or "1"),
66
148
  # Enable/disable Ray usage statistics collection
67
149
  "RAY_USAGE_STATS_ENABLED":
68
150
  lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
69
151
  # Ray compiled DAG channel type for TPU
70
152
  "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
71
- lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm"),
153
+ env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
154
+ "ENABLE_QUANTIZED_MATMUL_KERNEL":
155
+ lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
72
156
  }
73
157
 
74
158
 
@@ -108,6 +108,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
108
108
  ip_port = self.collective_rpc("get_node_kv_ip_port")
109
109
  for item in ip_port:
110
110
  set_node_kv_ip_port(item)
111
+ self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
112
+ self.vllm_config.ec_transfer_config is None
113
+ or not self.vllm_config.ec_transfer_config.is_ec_producer)
111
114
 
112
115
  def _initialize_ray_cluster(self) -> None:
113
116
  """Initialize the distributed cluster with Ray.
@@ -133,10 +136,14 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
133
136
 
134
137
  pp_size = self.parallel_config.pipeline_parallel_size
135
138
  placement_group_specs: List[Dict[str, float]] = []
139
+
140
+ ray_nodes = ray.nodes()
141
+ logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
142
+
136
143
  if pp_size == 1:
137
144
  placement_group_specs = [{
138
145
  device_str: node['Resources'][device_str]
139
- } for node in ray.nodes()]
146
+ } for node in ray_nodes]
140
147
  else:
141
148
  num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
142
149
  placement_group_specs = [{
@@ -352,7 +359,7 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
352
359
  self.collective_rpc("init_worker", args=(all_kwargs, ))
353
360
  self.collective_rpc("init_device")
354
361
  if self.parallel_config.pipeline_parallel_size > 1:
355
- self._run_workers("initialize_pp_transfer_connect")
362
+ self.collective_rpc("initialize_pp_transfer_connect")
356
363
  self.collective_rpc("load_model")
357
364
 
358
365
  if self.use_ray_spmd_worker: