tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (54) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tests/test_envs.py +32 -11
  5. tests/test_utils.py +1 -2
  6. tpu_inference/__init__.py +22 -3
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +3 -4
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +61 -8
  11. tpu_inference/executors/ray_distributed_executor.py +31 -11
  12. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
  15. tpu_inference/layers/common/attention_interface.py +7 -1
  16. tpu_inference/layers/common/sharding.py +5 -5
  17. tpu_inference/layers/vllm/fused_moe.py +74 -25
  18. tpu_inference/layers/vllm/quantization/common.py +6 -1
  19. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
  20. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  21. tpu_inference/layers/vllm/sharding.py +2 -2
  22. tpu_inference/lora/torch_punica_tpu.py +1 -2
  23. tpu_inference/models/common/model_loader.py +45 -11
  24. tpu_inference/models/jax/llama3.py +2 -1
  25. tpu_inference/models/jax/llama_eagle3.py +8 -5
  26. tpu_inference/models/jax/llama_guard_4.py +361 -0
  27. tpu_inference/models/jax/qwen2.py +2 -1
  28. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  29. tpu_inference/models/jax/qwen3.py +2 -1
  30. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  31. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  32. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
  33. tpu_inference/platforms/tpu_platform.py +28 -22
  34. tpu_inference/runner/compilation_manager.py +144 -59
  35. tpu_inference/runner/kv_cache_manager.py +17 -18
  36. tpu_inference/runner/persistent_batch_manager.py +40 -2
  37. tpu_inference/runner/structured_decoding_manager.py +2 -3
  38. tpu_inference/runner/tpu_runner.py +271 -147
  39. tpu_inference/runner/utils.py +2 -2
  40. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  41. tpu_inference/tpu_info.py +4 -3
  42. tpu_inference/utils.py +36 -13
  43. tpu_inference/worker/tpu_worker.py +162 -25
  44. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
  45. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
  46. tpu_inference/mock/__init__.py +0 -0
  47. tpu_inference/mock/vllm_config_utils.py +0 -28
  48. tpu_inference/mock/vllm_envs.py +0 -1219
  49. tpu_inference/mock/vllm_logger.py +0 -212
  50. tpu_inference/mock/vllm_logging_utils.py +0 -15
  51. tpu_inference/models/jax/phi3.py +0 -376
  52. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  53. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  54. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  import numpy as np
4
- from absl.testing import absltest
4
+ from absl.testing import absltest, parameterized
5
5
  from jax._src import test_util as jtu
6
6
  from jax.sharding import Mesh
7
7
 
@@ -10,6 +10,15 @@ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
10
10
  jax.config.parse_flags_with_absl()
11
11
 
12
12
 
13
+ def cdiv(a, b):
14
+ assert b != 0
15
+ return (a + b - 1) // b
16
+
17
+
18
+ def align_to(x, a):
19
+ return cdiv(x, a) * a
20
+
21
+
13
22
  def gen_moe_inputs(
14
23
  dtype,
15
24
  top_k,
@@ -19,11 +28,14 @@ def gen_moe_inputs(
19
28
  num_tokens,
20
29
  *,
21
30
  seed=1234,
31
+ has_bias=False,
22
32
  ):
23
33
  key = jax.random.key(seed)
24
- k0, k1, k2, k4, k5 = jax.random.split(key, 5)
34
+ k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
35
+
25
36
  a = jax.random.normal(k0, (num_tokens, hidden_size),
26
37
  dtype=jnp.float32).astype(dtype) / 10
38
+
27
39
  w1 = (jax.random.normal(
28
40
  k1,
29
41
  (num_experts, 2, hidden_size, intermediate_size),
@@ -31,21 +43,54 @@ def gen_moe_inputs(
31
43
  ) / 10).astype(dtype)
32
44
  w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
33
45
  dtype=jnp.float32) / 10).astype(dtype)
