tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/test_envs.py +182 -0
- tests/test_utils.py +23 -14
- tpu_inference/core/core_tpu.py +17 -9
- tpu_inference/executors/ray_distributed_executor.py +24 -11
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +33 -10
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
- tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/quantization/__init__.py +7 -3
- tpu_inference/layers/vllm/quantization/awq.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
- tpu_inference/models/common/model_loader.py +3 -2
- tpu_inference/models/jax/llama3.py +2 -2
- tpu_inference/models/jax/phi3.py +1 -1
- tpu_inference/models/jax/qwen2.py +1 -1
- tpu_inference/models/jax/qwen2_5_vl.py +2 -2
- tpu_inference/models/jax/qwen3.py +1 -1
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
- tpu_inference/platforms/tpu_platform.py +12 -5
- tpu_inference/runner/compilation_manager.py +4 -2
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/tpu_runner.py +31 -7
- tpu_inference/utils.py +2 -2
- tpu_inference/worker/tpu_worker.py +1 -1
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +1 -1
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +37 -34
- /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
- /tpu_inference/layers/{jax → common}/sharding.py +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
tests/test_envs.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
import tpu_inference.envs as envs
|
|
7
|
+
from tpu_inference.envs import enable_envs_cache, environment_variables
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_getattr_without_cache(monkeypatch: pytest.MonkeyPatch):
|
|
11
|
+
assert envs.JAX_PLATFORMS == ""
|
|
12
|
+
assert envs.PHASED_PROFILING_DIR == ""
|
|
13
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
14
|
+
monkeypatch.setenv("PHASED_PROFILING_DIR", "/tmp/profiling")
|
|
15
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
16
|
+
assert envs.PHASED_PROFILING_DIR == "/tmp/profiling"
|
|
17
|
+
|
|
18
|
+
assert envs.TPU_NAME is None
|
|
19
|
+
assert envs.TPU_ACCELERATOR_TYPE is None
|
|
20
|
+
monkeypatch.setenv("TPU_NAME", "my-tpu")
|
|
21
|
+
monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v5litepod-16")
|
|
22
|
+
assert envs.TPU_NAME == "my-tpu"
|
|
23
|
+
assert envs.TPU_ACCELERATOR_TYPE == "v5litepod-16"
|
|
24
|
+
|
|
25
|
+
# __getattr__ is not decorated with functools.cache
|
|
26
|
+
assert not hasattr(envs.__getattr__, "cache_info")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|
30
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
31
|
+
monkeypatch.setenv("TPU_NAME", "my-tpu")
|
|
32
|
+
|
|
33
|
+
# __getattr__ is not decorated with functools.cache
|
|
34
|
+
assert not hasattr(envs.__getattr__, "cache_info")
|
|
35
|
+
|
|
36
|
+
enable_envs_cache()
|
|
37
|
+
|
|
38
|
+
# __getattr__ is decorated with functools.cache
|
|
39
|
+
assert hasattr(envs.__getattr__, "cache_info")
|
|
40
|
+
start_hits = envs.__getattr__.cache_info().hits
|
|
41
|
+
|
|
42
|
+
# 2 more hits due to JAX_PLATFORMS and TPU_NAME accesses
|
|
43
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
44
|
+
assert envs.TPU_NAME == "my-tpu"
|
|
45
|
+
assert envs.__getattr__.cache_info().hits == start_hits + 2
|
|
46
|
+
|
|
47
|
+
# All environment variables are cached
|
|
48
|
+
for environment_variable in environment_variables:
|
|
49
|
+
envs.__getattr__(environment_variable)
|
|
50
|
+
assert envs.__getattr__.cache_info(
|
|
51
|
+
).hits == start_hits + 2 + len(environment_variables)
|
|
52
|
+
|
|
53
|
+
# Reset envs.__getattr__ back to non-cached version to
|
|
54
|
+
# avoid affecting other tests
|
|
55
|
+
envs.__getattr__ = envs.__getattr__.__wrapped__
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
59
|
+
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
60
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
61
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
|
|
62
|
+
assert envs.SKIP_JAX_PRECOMPILE is True
|
|
63
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
64
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
65
|
+
|
|
66
|
+
# Test NEW_MODEL_DESIGN (default False)
|
|
67
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
68
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
|
|
69
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
70
|
+
|
|
71
|
+
# Test USE_MOE_EP_KERNEL (default False)
|
|
72
|
+
assert envs.USE_MOE_EP_KERNEL is False
|
|
73
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
|
|
74
|
+
assert envs.USE_MOE_EP_KERNEL is True
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
78
|
+
assert envs.PYTHON_TRACER_LEVEL == 1
|
|
79
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
|
|
80
|
+
assert envs.PYTHON_TRACER_LEVEL == 3
|
|
81
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
|
|
82
|
+
assert envs.PYTHON_TRACER_LEVEL == 0
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
|
|
86
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
|
|
87
|
+
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
88
|
+
|
|
89
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
|
|
90
|
+
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
|
|
94
|
+
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
|
|
95
|
+
monkeypatch.delenv("PREFILL_SLICES", raising=False)
|
|
96
|
+
monkeypatch.delenv("DECODE_SLICES", raising=False)
|
|
97
|
+
|
|
98
|
+
assert envs.JAX_PLATFORMS == ""
|
|
99
|
+
assert envs.PREFILL_SLICES == ""
|
|
100
|
+
assert envs.DECODE_SLICES == ""
|
|
101
|
+
assert envs.PHASED_PROFILING_DIR == ""
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_none_default_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
105
|
+
monkeypatch.delenv("TPU_ACCELERATOR_TYPE", raising=False)
|
|
106
|
+
monkeypatch.delenv("TPU_NAME", raising=False)
|
|
107
|
+
monkeypatch.delenv("TPU_WORKER_ID", raising=False)
|
|
108
|
+
|
|
109
|
+
assert envs.TPU_ACCELERATOR_TYPE is None
|
|
110
|
+
assert envs.TPU_NAME is None
|
|
111
|
+
assert envs.TPU_WORKER_ID is None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
115
|
+
assert envs.RAY_USAGE_STATS_ENABLED == "0"
|
|
116
|
+
monkeypatch.setenv("RAY_USAGE_STATS_ENABLED", "1")
|
|
117
|
+
assert envs.RAY_USAGE_STATS_ENABLED == "1"
|
|
118
|
+
|
|
119
|
+
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
|
|
120
|
+
monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
|
|
121
|
+
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def test_invalid_attribute_raises_error():
|
|
125
|
+
with pytest.raises(AttributeError,
|
|
126
|
+
match="has no attribute 'NONEXISTENT_VAR'"):
|
|
127
|
+
_ = envs.NONEXISTENT_VAR
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_dir_returns_all_env_vars():
|
|
131
|
+
env_vars = envs.__dir__()
|
|
132
|
+
assert isinstance(env_vars, list)
|
|
133
|
+
assert len(env_vars) == len(environment_variables)
|
|
134
|
+
assert "JAX_PLATFORMS" in env_vars
|
|
135
|
+
assert "TPU_NAME" in env_vars
|
|
136
|
+
assert "SKIP_JAX_PRECOMPILE" in env_vars
|
|
137
|
+
assert "MODEL_IMPL_TYPE" in env_vars
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
141
|
+
monkeypatch.setenv("TPU_WORKER_ID", "0")
|
|
142
|
+
assert envs.TPU_WORKER_ID == "0"
|
|
143
|
+
|
|
144
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
|
|
145
|
+
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
146
|
+
|
|
147
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
|
|
148
|
+
assert envs.TPU_MULTIHOST_BACKEND == "xla"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
152
|
+
monkeypatch.setenv("PREFILL_SLICES", "0,1,2,3")
|
|
153
|
+
assert envs.PREFILL_SLICES == "0,1,2,3"
|
|
154
|
+
|
|
155
|
+
monkeypatch.setenv("DECODE_SLICES", "4,5,6,7")
|
|
156
|
+
assert envs.DECODE_SLICES == "4,5,6,7"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
|
|
160
|
+
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
|
|
161
|
+
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def test_cache_preserves_values_across_env_changes(
|
|
165
|
+
monkeypatch: pytest.MonkeyPatch):
|
|
166
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
167
|
+
|
|
168
|
+
enable_envs_cache()
|
|
169
|
+
|
|
170
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
171
|
+
|
|
172
|
+
# Change environment variable
|
|
173
|
+
monkeypatch.setenv("JAX_PLATFORMS", "cpu")
|
|
174
|
+
|
|
175
|
+
# Cached value should still be "tpu"
|
|
176
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
177
|
+
|
|
178
|
+
# Reset envs.__getattr__ back to non-cached version
|
|
179
|
+
envs.__getattr__ = envs.__getattr__.__wrapped__
|
|
180
|
+
|
|
181
|
+
# Now it should reflect the new value
|
|
182
|
+
assert envs.JAX_PLATFORMS == "cpu"
|
tests/test_utils.py
CHANGED
|
@@ -75,25 +75,34 @@ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
|
|
|
75
75
|
mock_device2 = MagicMock()
|
|
76
76
|
devices = [mock_device1, mock_device2]
|
|
77
77
|
|
|
78
|
-
# Create mock
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
78
|
+
# Create mock addressable shards with data property
|
|
79
|
+
mock_data1_dev1 = MagicMock()
|
|
80
|
+
mock_data1_dev1.device = mock_device1
|
|
81
|
+
mock_data1_dev1.nbytes = 2000 # 2000 bytes on device1
|
|
82
82
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
83
|
+
mock_data1_dev2 = MagicMock()
|
|
84
|
+
mock_data1_dev2.device = mock_device2
|
|
85
|
+
mock_data1_dev2.nbytes = 2000 # 2000 bytes on device2
|
|
86
86
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
87
|
+
mock_data2_dev1 = MagicMock()
|
|
88
|
+
mock_data2_dev1.device = mock_device1
|
|
89
|
+
mock_data2_dev1.nbytes = 1000 # 1000 bytes on device1
|
|
90
90
|
|
|
91
|
-
|
|
91
|
+
mock_shard1_dev1 = MagicMock()
|
|
92
|
+
mock_shard1_dev1.data = mock_data1_dev1
|
|
93
|
+
|
|
94
|
+
mock_shard1_dev2 = MagicMock()
|
|
95
|
+
mock_shard1_dev2.data = mock_data1_dev2
|
|
96
|
+
|
|
97
|
+
mock_shard2_dev1 = MagicMock()
|
|
98
|
+
mock_shard2_dev1.data = mock_data2_dev1
|
|
99
|
+
|
|
100
|
+
# Create mock arrays with addressable_shards
|
|
92
101
|
mock_array1 = MagicMock()
|
|
93
|
-
mock_array1.
|
|
102
|
+
mock_array1.addressable_shards = [mock_shard1_dev1, mock_shard1_dev2]
|
|
94
103
|
|
|
95
104
|
mock_array2 = MagicMock()
|
|
96
|
-
mock_array2.
|
|
105
|
+
mock_array2.addressable_shards = [mock_shard2_dev1]
|
|
97
106
|
|
|
98
107
|
mock_live_arrays.return_value = [mock_array1, mock_array2]
|
|
99
108
|
|
|
@@ -159,7 +168,7 @@ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
|
|
|
159
168
|
"head_dim, expected_padded_head_dim",
|
|
160
169
|
[
|
|
161
170
|
(1, 128),
|
|
162
|
-
(64,
|
|
171
|
+
(64, 64),
|
|
163
172
|
(127, 128),
|
|
164
173
|
(128, 128),
|
|
165
174
|
(129, 256),
|
tpu_inference/core/core_tpu.py
CHANGED
|
@@ -29,6 +29,7 @@ from vllm.v1.request import Request, RequestStatus
|
|
|
29
29
|
|
|
30
30
|
from tpu_inference import utils as common_utils
|
|
31
31
|
from tpu_inference.core import disagg_executor, disagg_utils
|
|
32
|
+
from tpu_inference.runner.tpu_runner import AsyncTPUModelRunnerOutput
|
|
32
33
|
# ======================================================================================
|
|
33
34
|
# Imports for _DisaggOrchestrator (decoupled from vLLM)
|
|
34
35
|
# ======================================================================================
|
|
@@ -186,6 +187,8 @@ class _DisaggOrchestrator:
|
|
|
186
187
|
if model_output is None:
|
|
187
188
|
model_output = prefill_engine.model_executor.sample_tokens(
|
|
188
189
|
grammar_output)
|
|
190
|
+
if isinstance(model_output, AsyncTPUModelRunnerOutput):
|
|
191
|
+
model_output = model_output.get_output()
|
|
189
192
|
|
|
190
193
|
if scheduler_output.total_num_scheduled_tokens > 0:
|
|
191
194
|
logger.debug(f"Prefill result: {model_output}")
|
|
@@ -218,15 +221,16 @@ class _DisaggOrchestrator:
|
|
|
218
221
|
f"request-{req_id}: tokens={request.all_token_ids} after prefill"
|
|
219
222
|
)
|
|
220
223
|
# Remove request from the prefill engine.
|
|
224
|
+
if req_id in prefill_engine.scheduler.requests:
|
|
225
|
+
request = prefill_engine.scheduler.requests[req_id]
|
|
226
|
+
prefill_engine.scheduler.running.remove(request)
|
|
227
|
+
prefill_engine.scheduler.encoder_cache_manager.free(
|
|
228
|
+
request)
|
|
221
229
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
prefill_engine.scheduler.encoder_cache_manager.free(
|
|
225
|
-
request)
|
|
230
|
+
prefill_engine.scheduler.kv_cache_manager.free(
|
|
231
|
+
request)
|
|
226
232
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
prefill_engine.scheduler.requests.pop(req_id)
|
|
233
|
+
prefill_engine.scheduler.requests.pop(req_id)
|
|
230
234
|
|
|
231
235
|
for output in (engine_core_outputs.items()
|
|
232
236
|
if engine_core_outputs else ()):
|
|
@@ -335,8 +339,10 @@ class _DisaggOrchestrator:
|
|
|
335
339
|
new_block_ids = kv_cache_manager.get_block_ids(req_id)
|
|
336
340
|
logger.debug(
|
|
337
341
|
f"inserting {req_id} new_block_ids {new_block_ids}")
|
|
338
|
-
|
|
339
|
-
|
|
342
|
+
if len(new_block_ids[0]) != math.ceil(
|
|
343
|
+
prompt_tokens / self._config.cache_config.block_size):
|
|
344
|
+
logger.warning("Running out of blocks in decode engine! ")
|
|
345
|
+
break
|
|
340
346
|
|
|
341
347
|
decode_engine.model_executor.driver_worker.model_runner.insert_request_with_kv_cache(
|
|
342
348
|
vllm_request, kv_cache, new_block_ids)
|
|
@@ -366,6 +372,8 @@ class _DisaggOrchestrator:
|
|
|
366
372
|
if model_output is None:
|
|
367
373
|
model_output = decode_engine.model_executor.sample_tokens(
|
|
368
374
|
grammar_output)
|
|
375
|
+
if isinstance(model_output, AsyncTPUModelRunnerOutput):
|
|
376
|
+
model_output = model_output.get_output()
|
|
369
377
|
|
|
370
378
|
if scheduler_output.total_num_scheduled_tokens > 0:
|
|
371
379
|
logger.debug(f"Decode result: {model_output}")
|
|
@@ -131,10 +131,17 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
131
131
|
f"current platform {current_platform.device_name} does not "
|
|
132
132
|
"support ray.")
|
|
133
133
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
134
|
+
pp_size = self.parallel_config.pipeline_parallel_size
|
|
135
|
+
placement_group_specs: List[Dict[str, float]] = []
|
|
136
|
+
if pp_size == 1:
|
|
137
|
+
placement_group_specs = [{
|
|
138
|
+
device_str: node['Resources'][device_str]
|
|
139
|
+
} for node in ray.nodes()]
|
|
140
|
+
else:
|
|
141
|
+
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
142
|
+
placement_group_specs = [{
|
|
143
|
+
device_str: num_devices_per_pp_rank
|
|
144
|
+
} for _ in range(pp_size)]
|
|
138
145
|
|
|
139
146
|
# vLLM engine is also a worker to execute model with an accelerator,
|
|
140
147
|
# so it requires to have the device in a current node. Check if
|
|
@@ -329,6 +336,8 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
329
336
|
all_kwargs = []
|
|
330
337
|
for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
|
|
331
338
|
local_rank = node_workers[node_id].index(rank)
|
|
339
|
+
ip = sorted_worker_metadata[rank].ip
|
|
340
|
+
prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
|
|
332
341
|
kwargs = dict(
|
|
333
342
|
vllm_config=self.vllm_config,
|
|
334
343
|
local_rank=local_rank,
|
|
@@ -336,22 +345,26 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
336
345
|
distributed_init_method=distributed_init_method,
|
|
337
346
|
is_driver_worker=(not self.parallel_config)
|
|
338
347
|
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
|
348
|
+
ip=ip,
|
|
349
|
+
prev_worker_ip=prev_ip,
|
|
339
350
|
)
|
|
340
351
|
all_kwargs.append(kwargs)
|
|
341
352
|
self.collective_rpc("init_worker", args=(all_kwargs, ))
|
|
342
353
|
self.collective_rpc("init_device")
|
|
354
|
+
if self.parallel_config.pipeline_parallel_size > 1:
|
|
355
|
+
self._run_workers("initialize_pp_transfer_connect")
|
|
343
356
|
self.collective_rpc("load_model")
|
|
344
357
|
|
|
345
358
|
if self.use_ray_spmd_worker:
|
|
346
359
|
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
|
347
360
|
self.pp_tp_workers.append([])
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
#
|
|
353
|
-
|
|
354
|
-
|
|
361
|
+
num_tp_workers = int(
|
|
362
|
+
self.parallel_config.tensor_parallel_size //
|
|
363
|
+
num_tpu_per_worker)
|
|
364
|
+
for tp_rank in range(num_tp_workers):
|
|
365
|
+
# PP=2, TP=4, num_tpu_per_worker=2
|
|
366
|
+
# pp_tp_workers = [[0, 1], [2, 3]]
|
|
367
|
+
rank = (pp_rank * num_tp_workers) + tp_rank
|
|
355
368
|
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
|
356
369
|
assert pp_rank < len(self.pp_tp_workers)
|
|
357
370
|
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
|
@@ -317,6 +317,20 @@ def _ragged_paged_attention_kernel(
|
|
|
317
317
|
q_len = q_end - q_start
|
|
318
318
|
kv_len = kv_lens_ref[seq_idx]
|
|
319
319
|
|
|
320
|
+
bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
|
|
321
|
+
kv_len - sliding_window, 0) // bkv_sz
|
|
322
|
+
|
|
323
|
+
if sliding_window is None:
|
|
324
|
+
next_bkv_idx_start = 0
|
|
325
|
+
else:
|
|
326
|
+
|
|
327
|
+
def get_next_bkv_idx_start():
|
|
328
|
+
next_kv_len = kv_lens_ref[seq_idx + 1]
|
|
329
|
+
return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
|
|
330
|
+
|
|
331
|
+
next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
|
|
332
|
+
get_next_bkv_idx_start, lambda: 0)
|
|
333
|
+
|
|
320
334
|
def debug_print(msg, *args):
|
|
321
335
|
if debug_mode:
|
|
322
336
|
pl.debug_print(msg, *args)
|
|
@@ -353,8 +367,8 @@ def _ragged_paged_attention_kernel(
|
|
|
353
367
|
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
354
368
|
|
|
355
369
|
def load_with_init(ref, init_val):
|
|
356
|
-
return jnp.where(bkv_idx ==
|
|
357
|
-
ref[...])
|
|
370
|
+
return jnp.where(bkv_idx == bkv_idx_start,
|
|
371
|
+
jnp.full_like(ref, init_val), ref[...])
|
|
358
372
|
|
|
359
373
|
# Follow FlashAttention-2 forward pass.
|
|
360
374
|
if q_scale is not None:
|
|
@@ -378,9 +392,6 @@ def _ragged_paged_attention_kernel(
|
|
|
378
392
|
num_q_heads_per_kv_head)
|
|
379
393
|
k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
|
|
380
394
|
mask = q_span < k_span
|
|
381
|
-
# TODO(jevinjiang, xiowei): reduce pages_per_seq based on sliding_window.
|
|
382
|
-
if sliding_window is not None:
|
|
383
|
-
mask = jnp.logical_or(mask, q_span - sliding_window >= k_span)
|
|
384
395
|
|
|
385
396
|
if soft_cap is not None:
|
|
386
397
|
s = soft_cap * jnp.tanh(s / soft_cap)
|
|
@@ -391,7 +402,8 @@ def _ragged_paged_attention_kernel(
|
|
|
391
402
|
sinks = attention_sink_ref[kv_head_idx]
|
|
392
403
|
actual_bq_sz = q.shape[0] // num_q_heads_per_kv_head
|
|
393
404
|
m_prev_init = jnp.concat([sinks] * actual_bq_sz, axis=0)
|
|
394
|
-
m_prev = jnp.where(bkv_idx ==
|
|
405
|
+
m_prev = jnp.where(bkv_idx == bkv_idx_start, m_prev_init,
|
|
406
|
+
head_m_ref[...])
|
|
395
407
|
else:
|
|
396
408
|
m_prev = load_with_init(head_m_ref, -jnp.inf)
|
|
397
409
|
|
|
@@ -719,12 +731,19 @@ def _ragged_paged_attention_kernel(
|
|
|
719
731
|
def get_next_bkv_ids(seq_idx, bq_idx, bkv_idx, bkv_sem_idx):
|
|
720
732
|
next_bkv_idx = bkv_idx + 1
|
|
721
733
|
is_last_bkv = next_bkv_idx == num_bkv
|
|
722
|
-
next_bkv_idx = lax.select(is_last_bkv, 0, next_bkv_idx)
|
|
723
734
|
next_bq_idx = lax.select(is_last_bkv, bq_idx + 1, bq_idx)
|
|
724
735
|
is_last_bq = next_bq_idx == num_bq
|
|
725
736
|
next_bq_idx = lax.select(is_last_bq, 0, next_bq_idx)
|
|
726
737
|
next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
|
|
727
738
|
next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
|
|
739
|
+
|
|
740
|
+
next_bkv_idx = lax.select(
|
|
741
|
+
is_last_bkv,
|
|
742
|
+
lax.select(
|
|
743
|
+
is_last_bq,
|
|
744
|
+
next_bkv_idx_start,
|
|
745
|
+
bkv_idx_start,
|
|
746
|
+
), next_bkv_idx)
|
|
728
747
|
return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
|
|
729
748
|
|
|
730
749
|
def compute_with_bq(bq_idx, _):
|
|
@@ -759,7 +778,7 @@ def _ragged_paged_attention_kernel(
|
|
|
759
778
|
next_bkv_sem_idx)
|
|
760
779
|
|
|
761
780
|
# Wait for cur bq if not ready yet
|
|
762
|
-
@pl.when(bkv_idx ==
|
|
781
|
+
@pl.when(bkv_idx == bkv_idx_start)
|
|
763
782
|
def wait_cur_bq():
|
|
764
783
|
wait_fetch_bq(seq_idx, bq_idx, bq_sem_idx)
|
|
765
784
|
|
|
@@ -808,7 +827,11 @@ def _ragged_paged_attention_kernel(
|
|
|
808
827
|
kv_head_idx=kv_head_idx,
|
|
809
828
|
)
|
|
810
829
|
|
|
811
|
-
lax.fori_loop(
|
|
830
|
+
lax.fori_loop(bkv_idx_start,
|
|
831
|
+
num_bkv,
|
|
832
|
+
compute_with_bkv,
|
|
833
|
+
None,
|
|
834
|
+
unroll=False)
|
|
812
835
|
|
|
813
836
|
# Load acc and calculate final output.
|
|
814
837
|
acc = acc_ref[...]
|
|
@@ -838,7 +861,7 @@ def _ragged_paged_attention_kernel(
|
|
|
838
861
|
@pl.when(seq_idx == 0)
|
|
839
862
|
def prologue():
|
|
840
863
|
start_fetch_bq(0, 0, 0)
|
|
841
|
-
start_fetch_bkv(0,
|
|
864
|
+
start_fetch_bkv(0, bkv_idx_start, 0)
|
|
842
865
|
|
|
843
866
|
@pl.when(seq_idx < decode_end)
|
|
844
867
|
def process_decode():
|
|
@@ -17,7 +17,7 @@ import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
|
17
17
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
18
18
|
from tpu_inference.kernels.flash_attention.kernel import flash_attention
|
|
19
19
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
20
|
-
from tpu_inference.layers.
|
|
20
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
21
21
|
from tpu_inference.utils import get_megacore
|
|
22
22
|
|
|
23
23
|
MAX_ALLOWED_PAGE_INDICES_N = (
|
|
@@ -13,9 +13,9 @@ from tpu_inference import utils
|
|
|
13
13
|
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
14
14
|
ragged_paged_attention
|
|
15
15
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
16
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
16
17
|
from tpu_inference.layers.jax.base import create_param
|
|
17
18
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
18
|
-
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
19
19
|
|
|
20
20
|
KVCache = Tuple[jax.Array, jax.Array]
|
|
21
21
|
|
|
@@ -12,7 +12,7 @@ import jax
|
|
|
12
12
|
import jax.numpy as jnp
|
|
13
13
|
import numpy as np
|
|
14
14
|
|
|
15
|
-
from tpu_inference.layers.
|
|
15
|
+
from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
|
|
16
16
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
17
17
|
TPUSupportedSamplingMetadata
|
|
18
18
|
|
|
@@ -6,10 +6,10 @@ from jax.sharding import Mesh, NamedSharding
|
|
|
6
6
|
from jax.sharding import PartitionSpec as P
|
|
7
7
|
from vllm.v1.outputs import LogprobsTensors
|
|
8
8
|
|
|
9
|
-
from tpu_inference.layers.
|
|
9
|
+
from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
|
|
10
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
10
11
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
11
12
|
TPUSupportedSamplingMetadata
|
|
12
|
-
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
13
13
|
|
|
14
14
|
_SAMPLING_EPS = 1e-5
|
|
15
15
|
|
|
@@ -13,8 +13,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
|
13
13
|
AttentionLayer, AttentionType)
|
|
14
14
|
|
|
15
15
|
from tpu_inference import utils
|
|
16
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
16
17
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
17
|
-
from tpu_inference.layers.jax.attention_interface import attention
|
|
18
18
|
from tpu_inference.logger import init_logger
|
|
19
19
|
from tpu_inference.models.vllm.vllm_model_wrapper_context import \
|
|
20
20
|
get_vllm_model_wrapper_context
|
|
@@ -5,10 +5,12 @@ from vllm.config import VllmConfig
|
|
|
5
5
|
from vllm.model_executor.layers.quantization.base_config import \
|
|
6
6
|
QuantizationConfig
|
|
7
7
|
|
|
8
|
+
from tpu_inference.layers.common import quant_methods
|
|
8
9
|
from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
|
|
9
10
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
10
11
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
|
|
11
12
|
VllmCompressedTensorsConfig # noqa: E501
|
|
13
|
+
from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
|
|
12
14
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
13
15
|
VllmUnquantizedConfig
|
|
14
16
|
|
|
@@ -19,8 +21,9 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
|
|
|
19
21
|
# TODO(kyuyeunk): Add support for "tpu_int8".
|
|
20
22
|
method_to_config: dict[str, str] = {
|
|
21
23
|
None: VllmUnquantizedConfig,
|
|
22
|
-
|
|
23
|
-
|
|
24
|
+
quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
|
|
25
|
+
quant_methods.AWQ: VllmAWQConfig,
|
|
26
|
+
quant_methods.MXFP4: VllmMxfp4Config,
|
|
24
27
|
}
|
|
25
28
|
if model_config.quantization not in method_to_config:
|
|
26
29
|
raise NotImplementedError(
|
|
@@ -30,6 +33,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
|
|
|
30
33
|
assert issubclass(quant_config, JaxCommonConfig)
|
|
31
34
|
quant_config.set_configs(vllm_config, mesh)
|
|
32
35
|
|
|
33
|
-
model_config.quantization =
|
|
36
|
+
model_config.quantization = quant_methods.get_tpu_quant_method(
|
|
37
|
+
quant_config.get_name())
|
|
34
38
|
return VllmConfig.get_quantization_config(model_config,
|
|
35
39
|
vllm_config.load_config)
|
|
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
|
18
18
|
is_layer_skipped, unpack_quantized_values_into_int32)
|
|
19
19
|
from vllm.scalar_type import scalar_types
|
|
20
20
|
|
|
21
|
+
from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
|
|
21
22
|
from tpu_inference.layers.vllm.linear_common import (
|
|
22
23
|
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
23
24
|
from tpu_inference.layers.vllm.quantization.common import (
|
|
@@ -29,12 +30,12 @@ P = PartitionSpec
|
|
|
29
30
|
logger = init_logger(__name__)
|
|
30
31
|
|
|
31
32
|
|
|
32
|
-
@register_quantization_config(
|
|
33
|
+
@register_quantization_config(get_tpu_quant_method(AWQ))
|
|
33
34
|
class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
34
35
|
|
|
35
36
|
@classmethod
|
|
36
|
-
def get_name(cls)
|
|
37
|
-
return
|
|
37
|
+
def get_name(cls):
|
|
38
|
+
return AWQ
|
|
38
39
|
|
|
39
40
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
40
41
|
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
|
|
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|
|
16
16
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
17
17
|
find_matched_target, should_ignore_layer)
|
|
18
18
|
|
|
19
|
+
from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
|
|
20
|
+
get_tpu_quant_method)
|
|
19
21
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
20
22
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
21
23
|
VllmCompressedTensorsW8A8Fp8MoEMethod
|
|
@@ -30,12 +32,12 @@ P = PartitionSpec
|
|
|
30
32
|
logger = init_logger(__name__)
|
|
31
33
|
|
|
32
34
|
|
|
33
|
-
@register_quantization_config(
|
|
35
|
+
@register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
|
|
34
36
|
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
35
37
|
|
|
36
38
|
@classmethod
|
|
37
39
|
def get_name(cls) -> str:
|
|
38
|
-
return
|
|
40
|
+
return COMPRESSED_TENSORS
|
|
39
41
|
|
|
40
42
|
def get_scheme(self,
|
|
41
43
|
layer: torch.nn.Module,
|