tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
18
18
  is_layer_skipped, unpack_quantized_values_into_int32)
19
19
  from vllm.scalar_type import scalar_types
20
20
 
21
- from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
22
21
  from tpu_inference.layers.vllm.linear_common import (
23
22
  slice_sharded_tensor_for_concatenation, torch_to_jax_param)
24
23
  from tpu_inference.layers.vllm.quantization.common import (
@@ -30,12 +29,12 @@ P = PartitionSpec
30
29
  logger = init_logger(__name__)
31
30
 
32
31
 
33
- @register_quantization_config(get_tpu_quant_method(AWQ))
32
+ @register_quantization_config("jax-awq")
34
33
  class VllmAWQConfig(AWQConfig, JaxCommonConfig):
35
34
 
36
35
  @classmethod
37
- def get_name(cls):
38
- return AWQ
36
+ def get_name(cls) -> str:
37
+ return "jax-awq"
39
38
 
40
39
  def get_supported_act_dtypes(self) -> list[torch.dtype]:
41
40
  # NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
@@ -61,12 +61,7 @@ class JaxCommonLinearConfig:
61
61
  " bad performance.", type(layer))
62
62
 
63
63
  self.bias_sharding = P(self.weight_sharding[0])
64
- if isinstance(self.weight_sharding[0], tuple):
65
- self.n_shards = 1
66
- for axis in self.weight_sharding[0]:
67
- self.n_shards *= self.mesh.shape.get(axis, 1)
68
- else:
69
- self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
64
+ self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
70
65
 
71
66
  def get_input_sharding(self, x: torchax.tensor.Tensor):
72
67
  if self.enable_sequence_parallelism:
@@ -16,8 +16,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
16
16
  from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
17
17
  find_matched_target, should_ignore_layer)
18
18
 
19
- from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
20
- get_tpu_quant_method)
21
19
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
22
20
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
23
21
  VllmCompressedTensorsW8A8Fp8MoEMethod
@@ -32,12 +30,12 @@ P = PartitionSpec
32
30
  logger = init_logger(__name__)
33
31
 
34
32
 
35
- @register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
33
+ @register_quantization_config("jax-compressed-tensors")
36
34
  class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
37
35
 
38
36
  @classmethod
39
37
  def get_name(cls) -> str:
40
- return COMPRESSED_TENSORS
38
+ return "jax-compressed-tensors"
41
39
 
42
40
  def get_scheme(self,
43
41
  layer: torch.nn.Module,
@@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.base_config import (
23
23
 
24
24
  from tpu_inference import envs
25
25
  from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
26
- from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
27
- get_tpu_quant_method)
28
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func
26
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
29
27
  from tpu_inference.layers.vllm.linear_common import (
30
28
  reorder_concatenated_tensor_for_sharding,
31
29
  slice_sharded_tensor_for_concatenation, torch_to_jax_param)
@@ -36,12 +34,12 @@ P = PartitionSpec
36
34
  logger = init_logger(__name__)
37
35
 
38
36
 
39
- @register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
37
+ @register_quantization_config("jax-unquantized")
40
38
  class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
41
39
 
42
40
  @classmethod
43
41
  def get_name(cls) -> str:
44
- return UNQUANTIZED
42
+ return "jax-unquantized"
45
43
 
46
44
  @classmethod
47
45
  def get_supported_act_dtypes(cls) -> list[torch.dtype]:
@@ -108,8 +106,6 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
108
106
  layer: torch.nn.Module,
109
107
  x: torch.Tensor,
110
108
  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
111
- assert isinstance(layer, LinearBase)
112
-
113
109
  with jax.named_scope(layer._get_name()):
114
110
  if in_sharding := self.jax_config.get_input_sharding(x):
115
111
  x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
@@ -168,18 +164,18 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
168
164
  ep_axis_name: str = 'model'):
169
165
  super().__init__(moe)
170
166
  self.mesh = mesh
171
- self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
167
+ self.use_kernel = envs.USE_MOE_EP_KERNEL
172
168
  self.ep_axis_name = ep_axis_name
173
169
  # TODO: Use autotune table once we have it.
174
170
  self.block_size = {
175
- "bt": 64,
176
- "bf": 1024,
177
- "bd1": 1536,
178
- "bd2": 1536,
179
- "btc": 64,
180
- "bfc": 1024,
181
- "bd1c": 1536,
182
- "bd2c": 1536,
171
+ "bt": 16,
172
+ "bf": 384,
173
+ "bd1": 512,
174
+ "bd2": 512,
175
+ "btc": 16,
176
+ "bfc": 384,
177
+ "bd1c": 256,
178
+ "bd2c": 256,
183
179
  }
184
180
 
