tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  import numpy as np
4
- from absl.testing import absltest, parameterized
4
+ from absl.testing import absltest
5
5
  from jax._src import test_util as jtu
6
6
  from jax.sharding import Mesh
7
7
 
@@ -10,15 +10,6 @@ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
10
10
  jax.config.parse_flags_with_absl()
11
11
 
12
12
 
13
- def cdiv(a, b):
14
- assert b != 0
15
- return (a + b - 1) // b
16
-
17
-
18
- def align_to(x, a):
19
- return cdiv(x, a) * a
20
-
21
-
22
13
  def gen_moe_inputs(
23
14
  dtype,
24
15
  top_k,
@@ -28,14 +19,11 @@ def gen_moe_inputs(
28
19
  num_tokens,
29
20
  *,
30
21
  seed=1234,
31
- has_bias=False,
32
22
  ):
33
23
  key = jax.random.key(seed)
34
- k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
35
-
24
+ k0, k1, k2, k4, k5 = jax.random.split(key, 5)
36
25
  a = jax.random.normal(k0, (num_tokens, hidden_size),
37
26
  dtype=jnp.float32).astype(dtype) / 10
38
-
39
27
  w1 = (jax.random.normal(
40
28
  k1,
41
29
  (num_experts, 2, hidden_size, intermediate_size),
@@ -43,54 +31,21 @@ def gen_moe_inputs(
43
31
  ) / 10).astype(dtype)
44
32
  w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
45
33
  dtype=jnp.float32) / 10).astype(dtype)
46
-
47
- if has_bias:
48
- b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
49
- dtype=jnp.float32) / 10).astype(dtype)
50
- b2 = (jax.random.normal(k4, (num_experts, hidden_size),
51
- dtype=jnp.float32) / 10).astype(dtype)
52
- else:
53
- b1 = b2 = None
54
-
55
34
  gating_output = (
56
- jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
35
+ jax.random.normal(k4, (num_tokens, num_experts), dtype=jnp.float32) +
57
36
  jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
58
37
  num_tokens, num_experts) / 100)
59
-
60
38
  # To generate unique top-k!
61
- top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
39
+ top_k_indices = jax.random.randint(k5, (num_tokens, top_k),
62
40
  minval=0,
63
41
  maxval=num_experts - 1,
64
42
  dtype=jnp.int32)
65
-
66
43
  one_hot = (jnp.sum(
67
44
  jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
68
45
  axis=1,
69
- ) * 30)
70
-
46
+ ) * 10)
71
47
  gating_output = (gating_output + one_hot).astype(dtype)
72
-
73
- return a, w1, w2, b1, b2, gating_output
74
-
75
-
76
- def sub_channel_quantize(x, quant_dtype, wsz=256):
77
- """Quantizes x with sub-channel quantization on the 2nd minor."""
78
- if jnp.issubdtype(quant_dtype, jnp.floating):
79
- dtype_info = jnp.finfo(quant_dtype)
80
- else:
81
- dtype_info = jnp.iinfo(quant_dtype)
82
- dtype_max = float(dtype_info.max)
83
- w_lst, scale_lst = [], []
84
- assert len(x.shape) >= 2
85
- assert x.shape[-2] % wsz == 0
86
- for i in range(0, x.shape[-2], wsz):
87
- y = x[..., i:i + wsz, :]
88
- abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
89
- scale = (abs_max / dtype_max).astype(jnp.float32)
90
- w = (y / scale).astype(quant_dtype)
91
- w_lst.append(w)
92
- scale_lst.append(scale)
93
- return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
48
+ return a, w1, w2, gating_output
94
49
 
95
50
 
96
51
  @jtu.with_config(jax_numpy_dtype_promotion="standard")
@@ -108,266 +63,42 @@ class MoEKernelTest(jtu.JaxTestCase):
108
63
  self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
109
64
  axis_names=("data", "model"))
110
65
 
