tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (25) hide show
  1. tests/test_envs.py +32 -11
  2. tests/test_utils.py +1 -2
  3. tpu_inference/distributed/tpu_connector.py +1 -1
  4. tpu_inference/envs.py +60 -7
  5. tpu_inference/executors/ray_distributed_executor.py +5 -1
  6. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +72 -19
  7. tpu_inference/layers/common/sharding.py +3 -4
  8. tpu_inference/layers/vllm/quantization/mxfp4.py +2 -1
  9. tpu_inference/models/common/model_loader.py +3 -1
  10. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  11. tpu_inference/models/vllm/vllm_model_wrapper.py +1 -2
  12. tpu_inference/platforms/tpu_platform.py +13 -20
  13. tpu_inference/runner/compilation_manager.py +87 -27
  14. tpu_inference/runner/kv_cache_manager.py +8 -15
  15. tpu_inference/runner/persistent_batch_manager.py +40 -2
  16. tpu_inference/runner/tpu_runner.py +68 -45
  17. tpu_inference/runner/utils.py +2 -2
  18. tpu_inference/spec_decode/jax/eagle3.py +52 -19
  19. tpu_inference/utils.py +31 -9
  20. tpu_inference/worker/tpu_worker.py +2 -2
  21. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +1 -1
  22. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +25 -25
  23. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  24. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  25. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
tests/test_envs.py CHANGED
@@ -56,6 +56,12 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
56
56
 
57
57
 
58
58
  def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
59
+ # Ensure clean environment for boolean vars by setting to default "0"
60
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
61
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
62
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
63
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
64
+
59
65
  # Test SKIP_JAX_PRECOMPILE (default False)
60
66
  assert envs.SKIP_JAX_PRECOMPILE is False
61
67
  monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
@@ -63,6 +69,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
63
69
  monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
64
70
  assert envs.SKIP_JAX_PRECOMPILE is False
65
71
 
72
+ # Test VLLM_XLA_CHECK_RECOMPILATION (default False)
73
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
74
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
75
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
76
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
77
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
78
+
66
79
  # Test NEW_MODEL_DESIGN (default False)
67
80
  assert envs.NEW_MODEL_DESIGN is False
68
81
  monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
@@ -75,20 +88,32 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
75
88
 
76
89
 
77
90
  def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
91
+ # Ensure clean environment for integer vars by setting to defaults
92
+ monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
93
+ monkeypatch.setenv("NUM_SLICES", "1")
94
+
78
95
  assert envs.PYTHON_TRACER_LEVEL == 1
79
96
  monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
80
97
  assert envs.PYTHON_TRACER_LEVEL == 3
81
98
  monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
82
99
  assert envs.PYTHON_TRACER_LEVEL == 0
83
100
 
101
+ # Test NUM_SLICES (default 1)
102
+ assert envs.NUM_SLICES == 1
103
+ monkeypatch.setenv("NUM_SLICES", "2")
104
+ assert envs.NUM_SLICES == 2
105
+ monkeypatch.setenv("NUM_SLICES", "4")
106
+ assert envs.NUM_SLICES == 4
84
107
 
85
- def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
86
- monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
87
- assert envs.TPU_MULTIHOST_BACKEND == "grpc"
88
108
 
89
- monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
109
+ def test_model_impl_type_choices(monkeypatch: pytest.MonkeyPatch):
110
+ # Test case sensitive choices
111
+ monkeypatch.setenv("MODEL_IMPL_TYPE", "flax_nnx")
90
112
  assert envs.MODEL_IMPL_TYPE == "flax_nnx"
91
113
 
114
+ monkeypatch.setenv("MODEL_IMPL_TYPE", "vllm")
115
+ assert envs.MODEL_IMPL_TYPE == "vllm"
116
+
92
117
 
93
118
  def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
94
119
  monkeypatch.delenv("JAX_PLATFORMS", raising=False)
@@ -117,8 +142,6 @@ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
117
142
  assert envs.RAY_USAGE_STATS_ENABLED == "1"