185
181
  def select_gemm_impl(
@@ -193,11 +189,10 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
193
189
 
194
190
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
195
191
  assert isinstance(layer, FusedMoE)
192
+
196
193
  w13_weight = t2j(layer.w13_weight, use_dlpack=False)
197
194
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
198
195
 
199
- num_experts, hidden_size, intermediate_size = w2_weight.shape
200
-
201
196
  if self.moe.has_bias:
202
197
  w13_bias = t2j(layer.w13_bias, use_dlpack=False)
203
198
  w2_bias = t2j(layer.w2_bias, use_dlpack=False)
@@ -216,56 +211,76 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
216
211
  w3_bias = w13_bias[:, 1::2]
217
212
  w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
218
213
 
219
- if self.use_kernel:
214
+ if self.use_kernel and layer.use_ep:
220
215
  # Kernel expects:
221
216
  # w13: (num_experts, 2, hidden_size, intermediate_size)
222
217
  # w2: (num_experts, intermediate_size, hidden_size)
223
218
  # Current format:
224
219
  # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
225
220
  # w2_weight: (num_experts, hidden_size, intermediate_size)
221
+ num_experts = w13_weight.shape[0]
222
+ intermediate_size = w13_weight.shape[1] // 2
223
+ hidden_size = w13_weight.shape[2]
226
224
 
225
+ # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
227
226
  w13_reshaped = w13_weight.reshape(num_experts, 2,
228
227
  intermediate_size, hidden_size)
228
+ w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
229
229
 
230
- # Transpose non-constracting dim to right most dim
231
- w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3)
232
- w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
230
+ # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
231
+ w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
233
232
 
234
233
  # Apply EP sharding
235
- ep_sharding = NamedSharding(self.mesh, P("model"))
236
-
237
234
  w13_weight = jax.device_put(
238
- w13_weight_transposed, Format(Layout((0, 1, 2, 3)),
239
- ep_sharding))
240
- w2_weight = jax.device_put(w2_weight_transposed,
241
- Format(Layout((0, 1, 2)), ep_sharding))
235
+ w13_weight_transposed,
236
+ Format(Layout((0, 1, 2, 3)),
237
+ NamedSharding(self.mesh, P("model", None, None, None))))
238
+ w2_weight = jax.device_put(
239
+ w2_weight_transposed,
240
+ Format(Layout((0, 1, 2)),
241
+ NamedSharding(self.mesh, P("model", None, None))))
242
242
 
243
243
  if self.moe.has_bias:
244
244
  w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
245
+
246
+ # Apply EP sharding
245
247
  w13_bias = jax.device_put(
246
- w13_bias, Format(Layout((0, 1, 2)), ep_sharding))
247
- w2_bias = jax.device_put(w2_bias,
248
- Format(Layout((0, 1)), ep_sharding))
249
- else:
248
+ w13_bias,
249
+ Format(Layout((0, 1, 2)),
250
+ NamedSharding(self.mesh, P("model", None, None))))
251
+ w2_bias = jax.device_put(
252
+ w2_bias,
253
+ Format(Layout((0, 1)),
254
+ NamedSharding(self.mesh, P("model", None))))
250
255
 
256
+ else:
257
+ # Original logic for non-kernel path
251
258
  if layer.use_ep:
252
- ep_sharding = NamedSharding(self.mesh, P("model"))
253
259
  w13_weight = jax.device_put(
254
- w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
260
+ w13_weight,
261
+ Format(Layout((0, 1, 2)),
262
+ NamedSharding(self.mesh, P("model", None, None))))
255
263
  w2_weight = jax.device_put(
256
- w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
264
+ w2_weight,
265
+ Format(Layout((0, 1, 2)),
266
+ NamedSharding(self.mesh, P("model", None, None))))
257
267
 
258
268
  if self.moe.has_bias:
259
269
  w13_bias = jax.device_put(
260
- w13_bias, Format(Layout((0, 1)), ep_sharding))
270
+ w13_bias,
271
+ Format(Layout((0, 1)),
272
+ NamedSharding(self.mesh, P("model", None))))
261
273
  w2_bias = jax.device_put(
262
- w2_bias, Format(Layout((0, 1)), ep_sharding))
274
+ w2_bias,
275
+ Format(Layout((0, 1)),
276
+ NamedSharding(self.mesh, P("model", None))))
263
277
 
264
278
  else:
279
+ intermediate_size = w13_weight.shape[1] // 2
280
+ assert intermediate_size == w2_weight.shape[-1]
265
281
  output_sizes = [intermediate_size, intermediate_size]
266
282
  n_shards = self.mesh.shape["model"]
267
283
  assert intermediate_size % n_shards == 0
268
-
269
284
  w13_weight = reorder_concatenated_tensor_for_sharding(
270
285
  w13_weight, output_sizes, n_shards, dim=1)