111
- def _test_moe(
112
- self,
113
- dtype,
114
- top_k,
115
- num_experts,
116
- hidden_size,
117
- intermediate_size,
118
- num_tokens,
119
- seed,
120
- renormalize_topk_logits,
121
- bt,
122
- bf,
123
- bd1,
124
- bd2,
125
- btc,
126
- bfc,
127
- bd1c,
128
- bd2c,
129
- act_fn="silu",
130
- w_dtype=None,
131
- subc_quant_wsz=None,
132
- has_bias=False,
133
- atol=2e-1,
134
- rtol=2e-1,
135
- ):
136
- a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
66
+ def test_basic(self):
67
+ dtype = jnp.bfloat16
68
+ top_k = 2
69
+ num_experts = 16
70
+ hidden_size = 256
71
+ intermediate_size = 256
72
+ num_tokens = 8 * 2
73
+
74
+ a, w1, w2, gating_output = gen_moe_inputs(
137
75
  dtype,
138
76
  top_k,
139
77
  num_experts,
140
78
  hidden_size,
141
79
  intermediate_size,
142
80
  num_tokens,
143
- seed=seed,
144
- has_bias=has_bias,
145
81
  )
146
- w1_scale = None
147
- w2_scale = None
148
- if w_dtype is not None:
149
- if subc_quant_wsz is None:
150
- subc_quant_wsz = 256
151
- w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
152
- w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
153
82
 