118
143
 
119
144
  assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
120
- monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
121
- assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
122
145
 
123
146
 
124
147
  def test_invalid_attribute_raises_error():
@@ -134,6 +157,7 @@ def test_dir_returns_all_env_vars():
134
157
  assert "JAX_PLATFORMS" in env_vars
135
158
  assert "TPU_NAME" in env_vars
136
159
  assert "SKIP_JAX_PRECOMPILE" in env_vars
160
+ assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
137
161
  assert "MODEL_IMPL_TYPE" in env_vars
138
162
 
139
163
 
@@ -141,11 +165,8 @@ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
141
165
  monkeypatch.setenv("TPU_WORKER_ID", "0")
142
166
  assert envs.TPU_WORKER_ID == "0"
143
167
 
144
- monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
145
- assert envs.TPU_MULTIHOST_BACKEND == "grpc"
146
-
147
- monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
148
- assert envs.TPU_MULTIHOST_BACKEND == "xla"
168
+ monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "ray")
169
+ assert envs.TPU_MULTIHOST_BACKEND == "ray"
149
170
 
150
171
 
151
172
  def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
tests/test_utils.py CHANGED
@@ -231,6 +231,5 @@ def test_get_jax_dtype_from_str_dtype():
231
231
  assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
232
232
  assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
233
233
  assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
234
- assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3
234
+ assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
235
235
  assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
236
- assert get_jax_dtype_from_str_dtype("auto") is None
@@ -457,7 +457,6 @@ class TPUConnectorWorker:
457
457
  self.side_channel_port = get_side_channel_port()
458
458
 
459
459
  self.kv_transfer_server = None
460
- self._maybe_start_p2p_server()
461
460
  self.zmq_cxt = zmq.Context()
462
461
  if self.is_producer:
463
462
  ready_event = threading.Event()
@@ -499,6 +498,7 @@ class TPUConnectorWorker:
499
498
  self.shape = list(kv_layer.shape)
500
499
  self.dtype = kv_layer.dtype
501
500
  self.sharding = kv_layer.sharding
501
+ self._maybe_start_p2p_server()
502
502
 
503
503
  def _maybe_start_p2p_server(self):
504
504
  if self.kv_transfer_server is not None:
tpu_inference/envs.py CHANGED
@@ -15,14 +15,60 @@ if TYPE_CHECKING:
15
15
  PREFILL_SLICES: str = ""
16
16
  DECODE_SLICES: str = ""
17
17
  SKIP_JAX_PRECOMPILE: bool = False
18
+ VLLM_XLA_CHECK_RECOMPILATION: bool = False
18
19
  MODEL_IMPL_TYPE: str = "flax_nnx"
19
20
  NEW_MODEL_DESIGN: bool = False
20
21
  PHASED_PROFILING_DIR: str = ""
21
22
  PYTHON_TRACER_LEVEL: int = 1
22
23
  USE_MOE_EP_KERNEL: bool = False
24
+ NUM_SLICES: int = 1
23
25
  RAY_USAGE_STATS_ENABLED: str = "0"
24
26
  VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
25
27
 