46
+
47
+ if has_bias:
48
+ b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
49
+ dtype=jnp.float32) / 10).astype(dtype)
50
+ b2 = (jax.random.normal(k4, (num_experts, hidden_size),
51
+ dtype=jnp.float32) / 10).astype(dtype)
52
+ else:
53
+ b1 = b2 = None
54
+
34
55
  gating_output = (
35
- jax.random.normal(k4, (num_tokens, num_experts), dtype=jnp.float32) +
56
+ jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
36
57
  jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
37
58
  num_tokens, num_experts) / 100)
59
+
38
60
  # To generate unique top-k!
39
- top_k_indices = jax.random.randint(k5, (num_tokens, top_k),
61
+ top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
40
62
  minval=0,
41
63
  maxval=num_experts - 1,
42
64
  dtype=jnp.int32)
65
+
43
66
  one_hot = (jnp.sum(
44
67
  jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
45
68
  axis=1,
46
- ) * 10)
69
+ ) * 30)
70
+
47
71
  gating_output = (gating_output + one_hot).astype(dtype)
48
- return a, w1, w2, gating_output
72
+
73
+ return a, w1, w2, b1, b2, gating_output
74
+
75
+
76
+ def sub_channel_quantize(x, quant_dtype, wsz=256):
77
+ """Quantizes x with sub-channel quantization on the 2nd minor."""
78
+ if jnp.issubdtype(quant_dtype, jnp.floating):
79
+ dtype_info = jnp.finfo(quant_dtype)
80
+ else:
81
+ dtype_info = jnp.iinfo(quant_dtype)
82
+ dtype_max = float(dtype_info.max)
83
+ w_lst, scale_lst = [], []
84
+ assert len(x.shape) >= 2
85
+ assert x.shape[-2] % wsz == 0
86
+ for i in range(0, x.shape[-2], wsz):
87
+ y = x[..., i:i + wsz, :]
88
+ abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
89
+ scale = (abs_max / dtype_max).astype(jnp.float32)
90
+ w = (y / scale).astype(quant_dtype)
91
+ w_lst.append(w)
92
+ scale_lst.append(scale)
93
+ return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
49
94
 
50
95
 
51
96
  @jtu.with_config(jax_numpy_dtype_promotion="standard")
@@ -63,42 +108,266 @@ class MoEKernelTest(jtu.JaxTestCase):
63
108
  self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
64
109
  axis_names=("data", "model"))
65
110
 
66
- def test_basic(self):
67
- dtype = jnp.bfloat16
68
- top_k = 2
69
- num_experts = 16
70
- hidden_size = 256
71
- intermediate_size = 256
72
- num_tokens = 8 * 2
73
-
74
- a, w1, w2, gating_output = gen_moe_inputs(
111
+ def _test_moe(
112
+ self,
113
+ dtype,
114
+ top_k,
115
+ num_experts,
116
+ hidden_size,
117
+ intermediate_size,
118
+ num_tokens,
119
+ seed,
120
+ renormalize_topk_logits,
121
+ bt,
122
+ bf,
123
+ bd1,
124
+ bd2,
125
+ btc,
126
+ bfc,
127
+ bd1c,
128
+ bd2c,
129
+ act_fn="silu",
130
+ w_dtype=None,
131
+ subc_quant_wsz=None,
132
+ has_bias=False,
133
+ atol=2e-1,
134
+ rtol=2e-1,
135
+ ):
136
+ a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
75
137
  dtype,
76
138
  top_k,
77
139
  num_experts,
78
140
  hidden_size,
79
141
  intermediate_size,
80
142
  num_tokens,
143
+ seed=seed,
144
+ has_bias=has_bias,
81
145
  )
146
+ w1_scale = None
147
+ w2_scale = None
148
+ if w_dtype is not None:
149
+ if subc_quant_wsz is None:
150
+ subc_quant_wsz = 256
151
+ w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
152
+ w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
82
153
 