154
- actual = fused_ep_moe(
155
- mesh=self.mesh,
156
- tokens=a,
157
- w1=w1,
158
- w2=w2,
159
- gating_output=gating_output,
160
- top_k=top_k,
161
- renormalize_topk_logits=renormalize_topk_logits,
162
- act_fn=act_fn,
163
- subc_quant_wsz=subc_quant_wsz,
164
- w1_scale=w1_scale,
165
- w2_scale=w2_scale,
166
- b1=b1,
167
- b2=b2,
168
- bt=bt,
169
- bf=bf,
170
- bd1=bd1,
171
- bd2=bd2,
172
- btc=btc,
173
- bfc=bfc,
174
- bd1c=bd1c,
175
- bd2c=bd2c,
176
- )
177
- expected = ref_moe(
178
- a,
179
- w1,
180
- w2,
181
- gating_output,
182
- top_k,
183
- b1=b1,
184
- b2=b2,
185
- renormalize_topk_logits=renormalize_topk_logits,
186
- activation=act_fn,
187
- subc_quant_wsz=subc_quant_wsz,
188
- w1_scale=w1_scale,
189
- w2_scale=w2_scale,
190
- )
191
- self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
192
-
193
- @parameterized.product(renormalize_topk_logits=[True, False], )
194
- def test_basic(self, renormalize_topk_logits):
195
- dtype = jnp.bfloat16
196
- top_k = 8
197
- num_experts = 128
198
- hidden_size = 1024
199
- intermediate_size = 1024
200
- num_tokens = 8 * 32
201
- self._test_moe(
202
- dtype=dtype,
203
- top_k=top_k,
204
- num_experts=num_experts,
205
- hidden_size=hidden_size,
206
- intermediate_size=intermediate_size,
207
- num_tokens=num_tokens,
208
- seed=1234,
209
- renormalize_topk_logits=renormalize_topk_logits,
210
- bt=32,
211
- bf=1024,
212
- bd1=1024,
213
- bd2=1024,
214
- btc=32,
215
- bfc=256,
216
- bd1c=256,
217
- bd2c=256,
218
- )
219
-
220
- @parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
221
- def test_activation(self, act_fn):
222
- dtype = jnp.bfloat16
223
- top_k = 8
224
- num_experts = 128
225
- hidden_size = 1024
226
- intermediate_size = 1024
227
- num_tokens = 8 * 32
228
- self._test_moe(
229
- dtype=dtype,
230
- top_k=top_k,
231
- num_experts=num_experts,
232
- hidden_size=hidden_size,
233
- intermediate_size=intermediate_size,
234
- num_tokens=num_tokens,
235
- seed=1234,
236
- renormalize_topk_logits=True,
237
- act_fn=act_fn,
238
- bt=32,
239
- bf=512,
240
- bd1=512,
241
- bd2=512,
242
- btc=32,
243
- bfc=256,
244
- bd1c=256,
245
- bd2c=256,
246
- )
247
-
248
- def test_benchmark_qwen_235(self):
249
- num_experts = 128
250
- top_k = 8
251
- hidden_size = 4096
252
- intermediate_size = 1536
253
- dtype = jnp.bfloat16
254
- num_tokens = 8 * 64
255
- seed = 54321
256
- renormalize_topk_logits = True
257
- self._test_moe(
258
- dtype=dtype,
259
- top_k=top_k,
260
- num_experts=num_experts,
261
- hidden_size=hidden_size,
262
- intermediate_size=intermediate_size,
263
- num_tokens=num_tokens,
264
- seed=seed,
265
- renormalize_topk_logits=renormalize_topk_logits,
266
- bt=64,
267
- bf=768,
268
- bd1=2048,
269
- bd2=2048,
270
- btc=64,
271
- bfc=768,
272
- bd1c=2048,
273
- bd2c=2048,
274
- act_fn="silu",
275
- atol=5e-2,
276
- rtol=5e-2,
277
- )
278
-
279
- def test_benchmark_qwen_30b_a3b(self):
280
- num_experts = 128
281
- top_k = 8
282
- hidden_size = 2048
283
- intermediate_size = 768
284
- dtype = jnp.bfloat16
285
- num_tokens = 512
286
- seed = 54321
287
- renormalize_topk_logits = True
288
- self._test_moe(
289
- dtype=dtype,
290
- top_k=top_k,
291
- num_experts=num_experts,
292
- hidden_size=hidden_size,
293
- intermediate_size=intermediate_size,
294
- num_tokens=num_tokens,
295
- seed=seed,
296
- renormalize_topk_logits=renormalize_topk_logits,
297
- bt=16,
298
- bf=384,
299
- bd1=512,
300
- bd2=512,
301
- btc=16,
302
- bfc=384,
303
- bd1c=256,
304
- bd2c=256,
305
- act_fn="silu",
306
- atol=5e-2,
307
- rtol=5e-2,
308
- )
309
-
310
- @parameterized.product(
311
- w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
312
- def test_sub_channel_quantization(self, w_dtype):
313
- if w_dtype in (
314
- jnp.float8_e5m2,
315
- jnp.float4_e2m1fn,
316
- ) and not jtu.is_device_tpu_at_least(version=7):
317
- self.skipTest("Expect TPUv7+")
318
- dtype = jnp.bfloat16
319
- top_k = 8
320
- num_experts = 128
321
- hidden_size = 1024
322
- intermediate_size = 1024
323
- num_tokens = 8 * 32
324
- self._test_moe(
325
- dtype=dtype,
326
- top_k=top_k,
327
- num_experts=num_experts,
328
- hidden_size=hidden_size,
329
- intermediate_size=intermediate_size,
330
- num_tokens=num_tokens,
331
- seed=1234,
332
- renormalize_topk_logits=False,
333
- w_dtype=w_dtype,
334
- subc_quant_wsz=256,
335
- bt=32,
336
- bf=1024,
337
- bd1=1024,
338
- bd2=1024,
339
- btc=32,
340
- bfc=256,
341
- bd1c=256,
342
- bd2c=256,
343
- )
344
-
345
- def test_bias(self):
346
- dtype = jnp.bfloat16
347
- top_k = 8
348
- num_experts = 128
349
- hidden_size = 1024
350
- intermediate_size = 1024
351
- num_tokens = 8 * 32
352
- self._test_moe(
353
- dtype=dtype,
354
- top_k=top_k,
355
- num_experts=num_experts,
356
- hidden_size=hidden_size,
357
- intermediate_size=intermediate_size,
358
- num_tokens=num_tokens,
359
- seed=1234,
360
- renormalize_topk_logits=False,
361
- has_bias=True,
362
- bt=32,
363
- bf=512,
364
- bd1=512,
365
- bd2=512,
366
- btc=32,
367
- bfc=256,
368
- bd1c=256,
369
- bd2c=256,
370
- )
83
+ actual = jax.block_until_ready(
84
+ fused_ep_moe(
85
+ mesh=self.mesh,
86
+ tokens=a,
87
+ w1=w1,
88
+ w2=w2,
89
+ gating_output=gating_output,
90
+ top_k=top_k,
91
+ bt=32,
92
+ bf=512,
93
+ bd1=512,
94
+ bd2=512,
95
+ btc=32,
96
+ bfc=256,
97
+ bd1c=256,
98
+ bd2c=256,
99
+ ))
100
+ expected = ref_moe(a, w1, w2, gating_output, top_k)
101
+ self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2)
371
102
 
372
103
 
373
104
  if __name__ == "__main__":
@@ -99,7 +99,7 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
99
99
  (0, 0),
100
100
  (0, 0),
101
101
  ),
102
- constant_values=0,
102
+ constant_values=jnp.nan,
103
103
  ).reshape(
104
104
  -1,
105
105
  page_size,
@@ -122,7 +122,7 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
122
122
  kv_cache,
123
123
  ((0, num_pages - kv_cache.shape[0]), (0, 0), (0, 0), (0, 0),
124
124
  (0, 0)),
125
- constant_values=0,
125
+ constant_values=jnp.nan,
126
126
  )
