tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +60 -7
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +72 -19
- tpu_inference/layers/common/sharding.py +3 -4
- tpu_inference/layers/vllm/quantization/mxfp4.py +2 -1
- tpu_inference/models/common/model_loader.py +3 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/vllm/vllm_model_wrapper.py +1 -2
- tpu_inference/platforms/tpu_platform.py +13 -20
- tpu_inference/runner/compilation_manager.py +87 -27
- tpu_inference/runner/kv_cache_manager.py +8 -15
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/tpu_runner.py +68 -45
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +52 -19
- tpu_inference/utils.py +31 -9
- tpu_inference/worker/tpu_worker.py +2 -2
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +1 -1
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +25 -25
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
tests/test_envs.py
CHANGED
|
@@ -56,6 +56,12 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
59
|
+
# Ensure clean environment for boolean vars by setting to default "0"
|
|
60
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
61
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
62
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
|
|
63
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
|
|
64
|
+
|
|
59
65
|
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
60
66
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
61
67
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
|
|
@@ -63,6 +69,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
63
69
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
64
70
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
65
71
|
|
|
72
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION (default False)
|
|
73
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
74
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
|
|
75
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
76
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
77
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
78
|
+
|
|
66
79
|
# Test NEW_MODEL_DESIGN (default False)
|
|
67
80
|
assert envs.NEW_MODEL_DESIGN is False
|
|
68
81
|
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
|
|
@@ -75,20 +88,32 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
75
88
|
|
|
76
89
|
|
|
77
90
|
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
91
|
+
# Ensure clean environment for integer vars by setting to defaults
|
|
92
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
|
|
93
|
+
monkeypatch.setenv("NUM_SLICES", "1")
|
|
94
|
+
|
|
78
95
|
assert envs.PYTHON_TRACER_LEVEL == 1
|
|
79
96
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
|
|
80
97
|
assert envs.PYTHON_TRACER_LEVEL == 3
|
|
81
98
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
|
|
82
99
|
assert envs.PYTHON_TRACER_LEVEL == 0
|
|
83
100
|
|
|
101
|
+
# Test NUM_SLICES (default 1)
|
|
102
|
+
assert envs.NUM_SLICES == 1
|
|
103
|
+
monkeypatch.setenv("NUM_SLICES", "2")
|
|
104
|
+
assert envs.NUM_SLICES == 2
|
|
105
|
+
monkeypatch.setenv("NUM_SLICES", "4")
|
|
106
|
+
assert envs.NUM_SLICES == 4
|
|
84
107
|
|
|
85
|
-
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
|
|
86
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
|
|
87
|
-
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
88
108
|
|
|
89
|
-
|
|
109
|
+
def test_model_impl_type_choices(monkeypatch: pytest.MonkeyPatch):
|
|
110
|
+
# Test case sensitive choices
|
|
111
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "flax_nnx")
|
|
90
112
|
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
91
113
|
|
|
114
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "vllm")
|
|
115
|
+
assert envs.MODEL_IMPL_TYPE == "vllm"
|
|
116
|
+
|
|
92
117
|
|
|
93
118
|
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
|
|
94
119
|
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
|
|
@@ -117,8 +142,6 @@ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
117
142
|
assert envs.RAY_USAGE_STATS_ENABLED == "1"
|
|
118
143
|
|
|
119
144
|
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
|
|
120
|
-
monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
|
|
121
|
-
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
|
|
122
145
|
|
|
123
146
|
|
|
124
147
|
def test_invalid_attribute_raises_error():
|
|
@@ -134,6 +157,7 @@ def test_dir_returns_all_env_vars():
|
|
|
134
157
|
assert "JAX_PLATFORMS" in env_vars
|
|
135
158
|
assert "TPU_NAME" in env_vars
|
|
136
159
|
assert "SKIP_JAX_PRECOMPILE" in env_vars
|
|
160
|
+
assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
|
|
137
161
|
assert "MODEL_IMPL_TYPE" in env_vars
|
|
138
162
|
|
|
139
163
|
|
|
@@ -141,11 +165,8 @@ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
141
165
|
monkeypatch.setenv("TPU_WORKER_ID", "0")
|
|
142
166
|
assert envs.TPU_WORKER_ID == "0"
|
|
143
167
|
|
|
144
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "
|
|
145
|
-
assert envs.TPU_MULTIHOST_BACKEND == "
|
|
146
|
-
|
|
147
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
|
|
148
|
-
assert envs.TPU_MULTIHOST_BACKEND == "xla"
|
|
168
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "ray")
|
|
169
|
+
assert envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
149
170
|
|
|
150
171
|
|
|
151
172
|
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
tests/test_utils.py
CHANGED
|
@@ -231,6 +231,5 @@ def test_get_jax_dtype_from_str_dtype():
|
|
|
231
231
|
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
|
|
232
232
|
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
|
|
233
233
|
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
|
|
234
|
-
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.
|
|
234
|
+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
|
|
235
235
|
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
|
|
236
|
-
assert get_jax_dtype_from_str_dtype("auto") is None
|
|
@@ -457,7 +457,6 @@ class TPUConnectorWorker:
|
|
|
457
457
|
self.side_channel_port = get_side_channel_port()
|
|
458
458
|
|
|
459
459
|
self.kv_transfer_server = None
|
|
460
|
-
self._maybe_start_p2p_server()
|
|
461
460
|
self.zmq_cxt = zmq.Context()
|
|
462
461
|
if self.is_producer:
|
|
463
462
|
ready_event = threading.Event()
|
|
@@ -499,6 +498,7 @@ class TPUConnectorWorker:
|
|
|
499
498
|
self.shape = list(kv_layer.shape)
|
|
500
499
|
self.dtype = kv_layer.dtype
|
|
501
500
|
self.sharding = kv_layer.sharding
|
|
501
|
+
self._maybe_start_p2p_server()
|
|
502
502
|
|
|
503
503
|
def _maybe_start_p2p_server(self):
|
|
504
504
|
if self.kv_transfer_server is not None:
|
tpu_inference/envs.py
CHANGED
|
@@ -15,14 +15,60 @@ if TYPE_CHECKING:
|
|
|
15
15
|
PREFILL_SLICES: str = ""
|
|
16
16
|
DECODE_SLICES: str = ""
|
|
17
17
|
SKIP_JAX_PRECOMPILE: bool = False
|
|
18
|
+
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
|
18
19
|
MODEL_IMPL_TYPE: str = "flax_nnx"
|
|
19
20
|
NEW_MODEL_DESIGN: bool = False
|
|
20
21
|
PHASED_PROFILING_DIR: str = ""
|
|
21
22
|
PYTHON_TRACER_LEVEL: int = 1
|
|
22
23
|
USE_MOE_EP_KERNEL: bool = False
|
|
24
|
+
NUM_SLICES: int = 1
|
|
23
25
|
RAY_USAGE_STATS_ENABLED: str = "0"
|
|
24
26
|
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
|
|
25
27
|
|
|
28
|
+
|
|
29
|
+
def env_with_choices(
|
|
30
|
+
env_name: str,
|
|
31
|
+
default: str | None,
|
|
32
|
+
choices: list[str] | Callable[[], list[str]],
|
|
33
|
+
case_sensitive: bool = True,
|
|
34
|
+
) -> Callable[[], str | None]:
|
|
35
|
+
"""
|
|
36
|
+
Create a lambda that validates environment variable against allowed choices
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
env_name: Name of the environment variable
|
|
40
|
+
default: Default value if not set (can be None)
|
|
41
|
+
choices: List of valid string options or callable that returns list
|
|
42
|
+
case_sensitive: Whether validation should be case sensitive
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Lambda function for environment_variables dict
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def _get_validated_env() -> str | None:
|
|
49
|
+
value = os.getenv(env_name)
|
|
50
|
+
if value is None:
|
|
51
|
+
return default
|
|
52
|
+
|
|
53
|
+
# Resolve choices if it's a callable (for lazy loading)
|
|
54
|
+
actual_choices = choices() if callable(choices) else choices
|
|
55
|
+
|
|
56
|
+
if not case_sensitive:
|
|
57
|
+
check_value = value.lower()
|
|
58
|
+
check_choices = [choice.lower() for choice in actual_choices]
|
|
59
|
+
else:
|
|
60
|
+
check_value = value
|
|
61
|
+
check_choices = actual_choices
|
|
62
|
+
|
|
63
|
+
if check_value not in check_choices:
|
|
64
|
+
raise ValueError(f"Invalid value '{value}' for {env_name}. "
|
|
65
|
+
f"Valid options: {actual_choices}.")
|
|
66
|
+
|
|
67
|
+
return value
|
|
68
|
+
|
|
69
|
+
return _get_validated_env
|
|
70
|
+
|
|
71
|
+
|
|
26
72
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
27
73
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
28
74
|
"JAX_PLATFORMS":
|
|
@@ -38,7 +84,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
38
84
|
lambda: os.getenv("TPU_WORKER_ID", None),
|
|
39
85
|
# Backend for multi-host communication on TPU
|
|
40
86
|
"TPU_MULTIHOST_BACKEND":
|
|
41
|
-
|
|
87
|
+
env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
|
|
42
88
|
# Slice configuration for disaggregated prefill workers
|
|
43
89
|
"PREFILL_SLICES":
|
|
44
90
|
lambda: os.getenv("PREFILL_SLICES", ""),
|
|
@@ -47,28 +93,35 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
47
93
|
lambda: os.getenv("DECODE_SLICES", ""),
|
|
48
94
|
# Skip JAX precompilation step during initialization
|
|
49
95
|
"SKIP_JAX_PRECOMPILE":
|
|
50
|
-
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE"
|
|
96
|
+
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
|
|
97
|
+
# Check for XLA recompilation during execution
|
|
98
|
+
"VLLM_XLA_CHECK_RECOMPILATION":
|
|
99
|
+
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
|
|
51
100
|
# Model implementation type (e.g., "flax_nnx")
|
|
52
101
|
"MODEL_IMPL_TYPE":
|
|
53
|
-
|
|
102
|
+
env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
|
|
103
|
+
["vllm", "flax_nnx", "jetpack"]),
|
|
54
104
|
# Enable new experimental model design
|
|
55
105
|
"NEW_MODEL_DESIGN":
|
|
56
|
-
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN"
|
|
106
|
+
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
|
|
57
107
|
# Directory to store phased profiling output
|
|
58
108
|
"PHASED_PROFILING_DIR":
|
|
59
109
|
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
|
|
60
110
|
# Python tracer level for profiling
|
|
61
111
|
"PYTHON_TRACER_LEVEL":
|
|
62
|
-
lambda: int(os.getenv("PYTHON_TRACER_LEVEL"
|
|
112
|
+
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
|
|
63
113
|
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
|
|
64
114
|
"USE_MOE_EP_KERNEL":
|
|
65
|
-
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL"
|
|
115
|
+
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
|
|
116
|
+
# Number of TPU slices for multi-slice mesh
|
|
117
|
+
"NUM_SLICES":
|
|
118
|
+
lambda: int(os.getenv("NUM_SLICES") or "1"),
|
|
66
119
|
# Enable/disable Ray usage statistics collection
|
|
67
120
|
"RAY_USAGE_STATS_ENABLED":
|
|
68
121
|
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
|
|
69
122
|
# Ray compiled DAG channel type for TPU
|
|
70
123
|
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
|
|
71
|
-
|
|
124
|
+
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
|
|
72
125
|
}
|
|
73
126
|
|
|
74
127
|
|
|
@@ -136,10 +136,14 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
136
136
|
|
|
137
137
|
pp_size = self.parallel_config.pipeline_parallel_size
|
|
138
138
|
placement_group_specs: List[Dict[str, float]] = []
|
|
139
|
+
|
|
140
|
+
ray_nodes = ray.nodes()
|
|
141
|
+
logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
|
|
142
|
+
|
|
139
143
|
if pp_size == 1:
|
|
140
144
|
placement_group_specs = [{
|
|
141
145
|
device_str: node['Resources'][device_str]
|
|
142
|
-
} for node in
|
|
146
|
+
} for node in ray_nodes]
|
|
143
147
|
else:
|
|
144
148
|
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
145
149
|
placement_group_specs = [{
|
|
@@ -352,7 +352,7 @@ def _ragged_paged_attention_kernel(
|
|
|
352
352
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
353
353
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
354
354
|
|
|
355
|
-
def
|
|
355
|
+
def flash_attention_step1_qk_softmax(
|
|
356
356
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
|
|
357
357
|
kv, # [bkv_sz, actual_head_dim_x2]
|
|
358
358
|
*,
|
|
@@ -366,7 +366,6 @@ def _ragged_paged_attention_kernel(
|
|
|
366
366
|
assert kv.shape == (bkv_sz, actual_head_dim_x2)
|
|
367
367
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
368
368
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
369
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
370
369
|
|
|
371
370
|
def load_with_init(ref, init_val):
|
|
372
371
|
return jnp.where(bkv_idx == bkv_idx_start,
|
|
@@ -416,15 +415,33 @@ def _ragged_paged_attention_kernel(
|
|
|
416
415
|
head_m_ref[...] = m_curr
|
|
417
416
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
418
417
|
|
|
419
|
-
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
420
|
-
if v_scale is not None:
|
|
421
|
-
pv *= v_scale
|
|
422
|
-
|
|
423
418
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
424
419
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
425
420
|
l_prev = load_with_init(head_l_ref, 1.0)
|
|
426
421
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
427
422
|
head_l_ref[...] = l_curr
|
|
423
|
+
|
|
424
|
+
return p, exp_m_diff
|
|
425
|
+
|
|
426
|
+
def flash_attention_step2_pv(
|
|
427
|
+
q_shape_0,
|
|
428
|
+
kv, # [bkv_sz, actual_head_dim_x2]
|
|
429
|
+
p, # from step1
|
|
430
|
+
exp_m_diff, # from step1
|
|
431
|
+
*,
|
|
432
|
+
bkv_idx,
|
|
433
|
+
kv_head_idx,
|
|
434
|
+
):
|
|
435
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
436
|
+
|
|
437
|
+
def load_with_init(ref, init_val):
|
|
438
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
439
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
440
|
+
|
|
441
|
+
pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
|
|
442
|
+
if v_scale is not None:
|
|
443
|
+
pv *= v_scale
|
|
444
|
+
|
|
428
445
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
429
446
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
430
447
|
head_acc_ref[...] = o_curr
|
|
@@ -835,6 +852,11 @@ def _ragged_paged_attention_kernel(
|
|
|
835
852
|
return
|
|
836
853
|
|
|
837
854
|
# Flash attention with cur bkv and bq
|
|
855
|
+
prev_bq_shape_0 = None
|
|
856
|
+
prev_kv_head_bkv = None
|
|
857
|
+
prev_kv_head_idx = None
|
|
858
|
+
prev_kv_head_p = None
|
|
859
|
+
prev_kv_head_exp_m_diff = None
|
|
838
860
|
for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
|
|
839
861
|
bkv_lst = strided_load_bkv(
|
|
840
862
|
bkv_sem_idx,
|
|
@@ -844,20 +866,51 @@ def _ragged_paged_attention_kernel(
|
|
|
844
866
|
)
|
|
845
867
|
assert len(bkv_lst) == kv_packing
|
|
846
868
|
for i in range(kv_packing):
|
|
847
|
-
|
|
848
|
-
if
|
|
869
|
+
cur_kv_head_idx = kv_head_start + i
|
|
870
|
+
if cur_kv_head_idx >= actual_num_kv_heads:
|
|
849
871
|
break
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
872
|
+
cur_kv_head_bq = load_bq(bq_sem_idx,
|
|
873
|
+
cur_kv_head_idx,
|
|
874
|
+
actual_bq_sz=actual_bq_sz)
|
|
875
|
+
cur_kv_head__bkv = bkv_lst[i]
|
|
876
|
+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
|
|
877
|
+
# and `flash_attention_step2_pv` to pipeline the computation.
|
|
878
|
+
# `step2_pv` for the previous KV head, which depends on the softmax
|
|
879
|
+
# output, is overlapped with `step1_qk_softmax` for the current KV
|
|
880
|
+
# head, reducing overall wait times.
|
|
881
|
+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
|
|
882
|
+
flash_attention_step1_qk_softmax(
|
|
883
|
+
cur_kv_head_bq,
|
|
884
|
+
cur_kv_head__bkv,
|
|
885
|
+
bq_idx=bq_idx,
|
|
886
|
+
bkv_idx=bkv_idx,
|
|
887
|
+
kv_head_idx=cur_kv_head_idx,
|
|
888
|
+
))
|
|
889
|
+
if prev_bq_shape_0 is not None:
|
|
890
|
+
flash_attention_step2_pv(
|
|
891
|
+
prev_bq_shape_0,
|
|
892
|
+
prev_kv_head_bkv,
|
|
893
|
+
prev_kv_head_p,
|
|
894
|
+
prev_kv_head_exp_m_diff,
|
|
895
|
+
bkv_idx=bkv_idx,
|
|
896
|
+
kv_head_idx=prev_kv_head_idx,
|
|
897
|
+
)
|
|
898
|
+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
899
|
+
prev_kv_head_bkv = cur_kv_head__bkv
|
|
900
|
+
prev_kv_head_p = cur_kv_head_p
|
|
901
|
+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
902
|
+
prev_kv_head_idx = cur_kv_head_idx
|
|
903
|
+
|
|
904
|
+
# Execute pv of last attention head.
|
|
905
|
+
assert prev_bq_shape_0 is not None
|
|
906
|
+
flash_attention_step2_pv(
|
|
907
|
+
prev_bq_shape_0,
|
|
908
|
+
prev_kv_head_bkv,
|
|
909
|
+
prev_kv_head_p,
|
|
910
|
+
prev_kv_head_exp_m_diff,
|
|
911
|
+
bkv_idx=bkv_idx,
|
|
912
|
+
kv_head_idx=prev_kv_head_idx,
|
|
913
|
+
)
|
|
861
914
|
|
|
862
915
|
lax.fori_loop(bkv_idx_start,
|
|
863
916
|
num_bkv,
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
|
-
import os
|
|
4
3
|
from dataclasses import asdict, dataclass
|
|
5
4
|
from typing import TYPE_CHECKING, List, Optional
|
|
6
5
|
|
|
@@ -8,7 +7,7 @@ import jax.numpy as jnp
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
from jax.sharding import Mesh
|
|
10
9
|
|
|
11
|
-
from tpu_inference import utils
|
|
10
|
+
from tpu_inference import envs, utils
|
|
12
11
|
|
|
13
12
|
if TYPE_CHECKING:
|
|
14
13
|
from vllm.v1.configs.vllm_config import VllmConfig
|
|
@@ -48,7 +47,7 @@ class ShardingAxisName2D:
|
|
|
48
47
|
|
|
49
48
|
|
|
50
49
|
try:
|
|
51
|
-
_use_base_sharding =
|
|
50
|
+
_use_base_sharding = envs.NEW_MODEL_DESIGN
|
|
52
51
|
if _use_base_sharding:
|
|
53
52
|
ShardingAxisName = ShardingAxisNameBase
|
|
54
53
|
else:
|
|
@@ -167,7 +166,7 @@ class ShardingConfigManager:
|
|
|
167
166
|
f"(DP size: {total_dp_size}). Please disable LoRA or "
|
|
168
167
|
f"set data parallelism to 1.")
|
|
169
168
|
if sharding_strategy.attention_data_parallelism > 1:
|
|
170
|
-
if not
|
|
169
|
+
if not envs.NEW_MODEL_DESIGN:
|
|
171
170
|
raise ValueError(
|
|
172
171
|
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
173
172
|
"NEW_MODEL_DESIGN=True.")
|
|
@@ -95,7 +95,8 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
95
95
|
"UnquantizedLinearMethod.")
|
|
96
96
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
97
97
|
elif isinstance(layer, FusedMoE):
|
|
98
|
-
|
|
98
|
+
moe_config = self.get_moe_config(layer)
|
|
99
|
+
return VllmMxfp4MoEMethod(moe_config, self.mesh)
|
|
99
100
|
elif isinstance(layer, Attention):
|
|
100
101
|
# TODO: Add support for MXFP4 Attention.
|
|
101
102
|
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
@@ -236,7 +236,9 @@ def get_flax_model(
|
|
|
236
236
|
hidden_states_sharding, # aux hidden states
|
|
237
237
|
),
|
|
238
238
|
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
|
|
239
|
-
static_argnums=
|
|
239
|
+
static_argnums=(
|
|
240
|
+
7, 10, 11
|
|
241
|
+
), #7 is layer_name_to_kvcache_index, 10 is is_first_rank, 11 is is_last_rank
|
|
240
242
|
)
|
|
241
243
|
def run_model(graphdef, state, *args):
|
|
242
244
|
model = nnx.merge(graphdef, state)
|
|
@@ -154,12 +154,9 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
|
|
|
154
154
|
logger.info(f"Memory usage before applying quantization of params: "
|
|
155
155
|
f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
|
|
156
156
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
# Handle the case where kv_cache_dtype is "auto"
|
|
161
|
-
if kv_cache_jnp_dtype is None:
|
|
162
|
-
assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
|
|
157
|
+
if kv_cache_dtype != "auto":
|
|
158
|
+
kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
|
|
159
|
+
else:
|
|
163
160
|
kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
|
|
164
161
|
|
|
165
162
|
kv_caches = create_kv_caches(
|
|
@@ -221,7 +221,7 @@ class VllmModelWrapper:
|
|
|
221
221
|
@functools.partial(
|
|
222
222
|
jax.jit,
|
|
223
223
|
out_shardings=(NamedSharding(self.mesh,
|
|
224
|
-
PartitionSpec(
|
|
224
|
+
PartitionSpec("data", "model"))),
|
|
225
225
|
)
|
|
226
226
|
def compute_logits_func(
|
|
227
227
|
params_and_buffers: Any,
|
|
@@ -263,7 +263,6 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
|
|
|
263
263
|
vllm_config,
|
|
264
264
|
device,
|
|
265
265
|
model.embedding_modules,
|
|
266
|
-
model.embedding_padding_modules,
|
|
267
266
|
)
|
|
268
267
|
return lora_manager, lora_manager.create_lora_manager(model)
|
|
269
268
|
|
|
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
|
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
import torch
|
|
7
7
|
import vllm.envs as vllm_envs
|
|
8
|
-
from torchax.ops.mappings import j2t_dtype
|
|
9
8
|
from tpu_info import device
|
|
10
9
|
from vllm.inputs import ProcessorInputs, PromptType
|
|
11
10
|
from vllm.platforms.interface import Platform, PlatformEnum
|
|
@@ -14,6 +13,7 @@ from vllm.sampling_params import SamplingParams, SamplingType
|
|
|
14
13
|
from tpu_inference import envs
|
|
15
14
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
16
15
|
from tpu_inference.logger import init_logger
|
|
16
|
+
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
19
|
from vllm.attention.backends.registry import _Backend
|
|
@@ -28,12 +28,6 @@ else:
|
|
|
28
28
|
|
|
29
29
|
logger = init_logger(__name__)
|
|
30
30
|
|
|
31
|
-
_DTYPE: dict[str, jnp.dtype] = {
|
|
32
|
-
"bfloat16": jnp.bfloat16,
|
|
33
|
-
"float": jnp.float32,
|
|
34
|
-
"float32": jnp.float32,
|
|
35
|
-
}
|
|
36
|
-
|
|
37
31
|
|
|
38
32
|
class TpuPlatform(Platform):
|
|
39
33
|
_enum = PlatformEnum.TPU
|
|
@@ -158,20 +152,19 @@ class TpuPlatform(Platform):
|
|
|
158
152
|
# NOTE(xiang): convert dtype to jnp.dtype
|
|
159
153
|
# NOTE(wenlong): skip this logic for mm model preprocessing
|
|
160
154
|
# For mm model preprocessors, it may need the output dtype to be torch.
|
|
161
|
-
# In order to avoid a PR to vLLM, we postpone the dtype checking during
|
|
155
|
+
# In order to avoid a PR to vLLM, we postpone the dtype checking during
|
|
156
|
+
# tpu_worker initialization
|
|
162
157
|
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
vllm_config.model_config.dtype = j2t_dtype(
|
|
174
|
-
vllm_config.model_config.dtype.dtype)
|
|
158
|
+
model_dtype = vllm_config.model_config.dtype
|
|
159
|
+
try:
|
|
160
|
+
dtype = to_jax_dtype(model_dtype)
|
|
161
|
+
except ValueError:
|
|
162
|
+
logger.warning(f"{model_dtype=} is not supported. "
|
|
163
|
+
"Falling back to jnp.bfloat16")
|
|
164
|
+
dtype = jnp.bfloat16
|
|
165
|
+
if impl == "vllm":
|
|
166
|
+
dtype = to_torch_dtype(dtype)
|
|
167
|
+
vllm_config.model_config.dtype = dtype
|
|
175
168
|
|
|
176
169
|
# TODO(cuiq): remove this dependency.
|
|
177
170
|
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|