271
286
  w13_weight = jax.device_put(
@@ -326,40 +341,30 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
326
341
  raise NotImplementedError(
327
342
  "Only softmax is supported for scoring_func")
328
343
 
329
- x = jax_view(x)
330
- w13_weight = jax_view(layer.w13_weight)
331
- w2_weight = jax_view(layer.w2_weight)
332
- w13_bias = w2_bias = None
333
- if self.moe.has_bias:
334
- w13_bias = jax_view(layer.w13_bias)
335
- w2_bias = jax_view(layer.w2_bias)
336
- gating_output = jax_view(router_logits)
337
-
338
344
  if self.use_kernel and layer.use_ep:
339
345
  output = fused_ep_moe(
340
346
  mesh=self.mesh,
341
- tokens=x,
342
- w1=w13_weight,
343
- w2=w2_weight,
344
- b1=w13_bias,
345
- b2=w2_bias,
346
- gating_output=gating_output,
347
+ tokens=jax_view(x),
348
+ w1=jax_view(layer.w13_weight),
349
+ w2=jax_view(layer.w2_weight),
350
+ gating_output=jax_view(router_logits),
347
351
  top_k=top_k,
348
352
  ep_axis_name=self.ep_axis_name,
349
- renormalize_topk_logits=renormalize,
350
- act_fn=activation,
351
353
  **self.block_size,
352
354
  )
353
355
  else:
354
- output = fused_moe_func(
355
- hidden_states=x,
356
- w1=w13_weight,
357
- w2=w2_weight,
358
- w1_bias=w13_bias,
359
- w2_bias=w2_bias,
360
- gating_output=gating_output,
356
+ # Use the original implementation
357
+ output = fused_moe_func_padded(
358
+ jax_view(x),
359
+ jax_view(layer.w13_weight),
360
+ jax_view(layer.w2_weight),
361
+ jax_view(layer.w13_bias) if self.moe.has_bias else None,
362
+ jax_view(layer.w2_bias) if self.moe.has_bias else None,
363
+ jax_view(router_logits),
361
364
  topk=top_k,
365
+ global_num_experts=global_num_experts,
362
366
  renormalize=renormalize,
367
+ reduce_results=layer.reduce_results,
363
368
  mesh=self.mesh,
364
369
  use_ep=layer.use_ep,
365
370
  activation=activation,
@@ -19,7 +19,6 @@ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
19
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
20
20
  ParallelLMHead, VocabParallelEmbedding)
21
21
 
22
- from tpu_inference import envs
23
22
  from tpu_inference.logger import init_logger
24
23
 
25
24
  P = PartitionSpec
@@ -212,7 +211,8 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
212
211
  def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
213
212
  if isinstance(tensor, tuple):
214
213
  return tuple(_sharded_device_put(t, sharding) for t in tensor)
215
- multihost_backend = envs.TPU_MULTIHOST_BACKEND
214
+ import os
215
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
216
216
  if multihost_backend != "ray":
217
217
  return jax.device_put(tensor, sharding)
218
218
 
@@ -239,6 +239,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
239
239
  lora_index_to_id: list[Optional[int]],
240
240
  max_loras: int,
241
241
  vocab_size: int,
242
+ extra_vocab_size: int,
242
243
  ):
243
244
  # Pad the prompt mapping to avoid running into recompiles on the TPU
244
245
  # TODO: Should this happen inside mapping internally? If so how can we
@@ -257,7 +258,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
257
258
  lora_index_to_id,
258
259
  max_loras,
259
260
  vocab_size,
260
- 0, # extra_vocab_size
261
+ extra_vocab_size,
261
262
  "cpu",
262
263
  )
263
264
  with torchax.default_env():
File without changes
@@ -0,0 +1,28 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Mapping
3
+
4
+
5
+ @dataclass
6
+ class ModelConfig():
7
+ max_model_len: int = 2048
8
+ max_prefill_len: int = 1024
9
+ prefill_batch_size: int = 1
10
+ decode_batch_size: int = 1
11
+ block_size: int = 16
12
+ num_layers: int = 32
13
+ num_kv_heads: int = 32
14
+ head_dim: int = 128
15
+ vocab_size: int = 32000
16
+ model: str = "llama3"
17
+ hf_config: str = ""
18
+ architectures: List[str] = field(default_factory=list)
19
+ override_generation_config: dict[str, Any] = field(default_factory=dict)
20
+ hf_overrides: dict[str, Any] = field(default_factory=dict)
21
+
22
+
23
+ @dataclass
24
+ class VllmConfig():
25
+ additional_config: Mapping[str, Any] = field(default_factory=dict)
26
+ # Set default max_model_len to turn off warnings.
27
+ model_config: ModelConfig = field(
28
+ default_factory=lambda: ModelConfig(max_model_len=1024))