127
127
  page_indices = jnp.stack(page_indices_list, axis=0)
128
128
  page_indices = jnp.pad(
tests/lora/test_layers.py CHANGED
@@ -91,6 +91,7 @@ def populate_loras(
91
91
  index_to_id: list[Optional[int]],
92
92
  lora_layer: BaseLayerWithLoRA,
93
93
  baselayer_weights: torch.Tensor,
94
+ generate_embeddings_tensor: int = 0,
94
95
  repeats: int = 1,
95
96
  ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
96
97
  """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
@@ -102,6 +103,8 @@ def populate_loras(
102
103
  lora_layer: the LoRAlayer to populate.
103
104
  baselayer_weights: the PyTorch tensor containing the layer's
104
105
  weights.
106
+ generate_embeddings_tensor: whether to generate an
107
+ embeddings tensor for each LoRA.
105
108
  repeats: must only be set for column parallel packed
106
109
  layers. Indicates the number of loras to compose
107
110
  together to create a single lora layer.
@@ -128,6 +131,7 @@ def populate_loras(
128
131
  baselayer_weights.device).init_random_lora(
129
132
  module_name=f"fake_{i}",
130
133
  weight=baselayer_weights,
134
+ generate_embeddings_tensor=generate_embeddings_tensor,
131
135
  )
132
136
  sublora.lora_b = sublora.lora_b[(sublora_len *
133
137
  i):(sublora_len * (i + 1)), :]
@@ -143,6 +147,7 @@ def populate_loras(
143
147
  slot_idx,
144
148
  lora_a=lora.lora_a,
145
149
  lora_b=lora.lora_b,
150
+ embeddings_tensor=lora.embeddings_tensor,
146
151
  )
147
152
 
148
153
  lora_dict[lora_id] = lora
@@ -541,6 +546,7 @@ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
541
546
  index_to_id,
542
547
  lora_config.max_loras,
543
548
  vocab_size=512,
549
+ extra_vocab_size=lora_config.lora_extra_vocab_size,
544
550
  )
545
551
  assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
546
552
  ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
tests/lora/utils.py CHANGED
@@ -24,6 +24,7 @@ class DummyLoRAManager:
24
24
  module_name: str,
25
25
  weight: torch.Tensor,
26
26
  rank: int = 8,
27
+ generate_embeddings_tensor: int = 0,
27
28
  ):
28
29
  lora = LoRALayerWeights(
29
30
  module_name,
@@ -36,6 +37,13 @@ class DummyLoRAManager:
36
37
  dtype=weight.dtype,
37
38
  device=self._device),
38
39
  )
40
+ if generate_embeddings_tensor:
41
+ lora.embeddings_tensor = torch.rand(
42
+ 5,
43
+ generate_embeddings_tensor,
44
+ dtype=weight.dtype,
45
+ device=self._device,
46
+ )
39
47
  self.set_module_lora(module_name, lora)
40
48
 
41
49
  return lora
tests/test_utils.py CHANGED
@@ -75,34 +75,25 @@ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
75
75
  mock_device2 = MagicMock()
76
76
  devices = [mock_device1, mock_device2]
77
77
 
78
- # Create mock addressable shards with data property
79
- mock_data1_dev1 = MagicMock()
80
- mock_data1_dev1.device = mock_device1
81
- mock_data1_dev1.nbytes = 2000 # 2000 bytes on device1
78
+ # Create mock device buffers
79
+ mock_buffer1_dev1 = MagicMock()
80
+ mock_buffer1_dev1.device = mock_device1
81
+ mock_buffer1_dev1.nbytes = 2000 # 2000 bytes on device1
82
82
 
83
- mock_data1_dev2 = MagicMock()
84
- mock_data1_dev2.device = mock_device2
85
- mock_data1_dev2.nbytes = 2000 # 2000 bytes on device2
83
+ mock_buffer1_dev2 = MagicMock()
84
+ mock_buffer1_dev2.device = mock_device2
85
+ mock_buffer1_dev2.nbytes = 2000 # 2000 bytes on device2
86
86
 
87
- mock_data2_dev1 = MagicMock()
88
- mock_data2_dev1.device = mock_device1
89
- mock_data2_dev1.nbytes = 1000 # 1000 bytes on device1
87
+ mock_buffer2_dev1 = MagicMock()
88
+ mock_buffer2_dev1.device = mock_device1
89
+ mock_buffer2_dev1.nbytes = 1000 # 1000 bytes on device1
90
90
 
91
- mock_shard1_dev1 = MagicMock()
92
- mock_shard1_dev1.data = mock_data1_dev1
93
-
94
- mock_shard1_dev2 = MagicMock()
95
- mock_shard1_dev2.data = mock_data1_dev2
96
-
97
- mock_shard2_dev1 = MagicMock()
98
- mock_shard2_dev1.data = mock_data2_dev1
99
-
100
- # Create mock arrays with addressable_shards
91
+ # Create mock arrays with device buffers
101
92
  mock_array1 = MagicMock()
102
- mock_array1.addressable_shards = [mock_shard1_dev1, mock_shard1_dev2]
93
+ mock_array1.device_buffers = [mock_buffer1_dev1, mock_buffer1_dev2]
103
94
 
104
95
  mock_array2 = MagicMock()
105
- mock_array2.addressable_shards = [mock_shard2_dev1]
96
+ mock_array2.device_buffers = [mock_buffer2_dev1]
106
97
 
107
98
  mock_live_arrays.return_value = [mock_array1, mock_array2]
108
99
 
@@ -168,7 +159,7 @@ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
168
159
  "head_dim, expected_padded_head_dim",
169
160
  [
170
161
  (1, 128),
171
- (64, 64),
162
+ (64, 128),
172
163
  (127, 128),
173
164
  (128, 128),
174
165
  (129, 256),
@@ -231,5 +222,6 @@ def test_get_jax_dtype_from_str_dtype():
231
222
  assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
232
223
  assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
233
224
  assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
234
- assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
225
+ assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
235
226
  assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
227
+ assert get_jax_dtype_from_str_dtype("auto") is None
tpu_inference/__init__.py CHANGED
@@ -1,40 +1,21 @@
1
+ import os
2
+
1
3
  # The environment variables override should be imported before any other
2
4
  # modules to ensure that the environment variables are set before any
3
5
  # other modules are imported.
4
6
  import tpu_inference.env_override # noqa: F401
5
- from tpu_inference import envs
6
7
  from tpu_inference import tpu_info as ti
7
8
  from tpu_inference.logger import init_logger
8
9
 
9
10
  logger = init_logger(__name__)
10
11
 
11
- if "proxy" in envs.JAX_PLATFORMS:
12
+ if "proxy" in os.environ.get('JAX_PLATFORMS', '').lower():
12
13
  logger.info("Running vLLM on TPU via Pathways proxy.")
13
14
  # Must run pathwaysutils.initialize() before any JAX operations
14
15
  try:
15
- import traceback
16
-
17
16
  import pathwaysutils
18
- import vllm
19
- from vllm.platforms import (resolve_current_platform_cls_qualname,
20
- resolve_obj_by_qualname)
21
17
  pathwaysutils.initialize()
22
18
  logger.info("Module pathwaysutils is imported.")
23
-
24
- # Pathways requires eager resolution of vllm.current_platform instead of
25
- # lazy resolution in the normal code path. Since this part involves
26
- # global topology discovery across multiple hosts, the platform
27
- # resolution must happen before other components are loaded.
28
- logger.info("Eagerly resolving vLLM current_platform for Pathways.")
29
- platform_cls_qualname = resolve_current_platform_cls_qualname()
30
- resolved_platform_instance = resolve_obj_by_qualname(
31
- platform_cls_qualname)()
32
- vllm.platforms._current_platform = resolved_platform_instance
33
- vllm.platforms._init_trace = "".join(traceback.format_stack())
34
- logger.info(
35
- f"vLLM platform resolved to: {resolved_platform_instance.__class__.__name__}"
36
- )
37
-
38
19
  except Exception as e:
39
20
  logger.error(
40
21
  f"Error occurred while importing pathwaysutils or logging TPU info: {e}"
@@ -29,7 +29,6 @@ from vllm.v1.request import Request, RequestStatus
29
29
 
30
30
  from tpu_inference import utils as common_utils
31
31
  from tpu_inference.core import disagg_executor, disagg_utils
32
- from tpu_inference.runner.tpu_runner import AsyncTPUModelRunnerOutput
33
32
  # ======================================================================================
34
33
  # Imports for _DisaggOrchestrator (decoupled from vLLM)
35
34
  # ======================================================================================
@@ -187,8 +186,6 @@ class _DisaggOrchestrator:
187
186
  if model_output is None:
188
187
  model_output = prefill_engine.model_executor.sample_tokens(
189
188
  grammar_output)
190
- if isinstance(model_output, AsyncTPUModelRunnerOutput):
191
- model_output = model_output.get_output()
192
189
 
193
190
  if scheduler_output.total_num_scheduled_tokens > 0:
194
191
  logger.debug(f"Prefill result: {model_output}")
@@ -221,16 +218,15 @@ class _DisaggOrchestrator:
221
218
  f"request-{req_id}: tokens={request.all_token_ids} after prefill"
222
219
  )
223
220
  # Remove request from the prefill engine.
224
- if req_id in prefill_engine.scheduler.requests:
225
- request = prefill_engine.scheduler.requests[req_id]
226
- prefill_engine.scheduler.running.remove(request)
227
- prefill_engine.scheduler.encoder_cache_manager.free(
228
- request)
229
221
 
230
- prefill_engine.scheduler.kv_cache_manager.free(
231
- request)
222
+ request = prefill_engine.scheduler.requests[req_id]
223
+ prefill_engine.scheduler.running.remove(request)
224
+ prefill_engine.scheduler.encoder_cache_manager.free(
225
+ request)
232
226
 
233
- prefill_engine.scheduler.requests.pop(req_id)
227
+ prefill_engine.scheduler.kv_cache_manager.free(request)
228
+
229
+ prefill_engine.scheduler.requests.pop(req_id)
234
230
 
235
231
  for output in (engine_core_outputs.items()
236
232
  if engine_core_outputs else ()):
@@ -339,10 +335,8 @@ class _DisaggOrchestrator:
339
335
  new_block_ids = kv_cache_manager.get_block_ids(req_id)
340
336
  logger.debug(
341
337
  f"inserting {req_id} new_block_ids {new_block_ids}")
342
- if len(new_block_ids[0]) != math.ceil(
343
- prompt_tokens / self._config.cache_config.block_size):
344
- logger.warning("Running out of blocks in decode engine! ")
345
- break
338
+ assert (len(new_block_ids[0]) == math.ceil(
339
+ prompt_tokens / self._config.cache_config.block_size))
346
340
 
347
341
  decode_engine.model_executor.driver_worker.model_runner.insert_request_with_kv_cache(
348
342
  vllm_request, kv_cache, new_block_ids)
@@ -372,8 +366,6 @@ class _DisaggOrchestrator:
372
366
  if model_output is None:
373
367
  model_output = decode_engine.model_executor.sample_tokens(
374
368
  grammar_output)
375
- if isinstance(model_output, AsyncTPUModelRunnerOutput):
376
- model_output = model_output.get_output()
377
369
 
378
370
  if scheduler_output.total_num_scheduled_tokens > 0:
379
371
  logger.debug(f"Decode result: {model_output}")
@@ -1,15 +1,17 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
+ import os
3
4
  from typing import Tuple
4
5
 
5
- from tpu_inference import envs
6
+ PREFILL_SLICES = 'PREFILL_SLICES'
7
+ DECODE_SLICES = 'DECODE_SLICES'
6
8
 
7
9
 
8
10
  def is_disagg_enabled() -> bool:
9
11
  # We triggrer our code path as long as prefill slices are set. This
10
12
  # allows us to test interleave mode effectively with the code path
11
13
  # for comparison purposes.
12
- return bool(envs.PREFILL_SLICES)
14
+ return PREFILL_SLICES in os.environ
13
15
 
14
16
 
15
17
  def _parse_slices(slices_str: str) -> Tuple[int, ...]:
@@ -38,12 +40,12 @@ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
38
40
 
39
41
 
40
42
  def get_prefill_slices() -> Tuple[int, ...]:
41
- if not envs.PREFILL_SLICES:
43
+ if PREFILL_SLICES not in os.environ:
42
44
  return ()
43
- return _parse_slices(envs.PREFILL_SLICES)
45
+ return _parse_slices(os.environ[PREFILL_SLICES])
44
46
 
45
47
 
46
48
  def get_decode_slices() -> Tuple[int, ...]:
47
- if not envs.DECODE_SLICES:
49
+ if DECODE_SLICES not in os.environ:
48
50
  return ()
49
- return _parse_slices(envs.DECODE_SLICES)
51
+ return _parse_slices(os.environ[DECODE_SLICES])