28
+
29
+ def env_with_choices(
30
+ env_name: str,
31
+ default: str | None,
32
+ choices: list[str] | Callable[[], list[str]],
33
+ case_sensitive: bool = True,
34
+ ) -> Callable[[], str | None]:
35
+ """
36
+ Create a lambda that validates environment variable against allowed choices
37
+
38
+ Args:
39
+ env_name: Name of the environment variable
40
+ default: Default value if not set (can be None)
41
+ choices: List of valid string options or callable that returns list
42
+ case_sensitive: Whether validation should be case sensitive
43
+
44
+ Returns:
45
+ Lambda function for environment_variables dict
46
+ """
47
+
48
+ def _get_validated_env() -> str | None:
49
+ value = os.getenv(env_name)
50
+ if value is None:
51
+ return default
52
+
53
+ # Resolve choices if it's a callable (for lazy loading)
54
+ actual_choices = choices() if callable(choices) else choices
55
+
56
+ if not case_sensitive:
57
+ check_value = value.lower()
58
+ check_choices = [choice.lower() for choice in actual_choices]
59
+ else:
60
+ check_value = value
61
+ check_choices = actual_choices
62
+
63
+ if check_value not in check_choices:
64
+ raise ValueError(f"Invalid value '{value}' for {env_name}. "
65
+ f"Valid options: {actual_choices}.")
66
+
67
+ return value
68
+
69
+ return _get_validated_env
70
+
71
+
26
72
  environment_variables: dict[str, Callable[[], Any]] = {
27
73
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
28
74
  "JAX_PLATFORMS":
@@ -38,7 +84,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
38
84
  lambda: os.getenv("TPU_WORKER_ID", None),
39
85
  # Backend for multi-host communication on TPU
40
86
  "TPU_MULTIHOST_BACKEND":
41
- lambda: os.getenv("TPU_MULTIHOST_BACKEND", "").lower(),
87
+ env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
42
88
  # Slice configuration for disaggregated prefill workers
43
89
  "PREFILL_SLICES":
44
90
  lambda: os.getenv("PREFILL_SLICES", ""),
@@ -47,28 +93,35 @@ environment_variables: dict[str, Callable[[], Any]] = {
47
93
  lambda: os.getenv("DECODE_SLICES", ""),
48
94
  # Skip JAX precompilation step during initialization
49
95
  "SKIP_JAX_PRECOMPILE":
50
- lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
96
+ lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
97
+ # Check for XLA recompilation during execution
98
+ "VLLM_XLA_CHECK_RECOMPILATION":
99
+ lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
51
100
  # Model implementation type (e.g., "flax_nnx")
52
101
  "MODEL_IMPL_TYPE":
53
- lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
102
+ env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
103
+ ["vllm", "flax_nnx", "jetpack"]),
54
104
  # Enable new experimental model design
55
105
  "NEW_MODEL_DESIGN":
56
- lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))),
106
+ lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
57
107
  # Directory to store phased profiling output
58
108
  "PHASED_PROFILING_DIR":
59
109
  lambda: os.getenv("PHASED_PROFILING_DIR", ""),
60
110
  # Python tracer level for profiling
61
111
  "PYTHON_TRACER_LEVEL":
62
- lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")),
112
+ lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
63
113
  # Use custom expert-parallel kernel for MoE (Mixture of Experts)
64
114
  "USE_MOE_EP_KERNEL":
65
- lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
115
+ lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
116
+ # Number of TPU slices for multi-slice mesh
117
+ "NUM_SLICES":
118
+ lambda: int(os.getenv("NUM_SLICES") or "1"),
66
119
  # Enable/disable Ray usage statistics collection
67
120
  "RAY_USAGE_STATS_ENABLED":
68
121
  lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
69
122
  # Ray compiled DAG channel type for TPU
70
123
  "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
71
- lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm"),
124
+ env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
72
125
  }
73
126
 
74
127
 
@@ -136,10 +136,14 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
136
136
 
137
137
  pp_size = self.parallel_config.pipeline_parallel_size
138
138
  placement_group_specs: List[Dict[str, float]] = []
139
+
140
+ ray_nodes = ray.nodes()
141
+ logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
142
+
139
143
  if pp_size == 1:
140
144
  placement_group_specs = [{
141
145
  device_str: node['Resources'][device_str]
142
- } for node in ray.nodes()]
146
+ } for node in ray_nodes]
143
147
  else:
144
148
  num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