83
- actual = jax.block_until_ready(
84
- fused_ep_moe(
85
- mesh=self.mesh,
86
- tokens=a,
87
- w1=w1,
88
- w2=w2,
89
- gating_output=gating_output,
90
- top_k=top_k,
91
- bt=32,
92
- bf=512,
93
- bd1=512,
94
- bd2=512,
95
- btc=32,
96
- bfc=256,
97
- bd1c=256,
98
- bd2c=256,
99
- ))
100
- expected = ref_moe(a, w1, w2, gating_output, top_k)
101
- self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2)
154
+ actual = fused_ep_moe(
155
+ mesh=self.mesh,
156
+ tokens=a,
157
+ w1=w1,
158
+ w2=w2,
159
+ gating_output=gating_output,
160
+ top_k=top_k,
161
+ renormalize_topk_logits=renormalize_topk_logits,
162
+ act_fn=act_fn,
163
+ subc_quant_wsz=subc_quant_wsz,
164
+ w1_scale=w1_scale,
165
+ w2_scale=w2_scale,
166
+ b1=b1,
167
+ b2=b2,
168
+ bt=bt,
169
+ bf=bf,
170
+ bd1=bd1,
171
+ bd2=bd2,
172
+ btc=btc,
173
+ bfc=bfc,
174
+ bd1c=bd1c,
175
+ bd2c=bd2c,
176
+ )
177
+ expected = ref_moe(
178
+ a,
179
+ w1,
180
+ w2,
181
+ gating_output,
182
+ top_k,
183
+ b1=b1,
184
+ b2=b2,
185
+ renormalize_topk_logits=renormalize_topk_logits,
186
+ activation=act_fn,
187
+ subc_quant_wsz=subc_quant_wsz,
188
+ w1_scale=w1_scale,
189
+ w2_scale=w2_scale,
190
+ )
191
+ self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
192
+
193
+ @parameterized.product(renormalize_topk_logits=[True, False], )
194
+ def test_basic(self, renormalize_topk_logits):
195
+ dtype = jnp.bfloat16
196
+ top_k = 8
197
+ num_experts = 128
198
+ hidden_size = 1024
199
+ intermediate_size = 1024
200
+ num_tokens = 8 * 32
201
+ self._test_moe(
202
+ dtype=dtype,
203
+ top_k=top_k,
204
+ num_experts=num_experts,
205
+ hidden_size=hidden_size,
206
+ intermediate_size=intermediate_size,
207
+ num_tokens=num_tokens,
208
+ seed=1234,
209
+ renormalize_topk_logits=renormalize_topk_logits,
210
+ bt=32,
211
+ bf=1024,
212
+ bd1=1024,
213
+ bd2=1024,
214
+ btc=32,
215
+ bfc=256,
216
+ bd1c=256,
217
+ bd2c=256,
218
+ )
219
+
220
+ @parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
221
+ def test_activation(self, act_fn):
222
+ dtype = jnp.bfloat16
223
+ top_k = 8
224
+ num_experts = 128
225
+ hidden_size = 1024
226
+ intermediate_size = 1024
227
+ num_tokens = 8 * 32
228
+ self._test_moe(
229
+ dtype=dtype,
230
+ top_k=top_k,
231
+ num_experts=num_experts,
232
+ hidden_size=hidden_size,
233
+ intermediate_size=intermediate_size,
234
+ num_tokens=num_tokens,
235
+ seed=1234,
236
+ renormalize_topk_logits=True,
237
+ act_fn=act_fn,
238
+ bt=32,
239
+ bf=512,
240
+ bd1=512,
241
+ bd2=512,
242
+ btc=32,
243
+ bfc=256,
244
+ bd1c=256,
245
+ bd2c=256,
246
+ )
247
+
248
+ def test_benchmark_qwen_235(self):
249
+ num_experts = 128
250
+ top_k = 8
251
+ hidden_size = 4096
252
+ intermediate_size = 1536
253
+ dtype = jnp.bfloat16
254
+ num_tokens = 8 * 64
255
+ seed = 54321
256
+ renormalize_topk_logits = True
257
+ self._test_moe(
258
+ dtype=dtype,
259
+ top_k=top_k,
260
+ num_experts=num_experts,
261
+ hidden_size=hidden_size,
262
+ intermediate_size=intermediate_size,
263
+ num_tokens=num_tokens,
264
+ seed=seed,
265
+ renormalize_topk_logits=renormalize_topk_logits,
266
+ bt=64,
267
+ bf=768,
268
+ bd1=2048,
269
+ bd2=2048,
270
+ btc=64,
271
+ bfc=768,
272
+ bd1c=2048,
273
+ bd2c=2048,
274
+ act_fn="silu",
275
+ atol=5e-2,
276
+ rtol=5e-2,
277
+ )
278
+
279
+ def test_benchmark_qwen_30b_a3b(self):
280
+ num_experts = 128
281
+ top_k = 8
282
+ hidden_size = 2048
283
+ intermediate_size = 768
284
+ dtype = jnp.bfloat16
285
+ num_tokens = 512
286
+ seed = 54321
287
+ renormalize_topk_logits = True
288
+ self._test_moe(
289
+ dtype=dtype,
290
+ top_k=top_k,
291
+ num_experts=num_experts,
292
+ hidden_size=hidden_size,
293
+ intermediate_size=intermediate_size,
294
+ num_tokens=num_tokens,
295
+ seed=seed,
296
+ renormalize_topk_logits=renormalize_topk_logits,
297
+ bt=16,
298
+ bf=384,
299
+ bd1=512,
300
+ bd2=512,
301
+ btc=16,
302
+ bfc=384,
303
+ bd1c=256,
304
+ bd2c=256,
305
+ act_fn="silu",
306
+ atol=5e-2,
307
+ rtol=5e-2,
308
+ )
309
+
310
+ @parameterized.product(
311
+ w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
312
+ def test_sub_channel_quantization(self, w_dtype):
313
+ if w_dtype in (
314
+ jnp.float8_e5m2,
315
+ jnp.float4_e2m1fn,
316
+ ) and not jtu.is_device_tpu_at_least(version=7):
317
+ self.skipTest("Expect TPUv7+")
318
+ dtype = jnp.bfloat16
319
+ top_k = 8
320
+ num_experts = 128
321
+ hidden_size = 1024
322
+ intermediate_size = 1024
323
+ num_tokens = 8 * 32
324
+ self._test_moe(
325
+ dtype=dtype,
326
+ top_k=top_k,
327
+ num_experts=num_experts,
328
+ hidden_size=hidden_size,
329
+ intermediate_size=intermediate_size,
330
+ num_tokens=num_tokens,
331
+ seed=1234,
332
+ renormalize_topk_logits=False,
333
+ w_dtype=w_dtype,
334
+ subc_quant_wsz=256,
335
+ bt=32,
336
+ bf=1024,
337
+ bd1=1024,
338
+ bd2=1024,
339
+ btc=32,
340
+ bfc=256,
341
+ bd1c=256,
342
+ bd2c=256,
343
+ )
344
+
345
+ def test_bias(self):
346
+ dtype = jnp.bfloat16
347
+ top_k = 8
348
+ num_experts = 128
349
+ hidden_size = 1024
350
+ intermediate_size = 1024
351
+ num_tokens = 8 * 32
352
+ self._test_moe(
353
+ dtype=dtype,
354
+ top_k=top_k,
355
+ num_experts=num_experts,
356
+ hidden_size=hidden_size,
357
+ intermediate_size=intermediate_size,
358
+ num_tokens=num_tokens,
359
+ seed=1234,
360
+ renormalize_topk_logits=False,
361
+ has_bias=True,
362
+ bt=32,
363
+ bf=512,
364
+ bd1=512,
365
+ bd2=512,
366
+ btc=32,
367
+ bfc=256,
368
+ bd1c=256,
369
+ bd2c=256,
370
+ )
102
371
 
103
372
 
104
373
  if __name__ == "__main__":
tests/lora/test_layers.py CHANGED
@@ -91,7 +91,6 @@ def populate_loras(
91
91
  index_to_id: list[Optional[int]],
92
92
  lora_layer: BaseLayerWithLoRA,
93
93
  baselayer_weights: torch.Tensor,
94
- generate_embeddings_tensor: int = 0,
95
94
  repeats: int = 1,
96
95
  ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
97
96
  """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
@@ -103,8 +102,6 @@ def populate_loras(
103
102
  lora_layer: the LoRAlayer to populate.
104
103
  baselayer_weights: the PyTorch tensor containing the layer's
105
104
  weights.
106
- generate_embeddings_tensor: whether to generate an
107
- embeddings tensor for each LoRA.
108
105
  repeats: must only be set for column parallel packed
109
106
  layers. Indicates the number of loras to compose
110
107
  together to create a single lora layer.
@@ -131,7 +128,6 @@ def populate_loras(
131
128
  baselayer_weights.device).init_random_lora(
132
129
  module_name=f"fake_{i}",
133
130
  weight=baselayer_weights,
134
- generate_embeddings_tensor=generate_embeddings_tensor,
135
131
  )
136
132
  sublora.lora_b = sublora.lora_b[(sublora_len *
137
133
  i):(sublora_len * (i + 1)), :]
@@ -147,7 +143,6 @@ def populate_loras(
147
143
  slot_idx,
148
144
  lora_a=lora.lora_a,
149
145
  lora_b=lora.lora_b,
150
- embeddings_tensor=lora.embeddings_tensor,
151
146
  )
152
147
 
153
148
  lora_dict[lora_id] = lora
@@ -546,7 +541,6 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
546
541
  index_to_id,
547
542
  lora_config.max_loras,
548
543
  vocab_size=512,
549
- extra_vocab_size=lora_config.lora_extra_vocab_size,
550
544
  )
551
545
  assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
552
546
  ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
tests/lora/utils.py CHANGED
@@ -24,7 +24,6 @@ class DummyLoRAManager:
24
24
  module_name: str,
25
25
  weight: torch.Tensor,
26
26
  rank: int = 8,
27
- generate_embeddings_tensor: int = 0,
28
27
  ):
29
28
  lora = LoRALayerWeights(
30
29
  module_name,
@@ -37,13 +36,6 @@ class DummyLoRAManager:
37
36
  dtype=weight.dtype,
38
37
  device=self._device),
39
38
  )
40
- if generate_embeddings_tensor:
41
- lora.embeddings_tensor = torch.rand(
42
- 5,
43
- generate_embeddings_tensor,
44
- dtype=weight.dtype,
45
- device=self._device,
46
- )
47
39
  self.set_module_lora(module_name, lora)
48
40
 
49
41
  return lora
tests/test_envs.py CHANGED
@@ -56,6 +56,12 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
56
56
 
57
57
 
58
58
  def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
59
+ # Ensure clean environment for boolean vars by setting to default "0"
60
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
61
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
62
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
63
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
64
+
59
65
  # Test SKIP_JAX_PRECOMPILE (default False)
60
66
  assert envs.SKIP_JAX_PRECOMPILE is False
61
67
  monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
@@ -63,6 +69,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
63
69
  monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
64
70
  assert envs.SKIP_JAX_PRECOMPILE is False
65
71
 
72
+ # Test VLLM_XLA_CHECK_RECOMPILATION (default False)
73
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
74
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
75
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
76
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
77
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
78
+
66
79
  # Test NEW_MODEL_DESIGN (default False)
67
80
  assert envs.NEW_MODEL_DESIGN is False
68
81
  monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
@@ -75,20 +88,32 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
75
88
 
76
89
 
77
90
  def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
91
+ # Ensure clean environment for integer vars by setting to defaults
92
+ monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
93
+ monkeypatch.setenv("NUM_SLICES", "1")
94
+
78
95
  assert envs.PYTHON_TRACER_LEVEL == 1
79
96
  monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
80
97
  assert envs.PYTHON_TRACER_LEVEL == 3
81
98
  monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
82
99
  assert envs.PYTHON_TRACER_LEVEL == 0
83
100
 
101
+ # Test NUM_SLICES (default 1)
102
+ assert envs.NUM_SLICES == 1
103
+ monkeypatch.setenv("NUM_SLICES", "2")
104
+ assert envs.NUM_SLICES == 2
105
+ monkeypatch.setenv("NUM_SLICES", "4")
106
+ assert envs.NUM_SLICES == 4
84
107
 
85
- def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
86
- monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
87
- assert envs.TPU_MULTIHOST_BACKEND == "grpc"
88
108
 
89
- monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
109
+ def test_model_impl_type_choices(monkeypatch: pytest.MonkeyPatch):
110
+ # Test case sensitive choices
111
+ monkeypatch.setenv("MODEL_IMPL_TYPE", "flax_nnx")
90
112
  assert envs.MODEL_IMPL_TYPE == "flax_nnx"
91
113
 
114
+ monkeypatch.setenv("MODEL_IMPL_TYPE", "vllm")
115
+ assert envs.MODEL_IMPL_TYPE == "vllm"
116
+
92
117
 
93
118
  def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
94
119
  monkeypatch.delenv("JAX_PLATFORMS", raising=False)
@@ -117,8 +142,6 @@ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
117
142
  assert envs.RAY_USAGE_STATS_ENABLED == "1"
118
143
 
119
144
  assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
120
- monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
121
- assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
122
145
 
123
146
 
124
147
  def test_invalid_attribute_raises_error():
@@ -134,6 +157,7 @@ def test_dir_returns_all_env_vars():
134
157
  assert "JAX_PLATFORMS" in env_vars
135
158
  assert "TPU_NAME" in env_vars
136
159
  assert "SKIP_JAX_PRECOMPILE" in env_vars
160
+ assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
137
161
  assert "MODEL_IMPL_TYPE" in env_vars
138
162
 
139
163
 
@@ -141,11 +165,8 @@ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
141
165
  monkeypatch.setenv("TPU_WORKER_ID", "0")
142
166
  assert envs.TPU_WORKER_ID == "0"
143
167
 
144
- monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
145
- assert envs.TPU_MULTIHOST_BACKEND == "grpc"
146
-
147
- monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
148
- assert envs.TPU_MULTIHOST_BACKEND == "xla"
168
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "ray")
169
+ assert envs.TPU_MULTIHOST_BACKEND == "ray"
149
170
 
150
171
 
151
172
  def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
tests/test_utils.py CHANGED
@@ -231,6 +231,5 @@ def test_get_jax_dtype_from_str_dtype():
231
231
  assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
232
232
  assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
233
233
  assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
234
- assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
234
+ assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
235
235
  assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
236
- assert get_jax_dtype_from_str_dtype("auto") is None
tpu_inference/__init__.py CHANGED
@@ -1,21 +1,40 @@
1
- import os
2
-
3
1
  # The environment variables override should be imported before any other
4
2
  # modules to ensure that the environment variables are set before any
5
3
  # other modules are imported.
6
4
  import tpu_inference.env_override # noqa: F401
5
+ from tpu_inference import envs
7
6
  from tpu_inference import tpu_info as ti
8
7
  from tpu_inference.logger import init_logger
9
8
 
10
9
  logger = init_logger(__name__)
11
10
 
12
- if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
11
+ if "proxy" in envs.JAX_PLATFORMS:
13
12
  logger.info("Running vLLM on TPU via Pathways proxy.")
14
13
  # Must run pathwaysutils.initialize() before any JAX operations
15
14
  try:
15
+ import traceback
16
+
16
17
  import pathwaysutils
18
+ import vllm
19
+ from vllm.platforms import (resolve_current_platform_cls_qualname,
20
+ resolve_obj_by_qualname)
17
21
  pathwaysutils.initialize()
18
22
  logger.info("Module pathwaysutils is imported.")
23
+
24
+ # Pathways requires eager resolution of vllm.current_platform instead of
25
+ # lazy resolution in the normal code path. Since this part involves
26
+ # global topology discovery across multiple hosts, the platform
27
+ # resolution must happen before other components are loaded.
28
+ logger.info("Eagerly resolving vLLM current_platform for Pathways.")
29
+ platform_cls_qualname = resolve_current_platform_cls_qualname()
30
+ resolved_platform_instance = resolve_obj_by_qualname(
31
+ platform_cls_qualname)()
32
+ vllm.platforms._current_platform = resolved_platform_instance
33
+ vllm.platforms._init_trace = "".join(traceback.format_stack())
34
+ logger.info(
35
+ f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
36
+ )
37
+
19
38
  except Exception as e:
20
39
  logger.error(
21
40
  f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
@@ -1,17 +1,15 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import os
4
3
  from typing import Tuple
5
4
 
6
- PREFILL_SLICES = 'PREFILL_SLICES'
7
- DECODE_SLICES = 'DECODE_SLICES'
5
+ from tpu_inference import envs
8
6
 
9
7
 
10
8
  def is_disagg_enabled() -> bool:
11
9
  # We triggrer our code path as long as prefill slices are set. This
12
10
  # allows us to test interleave mode effectively with the code path
13
11
  # for comparison purposes.
14
- return PREFILL_SLICES in os.environ
12
+ return bool(envs.PREFILL_SLICES)
15
13
 
16
14
 
17
15
  def _parse_slices(slices_str: str) -> Tuple[int, ...]:
@@ -40,12 +38,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
40
38
 
41
39
 
42
40
  def get_prefill_slices() -> Tuple[int, ...]:
43
- if PREFILL_SLICES not in os.environ:
41
+ if not envs.PREFILL_SLICES:
44
42
  return ()
45
- return _parse_slices(os.environ[PREFILL_SLICES])
43
+ return _parse_slices(envs.PREFILL_SLICES)
46
44
 
47
45
 
48
46
  def get_decode_slices() -> Tuple[int, ...]:
49
- if DECODE_SLICES not in os.environ:
47
+ if not envs.DECODE_SLICES:
50
48
  return ()
51
- return _parse_slices(os.environ[DECODE_SLICES])
49
+ return _parse_slices(envs.DECODE_SLICES)
@@ -60,7 +60,6 @@ D workflow:
60
60
 
61
61
  import copy
62
62
  import functools
63
- import os
64
63
  import threading
65
64
  import time
66
65
  from concurrent.futures import Future, ThreadPoolExecutor
@@ -86,6 +85,7 @@ if TYPE_CHECKING:
86
85
  from vllm.v1.core.kv_cache_manager import KVCacheBlocks
87
86
  from vllm.v1.request import Request
88
87
 
88
+ from tpu_inference import envs
89
89
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_ips,
90
90
  get_kv_ports,
91
91
  get_kv_transfer_port, get_node_id,
@@ -441,8 +441,7 @@ class TPUConnectorWorker:
441
441
 
442
442
  self.runner: TPUModelRunner = None
443
443
  self.mesh: Mesh = None
444
- self.multi_host = os.getenv("TPU_MULTIHOST_BACKEND",
445
- "").lower() == "ray"
444
+ self.multi_host = envs.TPU_MULTIHOST_BACKEND == "ray"
446
445
  # NOTE(xiang): This can not be the worker rank set in RayDistributedExecutor.
447
446
  # The worker rank is assigned with vLLM's sorting logic, which does not work
448
447
  # for TPU host topology.
@@ -458,7 +457,6 @@ class TPUConnectorWorker:
458
457
  self.side_channel_port = get_side_channel_port()
459
458
 
460
459
  self.kv_transfer_server = None
461
- self._maybe_start_p2p_server()
462
460
  self.zmq_cxt = zmq.Context()
463
461
  if self.is_producer:
464
462
  ready_event = threading.Event()
@@ -500,6 +498,7 @@ class TPUConnectorWorker:
500
498
  self.shape = list(kv_layer.shape)
501
499
  self.dtype = kv_layer.dtype
502
500
  self.sharding = kv_layer.sharding
501
+ self._maybe_start_p2p_server()
503
502
 
504
503
  def _maybe_start_p2p_server(self):
505
504
  if self.kv_transfer_server is not None: