tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -1,22 +1,22 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
|
4
5
|
|
|
5
6
|
import jax.numpy as jnp
|
|
6
|
-
import torch
|
|
7
7
|
import vllm.envs as vllm_envs
|
|
8
|
+
from torchax.ops.mappings import j2t_dtype
|
|
8
9
|
from tpu_info import device
|
|
9
10
|
from vllm.inputs import ProcessorInputs, PromptType
|
|
10
11
|
from vllm.platforms.interface import Platform, PlatformEnum
|
|
11
12
|
from vllm.sampling_params import SamplingParams, SamplingType
|
|
12
13
|
|
|
13
14
|
from tpu_inference import envs
|
|
14
|
-
from tpu_inference.layers.
|
|
15
|
+
from tpu_inference.layers.jax.sharding import ShardingConfigManager
|
|
15
16
|
from tpu_inference.logger import init_logger
|
|
16
|
-
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
|
-
from vllm.attention.backends.registry import
|
|
19
|
+
from vllm.attention.backends.registry import _Backend
|
|
20
20
|
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
|
21
21
|
from vllm.pooling_params import PoolingParams
|
|
22
22
|
else:
|
|
@@ -24,10 +24,16 @@ else:
|
|
|
24
24
|
ModelConfig = None
|
|
25
25
|
VllmConfig = None
|
|
26
26
|
PoolingParams = None
|
|
27
|
-
|
|
27
|
+
_Backend = None
|
|
28
28
|
|
|
29
29
|
logger = init_logger(__name__)
|
|
30
30
|
|
|
31
|
+
_DTYPE: dict[str, jnp.dtype] = {
|
|
32
|
+
"bfloat16": jnp.bfloat16,
|
|
33
|
+
"float": jnp.float32,
|
|
34
|
+
"float32": jnp.float32,
|
|
35
|
+
}
|
|
36
|
+
|
|
31
37
|
|
|
32
38
|
class TpuPlatform(Platform):
|
|
33
39
|
_enum = PlatformEnum.TPU
|
|
@@ -48,13 +54,12 @@ class TpuPlatform(Platform):
|
|
|
48
54
|
]
|
|
49
55
|
|
|
50
56
|
@classmethod
|
|
51
|
-
def get_attn_backend_cls(cls, selected_backend: "
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
if selected_backend != AttentionBackendEnum.PALLAS:
|
|
57
|
+
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
|
58
|
+
dtype: jnp.dtype, kv_cache_dtype: Optional[str],
|
|
59
|
+
block_size: int, use_v1: bool, use_mla: bool,
|
|
60
|
+
has_sink: bool, use_sparse: bool) -> str:
|
|
61
|
+
from vllm.attention.backends.registry import _Backend
|
|
62
|
+
if selected_backend != _Backend.PALLAS:
|
|
58
63
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
59
64
|
|
|
60
65
|
if use_v1:
|
|
@@ -77,14 +82,6 @@ class TpuPlatform(Platform):
|
|
|
77
82
|
logger.warning(f"Error getting device name: {e}")
|
|
78
83
|
return 'TPU'
|
|
79
84
|
|
|
80
|
-
@classmethod
|
|
81
|
-
def fp8_dtype(cls) -> torch.dtype:
|
|
82
|
-
if cls.get_device_name().lower() == "tpu v6e":
|
|
83
|
-
logger.info(
|
|
84
|
-
"Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
|
|
85
|
-
return torch.float8_e5m2
|
|
86
|
-
return torch.float8_e4m3fn
|
|
87
|
-
|
|
88
85
|
@classmethod
|
|
89
86
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
90
87
|
raise NotImplementedError
|
|
@@ -135,7 +132,6 @@ class TpuPlatform(Platform):
|
|
|
135
132
|
# For v0, the default block size is 16.
|
|
136
133
|
if cache_config and cache_config.block_size is None:
|
|
137
134
|
cache_config.block_size = cast(BlockSize, 16)
|
|
138
|
-
|
|
139
135
|
compilation_config = vllm_config.compilation_config
|
|
140
136
|
|
|
141
137
|
# TPU only supports DYNAMO_TRACE_ONCE compilation level
|
|
@@ -152,19 +148,20 @@ class TpuPlatform(Platform):
|
|
|
152
148
|
# NOTE(xiang): convert dtype to jnp.dtype
|
|
153
149
|
# NOTE(wenlong): skip this logic for mm model preprocessing
|
|
154
150
|
# For mm model preprocessors, it may need the output dtype to be torch.
|
|
155
|
-
# In order to avoid a PR to vLLM, we postpone the dtype checking during
|
|
156
|
-
# tpu_worker initialization
|
|
151
|
+
# In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
|
|
157
152
|
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
dtype =
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
153
|
+
if not isinstance(vllm_config.model_config.dtype, str):
|
|
154
|
+
logger.warning(
|
|
155
|
+
"The model dtype is not properly set for JAX backend. "
|
|
156
|
+
"Overwriting it to jnp.bfloat16")
|
|
157
|
+
vllm_config.model_config.dtype = jnp.bfloat16
|
|
158
|
+
else:
|
|
159
|
+
vllm_config.model_config.dtype = _DTYPE.get(
|
|
160
|
+
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
161
|
+
|
|
162
|
+
if impl == "vllm":
|
|
163
|
+
vllm_config.model_config.dtype = j2t_dtype(
|
|
164
|
+
vllm_config.model_config.dtype.dtype)
|
|
168
165
|
|
|
169
166
|
# TODO(cuiq): remove this dependency.
|
|
170
167
|
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
|
@@ -185,16 +182,10 @@ class TpuPlatform(Platform):
|
|
|
185
182
|
parallel_config.worker_cls = \
|
|
186
183
|
"tpu_inference.worker.tpu_worker.TPUWorker"
|
|
187
184
|
|
|
188
|
-
multihost_backend =
|
|
185
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
189
186
|
if not multihost_backend: # Single host
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
single host without pipeline parallelism.")
|
|
193
|
-
parallel_config.distributed_executor_backend = "uni"
|
|
194
|
-
else:
|
|
195
|
-
logger.info("Force using MultiprocExecutor for JAX on \
|
|
196
|
-
single host with pipeline parallelism.")
|
|
197
|
-
parallel_config.distributed_executor_backend = "mp"
|
|
187
|
+
logger.info("Force using UniProcExecutor for JAX on single host.")
|
|
188
|
+
parallel_config.distributed_executor_backend = "uni"
|
|
198
189
|
elif multihost_backend == "ray":
|
|
199
190
|
from tpu_inference.executors.ray_distributed_executor import \
|
|
200
191
|
RayDistributedExecutor
|
|
@@ -269,7 +260,3 @@ class TpuPlatform(Platform):
|
|
|
269
260
|
Returns if the current platform needs to sync weight loader.
|
|
270
261
|
"""
|
|
271
262
|
return True
|
|
272
|
-
|
|
273
|
-
@classmethod
|
|
274
|
-
def support_hybrid_kv_cache(cls) -> bool:
|
|
275
|
-
return True
|
|
@@ -1,22 +1,20 @@
|
|
|
1
|
+
import os
|
|
1
2
|
import time
|
|
2
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
|
3
4
|
|
|
4
5
|
import jax
|
|
5
6
|
import jax.numpy as jnp
|
|
6
7
|
import numpy as np
|
|
7
|
-
import vllm.envs as
|
|
8
|
+
import vllm.envs as envs
|
|
8
9
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
10
|
|
|
10
|
-
import tpu_inference.envs as envs
|
|
11
11
|
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
|
-
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
14
13
|
from tpu_inference.layers.jax.sample.sampling import sample
|
|
15
14
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
16
15
|
TPUSupportedSamplingMetadata
|
|
16
|
+
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
17
17
|
from tpu_inference.logger import init_logger
|
|
18
|
-
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
19
|
-
JaxIntermediateTensors
|
|
20
18
|
from tpu_inference.utils import device_array
|
|
21
19
|
|
|
22
20
|
if TYPE_CHECKING:
|
|
@@ -32,10 +30,10 @@ class CompilationManager:
|
|
|
32
30
|
|
|
33
31
|
def __init__(self, runner: "TPUModelRunner"):
|
|
34
32
|
self.runner = runner
|
|
35
|
-
if not
|
|
33
|
+
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
36
34
|
logger.info("Enabling JAX compile cache.")
|
|
37
35
|
jax.config.update("jax_compilation_cache_dir",
|
|
38
|
-
|
|
36
|
+
envs.VLLM_XLA_CACHE_PATH)
|
|
39
37
|
|
|
40
38
|
def _create_dummy_tensor(self,
|
|
41
39
|
shape: Tuple[int, ...],
|
|
@@ -69,7 +67,8 @@ class CompilationManager:
|
|
|
69
67
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
|
70
68
|
|
|
71
69
|
def capture_model(self) -> None:
|
|
72
|
-
if
|
|
70
|
+
if os.getenv("SKIP_JAX_PRECOMPILE",
|
|
71
|
+
False) or self.runner.model_config.enforce_eager:
|
|
73
72
|
return
|
|
74
73
|
logger.info("Precompile all the subgraphs with possible input shapes.")
|
|
75
74
|
|
|
@@ -82,8 +81,6 @@ class CompilationManager:
|
|
|
82
81
|
self._precompile_backbone_with_inputs_embeds()
|
|
83
82
|
if self.runner.scheduler_config.async_scheduling:
|
|
84
83
|
self._precompile_substitute_placeholder_token()
|
|
85
|
-
if not self.runner.is_last_rank:
|
|
86
|
-
return
|
|
87
84
|
self._precompile_select_from_array()
|
|
88
85
|
self._precompile_compute_logits()
|
|
89
86
|
self._precompile_disagg_utils()
|
|
@@ -123,15 +120,8 @@ class CompilationManager:
|
|
|
123
120
|
num_tokens=num_tokens,
|
|
124
121
|
)
|
|
125
122
|
|
|
126
|
-
def _precompile_backbone_helper(self,
|
|
127
|
-
|
|
128
|
-
*,
|
|
129
|
-
input_ids,
|
|
130
|
-
positions,
|
|
131
|
-
inputs_embeds,
|
|
132
|
-
intermediate_tensors=None,
|
|
133
|
-
is_first_rank=True,
|
|
134
|
-
is_last_rank=True) -> None:
|
|
123
|
+
def _precompile_backbone_helper(self, name, *, input_ids, positions,
|
|
124
|
+
inputs_embeds) -> None:
|
|
135
125
|
num_tokens = None
|
|
136
126
|
if input_ids is not None:
|
|
137
127
|
num_tokens = input_ids.shape[0]
|
|
@@ -145,6 +135,12 @@ class CompilationManager:
|
|
|
145
135
|
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
|
|
146
136
|
|
|
147
137
|
# Keep existing pattern for complex array operations
|
|
138
|
+
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
|
|
139
|
+
block_tables = block_tables.reshape(-1)
|
|
140
|
+
block_tables = device_array(self.runner.mesh,
|
|
141
|
+
block_tables,
|
|
142
|
+
sharding=dp_sharding)
|
|
143
|
+
|
|
148
144
|
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
|
|
149
145
|
jnp.int32, dp_sharding)
|
|
150
146
|
query_start_loc = self._create_dummy_tensor(
|
|
@@ -156,49 +152,26 @@ class CompilationManager:
|
|
|
156
152
|
request_distribution,
|
|
157
153
|
sharding=dp_sharding)
|
|
158
154
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
block_tables = device_array(self.runner.mesh,
|
|
167
|
-
block_tables,
|
|
168
|
-
sharding=dp_sharding)
|
|
169
|
-
|
|
170
|
-
attention_metadata_gid = AttentionMetadata(
|
|
171
|
-
input_positions=positions,
|
|
172
|
-
block_tables=block_tables,
|
|
173
|
-
seq_lens=seq_lens,
|
|
174
|
-
query_start_loc=query_start_loc,
|
|
175
|
-
request_distribution=request_distribution,
|
|
176
|
-
)
|
|
177
|
-
if not self.runner.use_hybrid_kvcache:
|
|
178
|
-
# all layers share the same attention metadata
|
|
179
|
-
uniform_attention_metadata = attention_metadata_gid
|
|
180
|
-
else:
|
|
181
|
-
for layer_name in kv_cache_group.layer_names:
|
|
182
|
-
attention_metadata_per_layer[
|
|
183
|
-
layer_name] = attention_metadata_gid
|
|
155
|
+
attention_metadata = AttentionMetadata(
|
|
156
|
+
input_positions=positions,
|
|
157
|
+
block_tables=block_tables,
|
|
158
|
+
seq_lens=seq_lens,
|
|
159
|
+
query_start_loc=query_start_loc,
|
|
160
|
+
request_distribution=request_distribution,
|
|
161
|
+
)
|
|
184
162
|
|
|
185
163
|
def model_fn_wrapper(
|
|
186
164
|
state,
|
|
187
165
|
kv_caches,
|
|
188
166
|
input_ids,
|
|
189
167
|
attention_metadata,
|
|
190
|
-
positions,
|
|
191
168
|
inputs_embeds,
|
|
192
169
|
layer_name_to_kvcache_index,
|
|
193
170
|
lora_metadata,
|
|
194
|
-
intermediate_tensors,
|
|
195
|
-
is_first_rank,
|
|
196
|
-
is_last_rank,
|
|
197
171
|
):
|
|
198
172
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
199
173
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
200
|
-
|
|
201
|
-
intermediate_tensors, is_first_rank, is_last_rank)
|
|
174
|
+
layer_name_to_kvcache_index, lora_metadata)
|
|
202
175
|
self.runner.kv_caches = kv_caches
|
|
203
176
|
return hidden_states
|
|
204
177
|
|
|
@@ -206,10 +179,6 @@ class CompilationManager:
|
|
|
206
179
|
self.runner.lora_config, np.array([num_tokens],
|
|
207
180
|
dtype=np.int32)):
|
|
208
181
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
209
|
-
if self.runner.use_hybrid_kvcache:
|
|
210
|
-
attention_metadata = attention_metadata_per_layer
|
|
211
|
-
else:
|
|
212
|
-
attention_metadata = uniform_attention_metadata
|
|
213
182
|
self._run_compilation(
|
|
214
183
|
name,
|
|
215
184
|
model_fn_wrapper,
|
|
@@ -217,13 +186,9 @@ class CompilationManager:
|
|
|
217
186
|
self.runner.kv_caches,
|
|
218
187
|
input_ids,
|
|
219
188
|
attention_metadata,
|
|
220
|
-
positions,
|
|
221
189
|
inputs_embeds,
|
|
222
190
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
223
191
|
lora_metadata,
|
|
224
|
-
intermediate_tensors,
|
|
225
|
-
is_first_rank,
|
|
226
|
-
is_last_rank,
|
|
227
192
|
num_tokens=num_tokens,
|
|
228
193
|
)
|
|
229
194
|
|
|
@@ -274,7 +239,6 @@ class CompilationManager:
|
|
|
274
239
|
)
|
|
275
240
|
|
|
276
241
|
def _precompile_backbone_text_only(self) -> None:
|
|
277
|
-
hidden_size = self.runner.model_config.get_hidden_size()
|
|
278
242
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
279
243
|
dp_sharding = NamedSharding(
|
|
280
244
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
|
|
@@ -284,28 +248,10 @@ class CompilationManager:
|
|
|
284
248
|
dp_sharding)
|
|
285
249
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
|
|
286
250
|
dp_sharding)
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
else:
|
|
292
|
-
hidden_states = self._create_dummy_tensor(
|
|
293
|
-
(num_tokens, hidden_size), jnp.bfloat16)
|
|
294
|
-
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
295
|
-
jnp.bfloat16)
|
|
296
|
-
intermediate_tensors = JaxIntermediateTensors(
|
|
297
|
-
tensors={
|
|
298
|
-
"hidden_states": hidden_states,
|
|
299
|
-
"residual": residual
|
|
300
|
-
})
|
|
301
|
-
self._precompile_backbone_helper(
|
|
302
|
-
f"worker{self.runner.rank} backbone",
|
|
303
|
-
input_ids=input_ids,
|
|
304
|
-
positions=positions,
|
|
305
|
-
inputs_embeds=None,
|
|
306
|
-
intermediate_tensors=intermediate_tensors,
|
|
307
|
-
is_first_rank=is_first_rank,
|
|
308
|
-
is_last_rank=is_last_rank)
|
|
251
|
+
self._precompile_backbone_helper("backbone",
|
|
252
|
+
input_ids=input_ids,
|
|
253
|
+
positions=positions,
|
|
254
|
+
inputs_embeds=None)
|
|
309
255
|
|
|
310
256
|
def _precompile_backbone_with_inputs_embeds(self) -> None:
|
|
311
257
|
hidden_size = self.runner.model_config.get_hidden_size()
|
|
@@ -319,28 +265,10 @@ class CompilationManager:
|
|
|
319
265
|
else:
|
|
320
266
|
positions = self._create_dummy_tensor((num_tokens, ),
|
|
321
267
|
jnp.int32)
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
(num_tokens, hidden_size), jnp.bfloat16)
|
|
327
|
-
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
328
|
-
jnp.bfloat16)
|
|
329
|
-
intermediate_tensors = JaxIntermediateTensors(
|
|
330
|
-
tensors={
|
|
331
|
-
"hidden_states": hidden_states,
|
|
332
|
-
"residual": residual
|
|
333
|
-
})
|
|
334
|
-
else:
|
|
335
|
-
intermediate_tensors = None
|
|
336
|
-
self._precompile_backbone_helper(
|
|
337
|
-
f"worker{self.runner.rank} backbone with embeds",
|
|
338
|
-
input_ids=None,
|
|
339
|
-
positions=positions,
|
|
340
|
-
inputs_embeds=inputs_embeds,
|
|
341
|
-
intermediate_tensors=intermediate_tensors,
|
|
342
|
-
is_first_rank=is_first_rank,
|
|
343
|
-
is_last_rank=is_last_rank)
|
|
268
|
+
self._precompile_backbone_helper("backbone with embeds",
|
|
269
|
+
input_ids=None,
|
|
270
|
+
positions=positions,
|
|
271
|
+
inputs_embeds=inputs_embeds)
|
|
344
272
|
|
|
345
273
|
def _precompile_select_from_array_helper(
|
|
346
274
|
self,
|
|
@@ -404,23 +332,20 @@ class CompilationManager:
|
|
|
404
332
|
index_paddings = self.runner.num_reqs_paddings
|
|
405
333
|
dp_sharding = NamedSharding(self.runner.mesh,
|
|
406
334
|
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
407
|
-
hidden_states_sharding = NamedSharding(
|
|
408
|
-
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
409
335
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
410
336
|
self._precompile_select_from_array_helper(
|
|
411
|
-
name=
|
|
337
|
+
name="select all logits",
|
|
412
338
|
source_paddings=self.runner.num_tokens_paddings,
|
|
413
339
|
indices_paddings=index_paddings,
|
|
414
340
|
hidden_dim=hsize,
|
|
415
|
-
input_sharding=
|
|
341
|
+
input_sharding=dp_sharding,
|
|
416
342
|
indices_sharding=dp_sharding if dp_size > 1 else None,
|
|
417
343
|
)
|
|
418
344
|
|
|
419
345
|
if self.runner.speculative_config:
|
|
420
346
|
vocab_size = self.runner.model_config.get_vocab_size()
|
|
421
347
|
self._precompile_select_from_array_helper(
|
|
422
|
-
name=
|
|
423
|
-
f"worker{self.runner.rank} select bonus tokens for spec decoding",
|
|
348
|
+
name="select bonus tokens for spec decoding",
|
|
424
349
|
source_paddings=self.runner.num_logits_paddings,
|
|
425
350
|
indices_paddings=self.runner.num_reqs_paddings,
|
|
426
351
|
hidden_dim=vocab_size,
|
|
@@ -428,8 +353,7 @@ class CompilationManager:
|
|
|
428
353
|
PartitionSpec(None, "model")),
|
|
429
354
|
)
|
|
430
355
|
self._precompile_select_from_array_helper(
|
|
431
|
-
name=
|
|
432
|
-
f"worker{self.runner.rank} select target tokens for spec decoding",
|
|
356
|
+
name="select target tokens for spec decoding",
|
|
433
357
|
source_paddings=self.runner.num_logits_paddings,
|
|
434
358
|
indices_paddings=self.runner.num_logits_paddings,
|
|
435
359
|
hidden_dim=vocab_size,
|
|
@@ -452,7 +376,7 @@ class CompilationManager:
|
|
|
452
376
|
np.array([num_reqs], dtype=np.int32)):
|
|
453
377
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
454
378
|
self._run_compilation(
|
|
455
|
-
|
|
379
|
+
"compute_logits",
|
|
456
380
|
self.runner.compute_logits_fn,
|
|
457
381
|
self.runner.state,
|
|
458
382
|
hidden_states,
|
|
@@ -494,7 +418,7 @@ class CompilationManager:
|
|
|
494
418
|
do_sampling=do_sampling,
|
|
495
419
|
)
|
|
496
420
|
self._run_compilation(
|
|
497
|
-
|
|
421
|
+
"sample",
|
|
498
422
|
sample,
|
|
499
423
|
self.runner.rng_params_for_sampling,
|
|
500
424
|
self.runner.mesh,
|
|
@@ -535,7 +459,7 @@ class CompilationManager:
|
|
|
535
459
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
|
|
536
460
|
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
|
|
537
461
|
self._run_compilation(
|
|
538
|
-
|
|
462
|
+
"gather_logprobs",
|
|
539
463
|
self.runner._compute_and_gather_logprobs,
|
|
540
464
|
logits,
|
|
541
465
|
token_ids,
|
|
@@ -587,7 +511,7 @@ class CompilationManager:
|
|
|
587
511
|
do_sampling=do_sampling)
|
|
588
512
|
|
|
589
513
|
self._run_compilation(
|
|
590
|
-
|
|
514
|
+
compilation_name,
|
|
591
515
|
self.runner.rejection_sampler,
|
|
592
516
|
draft_token_ids,
|
|
593
517
|
num_draft_tokens,
|
|
@@ -604,9 +528,7 @@ class CompilationManager:
|
|
|
604
528
|
def _precompile_eagle3_helpers(self) -> None:
|
|
605
529
|
logger.info(
|
|
606
530
|
"Compiling eagle3 jitted helpers with different input shapes.")
|
|
607
|
-
|
|
608
|
-
draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
|
|
609
|
-
)
|
|
531
|
+
hidden_size = self.runner.model_config.get_hidden_size()
|
|
610
532
|
dtype = self.runner.model_config.dtype
|
|
611
533
|
|
|
612
534
|
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
|
|
@@ -653,11 +575,10 @@ class CompilationManager:
|
|
|
653
575
|
|
|
654
576
|
for num_logits in self.runner.num_logits_paddings:
|
|
655
577
|
hidden_states = self._create_dummy_tensor(
|
|
656
|
-
(num_logits,
|
|
578
|
+
(num_logits, hidden_size), jnp.bfloat16)
|
|
657
579
|
self._run_compilation(
|
|
658
580
|
"eagle3_get_draft_token_ids",
|
|
659
581
|
self.runner.drafter._get_draft_token_ids,
|
|
660
|
-
self.runner.drafter.state,
|
|
661
582
|
hidden_states,
|
|
662
583
|
num_logits=num_logits,
|
|
663
584
|
)
|
|
@@ -665,8 +586,8 @@ class CompilationManager:
|
|
|
665
586
|
input_ids_loop = self._create_dummy_tensor(
|
|
666
587
|
(self.runner.max_num_reqs, ), jnp.int32,
|
|
667
588
|
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
668
|
-
|
|
669
|
-
(self.runner.max_num_reqs,
|
|
589
|
+
target_hidden_state_loop = self._create_dummy_tensor(
|
|
590
|
+
(self.runner.max_num_reqs, hidden_size), dtype,
|
|
670
591
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
671
592
|
next_token_ids = self._create_dummy_tensor(
|
|
672
593
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
@@ -674,12 +595,9 @@ class CompilationManager:
|
|
|
674
595
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
675
596
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
676
597
|
aux_hidden_states = [
|
|
677
|
-
self._create_dummy_tensor((num_tokens,
|
|
678
|
-
|
|
679
|
-
self._create_dummy_tensor((num_tokens,
|
|
680
|
-
dtype),
|
|
681
|
-
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
682
|
-
dtype),
|
|
598
|
+
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
|
|
599
|
+
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
|
|
600
|
+
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
|
|
683
601
|
]
|
|
684
602
|
|
|
685
603
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
@@ -702,23 +620,23 @@ class CompilationManager:
|
|
|
702
620
|
num_reqs,
|
|
703
621
|
):
|
|
704
622
|
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
623
|
+
token_indices, query_start_loc, seq_lens, input_ids,
|
|
624
|
+
aux_hidden_states, attention_metadata, next_token_ids,
|
|
625
|
+
num_reqs)
|
|
708
626
|
return target_hidden_states, input_ids, last_token_indices
|
|
709
627
|
|
|
710
628
|
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
711
629
|
aux_hidden_states = [
|
|
712
630
|
self._create_dummy_tensor(
|
|
713
|
-
(num_tokens,
|
|
631
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
714
632
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
715
633
|
None))),
|
|
716
634
|
self._create_dummy_tensor(
|
|
717
|
-
(num_tokens,
|
|
635
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
718
636
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
719
637
|
None))),
|
|
720
638
|
self._create_dummy_tensor(
|
|
721
|
-
(num_tokens,
|
|
639
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
722
640
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
723
641
|
None))),
|
|
724
642
|
]
|
|
@@ -750,17 +668,17 @@ class CompilationManager:
|
|
|
750
668
|
state,
|
|
751
669
|
kv_caches,
|
|
752
670
|
input_ids,
|
|
753
|
-
|
|
671
|
+
target_hidden_states,
|
|
754
672
|
attention_metadata,
|
|
755
673
|
):
|
|
756
674
|
kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
|
|
757
|
-
state, kv_caches, input_ids,
|
|
675
|
+
state, kv_caches, input_ids, target_hidden_states,
|
|
758
676
|
attention_metadata)
|
|
759
677
|
self.runner.kv_caches = kv_caches
|
|
760
678
|
return hidden_states
|
|
761
679
|
|
|
762
|
-
|
|
763
|
-
(num_tokens,
|
|
680
|
+
target_hidden_states = self._create_dummy_tensor(
|
|
681
|
+
(num_tokens, hidden_size), dtype,
|
|
764
682
|
NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
|
|
765
683
|
input_ids = self._create_dummy_tensor(
|
|
766
684
|
(num_tokens, ), jnp.int32,
|
|
@@ -771,7 +689,7 @@ class CompilationManager:
|
|
|
771
689
|
self.runner.drafter.state,
|
|
772
690
|
self.runner.kv_caches,
|
|
773
691
|
input_ids,
|
|
774
|
-
|
|
692
|
+
target_hidden_states,
|
|
775
693
|
attention_metadata,
|
|
776
694
|
num_tokens=num_tokens,
|
|
777
695
|
)
|
|
@@ -781,7 +699,6 @@ class CompilationManager:
|
|
|
781
699
|
self._run_compilation(
|
|
782
700
|
"eagle3_prepare_hidden_states_and_input_ids",
|
|
783
701
|
self.runner.drafter._prepare_hidden_states_and_input_ids,
|
|
784
|
-
self.runner.drafter.state,
|
|
785
702
|
aux_hidden_states,
|
|
786
703
|
query_start_loc,
|
|
787
704
|
target_token_ids,
|
|
@@ -804,19 +721,18 @@ class CompilationManager:
|
|
|
804
721
|
self.runner.drafter.state,
|
|
805
722
|
self.runner.kv_caches,
|
|
806
723
|
input_ids_loop,
|
|
807
|
-
|
|
724
|
+
target_hidden_state_loop,
|
|
808
725
|
attention_metadata,
|
|
809
726
|
num_tokens=num_tokens,
|
|
810
727
|
)
|
|
811
728
|
|
|
812
729
|
hidden_states = self._create_dummy_tensor(
|
|
813
|
-
(num_tokens,
|
|
730
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
814
731
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
815
732
|
|
|
816
733
|
self._run_compilation(
|
|
817
734
|
"eagle3_select_inputs_for_loop_speculation",
|
|
818
735
|
self.runner.drafter._select_inputs_for_loop_speculation,
|
|
819
|
-
self.runner.drafter.state,
|
|
820
736
|
positions,
|
|
821
737
|
hidden_states,
|
|
822
738
|
hidden_states,
|
|
@@ -827,7 +743,6 @@ class CompilationManager:
|
|
|
827
743
|
self._run_compilation(
|
|
828
744
|
"eagle3_select_draft_token_ids",
|
|
829
745
|
self.runner.drafter._select_draft_token_ids,
|
|
830
|
-
self.runner.drafter.state,
|
|
831
746
|
hidden_states,
|
|
832
747
|
last_token_indices,
|
|
833
748
|
num_tokens=num_tokens,
|
tpu_inference/runner/kv_cache.py
CHANGED
|
@@ -9,7 +9,7 @@ from torchax.ops.mappings import t2j_dtype
|
|
|
9
9
|
|
|
10
10
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
11
11
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
12
|
-
from tpu_inference.layers.
|
|
12
|
+
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
13
13
|
from tpu_inference.logger import init_logger
|
|
14
14
|
|
|
15
15
|
logger = init_logger(__name__)
|
|
@@ -82,7 +82,7 @@ def create_kv_caches(
|
|
|
82
82
|
ShardingAxisName.ATTN_HEAD))
|
|
83
83
|
|
|
84
84
|
def _allocate() -> jax.Array:
|
|
85
|
-
return jnp.
|
|
85
|
+
return jnp.empty(
|
|
86
86
|
shape=cache_shape,
|
|
87
87
|
dtype=cache_dtype,
|
|
88
88
|
)
|