145
149
  placement_group_specs = [{
@@ -352,7 +352,7 @@ def _ragged_paged_attention_kernel(
352
352
  debug_print("[RPA debug] q_len={}", q_len)
353
353
  debug_print("[RPA debug] kv_len={}", kv_len)
354
354
 
355
- def flash_attention(
355
+ def flash_attention_step1_qk_softmax(
356
356
  q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
357
357
  kv, # [bkv_sz, actual_head_dim_x2]
358
358
  *,
@@ -366,7 +366,6 @@ def _ragged_paged_attention_kernel(
366
366
  assert kv.shape == (bkv_sz, actual_head_dim_x2)
367
367
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
368
368
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
369
- head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
370
369
 
371
370
  def load_with_init(ref, init_val):
372
371
  return jnp.where(bkv_idx == bkv_idx_start,
@@ -416,15 +415,33 @@ def _ragged_paged_attention_kernel(
416
415
  head_m_ref[...] = m_curr
417
416
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
418
417
 
419
- pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
420
- if v_scale is not None:
421
- pv *= v_scale
422
-
423
418
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
424
419
  exp_m_diff = jnp.exp(m_prev - m_curr)
425
420
  l_prev = load_with_init(head_l_ref, 1.0)
426
421
  l_curr = exp_m_diff * l_prev + p_rowsum
427
422
  head_l_ref[...] = l_curr
423
+
424
+ return p, exp_m_diff
425
+
426
+ def flash_attention_step2_pv(
427
+ q_shape_0,
428
+ kv, # [bkv_sz, actual_head_dim_x2]
429
+ p, # from step1
430
+ exp_m_diff, # from step1
431
+ *,
432
+ bkv_idx,
433
+ kv_head_idx,
434
+ ):
435
+ head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
436
+
437
+ def load_with_init(ref, init_val):
438
+ return jnp.where(bkv_idx == bkv_idx_start,
439
+ jnp.full_like(ref, init_val), ref[...])
440
+
441
+ pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
442
+ if v_scale is not None:
443
+ pv *= v_scale
444
+
428
445
  o_prev = load_with_init(head_acc_ref, 0.0)
429
446
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
430
447
  head_acc_ref[...] = o_curr
@@ -835,6 +852,11 @@ def _ragged_paged_attention_kernel(
835
852
  return
836
853
 
837
854
  # Flash attention with cur bkv and bq
855
+ prev_bq_shape_0 = None
856
+ prev_kv_head_bkv = None
857
+ prev_kv_head_idx = None
858
+ prev_kv_head_p = None
859
+ prev_kv_head_exp_m_diff = None
838
860
  for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
839
861
  bkv_lst = strided_load_bkv(
840
862
  bkv_sem_idx,
@@ -844,20 +866,51 @@ def _ragged_paged_attention_kernel(
844
866
  )
845
867
  assert len(bkv_lst) == kv_packing
846
868
  for i in range(kv_packing):
847
- kv_head_idx = kv_head_start + i
848
- if kv_head_idx >= actual_num_kv_heads:
869
+ cur_kv_head_idx = kv_head_start + i
870
+ if cur_kv_head_idx >= actual_num_kv_heads:
849
871
  break
850
- bq = load_bq(bq_sem_idx,
851
- kv_head_idx,
852
- actual_bq_sz=actual_bq_sz)
853
- bkv = bkv_lst[i]
854
- flash_attention(
855
- bq,
856
- bkv,
857
- bq_idx=bq_idx,
858
- bkv_idx=bkv_idx,
859
- kv_head_idx=kv_head_idx,
860
- )
872
+ cur_kv_head_bq = load_bq(bq_sem_idx,
873
+ cur_kv_head_idx,
874
+ actual_bq_sz=actual_bq_sz)
875
+ cur_kv_head__bkv = bkv_lst[i]
876
+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
877
+ # and `flash_attention_step2_pv` to pipeline the computation.
878
+ # `step2_pv` for the previous KV head, which depends on the softmax
879
+ # output, is overlapped with `step1_qk_softmax` for the current KV
880
+ # head, reducing overall wait times.
881
+ cur_kv_head_p, cur_kv_head_exp_m_diff = (
882
+ flash_attention_step1_qk_softmax(
883
+ cur_kv_head_bq,
884
+ cur_kv_head__bkv,
885
+ bq_idx=bq_idx,
886
+ bkv_idx=bkv_idx,
887
+ kv_head_idx=cur_kv_head_idx,
888
+ ))
889
+ if prev_bq_shape_0 is not None:
890
+ flash_attention_step2_pv(
891
+ prev_bq_shape_0,
892
+ prev_kv_head_bkv,
893
+ prev_kv_head_p,
894
+ prev_kv_head_exp_m_diff,
895
+ bkv_idx=bkv_idx,
896
+ kv_head_idx=prev_kv_head_idx,
897
+ )
898
+ prev_bq_shape_0 = cur_kv_head_bq.shape[0]
899
+ prev_kv_head_bkv = cur_kv_head__bkv
900
+ prev_kv_head_p = cur_kv_head_p
901
+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
902
+ prev_kv_head_idx = cur_kv_head_idx
903
+
904
+ # Execute pv of last attention head.
905
+ assert prev_bq_shape_0 is not None
906
+ flash_attention_step2_pv(
907
+ prev_bq_shape_0,
908
+ prev_kv_head_bkv,
909
+ prev_kv_head_p,
910
+ prev_kv_head_exp_m_diff,
911
+ bkv_idx=bkv_idx,
912
+ kv_head_idx=prev_kv_head_idx,
913
+ )
861
914
 
862
915
  lax.fori_loop(bkv_idx_start,
863
916
  num_bkv,
@@ -1,6 +1,5 @@
1
1
  import json
2
2
  import math
3
- import os
4
3
  from dataclasses import asdict, dataclass
5
4
  from typing import TYPE_CHECKING, List, Optional
6
5
 
@@ -8,7 +7,7 @@ import jax.numpy as jnp
8
7
  import numpy as np
9
8
  from jax.sharding import Mesh
10
9
 
11
- from tpu_inference import utils
10
+ from tpu_inference import envs, utils
12
11
 
13
12
  if TYPE_CHECKING:
14
13
  from vllm.v1.configs.vllm_config import VllmConfig
@@ -48,7 +47,7 @@ class ShardingAxisName2D:
48
47
 
49
48
 
50
49
  try:
51
- _use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
50
+ _use_base_sharding = envs.NEW_MODEL_DESIGN
52
51
  if _use_base_sharding:
53
52
  ShardingAxisName = ShardingAxisNameBase
54
53
  else:
@@ -167,7 +166,7 @@ class ShardingConfigManager:
167
166
  f"(DP size: {total_dp_size}). Please disable LoRA or "
168
167
  f"set data parallelism to 1.")
169
168
  if sharding_strategy.attention_data_parallelism > 1:
170
- if not os.environ.get("NEW_MODEL_DESIGN", False):
169
+ if not envs.NEW_MODEL_DESIGN:
171
170
  raise ValueError(
172
171
  "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
173
172
  "NEW_MODEL_DESIGN=True.")
@@ -95,7 +95,8 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
95
95
  "UnquantizedLinearMethod.")
96
96
  return VllmUnquantizedLinearMethod(linear_config)
97
97
  elif isinstance(layer, FusedMoE):
98
- return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
98
+ moe_config = self.get_moe_config(layer)
99
+ return VllmMxfp4MoEMethod(moe_config, self.mesh)
99
100
  elif isinstance(layer, Attention):
100
101
  # TODO: Add support for MXFP4 Attention.
101
102
  logger.warning_once("MXFP4 attention layer is not implemented. "
@@ -236,7 +236,9 @@ def get_flax_model(
236
236
  hidden_states_sharding, # aux hidden states
237
237
  ),
238
238
  donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
239
- static_argnums=7, #7 is layer_name_to_kvcache_index
239
+ static_argnums=(
240
+ 7, 10, 11
241
+ ), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
240
242
  )
241
243
  def run_model(graphdef, state, *args):
242
244
  model = nnx.merge(graphdef, state)
@@ -154,12 +154,9 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
154
154
  logger.info(f"Memory usage before applying quantization of params: "
155
155
  f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
156
156
 
157
- # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
158
- kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
159
-
160
- # Handle the case where kv_cache_dtype is "auto"
161
- if kv_cache_jnp_dtype is None:
162
- assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
157
+ if kv_cache_dtype != "auto":
158
+ kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
159
+ else:
163
160
  kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
164
161
 
165
162
  kv_caches = create_kv_caches(
@@ -221,7 +221,7 @@ class VllmModelWrapper:
221
221
  @functools.partial(
222
222
  jax.jit,
223
223
  out_shardings=(NamedSharding(self.mesh,
224
- PartitionSpec(None, "model"))),
224
+ PartitionSpec("data", "model"))),
225
225
  )
226
226
  def compute_logits_func(
227
227
  params_and_buffers: Any,
@@ -263,7 +263,6 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
263
263
  vllm_config,
264
264
  device,
265
265
  model.embedding_modules,
266
- model.embedding_padding_modules,
267
266
  )
268
267
  return lora_manager, lora_manager.create_lora_manager(model)
269
268
 
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
5
5
  import jax.numpy as jnp
6
6
  import torch
7
7
  import vllm.envs as vllm_envs
8
- from torchax.ops.mappings import j2t_dtype
9
8
  from tpu_info import device
10
9
  from vllm.inputs import ProcessorInputs, PromptType
11
10
  from vllm.platforms.interface import Platform, PlatformEnum
@@ -14,6 +13,7 @@ from vllm.sampling_params import SamplingParams, SamplingType
14
13
  from tpu_inference import envs
15
14
  from tpu_inference.layers.common.sharding import ShardingConfigManager
16
15
  from tpu_inference.logger import init_logger
16
+ from tpu_inference.utils import to_jax_dtype, to_torch_dtype
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  from vllm.attention.backends.registry import _Backend
@@ -28,12 +28,6 @@ else:
28
28
 
29
29
  logger = init_logger(__name__)
30
30
 
31
- _DTYPE: dict[str, jnp.dtype] = {
32
- "bfloat16": jnp.bfloat16,
33
- "float": jnp.float32,
34
- "float32": jnp.float32,
35
- }
36
-
37
31
 
38
32
  class TpuPlatform(Platform):
39
33
  _enum = PlatformEnum.TPU
@@ -158,20 +152,19 @@ class TpuPlatform(Platform):
158
152
  # NOTE(xiang): convert dtype to jnp.dtype
159
153
  # NOTE(wenlong): skip this logic for mm model preprocessing
160
154
  # For mm model preprocessors, it may need the output dtype to be torch.
161
- # In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
155
+ # In order to avoid a PR to vLLM, we postpone the dtype checking during
156
+ # tpu_worker initialization
162
157
  if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
163
- if not isinstance(vllm_config.model_config.dtype, str):
164
- logger.warning(
165
- "The model dtype is not properly set for JAX backend. "
166
- "Overwriting it to jnp.bfloat16")
167
- vllm_config.model_config.dtype = jnp.bfloat16
168
- else:
169
- vllm_config.model_config.dtype = _DTYPE.get(
170
- vllm_config.model_config.dtype, jnp.bfloat16)
171
-
172
- if impl == "vllm":
173
- vllm_config.model_config.dtype = j2t_dtype(
174
- vllm_config.model_config.dtype.dtype)
158
+ model_dtype = vllm_config.model_config.dtype
159
+ try:
160
+ dtype = to_jax_dtype(model_dtype)
161
+ except ValueError:
162
+ logger.warning(f"{model_dtype=} is not supported. "
163
+ "Falling back to jnp.bfloat16")
164
+ dtype = jnp.bfloat16
165
+ if impl == "vllm":
166
+ dtype = to_torch_dtype(dtype)
167
+ vllm_config.model_config.dtype = dtype
175
168
 
176
169
  # TODO(cuiq): remove this dependency.
177
170
  from vllm.v1.attention.backends.pallas import PallasAttentionBackend