tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -7,8 +7,8 @@ import numpy as np
|
|
|
7
7
|
import vllm.envs as envs
|
|
8
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
9
|
from torchax.ops.mappings import t2j_dtype
|
|
10
|
-
from vllm.attention import Attention
|
|
11
10
|
from vllm.attention.backends.abstract import AttentionType
|
|
11
|
+
from vllm.attention.layer import Attention
|
|
12
12
|
from vllm.config import get_layers_from_vllm_config
|
|
13
13
|
from vllm.utils.math_utils import cdiv
|
|
14
14
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
@@ -39,20 +39,30 @@ class KVCacheManager:
|
|
|
39
39
|
# means this layer will perform attention using the keys and values
|
|
40
40
|
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
|
41
41
|
self.shared_kv_cache_layers: dict[str, str] = {}
|
|
42
|
+
self.use_mla = self.runner.model_config.use_mla
|
|
42
43
|
|
|
43
44
|
def get_kv_cache_spec(self):
|
|
44
45
|
# TODO(xiang): this hack tricks engine core to init successfully
|
|
45
46
|
block_size = self.runner.cache_config.block_size
|
|
46
|
-
use_mla = self.runner.model_config.use_mla
|
|
47
47
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
48
48
|
|
|
49
49
|
# If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
|
|
50
50
|
# attention into compilation config.
|
|
51
51
|
# Use FullAttentionSpec for each layer
|
|
52
52
|
# TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
|
|
53
|
+
model_config = self.runner.model_config
|
|
54
|
+
if self.use_mla:
|
|
55
|
+
# Individually pad the RopE and latents
|
|
56
|
+
qk_rope_head_dim = getattr(model_config.hf_text_config,
|
|
57
|
+
"qk_rope_head_dim", 0)
|
|
58
|
+
padded_kv_lora_rank = common_utils.align_to(
|
|
59
|
+
model_config.hf_text_config.kv_lora_rank, 128)
|
|
60
|
+
padded_qk_rope_head_dim = common_utils.align_to(
|
|
61
|
+
qk_rope_head_dim, 128)
|
|
62
|
+
mla_head_size = padded_kv_lora_rank + padded_qk_rope_head_dim
|
|
63
|
+
|
|
53
64
|
if len(self.runner.vllm_config.compilation_config.
|
|
54
65
|
static_forward_context) == 0:
|
|
55
|
-
model_config = self.runner.model_config
|
|
56
66
|
parallel_config = self.runner.parallel_config
|
|
57
67
|
# Pad num_kv_heads to multiple of TP size.
|
|
58
68
|
num_kv_heads = common_utils.get_padded_num_heads(
|
|
@@ -61,11 +71,11 @@ class KVCacheManager:
|
|
|
61
71
|
head_size = common_utils.get_padded_head_dim(
|
|
62
72
|
model_config.get_head_size())
|
|
63
73
|
for i in range(model_config.get_num_layers(parallel_config)):
|
|
64
|
-
if use_mla:
|
|
74
|
+
if self.use_mla:
|
|
65
75
|
kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
|
|
66
76
|
block_size=block_size,
|
|
67
|
-
num_kv_heads=
|
|
68
|
-
head_size=
|
|
77
|
+
num_kv_heads=1,
|
|
78
|
+
head_size=mla_head_size,
|
|
69
79
|
dtype=self.runner.kv_cache_dtype,
|
|
70
80
|
cache_dtype_str=self.runner.vllm_config.cache_config.
|
|
71
81
|
cache_dtype)
|
|
@@ -83,14 +93,13 @@ class KVCacheManager:
|
|
|
83
93
|
self.runner.mesh.shape["model"])
|
|
84
94
|
head_size = common_utils.get_padded_head_dim(
|
|
85
95
|
hf_config.hidden_size // hf_config.num_attention_heads)
|
|
86
|
-
|
|
87
96
|
# Eagle3 has only 1 layer
|
|
88
97
|
for i in range(1):
|
|
89
|
-
if use_mla:
|
|
90
|
-
kv_cache_spec[f"
|
|
98
|
+
if self.use_mla:
|
|
99
|
+
kv_cache_spec[f"draft_layer.{i}"] = MLAAttentionSpec(
|
|
91
100
|
block_size=block_size,
|
|
92
|
-
num_kv_heads=
|
|
93
|
-
head_size=
|
|
101
|
+
num_kv_heads=1,
|
|
102
|
+
head_size=mla_head_size,
|
|
94
103
|
dtype=self.runner.kv_cache_dtype,
|
|
95
104
|
cache_dtype_str=self.runner.vllm_config.
|
|
96
105
|
cache_config.cache_dtype)
|
|
@@ -104,6 +113,7 @@ class KVCacheManager:
|
|
|
104
113
|
# Else propagate attention modules from compilation config.
|
|
105
114
|
layers = get_layers_from_vllm_config(self.runner.vllm_config,
|
|
106
115
|
Attention)
|
|
116
|
+
logger.warning(f"Compilation num_layers = {len(layers.items())}")
|
|
107
117
|
for layer_name, attn_module in layers.items():
|
|
108
118
|
if (kv_tgt_layer :=
|
|
109
119
|
attn_module.kv_sharing_target_layer_name) is not None:
|
|
@@ -127,11 +137,11 @@ class KVCacheManager:
|
|
|
127
137
|
attn_module.head_size),
|
|
128
138
|
dtype=self.runner.kv_cache_dtype,
|
|
129
139
|
sliding_window=attn_module.sliding_window)
|
|
130
|
-
elif use_mla:
|
|
131
|
-
kv_cache_spec[
|
|
140
|
+
elif self.use_mla:
|
|
141
|
+
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
|
132
142
|
block_size=block_size,
|
|
133
|
-
num_kv_heads=
|
|
134
|
-
head_size=
|
|
143
|
+
num_kv_heads=1,
|
|
144
|
+
head_size=mla_head_size,
|
|
135
145
|
dtype=self.runner.kv_cache_dtype,
|
|
136
146
|
cache_dtype_str=self.runner.vllm_config.
|
|
137
147
|
cache_config.cache_dtype)
|
|
@@ -198,14 +208,20 @@ class KVCacheManager:
|
|
|
198
208
|
# num_blocks must be a multiple of dp_size
|
|
199
209
|
num_blocks = (num_blocks // dp_size) * dp_size
|
|
200
210
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
211
|
+
if self.use_mla:
|
|
212
|
+
head_size = self.runner.model_config.hf_config.kv_lora_rank + \
|
|
213
|
+
self.runner.model_config.hf_config.qk_rope_head_dim
|
|
214
|
+
else:
|
|
215
|
+
head_size = representative_spec.head_size
|
|
201
216
|
kv_cache = create_kv_caches(
|
|
202
217
|
num_blocks=num_blocks,
|
|
203
218
|
block_size=representative_spec.block_size,
|
|
204
219
|
num_kv_heads=representative_spec.num_kv_heads,
|
|
205
|
-
head_size=
|
|
220
|
+
head_size=head_size,
|
|
206
221
|
mesh=self.runner.mesh,
|
|
207
222
|
layer_names=[f'kv_cache_tensor.{i}'],
|
|
208
223
|
cache_dtype=t2j_dtype(representative_spec.dtype),
|
|
224
|
+
use_mla=self.use_mla,
|
|
209
225
|
)[0]
|
|
210
226
|
kv_caches.append(kv_cache)
|
|
211
227
|
num_blocks_list.append(num_blocks)
|
|
@@ -289,13 +305,8 @@ class KVCacheManager:
|
|
|
289
305
|
|
|
290
306
|
def _update_layer(cache, slices):
|
|
291
307
|
"""The function to apply to each layer's cache and slices."""
|
|
292
|
-
reshaped_slices = slices.reshape(-1,
|
|
293
|
-
|
|
294
|
-
for (i, block_idx) in enumerate(block_numbers):
|
|
295
|
-
cache = jax.lax.dynamic_update_slice_in_dim(cache,
|
|
296
|
-
reshaped_slices[i],
|
|
297
|
-
block_idx,
|
|
298
|
-
axis=0)
|
|
308
|
+
reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
|
|
309
|
+
cache.at[block_numbers].set(reshaped_slices)
|
|
299
310
|
return cache
|
|
300
311
|
|
|
301
312
|
return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
|
|
@@ -348,16 +359,12 @@ class KVCacheManager:
|
|
|
348
359
|
"""
|
|
349
360
|
if block_ids == list(range(block_ids[0],
|
|
350
361
|
block_ids[0] + len(block_ids))):
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
354
|
-
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
362
|
+
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
363
|
+
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
355
364
|
|
|
356
365
|
else:
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
360
|
-
self.runner.kv_caches, jnp.array(block_ids))
|
|
366
|
+
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
367
|
+
self.runner.kv_caches, jnp.array(block_ids))
|
|
361
368
|
return batched_kv_cache_per_layer
|
|
362
369
|
|
|
363
370
|
def transfer_kv_cache(self,
|
|
@@ -446,6 +453,7 @@ class KVCacheManager:
|
|
|
446
453
|
kv_cache_slices,
|
|
447
454
|
start_block,
|
|
448
455
|
)
|
|
456
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
449
457
|
else:
|
|
450
458
|
with runner_utils.LatencyTracker(
|
|
451
459
|
f"JittedInsertKVCache-b{len(block_numbers)}"):
|
|
@@ -457,6 +465,7 @@ class KVCacheManager:
|
|
|
457
465
|
kv_cache_slices,
|
|
458
466
|
jnp.array(block_numbers),
|
|
459
467
|
)
|
|
468
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
460
469
|
|
|
461
470
|
logger.debug(
|
|
462
471
|
f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
|
|
@@ -14,12 +14,13 @@ class PersistentBatchManager:
|
|
|
14
14
|
def __init__(self, requests: Dict[str, CachedRequestState],
|
|
15
15
|
input_batch: InputBatch, encoder_cache: Dict[str,
|
|
16
16
|
'jax.Array'],
|
|
17
|
-
uses_mrope: bool, model_config):
|
|
17
|
+
uses_mrope: bool, model_config, is_last_rank: bool):
|
|
18
18
|
self.requests = requests
|
|
19
19
|
self.input_batch = input_batch
|
|
20
20
|
self.encoder_cache = encoder_cache
|
|
21
21
|
self.uses_mrope = uses_mrope
|
|
22
22
|
self.model_config = model_config
|
|
23
|
+
self.is_last_rank = is_last_rank
|
|
23
24
|
|
|
24
25
|
def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
|
|
25
26
|
""" Reorder the sheduled requests to RPA kernel friendly distribution
|
|
@@ -179,9 +180,35 @@ class PersistentBatchManager:
|
|
|
179
180
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
180
181
|
new_block_ids = req_data.new_block_ids[i]
|
|
181
182
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
183
|
+
num_output_tokens = req_data.num_output_tokens[i]
|
|
182
184
|
|
|
183
185
|
# Update the cached states.
|
|
184
186
|
req_state.num_computed_tokens = num_computed_tokens
|
|
187
|
+
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
188
|
+
|
|
189
|
+
if not self.is_last_rank:
|
|
190
|
+
# When using PP, the scheduler sends the sampled tokens back,
|
|
191
|
+
# because there's no direct communication between the first-
|
|
192
|
+
# stage worker and the last-stage worker.
|
|
193
|
+
new_token_ids = req_data.new_token_ids[i]
|
|
194
|
+
# Add the sampled token(s) from the previous step (if any).
|
|
195
|
+
# This doesn't include "unverified" tokens like spec tokens.
|
|
196
|
+
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
|
197
|
+
req_state.num_tokens)
|
|
198
|
+
if num_new_tokens == 1:
|
|
199
|
+
req_state.output_token_ids.append(new_token_ids[-1])
|
|
200
|
+
elif num_new_tokens > 0:
|
|
201
|
+
req_state.output_token_ids.extend(
|
|
202
|
+
new_token_ids[-num_new_tokens:])
|
|
203
|
+
elif num_output_tokens < len(req_state.output_token_ids):
|
|
204
|
+
del req_state.output_token_ids[num_output_tokens:]
|
|
205
|
+
if req_index is not None:
|
|
206
|
+
end_idx = (self.input_batch.num_prompt_tokens[req_index] +
|
|
207
|
+
num_output_tokens)
|
|
208
|
+
self.input_batch.num_tokens[req_index] = end_idx
|
|
209
|
+
self.input_batch.num_tokens_no_spec[req_index] = end_idx
|
|
210
|
+
|
|
211
|
+
# Update the block IDs.
|
|
185
212
|
if not resumed_from_preemption:
|
|
186
213
|
if new_block_ids is not None:
|
|
187
214
|
# Append the new blocks to the existing block IDs.
|
|
@@ -194,7 +221,6 @@ class PersistentBatchManager:
|
|
|
194
221
|
# Replace the existing block IDs with the new ones.
|
|
195
222
|
req_state.block_ids = new_block_ids
|
|
196
223
|
|
|
197
|
-
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
198
224
|
if req_index is None:
|
|
199
225
|
# The request is not in the persistent batch.
|
|
200
226
|
# The request was either preempted and resumed later, or was not
|
|
@@ -209,6 +235,18 @@ class PersistentBatchManager:
|
|
|
209
235
|
self.input_batch.block_table.append_row(
|
|
210
236
|
new_block_ids, req_index)
|
|
211
237
|
|
|
238
|
+
# For the last rank, we don't need to update the token_ids_cpu
|
|
239
|
+
# because the sampled tokens are already cached.
|
|
240
|
+
if not self.is_last_rank:
|
|
241
|
+
start_token_index = num_computed_tokens
|
|
242
|
+
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
243
|
+
self.input_batch.token_ids_cpu[
|
|
244
|
+
req_index,
|
|
245
|
+
start_token_index:end_token_index] = new_token_ids
|
|
246
|
+
self.input_batch.num_tokens_no_spec[
|
|
247
|
+
req_index] = end_token_index
|
|
248
|
+
self.input_batch.num_tokens[req_index] = end_token_index
|
|
249
|
+
|
|
212
250
|
# Add spec_token_ids to token_ids_cpu.
|
|
213
251
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
214
252
|
req_id, ())
|
|
@@ -61,11 +61,10 @@ class StructuredDecodingManager:
|
|
|
61
61
|
self.runner.require_structured_out_cpu.fill(0)
|
|
62
62
|
|
|
63
63
|
sorted_struct_requests = sorted(
|
|
64
|
-
grammar_output.structured_output_request_ids
|
|
65
|
-
key=lambda item: item[1])
|
|
64
|
+
grammar_output.structured_output_request_ids)
|
|
66
65
|
|
|
67
66
|
cumulative_mask_idx = 0
|
|
68
|
-
for req_id
|
|
67
|
+
for req_id in sorted_struct_requests:
|
|
69
68
|
if req_id not in self.runner.input_batch.req_id_to_index:
|
|
70
69
|
continue
|
|
71
70
|
batch_index = self.runner.input_batch.req_id_to_index[req_id]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import functools
|
|
3
|
-
import
|
|
3
|
+
import logging
|
|
4
4
|
import random
|
|
5
5
|
from contextlib import nullcontext
|
|
6
6
|
from dataclasses import dataclass
|
|
@@ -10,17 +10,15 @@ import jax
|
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import jaxtyping
|
|
12
12
|
import numpy as np
|
|
13
|
-
import
|
|
14
|
-
import vllm.envs as envs
|
|
13
|
+
import vllm.envs as vllm_envs
|
|
15
14
|
from flax import nnx
|
|
16
15
|
from jax.experimental import mesh_utils
|
|
17
16
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
18
|
-
from torchax.ops.mappings import j2t_dtype
|
|
19
17
|
from vllm.config import VllmConfig
|
|
18
|
+
from vllm.distributed import get_pp_group
|
|
20
19
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
21
20
|
has_kv_transfer_group)
|
|
22
21
|
from vllm.forward_context import set_forward_context
|
|
23
|
-
from vllm.sequence import IntermediateTensors
|
|
24
22
|
from vllm.tasks import SupportedTask
|
|
25
23
|
from vllm.utils.math_utils import cdiv
|
|
26
24
|
from vllm.v1.core.sched.output import GrammarOutput
|
|
@@ -35,6 +33,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
|
35
33
|
KVConnectorModelRunnerMixin
|
|
36
34
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
37
35
|
|
|
36
|
+
import tpu_inference.envs as envs
|
|
38
37
|
from tpu_inference import utils as common_utils
|
|
39
38
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
39
|
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
@@ -48,6 +47,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
|
48
47
|
TPUSupportedSamplingMetadata
|
|
49
48
|
from tpu_inference.logger import init_logger
|
|
50
49
|
from tpu_inference.models.common.model_loader import get_model
|
|
50
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
51
|
+
JaxIntermediateTensors
|
|
51
52
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
52
53
|
shard_put, transfer_state_with_mappings)
|
|
53
54
|
from tpu_inference.runner import utils as runner_utils
|
|
@@ -64,10 +65,12 @@ from tpu_inference.runner.structured_decoding_manager import \
|
|
|
64
65
|
StructuredDecodingManager
|
|
65
66
|
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
66
67
|
from tpu_inference.utils import (device_array, make_optimized_mesh,
|
|
67
|
-
time_function)
|
|
68
|
+
time_function, to_jax_dtype, to_torch_dtype)
|
|
68
69
|
|
|
69
70
|
logger = init_logger(__name__)
|
|
70
71
|
|
|
72
|
+
logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
|
|
73
|
+
|
|
71
74
|
INVALID_TOKEN_ID = -1
|
|
72
75
|
# Smallest output size
|
|
73
76
|
MIN_NUM_SEQS = 8
|
|
@@ -78,17 +81,6 @@ DUMMY_METADATA = AttentionMetadata(
|
|
|
78
81
|
request_distribution=[0, 0, 0],
|
|
79
82
|
)
|
|
80
83
|
|
|
81
|
-
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
82
|
-
"half": torch.half,
|
|
83
|
-
"bfloat16": torch.bfloat16,
|
|
84
|
-
"float": torch.float,
|
|
85
|
-
"fp8": torch.float8_e4m3fn,
|
|
86
|
-
"fp8_e4m3": torch.float8_e4m3fn,
|
|
87
|
-
"fp8_e5m2": torch.float8_e5m2,
|
|
88
|
-
"int8": torch.int8,
|
|
89
|
-
"uint8": torch.uint8,
|
|
90
|
-
}
|
|
91
|
-
|
|
92
84
|
|
|
93
85
|
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
94
86
|
"""Holds asynchronous model output specifically from a TPU runner.
|
|
@@ -243,6 +235,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
243
235
|
self.maybe_forbid_compile = runner_utils.ForbidCompile(
|
|
244
236
|
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
|
|
245
237
|
self.dp_size = self.vllm_config.sharding_config.total_dp_size
|
|
238
|
+
self.rank = rank
|
|
239
|
+
self.is_first_rank = is_first_rank
|
|
240
|
+
self.is_last_rank = is_last_rank
|
|
246
241
|
|
|
247
242
|
self._init_random()
|
|
248
243
|
self._init_mesh()
|
|
@@ -253,31 +248,21 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
253
248
|
|
|
254
249
|
# Delegate functions to specific manager classes.
|
|
255
250
|
self.compilation_manager = CompilationManager(self)
|
|
256
|
-
self.
|
|
257
|
-
|
|
251
|
+
if self.is_last_rank:
|
|
252
|
+
self.speculative_decoding_manager = SpeculativeDecodingManager(
|
|
253
|
+
self)
|
|
254
|
+
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
258
255
|
self.kv_cache_manager = KVCacheManager(self)
|
|
259
256
|
self.mm_manager = MultiModalManager(self)
|
|
260
257
|
self.persistent_batch_manager = PersistentBatchManager(
|
|
261
258
|
self.requests, self.input_batch, self.encoder_cache,
|
|
262
|
-
self.uses_mrope, self.model_config)
|
|
259
|
+
self.uses_mrope, self.model_config, self.is_last_rank)
|
|
263
260
|
self.lora_utils = LoraUtils(self)
|
|
264
261
|
|
|
265
|
-
|
|
266
|
-
if
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
|
270
|
-
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
|
|
271
|
-
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
|
|
272
|
-
elif isinstance(model_dtype, torch.dtype):
|
|
273
|
-
self.kv_cache_dtype = model_dtype
|
|
274
|
-
else:
|
|
275
|
-
raise ValueError(
|
|
276
|
-
"KV cache is unsupported for model_dtype of %s",
|
|
277
|
-
model_dtype)
|
|
278
|
-
else:
|
|
279
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
|
|
280
|
-
cache_config.cache_dtype]
|
|
262
|
+
cache_dtype = self.cache_config.cache_dtype
|
|
263
|
+
if cache_dtype == "auto":
|
|
264
|
+
cache_dtype = self.dtype
|
|
265
|
+
self.kv_cache_dtype = to_torch_dtype(cache_dtype)
|
|
281
266
|
|
|
282
267
|
self._pre_async_results: AsyncPreResults | None = None
|
|
283
268
|
self._substitute_placeholder_token_fn = _substitute_placeholder_token
|
|
@@ -291,7 +276,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
291
276
|
self.rng_key = jax.random.key(self.model_config.seed)
|
|
292
277
|
|
|
293
278
|
def _init_mesh(self) -> None:
|
|
294
|
-
if
|
|
279
|
+
if envs.NEW_MODEL_DESIGN:
|
|
295
280
|
self.mesh = self._create_new_model_mesh()
|
|
296
281
|
else:
|
|
297
282
|
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
|
|
@@ -302,7 +287,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
302
287
|
logger.info(f"Init mesh | mesh={self.mesh}")
|
|
303
288
|
|
|
304
289
|
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
|
|
305
|
-
num_slices =
|
|
290
|
+
num_slices = envs.NUM_SLICES
|
|
306
291
|
|
|
307
292
|
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
|
|
308
293
|
f"num_slices={num_slices}")
|
|
@@ -371,7 +356,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
371
356
|
devices=self.devices)
|
|
372
357
|
|
|
373
358
|
def _init_phased_profiling(self) -> None:
|
|
374
|
-
self.phased_profiling_dir =
|
|
359
|
+
self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
|
|
375
360
|
self.phase_based_profiler = None
|
|
376
361
|
if self.phased_profiling_dir:
|
|
377
362
|
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
|
|
@@ -413,7 +398,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
413
398
|
min_token_size=max(16, self.dp_size),
|
|
414
399
|
max_token_size=scheduler_config.max_num_batched_tokens *
|
|
415
400
|
self.dp_size,
|
|
416
|
-
padding_gap=
|
|
401
|
+
padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
|
417
402
|
self.num_tokens_paddings_per_dp = [
|
|
418
403
|
padding // self.dp_size for padding in self.num_tokens_paddings
|
|
419
404
|
]
|
|
@@ -555,12 +540,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
555
540
|
def execute_model(
|
|
556
541
|
self,
|
|
557
542
|
scheduler_output: "VllmSchedulerOutput",
|
|
558
|
-
intermediate_tensors: Optional[
|
|
559
|
-
) -> ModelRunnerOutput | None:
|
|
543
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
544
|
+
) -> ModelRunnerOutput | JaxIntermediateTensors | None:
|
|
560
545
|
if self.execute_model_state is not None:
|
|
561
546
|
raise RuntimeError("State error: sample_tokens() must be called "
|
|
562
547
|
"after execute_model() returns None.")
|
|
563
|
-
_, output = self._execute_model(scheduler_output)
|
|
548
|
+
_, output = self._execute_model(scheduler_output, intermediate_tensors)
|
|
564
549
|
return output
|
|
565
550
|
|
|
566
551
|
def sample_tokens(
|
|
@@ -686,7 +671,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
686
671
|
def _execute_model(
|
|
687
672
|
self,
|
|
688
673
|
scheduler_output: "VllmSchedulerOutput",
|
|
689
|
-
|
|
674
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
675
|
+
) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
|
|
676
|
+
| None]:
|
|
690
677
|
self.persistent_batch_manager.update_states(
|
|
691
678
|
scheduler_output, self.get_mrope_input_positions_fn)
|
|
692
679
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
@@ -764,7 +751,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
764
751
|
scheduler_output) as kv_connector_output:
|
|
765
752
|
# NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
|
|
766
753
|
# but one of them would be `None`
|
|
767
|
-
|
|
768
754
|
(self.kv_caches, hidden_states,
|
|
769
755
|
aux_hidden_states) = self.model_fn(
|
|
770
756
|
self.state,
|
|
@@ -775,8 +761,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
775
761
|
input_positions,
|
|
776
762
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
777
763
|
lora_metadata,
|
|
764
|
+
intermediate_tensors,
|
|
765
|
+
self.is_first_rank,
|
|
766
|
+
self.is_last_rank,
|
|
778
767
|
)
|
|
779
|
-
|
|
768
|
+
if not get_pp_group().is_last_rank:
|
|
769
|
+
assert isinstance(hidden_states, JaxIntermediateTensors)
|
|
770
|
+
hidden_states.kv_connector_output = kv_connector_output
|
|
771
|
+
return attn_metadata, hidden_states
|
|
780
772
|
hidden_states = self._select_from_array_fn(hidden_states,
|
|
781
773
|
logits_indices)
|
|
782
774
|
logits = self.compute_logits_fn(
|
|
@@ -822,18 +814,31 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
822
814
|
|
|
823
815
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
824
816
|
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
|
|
817
|
+
|
|
818
|
+
# TODO(pooyam): Should we move this to `_prepare_inputs`?
|
|
819
|
+
if tpu_sampling_metadata.do_sampling:
|
|
820
|
+
self.rng_params_for_sampling, step_rng = jax.random.split(
|
|
821
|
+
self.rng_params_for_sampling)
|
|
822
|
+
else:
|
|
823
|
+
step_rng = self.rng_params_for_sampling
|
|
824
|
+
|
|
825
825
|
if spec_decode_metadata is None:
|
|
826
826
|
next_tokens = sample(
|
|
827
|
-
|
|
827
|
+
step_rng,
|
|
828
828
|
self.mesh,
|
|
829
829
|
logits,
|
|
830
830
|
tpu_sampling_metadata,
|
|
831
831
|
)
|
|
832
832
|
else:
|
|
833
|
+
if tpu_sampling_metadata.do_sampling:
|
|
834
|
+
bonus_rng, rejection_rng = jax.random.split(step_rng)
|
|
835
|
+
else:
|
|
836
|
+
bonus_rng = step_rng
|
|
837
|
+
rejection_rng = step_rng
|
|
833
838
|
bonus_logits = self._select_from_array_fn(
|
|
834
839
|
logits, spec_decode_metadata.bonus_logits_indices)
|
|
835
840
|
bonus_token_ids = sample(
|
|
836
|
-
|
|
841
|
+
bonus_rng,
|
|
837
842
|
self.mesh,
|
|
838
843
|
bonus_logits,
|
|
839
844
|
tpu_sampling_metadata,
|
|
@@ -847,7 +852,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
847
852
|
target_logits=target_logits,
|
|
848
853
|
bonus_token_ids=bonus_token_ids,
|
|
849
854
|
sampling_metadata=tpu_sampling_metadata,
|
|
850
|
-
key=
|
|
855
|
+
key=rejection_rng,
|
|
851
856
|
)
|
|
852
857
|
|
|
853
858
|
if tpu_sampling_metadata.logprobs:
|
|
@@ -1332,7 +1337,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1332
1337
|
_request_distribution = []
|
|
1333
1338
|
for dp_rank in range(dp_size):
|
|
1334
1339
|
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1335
|
-
|
|
1340
|
+
# The batch has been reordered by _reorder_batch so decode requests come first
|
|
1341
|
+
# Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
|
|
1342
|
+
num_decode_in_dp_rank = 0
|
|
1343
|
+
for req_id in req_ids_dp[dp_rank]:
|
|
1344
|
+
if scheduler_output.num_scheduled_tokens[req_id] == 1:
|
|
1345
|
+
num_decode_in_dp_rank += 1
|
|
1346
|
+
_request_distribution.append(
|
|
1347
|
+
[num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
|
|
1336
1348
|
request_distribution = np.array(_request_distribution).ravel()
|
|
1337
1349
|
|
|
1338
1350
|
use_spec_decode = len(
|
|
@@ -1391,7 +1403,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1391
1403
|
block_tables[
|
|
1392
1404
|
req_offset:req_offset + _num_reqs, :self.
|
|
1393
1405
|
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1394
|
-
|
|
1406
|
+
kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1395
1407
|
# Convert block_tables to 1D on cpu.
|
|
1396
1408
|
block_tables = block_tables.reshape(-1)
|
|
1397
1409
|
block_tables = device_array(
|
|
@@ -1706,3 +1718,34 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1706
1718
|
mappings=mappings,
|
|
1707
1719
|
transpose_keys=transpose_keys,
|
|
1708
1720
|
shard=shard)
|
|
1721
|
+
|
|
1722
|
+
def get_intermediate_tensor_spec(self, num_tokens: int):
|
|
1723
|
+
jax_dtype = to_jax_dtype(self.dtype)
|
|
1724
|
+
num_padded_tokens = runner_utils.get_padded_token_len(
|
|
1725
|
+
self.num_tokens_paddings, num_tokens)
|
|
1726
|
+
sharding = NamedSharding(self.mesh, PartitionSpec())
|
|
1727
|
+
hidden_size = self.model_config.get_hidden_size()
|
|
1728
|
+
spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
|
|
1729
|
+
dtype=jax_dtype,
|
|
1730
|
+
sharding=sharding)
|
|
1731
|
+
tensor_spec = {"hidden_states": spec, "residual": spec}
|
|
1732
|
+
return tensor_spec
|
|
1733
|
+
|
|
1734
|
+
def get_uuid_for_jax_transfer(self,
|
|
1735
|
+
scheduler_output: "VllmSchedulerOutput",
|
|
1736
|
+
rank: int, step: int) -> int:
|
|
1737
|
+
'''
|
|
1738
|
+
Get a uuid for jax.transfer, here we use the hash of
|
|
1739
|
+
scheduler_output + counter_step + sender's rank
|
|
1740
|
+
'''
|
|
1741
|
+
scheduler_output_str = ""
|
|
1742
|
+
if not scheduler_output.num_scheduled_tokens:
|
|
1743
|
+
scheduler_output_str = "empty_batch"
|
|
1744
|
+
else:
|
|
1745
|
+
scheduler_output_str = str(
|
|
1746
|
+
sorted(scheduler_output.num_scheduled_tokens.items()))
|
|
1747
|
+
unique_str = f'{scheduler_output_str} {step} {rank}'
|
|
1748
|
+
import hashlib
|
|
1749
|
+
hasher = hashlib.sha1()
|
|
1750
|
+
hasher.update(unique_str.encode('utf-8'))
|
|
1751
|
+
return int.from_bytes(hasher.digest()[:8], 'big')
|
tpu_inference/runner/utils.py
CHANGED
|
@@ -15,6 +15,7 @@ import jax
|
|
|
15
15
|
from jax._src.interpreters import pxla
|
|
16
16
|
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
17
17
|
|
|
18
|
+
from tpu_inference import envs
|
|
18
19
|
from tpu_inference.logger import init_logger
|
|
19
20
|
from tpu_inference.runner.input_batch import InputBatch
|
|
20
21
|
|
|
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
|
|
|
306
307
|
InferencePhase.BALANCED: False
|
|
307
308
|
}
|
|
308
309
|
self.default_profiling_options = jax.profiler.ProfileOptions()
|
|
309
|
-
self.default_profiling_options.python_tracer_level =
|
|
310
|
-
"PYTHON_TRACER_LEVEL", 0)
|
|
310
|
+
self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
311
311
|
|
|
312
312
|
self.current_phase: str = ""
|
|
313
313
|
|